diff --git a/src/main/scala/leon/solvers/IncrementalSolver.scala b/src/main/scala/leon/solvers/IncrementalSolver.scala index 8a8e5de49efde3657c4da59dbd218a9b9189cd6e..968eb83905a9dea7a6c236da4b77f7b12e17adce 100644 --- a/src/main/scala/leon/solvers/IncrementalSolver.scala +++ b/src/main/scala/leon/solvers/IncrementalSolver.scala @@ -19,7 +19,7 @@ trait IncrementalSolver { def assertCnstr(expression: Expr): Unit def check: Option[Boolean] - def checkAssumptions(assumptions: Seq[Expr]): Option[Boolean] + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] def getModel: Map[Identifier, Expr] def getUnsatCore: Set[Expr] } @@ -50,8 +50,8 @@ trait NaiveIncrementalSolver extends IncrementalSolverBuilder { solveSAT(And(allConstraints()))._1 } - def checkAssumptions(assumptions: Seq[Expr]): Option[Boolean] = { - solveSAT(And(assumptions ++ allConstraints()))._1 match { + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + solveSAT(And((assumptions ++ allConstraints()).toSeq))._1 match { case Some(true) => unsatCore = Set[Expr]() Some(true) diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 0660dd26ce43fd2410786b8156449a7307a67ef6..71e26b83f9b50af253607fa5c2680623a92ebc56 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -694,6 +694,8 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ def getNewSolver = new solvers.IncrementalSolver { val solver = z3.mkSolver + private var varsInVC = Set[Identifier]() + def push() { solver.push } @@ -703,6 +705,7 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ } def assertCnstr(expression: Expr) { + varsInVC ++= variablesOf(expression) solver.assertCnstr(toZ3Formula(expression).get) } @@ -710,8 +713,8 @@ class FairZ3Solver(context : LeonContext) extends Solver(context) with AbstractZ solver.check } - def checkAssumptions(assumptions: Seq[Expr]): Option[Boolean] = { - solver.checkAssumptions(assumptions.map(toZ3Formula(_).get) : _*) + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_).get) : _*) } def getModel = { diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 2d40fcf1da0859685f8ef39b9270d080adb71728..16f3bf92af8b53ef072025dc720a1897b48d47e5 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -103,9 +103,9 @@ class UninterpretedZ3Solver(context : LeonContext) extends Solver(context) with solver.check } - def checkAssumptions(assumptions: Seq[Expr]): Option[Boolean] = { + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { variables ++= assumptions.flatMap(variablesOf(_)) - solver.checkAssumptions(assumptions.map(toZ3Formula(_).get) : _*) + solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_).get) : _*) } def getModel = { diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 41974dccd80ff7028c7f1f50ae78f425d30e279b..98647bbf92676ba434c7c9b5fd1f85abb1d5175b 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -138,36 +138,52 @@ case object CEGIS extends Rule("CEGIS", 150) { var continue = true + // solver1 is used for the initial SAT queries + val solver1 = sctx.solver.getNewSolver + + val basePhi = currentF.entireFormula + solver1.assertCnstr(basePhi) + + // solver2 is used for the CE search + val solver2 = sctx.solver.getNewSolver + solver2.assertCnstr(And(currentF.pathcond :: currentF.program :: Not(currentF.phi) :: Nil)) + + // solver3 is used for the unsatcore search + val solver3 = sctx.solver.getNewSolver + solver3.assertCnstr(And(currentF.pathcond :: currentF.program :: currentF.phi :: Nil)) + while (result.isEmpty && continue) { - val basePhi = currentF.entireFormula - val constrainedPhi = And(basePhi +: predicates) //println("-"*80) //println("To satisfy: "+constrainedPhi) - sctx.solver.solveSAT(constrainedPhi) match { - case (Some(true), satModel) => - //println("Found candidate!: "+satModel.filterKeys(bss)) + solver1.check match { + case Some(true) => + val satModel = solver1.getModel //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel))) val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) //println("Phi with fixed sat bss: "+fixedBss) - val counterPhi = And(Seq(currentF.pathcond, fixedBss, currentF.program, Not(currentF.phi))) + solver2.push() + solver2.assertCnstr(fixedBss) //println("Formula to validate: "+counterPhi) - sctx.solver.solveSAT(counterPhi) match { - case (Some(true), invalidModel) => - val fixedAss = And(ass.map(a => Equals(Variable(a), invalidModel(a))).toSeq) + solver2.check match { + case Some(true) => + val invalidModel = solver2.getModel + val fixedAss = And(ass.map(a => Equals(Variable(a), invalidModel(a))).toSeq) - val mustBeUnsat = And(currentF.pathcond :: currentF.program :: fixedAss :: currentF.phi :: Nil) + solver3.push() + solver3.assertCnstr(fixedAss) val bssAssumptions: Set[Expr] = bss.toSet.map { b: Identifier => satModel(b) match { case BooleanLiteral(true) => Variable(b) case BooleanLiteral(false) => Not(Variable(b)) }} - val unsatCore = sctx.solver.solveSATWithCores(mustBeUnsat, bssAssumptions) match { - case ((Some(false), _, core)) => + val unsatCore = solver3.checkAssumptions(bssAssumptions) match { + case Some(false) => + val core = solver3.getUnsatCore //println("Formula: "+mustBeUnsat) //println("Core: "+core) //println(synth.solver.solveSAT(And(mustBeUnsat +: bssAssumptions.toSeq))) @@ -183,6 +199,8 @@ case object CEGIS extends Rule("CEGIS", 150) { bssAssumptions } + solver3.pop() + val freshCss = currentF.css.map(c => c -> Variable(FreshIdentifier(c.name, true).setType(c.getType))).toMap val ceIn = ass.map(id => id -> invalidModel(id)) @@ -199,10 +217,10 @@ case object CEGIS extends Rule("CEGIS", 150) { continue = false } else { //predicates = Not(And(unsatCore.toSeq)) +: counterexemple +: predicates - predicates = Not(And(unsatCore.toSeq)) +: predicates + solver1.assertCnstr(Not(And(unsatCore.toSeq))) } - case (Some(false), _) => + case Some(false) => //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", ")) var mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap @@ -220,7 +238,9 @@ case object CEGIS extends Rule("CEGIS", 150) { continue = false } - case (Some(false), _) => + solver2.pop() + + case Some(false) => //println("%%%% UNSAT") continue = false case _ => @@ -239,6 +259,7 @@ case object CEGIS extends Rule("CEGIS", 150) { } catch { case e: Throwable => sctx.reporter.warning("CEGIS crashed: "+e.getMessage) + e.printStackTrace RuleApplicationImpossible }