From b9548a321dc57e48837db565cea11efe169a794b Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Wed, 17 Jul 2013 16:23:43 +0200
Subject: [PATCH] Improve Leon's parallel search

- Search tree can be iterated over in order
- Worker pool get displayed periodically when stuck
- Make sure global caches are concurrent
---
 src/main/scala/leon/purescala/TreeOps.scala   |  40 ++--
 .../scala/leon/synthesis/ParallelSearch.scala |   9 +-
 .../leon/synthesis/search/AndOrGraph.scala    | 173 ++++++++++++++++--
 .../search/AndOrGraphParallelSearch.scala     |  29 ++-
 .../synthesis/search/AndOrGraphSearch.scala   |  57 +-----
 5 files changed, 217 insertions(+), 91 deletions(-)

diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 3e81b76cb..405055e44 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -5,6 +5,8 @@ package purescala
 
 import leon.solvers.Solver
 
+import scala.collection.concurrent.TrieMap
+
 object TreeOps {
   import Common._
   import TypeTrees._
@@ -667,20 +669,20 @@ object TreeOps {
     rec(expr, Map.empty)
   }
 
-  private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]()
   /** Rewrites all pattern-matching expressions into if-then-else expressions,
    * with additional error conditions. Does not introduce additional variables.
-   * We use a cache because we can. */
+   */
+  val cacheMtITE = new TrieMap[Expr, Expr]()
+
   def matchToIfThenElse(expr: Expr) : Expr = {
-    val toRet = if(matchConverterCache.isDefinedAt(expr)) {
-      matchConverterCache(expr)
-    } else {
-      val converted = convertMatchToIfThenElse(expr)
-      matchConverterCache(expr) = converted
-      converted
+    cacheMtITE.get(expr) match {
+      case Some(res) =>
+        res
+      case None =>
+        val r = convertMatchToIfThenElse(expr)
+        cacheMtITE += expr -> r
+        r
     }
-
-    toRet
   }
 
   def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false) : Expr = {
@@ -784,18 +786,18 @@ object TreeOps {
     searchAndReplaceDFS(rewritePM)(expr)
   }
 
-  private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]()
   /** Rewrites all map accesses with additional error conditions. */
