From aaed71abce62cf735a5e4e39f751869895b0e043 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Fri, 30 Nov 2012 05:03:35 +0100
Subject: [PATCH] Confidence is now over 9000

---
 src/main/scala/leon/FunctionTemplate.scala    |   4 +-
 .../scala/leon/solvers/z3/FairZ3Solver.scala  | 105 +++++++-----------
 .../scala/leon/synthesis/rules/Cegis.scala    |  23 +++-
 3 files changed, 61 insertions(+), 71 deletions(-)

diff --git a/src/main/scala/leon/FunctionTemplate.scala b/src/main/scala/leon/FunctionTemplate.scala
index 04a717151..c7dac3029 100644
--- a/src/main/scala/leon/FunctionTemplate.scala
+++ b/src/main/scala/leon/FunctionTemplate.scala
@@ -197,7 +197,7 @@ object FunctionTemplate {
           None
       }
   
-      val activatingBool : Identifier = FreshIdentifier("a", true).setType(BooleanType)
+      val activatingBool : Identifier = FreshIdentifier("start", true).setType(BooleanType)
   
       funDef match {
         case Some(fd) => 
@@ -206,7 +206,7 @@ object FunctionTemplate {
 
         case None =>
          storeGuarded(activatingBool, false, BooleanLiteral(false))
-         val newFormula = rec(activatingBool, true, body.get)
+         val newFormula = rec(activatingBool, true, newBody.get)
          storeGuarded(activatingBool, true, newFormula)
       }
   
diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
index f1e11303b..90ede1bf5 100644
--- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
@@ -221,7 +221,7 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
     varsInVC ++= variablesOf(expandedVC)
 
     reporter.info(" - Initial unrolling...")
-    val (clauses, guards) = unrollingBank.initialUnrolling(expandedVC)
+    val (clauses, guards) = unrollingBank.scanForNewTemplates(expandedVC)
 
     val cc = toZ3Formula(And(clauses)).get
     solver.assertCnstr(cc)
@@ -579,7 +579,17 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
     }
 
     def scanForNewTemplates(expr: Expr): (Seq[Expr], Seq[(Identifier, Boolean)]) = {
-      (Seq(), Seq())
+      val tmp = FunctionTemplate.mkTemplate(expr)
+
+      val allBlocks : MutableSet[(Identifier,Boolean)] = MutableSet.empty
+
+      for (((i, p), fis) <- tmp.blockers) {
+        if(registerBlocked(i, p, fis)) {
+          allBlocks += i -> p
+        }
+      }
+
+      (tmp.asClauses, allBlocks.toSeq)
     }
 
     private def treatFunctionInvocationSet(sVar : Identifier, pol : Boolean, fis : Set[FunctionInvocation]) : (Seq[Expr],Seq[(Identifier,Boolean)]) = {
@@ -590,9 +600,9 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
         val temp = FunctionTemplate.mkTemplate(fi.funDef)
         val (newExprs,newBlocks) = temp.instantiate(sVar, pol, fi.args)
 
-        for(((i,p),fis) <- newBlocks) {
-          if(registerBlocked(i,p,fis)) {
-            allBlocks += ((i,p))
+        for(((i, p), fis) <- newBlocks) {
+          if(registerBlocked(i, p, fis)) {
+            allBlocks += i -> p
           }
         }
         allNewExprs = allNewExprs ++ newExprs
@@ -600,16 +610,6 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
       (allNewExprs, allBlocks.toSeq)
     }
 
-    def initialUnrolling(formula : Expr) : (Seq[Expr], Seq[(Identifier,Boolean)]) = {
-      val tmp = FunctionTemplate.mkTemplate(formula)
-
-      for (((p1, p2), calls) <- tmp.blockers) {
-        registerBlocked(p1, p2, calls)
-      }
-
-      (tmp.asClauses, tmp.blockers.keySet.toSeq)
-    }
-
     def unlock(id: Identifier, pol: Boolean) : (Seq[Expr], Seq[(Identifier,Boolean)]) = {
       if(!blockMap.isDefinedAt((id,pol))) {
         (Seq.empty,Seq.empty)
@@ -624,26 +624,24 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
   def getNewSolver = new solvers.IncrementalSolver {
     val solver = z3.mkSolver
 
-    private var ownStack        = List[Set[Z3AST]](Set())
-    private var assertionsStack = List[List[Expr]](Nil)
-    private var varsInVC        = Set[Identifier]()
+    private var frameGuards      = List[Z3AST](z3.mkFreshConst("frame", z3.mkBoolSort))
+    private var frameExpressions = List[List[Expr]](Nil)
+    private var varsInVC         = Set[Identifier]()
 
-    def allClausePredicates = ownStack.flatten
+    def entireFormula  = And(frameExpressions.flatten)
 
     def push() {
-      ownStack        = Set[Z3AST]() :: ownStack
-      assertionsStack = Nil :: assertionsStack
+      frameGuards      = z3.mkFreshConst("frame", z3.mkBoolSort) :: frameGuards
+      frameExpressions = Nil :: frameExpressions
     }
 
     def pop(lvl: Int = 1) {
-      val frame = ownStack.head
-      ownStack  = ownStack.tail
+      // We make sure we discard the expressions guarded by this frame
+      solver.assertCnstr(z3.mkNot(frameGuards.head))
 
-      assertionsStack = assertionsStack.tail
-
-      frame.foreach { b =>
-        solver.assertCnstr(z3.mkNot(b))
-      }
+      // Pop the frames
+      frameGuards      = frameGuards.tail
+      frameExpressions = frameExpressions.tail
     }
 
     def check: Option[Boolean] = {
@@ -655,7 +653,6 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
     }
 
     val unrollingBank = new UnrollingBank()
-    var isBankInitialized = false
 
     var foundDefinitiveAnswer = false
     var definitiveAnswer : Option[Boolean] = None
@@ -666,21 +663,21 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
     private var blockingSet: Set[Expr] = Set.empty
 
     def assertCnstr(expression: Expr) {
-      val b = z3.mkFreshConst("b", z3.mkBoolSort)
+      val guard = frameGuards.head
+
       varsInVC ++= variablesOf(expression)
 
-      ownStack        = (ownStack.head + b) :: ownStack.tail
-      assertionsStack = (expression :: assertionsStack.head) :: assertionsStack.tail
+      frameExpressions = (expression :: frameExpressions.head) :: frameExpressions.tail
 
-      solver.assertCnstr(z3.mkImplies(b, toZ3Formula(expression).get))
+      solver.assertCnstr(z3.mkImplies(guard, toZ3Formula(expression).get))
 
-      if (isBankInitialized) {
-        val (newClauses, newGuards) = unrollingBank.scanForNewTemplates(expression)
-        for (cl <- newClauses) {
-          solver.assertCnstr(toZ3Formula(cl).get)
-        }
-        blockingSet ++= newGuards.map(p => if(p._2) Not(Variable(p._1)) else Variable(p._1))
+      val (newClauses, newGuards) = unrollingBank.scanForNewTemplates(expression)
+ 
+      for (cl <- newClauses) {
+        solver.assertCnstr(z3.mkImplies(guard, toZ3Formula(cl).get))
       }
+
+      blockingSet ++= newGuards.map(p => if(p._2) Not(Variable(p._1)) else Variable(p._1))
     }
 
     def getModel = {
@@ -702,7 +699,7 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
       }
 
       def z3CoreToCore(core: Seq[Z3AST]): Set[Expr] = {
-        val internalAssumptions = allClausePredicates.toSet
+        val internalAssumptions = frameGuards.toSet
         val userlandAssumptions = core.filterNot(internalAssumptions)
 
         userlandAssumptions.map(ast => fromZ3Formula(null, ast, None) match {
@@ -712,17 +709,8 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
         }).toSet
       }
 
-      if (!isBankInitialized) {
-        reporter.info(" - Initial unrolling...")
-        val (clauses, guards) = unrollingBank.initialUnrolling(And(assertionsStack.flatten))
-
-        solver.assertCnstr(toZ3Formula(And(clauses)).get)
-        blockingSet ++= guards.map(p => if(p._2) Not(Variable(p._1)) else Variable(p._1))
-      }
-
       // these are the optional sequence of assumption literals
-      val assumptionsAsZ3: Seq[Z3AST] = allClausePredicates ++ assumptions.map(toZ3Formula(_).get)
-
+      val assumptionsAsZ3: Seq[Z3AST] = frameGuards ++ assumptions.map(toZ3Formula(_).get)
 
       var iterationsLeft : Int = if(Settings.unrollingLevel > 0) Settings.unrollingLevel else 16121984
 
@@ -732,14 +720,9 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
         val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get)
         // println("Blocking set : " + blockingSet)
 
-        solver.push()
-        for (block <- blockingSetAsZ3) {
-          solver.assertCnstr(block)
-        }
-
         reporter.info(" - Running Z3 search...")
 
-        val res = solver.checkAssumptions(assumptionsAsZ3 :_*)
+        val res = solver.checkAssumptions((blockingSetAsZ3 ++ assumptionsAsZ3) :_*)
 
         reporter.info(" - Finished search with blocked literals")
 
@@ -754,7 +737,7 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
             val z3model = solver.getModel
 
             if (Settings.verifyModel && false) {
-              val (isValid, model) = validateAndDeleteModel(z3model, toCheckAgainstModels, varsInVC)
+              val (isValid, model) = validateAndDeleteModel(z3model, entireFormula, varsInVC)
 
               if (isValid) {
                 foundAnswer(Some(true), model)
@@ -777,7 +760,6 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
             val core = z3CoreToCore(solver.getUnsatCore)  
 
             foundAnswer(Some(false), core = core)
-            solver.pop(1)
 
           // This branch is both for with and without unsat cores. The
           // distinction is made inside.
@@ -785,9 +767,6 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
 
             val core = z3CoreToCore(solver.getUnsatCore)  
 
-            // Removes blocking literals
-            solver.pop(1)
-              
             if (!forceStop) {
               if (Settings.luckyTest) {
                 // we need the model to perform the additional test
@@ -804,10 +783,8 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ
                 case Some(true) =>
                   if (Settings.luckyTest && !forceStop) {
                     // we might have been lucky :D
-                    val (wereWeLucky, cleanModel) = validateAndDeleteModel(solver.getModel, toCheckAgainstModels, varsInVC)
+                    val (wereWeLucky, cleanModel) = validateAndDeleteModel(solver.getModel, entireFormula, varsInVC)
                     if(wereWeLucky) {
-                      reporter.info("Found lucky to "+solver.getAssertions.toSeq+" with "+assumptionsAsZ3)
-                      reporter.info("Stack: "+ownStack)
                       foundAnswer(Some(true), cleanModel)
                     }
                   }
diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index 98647bbf9..13e9843a4 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -9,6 +9,8 @@ import purescala.TypeTrees._
 import purescala.TreeOps._
 import purescala.Extractors._
 
+import solvers.z3.FairZ3Solver
+
 case object CEGIS extends Rule("CEGIS", 150) {
   def attemptToApplyOn(sctx: SynthesisContext, p: Problem): RuleResult = {
     case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]);
@@ -138,41 +140,51 @@ case object CEGIS extends Rule("CEGIS", 150) {
 
               var continue = true
 
+              val mainSolver: FairZ3Solver = sctx.solver.asInstanceOf[FairZ3Solver]
+
               // solver1 is used for the initial SAT queries
-              val solver1 = sctx.solver.getNewSolver
+              val solver1 = mainSolver.getNewSolver
 
               val basePhi = currentF.entireFormula
               solver1.assertCnstr(basePhi)
 
               // solver2 is used for the CE search
-              val solver2 = sctx.solver.getNewSolver
+              val solver2 = mainSolver.getNewSolver
               solver2.assertCnstr(And(currentF.pathcond :: currentF.program :: Not(currentF.phi) :: Nil))
 
               // solver3 is used for the unsatcore search
-              val solver3 = sctx.solver.getNewSolver
+              val solver3 = mainSolver.getNewSolver
               solver3.assertCnstr(And(currentF.pathcond :: currentF.program :: currentF.phi :: Nil))
 
               while (result.isEmpty && continue) {
                 //println("-"*80)
+                //println(basePhi)
+
                 //println("To satisfy: "+constrainedPhi)
                 solver1.check match {
                   case Some(true) =>
                     val satModel = solver1.getModel
 
+                    //println("Found solution: "+satModel)
                     //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel)))
                     val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)
                     //println("Phi with fixed sat bss: "+fixedBss)
 
                     solver2.push()
                     solver2.assertCnstr(fixedBss)
-                    //println("Formula to validate: "+counterPhi)
 
+                    //println("FORMULA: "+And(currentF.pathcond :: currentF.program :: Not(currentF.phi) :: fixedBss :: Nil))
+
+                    //println("#"*80)
                     solver2.check match {
                       case Some(true) =>
+                        //println("#"*80)
                         val invalidModel = solver2.getModel
 
                         val fixedAss = And(ass.map(a => Equals(Variable(a), invalidModel(a))).toSeq)
 
+                        //println("Found counter example: "+fixedAss)
+
                         solver3.push()
                         solver3.assertCnstr(fixedAss)
 
@@ -221,10 +233,11 @@ case object CEGIS extends Rule("CEGIS", 150) {
                         }
 
                       case Some(false) =>
+                        //println("#"*80)
+                        //println("UNSAT!")
                         //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", "))
                         var mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap
 
-                        //println("Mapping: "+mapping)
 
                         // Resolve mapping
                         for ((c, e) <- mapping) {
-- 
GitLab