From b9548a321dc57e48837db565cea11efe169a794b Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Wed, 17 Jul 2013 16:23:43 +0200 Subject: [PATCH] Improve Leon's parallel search - Search tree can be iterated over in order - Worker pool get displayed periodically when stuck - Make sure global caches are concurrent --- src/main/scala/leon/purescala/TreeOps.scala | 40 ++-- .../scala/leon/synthesis/ParallelSearch.scala | 9 +- .../leon/synthesis/search/AndOrGraph.scala | 173 ++++++++++++++++-- .../search/AndOrGraphParallelSearch.scala | 29 ++- .../synthesis/search/AndOrGraphSearch.scala | 57 +----- 5 files changed, 217 insertions(+), 91 deletions(-) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 3e81b76cb..405055e44 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -5,6 +5,8 @@ package purescala import leon.solvers.Solver +import scala.collection.concurrent.TrieMap + object TreeOps { import Common._ import TypeTrees._ @@ -667,20 +669,20 @@ object TreeOps { rec(expr, Map.empty) } - private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() /** Rewrites all pattern-matching expressions into if-then-else expressions, * with additional error conditions. Does not introduce additional variables. - * We use a cache because we can. */ + */ + val cacheMtITE = new TrieMap[Expr, Expr]() + def matchToIfThenElse(expr: Expr) : Expr = { - val toRet = if(matchConverterCache.isDefinedAt(expr)) { - matchConverterCache(expr) - } else { - val converted = convertMatchToIfThenElse(expr) - matchConverterCache(expr) = converted - converted + cacheMtITE.get(expr) match { + case Some(res) => + res + case None => + val r = convertMatchToIfThenElse(expr) + cacheMtITE += expr -> r + r } - - toRet } def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false) : Expr = { @@ -784,18 +786,18 @@ object TreeOps { searchAndReplaceDFS(rewritePM)(expr) } - private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() /** Rewrites all map accesses with additional error conditions. */ + val cacheMGWC = new TrieMap[Expr, Expr]() + def mapGetWithChecks(expr: Expr) : Expr = { - val toRet = if (mapGetConverterCache.isDefinedAt(expr)) { - mapGetConverterCache(expr) - } else { - val converted = convertMapGet(expr) - mapGetConverterCache(expr) = converted - converted + cacheMGWC.get(expr) match { + case Some(res) => + res + case None => + val r = convertMapGet(expr) + cacheMGWC += expr -> r + r } - - toRet } private def convertMapGet(expr: Expr) : Expr = { diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala index 57672fb4f..d9982a12d 100644 --- a/src/main/scala/leon/synthesis/ParallelSearch.scala +++ b/src/main/scala/leon/synthesis/ParallelSearch.scala @@ -5,7 +5,7 @@ package synthesis import synthesis.search._ import akka.actor._ -import solvers.z3.FairZ3Solver +import solvers.z3.{FairZ3Solver,UninterpretedZ3Solver} import solvers.TrivialSolver class ParallelSearch(synth: Synthesizer, @@ -27,10 +27,13 @@ class ParallelSearch(synth: Synthesizer, val reporter = new SilentReporter val solver = new FairZ3Solver(synth.context.copy(reporter = reporter)) solver.setProgram(synth.program) - solver.initZ3 - val ctx = SynthesisContext.fromSynthesizer(synth).copy(solver = solver) + val simpleSolver = new UninterpretedZ3Solver(synth.context.copy(reporter = reporter)) + simpleSolver.setProgram(synth.program) + simpleSolver.initZ3 + + val ctx = SynthesisContext.fromSynthesizer(synth).copy(solver = solver, simpleSolver = simpleSolver) synchronized { contexts = ctx :: contexts diff --git a/src/main/scala/leon/synthesis/search/AndOrGraph.scala b/src/main/scala/leon/synthesis/search/AndOrGraph.scala index 6fd4f3ca8..a6e3e19ac 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraph.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraph.scala @@ -2,8 +2,7 @@ package leon.synthesis.search -trait AOTask[S] { -} +trait AOTask[S] { } trait AOAndTask[S] extends AOTask[S] { def composeSolution(sols: List[S]): Option[S] @@ -20,12 +19,35 @@ trait AOCostModel[AT <: AOAndTask[S], OT <: AOOrTask[S], S] { class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val costModel: AOCostModel[AT, OT, S]) { var tree: OrTree = RootNode + object LeafOrdering extends Ordering[Leaf] { + def compare(a: Leaf, b: Leaf) = { + val diff = scala.math.Ordering.Iterable[Int].compare(a.minReachCost, b.minReachCost) + if (diff == 0) { + if (a == b) { + 0 + } else { + a.## - b.## + } + } else { + diff + } + } + } + + val leaves = collection.mutable.TreeSet()(LeafOrdering) + leaves += RootNode + trait Tree { val task : AOTask[S] val parent: Node[_] def minCost: Cost + var minReachCost = List[Int]() + + def updateMinReach(reverseParent: List[Int]); + def removeLeaves(); + var isTrustworthy: Boolean = true var solution: Option[S] = None var isUnsolvable: Boolean = false @@ -43,7 +65,22 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos trait Leaf extends Tree { - def minCost = costModel.taskCost(task) + val minCost = costModel.taskCost(task) + + var removedLeaf = false; + + def updateMinReach(reverseParent: List[Int]) { + if (!removedLeaf) { + leaves -= this + minReachCost = (minCost.value :: reverseParent).reverse + leaves += this + } + } + + def removeLeaves() { + removedLeaf = true + leaves -= this + } } trait Node[T <: Tree] extends Tree { @@ -69,23 +106,75 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos subCosts.foldLeft(costModel.taskCost(task))(_ + _) } if (minCost != old) { - Option(parent).foreach(_.updateMin()) + if (parent ne null) { + parent.updateMin() + } else { + // Reached the root, propagate minReach up + updateMinReach(Nil) + } + } else { + // Reached boundary of update, propagate minReach up + updateMinReach(minReachCost.reverse.tail) } } + def updateMinReach(reverseParent: List[Int]) { + val rev = minCost.value :: reverseParent + + minReachCost = rev.reverse + + subProblems.values.foreach(_.updateMinReach(rev)) + } + + def removeLeaves() { + subProblems.values.foreach(_.removeLeaves()) + } def unsolvable(l: OrTree) { isUnsolvable = true + + this.removeLeaves() + parent.unsolvable(this) } def expandLeaf(l: OrLeaf, succ: List[AT]) { - subProblems += l.task -> new OrNode(this, succ, l.task) + //println("[[2]] Expanding "+l.task+" to: ") + //for (t <- succ) { + // println(" - "+t) + //} + + //println("BEFORE: In leaves we have: ") + //for (i <- leaves.iterator) { + // println("-> "+i.minReachCost+" == "+i.task) + //} + + if (!l.removedLeaf) { + l.removeLeaves() + + val orNode = new OrNode(this, succ, l.task) + subProblems += l.task -> orNode + + updateMin() + + leaves ++= orNode.andLeaves.values + } + + //println("AFTER: In leaves we have: ") + //for (i <- leaves.iterator) { + // println("-> "+i.minReachCost+" == "+i.task) + //} } def notifySolution(sub: OrTree, sol: S) { subSolutions += sub.task -> sol + sub match { + case l: Leaf => + l.removeLeaves() + case _ => + } + if (subSolutions.size == subProblems.size) { task.composeSolution(subTasks.map(subSolutions)) match { case Some(sol) => @@ -113,8 +202,15 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos object RootNode extends OrLeaf(null, root) { + minReachCost = List(minCost.value) + override def expandWith(succ: List[AT]) { - tree = new OrNode(null, succ, root) + this.removeLeaves() + + val orNode = new OrNode(null, succ, root) + tree = orNode + + leaves ++= orNode.andLeaves.values } } @@ -127,7 +223,8 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos class OrNode(val parent: AndNode, val altTasks: List[AT], val task: OT) extends OrTree with Node[AndTree] { - var alternatives: Map[AT, AndTree] = altTasks.map(t => t -> new AndLeaf(this, t)).toMap + val andLeaves = altTasks.map(t => t -> new AndLeaf(this, t)).toMap + var alternatives: Map[AT, AndTree] = andLeaves var triedAlternatives = Map[AT, AndTree]() var minAlternative: AndTree = _ var minCost = costModel.taskCost(task) @@ -139,8 +236,18 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos minAlternative = alternatives.values.minBy(_.minCost) val old = minCost minCost = minAlternative.minCost + + //println("Updated minCost of "+task+" from "+old.value+" to "+minCost.value) + if (minCost != old) { - Option(parent).foreach(_.updateMin()) + if (parent ne null) { + parent.updateMin() + } else { + // reached root, propagate minReach up + updateMinReach(Nil) + } + } else { + updateMinReach(minReachCost.reverse.tail) } } else { minAlternative = null @@ -148,11 +255,24 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos } } + def updateMinReach(reverseParent: List[Int]) { + val rev = minCost.value :: reverseParent + + minReachCost = rev.reverse + + alternatives.values.foreach(_.updateMinReach(rev)) + } + + def removeLeaves() { + alternatives.values.foreach(_.removeLeaves()) + } + def unsolvable(l: AndTree) { if (alternatives contains l.task) { triedAlternatives += l.task -> alternatives(l.task) alternatives -= l.task + l.removeLeaves() if (alternatives.isEmpty) { isUnsolvable = true @@ -166,16 +286,43 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos } def expandLeaf(l: AndLeaf, succ: List[OT]) { - val n = new AndNode(this, succ, l.task) - n.subProblems = succ.map(t => t -> new OrLeaf(n, t)).toMap - n.updateMin() + //println("[[1]] Expanding "+l.task+" to: ") + //for (t <- succ) { + // println(" - "+t) + //} + + //println("BEFORE: In leaves we have: ") + //for (i <- leaves.iterator) { + // println("-> "+i.minReachCost+" == "+i.task) + //} + + if (!l.removedLeaf) { + l.removeLeaves() + + val n = new AndNode(this, succ, l.task) + + val newLeaves = succ.map(t => t -> new OrLeaf(n, t)).toMap + n.subProblems = newLeaves - alternatives += l.task -> n + alternatives += l.task -> n - updateMin() + n.updateMin() + + updateMin() + + leaves ++= newLeaves.values + } + + + //println("AFTER: In leaves we have: ") + //for (i <- leaves.iterator) { + // println("-> "+i.minReachCost+" == "+i.task) + //} } def notifySolution(sub: AndTree, sol: S) { + this.removeLeaves() + solution match { case Some(preSol) if (costModel.solutionCost(preSol) < costModel.solutionCost(sol)) => isTrustworthy = sub.isTrustworthy diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala index c2f08f128..b2725b809 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala @@ -12,7 +12,8 @@ import akka.pattern.AskTimeoutException abstract class AndOrGraphParallelSearch[WC, AT <: AOAndTask[S], OT <: AOOrTask[S], - S](og: AndOrGraph[AT, OT, S], nWorkers: Int) extends AndOrGraphSearch[AT, OT, S](og) { + S](og: AndOrGraph[AT, OT, S], + nWorkers: Int) extends AndOrGraphSearch[AT, OT, S](og) { def initWorkerContext(w: ActorRef): WC @@ -66,6 +67,19 @@ abstract class AndOrGraphParallelSearch[WC, case object NoTaskReady } + def getNextLeaves(idleWorkers: Map[ActorRef, Option[g.Leaf]], workingWorkers: Map[ActorRef, Option[g.Leaf]]): List[g.Leaf] = { + val processing = workingWorkers.values.flatten.toSet + + val ts = System.currentTimeMillis(); + + val str = nextLeaves() + .filterNot(processing) + .take(idleWorkers.size) + .toList + + str + } + class Master extends Actor { import Protocol._ @@ -78,7 +92,7 @@ abstract class AndOrGraphParallelSearch[WC, assert(idleWorkers.size > 0) - nextLeaves(idleWorkers.size) match { + getNextLeaves(idleWorkers, workingWorkers) match { case Nil => if (workingWorkers.isEmpty) { outer ! SearchDone @@ -88,7 +102,6 @@ abstract class AndOrGraphParallelSearch[WC, case ls => for ((w, leaf) <- idleWorkers.keySet zip ls) { - processing += leaf leaf match { case al: g.AndLeaf => workers += w -> Some(al) @@ -101,6 +114,8 @@ abstract class AndOrGraphParallelSearch[WC, } } + context.setReceiveTimeout(10.seconds) + def receive = { case BeginSearch => outer = sender @@ -130,10 +145,16 @@ abstract class AndOrGraphParallelSearch[WC, case Terminated(w) => if (workers contains w) { - processing -= workers(w).get workers -= w } + case ReceiveTimeout => + println("@ Worker status:") + for ((w, t) <- workers if t.isDefined) { + println("@ - "+w.toString+": "+t.get.task) + } + + } } diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala index 42643d7a7..2693f41af 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala @@ -6,56 +6,13 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val g: AndOrGraph[AT, OT, S]) { - var processing = Set[g.Leaf]() - - def nextLeaves(k: Int): List[g.Leaf] = { - import scala.math.Ordering.Implicits._ - - case class WL(t: g.Leaf, costs: List[Int]) - - var leaves = List[WL]() - - def collectFromAnd(at: g.AndTree, costs: List[Int]) { - val newCosts = at.minCost.value :: costs - if (!at.isSolved && !at.isUnsolvable) { - at match { - case l: g.Leaf => - collectLeaf(WL(l, newCosts.reverse)) - case a: g.AndNode => - for (o <- a.subTasks.filterNot(a.subSolutions.keySet).map(a.subProblems)) { - collectFromOr(o, newCosts) - } - } - } - } - - def collectFromOr(ot: g.OrTree, costs: List[Int]) { - val newCosts = ot.minCost.value :: costs - - if (!ot.isSolved && !ot.isUnsolvable) { - ot match { - case l: g.Leaf => - collectLeaf(WL(l, newCosts.reverse)) - case o: g.OrNode => - for (a <- o.alternatives.values) { - collectFromAnd(a, newCosts) - } - } - } - } - - def collectLeaf(wl: WL) { - if (!processing(wl.t)) { - leaves = wl :: leaves - } - } - - collectFromOr(g.tree, Nil) - - leaves.sortBy(_.costs).map(_.t) + def nextLeaves(): Iterable[g.Leaf] = { + g.leaves } - def nextLeaf(): Option[g.Leaf] = nextLeaves(1).headOption + def nextLeaf(): Option[g.Leaf] = { + nextLeaves().headOption + } abstract class ExpandResult[T <: AOTask[S]] case class Expanded[T <: AOTask[S]](sub: List[T]) extends ExpandResult[T] @@ -84,8 +41,6 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], if (g.tree.isSolved) { stop() } - - processing -= al } def onExpansion(ol: g.OrLeaf, res: ExpandResult[AT]) { @@ -104,8 +59,6 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S], if (g.tree.isSolved) { stop() } - - processing -= ol } def traversePathFrom(n: g.Tree, path: List[Int]): Option[g.Tree] = { -- GitLab