From d5815659f6d2d900ff31b2cc7f06688d17d43839 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Thu, 1 Nov 2012 13:28:36 +0100 Subject: [PATCH] Store sufficient information to track resolution from rootTask --- .../scala/leon/synthesis/DerivationTree.scala | 38 ++++------- src/main/scala/leon/synthesis/Rules.scala | 10 +-- .../scala/leon/synthesis/Synthesizer.scala | 34 ++-------- src/main/scala/leon/synthesis/Task.scala | 66 ++++++++++--------- 4 files changed, 55 insertions(+), 93 deletions(-) diff --git a/src/main/scala/leon/synthesis/DerivationTree.scala b/src/main/scala/leon/synthesis/DerivationTree.scala index 370b10f79..0b7d7bf38 100644 --- a/src/main/scala/leon/synthesis/DerivationTree.scala +++ b/src/main/scala/leon/synthesis/DerivationTree.scala @@ -2,21 +2,6 @@ package leon package synthesis class DerivationTree(root: RootTask) { - var store = Map[SimpleTask, Map[Problem, SimpleTask]]().withDefaultValue(Map()) - var solutions = Map[Task, Solution]() - - def recordSolutionFor(task: Task, solution: Solution) = task match { - /* - case dt: SimpleTask => - if (dt.parent ne null) { - store += dt.parent -> (store(dt.parent) + (task.problem -> dt)) - } - - solutions += dt -> solution - case _ => - */ - case _ => - } private[this] var _nextID = 0 @@ -48,22 +33,21 @@ class DerivationTree(root: RootTask) { val node = nameFor(t, "task"); - /* - res append " "+node+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">"+t.rule.name+"</TD></TR><TR><TD BGCOLOR=\"indianred1\">"+t.problem+"</TD></TR><TR><TD BGCOLOR=\"greenyellow\">"+solutions.getOrElse(t, "?")+"</TD></TR></TABLE>> shape = \"none\" ];\n" - */ + t.solverTask match { + case Some(decompTask) => + res append " "+node+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">"+decompTask.rule.name+"</TD></TR><TR><TD BGCOLOR=\"indianred1\">"+t.problem+"</TD></TR><TR><TD BGCOLOR=\"greenyellow\">"+t.solution.getOrElse("?")+"</TD></TR></TABLE>> shape = \"none\" ];\n" - for ((_, task) <- store(t)) { - res append nameFor(task, "task") +" -> " + " "+node+";\n" - } - - for ((_, subt) <- store(t)) { - printTask(subt) + for (t <- decompTask.subTasks) { + printTask(t) + res append nameFor(t, "task") +" -> " + " "+node+";\n" + } + + case None => } } - for (task <- store.keysIterator if task.parent eq null) { - printTask(task) - } + + printTask(root) res append "}\n" diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 16351d9b4..be07f654c 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -153,7 +153,7 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) { } } -class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 500) { +class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 100) { def applyOn(task: Task): RuleResult = { val p = task.problem val unused = p.as.toSet -- variablesOf(p.phi) @@ -168,7 +168,7 @@ class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 500) { } } -class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", synth, 500) { +class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", synth, 100) { def applyOn(task: Task): RuleResult = { val p = task.problem val unconstr = p.xs.toSet -- variablesOf(p.phi) @@ -192,7 +192,7 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy } object Unification { - class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth, 300) { + class DecompTrivialClash(synth: Synthesizer) extends Rule("Unif Dec./Clash/Triv.", synth, 200) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -220,7 +220,7 @@ object Unification { } } - class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth, 300) { + class OccursCheck(synth: Synthesizer) extends Rule("Unif OccursCheck", synth, 200) { def applyOn(task: Task): RuleResult = { val p = task.problem @@ -247,7 +247,7 @@ object Unification { } -class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 300) { +class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 200) { def applyOn(task: Task): RuleResult = { val p = task.problem diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 1d75db8e0..8e146e1ec 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -15,25 +15,17 @@ import collection.mutable.PriorityQueue class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivationTrees: Boolean, filterFuns: Option[Set[String]]) { import r.{error,warning,info,fatalError} - private[this] var solution: Option[Solution] = None - private[this] var derivationTree: DerivationTree = _ - - var derivationCounter = 1; def synthesize(p: Problem, rules: List[Rule]): Solution = { val workList = new PriorityQueue[Task]() val rootTask = new RootTask(this, p) + var derivationCounter = 1; workList += rootTask - solution = None - if (generateDerivationTrees) { - derivationTree = new DerivationTree(rootTask) - } - - while (!workList.isEmpty && solution.isEmpty) { + while (!workList.isEmpty && rootTask.solution.isEmpty) { val task = workList.dequeue() val subtasks = task.run @@ -41,31 +33,15 @@ class Synthesizer(val r: Reporter, val solvers: List[Solver], generateDerivation workList ++= subtasks } - if (generateDerivationTrees) { - derivationTree.toDotFile("derivation"+derivationCounter+".dot") + val deriv = new DerivationTree(rootTask) + deriv.toDotFile("derivation"+derivationCounter+".dot") derivationCounter += 1 } - solution.getOrElse(Solution.none) - } - def onTaskSucceeded(task: Task, solution: Solution) { - info(" => Solved "+task.problem+" ⊢ "+solution) - if (generateDerivationTrees) { - derivationTree.recordSolutionFor(task, solution) - } - - task match { - case rt: RootTask => - info(" SUCCESS!") - this.solution = Some(solution) - case d: ApplyRuleTask => - d.parent.succeeded(solution) - case s: SimpleTask => - s.parent.subSucceeded(task.problem, solution) - } + rootTask.solution.getOrElse(Solution.none) } def solveSAT(phi: Expr): (Option[Boolean], Map[Identifier, Expr]) = { diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala index a81ac85a6..3584447e2 100644 --- a/src/main/scala/leon/synthesis/Task.scala +++ b/src/main/scala/leon/synthesis/Task.scala @@ -8,16 +8,6 @@ abstract class Task(val synth: Synthesizer, def compare(that: Task) = this.priority - that.priority - /* - def decompose(rule: Rule, subProblems: List[Problem], onSuccess: List[Solution] => Solution, score: Score): DecomposedTask = { - new DecomposedTask(this.synth, this.parent, this.problem, score, rule, subProblems, onSuccess) - } - - def solveUsing(rule: Rule, onSuccess: => Solution, score: Score): DecomposedTask = { - new DecomposedTask(this.synth, this.parent, this.problem, 1000, rule, Nil, { case _ => onSuccess }) - } - */ - def run: List[Task] override def toString = " Task("+priority+"): " +problem @@ -28,8 +18,18 @@ class SimpleTask(synth: Synthesizer, problem: Problem, priority: Priority) extends Task(synth, parent, problem, priority) { - def succeeded(solution: Solution) { - synth.onTaskSucceeded(this, solution) + var solverTask: Option[ApplyRuleTask] = None + var solution: Option[Solution] = None + + def solvedBy(t: ApplyRuleTask, s: Solution) { + if (solution.map(_.score).getOrElse(-1) < s.score) { + solution = Some(s) + solverTask = Some(t) + + if (parent ne null) { + parent.partlySolvedBy(this, s) + } + } } def run: List[Task] = { @@ -37,10 +37,16 @@ class SimpleTask(synth: Synthesizer, } var failed = Set[Rule]() - def notifyInapplicable(r: Rule) = { - failed += r + def unsolvedBy(t: ApplyRuleTask) = { + failed += t.rule if (failed.size == synth.rules.size) { - synth.onTaskSucceeded(this, Solution.choose(problem)) + val s = Solution.choose(problem) + solution = Some(s) + solverTask = None + + if (parent ne null) { + parent.partlySolvedBy(this, s) + } } } } @@ -52,21 +58,17 @@ class ApplyRuleTask(synth: Synthesizer, problem: Problem, val rule: Rule) extends Task(synth, parent, problem, rule.priority) { - var subProblems: List[Problem] = _ - var onSuccess: List[Solution] => Solution = _ - var subSolutions : Map[Problem, Solution] = _ - - def subSucceeded(p: Problem, s: Solution) { - assert(subProblems contains p, "Problem "+p+" is unknown to me ?!?") - - if (subSolutions.get(p).map(_.score).getOrElse(-1) <= s.score) { - subSolutions += p -> s - - if (subSolutions.size == subProblems.size) { + var subTasks: List[SimpleTask] = Nil + var onSuccess: List[Solution] => Solution = _ + var subSolutions : Map[SimpleTask, Solution] = _ - val solution = onSuccess(subProblems map subSolutions) + def partlySolvedBy(t: SimpleTask, s: Solution) { + if (subSolutions.get(t).map(_.score).getOrElse(-1) < s.score) { + subSolutions += t -> s - synth.onTaskSucceeded(this, solution) + if (subSolutions.size == subTasks.size) { + val solution = onSuccess(subTasks map subSolutions) + parent.solvedBy(this, solution) } } } @@ -75,16 +77,16 @@ class ApplyRuleTask(synth: Synthesizer, rule.applyOn(this) match { case RuleSuccess(solution) => // Solved - synth.onTaskSucceeded(this, solution) + parent.solvedBy(this, solution) Nil case RuleDecomposed(subProblems, onSuccess) => - this.subProblems = subProblems + this.subTasks = subProblems.map(new SimpleTask(synth, this, _, 1000)) this.onSuccess = onSuccess this.subSolutions = Map() - subProblems.map(new SimpleTask(synth, this, _, 42)) + subTasks case RuleInapplicable => - parent.notifyInapplicable(rule) + parent.unsolvedBy(this) Nil } } -- GitLab