From 206fe4714fc6524cd8c4fb9c113dceba2b8fa190 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Thu, 1 Nov 2012 19:50:29 +0100 Subject: [PATCH] Complexity needs to be carefully extracted/computed --- .../scala/leon/synthesis/Complexity.scala | 23 ++++ .../scala/leon/synthesis/DerivationTree.scala | 4 +- .../scala/leon/synthesis/Heuristics.scala | 50 ++++++-- src/main/scala/leon/synthesis/Rules.scala | 59 +++++----- src/main/scala/leon/synthesis/Solution.scala | 14 +-- .../scala/leon/synthesis/SynthesisPhase.scala | 13 ++- .../scala/leon/synthesis/Synthesizer.scala | 17 ++- src/main/scala/leon/synthesis/Task.scala | 108 +++++++----------- src/main/scala/leon/synthesis/package.scala | 4 +- 9 files changed, 172 insertions(+), 120 deletions(-) create mode 100644 src/main/scala/leon/synthesis/Complexity.scala diff --git a/src/main/scala/leon/synthesis/Complexity.scala b/src/main/scala/leon/synthesis/Complexity.scala new file mode 100644 index 000000000..f6e38ee51 --- /dev/null +++ b/src/main/scala/leon/synthesis/Complexity.scala @@ -0,0 +1,23 @@ +package leon +package synthesis + +abstract class Complexity extends Ordered[Complexity] { + def compare(that: Complexity): Int = (this.compute, that.compute) match { + case (x, y) if x < y => -1 + case (x, y) if x > y => +1 + case _ => 0 + } + + def compute : Double +} + +object Complexity { + val zero = new Complexity { + override def compute = 0 + override def toString = "0" + } + val max = new Complexity { + override def compute = 42 + override def toString = "MAX" + } +} diff --git a/src/main/scala/leon/synthesis/DerivationTree.scala b/src/main/scala/leon/synthesis/DerivationTree.scala index 927d69179..da8563815 100644 --- a/src/main/scala/leon/synthesis/DerivationTree.scala +++ b/src/main/scala/leon/synthesis/DerivationTree.scala @@ -31,10 +31,11 @@ class DerivationTree(root: RootTask) { } } - def printTask(t: SimpleTask) { + def printTask(t: Task) { val node = nameFor(t, "task"); + /* t.solverTask match { case Some(decompTask) => res append " "+node+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">"+decompTask.rule.name+"</TD></TR><TR><TD BGCOLOR=\"indianred1\">"+escapeHTML(t.problem.toString)+"</TD></TR><TR><TD BGCOLOR=\"greenyellow\">"+escapeHTML(t.solution.map(_.toString).getOrElse("?"))+"</TD></TR></TABLE>> shape = \"none\" ];\n" @@ -46,6 +47,7 @@ class DerivationTree(root: RootTask) { case None => } + */ } diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala index 2e4a3946a..266114d21 100644 --- a/src/main/scala/leon/synthesis/Heuristics.scala +++ b/src/main/scala/leon/synthesis/Heuristics.scala @@ -6,14 +6,16 @@ import purescala.Trees._ import purescala.Extractors._ import purescala.TreeOps._ import purescala.TypeTrees._ +import purescala.Definitions._ object Heuristics { def all(synth: Synthesizer) = Set( + new OptimisticGround(synth), new IntInduction(synth) ) } -class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 90) { +class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", synth, 9, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -43,7 +45,7 @@ class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", syn predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates case (Some(false), _) => - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe)))) + result = Some(RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe), cost))) case _ => result = Some(RuleInapplicable) @@ -51,7 +53,7 @@ class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", syn case (Some(false), _) => if (predicates.isEmpty) { - result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe)))) + result = Some(RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe), cost))) } else { result = Some(RuleInapplicable) } @@ -70,18 +72,46 @@ class OptimisticGround(synth: Synthesizer) extends Rule("Optimistic Ground", syn } -class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80) { +class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 8, 50) { def applyOn(task: Task): RuleResult = { val p = task.problem - p.as.find(_.getType == Int32Type) match { - case Some(inductOn) => + p.as match { + case List(origId) if origId.getType == Int32Type => + val tpe = TupleType(p.xs.map(_.getType)) + + val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType) - val subBase = Problem(p.as.filterNot(_ == inductOn), subst(inductOn -> IntLiteral(0), p.phi), p.xs) - // val subGT = Problem(p.as + tmpGT, And(Seq(p.phi, GreaterThan(Variable(inductOn), IntLiteral(0)), subst(inductOn -> IntLiteral(0), p.phi), p.xs) + val postXs = p.xs map (id => FreshIdentifier("r", true).setType(id.getType)) + + val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable(_)) + + val newPhi = subst(origId -> Variable(inductOn), p.phi) + val postCondGT = substAll(postXsMap + (origId -> Minus(Variable(inductOn), IntLiteral(1))), p.phi) + val postCondLT = substAll(postXsMap + (origId -> Plus(Variable(inductOn), IntLiteral(1))), p.phi) + + val subBase = Problem(List(), subst(origId -> IntLiteral(0), p.phi), p.xs) + val subGT = Problem(inductOn :: postXs, And(Seq(newPhi, GreaterThan(Variable(inductOn), IntLiteral(0)), postCondGT)), p.xs) + val subLT = Problem(inductOn :: postXs, And(Seq(newPhi, LessThan(Variable(inductOn), IntLiteral(0)), postCondLT)), p.xs) + + val onSuccess: List[Solution] => Solution = { + case List(base, gt, lt) => + val newFun = new FunDef(FreshIdentifier("rec", true), tpe, Seq(VarDecl(inductOn, inductOn.getType))) + newFun.body = Some( + IfExpr(Equals(Variable(inductOn), IntLiteral(0)), + base.toExpr, + IfExpr(GreaterThan(Variable(inductOn), IntLiteral(0)), + LetTuple(postXs, FunctionInvocation(newFun, Seq(Minus(Variable(inductOn), IntLiteral(1)))), gt.toExpr) + , LetTuple(postXs, FunctionInvocation(newFun, Seq(Plus(Variable(inductOn), IntLiteral(1)))), lt.toExpr))) + ) + + Solution(BooleanLiteral(true), LetDef(newFun, FunctionInvocation(newFun, Seq(Variable(origId)))), base.cost+gt.cost+lt.cost+cost) + case _ => + Solution.none + } - RuleDecomposed(List(subBase), forward) - case None => + RuleDecomposed(List(subBase, subGT, subLT), onSuccess) + case _ => RuleInapplicable } } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 7fb8a3c8e..6ebaf4e4e 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -17,7 +17,8 @@ object Rules { new CaseSplit(synth), new UnusedInput(synth), new UnconstrainedOutput(synth), - new Assert(synth) + new Assert(synth), + new GiveUp(synth) ) } @@ -27,21 +28,21 @@ case class RuleSuccess(solution: Solution) extends RuleResult case class RuleDecomposed(subProblems: List[Problem], onSuccess: List[Solution] => Solution) extends RuleResult -abstract class Rule(val name: String, val synth: Synthesizer, val priority: Priority) { +abstract class Rule(val name: String, val synth: Synthesizer, val priority: Priority, val cost: Cost) { def applyOn(task: Task): RuleResult def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replace(Map(Variable(what._1) -> what._2), in) def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replace(what.map(w => Variable(w._1) -> w._2), in) - val forward: List[Solution] => Solution = { - case List(s) => s + def forward(cost: Cost): List[Solution] => Solution = { + case List(s) => Solution(s.pre, s.term, s.cost + cost) case _ => Solution.none } override def toString = name } -class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 300) { +class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 30, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -64,11 +65,11 @@ class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 300) { val newProblem = Problem(p.as, subst(x -> e, And(others)), oxs) val onSuccess: List[Solution] => Solution = { - case List(Solution(pre, term)) => + case List(Solution(pre, term, c)) => if (oxs.isEmpty) { - Solution(pre, Tuple(e :: Nil)) + Solution(pre, Tuple(e :: Nil), c + cost) } else { - Solution(pre, LetTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_)))))) + Solution(pre, LetTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_))))), c + cost) } case _ => Solution.none } @@ -80,7 +81,7 @@ class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 300) { } } -class Ground(synth: Synthesizer) extends Rule("Ground", synth, 500) { +class Ground(synth: Synthesizer) extends Rule("Ground", synth, 50, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -90,9 +91,9 @@ class Ground(synth: Synthesizer) extends Rule("Ground", synth, 500) { synth.solveSAT(p.phi) match { case (Some(true), model) => - RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(model))).setType(tpe))) + RuleSuccess(Solution(BooleanLiteral(true), Tuple(p.xs.map(valuateWithModel(model))).setType(tpe), cost)) case (Some(false), model) => - RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe))) + RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe), cost)) case _ => RuleInapplicable } @@ -102,7 +103,7 @@ class Ground(synth: Synthesizer) extends Rule("Ground", synth, 500) { } } -class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 200) { +class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 20, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem p.phi match { @@ -111,7 +112,7 @@ class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 200) { val sub2 = Problem(p.as, o2, p.xs) val onSuccess: List[Solution] => Solution = { - case List(s1, s2) => Solution(Or(s1.pre, s2.pre), IfExpr(s1.pre, s1.term, s2.term)) + case List(Solution(p1, t1, c1), Solution(p2, t2, c2)) => Solution(Or(p1, p2), IfExpr(p1, t1, t2), c1+c2+cost) case _ => Solution.none } @@ -122,7 +123,7 @@ class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 200) { } } -class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) { +class Assert(synth: Synthesizer) extends Rule("Assert", synth, 20, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -134,10 +135,10 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) { if (!exprsA.isEmpty) { if (others.isEmpty) { - RuleSuccess(Solution(And(exprsA), Tuple(p.xs.map(id => simplestValue(Variable(id)))))) + RuleSuccess(Solution(And(exprsA), Tuple(p.xs.map(id => simplestValue(Variable(id)))), cost)) } else { val onSuccess: List[Solution] => Solution = { - case List(s) => Solution(And(s.pre +: exprsA), s.term) + case List(s) => Solution(And(s.pre +: exprsA), s.term, s.cost + cost) case _ => Solution.none } @@ -154,7 +155,7 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) { } } -class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 100) { +class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 10, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem val unused = p.as.toSet -- variablesOf(p.phi) @@ -162,14 +163,14 @@ class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 100) { if (!unused.isEmpty) { val sub = p.copy(as = p.as.filterNot(unused)) - RuleDecomposed(List(sub), forward) + RuleDecomposed(List(sub), forward(cost)) } else { RuleInapplicable } } } -class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", synth, 100) { +class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", synth, 10, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem val unconstr = p.xs.toSet -- variablesOf(p.phi) @@ -179,7 +180,7 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy val onSuccess: List[Solution] => Solution = { case List(s) => - Solution(s.pre, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(Variable(id)) else Variable(id))))) + Solution(s.pre, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(Variable(id)) else Variable(id)))), s.cost + cost) case _ => Solution.none } @@ -193,7 +194,7 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy } object Unification { - class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth, 200) { + class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth, 20, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -214,14 +215,14 @@ object Unification { val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) - RuleDecomposed(List(sub), forward) + RuleDecomposed(List(sub), forward(cost)) } else { RuleInapplicable } } } - class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth, 200) { + class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth, 20, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -239,7 +240,7 @@ object Unification { if (isImpossible) { val tpe = TupleType(p.xs.map(_.getType)) - RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe))) + RuleSuccess(Solution(BooleanLiteral(false), Error(p.phi+" is UNSAT!").setType(tpe), cost)) } else { RuleInapplicable } @@ -248,7 +249,7 @@ object Unification { } -class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 200) { +class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 20, 0) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -268,10 +269,16 @@ class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 200) { if (!toRemove.isEmpty) { val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) - RuleDecomposed(List(sub), forward) + RuleDecomposed(List(sub), forward(cost)) } else { RuleInapplicable } } } +class GiveUp(synth: Synthesizer) extends Rule("GiveUp", synth, 0, 100) { + def applyOn(task: Task): RuleResult = { + RuleSuccess(Solution.choose(task.problem, cost)) + } +} + diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 3f7bda5f6..623530395 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -6,7 +6,7 @@ import leon.purescala.TreeOps.simplifyLets // Defines a synthesis solution of the form: // ⟨ P | T ⟩ -class Solution(val pre: Expr, val term: Expr) { +class Solution(val pre: Expr, val term: Expr, val cost: Cost) { override def toString = "⟨ "+pre+" | "+term+" ⟩" def toExpr = { @@ -19,21 +19,19 @@ class Solution(val pre: Expr, val term: Expr) { } } - def score: Score = 10 + def complexity = Complexity.zero } object Solution { - def choose(p: Problem): Solution = new Solution(BooleanLiteral(true), Choose(p.xs, p.phi)) { - override def score: Score = 0 - } + def choose(p: Problem, cost: Cost): Solution = new Solution(BooleanLiteral(true), Choose(p.xs, p.phi), cost) def none: Solution = throw new Exception("Unexpected failure to construct solution") def simplify(e: Expr) = simplifyLets(e) - def apply(pre: Expr, term: Expr) = { - new Solution(simplify(pre), simplify(term)) + def apply(pre: Expr, term: Expr, cost: Cost) = { + new Solution(simplify(pre), simplify(term), cost) } - def unapply(s: Solution): Option[(Expr, Expr)] = if (s eq null) None else Some((s.pre, s.term)) + def unapply(s: Solution): Option[(Expr, Expr, Cost)] = if (s eq null) None else Some((s.pre, s.term, s.cost)) } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 8ed9fcc66..7754d81f6 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -16,6 +16,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { override def definedOptions = Set( LeonFlagOptionDef( "inplace", "--inplace", "Debug level"), LeonFlagOptionDef( "derivtrees", "--derivtrees", "Generate derivation trees"), + LeonFlagOptionDef( "firstonly", "--firstonly", "Stop as soon as one synthesis solution is found"), LeonValueOptionDef("functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..") ) @@ -30,6 +31,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { var inPlace = false var genTrees = false + var firstOnly = false var filterFun: Option[Seq[String]] = None for(opt <- ctx.options) opt match { @@ -37,12 +39,14 @@ object SynthesisPhase extends LeonPhase[Program, Program] { inPlace = true case LeonValueOption("functions", ListValue(fs)) => filterFun = Some(fs) + case LeonFlagOption("firstonly") => + firstOnly = true case LeonFlagOption("derivtrees") => genTrees = true case _ => } - val synth = new Synthesizer(ctx.reporter, solvers, genTrees, filterFun.map(_.toSet)) + val synth = new Synthesizer(ctx.reporter, solvers, genTrees, filterFun.map(_.toSet), firstOnly) val solutions = synth.synthesizeAll(p) @@ -52,14 +56,15 @@ object SynthesisPhase extends LeonPhase[Program, Program] { simplifyLets _, decomposeIfs _, patternMatchReconstruction _, - simplifyTautologies(uninterpretedZ3)(_) + simplifyTautologies(uninterpretedZ3)(_), + simplifyLets _ ) - val chooseToExprs = solutions.mapValues(sol => simplifiers.foldLeft(sol.toExpr){ (x, sim) => sim(x) }) + val chooseToExprs = solutions.map { case (ch, sol) => (ch, simplifiers.foldLeft(sol.toExpr){ (x, sim) => sim(x) }) } if (inPlace) { for (file <- ctx.files) { - new FileInterface(ctx.reporter, file).updateFile(chooseToExprs) + new FileInterface(ctx.reporter, file).updateFile(chooseToExprs.toMap) } } else { for ((chs, ex) <- chooseToExprs) { diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 918b6584a..3a1544822 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -12,7 +12,12 @@ import java.io.File import collection.mutable.PriorityQueue -class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivationTrees: Boolean, filterFuns: Option[Set[String]]) { +class Synthesizer(val r: Reporter, + val solvers: List[Solver], + generateDerivationTrees: Boolean, + filterFuns: Option[Set[String]], + firstOnly: Boolean) { + import r.{error,warning,info,fatalError} @@ -26,7 +31,7 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivation workList += rootTask - while (!workList.isEmpty && rootTask.solution.isEmpty) { + while (!workList.isEmpty && !(firstOnly && rootTask.solution.isDefined)) { val task = workList.dequeue() val subtasks = task.run @@ -59,14 +64,14 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivation val rules = Rules.all(this) ++ Heuristics.all(this) import purescala.Trees._ - def synthesizeAll(program: Program): Map[Choose, Solution] = { + def synthesizeAll(program: Program): List[(Choose, Solution)] = { solvers.foreach(_.setProgram(program)) def noop(u:Expr, u2: Expr) = u - var solutions = Map[Choose, Solution]() + var solutions = List[(Choose, Solution)]() def actOnChoose(f: FunDef)(e: Expr, a: Expr): Expr = e match { case ch @ Choose(vars, pred) => @@ -76,7 +81,7 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivation val sol = synthesize(Problem(as, phi, xs), rules) - solutions += ch -> sol + solutions = (ch -> sol) :: solutions a case _ => @@ -90,6 +95,6 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivation } } - solutions + solutions.reverse } } diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala index ee579b1a1..e3edd27a1 100644 --- a/src/main/scala/leon/synthesis/Task.scala +++ b/src/main/scala/leon/synthesis/Task.scala @@ -1,74 +1,34 @@ package leon package synthesis -abstract class Task(val synth: Synthesizer, - val parent: Task, +class Task(synth: Synthesizer, + parent: Task, val problem: Problem, - val priority: Priority) extends Ordered[Task] { + val rule: Rule) extends Ordered[Task] { - def compare(that: Task) = this.priority - that.priority + def compare(that: Task) = this.complexity compare that.complexity - def run: List[Task] + def complexity = Complexity.max - override def toString = " Task("+priority+"): " +problem -} - -class SimpleTask(synth: Synthesizer, - override val parent: ApplyRuleTask, - problem: Problem, - priority: Priority) extends Task(synth, parent, problem, priority) { - - var solverTask: Option[ApplyRuleTask] = None - var solution: Option[Solution] = None - - def solvedBy(t: ApplyRuleTask, s: Solution) { - if (solution.map(_.score).getOrElse(-1) < s.score) { - solution = Some(s) - solverTask = Some(t) + var subProblems: List[Problem] = Nil + var onSuccess: List[Solution] => Solution = _ + var subSolutions : Map[Problem, Solution] = _ + var subSolvers : Map[Problem, Task] = _ - if (parent ne null) { - parent.partlySolvedBy(this, s) - } + def currentComplexityFor(p: Problem): Complexity = + subSolutions.get(p) match { + case Some(s) => s.complexity + case None => Complexity.max } - } - def run: List[Task] = { - synth.rules.map(r => new ApplyRuleTask(synth, this, problem, r)).toList - } - - var failed = Set[Rule]() - def unsolvedBy(t: ApplyRuleTask) = { - failed += t.rule - if (failed.size == synth.rules.size) { - val s = Solution.choose(problem) - solution = Some(s) - solverTask = None + def partlySolvedBy(t: Task, s: Solution) { + if (s.complexity < currentComplexityFor(t.problem)) { + subSolutions += t.problem -> s + subSolvers += t.problem -> t - if (parent ne null) { - parent.partlySolvedBy(this, s) - } - } - } -} - -class RootTask(synth: Synthesizer, problem: Problem) extends SimpleTask(synth, null, problem, 0) - -class ApplyRuleTask(synth: Synthesizer, - override val parent: SimpleTask, - problem: Problem, - val rule: Rule) extends Task(synth, parent, problem, rule.priority) { - - var subTasks: List[SimpleTask] = Nil - var onSuccess: List[Solution] => Solution = _ - var subSolutions : Map[SimpleTask, Solution] = _ - - def partlySolvedBy(t: SimpleTask, s: Solution) { - if (subSolutions.get(t).map(_.score).getOrElse(-1) < s.score) { - subSolutions += t -> s - - if (subSolutions.size == subTasks.size) { - val solution = onSuccess(subTasks map subSolutions) - parent.solvedBy(this, solution) + if (subSolutions.size == subProblems.size) { + val solution = onSuccess(subProblems map subSolutions) + parent.partlySolvedBy(this, solution) } } } @@ -77,19 +37,39 @@ class ApplyRuleTask(synth: Synthesizer, rule.applyOn(this) match { case RuleSuccess(solution) => // Solved - parent.solvedBy(this, solution) + parent.partlySolvedBy(this, solution) Nil case RuleDecomposed(subProblems, onSuccess) => - this.subTasks = subProblems.map(new SimpleTask(synth, this, _, 1000)) + this.subProblems = subProblems this.onSuccess = onSuccess this.subSolutions = Map() + this.subSolvers = Map() - subTasks + for (p <- subProblems; r <- synth.rules) yield { + new Task(synth, this, p, r) + } case RuleInapplicable => - parent.unsolvedBy(this) Nil } } override def toString = "Applying "+rule+" on "+problem } + +class RootTask(synth: Synthesizer, problem: Problem) extends Task(synth, null, problem, null) { + var solution: Option[Solution] = None + var solver: Option[Task] = None + + override def partlySolvedBy(t: Task, s: Solution) = { + if (s.complexity < solution.map(_.complexity).getOrElse(Complexity.max)) { + solution = Some(s) + solver = Some(t) + } + } + + override def run: List[Task] = { + for (r <- synth.rules.toList) yield { + new Task(synth, this, problem, r) + } + } +} diff --git a/src/main/scala/leon/synthesis/package.scala b/src/main/scala/leon/synthesis/package.scala index fbef6f0ea..dcdbd7073 100644 --- a/src/main/scala/leon/synthesis/package.scala +++ b/src/main/scala/leon/synthesis/package.scala @@ -1,6 +1,8 @@ package leon package object synthesis { - type Score = Int + type Cost = Int type Priority = Int + + val MAX_COST = 500 } -- GitLab