diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index f457bb4b39c266b1f45e530495187b5ee92fad5d..fa5ebb6c3c6458e14897341a7a582a5994ea2813 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -47,13 +47,13 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } def initRC(mappings: Map[Identifier, Expr]): RC - def initGC: GC + def initGC(): GC private[this] var clpCache = Map[(Choose, Seq[Expr]), Expr]() def eval(ex: Expr, mappings: Map[Identifier, Expr]) = { try { - lastGC = Some(initGC) + lastGC = Some(initGC()) ctx.timers.evaluators.recursive.runtime.start() EvaluationResults.Successful(e(ex)(initRC(mappings), lastGC.get)) } catch { @@ -78,7 +78,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Some(v) => v case None => - throw EvalError("No value for identifier " + id.name + " in mapping.") + throw EvalError("No value for identifier " + id.asString(ctx) + " in mapping.") } case Application(caller, args) => @@ -145,8 +145,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val evArgs = args map e // build a mapping for the function... - val frame = rctx.newVars(tfd.paramSubst(evArgs)) - + val frame = rctx.withNewVars(tfd.paramSubst(evArgs)) + if(tfd.hasPrecondition) { e(tfd.precondition.get)(frame, gctx) match { case BooleanLiteral(true) => diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index e7acc9a42f34f3b014ded101a86484d8747cac83..64a0dca50c0b66ceb685be331588fa703f2f62e3 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -283,14 +283,6 @@ object DefOps { fdMapCache(fd).getOrElse(fd) } - def replaceCalls(e: Expr): Expr = { - preMap { - case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => - fiMapF(fi, fdMap(fd)).map(_.setPos(fi)) - case _ => - None - }(e) - } val newP = p.copy(units = for (u <- p.units) yield { u.copy( @@ -300,7 +292,7 @@ object DefOps { df match { case f : FunDef => val newF = fdMap(f) - newF.fullBody = replaceCalls(newF.fullBody) + newF.fullBody = replaceFunCalls(newF.fullBody, fdMap, fiMapF) newF case d => d @@ -319,6 +311,15 @@ object DefOps { (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) } + def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { + preMap { + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + fiMapF(fi, fdMapF(fd)).map(_.setPos(fi)) + case _ => + None + }(e) + } + def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = { var found = false val res = p.copy(units = for (u <- p.units) yield { diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 20b95a30bb868e93db3311e14c64740ce016afac..0449334d3702fd4164274b7b494642c647db994f 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -83,19 +83,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout for ((sol, i) <- solutions.zipWithIndex) { reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) val expr = sol.toSimplifiedExpr(ctx, synth.program) - reporter.info(ScalaPrinter(expr)) - } - reporter.info(ASCIIHelpers.title("In context:")) - - - for ((sol, i) <- solutions.zipWithIndex) { - reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) - val expr = sol.toSimplifiedExpr(ctx, synth.program) - val nfd = fd.duplicate - - nfd.body = Some(expr) - - reporter.info(ScalaPrinter(nfd)) + reporter.info(expr.asString(ctx)) } } } finally { diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index c0f44d23d2c02b3d2a26f95af835ffadf4063f1b..ca53ba5ae5f045bb1a451ae94508f16ecd6e42a3 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -23,13 +23,15 @@ import datagen._ import codegen.CodeGenParams import utils._ +import utils.ExpressionGrammars.{SizeBoundedGrammar, SizedLabel} +import bonsai.Generator abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { case class CegisParams( grammar: ExpressionGrammar[T], rootLabel: TypeTree => T, - maxUnfoldings: Int = 3 + maxUnfoldings: Int = 5 ) def getParams(sctx: SynthesisContext, p: Problem): CegisParams @@ -42,7 +44,8 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val nProgramsLimit = 100000 val sctx = hctx.sctx - val ctx = sctx.context + implicit val ctx = sctx.context + // CEGIS Flags to activate or deactivate features val useOptTimeout = sctx.settings.cegisUseOptTimeout.getOrElse(true) @@ -63,9 +66,31 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { return Nil } - class NonDeterministicProgram(val p: Problem) { + class NonDeterministicProgram(val p: Problem, initTermSize: Int = 1) { + + private var termSize = 0; + + val grammar = ExpressionGrammars.SizeBoundedGrammar(params.grammar) + + def rootLabel(tpe: TypeTree) = SizedLabel(params.rootLabel(tpe), termSize) + + def xLabels = p.xs.map(x => rootLabel(x.getType)) + + var nAltsCache = Map[SizedLabel[T], Int]() + + def countAlternatives(l: SizedLabel[T]): Int = { + if (!(nAltsCache contains l)) { + val count = grammar.getProductions(l).map { + case Generator(subTrees, _) => subTrees.map(countAlternatives).product + }.sum + nAltsCache += l -> count + } + nAltsCache(l) + } - private val grammar = params.grammar + def allProgramsCount(): Int = { + xLabels.map(countAlternatives).product + } /** * Different view of the tree of expressions: @@ -84,60 +109,113 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { * (b3, H(c7, c8), Set(c7, c8)) * ) */ - private var cTree: Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]] = Map() + private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map() - /** - * Computes dependencies of c's - * - * c1 -> Set(c2, c3, c4, c5) - */ - private var cDeps: Map[Identifier, Set[Identifier]] = Map() - - /** - * Keeps track of blocked Bs and which C are affected, assuming cs are undefined: - * - * b2 -> Set(c4) - * b3 -> Set(c4) - */ - private var closedBs: Map[Identifier, Set[Identifier]] = Map() - /** - * Maps c identifiers to grammar labels - * - * Labels allows us to use grammars that are not only type-based - */ - private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> params.rootLabel(x.getType)) + // C identifiers corresponding to p.xs + private var rootCs: Seq[Identifier] = Seq() private var bs: Set[Identifier] = Set() + private var bsOrdered: Seq[Identifier] = Seq() - /** - * Checks if 'b' is closed (meaning it depends on uninterpreted terms) - */ - def isBActive(b: Identifier) = !closedBs.contains(b) - def allProgramsCount(): Int = { - var nAltsCache = Map[Identifier, Int]() - - def nAltsFor(c: Identifier): Int = { - if (!(nAltsCache contains c)) { - val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b)) yield { - if (subcs.isEmpty) { - 1 - } else { - subcs.toSeq.map(nAltsFor).product + + class CGenerator { + private var buffers = Map[SizedLabel[T], Stream[Identifier]]() + + private var slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + + private def streamOf(t: SizedLabel[T]): Stream[Identifier] = { + FreshIdentifier(t.toString, t.getType, true) #:: streamOf(t) + } + + def rewind(): Unit = { + slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + } + + def getNext(t: SizedLabel[T]) = { + if (!(buffers contains t)) { + buffers += t -> streamOf(t) + } + + val n = slots(t) + slots += t -> (n+1) + + buffers(t)(n) + } + } + + def init(): Unit = { + updateCTree() + } + + + def updateCTree(): Unit = { + def freshB() = { + val id = FreshIdentifier("B", BooleanType, true) + bs += id + id + } + + def defineCTreeFor(l: SizedLabel[T], c: Identifier): Unit = { + if (!(cTree contains c)) { + val cGen = new CGenerator() + + var alts = grammar.getProductions(l) + + val cTreeData = for (gen <- alts) yield { + val b = freshB() + + // Optimize labels + cGen.rewind() + + val subCs = for (sl <- gen.subTrees) yield { + val subC = cGen.getNext(sl) + defineCTreeFor(sl, subC) + subC } + + (b, gen.builder, subCs) } - nAltsCache += c -> subs.sum + cTree += c -> cTreeData } - nAltsCache(c) } - p.xs.map(nAltsFor).product + val cGen = new CGenerator() + + rootCs = for (l <- xLabels) yield { + val c = cGen.getNext(l) + defineCTreeFor(l, c) + c + } + + sctx.reporter.ifDebug { printer => + printer("Grammar so far:") + grammar.printProductions(printer) + } + + bsOrdered = bs.toSeq.sortBy(_.id) + + setCExpr(computeCExpr()) } + /** + * Keeps track of blocked Bs and which C are affected, assuming cs are undefined: + * + * b2 -> Set(c4) + * b3 -> Set(c4) + */ + private var closedBs: Map[Identifier, Set[Identifier]] = Map() + + /** + * Checks if 'b' is closed (meaning it depends on uninterpreted terms) + */ + def isBActive(b: Identifier) = !closedBs.contains(b) + + /** * Returns all possible assignments to Bs in order to enumerate all possible programs */ @@ -151,7 +229,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { var cache = Map[Identifier, Seq[Set[Identifier]]]() - def allProgramsFor(cs: Set[Identifier]): Seq[Set[Identifier]] = { + def allProgramsFor(cs: Seq[Identifier]): Seq[Set[Identifier]] = { val seqs = for (c <- cs.toSeq) yield { if (!(cache contains c)) { val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b)) yield { @@ -173,100 +251,97 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } } - allProgramsFor(p.xs.toSet) + allProgramsFor(rootCs) } - private def debugCExpr(cTree: Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]], + private def debugCTree(cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]], markedBs: Set[Identifier] = Set()): Unit = { println(" -- -- -- -- -- ") for ((c, alts) <- cTree) { println println(f"$c%-4s :=") - for ((b, ex, cs) <- alts ) { + for ((b, builder, cs) <- alts ) { val active = if (isBActive(b)) " " else "тип" val markS = if (markedBs(b)) Console.GREEN else "" val markE = if (markedBs(b)) Console.RESET else "" - println(f" $markS$active $b%-4s => $ex%-40s [$cs]$markE") + val ex = builder(cs.map(_.toVariable)).asString + + println(f" $markS$active ${b.asString}%-4s => $ex%-40s [${cs.map(_.asString).mkString(", ")}]$markE") } } } - private def computeCExpr(): Expr = { + private def computeCExpr(): (Expr, Seq[FunDef]) = { + var cToFd = Map[Identifier, FunDef]() - val lets = for ((c, alts) <- cTree) yield { - val activeAlts = alts.filter(a => isBActive(a._1)) + def exprOf(alt: (Identifier, Seq[Expr] => Expr, Seq[Identifier])): Expr = { + val (_, builder, cs) = alt - val expr = activeAlts.foldLeft(simplestValue(c.getType): Expr) { - case (e, (b, ex, _)) => IfExpr(b.toVariable, ex, e) - } + val e = builder(cs.map { c => + val fd = cToFd(c) + FunctionInvocation(fd.typed, fd.params.map(_.toVariable)) + }) - (c, expr) + outerExprToInnerExpr(e) } - // We order the lets base don dependencies - def defFor(c: Identifier): Expr = { - cDeps(c).filter(lets.contains).foldLeft(lets(c)) { - case (e, c) => Let(c, defFor(c), e) - } + // Define all C-def + for ((c, alts) <- cTree) yield { + cToFd += c -> new FunDef(FreshIdentifier(c.toString, alwaysShowUniqueID = true), + Seq(), + c.getType, + p.as.map(id => ValDef(id))) } - val res = tupleWrap(p.xs.map(defFor)) - - val substMap : Map[Expr,Expr] = bsOrdered.zipWithIndex.map { - case (b, i) => Variable(b) -> ArraySelect(bArrayId.toVariable, IntLiteral(i)) - }.toMap - - val simplerRes = simplifyLets(res) - - replace(substMap, simplerRes) - } + // Fill C-def bodies + for ((c, alts) <- cTree) { + val activeAlts = alts.filter(a => isBActive(a._1)) + val body = if (activeAlts.nonEmpty) { + activeAlts.init.foldLeft(exprOf(activeAlts.last)) { + case (e, alt) => IfExpr(alt._1.toVariable, exprOf(alt), e) + } + } else { + Error(c.getType, "Impossibru") + } - /** - * Information about the final Program representing CEGIS solutions at - * the current unfolding level - */ - private val outerSolution = { - val part = new PartialSolution(hctx.search.g, true) - e : Expr => part.solutionAround(hctx.currentNode)(e).getOrElse { - sctx.reporter.fatalError("Unable to create outer solution") + cToFd(c).fullBody = body } - } - private val bArrayId = FreshIdentifier("bArray", ArrayType(BooleanType), true) + // Top-level expression for rootCs + val expr = tupleWrap(rootCs.map { c => + val fd = cToFd(c) + FunctionInvocation(fd.typed, fd.params.map(_.toVariable)) + }) - private var cTreeFd = new FunDef( - FreshIdentifier("cTree", alwaysShowUniqueID = true), - Seq(), - p.outType, - p.as.map(id => ValDef(id)) - ) + (expr, cToFd.values.toSeq) + } - private var phiFd = new FunDef( - FreshIdentifier("phiFd", alwaysShowUniqueID = true), - Seq(), - BooleanType, - p.as.map(id => ValDef(id)) - ) - private var programCTree: Program = _ - // Map functions from original program to cTree program - private var fdMapCTree: Map[FunDef, FunDef] = _ + private val cTreeFd = new FunDef(FreshIdentifier("cTree", alwaysShowUniqueID = true), + Seq(), + p.outType, + p.as.map(id => ValDef(id)) + ) - private var tester: (Seq[Expr], Set[Identifier]) => EvaluationResults.Result = _ + private val phiFd = new FunDef(FreshIdentifier("phiFd", alwaysShowUniqueID = true), + Seq(), + BooleanType, + p.as.map(id => ValDef(id)) + ) - private def initializeCTreeProgram(): Unit = { - // CEGIS is solved by called cTree function (without bs yet) - val fullSol = outerSolution(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable))) + private val (innerProgram, origFdMap) = { + val outerSolution = { + new PartialSolution(hctx.search.g, true) + .solutionAround(hctx.currentNode)(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable))) + .getOrElse(ctx.reporter.fatalError("Unable to get outer solution")) + } - val chFd = hctx.ci.fd - val prog0 = hctx.program - - val affected = prog0.callGraph.transitiveCallers(chFd) ++ Set(chFd, cTreeFd, phiFd) ++ fullSol.defs + val program0 = addFunDefs(sctx.program, Seq(cTreeFd, phiFd) ++ outerSolution.defs, hctx.ci.fd) cTreeFd.body = None phiFd.body = Some( @@ -275,83 +350,75 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { p.phi) ) - val prog1 = addFunDefs(prog0, Seq(cTreeFd, phiFd) ++ fullSol.defs, chFd) - - val (prog2, fdMap2) = replaceFunDefs(prog1)({ - case fd if affected(fd) => - // Add the b array argument to all affected functions - val nfd = new FunDef( - fd.id.freshen, - fd.tparams, - fd.returnType, - fd.params :+ ValDef(bArrayId) - ) - nfd.copyContentFrom(fd) - nfd.copiedFrom(fd) - - if (fd == chFd) { - nfd.fullBody = replace(Map(hctx.ci.ch -> fullSol.guardedTerm), nfd.fullBody) - } + replaceFunDefs(program0){ + case fd if fd == hctx.ci.fd => + val nfd = fd.duplicate + + nfd.fullBody = postMap { + case ch if ch eq hctx.ci.ch => + Some(outerSolution.term) + + case _ => None + }(nfd.fullBody) Some(nfd) - case _ => - None - }, { - case (FunctionInvocation(old, args), newfd) if old.fd != newfd => - Some(FunctionInvocation(newfd.typed(old.tps), args :+ bArrayId.toVariable)) - case _ => + case `cTreeFd` | `phiFd` => None - }) - programCTree = prog2 - cTreeFd = fdMap2(cTreeFd) - phiFd = fdMap2(phiFd) - fdMapCTree = fdMap2 + case fd => + Some(fd.duplicate) + } + } - private def setCExpr(cTree: Expr): Unit = { + /** + * Since CEGIS works with a copy of the outer program, + * it needs to map outer function calls to inner function calls + * and vice-versa. 'inner' refers to the CEGIS-specific program, + * 'outer' refers to the actual program on which we do synthesis. + */ + private def outerExprToInnerExpr(e: Expr): Expr = { + replaceFunCalls(e, {fd => origFdMap.getOrElse(fd, fd) }) + } - cTreeFd.body = Some(preMap{ - case FunctionInvocation(TypedFunDef(fd, tps), args) if fdMapCTree contains fd => - Some(FunctionInvocation(fdMapCTree(fd).typed(tps), args :+ bArrayId.toVariable)) - case _ => - None - }(cTree)) + private val innerPc = outerExprToInnerExpr(p.pc) + private val innerPhi = outerExprToInnerExpr(p.phi) + + private var programCTree: Program = _ + private var tester: (Seq[Expr], Set[Identifier]) => EvaluationResults.Result = _ + + private def setCExpr(cTreeInfo: (Expr, Seq[FunDef])): Unit = { + val (cTree, newFds) = cTreeInfo + + cTreeFd.body = Some(cTree) + programCTree = addFunDefs(innerProgram, newFds, cTreeFd) //println("-- "*30) - //println(programCTree) + //println(programCTree.asString) //println(".. "*30) - val evaluator = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default) - +// val evaluator = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default) + val evaluator = new DefaultEvaluator(sctx.context, programCTree) tester = { (ins: Seq[Expr], bValues: Set[Identifier]) => - val bsValue = finiteArray(bsOrdered.map(b => BooleanLiteral(bValues(b))), None, BooleanType) - val args = ins :+ bsValue + val envMap = bs.map(b => b -> BooleanLiteral(bValues(b))).toMap - val fi = FunctionInvocation(phiFd.typed, args) + val fi = FunctionInvocation(phiFd.typed, ins) - evaluator.eval(fi, Map()) + evaluator.eval(fi, envMap) } } - private def updateCTree() { - if (programCTree eq null) { - initializeCTreeProgram() - } - - setCExpr(computeCExpr()) - } - def testForProgram(bValues: Set[Identifier])(ins: Seq[Expr]): Boolean = { tester(ins, bValues) match { case EvaluationResults.Successful(res) => res == BooleanLiteral(true) case EvaluationResults.RuntimeError(err) => + sctx.reporter.warning("RE testing CE: "+err) false case EvaluationResults.EvaluatorError(err) => @@ -362,79 +429,77 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { + // Returns the outer expression corresponding to a B-valuation def getExpr(bValues: Set[Identifier]): Expr = { def getCValue(c: Identifier): Expr = { cTree(c).find(i => bValues(i._1)).map { - case (b, ex, cs) => - val map = for (c <- cs) yield { - c -> getCValue(c) - } - - substAll(map.toMap, ex) + case (b, builder, cs) => + builder(cs.map(getCValue)) }.getOrElse { simplestValue(c.getType) } } - tupleWrap(p.xs.map(c => getCValue(c))) + tupleWrap(rootCs.map(c => getCValue(c))) } + /** + * Here we check the validity of a given program in isolation, we compute + * the corresponding expr and replace it in place of the C-tree + */ def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = { - try { - val cexs = for (bs <- bss.toSeq) yield { - val sol = getExpr(bs) + val origImpl = cTreeFd.fullBody - val fullSol = outerSolution(sol) + val cexs = for (bs <- bss.toSeq) yield { + val outerSol = getExpr(bs) + val innerSol = outerExprToInnerExpr(outerSol) - val prog = addFunDefs(hctx.program, fullSol.defs, hctx.ci.fd) + cTreeFd.fullBody = innerSol - hctx.ci.ch.impl = Some(fullSol.guardedTerm) + val cnstr = and(innerPc, letTuple(p.xs, innerSol, Not(innerPhi))) - val cnstr = and(p.pc, letTuple(p.xs, sol, Not(p.phi))) - //println("Solving for: "+cnstr) + //println("Solving for: "+cnstr.asString) - val solverf = SolverFactory.default(ctx, prog).withTimeout(cexSolverTo) - val solver = solverf.getNewSolver() - try { - solver.assertCnstr(cnstr) - solver.check match { - case Some(true) => - excludeProgram(bs) - val model = solver.getModel - //println("Found counter example: ") - //for ((s, v) <- model) { - // println(" "+s.asString+" -> "+v.asString) - //} + val solverf = SolverFactory.default(ctx, innerProgram).withTimeout(cexSolverTo) + val solver = solverf.getNewSolver() + try { + solver.assertCnstr(cnstr) + solver.check match { + case Some(true) => + excludeProgram(bs) + val model = solver.getModel + //println("Found counter example: ") + //for ((s, v) <- model) { + // println(" "+s.asString+" -> "+v.asString) + //} - //val evaluator = new DefaultEvaluator(ctx, prog) - //println(evaluator.eval(cnstr, model)) + //val evaluator = new DefaultEvaluator(ctx, prog) + //println(evaluator.eval(cnstr, model)) - Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))) + Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))) - case Some(false) => - // UNSAT, valid program - return Left(Stream(Solution(BooleanLiteral(true), Set(), sol, true))) + case Some(false) => + // UNSAT, valid program + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) - case None => - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - // Optimistic valid solution - return Left(Stream(Solution(BooleanLiteral(true), Set(), sol, false))) - } else { - None - } - } - } finally { - solver.free() - solverf.shutdown() + case None => + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + // Optimistic valid solution + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) + } else { + None + } } + } finally { + solver.free() + solverf.shutdown() + cTreeFd.fullBody = origImpl } - - Right(cexs.flatten) - } finally { - hctx.ci.ch.impl = None } + + Right(cexs.flatten) } var excludedPrograms = ArrayBuffer[Set[Identifier]]() @@ -442,224 +507,18 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { // Explicitly remove program computed by bValues from the search space def excludeProgram(bValues: Set[Identifier]): Unit = { val bvs = bValues.filter(isBActive) - //println(f" (-) ${bvs.mkString(", ")}%-40s ("+getExpr(bvs)+")") excludedPrograms += bvs } - /** - * Shrinks the non-deterministic program to the provided set of - * alternatives only - */ - def shrinkTo(remainingBs: Set[Identifier], finalUnfolding: Boolean): Unit = { - //println("Shrinking!") - - val initialBs = remainingBs ++ (if (finalUnfolding) Set() else closedBs.keySet) - - var cParent = Map[Identifier, Identifier]() - var cOfB = Map[Identifier, Identifier]() - var underBs = Map[Identifier, Set[Identifier]]() - - for ((cparent, alts) <- cTree; - (b, _, cs) <- alts) { - - cOfB += b -> cparent - - for (cchild <- cs) { - underBs += cchild -> (underBs.getOrElse(cchild, Set()) + b) - cParent += cchild -> cparent - } - } - - def bParents(b: Identifier): Set[Identifier] = { - val parentBs = underBs.getOrElse(cOfB(b), Set()) - - Set(b) ++ parentBs.flatMap(bParents) - } - - // include parents - val keptBs = initialBs.flatMap(bParents) - - //println("Initial Bs: "+initialBs) - //println("Keeping Bs: "+keptBs) - - //debugCExpr(cTree, keptBs) - - var newCTree = Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]]() - - for ((c, alts) <- cTree) yield { - newCTree += c -> alts.filter(a => keptBs(a._1)) - } - - def removeDeadAlts(c: Identifier, deadC: Identifier) { - if (newCTree contains c) { - val alts = newCTree(c) - val newAlts = alts.filterNot(a => a._3 contains deadC) - - if (newAlts.isEmpty) { - for (cp <- cParent.get(c)) { - removeDeadAlts(cp, c) - } - newCTree -= c - } else { - newCTree += c -> newAlts - } - } - } - - //println("BETWEEN") - //debugCExpr(newCTree, keptBs) - - for ((c, alts) <- newCTree if alts.isEmpty) { - for (cp <- cParent.get(c)) { - removeDeadAlts(cp, c) - } - newCTree -= c - } - - var newCDeps = Map[Identifier, Set[Identifier]]() - - for ((c, alts) <- cTree) yield { - newCDeps += c -> alts.map(_._3).toSet.flatten - } - - cTree = newCTree - cDeps = newCDeps - closedBs = closedBs.filterKeys(keptBs) - - bs = cTree.map(_._2.map(_._1)).flatten.toSet - bsOrdered = bs.toSeq.sortBy(_.id) - - excludedPrograms = excludedPrograms.filter(_.forall(bs)) - - //debugCExpr(cTree) - updateCTree() - } - - class CGenerator { - private var buffers = Map[T, Stream[Identifier]]() - - private var slots = Map[T, Int]().withDefaultValue(0) - - private def streamOf(t: T): Stream[Identifier] = { - FreshIdentifier("c", t.getType, true) #:: streamOf(t) - } - - def reset(): Unit = { - slots = Map[T, Int]().withDefaultValue(0) - } - - def getNext(t: T) = { - if (!(buffers contains t)) { - buffers += t -> streamOf(t) - } - - val n = slots(t) - slots += t -> (n+1) - - buffers(t)(n) - } - } - def unfold(finalUnfolding: Boolean): Boolean = { - var newBs = Set[Identifier]() - var unfoldedSomething = false - - def freshB() = { - val id = FreshIdentifier("B", BooleanType, true) - newBs += id - id - } - - val unfoldBehind = if (cTree.isEmpty) { - p.xs - } else { - closedBs.flatMap(_._2).toSet - } - - closedBs = Map[Identifier, Set[Identifier]]() - - for (c <- unfoldBehind) { - var alts = grammar.getProductions(labels(c)) - - if (finalUnfolding) { - alts = alts.filter(_.subTrees.isEmpty) - } - - val cGen = new CGenerator() - - val cTreeInfos = if (alts.nonEmpty) { - for (gen <- alts) yield { - val b = freshB() - - // Optimize labels - cGen.reset() - - val cToLabel = for (t <- gen.subTrees) yield { - cGen.getNext(t) -> t - } - - - labels ++= cToLabel - - val cs = cToLabel.map(_._1) - val ex = gen.builder(cs.map(_.toVariable)) - - if (cs.nonEmpty) { - closedBs += b -> cs.toSet - } - - //println(" + "+b+" => "+c+" = "+ex) - - unfoldedSomething = true - - (b, ex, cs.toSet) - } - } else { - // Happens in final unfolding when no alts have ground terms - val b = freshB() - closedBs += b -> Set() - - Seq((b, simplestValue(c.getType), Set[Identifier]())) - } - - cTree += c -> cTreeInfos - cDeps += c -> cTreeInfos.map(_._3).toSet.flatten - } - - sctx.reporter.ifDebug { printer => - printer("Grammar so far:") - grammar.printProductions(printer) - } - - bs = bs ++ newBs - bsOrdered = bs.toSeq.sortBy(_.id) - - /** - * Close dead-ends - * - * Find 'c' that have no active alternatives, then close all 'b's that - * depend on such "dead" 'c's - */ - var deadCs = Set[Identifier]() - - for ((c, alts) <- cTree) { - if (alts.forall{ case (b, _, _) => !isBActive(b) }) { - deadCs += c - } - } - - for ((_, alts) <- cTree; (b, _, cs) <- alts) { - if ((cs & deadCs).nonEmpty) { - closedBs += (b -> closedBs.getOrElse(b, Set())) - } - } - - //debugCExpr(cTree) + termSize += 1 updateCTree() - - unfoldedSomething + true } + /** + * First phase of CEGIS: solve for potential programs (that work on at least one input) + */ def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = { val solverf = SolverFactory.default(ctx, programCTree).withTimeout(exSolverTo) val solver = solverf.getNewSolver() @@ -670,16 +529,14 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { //println(phiFd.fullBody.asString(ctx)) val fixedBs = finiteArray(bsOrdered.map(_.toVariable), None, BooleanType) - val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr) + val toFind = and(innerPc, cnstr) + //println(" --- Constraints ---") + //println(" - "+toFind) try { - val toFind = and(p.pc, cnstrFixed) - //println(" --- Constraints ---") - //println(" - "+toFind) + solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable)))) solver.assertCnstr(toFind) - // oneOfBs - //println(" -- OneOf:") for ((c, alts) <- cTree) { val activeBs = alts.map(_._1).filter(isBActive) @@ -739,17 +596,19 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } } + /** + * Second phase of CEGIS: verify a given program by looking for CEX inputs + */ def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = { val solverf = SolverFactory.default(ctx, programCTree).withTimeout(cexSolverTo) val solver = solverf.getNewSolver() val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) - val fixedBs = finiteArray(bsOrdered.map(b => BooleanLiteral(bs(b))), None, BooleanType) - val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr) try { - solver.assertCnstr(p.pc) - solver.assertCnstr(Not(cnstrFixed)) + solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable)))) + solver.assertCnstr(innerPc) + solver.assertCnstr(Not(cnstr)) solver.check match { case Some(true) => @@ -781,6 +640,8 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val sctx = hctx.sctx val ndProgram = new NonDeterministicProgram(p) + ndProgram.init() + var unfolding = 1 val maxUnfoldings = params.maxUnfoldings @@ -788,19 +649,21 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { var baseExampleInputs: ArrayBuffer[Seq[Expr]] = new ArrayBuffer[Seq[Expr]]() + sctx.reporter.ifDebug { printer => + ndProgram.grammar.printProductions(printer) + } + // We populate the list of examples with a predefined one sctx.reporter.debug("Acquiring initial list of examples") baseExampleInputs ++= p.tb.examples.map(_.ins).toSet - val pc = p.pc - - if (pc == BooleanLiteral(true)) { + if (p.pc == BooleanLiteral(true)) { baseExampleInputs += p.as.map(a => simplestValue(a.getType)) } else { val solver = sctx.newSolver.setTimeout(exSolverTo) - solver.assertCnstr(pc) + solver.assertCnstr(p.pc) try { solver.check match { @@ -823,18 +686,19 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { sctx.reporter.ifDebug { debug => baseExampleInputs.foreach { in => - debug(" - "+in.mkString(", ")) + debug(" - "+in.map(_.asString).mkString(", ")) } } + /** * We generate tests for discarding potential programs */ val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) { - new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, pc, 20, 3000) + new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000) } else { val evaluator = new DualEvaluator(sctx.context, sctx.program, CodeGenParams.default) - new GrammarDataGen(evaluator, ExpressionGrammars.ValueGrammar).generateFor(p.as, pc, 20, 1000) + new GrammarDataGen(evaluator, ExpressionGrammars.ValueGrammar).generateFor(p.as, p.pc, 20, 1000) } val cachedInputIterator = new Iterator[Seq[Expr]] { @@ -901,7 +765,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val e = examples.next() if (!ndProgram.testForProgram(bs)(e)) { failedTestsStats(e) += 1 - sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs)}%-80s failed on: ${e.mkString(", ")}") + sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.map(_.asString).mkString(", ")}") wrongPrograms += bs prunedPrograms -= bs @@ -964,7 +828,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { if (nPassing < nInitial * shrinkThreshold && useShrink) { // We shrink the program to only use the bs mentionned val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) - ndProgram.shrinkTo(bssToKeep, unfolding == maxUnfoldings) + //ndProgram.shrinkTo(bssToKeep, unfolding == maxUnfoldings) } else { wrongPrograms.foreach { ndProgram.excludeProgram diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index 8ad0dc7f7e63e71b014358457116cd9cabbd890f..3ac83225d73b7dfacaccd183eed03bb37825e818 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -8,6 +8,8 @@ import bonsai._ import Helpers._ +import leon.utils.SeqUtils.sumTo + import purescala.Expressions.{Or => LeonOr, _} import purescala.Common._ import purescala.Definitions._ @@ -413,6 +415,36 @@ object ExpressionGrammars { } } + case class SizedLabel[T <% Typed](underlying: T, size: Int) extends Typed { + val getType = underlying.getType + + override def toString = underlying.toString+"|"+size+"|" + } + + case class SizeBoundedGrammar[T <% Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[SizedLabel[T]] { + def computeProductions(sl: SizedLabel[T]): Seq[Gen] = { + if (sl.size <= 0) { + Nil + } else if (sl.size == 1) { + g.getProductions(sl.underlying).filter(_.subTrees.isEmpty).map { + case Generator(subTrees, builder) => + Generator[SizedLabel[T], Expr](Nil, builder) + } + } else { + g.getProductions(sl.underlying).filter(_.subTrees.nonEmpty).flatMap { + case Generator(subTrees, builder) => + val sizes = sumTo(sl.size-1, subTrees.size) + + for (ss <- sizes) yield { + val subSizedLabels = (subTrees zip ss) map (s => SizedLabel(s._1, s._2)) + + Generator[SizedLabel[T], Expr](subSizedLabels, builder) + } + } + } + } + } + case class BoundedGrammar[T](g: ExpressionGrammar[Label[T]], bound: Int) extends ExpressionGrammar[Label[T]] { def computeProductions(l: Label[T]): Seq[Gen] = g.computeProductions(l).flatMap { case g: Generator[Label[T], Expr] => diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala index ff1f55d43d696fe5a5f9e301517716f5bac69641..5a5e2dff3088da991ac99a6b8f1e46f759837629 100644 --- a/src/main/scala/leon/utils/SeqUtils.scala +++ b/src/main/scala/leon/utils/SeqUtils.scala @@ -31,4 +31,14 @@ object SeqUtils { result } + + def sumTo(sum: Int, arity: Int): Seq[Seq[Int]] = { + if (arity == 1) { + Seq(Seq(sum)) + } else { + (1 until sum).flatMap{ n => + sumTo(sum-n, arity-1).map( r => n +: r) + } + } + } }