diff --git a/src/main/scala/leon/synthesis/Cost.scala b/src/main/scala/leon/synthesis/Cost.scala index 7d989b3476c18edd1ed6a0d16a948dc8da585aa6..f8fb6d6944ef9e1f2a5e6ef473e8c32f1ede8b80 100644 --- a/src/main/scala/leon/synthesis/Cost.scala +++ b/src/main/scala/leon/synthesis/Cost.scala @@ -6,22 +6,30 @@ import purescala.TreeOps._ import synthesis.search.Cost -case class SolutionCost(s: Solution) extends Cost { - val value = { - val chooses = collectChooses(s.toExpr) - val chooseCost = chooses.foldLeft(0)((i, c) => i + ProblemCost(Problem.fromChoose(c)).value) +abstract class CostModel(name: String) { + def solutionCost(s: Solution): Cost + def problemCost(p: Problem): Cost + def ruleAppCost(r: Rule, app: RuleApplication): Cost +} + +case object NaiveCostModel extends CostModel("Naive") { + def solutionCost(s: Solution): Cost = new Cost { + val value = { + val chooses = collectChooses(s.toExpr) + val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c)).value) - formulaSize(s.toExpr) + chooseCost + formulaSize(s.toExpr) + chooseCost + } } -} -case class ProblemCost(p: Problem) extends Cost { - val value = p.xs.size -} + def problemCost(p: Problem): Cost = new Cost { + val value = p.xs.size + } -case class RuleApplicationCost(rule: Rule, app: RuleApplication) extends Cost { - val subSols = (1 to app.subProblemsCount).map {i => Solution.simplest }.toList - val simpleSol = app.onSuccess(subSols) + def ruleAppCost(r: Rule, app: RuleApplication): Cost = new Cost { + val subSols = (1 to app.subProblemsCount).map {i => Solution.simplest }.toList + val simpleSol = app.onSuccess(subSols) - val value = SolutionCost(simpleSol).value + val value = solutionCost(simpleSol).value + } } diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala index 9d24d1d27ca966640d313bd805e9d793ed31dd5a..7902675f07087301624b99dd8fdcf67ed701f7d2 100644 --- a/src/main/scala/leon/synthesis/ParallelSearch.scala +++ b/src/main/scala/leon/synthesis/ParallelSearch.scala @@ -8,7 +8,8 @@ import solvers.TrivialSolver class ParallelSearch(synth: Synthesizer, problem: Problem, - rules: Set[Rule]) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem))) { + rules: Set[Rule], + costModel: CostModel) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) { import synth.reporter._ diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala index b58d15cee91b71a849fb9d3a43e5d953258534ae..e619510c1f9badb871d7a29e70b775962f517480 100644 --- a/src/main/scala/leon/synthesis/SimpleSearch.scala +++ b/src/main/scala/leon/synthesis/SimpleSearch.scala @@ -4,8 +4,6 @@ package synthesis import synthesis.search._ case class TaskRunRule(problem: Problem, rule: Rule, app: RuleApplication) extends AOAndTask[Solution] { - val cost = RuleApplicationCost(rule, app) - def composeSolution(sols: List[Solution]): Solution = { app.onSuccess(sols) } @@ -14,14 +12,24 @@ case class TaskRunRule(problem: Problem, rule: Rule, app: RuleApplication) exten } case class TaskTryRules(p: Problem) extends AOOrTask[Solution] { - val cost = ProblemCost(p) - override def toString = p.toString } +case class SearchCostModel(cm: CostModel) extends AOCostModel[TaskRunRule, TaskTryRules, Solution] { + def taskCost(t: AOTask[Solution]) = t match { + case ttr: TaskRunRule => + cm.ruleAppCost(ttr.rule, ttr.app) + case trr: TaskTryRules => + cm.problemCost(trr.p) + } + + def solutionCost(s: Solution) = cm.solutionCost(s) +} + class SimpleSearch(synth: Synthesizer, problem: Problem, - rules: Set[Rule]) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem))) { + rules: Set[Rule], + costModel: CostModel) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) { import synth.reporter._ diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 471d154f74f18e4bce0f8a335f54f16811eed236..602b5f7716c66b99c25ba660befe9549ce7130d7 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -5,15 +5,11 @@ import leon.purescala.Trees._ import leon.purescala.Definitions._ import leon.purescala.TreeOps._ -import synthesis.search._ - // Defines a synthesis solution of the form: // ⟨ P | T ⟩ -class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) extends AOSolution { +class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { override def toString = "⟨ "+pre+" | "+defs.mkString(" ")+" "+term+" ⟩" - val cost: Cost = SolutionCost(this) - def toExpr = { val result = if (pre == BooleanLiteral(true)) { term diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index baf62236d14a2289cc4393500ecffdb77f0895df..ad21c6f7b074ad43860593ce81e36af31a690a41 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -31,30 +31,26 @@ object SynthesisPhase extends LeonPhase[Program, Program] { val uninterpretedZ3 = new UninterpretedZ3Solver(silentContext) uninterpretedZ3.setProgram(p) + var options = SynthesizerOptions() var inPlace = false - var genTrees = false - var firstOnly = false - var parallel = false - var filterFun: Option[Seq[String]] = None - var timeoutMs: Option[Long] = None for(opt <- ctx.options) opt match { case LeonFlagOption("inplace") => inPlace = true case LeonValueOption("functions", ListValue(fs)) => - filterFun = Some(fs) + options = options.copy(filterFuns = Some(fs.toSet)) case LeonValueOption("timeout", t) => try { - timeoutMs = Some(t.toLong) + options = options.copy(timeoutMs = Some(t.toLong)) } catch { case _: Throwable => } case LeonFlagOption("firstonly") => - firstOnly = true + options = options.copy(firstOnly = true) case LeonFlagOption("parallel") => - parallel = true + options = options.copy(parallel = true) case LeonFlagOption("derivtrees") => - genTrees = true + options = options.copy(generateDerivationTrees = true) case _ => } @@ -71,11 +67,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { p, problem, Rules.all ++ Heuristics.all, - genTrees, - filterFun.map(_.toSet), - parallel, - firstOnly, - timeoutMs) + options) val sol = synth.synthesize() solutions += ch -> (f, sol) @@ -87,7 +79,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { // Look for choose() for (f <- program.definedFunctions.sortBy(_.id.toString) if f.body.isDefined) { - if (filterFun.isEmpty || filterFun.get.contains(f.id.toString)) { + if (options.filterFuns.isEmpty || options.filterFuns.get.contains(f.id.toString)) { treeCatamorphism(x => x, noop, actOnChoose(f), f.body.get) } } diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 78ca68e29103e5b327316a4cf1dd0aaa39391f50..fe8f2ef8fedb453d950011581c7d5e208f5ce922 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -20,11 +20,8 @@ class Synthesizer(val context : LeonContext, val program: Program, val problem: Problem, val rules: Set[Rule], - generateDerivationTrees: Boolean = false, - filterFuns: Option[Set[String]] = None, - parallel: Boolean = false, - firstOnly: Boolean = false, - timeoutMs: Option[Long] = None) { + val options: SynthesizerOptions) { + protected[synthesis] val reporter = context.reporter import reporter.{error,warning,info,fatalError} @@ -33,10 +30,10 @@ class Synthesizer(val context : LeonContext, def synthesize(): Solution = { - val search = if (parallel) { - new ParallelSearch(this, problem, rules) + val search = if (options.parallel) { + new ParallelSearch(this, problem, rules, options.costModel) } else { - new SimpleSearch(this, problem, rules) + new SimpleSearch(this, problem, rules, options.costModel) } val sigINT = new Signal("INT") @@ -60,8 +57,9 @@ class Synthesizer(val context : LeonContext, val diff = System.currentTimeMillis()-ts reporter.info("Finished in "+diff+"ms") - if (generateDerivationTrees) { - new AndOrGraphDotConverter(search.g, firstOnly).writeFile("derivation"+AndOrGraphDotConverterCounter.next()+".dot") + if (options.generateDerivationTrees) { + val converter = new AndOrGraphDotConverter(search.g, options.firstOnly) + converter.writeFile("derivation"+AndOrGraphDotConverterCounter.next()+".dot") } res match { diff --git a/src/main/scala/leon/synthesis/SynthesizerOptions.scala b/src/main/scala/leon/synthesis/SynthesizerOptions.scala new file mode 100644 index 0000000000000000000000000000000000000000..768c31ce901767ed4548039e881618c666a86e29 --- /dev/null +++ b/src/main/scala/leon/synthesis/SynthesizerOptions.scala @@ -0,0 +1,11 @@ +package leon +package synthesis + +case class SynthesizerOptions( + generateDerivationTrees: Boolean = false, + filterFuns: Option[Set[String]] = None, + parallel: Boolean = false, + firstOnly: Boolean = false, + timeoutMs: Option[Long] = None, + costModel: CostModel = NaiveCostModel +) diff --git a/src/main/scala/leon/synthesis/search/AndOrGraph.scala b/src/main/scala/leon/synthesis/search/AndOrGraph.scala index 297c10fa7407f8a6443f1b06d79022802c986d96..2ac6881a2bc0ed52e2787e27eb4e7e7dcf078cf2 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraph.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraph.scala @@ -1,21 +1,21 @@ package leon.synthesis.search -trait AOTask[S <: AOSolution] { - def cost: Cost +trait AOTask[S] { } -trait AOAndTask[S <: AOSolution] extends AOTask[S] { +trait AOAndTask[S] extends AOTask[S] { def composeSolution(sols: List[S]): S } -trait AOOrTask[S <: AOSolution] extends AOTask[S] { +trait AOOrTask[S] extends AOTask[S] { } -trait AOSolution { - def cost: Cost +trait AOCostModel[AT <: AOAndTask[S], OT <: AOOrTask[S], S] { + def taskCost(at: AOTask[S]): Cost + def solutionCost(s: S): Cost } -class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val root: OT) { +class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val costModel: AOCostModel[AT, OT, S]) { var tree: OrTree = RootNode trait Tree { @@ -40,7 +40,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo trait Leaf extends Tree { - def minCost = task.cost + def minCost = costModel.taskCost(task) } trait Node[T <: Tree] extends Tree { @@ -59,11 +59,11 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo val old = minCost minCost = solution match { case Some(s) => - s.cost + costModel.solutionCost(s) case _ => val subCosts = subProblems.values.map(_.minCost) - subCosts.foldLeft(task.cost)(_ + _) + subCosts.foldLeft(costModel.taskCost(task))(_ + _) } if (minCost != old) { Option(parent).foreach(_.updateMin()) @@ -125,7 +125,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo class OrNode(val parent: AndNode, var alternatives: Map[AT, AndTree], val task: OT) extends OrTree with Node[AndTree] { var triedAlternatives = Map[AT, AndTree]() var minAlternative: AndTree = _ - var minCost = task.cost + var minCost = costModel.taskCost(task) def updateMin() { if (!alternatives.isEmpty) { @@ -166,7 +166,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo def notifySolution(sub: AndTree, sol: S) { solution match { - case Some(preSol) if (preSol.cost < sol.cost) => + case Some(preSol) if (costModel.solutionCost(preSol) < costModel.solutionCost(sol)) => solution = Some(sol) minAlternative = sub diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala b/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala index 931e276bde7482b3ff25110d2a1ed8804f1c2cac..062ada1d740988396d7e25ca4c3bead66607440e 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala @@ -2,7 +2,7 @@ package leon.synthesis.search class AndOrGraphDotConverter[AT <: AOAndTask[S], OT <: AOOrTask[S], - S <: AOSolution](val g: AndOrGraph[AT, OT, S], firstOnly: Boolean) { + S](val g: AndOrGraph[AT, OT, S], firstOnly: Boolean) { private[this] var _nextID = 0 @@ -97,7 +97,7 @@ class AndOrGraphDotConverter[AT <: AOAndTask[S], (if (t.isSolved) "palegreen" else if (t.isUnsolvable) "firebrick" else "white", "") } - res append " "+name+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">self: "+t.task.cost.value+" | tree-min: "+t.minCost.value+"</TD></TR><TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(t.task.toString)+"</TD></TR>"; + res append " "+name+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">self: "+g.costModel.taskCost(t.task).value+" | tree-min: "+t.minCost.value+"</TD></TR><TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(t.task.toString)+"</TD></TR>"; if (t.isSolved) { res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(t.solution.get.toString)+"</TD></TR>" diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala index be7ee349832f2c4d67f8b8ea07511474cc52a051..a9a3201aed8d3694a2fd715da6b1792bc3ed0746 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala @@ -9,7 +9,7 @@ import akka.dispatch.Await abstract class AndOrGraphParallelSearch[WC, AT <: AOAndTask[S], OT <: AOOrTask[S], - S <: AOSolution](og: AndOrGraph[AT, OT, S]) extends AndOrGraphSearch[AT, OT, S](og) { + S](og: AndOrGraph[AT, OT, S]) extends AndOrGraphSearch[AT, OT, S](og) { def initWorkerContext(w: ActorRef): WC diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala index 151f40ae5a0ee47f4a02d27062c550083cdbb4e7..2406bc1ca721b3bb9a1f0f1128ae062554efde1d 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala @@ -2,7 +2,7 @@ package leon.synthesis.search class AndOrGraphPartialSolution[AT <: AOAndTask[S], OT <: AOOrTask[S], - S <: AOSolution](val g: AndOrGraph[AT, OT, S], missing: AT => S) { + S](val g: AndOrGraph[AT, OT, S], missing: AT => S) { def getSolution: S = { diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala index 01aff5047dc0cde1353b8bdded8c9e3f8fd6901b..75ee3292fff51f56d1a8a58cfad5c28b371268e0 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala @@ -2,7 +2,7 @@ package leon.synthesis.search abstract class AndOrGraphSearch[AT <: AOAndTask[S], OT <: AOOrTask[S], - S <: AOSolution](val g: AndOrGraph[AT, OT, S]) { + S](val g: AndOrGraph[AT, OT, S]) { var processing = Set[g.Leaf]()