From 8c22a926245afea2d3fb6d81a94449930bf2d779 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Mon, 12 Nov 2012 02:40:11 +0100
Subject: [PATCH] Tentative implementation of multi-steps rules

---
 .../scala/leon/synthesis/DerivationTree.scala |  2 +-
 .../scala/leon/synthesis/Heuristics.scala     | 14 +--
 src/main/scala/leon/synthesis/Rules.scala     | 45 +++++----
 .../scala/leon/synthesis/SynthesisPhase.scala | 42 ++++++++-
 .../scala/leon/synthesis/Synthesizer.scala    | 72 +++++----------
 src/main/scala/leon/synthesis/Task.scala      | 92 ++++++++++---------
 6 files changed, 139 insertions(+), 128 deletions(-)

diff --git a/src/main/scala/leon/synthesis/DerivationTree.scala b/src/main/scala/leon/synthesis/DerivationTree.scala
index 9e3064bb3..c409f4020 100644
--- a/src/main/scala/leon/synthesis/DerivationTree.scala
+++ b/src/main/scala/leon/synthesis/DerivationTree.scala
@@ -37,7 +37,7 @@ class DerivationTree(root: RootTask)  {
 
       res append " "+node+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">"+t.rule+"</TD></TR><TR><TD BGCOLOR=\"indianred1\">"+escapeHTML(t.problem.toString)+"</TD></TR><TR><TD BGCOLOR=\"greenyellow\">"+escapeHTML(t.solution.map(_.toString).getOrElse("?"))+"</TD></TR></TABLE>> shape = \"none\" ];\n"
 
-      for (st <- t.subSolvers.values) {
+      for (st <- t.steps.flatMap(_.subSolvers.values)) {
         res.append(" "+taskName(st)+" -> "+node+"\n")
         printTask(st)
       }
diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala
index 8afdf96bf..6913cc9eb 100644
--- a/src/main/scala/leon/synthesis/Heuristics.scala
+++ b/src/main/scala/leon/synthesis/Heuristics.scala
@@ -9,10 +9,10 @@ import purescala.TypeTrees._
 import purescala.Definitions._
 
 object Heuristics {
-  def all(synth: Synthesizer) = Set(
-    new OptimisticGround(synth),
-    new IntInduction(synth),
-    new OptimisticInjection(synth)
+  def all = Set[Synthesizer => Rule](
+    new OptimisticGround(_),
+//    new IntInduction(_),
+    new OptimisticInjection(_)
   )
 }
 
@@ -117,7 +117,7 @@ class IntInduction(synth: Synthesizer) extends Rule("Int Induction", synth, 80)
             Solution.none
         }
 
-        RuleDecomposed(List(subBase, subGT, subLT), onSuccess)
+        RuleStep(List(subBase, subGT, subLT), onSuccess)
       case _ =>
         RuleInapplicable
     }
@@ -153,7 +153,7 @@ class OptimisticInjection(synth: Synthesizer) extends Rule("Opt. Injection", syn
 
       val sub = p.copy(phi = And(newExprs))
 
-      RuleDecomposed(List(sub), forward)
+      RuleStep(List(sub), forward)
     } else {
       RuleInapplicable
     }
@@ -189,7 +189,7 @@ class SelectiveInlining(synth: Synthesizer) extends Rule("Sel. Inlining", synth,
 
       val sub = p.copy(phi = And(newExprs))
 
-      RuleDecomposed(List(sub), forward)
+      RuleStep(List(sub), forward)
     } else {
       RuleInapplicable
     }
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index 071cad7f9..1af93b194 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -9,24 +9,31 @@ import purescala.TreeOps._
 import purescala.TypeTrees._
 
 object Rules {
-  def all(synth: Synthesizer) = Set(
-    new Unification.DecompTrivialClash(synth),
-    new Unification.OccursCheck(synth),
-    new ADTDual(synth),
-    new OnePoint(synth),
-    new Ground(synth),
-    new CaseSplit(synth),
-    new UnusedInput(synth),
-    new UnconstrainedOutput(synth),
-    new Assert(synth)
+  def all = Set[Synthesizer => Rule](
+    new Unification.DecompTrivialClash(_),
+    new Unification.OccursCheck(_),
+    new ADTDual(_),
+    new OnePoint(_),
+    new Ground(_),
+    new CaseSplit(_),
+    new UnusedInput(_),
+    new UnconstrainedOutput(_),
+    new Assert(_)
   )
 }
 
-abstract class RuleResult
+sealed abstract class RuleResult
 case object RuleInapplicable extends RuleResult
 case class RuleSuccess(solution: Solution) extends RuleResult
-case class RuleDecomposed(subProblems: List[Problem], onSuccess: List[Solution] => Solution) extends RuleResult
+case class RuleMultiSteps(subProblems: List[Problem],
+                          steps: List[List[Solution] => List[Problem]],
+                          onSuccess: List[Solution] => Solution) extends RuleResult
 
+object RuleStep {
+  def apply(subProblems: List[Problem], onSuccess: List[Solution] => Solution) = {
+    RuleMultiSteps(subProblems, Nil, onSuccess)
+  }
+}
 
 abstract class Rule(val name: String, val synth: Synthesizer, val priority: Priority) {
   def applyOn(task: Task): RuleResult
@@ -74,7 +81,7 @@ class OnePoint(synth: Synthesizer) extends Rule("One-point", synth, 300) {
         case _ => Solution.none
       }
 
-      RuleDecomposed(List(newProblem), onSuccess)
+      RuleStep(List(newProblem), onSuccess)
     } else {
       RuleInapplicable
     }
@@ -116,7 +123,7 @@ class CaseSplit(synth: Synthesizer) extends Rule("Case-Split", synth, 200) {
           case _ => Solution.none
         }
 
-        RuleDecomposed(List(sub1, sub2), onSuccess)
+        RuleStep(List(sub1, sub2), onSuccess)
       case _ =>
         RuleInapplicable
     }
@@ -146,7 +153,7 @@ class Assert(synth: Synthesizer) extends Rule("Assert", synth, 200) {
 
             val sub = p.copy(phi = And(others))
 
-            RuleDecomposed(List(sub), onSuccess)
+            RuleStep(List(sub), onSuccess)
             */
             RuleInapplicable
           }
@@ -167,7 +174,7 @@ class UnusedInput(synth: Synthesizer) extends Rule("UnusedInput", synth, 100) {
     if (!unused.isEmpty) {
       val sub = p.copy(as = p.as.filterNot(unused))
 
-      RuleDecomposed(List(sub), forward)
+      RuleStep(List(sub), forward)
     } else {
       RuleInapplicable
     }
@@ -189,7 +196,7 @@ class UnconstrainedOutput(synth: Synthesizer) extends Rule("Unconstr.Output", sy
           Solution.none
       }
 
-      RuleDecomposed(List(sub), onSuccess)
+      RuleStep(List(sub), onSuccess)
     } else {
       RuleInapplicable
     }
@@ -219,7 +226,7 @@ object Unification {
         val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq))
 
 
