diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 13e9843a47c8092c565c1f8e17dd327381732d81..80171e36db66787a306b93b582e241f8b5dd76e3 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -64,13 +64,14 @@ case object CEGIS extends Rule("CEGIS", 150) { p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]())) } - case class TentativeFormula(pathcond: Expr, - phi: Expr, - program: Expr, - mappings: Map[Identifier, (Identifier, Expr)], - recTerms: Map[Identifier, Set[Identifier]]) { - def unroll: TentativeFormula = { - var newProgram = List[Expr]() + class TentativeFormula(val pathcond: Expr, + val phi: Expr, + var program: Expr, + var mappings: Map[Identifier, (Identifier, Expr)], + var recTerms: Map[Identifier, Set[Identifier]]) { + + def unroll: (List[Expr], Set[Identifier]) = { + var newClauses = List[Expr]() var newRecTerms = Map[Identifier, Set[Identifier]]() var newMappings = Map[Identifier, (Identifier, Expr)]() @@ -98,10 +99,14 @@ case object CEGIS extends Rule("CEGIS", 150) { Implies(Variable(bid), Equals(Variable(recId), ex)) } - newProgram = newProgram ::: pre :: cases + newClauses = newClauses ::: pre :: cases } - TentativeFormula(pathcond, phi, And(program :: newProgram), mappings ++ newMappings, newRecTerms) + program = And(program :: newClauses) + mappings = mappings ++ newMappings + recTerms = newRecTerms + + (newClauses, newRecTerms.keySet) } def bounds = recTerms.keySet.map(id => Not(Variable(id))).toList @@ -124,78 +129,74 @@ case object CEGIS extends Rule("CEGIS", 150) { var ass = p.as.toSet var xss = p.xs.toSet - var lastF = TentativeFormula(p.pc, p.phi, BooleanLiteral(true), Map(), Map() ++ p.xs.map(x => x -> Set(x))) - var currentF = lastF.unroll + val unrolling = new TentativeFormula(p.pc, p.phi, BooleanLiteral(true), Map(), Map() ++ p.xs.map(x => x -> Set(x))) var unrolings = 0 val maxUnrolings = 3 var predicates: Seq[Expr] = Seq() + + + val mainSolver: FairZ3Solver = sctx.solver.asInstanceOf[FairZ3Solver] + + // solver1 is used for the initial SAT queries + val solver1 = mainSolver.getNewSolver + solver1.assertCnstr(And(p.pc, p.phi)) + + // solver2 is used for the CE search + val solver2 = mainSolver.getNewSolver + solver2.assertCnstr(And(p.pc :: Not(p.phi) :: Nil)) + try { do { + val (clauses, bounds) = unrolling.unroll + //println("UNROLLING: "+clauses+" WITH BOUNDS "+bounds) + solver1.assertCnstr(And(clauses)) + solver2.assertCnstr(And(clauses)) + //println("="*80) //println("Was: "+lastF.entireFormula) //println("Now Trying : "+currentF.entireFormula) val tpe = TupleType(p.xs.map(_.getType)) - val bss = currentF.bss - - var continue = true + val bss = unrolling.bss - val mainSolver: FairZ3Solver = sctx.solver.asInstanceOf[FairZ3Solver] - - // solver1 is used for the initial SAT queries - val solver1 = mainSolver.getNewSolver - - val basePhi = currentF.entireFormula - solver1.assertCnstr(basePhi) - - // solver2 is used for the CE search - val solver2 = mainSolver.getNewSolver - solver2.assertCnstr(And(currentF.pathcond :: currentF.program :: Not(currentF.phi) :: Nil)) - - // solver3 is used for the unsatcore search - val solver3 = mainSolver.getNewSolver - solver3.assertCnstr(And(currentF.pathcond :: currentF.program :: currentF.phi :: Nil)) + var continue = !clauses.isEmpty while (result.isEmpty && continue) { //println("-"*80) //println(basePhi) //println("To satisfy: "+constrainedPhi) - solver1.check match { + solver1.checkAssumptions(bounds.map(id => Not(Variable(id)))) match { case Some(true) => val satModel = solver1.getModel //println("Found solution: "+satModel) //println("Corresponding program: "+simplifyTautologies(synth.solver)(valuateWithModelIn(currentF.program, bss, satModel))) - val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) + //val fixedBss = And(bss.map(b => Equals(Variable(b), satModel(b))).toSeq) //println("Phi with fixed sat bss: "+fixedBss) - solver2.push() - solver2.assertCnstr(fixedBss) + val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match { + case BooleanLiteral(true) => Variable(b) + case BooleanLiteral(false) => Not(Variable(b)) + }) //println("FORMULA: "+And(currentF.pathcond :: currentF.program :: Not(currentF.phi) :: fixedBss :: Nil)) //println("#"*80) - solver2.check match { + solver2.checkAssumptions(bssAssumptions) match { case Some(true) => //println("#"*80) val invalidModel = solver2.getModel val fixedAss = And(ass.map(a => Equals(Variable(a), invalidModel(a))).toSeq) + solver1.push() + solver1.assertCnstr(fixedAss) //println("Found counter example: "+fixedAss) - solver3.push() - solver3.assertCnstr(fixedAss) - - val bssAssumptions: Set[Expr] = bss.toSet.map { b: Identifier => satModel(b) match { - case BooleanLiteral(true) => Variable(b) - case BooleanLiteral(false) => Not(Variable(b)) - }} - - val unsatCore = solver3.checkAssumptions(bssAssumptions) match { + val unsatCore = solver1.checkAssumptions(bssAssumptions) match { case Some(false) => - val core = solver3.getUnsatCore + val core = solver1.getUnsatCore //println("Formula: "+mustBeUnsat) //println("Core: "+core) //println(synth.solver.solveSAT(And(mustBeUnsat +: bssAssumptions.toSeq))) @@ -211,19 +212,21 @@ case object CEGIS extends Rule("CEGIS", 150) { bssAssumptions } - solver3.pop() + solver1.pop() - val freshCss = currentF.css.map(c => c -> Variable(FreshIdentifier(c.name, true).setType(c.getType))).toMap + val freshCss = unrolling.css.map(c => c -> Variable(FreshIdentifier(c.name, true).setType(c.getType))).toMap val ceIn = ass.map(id => id -> invalidModel(id)) - val counterexemple = substAll(freshCss ++ ceIn, And(Seq(currentF.program, currentF.phi))) + val counterexemple = substAll(freshCss ++ ceIn, And(Seq(unrolling.program, unrolling.phi))) //println("#"*80) //println(currentF.phi) //println(substAll(freshCss ++ ceIn, currentF.phi)) // Found as such as the xs break, refine predicates + solver1.assertCnstr(counterexemple) + solver2.assertCnstr(counterexemple) if (unsatCore.isEmpty) { continue = false @@ -236,7 +239,7 @@ case object CEGIS extends Rule("CEGIS", 150) { //println("#"*80) //println("UNSAT!") //println("Sat model: "+satModel.toSeq.sortBy(_._1.toString).map{ case (id, v) => id+" -> "+v }.mkString(", ")) - var mapping = currentF.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap + var mapping = unrolling.mappings.filterKeys(satModel.mapValues(_ == BooleanLiteral(true))).values.toMap // Resolve mapping @@ -251,8 +254,6 @@ case object CEGIS extends Rule("CEGIS", 150) { continue = false } - solver2.pop() - case Some(false) => //println("%%%% UNSAT") continue = false @@ -262,10 +263,8 @@ case object CEGIS extends Rule("CEGIS", 150) { } } - lastF = currentF - currentF = currentF.unroll unrolings += 1 - } while(unrolings < maxUnrolings && lastF != currentF && result.isEmpty) + } while(unrolings < maxUnrolings && result.isEmpty) result.getOrElse(RuleApplicationImpossible)