+  val cacheMGWC = new TrieMap[Expr, Expr]()
+
   def mapGetWithChecks(expr: Expr) : Expr = {
-    val toRet = if (mapGetConverterCache.isDefinedAt(expr)) {
-      mapGetConverterCache(expr)
-    } else {
-      val converted = convertMapGet(expr)
-      mapGetConverterCache(expr) = converted
-      converted
+    cacheMGWC.get(expr) match {
+      case Some(res) =>
+        res
+      case None =>
+        val r = convertMapGet(expr)
+        cacheMGWC += expr -> r
+        r
     }
-
-    toRet
   }
 
   private def convertMapGet(expr: Expr) : Expr = {
diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala
index 57672fb4f..d9982a12d 100644
--- a/src/main/scala/leon/synthesis/ParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/ParallelSearch.scala
@@ -5,7 +5,7 @@ package synthesis
 
 import synthesis.search._
 import akka.actor._
-import solvers.z3.FairZ3Solver
+import solvers.z3.{FairZ3Solver,UninterpretedZ3Solver}
 import solvers.TrivialSolver
 
 class ParallelSearch(synth: Synthesizer,
@@ -27,10 +27,13 @@ class ParallelSearch(synth: Synthesizer,
     val reporter = new SilentReporter
     val solver = new FairZ3Solver(synth.context.copy(reporter = reporter))
     solver.setProgram(synth.program)
-
     solver.initZ3
 
-    val ctx = SynthesisContext.fromSynthesizer(synth).copy(solver = solver)
+    val  simpleSolver = new UninterpretedZ3Solver(synth.context.copy(reporter = reporter))
+    simpleSolver.setProgram(synth.program)
+    simpleSolver.initZ3
+
+    val ctx = SynthesisContext.fromSynthesizer(synth).copy(solver = solver, simpleSolver = simpleSolver)
 
     synchronized {
       contexts = ctx :: contexts
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraph.scala b/src/main/scala/leon/synthesis/search/AndOrGraph.scala
index 6fd4f3ca8..a6e3e19ac 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraph.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraph.scala
@@ -2,8 +2,7 @@
 
 package leon.synthesis.search
 
-trait AOTask[S] {
-}
+trait AOTask[S] { }
 
 trait AOAndTask[S] extends AOTask[S] {
   def composeSolution(sols: List[S]): Option[S]
@@ -20,12 +19,35 @@ trait AOCostModel[AT <: AOAndTask[S], OT <: AOOrTask[S], S] {
 class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val costModel: AOCostModel[AT, OT, S]) {
   var tree: OrTree = RootNode
 
+  object LeafOrdering extends Ordering[Leaf] {
+    def compare(a: Leaf, b: Leaf) = {
+      val diff = scala.math.Ordering.Iterable[Int].compare(a.minReachCost, b.minReachCost)
+      if (diff == 0) {
+        if (a == b) {
+          0
+        } else {
+          a.## - b.##
+        }
+      } else {
+        diff
+      }
+    }
+  }
+
+  val leaves = collection.mutable.TreeSet()(LeafOrdering)
+  leaves += RootNode
+
   trait Tree {
     val task : AOTask[S]
     val parent: Node[_]
 
     def minCost: Cost
 
+    var minReachCost = List[Int]()
+
+    def updateMinReach(reverseParent: List[Int]);
+    def removeLeaves();
+
     var isTrustworthy: Boolean = true
     var solution: Option[S] = None
     var isUnsolvable: Boolean = false
@@ -43,7 +65,22 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
 
 
   trait Leaf extends Tree {
-    def minCost = costModel.taskCost(task)
+    val minCost = costModel.taskCost(task)
+
+    var removedLeaf = false;
+
+    def updateMinReach(reverseParent: List[Int]) {
+      if (!removedLeaf) {
+        leaves -= this
+        minReachCost = (minCost.value :: reverseParent).reverse
+        leaves += this
+      }
+    }
+
+    def removeLeaves() {
+      removedLeaf = true
+      leaves -= this
+    }
   }
 
   trait Node[T <: Tree] extends Tree {
@@ -69,23 +106,75 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
           subCosts.foldLeft(costModel.taskCost(task))(_ + _)
       }
       if (minCost != old) {
-        Option(parent).foreach(_.updateMin())
+        if (parent ne null) {
+            parent.updateMin()
+        } else {
+            // Reached the root, propagate minReach up
+            updateMinReach(Nil)
+        }
+      } else {
+        // Reached boundary of update, propagate minReach up
+        updateMinReach(minReachCost.reverse.tail)
       }
     }
 
+    def updateMinReach(reverseParent: List[Int]) {
+      val rev = minCost.value :: reverseParent
+
+      minReachCost = rev.reverse
+
+      subProblems.values.foreach(_.updateMinReach(rev))
+    }
+
+    def removeLeaves() {
+      subProblems.values.foreach(_.removeLeaves())
+    }
 
     def unsolvable(l: OrTree) {
       isUnsolvable = true
+
+      this.removeLeaves()
+
       parent.unsolvable(this)
     }
 
     def expandLeaf(l: OrLeaf, succ: List[AT]) {
-      subProblems += l.task -> new OrNode(this, succ, l.task)
+      //println("[[2]] Expanding "+l.task+" to: ")
+      //for (t <- succ) {
+      //  println(" - "+t)
+      //}
+
+      //println("BEFORE: In leaves we have: ")
+      //for (i <- leaves.iterator) {
+      //  println("-> "+i.minReachCost+" == "+i.task)
+      //}
+
+      if (!l.removedLeaf) {
+        l.removeLeaves()
+
+        val orNode = new OrNode(this, succ, l.task)
+        subProblems += l.task -> orNode
+
+        updateMin()
+
+        leaves ++= orNode.andLeaves.values
+      }
+
+      //println("AFTER: In leaves we have: ")
+      //for (i <- leaves.iterator) {
+      //  println("-> "+i.minReachCost+" == "+i.task)
+      //}
     }
 
     def notifySolution(sub: OrTree, sol: S) {
       subSolutions += sub.task -> sol
 
+      sub match {
+        case l: Leaf =>
+          l.removeLeaves()
+        case _ =>
+      }
+
       if (subSolutions.size == subProblems.size) {
         task.composeSolution(subTasks.map(subSolutions)) match {
           case Some(sol) =>
@@ -113,8 +202,15 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
 
   object RootNode extends OrLeaf(null, root) {
 
+    minReachCost = List(minCost.value)
+
     override def expandWith(succ: List[AT]) {
-      tree = new OrNode(null, succ, root)
+      this.removeLeaves()
+
+      val orNode = new OrNode(null, succ, root)
+      tree = orNode
+
+      leaves ++= orNode.andLeaves.values
     }
   }
 
@@ -127,7 +223,8 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
 
 
   class OrNode(val parent: AndNode, val altTasks: List[AT], val task: OT) extends OrTree with Node[AndTree] {
-    var alternatives: Map[AT, AndTree] = altTasks.map(t => t -> new AndLeaf(this, t)).toMap
+    val andLeaves                      = altTasks.map(t => t -> new AndLeaf(this, t)).toMap
+    var alternatives: Map[AT, AndTree] = andLeaves
     var triedAlternatives              = Map[AT, AndTree]()
     var minAlternative: AndTree        = _
     var minCost                        = costModel.taskCost(task)
@@ -139,8 +236,18 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
         minAlternative = alternatives.values.minBy(_.minCost)
         val old = minCost 
         minCost        = minAlternative.minCost
+
+        //println("Updated minCost of "+task+" from "+old.value+" to "+minCost.value)
+        
         if (minCost != old) {
-          Option(parent).foreach(_.updateMin())
+          if (parent ne null) {
+            parent.updateMin()
+          } else {
+            // reached root, propagate minReach up
+            updateMinReach(Nil)
+          }
+        } else {
+          updateMinReach(minReachCost.reverse.tail)
         }
       } else {
         minAlternative = null
@@ -148,11 +255,24 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
       }
     }
 
+    def updateMinReach(reverseParent: List[Int]) {
+      val rev = minCost.value :: reverseParent
+
+      minReachCost = rev.reverse
+
+      alternatives.values.foreach(_.updateMinReach(rev))
+    }
+
+    def removeLeaves() {
+      alternatives.values.foreach(_.removeLeaves())
+    }
+
     def unsolvable(l: AndTree) {
       if (alternatives contains l.task) {
         triedAlternatives += l.task -> alternatives(l.task)
         alternatives -= l.task
 
+        l.removeLeaves()
 
         if (alternatives.isEmpty) {
           isUnsolvable = true
@@ -166,16 +286,43 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
     }
 
     def expandLeaf(l: AndLeaf, succ: List[OT]) {
-      val n = new AndNode(this, succ, l.task)
-      n.subProblems = succ.map(t => t -> new OrLeaf(n, t)).toMap
-      n.updateMin()
+      //println("[[1]] Expanding "+l.task+" to: ")
+      //for (t <- succ) {
+      //  println(" - "+t)
+      //}
+
+      //println("BEFORE: In leaves we have: ")
+      //for (i <- leaves.iterator) {
+      //  println("-> "+i.minReachCost+" == "+i.task)
+      //}
+
+      if (!l.removedLeaf) {
+        l.removeLeaves()
+
+        val n = new AndNode(this, succ, l.task)
+
+        val newLeaves = succ.map(t => t -> new OrLeaf(n, t)).toMap
+        n.subProblems = newLeaves
 
-      alternatives += l.task -> n
+        alternatives += l.task -> n
 
-      updateMin()
+        n.updateMin()
+
+        updateMin()
+
+        leaves  ++= newLeaves.values
+      }
+
+
+      //println("AFTER: In leaves we have: ")
+      //for (i <- leaves.iterator) {
+      //  println("-> "+i.minReachCost+" == "+i.task)
+      //}
     }
 
     def notifySolution(sub: AndTree, sol: S) {
+      this.removeLeaves()
+
       solution match {
         case Some(preSol) if (costModel.solutionCost(preSol) < costModel.solutionCost(sol)) =>
           isTrustworthy  = sub.isTrustworthy
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
index c2f08f128..b2725b809 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
@@ -12,7 +12,8 @@ import akka.pattern.AskTimeoutException
 abstract class AndOrGraphParallelSearch[WC,
                                         AT <: AOAndTask[S],
                                         OT <: AOOrTask[S],
-                                        S](og: AndOrGraph[AT, OT, S], nWorkers: Int) extends AndOrGraphSearch[AT, OT, S](og) {
+                                        S](og: AndOrGraph[AT, OT, S],
+                                           nWorkers: Int) extends AndOrGraphSearch[AT, OT, S](og) {
 
   def initWorkerContext(w: ActorRef): WC
 
@@ -66,6 +67,19 @@ abstract class AndOrGraphParallelSearch[WC,
     case object NoTaskReady
   }
 
+  def getNextLeaves(idleWorkers: Map[ActorRef, Option[g.Leaf]], workingWorkers: Map[ActorRef, Option[g.Leaf]]): List[g.Leaf] = {
+    val processing = workingWorkers.values.flatten.toSet
+
+    val ts = System.currentTimeMillis();
+
+    val str = nextLeaves()
+      .filterNot(processing)
+      .take(idleWorkers.size)
+      .toList
+
+    str
+  }
+
   class Master extends Actor {
     import Protocol._
 
@@ -78,7 +92,7 @@ abstract class AndOrGraphParallelSearch[WC,
 
       assert(idleWorkers.size > 0)
 
-      nextLeaves(idleWorkers.size) match {
+      getNextLeaves(idleWorkers, workingWorkers) match {
         case Nil =>
           if (workingWorkers.isEmpty) {
             outer ! SearchDone
@@ -88,7 +102,6 @@ abstract class AndOrGraphParallelSearch[WC,
 
         case ls =>
           for ((w, leaf) <- idleWorkers.keySet zip ls) {
-            processing += leaf
             leaf match {
               case al: g.AndLeaf =>
                 workers += w -> Some(al)
@@ -101,6 +114,8 @@ abstract class AndOrGraphParallelSearch[WC,
       }
     }
 
+    context.setReceiveTimeout(10.seconds)
+
     def receive = {
       case BeginSearch =>
         outer = sender
@@ -130,10 +145,16 @@ abstract class AndOrGraphParallelSearch[WC,
 
       case Terminated(w) =>
         if (workers contains w) {
-          processing -= workers(w).get
           workers -= w
         }
 
+      case ReceiveTimeout =>
+        println("@ Worker status:")
+        for ((w, t) <- workers if t.isDefined) {
+          println("@  - "+w.toString+": "+t.get.task)
+        }
+
+
     }
   }
 
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
index 42643d7a7..2693f41af 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
@@ -6,56 +6,13 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S],
                                 OT <: AOOrTask[S],
                                 S](val g: AndOrGraph[AT, OT, S]) {
 
-  var processing = Set[g.Leaf]()
-
-  def nextLeaves(k: Int): List[g.Leaf] = {
-    import scala.math.Ordering.Implicits._
-
-    case class WL(t: g.Leaf, costs: List[Int])
-
-    var leaves = List[WL]()
-
-    def collectFromAnd(at: g.AndTree, costs: List[Int]) {
-      val newCosts = at.minCost.value :: costs
-      if (!at.isSolved && !at.isUnsolvable) {
-        at match {
-          case l: g.Leaf =>
-            collectLeaf(WL(l, newCosts.reverse)) 
-          case a: g.AndNode =>
-            for (o <- a.subTasks.filterNot(a.subSolutions.keySet).map(a.subProblems)) {
-              collectFromOr(o, newCosts)
-            }
-        }
-      }
-    }
-
-    def collectFromOr(ot: g.OrTree, costs: List[Int]) {
-      val newCosts = ot.minCost.value :: costs
-
-      if (!ot.isSolved && !ot.isUnsolvable) {
-        ot match {
-          case l: g.Leaf =>
-            collectLeaf(WL(l, newCosts.reverse))
-          case o: g.OrNode =>
-            for (a <- o.alternatives.values) {
-              collectFromAnd(a, newCosts)
-            }
-        }
-      }
-    }
-
-    def collectLeaf(wl: WL) {
-      if (!processing(wl.t)) {
-        leaves = wl :: leaves
-      }
-    }
-
-    collectFromOr(g.tree, Nil)
-
-    leaves.sortBy(_.costs).map(_.t)
+  def nextLeaves(): Iterable[g.Leaf] = {
+    g.leaves
   }
 
-  def nextLeaf(): Option[g.Leaf] = nextLeaves(1).headOption
+  def nextLeaf(): Option[g.Leaf] = {
+    nextLeaves().headOption
+  }
 
   abstract class ExpandResult[T <: AOTask[S]]
   case class Expanded[T <: AOTask[S]](sub: List[T]) extends ExpandResult[T]
@@ -84,8 +41,6 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S],
     if (g.tree.isSolved) {
       stop()
     }
-
-    processing -= al
   }
 
   def onExpansion(ol: g.OrLeaf, res: ExpandResult[AT]) {
@@ -104,8 +59,6 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S],
     if (g.tree.isSolved) {
       stop()
     }
-
-    processing -= ol
   }
 
   def traversePathFrom(n: g.Tree, path: List[Int]): Option[g.Tree] = {
-- 
GitLab