From a9f406602a9273e2a6b60c4895f884625055d402 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Thu, 3 Jan 2013 17:12:44 +0100
Subject: [PATCH] Implement --parallel[=N] to specify the number of workers to
 use. On success, shutdown immediately by halting solvers.

---
 src/main/scala/leon/LeonOption.scala          |  5 ++++
 src/main/scala/leon/Main.scala                |  6 ++--
 .../scala/leon/synthesis/ParallelSearch.scala | 17 +++++++++--
 .../scala/leon/synthesis/SynthesisPhase.scala | 28 +++++++++++++------
 .../scala/leon/synthesis/Synthesizer.scala    |  4 +--
 .../leon/synthesis/SynthesizerOptions.scala   |  2 +-
 .../scala/leon/synthesis/rules/Cegis.scala    |  4 ++-
 .../search/AndOrGraphParallelSearch.scala     |  3 +-
 8 files changed, 50 insertions(+), 19 deletions(-)

diff --git a/src/main/scala/leon/LeonOption.scala b/src/main/scala/leon/LeonOption.scala
index da28a90d5..9859fcc92 100644
--- a/src/main/scala/leon/LeonOption.scala
+++ b/src/main/scala/leon/LeonOption.scala
@@ -31,6 +31,7 @@ sealed abstract class LeonOptionDef {
   val usageDesc: String
   val isFlag: Boolean
 }
+
 case class LeonFlagOptionDef(name: String, usageOption: String, usageDesc: String) extends LeonOptionDef {
   val isFlag = true
 }
@@ -39,6 +40,10 @@ case class LeonValueOptionDef(name: String, usageOption: String, usageDesc: Stri
   val isFlag = false
 }
 
