From 1900e8dd48bd53adadfbbe822c1dc5ef5a20bd23 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Thu, 4 Apr 2013 15:32:02 +0200 Subject: [PATCH] Implement untrusted solutions, that are then validated --- .../scala/leon/synthesis/ParallelSearch.scala | 6 +- src/main/scala/leon/synthesis/Rules.scala | 6 +- .../scala/leon/synthesis/SimpleSearch.scala | 107 +++++++++++++++++- .../scala/leon/synthesis/Synthesizer.scala | 60 +++++++++- .../leon/synthesis/search/AndOrGraph.scala | 4 + .../search/AndOrGraphParallelSearch.scala | 4 +- .../search/AndOrGraphPartialSolution.scala | 6 +- .../synthesis/search/AndOrGraphSearch.scala | 10 +- 8 files changed, 179 insertions(+), 24 deletions(-) diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala index ca4346ade..dbdbe0517 100644 --- a/src/main/scala/leon/synthesis/ParallelSearch.scala +++ b/src/main/scala/leon/synthesis/ParallelSearch.scala @@ -49,13 +49,13 @@ class ParallelSearch(synth: Synthesizer, val prefix = "[%-20s] ".format(Option(t.rule).getOrElse("?")) t.app.apply(sctx) match { - case RuleSuccess(sol) => + case RuleSuccess(sol, isTrusted) => synth.synchronized { info(prefix+"Got: "+t.problem) - info(prefix+"Solved with: "+sol) + info(prefix+"Solved"+(if(isTrusted) "" else " (untrusted)")+" with: "+sol) } - ExpandSuccess(sol) + ExpandSuccess(sol, isTrusted) case RuleDecomposed(sub) => synth.synchronized { info(prefix+"Got: "+t.problem) diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 2f51522c7..2357241c0 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -57,9 +57,9 @@ abstract class RuleInstantiation(val problem: Problem, val rule: Rule, val onSuc } sealed abstract class RuleApplicationResult -case class RuleSuccess(solution: Solution) extends RuleApplicationResult -case class RuleDecomposed(sub: List[Problem]) extends RuleApplicationResult -case object RuleApplicationImpossible extends RuleApplicationResult +case class RuleSuccess(solution: Solution, isTrusted: Boolean = true) extends RuleApplicationResult +case class RuleDecomposed(sub: List[Problem]) extends RuleApplicationResult +case object RuleApplicationImpossible extends RuleApplicationResult object RuleInstantiation { def immediateDecomp(problem: Problem, rule: Rule, sub: List[Problem], onSuccess: List[Solution] => Option[Solution], description: String) = { diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala index 06fa7da53..d042fc2e1 100644 --- a/src/main/scala/leon/synthesis/SimpleSearch.scala +++ b/src/main/scala/leon/synthesis/SimpleSearch.scala @@ -1,6 +1,8 @@ package leon package synthesis +import purescala.Definitions.FunDef + import synthesis.search._ case class TaskRunRule(app: RuleInstantiation) extends AOAndTask[Solution] { @@ -48,10 +50,10 @@ class SimpleSearch(synth: Synthesizer, info(prefix+"Got: "+t.problem) t.app.apply(sctx) match { - case RuleSuccess(sol) => - info(prefix+"Solved with: "+sol) + case RuleSuccess(sol, isTrusted) => + info(prefix+"Solved"+(if(isTrusted) "" else " (untrusted)")+" with: "+sol) - ExpandSuccess(sol) + ExpandSuccess(sol, isTrusted) case RuleDecomposed(sub) => info(prefix+"Decomposed into:") for(p <- sub) { @@ -87,6 +89,101 @@ class SimpleSearch(synth: Synthesizer, } } + case class SubProgram(p: Problem, fd: FunDef, acc: Set[FunDef]) + + def programAt(n: g.Tree): SubProgram = { + import purescala.TypeTrees._ + import purescala.Common._ + import purescala.TreeOps.replace + import purescala.Trees._ + import purescala.Definitions._ + + def programFrom(from: g.AndNode, sp: SubProgram): SubProgram = { + if (from.parent.parent eq null) { + sp + } else { + val at = from.parent.parent + val res = bestProgramForAnd(at, Map(from.parent -> sp)) + programFrom(at, res) + } + } + + def bestProgramForOr(on: g.OrTree): SubProgram = { + val problem: Problem = on.task.p + + val fd = problemToFunDef(problem) + + SubProgram(problem, fd, Set(fd)) + } + + def fundefToSol(p: Problem, fd: FunDef): Solution = { + Solution(BooleanLiteral(true), Set(), FunctionInvocation(fd, p.as.map(Variable(_)))) + } + + def solToSubProgram(p: Problem, s: Solution): SubProgram = { + val fd = problemToFunDef(p) + fd.precondition = Some(s.pre) + fd.body = Some(s.term) + + SubProgram(p, fd, Set(fd)) + } + + def bestProgramForAnd(an: g.AndNode, subPrograms: Map[g.OrTree, SubProgram]): SubProgram = { + val subSubPrograms = an.subTasks.map(an.subProblems).map( ot => + subPrograms.getOrElse(ot, bestProgramForOr(ot)) + ) + + val allFd = subSubPrograms.flatMap(_.acc) + val subSolutions = subSubPrograms.map(ssp => fundefToSol(ssp.p, ssp.fd)) + + val sp = solToSubProgram(an.task.problem, an.task.composeSolution(subSolutions).get) + + sp.copy(acc = sp.acc ++ allFd) + } + + def problemToFunDef(p: Problem): FunDef = { + + val ret = if (p.xs.size == 1) { + p.xs.head.getType + } else { + TupleType(p.xs.map(_.getType)) + } + + val freshAs = p.as.map(_.freshen) + + val map = (p.as.map(Variable(_): Expr) zip freshAs.map(Variable(_): Expr)).toMap + + val res = ResultVariable().setType(ret) + + val mapPost: Map[Expr, Expr] = if (p.xs.size > 1) { + p.xs.zipWithIndex.map{ case (id, i) => + Variable(id) -> TupleSelect(res, i+1) + }.toMap + } else { + Map(Variable(p.xs.head) -> ResultVariable().setType(ret)) + } + + val fd = new FunDef(FreshIdentifier("chimp", true), ret, freshAs.map(id => VarDecl(id, id.getType))) + fd.precondition = Some(replace(map, p.pc)) + fd.postcondition = Some(replace(map++mapPost, p.phi)) + + fd + } + + n match { + case an: g.AndNode => + programFrom(an, bestProgramForAnd(an, Map.empty)) + + case on: g.OrNode => + if (on.parent ne null) { + programAt(on.parent) + } else { + bestProgramForOr(on) + } + } + } + + var shouldStop = false def searchStep() { @@ -105,7 +202,7 @@ class SimpleSearch(synth: Synthesizer, } } - def search(): Option[Solution] = { + def search(): Option[(Solution, Boolean)] = { sctx.solver.init() shouldStop = false @@ -113,7 +210,7 @@ class SimpleSearch(synth: Synthesizer, while (!g.tree.isSolved && !shouldStop) { searchStep() } - g.tree.solution + g.tree.solution.map(s => (s, g.tree.isTrustworthy)) } override def stop() { diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index ac2acc6f5..36cffecf0 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -4,9 +4,10 @@ package synthesis import purescala.Common._ import purescala.Definitions.{Program, FunDef} import purescala.TreeOps._ -import purescala.Trees.{Expr, Not} +import purescala.Trees._ import purescala.ScalaPrinter import solvers.z3._ +import solvers.TimeoutSolver import sun.misc.{Signal, SignalHandler} import solvers.Solver @@ -24,7 +25,8 @@ class Synthesizer(val context : LeonContext, val problem: Problem, val options: SynthesisOptions) { - val silentContext = context.copy(reporter = new SilentReporter) + val silentReporter = new SilentReporter + val silentContext = context.copy(reporter = silentReporter) val rules: Seq[Rule] = options.rules @@ -75,10 +77,60 @@ class Synthesizer(val context : LeonContext, } res match { - case Some(solution) => + case Some((solution, true)) => + val ssol = solution.toSimplifiedExpr(context, program) (solution, true) + case Some((sol, false)) => + val ssol = sol.toSimplifiedExpr(context, program) + reporter.info("Solution requires validation") + + val (npr, fds) = solutionToProgram(sol) + + val tsolver = new TimeoutSolver(new FairZ3Solver(silentContext), 10000L) + tsolver.setProgram(npr) + + import verification.AnalysisPhase._ + val vcs = generateVerificationConditions(reporter, npr, fds.map(_.id.name)) + val vcreport = checkVerificationConditions(silentReporter, Seq(tsolver), vcs) + + if (vcreport.totalValid == vcreport.totalConditions) { + (sol, true) + } else { + reporter.warning("Solution was invalid:") + reporter.warning(fds.map(ScalaPrinter(_)).mkString("\n\n")) + reporter.warning(vcreport.summaryString) + (new AndOrGraphPartialSolution(search.g, (task: TaskRunRule) => Solution.choose(task.problem), false).getSolution, false) + } case None => - (new AndOrGraphPartialSolution(search.g, (task: TaskRunRule) => Solution.choose(task.problem)).getSolution, false) + (new AndOrGraphPartialSolution(search.g, (task: TaskRunRule) => Solution.choose(task.problem), true).getSolution, false) } } + + // Returns the new program and the new functions generated for this + def solutionToProgram(sol: Solution): (Program, Set[FunDef]) = { + import purescala.TypeTrees.TupleType + import purescala.Definitions.VarDecl + + val mainObject = program.mainObject + + // Create new fundef for the body + val ret = TupleType(problem.xs.map(_.getType)) + val res = ResultVariable().setType(ret) + + val mapPost: Map[Expr, Expr] = + problem.xs.zipWithIndex.map{ case (id, i) => + Variable(id) -> TupleSelect(res, i+1) + }.toMap + + val fd = new FunDef(FreshIdentifier("finalTerm", true), ret, problem.as.map(id => VarDecl(id, id.getType))) + fd.precondition = Some(And(problem.pc, sol.pre)) + fd.postcondition = Some(replace(mapPost, problem.phi)) + fd.body = Some(sol.term) + + val newDefs = sol.defs + fd + + val npr = program.copy(mainObject = mainObject.copy(defs = mainObject.defs ++ newDefs)) + + (npr, newDefs) + } } diff --git a/src/main/scala/leon/synthesis/search/AndOrGraph.scala b/src/main/scala/leon/synthesis/search/AndOrGraph.scala index a289e40ca..50a7ab859 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraph.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraph.scala @@ -24,6 +24,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos def minCost: Cost + var isTrustworthy: Boolean = true var solution: Option[S] = None var isUnsolvable: Boolean = false @@ -86,6 +87,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos if (subSolutions.size == subProblems.size) { task.composeSolution(subTasks.map(subSolutions)) match { case Some(sol) => + isTrustworthy = subProblems.forall(_._2.isTrustworthy) solution = Some(sol) updateMin() @@ -174,12 +176,14 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos def notifySolution(sub: AndTree, sol: S) { solution match { case Some(preSol) if (costModel.solutionCost(preSol) < costModel.solutionCost(sol)) => + isTrustworthy = sub.isTrustworthy solution = Some(sol) minAlternative = sub notifyParent(solution.get) case None => + isTrustworthy = sub.isTrustworthy solution = Some(sol) minAlternative = sub diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala index 88f378a8e..4ff625966 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala @@ -18,7 +18,7 @@ abstract class AndOrGraphParallelSearch[WC, var system: ActorSystem = _ - def search(): Option[S] = { + def search(): Option[(S, Boolean)] = { system = ActorSystem("ParallelSearch") val master = system.actorOf(Props(new Master), name = "Master") @@ -38,7 +38,7 @@ abstract class AndOrGraphParallelSearch[WC, system = null } - g.tree.solution + g.tree.solution.map(s => (s, g.tree.isTrustworthy)) } override def stop() { diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala index d4d38dd77..3f9fbdb68 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala @@ -2,7 +2,7 @@ package leon.synthesis.search class AndOrGraphPartialSolution[AT <: AOAndTask[S], OT <: AOOrTask[S], - S](val g: AndOrGraph[AT, OT, S], missing: AT => S) { + S](val g: AndOrGraph[AT, OT, S], missing: AT => S, includeUntrusted: Boolean) { def getSolution: S = { @@ -10,7 +10,7 @@ class AndOrGraphPartialSolution[AT <: AOAndTask[S], } def solveAnd(t: g.AndTree): S = { - if (t.isSolved) { + if (t.isSolved && (includeUntrusted || t.isTrustworthy)) { t.solution.get } else { t match { @@ -23,7 +23,7 @@ class AndOrGraphPartialSolution[AT <: AOAndTask[S], } def solveOr(t: g.OrTree): S = { - if (t.isSolved) { + if (t.isSolved && (includeUntrusted || t.isTrustworthy)) { t.solution.get } else { t match { diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala index d4f5e9416..00593edef 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala @@ -57,20 +57,21 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], abstract class ExpandResult[T <: AOTask[S]] case class Expanded[T <: AOTask[S]](sub: List[T]) extends ExpandResult[T] - case class ExpandSuccess[T <: AOTask[S]](sol: S) extends ExpandResult[T] + case class ExpandSuccess[T <: AOTask[S]](sol: S, isTrustworthy: Boolean) extends ExpandResult[T] case class ExpandFailure[T <: AOTask[S]]() extends ExpandResult[T] def stop() { } - def search(): Option[S] + def search(): Option[(S, Boolean)] def onExpansion(al: g.AndLeaf, res: ExpandResult[OT]) { res match { case Expanded(ls) => al.expandWith(ls) - case r @ ExpandSuccess(sol) => + case r @ ExpandSuccess(sol, isTrustworthy) => + al.isTrustworthy = isTrustworthy al.solution = Some(sol) al.parent.notifySolution(al, sol) case _ => @@ -89,7 +90,8 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], res match { case Expanded(ls) => ol.expandWith(ls) - case r @ ExpandSuccess(sol) => + case r @ ExpandSuccess(sol, isTrustworthy) => + ol.isTrustworthy = isTrustworthy ol.solution = Some(sol) ol.parent.notifySolution(ol, sol) case _ => -- GitLab