-        RuleDecomposed(List(sub), forward)
+        RuleStep(List(sub), forward)
       } else {
         RuleInapplicable
       }
@@ -273,7 +280,7 @@ class ADTDual(synth: Synthesizer) extends Rule("ADTDual", synth, 200) {
     if (!toRemove.isEmpty) {
       val sub = p.copy(phi = And((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq))
 
-      RuleDecomposed(List(sub), forward)
+      RuleStep(List(sub), forward)
     } else {
       RuleInapplicable
     }
diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala
index 2f8dc774e..78103a3b1 100644
--- a/src/main/scala/leon/synthesis/SynthesisPhase.scala
+++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala
@@ -5,9 +5,9 @@ import purescala.TreeOps._
 import solvers.TrivialSolver
 import solvers.z3.{FairZ3Solver,UninterpretedZ3Solver}
 
-import purescala.Trees.Expr
+import purescala.Trees._
 import purescala.ScalaPrinter
-import purescala.Definitions.Program
+import purescala.Definitions.{Program, FunDef}
 
 object SynthesisPhase extends LeonPhase[Program, Program] {
   val name        = "Synthesis"
@@ -54,8 +54,39 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
       case _ =>
     }
 
-    val synth = new Synthesizer(ctx.reporter, mainSolver, genTrees, filterFun.map(_.toSet), firstOnly, timeoutMs)
-    val solutions = synth.synthesizeAll(p)
+    def synthesizeAll(program: Program): Map[Choose, Solution] = {
+      def noop(u:Expr, u2: Expr) = u
+
+      var solutions = Map[Choose, Solution]()
+
+      def actOnChoose(f: FunDef)(e: Expr, a: Expr): Expr = e match {
+        case ch @ Choose(vars, pred) =>
+          val xs = vars
+          val as = (variablesOf(pred)--xs).toList
+          val phi = pred
+
+          val problem = Problem(as, BooleanLiteral(true), phi, xs)
+          val synth = new Synthesizer(ctx.reporter, mainSolver, problem, Rules.all ++ Heuristics.all, genTrees, filterFun.map(_.toSet), firstOnly, timeoutMs)
+          val sol = synth.synthesize()
+
+          solutions += ch -> sol
+
+          a
+        case _ =>
+          a
+      }
+
+      // Look for choose()
+      for (f <- program.definedFunctions.sortBy(_.id.toString) if f.body.isDefined) {
+        if (filterFun.isEmpty || filterFun.get.contains(f.id.toString)) {
+          treeCatamorphism(x => x, noop, actOnChoose(f), f.body.get)
+        }
+      }
+
+      solutions
+    }
+
+    val solutions = synthesizeAll(p)
 
     // Simplify expressions
     val simplifiers = List[Expr => Expr](
@@ -71,7 +102,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
 
     if (inPlace) {
       for (file <- ctx.files) {
-        new FileInterface(ctx.reporter, file).updateFile(chooseToExprs.toMap)
+        new FileInterface(ctx.reporter, file).updateFile(chooseToExprs)
       }
     } else {
       for ((chs, ex) <- chooseToExprs) {
@@ -86,4 +117,5 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
     p
   }
 
+
 }
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index 983b4ec6f..cb945337d 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -12,23 +12,32 @@ import java.io.File
 
 import collection.mutable.PriorityQueue
 
-class Synthesizer(val r: Reporter,
+class Synthesizer(val reporter: Reporter,
                   val solver: Solver,
+                  val problem: Problem,
+                  val ruleConstructors: Set[Synthesizer => Rule],
                   generateDerivationTrees: Boolean = false,
                   filterFuns: Option[Set[String]]  = None,
                   firstOnly: Boolean               = false,
                   timeoutMs: Option[Long]          = None) {
+  import reporter.{error,warning,info,fatalError}
 
-  import r.{error,warning,info,fatalError}
-
+  val rules = ruleConstructors.map(_.apply(this))
 
   var derivationCounter = 1;
 
-  def synthesize(p: Problem, rules: Set[Rule]): Solution = {
+  val rootTask: RootTask = new RootTask(this, problem)
+
+  val workList = new PriorityQueue[Task]()
 
-    val workList = new PriorityQueue[Task]()
-    val rootTask = new RootTask(this, p)
+  def bestSolutionSoFar(): Solution = {
+    rootTask.solution.getOrElse(worstSolution)
+  }
 
+  val worstSolution = Solution.choose(problem)
+
+  def synthesize(): Solution = {
+    workList.clear()
     workList += rootTask
 
     val ts = System.currentTimeMillis
@@ -40,24 +49,10 @@ class Synthesizer(val r: Reporter,
       }
     }
 
-    val worstSolution = Solution.choose(p)
-
-    def bestSolutionSoFar(): Solution = {
-      rootTask.solution.getOrElse(worstSolution)
-    }
-    
     while (!workList.isEmpty && !(firstOnly && rootTask.solution.isDefined)) {
       val task = workList.dequeue()
 
-      val subProblems = task.run
-
-      // Check if solving this task has the slightest chance of improving the
-      // current solution
-      if (task.minComplexity < bestSolutionSoFar().complexity) {
-        for (p <- subProblems; r <- rules) yield {
-          workList += new Task(this, task, p, r)
-        }
-      }
+      task.run()
 
       if (timeoutExpired()) {
         warning("Timeout reached")
@@ -74,36 +69,11 @@ class Synthesizer(val r: Reporter,
     bestSolutionSoFar()
   }
 
-  val rules = Rules.all(this) ++ Heuristics.all(this)
-
-  import purescala.Trees._
-  def synthesizeAll(program: Program): List[(Choose, Solution)] = {
-    def noop(u:Expr, u2: Expr) = u
-
-    var solutions = List[(Choose, Solution)]()
-
-    def actOnChoose(f: FunDef)(e: Expr, a: Expr): Expr = e match {
-      case ch @ Choose(vars, pred) =>
-        val xs = vars
-        val as = (variablesOf(pred)--xs).toList
-        val phi = pred
-
-        val sol = synthesize(Problem(as, BooleanLiteral(true), phi, xs), rules)
-
-        solutions = (ch -> sol) :: solutions
-
-        a
-      case _ =>
-        a
+  def addProblems(task: Task, problems: Traversable[Problem]) {
+    // Check if solving this task has the slightest chance of improving the
+    // current solution
+    for (p <- problems; rule <- rules) yield {
+      workList += new Task(this, task, p, rule)
     }
-
-    // Look for choose()
-    for (f <- program.definedFunctions.sortBy(_.id.toString) if f.body.isDefined) {
-      if (filterFuns.isEmpty || filterFuns.get.contains(f.id.toString)) {
-        treeCatamorphism(x => x, noop, actOnChoose(f), f.body.get)
-      }
-    }
-
-    solutions.reverse
   }
 }
diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala
index e8eba4143..a1632ec33 100644
--- a/src/main/scala/leon/synthesis/Task.scala
+++ b/src/main/scala/leon/synthesis/Task.scala
@@ -17,42 +17,61 @@ class Task(synth: Synthesizer,
     }
   }
 
-  var subProblems: List[Problem]               = Nil
-  var onSuccess: List[Solution] => Solution    = _
-  var subSolutions : Map[Problem, Solution]    = Map()
-  var subSolvers : Map[Problem, Task]          = Map()
-  var solution : Option[Solution]              = None
-
-  def isBetterSolutionThan(sol: Solution, osol: Option[Solution]): Boolean =
+  def isBetterSolutionThan(sol: Solution, osol: Option[Solution]): Boolean = {
     osol match {
       case Some(s) => s.complexity > sol.complexity
       case None => true
     }
+  }
+
+  var solution: Option[Solution] = None
+
+  class TaskStep(val subProblems: List[Problem]) {
+    var subSolutions = Map[Problem, Solution]()
+    var subSolvers   = Map[Problem, Task]()
+    var failures     = Set[Rule]()
+  }
+
+  var steps: List[TaskStep]                            = Nil
+  var nextSteps: List[List[Solution] => List[Problem]] = Nil
+  var onSuccess: List[Solution] => Solution            = _
+
+  def currentStep                 = steps.head
 
   def partlySolvedBy(t: Task, s: Solution) {
-    if (isBetterSolutionThan(s, subSolutions.get(t.problem))) {
-      subSolutions += t.problem -> s
-      subSolvers   += t.problem -> t
+    if (isBetterSolutionThan(s, currentStep.subSolutions.get(t.problem))) {
+      currentStep.subSolutions += t.problem -> s
+      currentStep.subSolvers   += t.problem -> t
+
+      if (currentStep.subSolutions.size == currentStep.subProblems.size) {
+        val solutions = currentStep.subProblems map currentStep.subSolutions
+
+        if (!nextSteps.isEmpty) {
+          val newProblems = nextSteps.head.apply(solutions)
+          nextSteps = nextSteps.tail
 
-      if (subSolutions.size == subProblems.size) {
-        solution = Some(onSuccess(subProblems map subSolutions))
-        parent.partlySolvedBy(this, solution.get)
+          synth.addProblems(this, newProblems)
+
+          steps = new TaskStep(newProblems) :: steps
+        } else {
+          solution = Some(onSuccess(solutions))
+          parent.partlySolvedBy(this, solution.get)
+        }
       }
     }
   }
 
-  var failures = Set[Rule]()
   def unsolvedBy(t: Task) {
-    failures += t.rule
+    currentStep.failures += t.rule
 
-    if (failures.size == synth.rules.size) {
+    if (currentStep.failures.size == synth.rules.size) {
       // We might want to report unsolved instead of solvedByChoose, depending
       // on the cases
       parent.partlySolvedBy(this, Solution.choose(problem))
     }
   }
 
-  def run: List[Problem] = {
+  def run() {
     rule.applyOn(this) match {
       case RuleSuccess(solution) =>
         // Solved
@@ -63,14 +82,10 @@ class Task(synth: Synthesizer,
         println(prefix+"Got: "+problem)
         println(prefix+"Solved with: "+solution)
 
-        Nil
-
-      case RuleDecomposed(subProblems, onSuccess) =>
-        this.subProblems  = subProblems
-        this.onSuccess    = onSuccess
-
-        val simplestSolution = onSuccess(subProblems.map(Solution.basic _))
-        minComplexity = new FixedSolComplexity(parent.minComplexity.value + simplestSolution.complexity.value)
+      case RuleMultiSteps(subProblems, interSteps, onSuccess) =>
+        this.steps = new TaskStep(subProblems) :: Nil
+        this.nextSteps = interSteps
+        this.onSuccess = onSuccess
 
         val prefix = "[%-20s] ".format(Option(rule).map(_.toString).getOrElse("root"))
         println(prefix+"Got: "+problem)
@@ -79,42 +94,29 @@ class Task(synth: Synthesizer,
           println(prefix+" - "+p)
         }
 
-        subProblems
+        synth.addProblems(this, subProblems)
 
       case RuleInapplicable =>
         parent.unsolvedBy(this)
-        Nil
     }
   }
 
-  var minComplexity: AbsSolComplexity = new FixedSolComplexity(0)
-
   override def toString = "Applying "+rule+" on "+problem
 }
 
 class RootTask(synth: Synthesizer, problem: Problem) extends Task(synth, null, problem, null) {
   var solver: Option[Task] = None
 
-  override def partlySolvedBy(t: Task, s: Solution) {
-    if (isBetterSolutionThan(s, solution)) {
-      solution = Some(s)
-      solver   = Some(t)
-    }
+  override def run() {
+    synth.addProblems(this, List(problem))
   }
 
-  override def unsolvedBy(t: Task) {
-    failures += t.rule
-
-    if (failures.size == synth.rules.size) {
-      // We might want to report unsolved instead of solvedByChoose, depending
-      // on the cases
-      solution = Some(Solution.choose(problem))
-      solver   = None
-    }
+  override def partlySolvedBy(t: Task, s: Solution) {
+    solution = Some(s)
+    solver   = Some(t)
   }
 
-  override def run: List[Problem] = {
-    List(problem)
+  override def unsolvedBy(t: Task) {
   }
 
 }
-- 
GitLab