+case class LeonOptValueOptionDef(name: String, usageOption: String, usageDesc: String) extends LeonOptionDef {
+  val isFlag = false
+}
+
 object ListValue {
   def apply(values: Seq[String]) = values.mkString(":")
   def unapply(value: String): Option[Seq[String]] = Some(value.split(':').map(_.trim).filter(!_.isEmpty))
diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala
index 56d66e601..9c81f0ce7 100644
--- a/src/main/scala/leon/Main.scala
+++ b/src/main/scala/leon/Main.scala
@@ -80,10 +80,10 @@ object Main {
       }
 
       if (allOptionsMap contains leonOpt.name) {
-        (allOptionsMap(leonOpt.name).isFlag, leonOpt) match {
-          case (true,  LeonFlagOption(name)) =>
+        (allOptionsMap(leonOpt.name), leonOpt) match {
+          case (_: LeonFlagOptionDef  | _: LeonOptValueOptionDef,  LeonFlagOption(name)) =>
             Some(leonOpt)
-          case (false, LeonValueOption(name, value)) =>
+          case (_: LeonValueOptionDef | _: LeonOptValueOptionDef, LeonValueOption(name, value)) =>
             Some(leonOpt)
           case _ =>
             reporter.error("Invalid option usage: " + opt)
diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala
index a6d436509..f85917212 100644
--- a/src/main/scala/leon/synthesis/ParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/ParallelSearch.scala
@@ -9,10 +9,14 @@ import solvers.TrivialSolver
 class ParallelSearch(synth: Synthesizer,
                      problem: Problem,
                      rules: Set[Rule],
-                     costModel: CostModel) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) {
+                     costModel: CostModel,
+                     nWorkers: Int) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel)), nWorkers) {
 
   import synth.reporter._
 
+  // This is HOT shared memory, used only in stop() for shutting down solvers!
+  private[this] var contexts = List[SynthesisContext]()
+
   def initWorkerContext(wr: ActorRef) = {
     val reporter = new SilentReporter
     val solver = new FairZ3Solver(synth.context.copy(reporter = reporter))
@@ -20,11 +24,20 @@ class ParallelSearch(synth: Synthesizer,
 
     solver.initZ3
 
-    SynthesisContext(solver = solver, reporter = synth.reporter, shouldStop = synth.shouldStop)
+    val ctx = SynthesisContext(solver = solver, reporter = synth.reporter, shouldStop = synth.shouldStop)
+
+    synchronized {
+      contexts = ctx :: contexts
+    }
+
+    ctx
   }
 
   override def stop() = {
     synth.shouldStop.set(true)
+    for (ctx <- contexts) {
+      ctx.solver.halt()
+    }
     super.stop()
   }
 
diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala
index 3e03d9a14..1cb409d6e 100644
--- a/src/main/scala/leon/synthesis/SynthesisPhase.scala
+++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala
@@ -14,13 +14,13 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
   val description = "Synthesis"
 
   override val definedOptions : Set[LeonOptionDef] = Set(
-    LeonFlagOptionDef( "inplace",    "--inplace",         "Debug level"),
-    LeonFlagOptionDef( "parallel",   "--parallel",        "Parallel synthesis search"),
-    LeonFlagOptionDef( "derivtrees", "--derivtrees",      "Generate derivation trees"),
-    LeonFlagOptionDef( "firstonly",  "--firstonly",       "Stop as soon as one synthesis solution is found"),
-    LeonValueOptionDef("timeout",    "--timeout=T",       "Timeout after T seconds when searching for synthesis solutions .."),
-    LeonValueOptionDef("costmodel",  "--costmodel=cm",    "Use a specific cost model for this search"),
-    LeonValueOptionDef("functions",  "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..")
+    LeonFlagOptionDef(    "inplace",    "--inplace",         "Debug level"),
+    LeonOptValueOptionDef("parallel",   "--parallel[=N]",    "Parallel synthesis search using N workers"),
+    LeonFlagOptionDef(    "derivtrees", "--derivtrees",      "Generate derivation trees"),
+    LeonFlagOptionDef(    "firstonly",  "--firstonly",       "Stop as soon as one synthesis solution is found"),
+    LeonValueOptionDef(   "timeout",    "--timeout=T",       "Timeout after T seconds when searching for synthesis solutions .."),
+    LeonValueOptionDef(   "costmodel",  "--costmodel=cm",    "Use a specific cost model for this search"),
+    LeonValueOptionDef(   "functions",  "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..")
   )
 
   def run(ctx: LeonContext)(p: Program): Program = {
@@ -38,8 +38,10 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
     for(opt <- ctx.options) opt match {
       case LeonFlagOption("inplace") =>
         inPlace = true
+
       case LeonValueOption("functions", ListValue(fs)) =>
         options = options.copy(filterFuns = Some(fs.toSet))
+
       case LeonValueOption("costmodel", cm) =>
         CostModel.all.find(_.name.toLowerCase == cm.toLowerCase) match {
           case Some(model) =>
@@ -52,16 +54,26 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
 
             ctx.reporter.fatalError(errorMsg)
         }
+
       case v @ LeonValueOption("timeout", _) =>
         v.asInt(ctx).foreach { t =>
           options = options.copy(timeoutMs  = Some(t.toLong))
         } 
+
       case LeonFlagOption("firstonly") =>
         options = options.copy(firstOnly = true)
+
       case LeonFlagOption("parallel") =>
-        options = options.copy(parallel = true)
+        options = options.copy(searchWorkers = 5)
+
+      case o @ LeonValueOption("parallel", nWorkers) =>
+        o.asInt(ctx).foreach { nWorkers =>
+          options = options.copy(searchWorkers = nWorkers)
+        }
+
       case LeonFlagOption("derivtrees") =>
         options = options.copy(generateDerivationTrees = true)
+
       case _ =>
     }
 
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index aa054774a..b15e1ec69 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -32,8 +32,8 @@ class Synthesizer(val context : LeonContext,
 
   def synthesize(): (Solution, Boolean) = {
 
-    val search = if (options.parallel) {
-      new ParallelSearch(this, problem, rules, options.costModel)
+  val search = if (options.searchWorkers > 1) {
+      new ParallelSearch(this, problem, rules, options.costModel, options.searchWorkers)
     } else {
       new SimpleSearch(this, problem, rules, options.costModel)
     }
diff --git a/src/main/scala/leon/synthesis/SynthesizerOptions.scala b/src/main/scala/leon/synthesis/SynthesizerOptions.scala
index 768c31ce9..a460f03dc 100644
--- a/src/main/scala/leon/synthesis/SynthesizerOptions.scala
+++ b/src/main/scala/leon/synthesis/SynthesizerOptions.scala
@@ -4,7 +4,7 @@ package synthesis
 case class SynthesizerOptions(
   generateDerivationTrees: Boolean = false,
   filterFuns: Option[Set[String]]  = None,
-  parallel: Boolean                = false,
+  searchWorkers: Int               = 1,
   firstOnly: Boolean               = false,
   timeoutMs: Option[Long]          = None,
   costModel: CostModel             = NaiveCostModel
diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala
index 93898bcc8..8eafd160f 100644
--- a/src/main/scala/leon/synthesis/rules/Cegis.scala
+++ b/src/main/scala/leon/synthesis/rules/Cegis.scala
@@ -255,7 +255,9 @@ case object CEGIS extends Rule("CEGIS") {
                         result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe))))
 
                       case _ =>
-                        sctx.reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.")
+                        if (!sctx.shouldStop.get) {
+                          sctx.reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.")
+                        }
                         continue = false
                     }
 
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
index 0f126660e..88f378a8e 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
@@ -10,11 +10,10 @@ import akka.dispatch.Await
 abstract class AndOrGraphParallelSearch[WC,
                                         AT <: AOAndTask[S],
                                         OT <: AOOrTask[S],
-                                        S](og: AndOrGraph[AT, OT, S]) extends AndOrGraphSearch[AT, OT, S](og) {
+                                        S](og: AndOrGraph[AT, OT, S], nWorkers: Int) extends AndOrGraphSearch[AT, OT, S](og) {
 
   def initWorkerContext(w: ActorRef): WC
 
-  val nWorkers = 7
   val timeout = 600.seconds
 
   var system: ActorSystem = _
-- 
GitLab