From c76506d311a2fbbb438da756ff1c9a77f44728ce Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Wed, 5 Dec 2012 22:40:55 +0100
Subject: [PATCH] Improve search for leaves to target unsolved branches, fix
 one issue regarding parallelism

---
 .../leon/solvers/z3/AbstractZ3Solver.scala    |  2 ++
 src/main/scala/leon/synthesis/CostModel.scala | 14 ++++----
 .../scala/leon/synthesis/Heuristics.scala     |  2 +-
 .../scala/leon/synthesis/ParallelSearch.scala |  8 +++--
 src/main/scala/leon/synthesis/Rules.scala     |  8 ++---
 .../scala/leon/synthesis/SimpleSearch.scala   |  2 +-
 .../scala/leon/synthesis/rules/Cegis.scala    |  2 +-
 .../synthesis/search/AndOrGraphSearch.scala   | 32 +++++++++++--------
 8 files changed, 40 insertions(+), 30 deletions(-)

diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
index db658c4c6..202c96810 100644
--- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
@@ -100,6 +100,7 @@ trait AbstractZ3Solver extends solvers.IncrementalSolverBuilder {
       counter = 0
 
       z3 = new Z3Context(z3cfg)
+
       prepareSorts
       prepareFunctions
 
@@ -109,6 +110,7 @@ trait AbstractZ3Solver extends solvers.IncrementalSolverBuilder {
 
   protected[leon] def restartZ3() {
     isInitialized = false
+
     initZ3()
 
     exprToZ3Id = Map.empty
diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala
index 5d71c6ec8..447da1fb9 100644
--- a/src/main/scala/leon/synthesis/CostModel.scala
+++ b/src/main/scala/leon/synthesis/CostModel.scala
@@ -9,7 +9,13 @@ import synthesis.search.Cost
 abstract class CostModel(val name: String) {
   def solutionCost(s: Solution): Cost
   def problemCost(p: Problem): Cost
-  def ruleAppCost(r: Rule, app: RuleApplication): Cost
+
+  def ruleAppCost(r: Rule, app: RuleApplication): Cost = new Cost {
+    val subSols = (1 to app.subProblemsCount).map {i => Solution.simplest }.toList
+    val simpleSol = app.onSuccess(subSols)
+
+    val value = solutionCost(simpleSol).value
+  }
 }
 
 object CostModel {
@@ -30,10 +36,4 @@ case object NaiveCostModel extends CostModel("Naive") {
     val value = p.xs.size
   }
 
-  def ruleAppCost(r: Rule, app: RuleApplication): Cost = new Cost {
-    val subSols = (1 to app.subProblemsCount).map {i => Solution.simplest }.toList
-    val simpleSol = app.onSuccess(subSols)
-
-    val value = solutionCost(simpleSol).value
-  }
 }
diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala
index decfdeffb..b03e01667 100644
--- a/src/main/scala/leon/synthesis/Heuristics.scala
+++ b/src/main/scala/leon/synthesis/Heuristics.scala
@@ -36,7 +36,7 @@ object HeuristicStep {
 
   def apply(sctx: SynthesisContext, problem: Problem, subProblems: List[Problem], onSuccess: List[Solution] => Solution) = {
     new RuleApplication(subProblems.size, onSuccess.andThen(verifyPre(sctx, problem))) {
-      def apply() = RuleDecomposed(subProblems, onSuccess)
+      def apply(sctx: SynthesisContext) = RuleDecomposed(subProblems, onSuccess)
     }
   }
 }
diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala
index 7902675f0..7fa8a62de 100644
--- a/src/main/scala/leon/synthesis/ParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/ParallelSearch.scala
@@ -18,13 +18,15 @@ class ParallelSearch(synth: Synthesizer,
     val solver = new FairZ3Solver(synth.context.copy(reporter = reporter))
     solver.setProgram(synth.program)
 
+    solver.initZ3
+
     SynthesisContext(solver = solver, reporter = synth.reporter)
   }
 
   def expandAndTask(ref: ActorRef, sctx: SynthesisContext)(t: TaskRunRule) = {
     val prefix = "[%-20s] ".format(Option(t.rule).getOrElse("?"))
 
-    t.app.apply() match {
+    t.app.apply(sctx) match {
       case RuleSuccess(sol) =>
         info(prefix+"Got: "+t.problem)
         info(prefix+"Solved with: "+sol)
@@ -45,7 +47,9 @@ class ParallelSearch(synth: Synthesizer,
   }
 
   def expandOrTask(ref: ActorRef, sctx: SynthesisContext)(t: TaskTryRules) = {
-    val sub = rules.flatMap ( r => r.attemptToApplyOn(sctx, t.p).alternatives.map(TaskRunRule(t.p, r, _)) )
+    val sub = rules.flatMap { r => 
+      r.attemptToApplyOn(sctx, t.p).alternatives.map(TaskRunRule(t.p, r, _))
+    }
 
     if (!sub.isEmpty) {
       Expanded(sub.toList)
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index 750171b1d..3d2836c89 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -33,7 +33,7 @@ object RuleInapplicable extends RuleResult(List())
 abstract class RuleApplication(val subProblemsCount: Int,
                                val onSuccess: List[Solution] => Solution) {
 
-  def apply(): RuleApplicationResult
+  def apply(sctx: SynthesisContext): RuleApplicationResult
 }
 
 abstract class RuleImmediateApplication extends RuleApplication(0, s => Solution.simplest)
@@ -46,7 +46,7 @@ case object RuleApplicationImpossible extends RuleApplicationResult
 object RuleFastApplication {
   def apply(sub: List[Problem], onSuccess: List[Solution] => Solution) = {
     new RuleApplication(sub.size, onSuccess) {
-      def apply() = RuleDecomposed(sub, onSuccess)
+      def apply(sctx: SynthesisContext) = RuleDecomposed(sub, onSuccess)
     }
   }
 }
@@ -54,7 +54,7 @@ object RuleFastApplication {
 object RuleFastInapplicable {
   def apply() = {
     RuleResult(List(new RuleApplication(0, ls => Solution.simplest) {
-      def apply() = RuleApplicationImpossible
+      def apply(sctx: SynthesisContext) = RuleApplicationImpossible
     }))
   }
 }
@@ -68,7 +68,7 @@ object RuleFastStep {
 object RuleFastSuccess {
   def apply(solution: Solution) = {
     RuleResult(List(new RuleApplication(0, ls => solution) {
-      def apply() = RuleSuccess(solution)
+      def apply(sctx: SynthesisContext) = RuleSuccess(solution)
     }))
   }
 }
diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala
index e619510c1..d92de1d1e 100644
--- a/src/main/scala/leon/synthesis/SimpleSearch.scala
+++ b/src/main/scala/leon/synthesis/SimpleSearch.scala
@@ -38,7 +38,7 @@ class SimpleSearch(synth: Synthesizer,
   def expandAndTask(t: TaskRunRule): ExpandResult[TaskTryRules] = {
     val prefix = "[%-20s] ".format(Option(t.rule).getOrElse("?"))
 
-    t.app.apply() match {
+    t.app.apply(sctx) match {
       case RuleSuccess(sol) =>
         info(prefix+"Got: "+t.problem)
         info(prefix+"Solved with: "+sol)
diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index 1580e81b7..d3665ef42 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -123,7 +123,7 @@ case object CEGIS extends Rule("CEGIS", 150) {
     val (exprsA, others) = ands.partition(e => (variablesOf(e) & xsSet).isEmpty)
     if (exprsA.isEmpty) {
       val res = new RuleImmediateApplication {
-        def apply() = {
+        def apply(sctx: SynthesisContext) = {
           var result: Option[RuleApplicationResult]   = None
 
           var ass = p.as.toSet
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
index 75ee3292f..1291babe9 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
@@ -15,26 +15,30 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S],
 
     def collectFromAnd(at: g.AndTree, costs: List[Int]) {
       val newCosts = at.minCost.value :: costs
-      at match {
-        case l: g.Leaf =>
-          collectLeaf(WL(l, newCosts.reverse)) 
-        case a: g.AndNode =>
-          for (o <- (a.subProblems -- a.subSolutions.keySet).values) {
-            collectFromOr(o, newCosts)
-          }
+      if (!at.isSolved) {
+        at match {
+          case l: g.Leaf =>
+            collectLeaf(WL(l, newCosts.reverse)) 
+          case a: g.AndNode =>
+            for (o <- (a.subProblems -- a.subSolutions.keySet).values) {
+              collectFromOr(o, newCosts)
+            }
+        }
       }
     }
 
     def collectFromOr(ot: g.OrTree, costs: List[Int]) {
       val newCosts = ot.minCost.value :: costs
 
-      ot match {
-        case l: g.Leaf =>
-          collectLeaf(WL(l, newCosts.reverse))
-        case o: g.OrNode =>
-          for (a <- o.alternatives.values) {
-            collectFromAnd(a, newCosts)
-          }
+      if (!ot.isSolved) {
+        ot match {
+          case l: g.Leaf =>
+            collectLeaf(WL(l, newCosts.reverse))
+          case o: g.OrNode =>
+            for (a <- o.alternatives.values) {
+              collectFromAnd(a, newCosts)
+            }
+        }
       }
     }
 
-- 
GitLab