From 8c22a926245afea2d3fb6d81a94449930bf2d779 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Mon, 12 Nov 2012 02:40:11 +0100 Subject: [PATCH] Tentative implementation of multi-steps rules --- .../scala/leon/synthesis/DerivationTree.scala | 2 +- .../scala/leon/synthesis/Heuristics.scala | 14 +-- src/main/scala/leon/synthesis/Rules.scala | 45 +++++---- .../scala/leon/synthesis/SynthesisPhase.scala | 42 ++++++++- .../scala/leon/synthesis/Synthesizer.scala | 72 +++++---------- src/main/scala/leon/synthesis/Task.scala | 92 ++++++++++--------- 6 files changed, 139 insertions(+), 128 deletions(-) diff --git a/src/main/scala/leon/synthesis/DerivationTree.scala b/src/main/scala/leon/synthesis/DerivationTree.scala index 9e3064bb3..c409f4020 100644 --- a/src/main/scala/leon/synthesis/DerivationTree.scala +++ b/src/main/scala/leon/synthesis/DerivationTree.scala @@ -37,7 +37,7 @@ class DerivationTree(root: RootTask) { res append " "+node+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">"+t.rule+"</TD></TR><TR><TD BGCOLOR=\"indianred1\">"+escapeHTML(t.problem.toString)+"</TD></TR><TR><TD BGCOLOR=\"greenyellow\">"+escapeHTML(t.solution.map(_.toString).getOrElse("?"))+"</TD></TR></TABLE>> shape = \"none\" ];\n" - for (st <- t.subSolvers.values) { + for (st <- t.steps.flatMap(_.subSolvers.values)) { res.append(" "+taskName(st)+" -> "+node+"\n") printTask(st) } diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala index 8afdf96bf..6913cc9eb 100644 --- a/src/main/scala/leon/synthesis/Heuristics.scala +++ b/src/main/scala/leon/synthesis/Heuristics.scala @@ -9,10 +9,10 @@ import purescala.TypeTrees._ import purescala.Definitions._ object Heuristics { - def all(synth: Synthesizer) = Set( - new OptimisticGround(synth), - new IntInduction(synth), - new OptimisticInjection(synth) + def all = Set[Synthesizer => Rule]( + new OptimisticGround(_), +// new IntInduction(_), + new OptimisticInjection(_) ) } @@ -117,7 +117,7 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80) Solution.none } - RuleDecomposed(List(subBase, subGT, subLT), onSuccess) + RuleStep(List(subBase, subGT, subLT), onSuccess) case _ => RuleInapplicable } @@ -153,7 +153,7 @@ class OptimisticInjection(synth: Synthesizer) extends Rule("Opt. Injection", syn val sub = p.copy(phi = And(newExprs)) - RuleDecomposed(List(sub), forward) + RuleStep(List(sub), forward) } else { RuleInapplicable } @@ -189,7 +189,7 @@ class SelectiveInlining(synth: Synthesizer) extends Rule("Sel. Inlining", synth, val sub = p.copy(phi = And(newExprs)) - RuleDecomposed(List(sub), forward) + RuleStep(List(sub), forward) } else { RuleInapplicable } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 071cad7f9..1af93b194 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -9,24 +9,31 @@ import purescala.TreeOps._ import purescala.TypeTrees._ object Rules { - def all(synth: Synthesizer) = Set( - new Unification.DecompTrivialClash(synth), - new Unification.OccursCheck(synth), - new ADTDual(synth), - new OnePoint(synth), - new Ground(synth), - new CaseSplit(synth), - new UnusedInput(synth), - new UnconstrainedOutput(synth), - new Assert(synth) + def all = Set[Synthesizer => Rule]( + new Unification.DecompTrivialClash(_), + new Unification.OccursCheck(_), + new ADTDual(_), + new OnePoint(_), + new Ground(_), + new CaseSplit(_), + new UnusedInput(_), + new UnconstrainedOutput(_), + new Assert(_) ) } -abstract class RuleResult +sealed abstract class RuleResult case object RuleInapplicable extends RuleResult case class RuleSuccess(solution: Solution) extends RuleResult -case class RuleDecomposed(subProblems: List[Problem], onSuccess: List[Solution] => Solution) extends RuleResult +case class RuleMultiSteps(subProblems: List[Problem], + steps: List[List[Solution] => List[Problem]], + onSuccess: List[Solution] => Solution) extends RuleResult +object RuleStep { + def apply(subProblems: List[Problem], onSuccess: List[Solution] => Solution) = { + RuleMultiSteps(subProblems, Nil, onSuccess) + } +} abstract class Rule(val name: String, val synth: Synthesizer, val priority: Priority) { def applyOn(task: Task): RuleResult @@ -74,7 +81,7 @@ class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 300) { case _ => Solution.none } - RuleDecomposed(List(newProblem), onSuccess) + RuleStep(List(newProblem), onSuccess) } else { RuleInapplicable } @@ -116,7 +123,7 @@ class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 200) { case _ => Solution.none } - RuleDecomposed(List(sub1, sub2), onSuccess) + RuleStep(List(sub1, sub2), onSuccess) case _ => RuleInapplicable } @@ -146,7 +153,7 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) { val sub = p.copy(phi = And(others)) - RuleDecomposed(List(sub), onSuccess) + RuleStep(List(sub), onSuccess) */ RuleInapplicable } @@ -167,7 +174,7 @@ class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 100) { if (!unused.isEmpty) { val sub = p.copy(as = p.as.filterNot(unused)) - RuleDecomposed(List(sub), forward) + RuleStep(List(sub), forward) } else { RuleInapplicable } @@ -189,7 +196,7 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy Solution.none } - RuleDecomposed(List(sub), onSuccess) + RuleStep(List(sub), onSuccess) } else { RuleInapplicable } @@ -219,7 +226,7 @@ object Unification { val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) - RuleDecomposed(List(sub), forward) + RuleStep(List(sub), forward) } else { RuleInapplicable } @@ -273,7 +280,7 @@ class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 200) { if (!toRemove.isEmpty) { val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) - RuleDecomposed(List(sub), forward) + RuleStep(List(sub), forward) } else { RuleInapplicable } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 2f8dc774e..78103a3b1 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -5,9 +5,9 @@ import purescala.TreeOps._ import solvers.TrivialSolver import solvers.z3.{FairZ3Solver,UninterpretedZ3Solver} -import purescala.Trees.Expr +import purescala.Trees._ import purescala.ScalaPrinter -import purescala.Definitions.Program +import purescala.Definitions.{Program, FunDef} object SynthesisPhase extends LeonPhase[Program, Program] { val name = "Synthesis" @@ -54,8 +54,39 @@ object SynthesisPhase extends LeonPhase[Program, Program] { case _ => } - val synth = new Synthesizer(ctx.reporter, mainSolver, genTrees, filterFun.map(_.toSet), firstOnly, timeoutMs) - val solutions = synth.synthesizeAll(p) + def synthesizeAll(program: Program): Map[Choose, Solution] = { + def noop(u:Expr, u2: Expr) = u + + var solutions = Map[Choose, Solution]() + + def actOnChoose(f: FunDef)(e: Expr, a: Expr): Expr = e match { + case ch @ Choose(vars, pred) => + val xs = vars + val as = (variablesOf(pred)--xs).toList + val phi = pred + + val problem = Problem(as, BooleanLiteral(true), phi, xs) + val synth = new Synthesizer(ctx.reporter, mainSolver, problem, Rules.all ++ Heuristics.all, genTrees, filterFun.map(_.toSet), firstOnly, timeoutMs) + val sol = synth.synthesize() + + solutions += ch -> sol + + a + case _ => + a + } + + // Look for choose() + for (f <- program.definedFunctions.sortBy(_.id.toString) if f.body.isDefined) { + if (filterFun.isEmpty || filterFun.get.contains(f.id.toString)) { + treeCatamorphism(x => x, noop, actOnChoose(f), f.body.get) + } + } + + solutions + } + + val solutions = synthesizeAll(p) // Simplify expressions val simplifiers = List[Expr => Expr]( @@ -71,7 +102,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { if (inPlace) { for (file <- ctx.files) { - new FileInterface(ctx.reporter, file).updateFile(chooseToExprs.toMap) + new FileInterface(ctx.reporter, file).updateFile(chooseToExprs) } } else { for ((chs, ex) <- chooseToExprs) { @@ -86,4 +117,5 @@ object SynthesisPhase extends LeonPhase[Program, Program] { p } + } diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 983b4ec6f..cb945337d 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -12,23 +12,32 @@ import java.io.File import collection.mutable.PriorityQueue -class Synthesizer(val r: Reporter, +class Synthesizer(val reporter: Reporter, val solver: Solver, + val problem: Problem, + val ruleConstructors: Set[Synthesizer => Rule], generateDerivationTrees: Boolean = false, filterFuns: Option[Set[String]] = None, firstOnly: Boolean = false, timeoutMs: Option[Long] = None) { + import reporter.{error,warning,info,fatalError} - import r.{error,warning,info,fatalError} - + val rules = ruleConstructors.map(_.apply(this)) var derivationCounter = 1; - def synthesize(p: Problem, rules: Set[Rule]): Solution = { + val rootTask: RootTask = new RootTask(this, problem) + + val workList = new PriorityQueue[Task]() - val workList = new PriorityQueue[Task]() - val rootTask = new RootTask(this, p) + def bestSolutionSoFar(): Solution = { + rootTask.solution.getOrElse(worstSolution) + } + val worstSolution = Solution.choose(problem) + + def synthesize(): Solution = { + workList.clear() workList += rootTask val ts = System.currentTimeMillis @@ -40,24 +49,10 @@ class Synthesizer(val r: Reporter, } } - val worstSolution = Solution.choose(p) - - def bestSolutionSoFar(): Solution = { - rootTask.solution.getOrElse(worstSolution) - } - while (!workList.isEmpty && !(firstOnly && rootTask.solution.isDefined)) { val task = workList.dequeue() - val subProblems = task.run - - // Check if solving this task has the slightest chance of improving the - // current solution - if (task.minComplexity < bestSolutionSoFar().complexity) { - for (p <- subProblems; r <- rules) yield { - workList += new Task(this, task, p, r) - } - } + task.run() if (timeoutExpired()) { warning("Timeout reached") @@ -74,36 +69,11 @@ class Synthesizer(val r: Reporter, bestSolutionSoFar() } - val rules = Rules.all(this) ++ Heuristics.all(this) - - import purescala.Trees._ - def synthesizeAll(program: Program): List[(Choose, Solution)] = { - def noop(u:Expr, u2: Expr) = u - - var solutions = List[(Choose, Solution)]() - - def actOnChoose(f: FunDef)(e: Expr, a: Expr): Expr = e match { - case ch @ Choose(vars, pred) => - val xs = vars - val as = (variablesOf(pred)--xs).toList - val phi = pred - - val sol = synthesize(Problem(as, BooleanLiteral(true), phi, xs), rules) - - solutions = (ch -> sol) :: solutions - - a - case _ => - a + def addProblems(task: Task, problems: Traversable[Problem]) { + // Check if solving this task has the slightest chance of improving the + // current solution + for (p <- problems; rule <- rules) yield { + workList += new Task(this, task, p, rule) } - - // Look for choose() - for (f <- program.definedFunctions.sortBy(_.id.toString) if f.body.isDefined) { - if (filterFuns.isEmpty || filterFuns.get.contains(f.id.toString)) { - treeCatamorphism(x => x, noop, actOnChoose(f), f.body.get) - } - } - - solutions.reverse } } diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala index e8eba4143..a1632ec33 100644 --- a/src/main/scala/leon/synthesis/Task.scala +++ b/src/main/scala/leon/synthesis/Task.scala @@ -17,42 +17,61 @@ class Task(synth: Synthesizer, } } - var subProblems: List[Problem] = Nil - var onSuccess: List[Solution] => Solution = _ - var subSolutions : Map[Problem, Solution] = Map() - var subSolvers : Map[Problem, Task] = Map() - var solution : Option[Solution] = None - - def isBetterSolutionThan(sol: Solution, osol: Option[Solution]): Boolean = + def isBetterSolutionThan(sol: Solution, osol: Option[Solution]): Boolean = { osol match { case Some(s) => s.complexity > sol.complexity case None => true } + } + + var solution: Option[Solution] = None + + class TaskStep(val subProblems: List[Problem]) { + var subSolutions = Map[Problem, Solution]() + var subSolvers = Map[Problem, Task]() + var failures = Set[Rule]() + } + + var steps: List[TaskStep] = Nil + var nextSteps: List[List[Solution] => List[Problem]] = Nil + var onSuccess: List[Solution] => Solution = _ + + def currentStep = steps.head def partlySolvedBy(t: Task, s: Solution) { - if (isBetterSolutionThan(s, subSolutions.get(t.problem))) { - subSolutions += t.problem -> s - subSolvers += t.problem -> t + if (isBetterSolutionThan(s, currentStep.subSolutions.get(t.problem))) { + currentStep.subSolutions += t.problem -> s + currentStep.subSolvers += t.problem -> t + + if (currentStep.subSolutions.size == currentStep.subProblems.size) { + val solutions = currentStep.subProblems map currentStep.subSolutions + + if (!nextSteps.isEmpty) { + val newProblems = nextSteps.head.apply(solutions) + nextSteps = nextSteps.tail - if (subSolutions.size == subProblems.size) { - solution = Some(onSuccess(subProblems map subSolutions)) - parent.partlySolvedBy(this, solution.get) + synth.addProblems(this, newProblems) + + steps = new TaskStep(newProblems) :: steps + } else { + solution = Some(onSuccess(solutions)) + parent.partlySolvedBy(this, solution.get) + } } } } - var failures = Set[Rule]() def unsolvedBy(t: Task) { - failures += t.rule + currentStep.failures += t.rule - if (failures.size == synth.rules.size) { + if (currentStep.failures.size == synth.rules.size) { // We might want to report unsolved instead of solvedByChoose, depending // on the cases parent.partlySolvedBy(this, Solution.choose(problem)) } } - def run: List[Problem] = { + def run() { rule.applyOn(this) match { case RuleSuccess(solution) => // Solved @@ -63,14 +82,10 @@ class Task(synth: Synthesizer, println(prefix+"Got: "+problem) println(prefix+"Solved with: "+solution) - Nil - - case RuleDecomposed(subProblems, onSuccess) => - this.subProblems = subProblems - this.onSuccess = onSuccess - - val simplestSolution = onSuccess(subProblems.map(Solution.basic _)) - minComplexity = new FixedSolComplexity(parent.minComplexity.value + simplestSolution.complexity.value) + case RuleMultiSteps(subProblems, interSteps, onSuccess) => + this.steps = new TaskStep(subProblems) :: Nil + this.nextSteps = interSteps + this.onSuccess = onSuccess val prefix = "[%-20s] ".format(Option(rule).map(_.toString).getOrElse("root")) println(prefix+"Got: "+problem) @@ -79,42 +94,29 @@ class Task(synth: Synthesizer, println(prefix+" - "+p) } - subProblems + synth.addProblems(this, subProblems) case RuleInapplicable => parent.unsolvedBy(this) - Nil } } - var minComplexity: AbsSolComplexity = new FixedSolComplexity(0) - override def toString = "Applying "+rule+" on "+problem } class RootTask(synth: Synthesizer, problem: Problem) extends Task(synth, null, problem, null) { var solver: Option[Task] = None - override def partlySolvedBy(t: Task, s: Solution) { - if (isBetterSolutionThan(s, solution)) { - solution = Some(s) - solver = Some(t) - } + override def run() { + synth.addProblems(this, List(problem)) } - override def unsolvedBy(t: Task) { - failures += t.rule - - if (failures.size == synth.rules.size) { - // We might want to report unsolved instead of solvedByChoose, depending - // on the cases - solution = Some(Solution.choose(problem)) - solver = None - } + override def partlySolvedBy(t: Task, s: Solution) { + solution = Some(s) + solver = Some(t) } - override def run: List[Problem] = { - List(problem) + override def unsolvedBy(t: Task) { } } -- GitLab