From 087889d5202525caf0525b68857d1b60f89ca945 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Fri, 4 Jan 2013 20:07:19 +0100
Subject: [PATCH] Accelerate CEGIS by disabling features we thought would help

CEGIS now support internal flags that can enable/disable its features:

1) Injecting Counter-Examples on top of the unsat core to drive the
   search to interesting areas. Does not help => disabled

2) Computing Unsat-Cores to strenghten the search of programs. Help in
   some cases, doesn't hurt much => enabled

3) Checking whether the formula is unsat without blockers, to unrolling
   when there is no chance of finding a solution. Does not help =>
   disable

4) Add support for function calls in CEGIS generators. This is disabled
   by default and can be enabled using --cegis:gencalls.

It seems that doing additional checks in 1) and 3) triggers FairZ3 to
unroll more, tempering with the performance of the solver.

Also, this implements some improvements in the resulting programs by
simplifying further expressions.
---
 src/main/scala/leon/purescala/TreeOps.scala   |  32 ++++-
 src/main/scala/leon/purescala/Trees.scala     |   9 +-
 src/main/scala/leon/purescala/TypeTrees.scala |   9 ++
 .../leon/solvers/z3/AbstractZ3Solver.scala    |   4 +-
 .../leon/solvers/z3/FunctionTemplate.scala    |   2 +
 .../scala/leon/synthesis/ParallelSearch.scala |   2 +-
 .../leon/synthesis/SynthesisContext.scala     |  16 ++-
 .../scala/leon/synthesis/SynthesisPhase.scala |  26 ++--
 .../scala/leon/synthesis/Synthesizer.scala    |   1 +
 .../leon/synthesis/SynthesizerOptions.scala   |   3 +-
 .../scala/leon/synthesis/rules/ADTSplit.scala |   1 -
 .../scala/leon/synthesis/rules/Cegis.scala    | 128 +++++++++++-------
 .../leon/synthesis/utils/Benchmarks.scala     |  14 +-
 .../SynthesisProblemExtractionPhase.scala     |  11 +-
 .../leon/test/synthesis/SynthesisSuite.scala  |  19 +--
 testcases/synthesis/CegisFunctions.scala      |  30 ++++
 16 files changed, 224 insertions(+), 83 deletions(-)
 create mode 100644 testcases/synthesis/CegisFunctions.scala

diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 4a5ca356c..26148a779 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -1118,7 +1118,7 @@ object TreeOps {
               p
           }
 
-          MatchExpr(scrutinee, Seq(SimpleCase(simplifyPattern(pattern), newThen), SimpleCase(WildcardPattern(None), elze)))
+          MatchExpr(scrutinee, Seq(SimpleCase(simplifyPattern(pattern), newThen), SimpleCase(WildcardPattern(None), elze))).setType(e.getType)
         } else {
           e
         }
