From a9f406602a9273e2a6b60c4895f884625055d402 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Thu, 3 Jan 2013 17:12:44 +0100 Subject: [PATCH] Implement --parallel[=N] to specify the number of workers to use. On success, shutdown immediately by halting solvers. --- src/main/scala/leon/LeonOption.scala | 5 ++++ src/main/scala/leon/Main.scala | 6 ++-- .../scala/leon/synthesis/ParallelSearch.scala | 17 +++++++++-- .../scala/leon/synthesis/SynthesisPhase.scala | 28 +++++++++++++------ .../scala/leon/synthesis/Synthesizer.scala | 4 +-- .../leon/synthesis/SynthesizerOptions.scala | 2 +- .../scala/leon/synthesis/rules/Cegis.scala | 4 ++- .../search/AndOrGraphParallelSearch.scala | 3 +- 8 files changed, 50 insertions(+), 19 deletions(-) diff --git a/src/main/scala/leon/LeonOption.scala b/src/main/scala/leon/LeonOption.scala index da28a90d5..9859fcc92 100644 --- a/src/main/scala/leon/LeonOption.scala +++ b/src/main/scala/leon/LeonOption.scala @@ -31,6 +31,7 @@ sealed abstract class LeonOptionDef { val usageDesc: String val isFlag: Boolean } + case class LeonFlagOptionDef(name: String, usageOption: String, usageDesc: String) extends LeonOptionDef { val isFlag = true } @@ -39,6 +40,10 @@ case class LeonValueOptionDef(name: String, usageOption: String, usageDesc: Stri val isFlag = false } +case class LeonOptValueOptionDef(name: String, usageOption: String, usageDesc: String) extends LeonOptionDef { + val isFlag = false +} + object ListValue { def apply(values: Seq[String]) = values.mkString(":") def unapply(value: String): Option[Seq[String]] = Some(value.split(':').map(_.trim).filter(!_.isEmpty)) diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 56d66e601..9c81f0ce7 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -80,10 +80,10 @@ object Main { } if (allOptionsMap contains leonOpt.name) { - (allOptionsMap(leonOpt.name).isFlag, leonOpt) match { - case (true, LeonFlagOption(name)) => + (allOptionsMap(leonOpt.name), leonOpt) match { + case (_: LeonFlagOptionDef | _: LeonOptValueOptionDef, LeonFlagOption(name)) => Some(leonOpt) - case (false, LeonValueOption(name, value)) => + case (_: LeonValueOptionDef | _: LeonOptValueOptionDef, LeonValueOption(name, value)) => Some(leonOpt) case _ => reporter.error("Invalid option usage: " + opt) diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala index a6d436509..f85917212 100644 --- a/src/main/scala/leon/synthesis/ParallelSearch.scala +++ b/src/main/scala/leon/synthesis/ParallelSearch.scala @@ -9,10 +9,14 @@ import solvers.TrivialSolver class ParallelSearch(synth: Synthesizer, problem: Problem, rules: Set[Rule], - costModel: CostModel) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) { + costModel: CostModel, + nWorkers: Int) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel)), nWorkers) { import synth.reporter._ + // This is HOT shared memory, used only in stop() for shutting down solvers! + private[this] var contexts = List[SynthesisContext]() + def initWorkerContext(wr: ActorRef) = { val reporter = new SilentReporter val solver = new FairZ3Solver(synth.context.copy(reporter = reporter)) @@ -20,11 +24,20 @@ class ParallelSearch(synth: Synthesizer, solver.initZ3 - SynthesisContext(solver = solver, reporter = synth.reporter, shouldStop = synth.shouldStop) + val ctx = SynthesisContext(solver = solver, reporter = synth.reporter, shouldStop = synth.shouldStop) + + synchronized { + contexts = ctx :: contexts + } + + ctx } override def stop() = { synth.shouldStop.set(true) + for (ctx <- contexts) { + ctx.solver.halt() + } super.stop() } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 3e03d9a14..1cb409d6e 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -14,13 +14,13 @@ object SynthesisPhase extends LeonPhase[Program, Program] { val description = "Synthesis" override val definedOptions : Set[LeonOptionDef] = Set( - LeonFlagOptionDef( "inplace", "--inplace", "Debug level"), - LeonFlagOptionDef( "parallel", "--parallel", "Parallel synthesis search"), - LeonFlagOptionDef( "derivtrees", "--derivtrees", "Generate derivation trees"), - LeonFlagOptionDef( "firstonly", "--firstonly", "Stop as soon as one synthesis solution is found"), - LeonValueOptionDef("timeout", "--timeout=T", "Timeout after T seconds when searching for synthesis solutions .."), - LeonValueOptionDef("costmodel", "--costmodel=cm", "Use a specific cost model for this search"), - LeonValueOptionDef("functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..") + LeonFlagOptionDef( "inplace", "--inplace", "Debug level"), + LeonOptValueOptionDef("parallel", "--parallel[=N]", "Parallel synthesis search using N workers"), + LeonFlagOptionDef( "derivtrees", "--derivtrees", "Generate derivation trees"), + LeonFlagOptionDef( "firstonly", "--firstonly", "Stop as soon as one synthesis solution is found"), + LeonValueOptionDef( "timeout", "--timeout=T", "Timeout after T seconds when searching for synthesis solutions .."), + LeonValueOptionDef( "costmodel", "--costmodel=cm", "Use a specific cost model for this search"), + LeonValueOptionDef( "functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..") ) def run(ctx: LeonContext)(p: Program): Program = { @@ -38,8 +38,10 @@ object SynthesisPhase extends LeonPhase[Program, Program] { for(opt <- ctx.options) opt match { case LeonFlagOption("inplace") => inPlace = true + case LeonValueOption("functions", ListValue(fs)) => options = options.copy(filterFuns = Some(fs.toSet)) + case LeonValueOption("costmodel", cm) => CostModel.all.find(_.name.toLowerCase == cm.toLowerCase) match { case Some(model) => @@ -52,16 +54,26 @@ object SynthesisPhase extends LeonPhase[Program, Program] { ctx.reporter.fatalError(errorMsg) } + case v @ LeonValueOption("timeout", _) => v.asInt(ctx).foreach { t => options = options.copy(timeoutMs = Some(t.toLong)) } + case LeonFlagOption("firstonly") => options = options.copy(firstOnly = true) + case LeonFlagOption("parallel") => - options = options.copy(parallel = true) + options = options.copy(searchWorkers = 5) + + case o @ LeonValueOption("parallel", nWorkers) => + o.asInt(ctx).foreach { nWorkers => + options = options.copy(searchWorkers = nWorkers) + } + case LeonFlagOption("derivtrees") => options = options.copy(generateDerivationTrees = true) + case _ => } diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index aa054774a..b15e1ec69 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -32,8 +32,8 @@ class Synthesizer(val context : LeonContext, def synthesize(): (Solution, Boolean) = { - val search = if (options.parallel) { - new ParallelSearch(this, problem, rules, options.costModel) + val search = if (options.searchWorkers > 1) { + new ParallelSearch(this, problem, rules, options.costModel, options.searchWorkers) } else { new SimpleSearch(this, problem, rules, options.costModel) } diff --git a/src/main/scala/leon/synthesis/SynthesizerOptions.scala b/src/main/scala/leon/synthesis/SynthesizerOptions.scala index 768c31ce9..a460f03dc 100644 --- a/src/main/scala/leon/synthesis/SynthesizerOptions.scala +++ b/src/main/scala/leon/synthesis/SynthesizerOptions.scala @@ -4,7 +4,7 @@ package synthesis case class SynthesizerOptions( generateDerivationTrees: Boolean = false, filterFuns: Option[Set[String]] = None, - parallel: Boolean = false, + searchWorkers: Int = 1, firstOnly: Boolean = false, timeoutMs: Option[Long] = None, costModel: CostModel = NaiveCostModel diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 93898bcc8..8eafd160f 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -255,7 +255,9 @@ case object CEGIS extends Rule("CEGIS") { result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe)))) case _ => - sctx.reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.") + if (!sctx.shouldStop.get) { + sctx.reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.") + } continue = false } diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala index 0f126660e..88f378a8e 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala @@ -10,11 +10,10 @@ import akka.dispatch.Await abstract class AndOrGraphParallelSearch[WC, AT <: AOAndTask[S], OT <: AOOrTask[S], - S](og: AndOrGraph[AT, OT, S]) extends AndOrGraphSearch[AT, OT, S](og) { + S](og: AndOrGraph[AT, OT, S], nWorkers: Int) extends AndOrGraphSearch[AT, OT, S](og) { def initWorkerContext(w: ActorRef): WC - val nWorkers = 7 val timeout = 600.seconds var system: ActorSystem = _ -- GitLab