diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index 67509b4693f1870cd05c70a30a7aa1725a64cde2..f447869fd7444049d1d58f5fcd2411977ba3dc8b 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -102,6 +102,9 @@ abstract class CEGISLike(name: String) extends Rule(name) { */ private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map() + // cTree in expression form + private var cExpr: Expr = _ + // Top-level C identifiers corresponding to p.xs private var rootC: Identifier = _ @@ -192,7 +195,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { } bsOrdered = bs.toSeq.sorted - setCExpr() + cExpr = setCExpr() excludedPrograms = Set() prunedPrograms = allPrograms().toSet @@ -268,21 +271,22 @@ abstract class CEGISLike(name: String) extends Rule(name) { } } - - private val cTreeFd0 = new FunDef(FreshIdentifier("cTree", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), p.outType) + // This represents the current solution of the synthesis problem. + // It is within the image of hctx.functionContext in innerProgram. + // It should be set to the solution you want to check at each time. + // Usually it will either be cExpr or a concrete solution. + private val solutionBox = MutableExpr(NoTree(p.outType)) // The program with the body of the current function replaced by the current partial solution private val (innerProgram, origIdMap, origFdMap, origCdMap) = { val outerSolution = { new PartialSolution(hctx.search.strat, true) - .solutionAround(hctx.currentNode)(FunctionInvocation(cTreeFd0.typed, p.as.map(_.toVariable))) + .solutionAround(hctx.currentNode)(solutionBox) .getOrElse(fatalError("Unable to get outer solution")) } - val program0 = addFunDefs(hctx.program, Seq(cTreeFd0) ++ outerSolution.defs, hctx.functionContext) - - cTreeFd0.body = None + val program0 = addFunDefs(hctx.program, outerSolution.defs, hctx.functionContext) replaceFunDefs(program0){ case fd if fd == hctx.functionContext => @@ -302,9 +306,6 @@ abstract class CEGISLike(name: String) extends Rule(name) { } } - // The function which calls the synthesized expression within programCTree - private val cTreeFd = origFdMap.getOrElse(cTreeFd0, cTreeFd0) - private val outerToInner = new purescala.TreeTransformer { override def transform(id: Identifier): Identifier = origIdMap.getOrElse(id, id) override def transform(cd: ClassDef): ClassDef = origCdMap.getOrElse(cd, cd) @@ -321,12 +322,9 @@ abstract class CEGISLike(name: String) extends Rule(name) { private val innerPc = p.pc map outerExprToInnerExpr private val innerPhi = outerExprToInnerExpr(p.phi) + // Depends on the current solution private val innerSpec = outerExprToInnerExpr( - letTuple( - p.xs, - FunctionInvocation(cTreeFd0.typed, p.as.map(_.toVariable)), - p.phi - ) + letTuple(p.xs, solutionBox, p.phi) ) @@ -336,7 +334,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { private var evaluator: DefaultEvaluator = _ // Updates the program with the C tree after recalculating all relevant FunDef's - private def setCExpr(): Unit = { + private def setCExpr(): Expr = { // Computes a Seq of functions corresponding to the choices made at each non-terminal of the grammar, // and an expression which calls the top-level one. @@ -385,10 +383,10 @@ abstract class CEGISLike(name: String) extends Rule(name) { val (cExpr, newFds) = computeCExpr() - cTreeFd.body = Some(cExpr) - programCTree = addFunDefs(innerProgram, newFds, cTreeFd) + programCTree = addFunDefs(innerProgram, newFds, origFdMap(hctx.functionContext)) evaluator = new DefaultEvaluator(hctx, programCTree) + cExpr //println("-- "*30) //println(programCTree.asString) //println(".. "*30) @@ -431,7 +429,6 @@ abstract class CEGISLike(name: String) extends Rule(name) { op1 == op2 } - val origImpl = cTreeFd.fullBody val outerSol = getExpr(bValues) val redundancyCheck = false @@ -444,21 +441,19 @@ abstract class CEGISLike(name: String) extends Rule(name) { return Some(false) } val innerSol = outerExprToInnerExpr(outerSol) - val cnstr = letTuple(p.xs, innerSol, innerPhi) + def withBindings(e: Expr) = p.pc.bindings.foldRight(e){ case ((id, v), bd) => let(id, outerExprToInnerExpr(v), bd) } - cTreeFd.fullBody = withBindings(innerSol) // FIXME! This shouldnt be needed... Solution around should be somehow used + solutionBox.underlying = innerSol // FIXME! This shouldnt be needed... Solution around should be somehow used timers.testForProgram.start() - val boundCnstr = withBindings(cnstr) - val res = ex match { case InExample(ins) => - evaluator.eval(boundCnstr, p.as.zip(ex.ins).toMap) + evaluator.eval(withBindings(innerSpec), p.as.zip(ex.ins).toMap) case InOutExample(ins, outs) => evaluator.eval( @@ -468,8 +463,6 @@ abstract class CEGISLike(name: String) extends Rule(name) { } timers.testForProgram.stop() - cTreeFd.fullBody = origImpl - res match { case EvaluationResults.Successful(res) => Some(res == BooleanLiteral(true)) @@ -507,7 +500,6 @@ abstract class CEGISLike(name: String) extends Rule(name) { * We keep track of CEXs generated by invalid programs and preemptively filter the rest of the programs with them. */ def validatePrograms(bss: Set[Set[Identifier]]): Either[Seq[Seq[Expr]], Stream[Solution]] = { - val origImpl = cTreeFd.fullBody var cexs = Seq[Seq[Expr]]() @@ -519,7 +511,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { val innerSol = outerExprToInnerExpr(outerSol) //println(s"Testing $innerSol") //println(innerProgram) - cTreeFd.fullBody = innerSol + solutionBox.underlying = innerSol val cnstr = innerPc and letTuple(p.xs, innerSol, Not(innerPhi)) @@ -528,7 +520,6 @@ abstract class CEGISLike(name: String) extends Rule(name) { if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) { debug(s"Rejected by CEX: $outerSol") excludeProgram(bs, true) - cTreeFd.fullBody = origImpl } else { //println("Solving for: "+cnstr.asString) @@ -566,7 +557,6 @@ abstract class CEGISLike(name: String) extends Rule(name) { } finally { solverf.reclaim(solver) solverf.shutdown() - cTreeFd.fullBody = origImpl } } } @@ -624,6 +614,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { //println("-"*80) //println(programCTree.asString) + solutionBox.underlying = cExpr val toFind = innerPc and innerSpec //println(" --- Constraints ---") //println(" - "+toFind.asString) @@ -696,6 +687,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { val solver = solverf.getNewSolver() try { + solutionBox.underlying = cExpr solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable)))) solver.assertCnstr(innerPc and not(innerSpec))