@@ -1197,6 +1197,36 @@ object TreeOps {
         val se = rec(e, path)
         Let(i, se, rec(b, Equals(Variable(i), se) +: path))
 
+      case MatchExpr(scrut, cases) =>
+        val rs = rec(scrut, path)
+
+        var stillPossible = true
+
+        if (cases.exists(_.hasGuard)) {
+          // unsupported for now
+          e
+        } else {
+          MatchExpr(rs, cases.flatMap { c => 
+            val patternExpr = conditionForPattern(rs, c.pattern)
+
+            if (stillPossible && !contradictedBy(patternExpr, path)) {
+
+              if (impliedBy(patternExpr, path)) {
+                stillPossible = false
+              }
+
+              c match {
+                case SimpleCase(p, rhs) =>
+                  Some(SimpleCase(p, rec(rhs, patternExpr +: path)))
+                case GuardedCase(_, _, _) =>
+                  sys.error("woot.")
+              }
+            } else {
+              None
+            }
+          })
+        }
+
       case LetTuple(is, e, b) =>
         // Similar to the Let case
         val se = rec(e, path)
diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index 4e5e0fea6..98888aa5e 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -49,7 +49,9 @@ object Trees {
 
     funDef.args.zip(args).foreach { case (a, c) => typeCheck(c, a.tpe) }
   }
-  case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr 
+  case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr with FixedType {
+    val fixedType = leastUpperBound(then.getType, elze.getType).getOrElse(AnyType)
+  }
 
   case class Tuple(exprs: Seq[Expr]) extends Expr {
     val subTpes = exprs.map(_.getType)
@@ -87,7 +89,10 @@ object Trees {
     def unapply(me: MatchExpr) : Option[(Expr,Seq[MatchCase])] = if (me == null) None else Some((me.scrutinee, me.cases))
   }
 
-  class MatchExpr(val scrutinee: Expr, val cases: Seq[MatchCase]) extends Expr with ScalacPositional {
+  class MatchExpr(val scrutinee: Expr, val cases: Seq[MatchCase]) extends Expr with ScalacPositional with FixedType {
+
+    val fixedType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(AnyType)
+
     def scrutineeClassType: ClassType = scrutinee.getType.asInstanceOf[ClassType]
 
     override def equals(that: Any): Boolean = (that != null) && (that match {
diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala
index 7d2897ad5..d8c673fcc 100644
--- a/src/main/scala/leon/purescala/TypeTrees.scala
+++ b/src/main/scala/leon/purescala/TypeTrees.scala
@@ -107,6 +107,15 @@ object TypeTrees {
     case _ => None
   }
 
+  def leastUpperBound(ts: Seq[TypeTree]): Option[TypeTree] = {
+    def olub(ot1: Option[TypeTree], t2: Option[TypeTree]): Option[TypeTree] = ot1 match {
+      case Some(t1) => leastUpperBound(t1, t2.get)
+      case None => None
+    }
+
+    ts.map(Some(_)).reduceLeft(olub)
+  }
+
   def isSubtypeOf(t1: TypeTree, t2: TypeTree): Boolean = {
     leastUpperBound(t1, t2) == Some(t2)
   }
diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
index aefe9fd2c..2a14a92f5 100644
--- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
@@ -393,7 +393,9 @@ trait AbstractZ3Solver extends solvers.IncrementalSolverBuilder {
             //   scala.sys.error("Error in formula being translated to Z3: identifier " + id + " seems to have escaped its let-definition")
             // }
 
-            assert(!this.isInstanceOf[FairZ3Solver], "Trying to convert unknown variable '"+id+"' while using FairZ3")
+            // Remove this safety check, since choose() expresions are now
+            // translated to non-unrollable variables, that end up here.
+            // assert(!this.isInstanceOf[FairZ3Solver], "Trying to convert unknown variable '"+id+"' while using FairZ3")
 
             val newAST = z3.mkFreshConst(id.uniqueName/*name*/, typeToSort(v.getType))
             z3Vars = z3Vars + (id -> newAST)
diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala
index 6a4a96e15..2d6399660 100644
--- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala
+++ b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala
@@ -153,6 +153,8 @@ object FunctionTemplate {
           }
         }
 
+        case c @ Choose(_, _) => Variable(FreshIdentifier("choose", true).setType(c.getType))
+
         case n @ NAryOperator(as, r) => r(as.map(a => rec(pathVar, pathPol, a))).setType(n.getType)
         case b @ BinaryOperator(a1, a2, r) => r(rec(pathVar, pathPol, a1), rec(pathVar, pathPol, a2)).setType(b.getType)
         case u @ UnaryOperator(a, r) => r(rec(pathVar, pathPol, a)).setType(u.getType)
diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala
index 5ec24b47e..fb3f824c8 100644
--- a/src/main/scala/leon/synthesis/ParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/ParallelSearch.scala
@@ -24,7 +24,7 @@ class ParallelSearch(synth: Synthesizer,
 
     solver.initZ3
 
-    val ctx = SynthesisContext(solver = solver, reporter = synth.reporter, shouldStop = synth.shouldStop)
+    val ctx = SynthesisContext.fromSynthesizer(synth).copy(solver = solver)
 
     synchronized {
       contexts = ctx :: contexts
diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala
index d23f4df29..aa46f6381 100644
--- a/src/main/scala/leon/synthesis/SynthesisContext.scala
+++ b/src/main/scala/leon/synthesis/SynthesisContext.scala
@@ -2,16 +2,30 @@ package leon
 package synthesis
 
 import solvers.Solver
+import purescala.Trees._
+import purescala.Definitions.{Program, FunDef}
+import purescala.Common.Identifier
 
 import java.util.concurrent.atomic.AtomicBoolean
 
 case class SynthesisContext(
+  options: SynthesizerOptions,
+  functionContext: Option[FunDef],
+  program: Program,
   solver: Solver,
   reporter: Reporter,
   shouldStop: AtomicBoolean
 )
 
 object SynthesisContext {
-  def fromSynthesizer(synth: Synthesizer) = SynthesisContext(synth.solver, synth.reporter, new AtomicBoolean(false))
+  def fromSynthesizer(synth: Synthesizer) = {
+    SynthesisContext(
+      synth.options,
+      synth.functionContext,
+      synth.program,
+      synth.solver,
+      synth.reporter,
+      new AtomicBoolean(false))
+  }
 }
 
diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala
index 7adb83366..ed7cf7429 100644
--- a/src/main/scala/leon/synthesis/SynthesisPhase.scala
+++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala
@@ -14,13 +14,14 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
   val description = "Synthesis"
 
   override val definedOptions : Set[LeonOptionDef] = Set(
-    LeonFlagOptionDef(    "inplace",    "--inplace",         "Debug level"),
-    LeonOptValueOptionDef("parallel",   "--parallel[=N]",    "Parallel synthesis search using N workers"),
-    LeonFlagOptionDef(    "derivtrees", "--derivtrees",      "Generate derivation trees"),
-    LeonFlagOptionDef(    "firstonly",  "--firstonly",       "Stop as soon as one synthesis solution is found"),
-    LeonValueOptionDef(   "timeout",    "--timeout=T",       "Timeout after T seconds when searching for synthesis solutions .."),
-    LeonValueOptionDef(   "costmodel",  "--costmodel=cm",    "Use a specific cost model for this search"),
-    LeonValueOptionDef(   "functions",  "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..")
+    LeonFlagOptionDef(    "inplace",         "--inplace",         "Debug level"),
+    LeonOptValueOptionDef("parallel",        "--parallel[=N]",    "Parallel synthesis search using N workers"),
+    LeonFlagOptionDef(    "derivtrees",      "--derivtrees",      "Generate derivation trees"),
+    LeonFlagOptionDef(    "firstonly",       "--firstonly",       "Stop as soon as one synthesis solution is found"),
+    LeonValueOptionDef(   "timeout",         "--timeout=T",       "Timeout after T seconds when searching for synthesis solutions .."),
+    LeonValueOptionDef(   "costmodel",       "--costmodel=cm",    "Use a specific cost model for this search"),
+    LeonValueOptionDef(   "functions",       "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,.."),
+    LeonFlagOptionDef(    "cegis:gencalls",  "--cegis:gencalls",  "Include function calls in CEGIS generators")
   )
 
   def run(ctx: LeonContext)(p: Program): Program = {
@@ -74,6 +75,9 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
           options = options.copy(searchWorkers = nWorkers)
         }
 
+      case LeonFlagOption("cegis:gencalls") =>
+        options = options.copy(cegisGenerateFunCalls = true)
+
       case LeonFlagOption("derivtrees") =>
         options = options.copy(generateDerivationTrees = true)
 
@@ -89,6 +93,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
         case ch @ Choose(vars, pred) =>
           val problem = Problem.fromChoose(ch)
           val synth = new Synthesizer(ctx,
+                                      Some(f),
                                       mainSolver,
                                       p,
                                       problem,
@@ -117,9 +122,11 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
 
     // Simplify expressions
     val simplifiers = List[Expr => Expr](
-      simplifyTautologies(uninterpretedZ3)(_), 
+      simplifyTautologies(uninterpretedZ3)(_),
       simplifyLets _,
       decomposeIfs _,
+      matchToIfThenElse _,
+      simplifyPaths(uninterpretedZ3)(_),
       patternMatchReconstruction _,
       simplifyTautologies(uninterpretedZ3)(_),
       simplifyLets _,
@@ -129,7 +136,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
     def simplify(e: Expr): Expr = simplifiers.foldLeft(e){ (x, sim) => sim(x) }
 
     val chooseToExprs = solutions.map {
-      case (ch, (fd, sol)) => (ch, (fd, simplify(sol.toExpr)))
+      case (ch, (fd, sol)) =>
+        (ch, (fd, simplify(sol.toExpr)))
     }
 
     if (inPlace) {
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index b15e1ec69..0a402e5db 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -18,6 +18,7 @@ import synthesis.search._
 import java.util.concurrent.atomic.AtomicBoolean
 
 class Synthesizer(val context : LeonContext,
+                  val functionContext: Option[FunDef],
                   val solver: Solver,
                   val program: Program,
                   val problem: Problem,
diff --git a/src/main/scala/leon/synthesis/SynthesizerOptions.scala b/src/main/scala/leon/synthesis/SynthesizerOptions.scala
index 177d21c49..e9d2f4c94 100644
--- a/src/main/scala/leon/synthesis/SynthesizerOptions.scala
+++ b/src/main/scala/leon/synthesis/SynthesizerOptions.scala
@@ -7,5 +7,6 @@ case class SynthesizerOptions(
   searchWorkers: Int               = 1,
   firstOnly: Boolean               = false,
   timeoutMs: Option[Long]          = None,
-  costModel: CostModel             = CostModel.default
+  costModel: CostModel             = CostModel.default,
+  cegisGenerateFunCalls: Boolean   = false
 )
diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala
index 61d3a9768..7235f9bb4 100644
--- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala
+++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala
@@ -41,7 +41,6 @@ case object ADTSplit extends Rule("ADT Split.") {
         }
     }
 
-
     candidates.collect{ _ match {
       case Some((id, cases)) =>
         val oas = p.as.filter(_ != id)
diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index c7f84e483..0be347929 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -13,6 +13,13 @@ import solvers.z3.FairZ3Solver
 
 case object CEGIS extends Rule("CEGIS") {
   def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = {
+
+    // CEGIS Flags to actiave or de-activate features
+    val useCounterExamples    = false
+    val useUninterpretedProbe = false
+    val useUnsatCores         = true
+    val useFunGenerators      = sctx.options.cegisGenerateFunCalls
+
     case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]);
 
     var generators = Map[TypeTree, Generator]()
@@ -64,6 +71,31 @@ case object CEGIS extends Rule("CEGIS") {
       p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]()))
     }
 
+    def funcAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = {
+      if (useFunGenerators) {
+        def isCandidate(fd: FunDef): Boolean = {
+          // Prevents recursive calls
+          val isRecursiveCall = sctx.functionContext match {
+            case Some(cfd) =>
+              (sctx.program.transitiveCallers(cfd) + cfd) contains fd
+
+            case None =>
+              false
+          }
+
+          isSubtypeOf(fd.returnType, t) && !isRecursiveCall
+        }
+
+        sctx.program.definedFunctions.filter(isCandidate).map{ fd =>
+          val ids = fd.args.map(vd => FreshIdentifier("c", true).setType(vd.getType))
+
+          (FunctionInvocation(fd, ids.map(Variable(_))), ids.toSet)
+        }.toList
+      } else {
+        Nil
+      }
+    }
+
     class TentativeFormula(val pathcond: Expr,
                            val phi: Expr,
                            var program: Expr,
@@ -77,7 +109,7 @@ case object CEGIS extends Rule("CEGIS") {
 
         for ((_, recIds) <- recTerms; recId <- recIds) {
           val gen  = getGenerator(recId.getType)
-          val alts = gen.altBuilder() ::: inputAlternatives(recId.getType)
+          val alts = gen.altBuilder() ::: inputAlternatives(recId.getType) ::: funcAlternatives(recId.getType)
 
           val altsWithBranches = alts.map(alt => FreshIdentifier("b", true).setType(BooleanType) -> alt)
 
@@ -122,6 +154,7 @@ case object CEGIS extends Rule("CEGIS") {
 
     val xsSet = p.xs.toSet
 
+
     val (exprsA, others) = ands.partition(e => (variablesOf(e) & xsSet).isEmpty)
     if (exprsA.isEmpty) {
       val res = new RuleInstantiation(p, this, SolutionBuilder.none) {
@@ -150,42 +183,35 @@ case object CEGIS extends Rule("CEGIS") {
           try {
             do {
               val (clauses, bounds) = unrolling.unroll
-              //println("UNROLLING: "+clauses+" WITH BOUNDS "+bounds)
-              solver1.assertCnstr(And(clauses))
-              solver2.assertCnstr(And(clauses))
+              //println("UNROLLING: ")
+              //for (c <- clauses) {
+              //  println(" - " + c)
+              //}
+              //println("BOUNDS "+bounds)
 
-              //println("="*80)
-              //println("Was: "+lastF.entireFormula)
-              //println("Now Trying : "+currentF.entireFormula)
+              val clause = And(clauses)
+              solver1.assertCnstr(clause)
+              solver2.assertCnstr(clause)
 
               val tpe = TupleType(p.xs.map(_.getType))
               val bss = unrolling.bss
 
               var continue = !clauses.isEmpty
 
-              //println("Unrolling #"+unrolings+" bss size: "+bss.size)
-
               while (result.isEmpty && continue && !sctx.shouldStop.get) {
                 //println("Looking for CE...")
                 //println("-"*80)
-                //println(basePhi)
 
-                //println("To satisfy: "+constrainedPhi)
                 solver1.checkAssumptions(bounds.map(id => Not(Variable(id)))) match {
                   case Some(true) =>
                     val satModel = solver1.getModel
 
-                    //println("Found solution: "+satModel)
-                    //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel)))
-                    //val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq)
-                    //println("Phi with fixed sat bss: "+fixedBss)
-
                     val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match {
                       case BooleanLiteral(true)  => Variable(b)
                       case BooleanLiteral(false) => Not(Variable(b))
                     })
 
-                    //println("FORMULA: "+And(currentF.pathcond :: currentF.program :: Not(currentF.phi) :: fixedBss :: Nil))
+                    //println("Found solution: "+bssAssumptions)
 
                     //println("#"*80)
                     solver2.checkAssumptions(bssAssumptions) match {
@@ -201,22 +227,18 @@ case object CEGIS extends Rule("CEGIS") {
                         solver1.assertCnstr(fixedAss)
                         //println("Found counter example: "+fixedAss)
 
-                        val unsatCore = solver1.checkAssumptions(bssAssumptions) match {
-                          case Some(false) =>
-                            val core = solver1.getUnsatCore
-                            //println("Formula: "+mustBeUnsat)
-                            //println("Core:    "+core)
-                            //println(synth.solver.solveSAT(And(mustBeUnsat +: bssAssumptions.toSeq)))
-                            //println("maxcore: "+bssAssumptions)
-                            if (core.isEmpty) {
-                              // This happens if unrolling level is insufficient, it becomes unsat no matter what the assumptions are.
-                              //sctx.reporter.warning("Got empty core, must be unsat without assumptions!")
-                              Set()
-                            } else {
-                              core
-                            }
-                          case _ =>
-                            bssAssumptions
+                        val unsatCore = if (useUnsatCores) {
+                          solver1.checkAssumptions(bssAssumptions) match {
+                            case Some(false) =>
+                              // Core might be empty if unrolling level is
+                              // insufficient, it becomes unsat no matter what
+                              // the assumptions are.
+                              solver1.getUnsatCore
+                            case _ =>
+                              bssAssumptions
+                          }
+                        } else {
+                          bssAssumptions
                         }
 
                         solver1.pop()
@@ -224,29 +246,31 @@ case object CEGIS extends Rule("CEGIS") {
                         if (unsatCore.isEmpty) {
                           continue = false
                         } else {
+                          if (useCounterExamples) {
+                            val freshCss = unrolling.css.map(c => c -> Variable(FreshIdentifier(c.name, true).setType(c.getType))).toMap
+                            val ceIn     = ass.collect { 
+                              case id if invalidModel contains id => id -> invalidModel(id)
+                            }
 
-                          val freshCss = unrolling.css.map(c => c -> Variable(FreshIdentifier(c.name, true).setType(c.getType))).toMap
-                          val ceIn     = ass.collect { 
-                            case id if invalidModel contains id => id -> invalidModel(id)
-                          }
+                            val ceMap = (freshCss ++ ceIn)
+
+                            val counterexample = substAll(ceMap, And(Seq(unrolling.program, unrolling.phi)))
 
-                          val counterexample = substAll(freshCss ++ ceIn, And(Seq(unrolling.program, unrolling.phi)))
+                            //val And(ands) = counterexample
+                            //println("CE:")
+                            //for (a <- ands) {
+                            //  println(" - "+a)
+                            //}
 
-                          solver1.assertCnstr(counterexample)
-                          solver2.assertCnstr(counterexample)
+                            solver1.assertCnstr(counterexample)
+                          }
 
-                          //predicates = Not(And(unsatCore.toSeq)) +: counterexample +: predicates
                           solver1.assertCnstr(Not(And(unsatCore.toSeq)))
-                          solver2.assertCnstr(Not(And(unsatCore.toSeq)))
                         }
 
                       case Some(false) =>
-                        //println("#"*80)
-                        //println("UNSAT!")
-                        //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", "))
                         var mapping = unrolling.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap
 
-
                         // Resolve mapping
                         for ((c, e) <- mapping) {
                           mapping += c -> substAll(mapping, e)
@@ -263,7 +287,19 @@ case object CEGIS extends Rule("CEGIS") {
 
                   case Some(false) =>
                     //println("%%%% UNSAT")
+
+                    if (useUninterpretedProbe) {
+                      solver1.check match {
+                        case Some(false) =>
+                          // Unsat even without blockers (under which fcalls are then uninterpreted)
+                          result = Some(RuleApplicationImpossible)
+
+                        case _ =>
+                      }
+                    }
+
                     continue = false
+
                   case _ =>
                     //println("%%%% WOOPS")
                     continue = false
diff --git a/src/main/scala/leon/synthesis/utils/Benchmarks.scala b/src/main/scala/leon/synthesis/utils/Benchmarks.scala
index f5cd758c4..99bd387fd 100644
--- a/src/main/scala/leon/synthesis/utils/Benchmarks.scala
+++ b/src/main/scala/leon/synthesis/utils/Benchmarks.scala
@@ -80,13 +80,21 @@ object Benchmarks extends App {
 
     val pipeline = leon.plugin.ExtractionPhase andThen SynthesisProblemExtractionPhase
 
-    val (results, solver) = pipeline.run(innerCtx)(file.getPath :: Nil)
+    val (program, results) = pipeline.run(innerCtx)(file.getPath :: Nil)
 
-
-    val sctx = SynthesisContext(solver, new DefaultReporter, new java.util.concurrent.atomic.AtomicBoolean)
+    val solver = new FairZ3Solver(ctx.copy(reporter = new SilentReporter))
 
 
     for ((f, ps) <- results.toSeq.sortBy(_._1.id.toString); p <- ps) {
+      val sctx = SynthesisContext(
+        options = opts,
+        functionContext = Some(f),
+        program = program,
+        solver = solver,
+        reporter = new DefaultReporter,
+        shouldStop = new java.util.concurrent.atomic.AtomicBoolean
+      )
+
       val ts = System.currentTimeMillis
 
       val rr = rule.instantiateOn(sctx, p)
diff --git a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala
index d875118df..c31a04a6a 100644
--- a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala
+++ b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala
@@ -8,16 +8,11 @@ import purescala.Definitions._
 import solvers.z3._
 import solvers.Solver
 
-object SynthesisProblemExtractionPhase extends LeonPhase[Program, (Map[FunDef, Seq[Problem]], Solver)] {
+object SynthesisProblemExtractionPhase extends LeonPhase[Program, (Program, Map[FunDef, Seq[Problem]])] {
   val name        = "Synthesis Problem Extraction"
   val description = "Synthesis Problem Extraction"
 
-  def run(ctx: LeonContext)(p: Program): (Map[FunDef, Seq[Problem]], Solver) = {
-
-     val silentContext : LeonContext = ctx.copy(reporter = new SilentReporter)
-     val mainSolver = new FairZ3Solver(silentContext)
-     mainSolver.setProgram(p)
-
+  def run(ctx: LeonContext)(p: Program): (Program, Map[FunDef, Seq[Problem]]) = {
     var results  = Map[FunDef, Seq[Problem]]()
     def noop(u:Expr, u2: Expr) = u
 
@@ -38,7 +33,7 @@ object SynthesisProblemExtractionPhase extends LeonPhase[Program, (Map[FunDef, S
       treeCatamorphism(x => x, noop, actOnChoose(f), f.body.get)
     }
 
-    (results, mainSolver)
+    (p, results)
   }
 
 }
diff --git a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala
index 55ec29906..96d7104d7 100644
--- a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala
+++ b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala
@@ -21,7 +21,7 @@ class SynthesisSuite extends FunSuite {
     counter
   }
 
-  def forProgram(title: String)(content: String)(block: (Solver, FunDef, Problem) => Unit) {
+  def forProgram(title: String)(content: String)(block: (SynthesisContext, FunDef, Problem) => Unit) {
 
     val ctx = LeonContext(
       settings = Settings(
@@ -37,11 +37,16 @@ class SynthesisSuite extends FunSuite {
 
     val pipeline = leon.plugin.TemporaryInputPhase andThen leon.plugin.ExtractionPhase andThen SynthesisProblemExtractionPhase
 
-    val (results, solver) = pipeline.run(ctx)((content, Nil))
+    val (program, results) = pipeline.run(ctx)((content, Nil))
+
+    val solver = new FairZ3Solver(ctx)
+    solver.setProgram(program)
 
     for ((f, ps) <- results; p <- ps) {
       test("Synthesizing %3d: %-20s [%s]".format(nextInt(), f.id.toString, title)) {
-        block(solver, f, p)
+        val sctx = SynthesisContext(opts, Some(f), program, solver, new DefaultReporter, new java.util.concurrent.atomic.AtomicBoolean)
+
+        block(sctx, f, p)
       }
     }
   }
@@ -99,9 +104,7 @@ object Injection {
 }
     """
   ) {
-    case (solver, fd, p) =>
-      val sctx = SynthesisContext(solver, new SilentReporter, new java.util.concurrent.atomic.AtomicBoolean)
-
+    case (sctx, fd, p) =>
       assertAllAlternativesSucceed(sctx, rules.CEGIS.instantiateOn(sctx, p))
       assertFastEnough(sctx, rules.CEGIS.instantiateOn(sctx, p), 100)
   }
@@ -127,9 +130,7 @@ object Injection {
 }
     """
   ) {
-    case (solver, fd, p) =>
-      val sctx = SynthesisContext(solver, new DefaultReporter, new java.util.concurrent.atomic.AtomicBoolean)
-
+    case (sctx, fd, p) =>
       rules.CEGIS.instantiateOn(sctx, p).head.apply(sctx) match {
         case RuleSuccess(sol) =>
           assert(false, "CEGIS should have failed, but found : %s".format(sol))
diff --git a/testcases/synthesis/CegisFunctions.scala b/testcases/synthesis/CegisFunctions.scala
new file mode 100644
index 000000000..a3451d01d
--- /dev/null
+++ b/testcases/synthesis/CegisFunctions.scala
@@ -0,0 +1,30 @@
+import leon.Utils._
+
+object CegisTests {
+  sealed abstract class List
+  case class Cons(head: Int, tail: List) extends List
+  case class Nil() extends List
+
+  // proved with unrolling=0
+  def size(l: List) : Int = (l match {
+      case Nil() => 0
+      case Cons(_, t) => 1 + size(t)
+  }) ensuring(res => res >= 0)
+
+  def content(l: List): Set[Int] = l match {
+    case Nil() => Set()
+    case Cons(i, t) => Set(i) ++ content(t)
+  }
+
+  def insert(l: List, i: Int) = {
+    Cons(i, l)
+  }.ensuring(res => size(res) == size(l)+1 && content(res) == content(l) ++ Set(i))
+
+  def testInsert(l: List, i: Int) = {
+    choose { (o: List) => size(o) == size(l) + 1 }
+  }
+
+  def testDelete(l: List, i: Int) = {
+    choose { (o: List) => size(o) == size(l) - 1 }
+  }
+}
-- 
GitLab