From 242ae4cdedcd07c94d4f13258a6cc15aaa9b2ff8 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Tue, 4 Dec 2012 03:38:07 +0100 Subject: [PATCH] Add cost models, make things a tad bit nicer, synthesizer options --- src/main/scala/leon/synthesis/Cost.scala | 34 ++++++++++++------- .../scala/leon/synthesis/ParallelSearch.scala | 3 +- .../scala/leon/synthesis/SimpleSearch.scala | 18 +++++++--- src/main/scala/leon/synthesis/Solution.scala | 6 +--- .../scala/leon/synthesis/SynthesisPhase.scala | 24 +++++-------- .../scala/leon/synthesis/Synthesizer.scala | 18 +++++----- .../leon/synthesis/SynthesizerOptions.scala | 11 ++++++ .../leon/synthesis/search/AndOrGraph.scala | 24 ++++++------- .../search/AndOrGraphDotConverter.scala | 4 +-- .../search/AndOrGraphParallelSearch.scala | 2 +- .../search/AndOrGraphPartialSolution.scala | 2 +- .../synthesis/search/AndOrGraphSearch.scala | 2 +- 12 files changed, 81 insertions(+), 67 deletions(-) create mode 100644 src/main/scala/leon/synthesis/SynthesizerOptions.scala diff --git a/src/main/scala/leon/synthesis/Cost.scala b/src/main/scala/leon/synthesis/Cost.scala index 7d989b347..f8fb6d694 100644 --- a/src/main/scala/leon/synthesis/Cost.scala +++ b/src/main/scala/leon/synthesis/Cost.scala @@ -6,22 +6,30 @@ import purescala.TreeOps._ import synthesis.search.Cost -case class SolutionCost(s: Solution) extends Cost { - val value = { - val chooses = collectChooses(s.toExpr) - val chooseCost = chooses.foldLeft(0)((i, c) => i + ProblemCost(Problem.fromChoose(c)).value) +abstract class CostModel(name: String) { + def solutionCost(s: Solution): Cost + def problemCost(p: Problem): Cost + def ruleAppCost(r: Rule, app: RuleApplication): Cost +} + +case object NaiveCostModel extends CostModel("Naive") { + def solutionCost(s: Solution): Cost = new Cost { + val value = { + val chooses = collectChooses(s.toExpr) + val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c)).value) - formulaSize(s.toExpr) + chooseCost + formulaSize(s.toExpr) + chooseCost + } } -} -case class ProblemCost(p: Problem) extends Cost { - val value = p.xs.size -} + def problemCost(p: Problem): Cost = new Cost { + val value = p.xs.size + } -case class RuleApplicationCost(rule: Rule, app: RuleApplication) extends Cost { - val subSols = (1 to app.subProblemsCount).map {i => Solution.simplest }.toList - val simpleSol = app.onSuccess(subSols) + def ruleAppCost(r: Rule, app: RuleApplication): Cost = new Cost { + val subSols = (1 to app.subProblemsCount).map {i => Solution.simplest }.toList + val simpleSol = app.onSuccess(subSols) - val value = SolutionCost(simpleSol).value + val value = solutionCost(simpleSol).value + } } diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala index 9d24d1d27..7902675f0 100644 --- a/src/main/scala/leon/synthesis/ParallelSearch.scala +++ b/src/main/scala/leon/synthesis/ParallelSearch.scala @@ -8,7 +8,8 @@ import solvers.TrivialSolver class ParallelSearch(synth: Synthesizer, problem: Problem, - rules: Set[Rule]) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem))) { + rules: Set[Rule], + costModel: CostModel) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) { import synth.reporter._ diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala index b58d15cee..e619510c1 100644 --- a/src/main/scala/leon/synthesis/SimpleSearch.scala +++ b/src/main/scala/leon/synthesis/SimpleSearch.scala @@ -4,8 +4,6 @@ package synthesis import synthesis.search._ case class TaskRunRule(problem: Problem, rule: Rule, app: RuleApplication) extends AOAndTask[Solution] { - val cost = RuleApplicationCost(rule, app) - def composeSolution(sols: List[Solution]): Solution = { app.onSuccess(sols) } @@ -14,14 +12,24 @@ case class TaskRunRule(problem: Problem, rule: Rule, app: RuleApplication) exten } case class TaskTryRules(p: Problem) extends AOOrTask[Solution] { - val cost = ProblemCost(p) - override def toString = p.toString } +case class SearchCostModel(cm: CostModel) extends AOCostModel[TaskRunRule, TaskTryRules, Solution] { + def taskCost(t: AOTask[Solution]) = t match { + case ttr: TaskRunRule => + cm.ruleAppCost(ttr.rule, ttr.app) + case trr: TaskTryRules => + cm.problemCost(trr.p) + } + + def solutionCost(s: Solution) = cm.solutionCost(s) +} + class SimpleSearch(synth: Synthesizer, problem: Problem, - rules: Set[Rule]) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem))) { + rules: Set[Rule], + costModel: CostModel) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) { import synth.reporter._ diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 471d154f7..602b5f771 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -5,15 +5,11 @@ import leon.purescala.Trees._ import leon.purescala.Definitions._ import leon.purescala.TreeOps._ -import synthesis.search._ - // Defines a synthesis solution of the form: // ⟨ P | T ⟩ -class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) extends AOSolution { +class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { override def toString = "⟨ "+pre+" | "+defs.mkString(" ")+" "+term+" ⟩" - val cost: Cost = SolutionCost(this) - def toExpr = { val result = if (pre == BooleanLiteral(true)) { term diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index baf62236d..ad21c6f7b 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -31,30 +31,26 @@ object SynthesisPhase extends LeonPhase[Program, Program] { val uninterpretedZ3 = new UninterpretedZ3Solver(silentContext) uninterpretedZ3.setProgram(p) + var options = SynthesizerOptions() var inPlace = false - var genTrees = false - var firstOnly = false - var parallel = false - var filterFun: Option[Seq[String]] = None - var timeoutMs: Option[Long] = None for(opt <- ctx.options) opt match { case LeonFlagOption("inplace") => inPlace = true case LeonValueOption("functions", ListValue(fs)) => - filterFun = Some(fs) + options = options.copy(filterFuns = Some(fs.toSet)) case LeonValueOption("timeout", t) => try { - timeoutMs = Some(t.toLong) + options = options.copy(timeoutMs = Some(t.toLong)) } catch { case _: Throwable => } case LeonFlagOption("firstonly") => - firstOnly = true + options = options.copy(firstOnly = true) case LeonFlagOption("parallel") => - parallel = true + options = options.copy(parallel = true) case LeonFlagOption("derivtrees") => - genTrees = true + options = options.copy(generateDerivationTrees = true) case _ => } @@ -71,11 +67,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { p, problem, Rules.all ++ Heuristics.all, - genTrees, - filterFun.map(_.toSet), - parallel, - firstOnly, - timeoutMs) + options) val sol = synth.synthesize() solutions += ch -> (f, sol) @@ -87,7 +79,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { // Look for choose() for (f <- program.definedFunctions.sortBy(_.id.toString) if f.body.isDefined) { - if (filterFun.isEmpty || filterFun.get.contains(f.id.toString)) { + if (options.filterFuns.isEmpty || options.filterFuns.get.contains(f.id.toString)) { treeCatamorphism(x => x, noop, actOnChoose(f), f.body.get) } } diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 78ca68e29..fe8f2ef8f 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -20,11 +20,8 @@ class Synthesizer(val context : LeonContext, val program: Program, val problem: Problem, val rules: Set[Rule], - generateDerivationTrees: Boolean = false, - filterFuns: Option[Set[String]] = None, - parallel: Boolean = false, - firstOnly: Boolean = false, - timeoutMs: Option[Long] = None) { + val options: SynthesizerOptions) { + protected[synthesis] val reporter = context.reporter import reporter.{error,warning,info,fatalError} @@ -33,10 +30,10 @@ class Synthesizer(val context : LeonContext, def synthesize(): Solution = { - val search = if (parallel) { - new ParallelSearch(this, problem, rules) + val search = if (options.parallel) { + new ParallelSearch(this, problem, rules, options.costModel) } else { - new SimpleSearch(this, problem, rules) + new SimpleSearch(this, problem, rules, options.costModel) } val sigINT = new Signal("INT") @@ -60,8 +57,9 @@ class Synthesizer(val context : LeonContext, val diff = System.currentTimeMillis()-ts reporter.info("Finished in "+diff+"ms") - if (generateDerivationTrees) { - new AndOrGraphDotConverter(search.g, firstOnly).writeFile("derivation"+AndOrGraphDotConverterCounter.next()+".dot") + if (options.generateDerivationTrees) { + val converter = new AndOrGraphDotConverter(search.g, options.firstOnly) + converter.writeFile("derivation"+AndOrGraphDotConverterCounter.next()+".dot") } res match { diff --git a/src/main/scala/leon/synthesis/SynthesizerOptions.scala b/src/main/scala/leon/synthesis/SynthesizerOptions.scala new file mode 100644 index 000000000..768c31ce9 --- /dev/null +++ b/src/main/scala/leon/synthesis/SynthesizerOptions.scala @@ -0,0 +1,11 @@ +package leon +package synthesis + +case class SynthesizerOptions( + generateDerivationTrees: Boolean = false, + filterFuns: Option[Set[String]] = None, + parallel: Boolean = false, + firstOnly: Boolean = false, + timeoutMs: Option[Long] = None, + costModel: CostModel = NaiveCostModel +) diff --git a/src/main/scala/leon/synthesis/search/AndOrGraph.scala b/src/main/scala/leon/synthesis/search/AndOrGraph.scala index 297c10fa7..2ac6881a2 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraph.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraph.scala @@ -1,21 +1,21 @@ package leon.synthesis.search -trait AOTask[S <: AOSolution] { - def cost: Cost +trait AOTask[S] { } -trait AOAndTask[S <: AOSolution] extends AOTask[S] { +trait AOAndTask[S] extends AOTask[S] { def composeSolution(sols: List[S]): S } -trait AOOrTask[S <: AOSolution] extends AOTask[S] { +trait AOOrTask[S] extends AOTask[S] { } -trait AOSolution { - def cost: Cost +trait AOCostModel[AT <: AOAndTask[S], OT <: AOOrTask[S], S] { + def taskCost(at: AOTask[S]): Cost + def solutionCost(s: S): Cost } -class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val root: OT) { +class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val costModel: AOCostModel[AT, OT, S]) { var tree: OrTree = RootNode trait Tree { @@ -40,7 +40,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo trait Leaf extends Tree { - def minCost = task.cost + def minCost = costModel.taskCost(task) } trait Node[T <: Tree] extends Tree { @@ -59,11 +59,11 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo val old = minCost minCost = solution match { case Some(s) => - s.cost + costModel.solutionCost(s) case _ => val subCosts = subProblems.values.map(_.minCost) - subCosts.foldLeft(task.cost)(_ + _) + subCosts.foldLeft(costModel.taskCost(task))(_ + _) } if (minCost != old) { Option(parent).foreach(_.updateMin()) @@ -125,7 +125,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo class OrNode(val parent: AndNode, var alternatives: Map[AT, AndTree], val task: OT) extends OrTree with Node[AndTree] { var triedAlternatives = Map[AT, AndTree]() var minAlternative: AndTree = _ - var minCost = task.cost + var minCost = costModel.taskCost(task) def updateMin() { if (!alternatives.isEmpty) { @@ -166,7 +166,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S <: AOSolution](val roo def notifySolution(sub: AndTree, sol: S) { solution match { - case Some(preSol) if (preSol.cost < sol.cost) => + case Some(preSol) if (costModel.solutionCost(preSol) < costModel.solutionCost(sol)) => solution = Some(sol) minAlternative = sub diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala b/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala index 931e276bd..062ada1d7 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala @@ -2,7 +2,7 @@ package leon.synthesis.search class AndOrGraphDotConverter[AT <: AOAndTask[S], OT <: AOOrTask[S], - S <: AOSolution](val g: AndOrGraph[AT, OT, S], firstOnly: Boolean) { + S](val g: AndOrGraph[AT, OT, S], firstOnly: Boolean) { private[this] var _nextID = 0 @@ -97,7 +97,7 @@ class AndOrGraphDotConverter[AT <: AOAndTask[S], (if (t.isSolved) "palegreen" else if (t.isUnsolvable) "firebrick" else "white", "") } - res append " "+name+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">self: "+t.task.cost.value+" | tree-min: "+t.minCost.value+"</TD></TR><TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(t.task.toString)+"</TD></TR>"; + res append " "+name+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">self: "+g.costModel.taskCost(t.task).value+" | tree-min: "+t.minCost.value+"</TD></TR><TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(t.task.toString)+"</TD></TR>"; if (t.isSolved) { res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(t.solution.get.toString)+"</TD></TR>" diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala index be7ee3498..a9a3201ae 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala @@ -9,7 +9,7 @@ import akka.dispatch.Await abstract class AndOrGraphParallelSearch[WC, AT <: AOAndTask[S], OT <: AOOrTask[S], - S <: AOSolution](og: AndOrGraph[AT, OT, S]) extends AndOrGraphSearch[AT, OT, S](og) { + S](og: AndOrGraph[AT, OT, S]) extends AndOrGraphSearch[AT, OT, S](og) { def initWorkerContext(w: ActorRef): WC diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala index 151f40ae5..2406bc1ca 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala @@ -2,7 +2,7 @@ package leon.synthesis.search class AndOrGraphPartialSolution[AT <: AOAndTask[S], OT <: AOOrTask[S], - S <: AOSolution](val g: AndOrGraph[AT, OT, S], missing: AT => S) { + S](val g: AndOrGraph[AT, OT, S], missing: AT => S) { def getSolution: S = { diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala index 01aff5047..75ee3292f 100644 --- a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala +++ b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala @@ -2,7 +2,7 @@ package leon.synthesis.search abstract class AndOrGraphSearch[AT <: AOAndTask[S], OT <: AOOrTask[S], - S <: AOSolution](val g: AndOrGraph[AT, OT, S]) { + S](val g: AndOrGraph[AT, OT, S]) { var processing = Set[g.Leaf]() -- GitLab