diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 623c870c64b83d39c774032c8893e98b8865fc28..d2d5da2500353534fca251008cde9a185d24f3ef 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -4,6 +4,7 @@ package rules import solvers.TimeoutSolver import purescala.Trees._ +import purescala.DataGen import purescala.Common._ import purescala.Definitions._ import purescala.TypeTrees._ @@ -11,6 +12,8 @@ import purescala.TreeOps._ import purescala.Extractors._ import purescala.ScalaPrinter +import scala.collection.mutable.{Map=>MutableMap} + import evaluators._ import solvers.z3.FairZ3Solver @@ -22,10 +25,15 @@ case object CEGIS extends Rule("CEGIS") { val useCEAsserts = false val useUninterpretedProbe = false val useUnsatCores = true + val useOptTimeout = true val useFunGenerators = sctx.options.cegisGenerateFunCalls val useBPaths = sctx.options.cegisUseBPaths val useCETests = sctx.options.cegisUseCETests val useCEPruning = sctx.options.cegisUseCEPruning + // Limits the number of programs CEGIS will specifically test for instead of reasonning symbolically + val testUpTo = 5 + val useBssFiltering = true + val filterThreshold = 1.0/2 val evaluator = new CodeGenEvaluator(sctx.context, sctx.program) case class Generator(tpe: TypeTree, altBuilder: () => List[(Expr, Set[Identifier])]); @@ -79,6 +87,8 @@ case object CEGIS extends Rule("CEGIS") { p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]())) } + val funcCache: MutableMap[TypeTree, Seq[FunDef]] = MutableMap.empty + def funcAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = { if (useFunGenerators) { def isCandidate(fd: FunDef): Boolean = { @@ -104,11 +114,19 @@ case object CEGIS extends Rule("CEGIS") { isSubtypeOf(fd.returnType, t) && !isRecursiveCall && isNotSynthesizable } - sctx.program.definedFunctions.filter(isCandidate).map{ fd => - val ids = fd.args.map(vd => FreshIdentifier("c", true).setType(vd.getType)) + val funcs = funcCache.get(t) match { + case Some(alts) => + alts + case None => + val alts = sctx.program.definedFunctions.filter(isCandidate) + funcCache += t -> alts + alts + } - (FunctionInvocation(fd, ids.map(Variable(_))), ids.toSet) - }.toList + funcs.map{ fd => + val ids = fd.args.map(vd => FreshIdentifier("c", true).setType(vd.getType)) + (FunctionInvocation(fd, ids.map(Variable(_))), ids.toSet) + }.toList } else { Nil } @@ -187,8 +205,8 @@ case object CEGIS extends Rule("CEGIS") { res == BooleanLiteral(true) case EvaluationResults.RuntimeError(err) => - sctx.reporter.error("Error testing CE: "+err) - true + //sctx.reporter.error("Error testing CE: "+err) + false case EvaluationResults.EvaluatorError(err) => sctx.reporter.error("Error testing CE: "+err) @@ -249,10 +267,10 @@ case object CEGIS extends Rule("CEGIS") { val simplerRes = simplifyLets(res) - // println("COMPILATION RESULT: ") - // println(ScalaPrinter(simplerRes)) - // println("BSS: "+bssOrdered) - // println("FREE: "+variablesOf(simplerRes)) + //println("COMPILATION RESULT: ") + //println(ScalaPrinter(simplerRes)) + //println("BSS: "+bssOrdered) + //println("FREE: "+variablesOf(simplerRes)) def compileWithArray(): Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = { val ba = FreshIdentifier("bssArray").setType(ArrayType(BooleanType)) @@ -309,6 +327,54 @@ case object CEGIS extends Rule("CEGIS") { } + def filterFor(remainingBss: Set[Identifier]): Seq[Expr] = { + val filteredBss = remainingBss + initGuard + + // The following code is black-magic, read with caution + mappings = mappings.filterKeys(filteredBss) + guardedTerms = Map() + bTree = bTree.filterKeys(filteredBss) + bTree = bTree.mapValues(cToBs => cToBs.mapValues(bs => bs & filteredBss)) + + val filteredCss = mappings.map(_._2._1).toSet + cChildren = cChildren.filterKeys(filteredCss) + cChildren = cChildren.mapValues(css => css & filteredCss) + for (c <- filteredCss) { + if (!(cChildren contains c)) { + cChildren += c -> Set() + } + } + + // Finally, we reset the state of the evaluator + triedCompilation = false + progEvaluator = None + + // We need to regenerate clauses for each b + val pathConstraints = for ((parentGuard, cToBs) <- bTree; (c, bs) <- cToBs) yield { + val bvs = bs.toList.map(Variable(_)) + + val failedPath = Not(Variable(parentGuard)) + + val distinct = bvs.combinations(2).collect { + case List(a, b) => + Or(Not(a) :: Not(b) :: Nil) + } + + And(Seq(Or(failedPath :: bvs), Implies(failedPath, And(bvs.map(Not(_))))) ++ distinct) + } + + // Generate all the b => c = ... + val impliess = mappings.map { case (bid, (recId, ex)) => + Implies(Variable(bid), Equals(Variable(recId), ex)) + } + + //for (i <- impliess) { + // println(": "+i) + //} + + (pathConstraints ++ impliess).toSeq + } + def unroll: (List[Expr], Set[Identifier]) = { var newClauses = List[Expr]() var newGuardedTerms = Map[Identifier, Set[Identifier]]() @@ -356,7 +422,7 @@ case object CEGIS extends Rule("CEGIS") { guardedTerms = newGuardedTerms - // Finally, we reset the state of the evalautor + // Finally, we reset the state of the evaluator triedCompilation = false progEvaluator = None @@ -380,7 +446,8 @@ case object CEGIS extends Rule("CEGIS") { var unrolings = 0 val maxUnrolings = 3 - val mainSolver = new TimeoutSolver(sctx.solver, 10000L) // 10sec + val exSolver = new TimeoutSolver(sctx.solver, 3000L) // 3sec + val cexSolver = new TimeoutSolver(sctx.solver, 3000L) // 3sec var exampleInputs = Set[Seq[Expr]]() @@ -388,7 +455,7 @@ case object CEGIS extends Rule("CEGIS") { if (p.pc == BooleanLiteral(true)) { exampleInputs += p.as.map(a => simplestValue(a.getType)) } else { - val solver = mainSolver.getNewSolver + val solver = exSolver.getNewSolver solver.assertCnstr(p.pc) @@ -404,85 +471,147 @@ case object CEGIS extends Rule("CEGIS") { sctx.reporter.warning("Solver could not solve path-condition") return RuleApplicationImpossible // This is not necessary though, but probably wanted } + + } + + val discoveredInputs = DataGen.findModels(p.pc, evaluator, 20, 1000, forcedFreeVars = Some(p.as)).map{ + m => p.as.map(a => m(a)) + } + + def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplicationResult = { + for (prog <- programs) { + val expr = ndProgram.determinize(prog) + val res = Equals(Tuple(p.xs.map(Variable(_))), expr) + val solver3 = cexSolver.getNewSolver + solver3.assertCnstr(And(p.pc :: res :: Not(p.phi) :: Nil)) + + solver3.check match { + case Some(false) => + return RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = true) + case None => + return RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = false) + case Some(true) => + // invalid program, we skip + } + } + + RuleApplicationImpossible } + // println("Generating tests..") + // println("Found: "+discoveredInputs.size) + exampleInputs ++= discoveredInputs + // Keep track of collected cores to filter programs to test var collectedCores = Set[Set[Identifier]]() + val initExClause = And(p.pc :: p.phi :: Variable(initGuard) :: Nil) + val initCExClause = And(p.pc :: Not(p.phi) :: Variable(initGuard) :: Nil) + // solver1 is used for the initial SAT queries - var solver1 = mainSolver.getNewSolver - solver1.assertCnstr(And(p.pc :: p.phi :: Variable(initGuard) :: Nil)) + var solver1 = exSolver.getNewSolver + solver1.assertCnstr(initExClause) // solver2 is used for validating a candidate program, or finding new inputs - val solver2 = mainSolver.getNewSolver - solver2.assertCnstr(And(p.pc :: Not(p.phi) :: Variable(initGuard) :: Nil)) + var solver2 = cexSolver.getNewSolver + solver2.assertCnstr(initCExClause) + var didFilterAlready = false - var allClauses = List[Expr]() + val tpe = TupleType(p.xs.map(_.getType)) try { do { var needMoreUnrolling = false - // Compute all programs that have not been excluded yet - var allPrograms: Set[Set[Identifier]] = if (useCEPruning) { - ndProgram.allPrograms.filterNot(p => collectedCores.exists(c => c.subsetOf(p))) - } else { - Set() + var bssAssumptions = Set[Identifier]() + + if (!didFilterAlready) { + val (clauses, closedBs) = ndProgram.unroll + + bssAssumptions = closedBs + + //println("UNROLLING: ") + //for (c <- clauses) { + // println(" - " + c) + //} + //println("CLOSED Bs "+closedBs) + + val clause = And(clauses) + + solver1.assertCnstr(clause) + solver2.assertCnstr(clause) } - //println("Programs: "+allPrograms.size) - //println("CEs: "+exampleInputs.size) + // Compute all programs that have not been excluded yet + var prunedPrograms: Set[Set[Identifier]] = if (useCEPruning) { + ndProgram.allPrograms.filterNot(p => collectedCores.exists(c => c.subsetOf(p))) + } else { + Set() + } + + val allPrograms = prunedPrograms.size + + //println("Programs: "+prunedPrograms.size) + //println("#Tests: "+exampleInputs.size) // We further filter the set of working programs to remove those that fail on known examples if (useCEPruning && !exampleInputs.isEmpty && ndProgram.canTest()) { - //for (ce <- exampleInputs) { - // println("CE: "+ce) - //} - - for (p <- allPrograms) { + for (p <- prunedPrograms) { if (!exampleInputs.forall(ndProgram.testForProgram(p))) { // This program failed on at least one example solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) - allPrograms -= p + prunedPrograms -= p } } - if (allPrograms.isEmpty) { + if (prunedPrograms.isEmpty) { needMoreUnrolling = true } - //println("Passing tests: "+allPrograms.size) + //println("Passing tests: "+prunedPrograms.size) } - //allPrograms.foreach { p => - // println("PATH: "+p) - // println("CLAUSES: "+p.flatMap( b => ndProgram.mappings.get(b).map{ case (c, ex) => c+" = "+ex}).mkString(" && ")) - //} - - val (clauses, closedBs) = ndProgram.unroll - //println("UNROLLING: ") - //for (c <- clauses) { - // println(" - " + c) - //} - //println("CLOSED Bs "+closedBs) - - val clause = And(clauses) - allClauses = clause :: allClauses + val nPassing = prunedPrograms.size + + if (nPassing == 0) { + needMoreUnrolling = true; + } else if (nPassing <= testUpTo) { + // Immediate Test + result = Some(checkForPrograms(prunedPrograms)) + } else if (((nPassing < allPrograms*filterThreshold) || didFilterAlready) && useBssFiltering) { + // We filter the Bss so that the formula we give to z3 is much smalled + val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) + //println("To Keep: "+bssToKeep.size+"/"+ndProgram.bss.size) + + // Cannot unroll normally after having filtered, so we need to + // repeat the filtering procedure at next unrolling. + didFilterAlready = true + + // Freshening solvers + solver1 = exSolver.getNewSolver + solver1.assertCnstr(initExClause) + solver2 = cexSolver.getNewSolver + solver2.assertCnstr(initCExClause) + + val clauses = ndProgram.filterFor(bssToKeep) + val clause = And(clauses) + + solver1.assertCnstr(clause) + solver2.assertCnstr(clause) + + //println("Filtered clauses:") + //for (c <- clauses) { + // println(" - " + c) + //} - solver1.assertCnstr(clause) - solver2.assertCnstr(clause) + } - val tpe = TupleType(p.xs.map(_.getType)) val bss = ndProgram.bss - if (clauses.isEmpty) { - needMoreUnrolling = true - } - while (result.isEmpty && !needMoreUnrolling && !sctx.shouldStop.get) { - solver1.checkAssumptions(closedBs.map(id => Not(Variable(id)))) match { + solver1.checkAssumptions(bssAssumptions.map(id => Not(Variable(id)))) match { case Some(true) => val satModel = solver1.getModel @@ -518,6 +647,7 @@ case object CEGIS extends Rule("CEGIS") { } if (validateWithZ3) { + //println("Looking for CE...") solver2.checkAssumptions(bssAssumptions) match { case Some(true) => //println("#"*80) @@ -535,7 +665,7 @@ case object CEGIS extends Rule("CEGIS") { // Retest whether the newly found C-E invalidates all programs if (useCEPruning && ndProgram.canTest) { - if (allPrograms.forall(p => !ndProgram.testForProgram(p)(newCE))) { + if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(newCE))) { // println("I found a killer example!") needMoreUnrolling = true } @@ -601,7 +731,14 @@ case object CEGIS extends Rule("CEGIS") { result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr))) case _ => - return RuleApplicationImpossible + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = false)) + } else { + return RuleApplicationImpossible + } } } @@ -622,8 +759,8 @@ case object CEGIS extends Rule("CEGIS") { needMoreUnrolling = true case _ => - //println("%%%% WOOPS") - return RuleApplicationImpossible + // Last chance, we test first few programs + return checkForPrograms(prunedPrograms.take(testUpTo)) } }