diff --git a/src/main/scala/leon/synthesis/PartialSolution.scala b/src/main/scala/leon/synthesis/PartialSolution.scala index 2c517203fb0ab7c5b8c04bfd103b1086320f626c..42b2d7cfcffd9ad1b68fa13eb5939538851b6ff0 100644 --- a/src/main/scala/leon/synthesis/PartialSolution.scala +++ b/src/main/scala/leon/synthesis/PartialSolution.scala @@ -3,9 +3,14 @@ package leon package synthesis +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.DefOps._ +import purescala.Common._ + import graph._ -class PartialSolution(g: Graph, includeUntrusted: Boolean) { +class PartialSolution(g: Graph, includeUntrusted: Boolean = false) { def includeSolution(s: Solution) = { includeUntrusted || s.isTrusted @@ -15,6 +20,49 @@ class PartialSolution(g: Graph, includeUntrusted: Boolean) { Solution.choose(p) } + def solutionAround(n: Node): Option[Expr => Solution] = { + def solveWith(optn: Option[Node], sol: Solution): Option[Solution] = optn match { + case None => + Some(sol) + + case Some(n) => n.parent match { + case None => + Some(sol) + + case Some(on: OrNode) => + solveWith(on.parent, sol) + + case Some(an: AndNode) => + val ssols = for (d <- an.descendents) yield { + if (d == n) { + sol + } else { + getSolutionFor(d) + } + } + + an.ri.onSuccess(ssols).flatMap { nsol => + solveWith(an.parent, nsol) + } + } + } + + val anchor = FreshIdentifier("anchor").setType(n.p.outType) + val s = Solution(BooleanLiteral(true), Set(), anchor.toVariable) + + solveWith(Some(n), s) map { + case s @ Solution(pre, defs, term) => + (e: Expr) => + Solution(replaceFromIDs(Map(anchor -> e), pre), + defs.map(preMapOnFunDef({ + case Variable(`anchor`) => Some(e) + case _ => None + })), + replaceFromIDs(Map(anchor -> e), term), + s.isTrusted) + } + } + def getSolution(): Solution = { getSolutionFor(g.root)