Skip to content
Snippets Groups Projects
Commit a9f40660 authored by Etienne Kneuss's avatar Etienne Kneuss
Browse files

Implement --parallel[=N] to specify the number of workers to use. On success,...

Implement --parallel[=N] to specify the number of workers to use. On success, shutdown immediately by halting solvers.
parent ce6ef073
No related branches found
No related tags found
No related merge requests found
...@@ -31,6 +31,7 @@ sealed abstract class LeonOptionDef { ...@@ -31,6 +31,7 @@ sealed abstract class LeonOptionDef {
val usageDesc: String val usageDesc: String
val isFlag: Boolean val isFlag: Boolean
} }
case class LeonFlagOptionDef(name: String, usageOption: String, usageDesc: String) extends LeonOptionDef { case class LeonFlagOptionDef(name: String, usageOption: String, usageDesc: String) extends LeonOptionDef {
val isFlag = true val isFlag = true
} }
...@@ -39,6 +40,10 @@ case class LeonValueOptionDef(name: String, usageOption: String, usageDesc: Stri ...@@ -39,6 +40,10 @@ case class LeonValueOptionDef(name: String, usageOption: String, usageDesc: Stri
val isFlag = false val isFlag = false
} }
case class LeonOptValueOptionDef(name: String, usageOption: String, usageDesc: String) extends LeonOptionDef {
val isFlag = false
}
object ListValue { object ListValue {
def apply(values: Seq[String]) = values.mkString(":") def apply(values: Seq[String]) = values.mkString(":")
def unapply(value: String): Option[Seq[String]] = Some(value.split(':').map(_.trim).filter(!_.isEmpty)) def unapply(value: String): Option[Seq[String]] = Some(value.split(':').map(_.trim).filter(!_.isEmpty))
......
...@@ -80,10 +80,10 @@ object Main { ...@@ -80,10 +80,10 @@ object Main {
} }
if (allOptionsMap contains leonOpt.name) { if (allOptionsMap contains leonOpt.name) {
(allOptionsMap(leonOpt.name).isFlag, leonOpt) match { (allOptionsMap(leonOpt.name), leonOpt) match {
case (true, LeonFlagOption(name)) => case (_: LeonFlagOptionDef | _: LeonOptValueOptionDef, LeonFlagOption(name)) =>
Some(leonOpt) Some(leonOpt)
case (false, LeonValueOption(name, value)) => case (_: LeonValueOptionDef | _: LeonOptValueOptionDef, LeonValueOption(name, value)) =>
Some(leonOpt) Some(leonOpt)
case _ => case _ =>
reporter.error("Invalid option usage: " + opt) reporter.error("Invalid option usage: " + opt)
......
...@@ -9,10 +9,14 @@ import solvers.TrivialSolver ...@@ -9,10 +9,14 @@ import solvers.TrivialSolver
class ParallelSearch(synth: Synthesizer, class ParallelSearch(synth: Synthesizer,
problem: Problem, problem: Problem,
rules: Set[Rule], rules: Set[Rule],
costModel: CostModel) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) { costModel: CostModel,
nWorkers: Int) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel)), nWorkers) {
import synth.reporter._ import synth.reporter._
// This is HOT shared memory, used only in stop() for shutting down solvers!
private[this] var contexts = List[SynthesisContext]()
def initWorkerContext(wr: ActorRef) = { def initWorkerContext(wr: ActorRef) = {
val reporter = new SilentReporter val reporter = new SilentReporter
val solver = new FairZ3Solver(synth.context.copy(reporter = reporter)) val solver = new FairZ3Solver(synth.context.copy(reporter = reporter))
...@@ -20,11 +24,20 @@ class ParallelSearch(synth: Synthesizer, ...@@ -20,11 +24,20 @@ class ParallelSearch(synth: Synthesizer,
solver.initZ3 solver.initZ3
SynthesisContext(solver = solver, reporter = synth.reporter, shouldStop = synth.shouldStop) val ctx = SynthesisContext(solver = solver, reporter = synth.reporter, shouldStop = synth.shouldStop)
synchronized {
contexts = ctx :: contexts
}
ctx
} }
override def stop() = { override def stop() = {
synth.shouldStop.set(true) synth.shouldStop.set(true)
for (ctx <- contexts) {
ctx.solver.halt()
}
super.stop() super.stop()
} }
......
...@@ -14,13 +14,13 @@ object SynthesisPhase extends LeonPhase[Program, Program] { ...@@ -14,13 +14,13 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
val description = "Synthesis" val description = "Synthesis"
override val definedOptions : Set[LeonOptionDef] = Set( override val definedOptions : Set[LeonOptionDef] = Set(
LeonFlagOptionDef( "inplace", "--inplace", "Debug level"), LeonFlagOptionDef( "inplace", "--inplace", "Debug level"),
LeonFlagOptionDef( "parallel", "--parallel", "Parallel synthesis search"), LeonOptValueOptionDef("parallel", "--parallel[=N]", "Parallel synthesis search using N workers"),
LeonFlagOptionDef( "derivtrees", "--derivtrees", "Generate derivation trees"), LeonFlagOptionDef( "derivtrees", "--derivtrees", "Generate derivation trees"),
LeonFlagOptionDef( "firstonly", "--firstonly", "Stop as soon as one synthesis solution is found"), LeonFlagOptionDef( "firstonly", "--firstonly", "Stop as soon as one synthesis solution is found"),
LeonValueOptionDef("timeout", "--timeout=T", "Timeout after T seconds when searching for synthesis solutions .."), LeonValueOptionDef( "timeout", "--timeout=T", "Timeout after T seconds when searching for synthesis solutions .."),
LeonValueOptionDef("costmodel", "--costmodel=cm", "Use a specific cost model for this search"), LeonValueOptionDef( "costmodel", "--costmodel=cm", "Use a specific cost model for this search"),
LeonValueOptionDef("functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..") LeonValueOptionDef( "functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,..")
) )
def run(ctx: LeonContext)(p: Program): Program = { def run(ctx: LeonContext)(p: Program): Program = {
...@@ -38,8 +38,10 @@ object SynthesisPhase extends LeonPhase[Program, Program] { ...@@ -38,8 +38,10 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
for(opt <- ctx.options) opt match { for(opt <- ctx.options) opt match {
case LeonFlagOption("inplace") => case LeonFlagOption("inplace") =>
inPlace = true inPlace = true
case LeonValueOption("functions", ListValue(fs)) => case LeonValueOption("functions", ListValue(fs)) =>
options = options.copy(filterFuns = Some(fs.toSet)) options = options.copy(filterFuns = Some(fs.toSet))
case LeonValueOption("costmodel", cm) => case LeonValueOption("costmodel", cm) =>
CostModel.all.find(_.name.toLowerCase == cm.toLowerCase) match { CostModel.all.find(_.name.toLowerCase == cm.toLowerCase) match {
case Some(model) => case Some(model) =>
...@@ -52,16 +54,26 @@ object SynthesisPhase extends LeonPhase[Program, Program] { ...@@ -52,16 +54,26 @@ object SynthesisPhase extends LeonPhase[Program, Program] {
ctx.reporter.fatalError(errorMsg) ctx.reporter.fatalError(errorMsg)
} }
case v @ LeonValueOption("timeout", _) => case v @ LeonValueOption("timeout", _) =>
v.asInt(ctx).foreach { t => v.asInt(ctx).foreach { t =>
options = options.copy(timeoutMs = Some(t.toLong)) options = options.copy(timeoutMs = Some(t.toLong))
} }
case LeonFlagOption("firstonly") => case LeonFlagOption("firstonly") =>
options = options.copy(firstOnly = true) options = options.copy(firstOnly = true)
case LeonFlagOption("parallel") => case LeonFlagOption("parallel") =>
options = options.copy(parallel = true) options = options.copy(searchWorkers = 5)
case o @ LeonValueOption("parallel", nWorkers) =>
o.asInt(ctx).foreach { nWorkers =>
options = options.copy(searchWorkers = nWorkers)
}
case LeonFlagOption("derivtrees") => case LeonFlagOption("derivtrees") =>
options = options.copy(generateDerivationTrees = true) options = options.copy(generateDerivationTrees = true)
case _ => case _ =>
} }
......
...@@ -32,8 +32,8 @@ class Synthesizer(val context : LeonContext, ...@@ -32,8 +32,8 @@ class Synthesizer(val context : LeonContext,
def synthesize(): (Solution, Boolean) = { def synthesize(): (Solution, Boolean) = {
val search = if (options.parallel) { val search = if (options.searchWorkers > 1) {
new ParallelSearch(this, problem, rules, options.costModel) new ParallelSearch(this, problem, rules, options.costModel, options.searchWorkers)
} else { } else {
new SimpleSearch(this, problem, rules, options.costModel) new SimpleSearch(this, problem, rules, options.costModel)
} }
......
...@@ -4,7 +4,7 @@ package synthesis ...@@ -4,7 +4,7 @@ package synthesis
case class SynthesizerOptions( case class SynthesizerOptions(
generateDerivationTrees: Boolean = false, generateDerivationTrees: Boolean = false,
filterFuns: Option[Set[String]] = None, filterFuns: Option[Set[String]] = None,
parallel: Boolean = false, searchWorkers: Int = 1,
firstOnly: Boolean = false, firstOnly: Boolean = false,
timeoutMs: Option[Long] = None, timeoutMs: Option[Long] = None,
costModel: CostModel = NaiveCostModel costModel: CostModel = NaiveCostModel
......
...@@ -255,7 +255,9 @@ case object CEGIS extends Rule("CEGIS") { ...@@ -255,7 +255,9 @@ case object CEGIS extends Rule("CEGIS") {
result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe)))) result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(mapping))).setType(tpe))))
case _ => case _ =>
sctx.reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.") if (!sctx.shouldStop.get) {
sctx.reporter.warning("Solver returned 'UNKNOWN' in a CEGIS iteration.")
}
continue = false continue = false
} }
......
...@@ -10,11 +10,10 @@ import akka.dispatch.Await ...@@ -10,11 +10,10 @@ import akka.dispatch.Await
abstract class AndOrGraphParallelSearch[WC, abstract class AndOrGraphParallelSearch[WC,
AT <: AOAndTask[S], AT <: AOAndTask[S],
OT <: AOOrTask[S], OT <: AOOrTask[S],
S](og: AndOrGraph[AT, OT, S]) extends AndOrGraphSearch[AT, OT, S](og) { S](og: AndOrGraph[AT, OT, S], nWorkers: Int) extends AndOrGraphSearch[AT, OT, S](og) {
def initWorkerContext(w: ActorRef): WC def initWorkerContext(w: ActorRef): WC
val nWorkers = 7
val timeout = 600.seconds val timeout = 600.seconds
var system: ActorSystem = _ var system: ActorSystem = _
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment