From 06c268ab5ac99fc3737f7bd466b8c3999b87ee99 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <etienne.kneuss@epfl.ch> Date: Thu, 9 Oct 2014 17:11:06 +0200 Subject: [PATCH] Synthesis framework refactoring --- library/lang/string/String.scala | 2 + .../scala/leon/synthesis/BoundedSearch.scala | 23 -- src/main/scala/leon/synthesis/CostModel.scala | 54 +-- .../scala/leon/synthesis/Distribution.scala | 179 ++++++++ .../scala/leon/synthesis/ManualSearch.scala | 206 ---------- .../scala/leon/synthesis/ParallelSearch.scala | 70 ---- .../leon/synthesis/PartialSolution.scala | 68 ++++ src/main/scala/leon/synthesis/Rules.scala | 21 +- .../leon/synthesis/SearchCostModel.scala | 18 - .../scala/leon/synthesis/SimpleSearch.scala | 202 --------- src/main/scala/leon/synthesis/Solution.scala | 10 +- .../scala/leon/synthesis/SynthesisPhase.scala | 37 +- .../scala/leon/synthesis/Synthesizer.scala | 63 +-- src/main/scala/leon/synthesis/Task.scala | 145 ------- .../scala/leon/synthesis/TaskRunRule.scala | 18 - .../scala/leon/synthesis/TaskTryRules.scala | 10 - .../leon/synthesis/graph/DotGenerator.scala | 141 +++++++ .../scala/leon/synthesis/graph/Graph.scala | 207 ++++++++++ .../scala/leon/synthesis/graph/Search.scala | 273 +++++++++++++ .../scala/leon/synthesis/rules/AsChoose.scala | 2 +- .../scala/leon/synthesis/rules/Cegis.scala | 32 +- .../scala/leon/synthesis/rules/Ground.scala | 8 +- .../synthesis/rules/OptimisticGround.scala | 12 +- .../scala/leon/synthesis/rules/Tegis.scala | 87 ++-- .../leon/synthesis/search/AndOrGraph.scala | 383 ------------------ .../search/AndOrGraphDotConverter.scala | 129 ------ .../search/AndOrGraphParallelSearch.scala | 183 --------- .../search/AndOrGraphPartialSolution.scala | 39 -- .../synthesis/search/AndOrGraphSearch.scala | 96 ----- .../scala/leon/synthesis/search/Cost.scala | 18 - .../{ASCIITable.scala => ASCIIHelpers.scala} | 16 +- src/main/scala/leon/utils/Simplifiers.scala | 21 + src/main/scala/leon/utils/StreamUtils.scala | 98 +++++ src/main/scala/leon/utils/Timer.scala | 2 +- .../verification/VerificationReport.scala | 2 +- 35 files changed, 1180 insertions(+), 1695 deletions(-) delete mode 100644 src/main/scala/leon/synthesis/BoundedSearch.scala create mode 100644 src/main/scala/leon/synthesis/Distribution.scala delete mode 100644 src/main/scala/leon/synthesis/ManualSearch.scala delete mode 100644 src/main/scala/leon/synthesis/ParallelSearch.scala create mode 100644 src/main/scala/leon/synthesis/PartialSolution.scala delete mode 100644 src/main/scala/leon/synthesis/SearchCostModel.scala delete mode 100644 src/main/scala/leon/synthesis/SimpleSearch.scala delete mode 100644 src/main/scala/leon/synthesis/Task.scala delete mode 100644 src/main/scala/leon/synthesis/TaskRunRule.scala delete mode 100644 src/main/scala/leon/synthesis/TaskTryRules.scala create mode 100644 src/main/scala/leon/synthesis/graph/DotGenerator.scala create mode 100644 src/main/scala/leon/synthesis/graph/Graph.scala create mode 100644 src/main/scala/leon/synthesis/graph/Search.scala delete mode 100644 src/main/scala/leon/synthesis/search/AndOrGraph.scala delete mode 100644 src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala delete mode 100644 src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala delete mode 100644 src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala delete mode 100644 src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala delete mode 100644 src/main/scala/leon/synthesis/search/Cost.scala rename src/main/scala/leon/utils/{ASCIITable.scala => ASCIIHelpers.scala} (89%) create mode 100644 src/main/scala/leon/utils/StreamUtils.scala diff --git a/library/lang/string/String.scala b/library/lang/string/String.scala index b47fcded3..89c53a7ee 100644 --- a/library/lang/string/String.scala +++ b/library/lang/string/String.scala @@ -1,6 +1,8 @@ /* Copyright 2009-2014 EPFL, Lausanne */ package leon.lang.string +import leon.annotation._ +@library case class String(chars: leon.collection.List[Char]) { def +(that: String): String = { String(this.chars ++ that.chars) diff --git a/src/main/scala/leon/synthesis/BoundedSearch.scala b/src/main/scala/leon/synthesis/BoundedSearch.scala deleted file mode 100644 index 5430d05d2..000000000 --- a/src/main/scala/leon/synthesis/BoundedSearch.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -class BoundedSearch(synth: Synthesizer, - problem: Problem, - costModel: CostModel, - searchBound: Int) extends SimpleSearch(synth, problem, costModel) { - - def this(synth: Synthesizer, problem: Problem, searchBound: Int) = { - this(synth, problem, synth.options.costModel, searchBound) - } - - override def searchStep() { - val (closed, total) = g.getStatus - if (total > searchBound) { - stop() - } else { - super.searchStep() - } - } -} diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index a133d7dee..b4cac6662 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -6,46 +6,50 @@ package synthesis import purescala.Trees._ import purescala.TreeOps._ -import synthesis.search.Cost - abstract class CostModel(val name: String) { + type Cost = Int + def solutionCost(s: Solution): Cost def problemCost(p: Problem): Cost - def ruleAppCost(app: RuleInstantiation): Cost = new Cost { + def ruleAppCost(app: RuleInstantiation): Cost = { val subSols = app.onSuccess.types.map {t => Solution.simplest(t) }.toList val simpleSol = app.onSuccess(subSols) - val value = simpleSol match { + simpleSol match { case Some(sol) => - solutionCost(sol).value + solutionCost(sol) case None => - problemCost(app.problem).value + problemCost(app.problem) } } } +case class ScaledCostModel(cm: CostModel, scale: Int) extends CostModel(cm.name+"/"+scale) { + def solutionCost(s: Solution): Cost = Math.max(cm.solutionCost(s)/scale, 1) + def problemCost(p: Problem): Cost = Math.max(cm.problemCost(p)/scale, 1) + override def ruleAppCost(app: RuleInstantiation): Cost = Math.max(cm.ruleAppCost(app)/scale, 1) +} + object CostModel { - def default: CostModel = WeightedBranchesCostModel + def default: CostModel = ScaledCostModel(WeightedBranchesCostModel, 5) def all: Set[CostModel] = Set( - NaiveCostModel, - WeightedBranchesCostModel + ScaledCostModel(NaiveCostModel, 5), + ScaledCostModel(WeightedBranchesCostModel, 5) ) } case object NaiveCostModel extends CostModel("Naive") { - def solutionCost(s: Solution): Cost = new Cost { - val value = { - val chooses = collectChooses(s.toExpr) - val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c)).value) + def solutionCost(s: Solution): Cost = { + val chooses = collectChooses(s.toExpr) + val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c))) - formulaSize(s.toExpr) + chooseCost - } + (formulaSize(s.toExpr) + chooseCost)/5+1 } - def problemCost(p: Problem): Cost = new Cost { - val value = p.xs.size + def problemCost(p: Problem): Cost = { + 1 } } @@ -87,19 +91,15 @@ case object WeightedBranchesCostModel extends CostModel("WeightedBranches") { bc.cost } - def solutionCost(s: Solution): Cost = new Cost { - val value = { - val chooses = collectChooses(s.toExpr) - val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c)).value) + def solutionCost(s: Solution): Cost = { + val chooses = collectChooses(s.toExpr) + val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c))) - formulaSize(s.toExpr) + branchesCost(s.toExpr) + chooseCost - } + formulaSize(s.toExpr) + branchesCost(s.toExpr) + chooseCost } - def problemCost(p: Problem): Cost = new Cost { - val value = { - p.xs.size - } + def problemCost(p: Problem): Cost = { + p.xs.size } } diff --git a/src/main/scala/leon/synthesis/Distribution.scala b/src/main/scala/leon/synthesis/Distribution.scala new file mode 100644 index 000000000..7c6a07040 --- /dev/null +++ b/src/main/scala/leon/synthesis/Distribution.scala @@ -0,0 +1,179 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon.synthesis + +class Distribution(val span: Int, val values: Array[Long], val total: Long) extends Ordered[Distribution] { + def and(that: Distribution): Distribution = { val res = (this, that) match { + case (d1, d2) if d1.total == 0 => + d1 + + case (d1, d2) if d2.total == 0 => + d2 + + case (d1: PointDistribution, d2: PointDistribution) => + if (d1.at + d2.at >= span) { + Distribution.empty(span) + } else { + new PointDistribution(span, d1.at+d2.at) + } + + case (d: PointDistribution, o) => + val a = Array.fill(span)(0l) + + val base = d.at + var innerTotal = 0l; + var i = d.at; + while(i < span) { + val v = o.values(i-base) + a(i) = v + innerTotal += v + i += 1 + } + + if (innerTotal == 0) { + Distribution.empty(span) + } else { + new Distribution(span, a, total) + } + + case (o, d: PointDistribution) => + val a = Array.fill(span)(0l) + + val base = d.at + var innerTotal = 0l; + var i = d.at; + while(i < span) { + val v = o.values(i-base) + a(i) = v + innerTotal += v + i += 1 + } + + if (innerTotal == 0) { + Distribution.empty(span) + } else { + new Distribution(span, a, total) + } + + case (left, right) => + if (left == right) { + left + } else { + val a = Array.fill(span)(0l) + var innerTotal = 0l; + var i = 0; + while (i < span) { + var j = 0; + while (j < span) { + if (i+j < span) { + val lv = left.values(i) + val rv = right.values(j) + + a(i+j) += lv*rv + innerTotal += lv*rv + } + j += 1 + } + i += 1 + } + + if (innerTotal == 0) { + Distribution.empty(span) + } else { + new Distribution(span, a, left.total * right.total) + } + } + } + //println("And of "+this+" and "+that+" = "+res) + res + } + + def or(that: Distribution): Distribution = (this, that) match { + case (d1, d2) if d1.total == 0 => + d2 + + case (d1, d2) if d2.total == 0 => + d1 + + case (d1: PointDistribution, d2: PointDistribution) => + if (d1.at < d2.at) { + d1 + } else { + d2 + } + + case (d1, d2) => + if (d1.weightedSum < d2.weightedSum) { + //if (d1.firstNonZero < d2.firstNonZero) { + d1 + } else { + d2 + } + } + + lazy val firstNonZero: Int = { + if (total == 0) { + span + } else { + var i = 0; + var continue = true; + while (continue && i < span) { + if (values(i) != 0l) { + continue = false + } + i += 1 + } + i + } + } + + lazy val weightedSum: Double = { + var res = 0d; + var i = 0; + while (i < span) { + res += (1d*i*values(i))/total + i += 1 + } + res + } + + override def toString: String = { + "Tot:"+total+"(at "+firstNonZero+")" + } + + def compare(that: Distribution) = { + this.firstNonZero - that.firstNonZero + } +} + +object Distribution { + def point(span: Int, at: Int) = { + if (span <= at) { + empty(span) + } else { + new PointDistribution(span, at) + } + } + + def empty(span: Int) = new Distribution(span, Array[Long](), 0l) + def uniform(span: Int, v: Long, total: Int) = { + new Distribution(span, Array.fill(span)(v), total) + } + + def uniformFrom(span: Int, from: Int, ratio: Double) = { + var i = from + val a = Array.fill(span)(0l) + while(i < span) { + a(i) = 1 + i += 1 + } + + new Distribution(span, a, ((span-from).toDouble*(1/ratio)).toInt) + } +} + +class PointDistribution(span: Int, val at: Int) extends Distribution(span, new Array[Long](span).updated(at, 1l), 1l) { + override lazy val firstNonZero: Int = { + at + } +} diff --git a/src/main/scala/leon/synthesis/ManualSearch.scala b/src/main/scala/leon/synthesis/ManualSearch.scala deleted file mode 100644 index b7e728d23..000000000 --- a/src/main/scala/leon/synthesis/ManualSearch.scala +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -import leon.purescala.ScalaPrinter - -class ManualSearch(synth: Synthesizer, - problem: Problem, - costModel: CostModel) extends SimpleSearch(synth, problem, costModel) { - - def this(synth: Synthesizer, problem: Problem) = { - this(synth, problem, synth.options.costModel) - } - - import synth.reporter._ - - var cd = List[Int]() - var cmdQueue = List[String]() - - def printGraph() { - def pathToString(path: List[Int]): String = { - val p = path.reverse.drop(cd.size) - if (p.isEmpty) { - "" - } else { - " "+p.mkString(" ") - } - } - - def title(str: String) = "\u001b[1m" + str + "\u001b[0m" - def failed(str: String) = "\u001b[31m" + str + "\u001b[0m" - def solved(str: String) = "\u001b[32m" + str + "\u001b[0m" - - def displayApp(app: RuleInstantiation): String = { - f"(${costModel.ruleAppCost(app).value}%3d) $app" - } - - def displayProb(p: Problem): String = { - f"(${costModel.problemCost(p).value}%3d) $p" - } - - def traversePathFrom(n: g.Tree, prefix: List[Int]) { - n match { - case l: g.AndLeaf => - if (prefix.endsWith(cd.reverse)) { - println(pathToString(prefix)+" \u2508 "+displayApp(l.task.app)) - } - case l: g.OrLeaf => - if (prefix.endsWith(cd.reverse)) { - println(pathToString(prefix)+" \u2508 "+displayProb(l.task.p)) - } - case an: g.AndNode => - if (an.isSolved) { - if (prefix.endsWith(cd.reverse)) { - println(solved(pathToString(prefix)+" \u2508 "+displayApp(an.task.app))) - } - } else { - if (prefix.endsWith(cd.reverse)) { - println(title(pathToString(prefix)+" \u2510 "+displayApp(an.task.app))) - } - for ((st, i) <- an.subTasks.zipWithIndex) { - traversePathFrom(an.subProblems(st), i :: prefix) - } - } - - case on: g.OrNode => - if (on.isSolved) { - if (prefix.endsWith(cd.reverse)) { - println(solved(pathToString(prefix)+on.task.p)) - } - } else { - if (prefix.endsWith(cd.reverse)) { - println(title(pathToString(prefix)+" \u2510 "+displayProb(on.task.p))) - } - for ((at, i) <- on.altTasks.zipWithIndex) { - if (on.triedAlternatives contains at) { - if (prefix.endsWith(cd.reverse)) { - println(failed(pathToString(i :: prefix)+" \u2508 "+displayApp(at.app))) - } - } else { - traversePathFrom(on.alternatives(at), i :: prefix) - } - } - } - } - } - - println("-"*80) - traversePathFrom(g.tree, List()) - println("-"*80) - } - - override def stop() { - super.stop() - cmdQueue = "q" :: Nil - continue = false - } - - var continue = true - - - override def nextLeaf(): Option[g.Leaf] = { - g.tree match { - case l: g.Leaf => - Some(l) - case _ => - - var res: Option[g.Leaf] = None - continue = true - - while(continue) { - printGraph() - - try { - - print("Next action? (q to quit) "+cd.mkString(" ")+" $ ") - val line = if (cmdQueue.isEmpty) { - scala.io.StdIn.readLine() - } else { - val n = cmdQueue.head - println(n) - cmdQueue = cmdQueue.tail - n - } - if (line == "q") { - continue = false - res = None - } else if (line startsWith "cd") { - val parts = line.split("\\s+").toList - - parts match { - case List("cd") => - cd = List() - case List("cd", "..") => - if (cd.size > 0) { - cd = cd.dropRight(1) - } - case "cd" :: parts => - cd = cd ::: parts.map(_.toInt) - case _ => - } - - } else if (line startsWith "p") { - val parts = line.split("\\s+").toList.tail.map(_.toInt) - traversePath(cd ::: parts) match { - case Some(n) => - println("#"*80) - println("AT:"+n.task) - val sp = programAt(n) - sp.acc.foreach(fd => println(ScalaPrinter(fd))) - println("$"*20) - println("ROOT: "+sp.fd.id) - case _ => - } - - } else { - val parts = line.split("\\s+").toList - - val c = parts.head.toInt - cmdQueue = cmdQueue ::: parts.tail - - traversePath(cd ::: c :: Nil) match { - case Some(l: g.Leaf) => - res = Some(l) - cd = cd ::: c :: Nil - continue = false - case Some(_) => - cd = cd ::: c :: Nil - case None => - error("Invalid path") - } - } - } catch { - case e: java.lang.NumberFormatException => - - case e: java.io.IOException => - continue = false - - case e: Throwable => - error("Woops: "+e.getMessage()) - e.printStackTrace() - } - } - res - } - } - - override def searchStep() { - super.searchStep() - - var continue = cd.size > 0 - while(continue) { - traversePath(cd) match { - case Some(t) if !t.isSolved => - continue = false - case Some(t) => - cd = cd.dropRight(1) - case None => - cd = cd.dropRight(1) - } - continue = continue && (cd.size > 0) - } - } - -} diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala deleted file mode 100644 index b0af9cab8..000000000 --- a/src/main/scala/leon/synthesis/ParallelSearch.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -import synthesis.search._ -import akka.actor._ -import solvers.z3._ - -class ParallelSearch(synth: Synthesizer, - problem: Problem, - costModel: CostModel, - nWorkers: Int) extends AndOrGraphParallelSearch[SynthesisContext, TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel)), nWorkers) { - - def this(synth: Synthesizer, problem: Problem, nWorkers: Int) = { - this(synth, problem, synth.options.costModel, nWorkers) - } - - import synth.reporter._ - - // This is HOT shared memory, used only in stop() for shutting down solvers! - private[this] var contexts = List[SynthesisContext]() - - def initWorkerContext(wr: ActorRef) = { - val ctx = SynthesisContext.fromSynthesizer(synth) - - synchronized { - contexts = ctx :: contexts - } - - ctx - } - - def expandAndTask(ref: ActorRef, sctx: SynthesisContext)(t: TaskRunRule) = { - val prefix = "[%-20s] ".format(Option(t.rule).getOrElse("?")) - - t.app.apply(sctx) match { - case RuleSuccess(sol, isTrusted) => - synth.synchronized { - info(prefix+"Got: "+t.problem) - info(prefix+"Solved"+(if(isTrusted) "" else " (untrusted)")+" with: "+sol) - } - - ExpandSuccess(sol, isTrusted) - case RuleDecomposed(sub) => - synth.synchronized { - info(prefix+"Got: "+t.problem) - info(prefix+"Decomposed into:") - for(p <- sub) { - info(prefix+" - "+p) - } - } - - Expanded(sub.map(TaskTryRules(_))) - - case RuleApplicationImpossible => - ExpandFailure() - } - } - - def expandOrTask(ref: ActorRef, sctx: SynthesisContext)(t: TaskTryRules) = { - val apps = Rules.getInstantiations(sctx, t.p) - - if (apps.nonEmpty) { - Expanded(apps.map(TaskRunRule(_))) - } else { - ExpandFailure() - } - } -} diff --git a/src/main/scala/leon/synthesis/PartialSolution.scala b/src/main/scala/leon/synthesis/PartialSolution.scala new file mode 100644 index 000000000..039342d3c --- /dev/null +++ b/src/main/scala/leon/synthesis/PartialSolution.scala @@ -0,0 +1,68 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package synthesis + +import graph._ + +class PartialSolution(g: Graph, includeUntrusted: Boolean) { + + def includeSolution(s: Solution) = { + includeUntrusted || s.isTrusted + } + + def completeProblem(p: Problem) = { + Solution.choose(p) + } + + + def getSolution(): Solution = { + getSolutionFor(g.root) + } + + def getSolutionFor(n: g.Node): Solution = { + n match { + case on: g.OrNode => + if (on.isSolved) { + val sols = on.generateSolutions() + sols.find(includeSolution) match { + case Some(sol) => + return sol + case _ => + } + } + + if (n.isExpanded) { + val descs = on.descendents.filter(_.isClosed) + if (descs.isEmpty) { + completeProblem(on.p) + } else { + getSolutionFor(descs.minBy(_.costDist)) + } + } else { + completeProblem(on.p) + } + case an: g.AndNode => + if (an.isSolved) { + val sols = an.generateSolutions() + sols.find(includeSolution) match { + case Some(sol) => + return sol + case _ => + } + } + + if (n.isExpanded) { + an.ri.onSuccess(n.descendents.map(getSolutionFor)) match { + case Some(sol) => + sol + + case None => + completeProblem(an.ri.problem) + } + } else { + completeProblem(an.ri.problem) + } + } + } +} diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index b1b5cd756..431eeba3b 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -77,15 +77,22 @@ abstract class RuleInstantiation( val description: String, val priority: RulePriority) { - def apply(sctx: SynthesisContext): RuleApplicationResult + def apply(sctx: SynthesisContext): RuleApplication override def toString = description } -sealed abstract class RuleApplicationResult -case class RuleSuccess(solution: Solution, isTrusted: Boolean = true) extends RuleApplicationResult -case class RuleDecomposed(sub: List[Problem]) extends RuleApplicationResult -case object RuleApplicationImpossible extends RuleApplicationResult +sealed abstract class RuleApplication +case class RuleClosed(solutions: Stream[Solution]) extends RuleApplication +case class RuleExpanded(sub: List[Problem]) extends RuleApplication + +object RuleClosed { + def apply(s: Solution): RuleClosed = RuleClosed(Stream(s)) +} + +object RuleFailed { + def apply(): RuleClosed = RuleClosed(Stream.empty) +} sealed abstract class RulePriority(val v: Int) extends Ordered[RulePriority] { def compare(that: RulePriority) = this.v - that.v @@ -121,7 +128,7 @@ object RuleInstantiation { val subTypes = sub.map(p => TupleType(p.xs.map(_.getType))) new RuleInstantiation(problem, rule, new SolutionCombiner(sub.size, subTypes, onSuccess), description, priority) { - def apply(sctx: SynthesisContext) = RuleDecomposed(sub) + def apply(sctx: SynthesisContext) = RuleExpanded(sub) } } @@ -137,7 +144,7 @@ object RuleInstantiation { solution: Solution, priority: RulePriority): RuleInstantiation = { new RuleInstantiation(problem, rule, new SolutionCombiner(0, Seq(), ls => Some(solution)), "Solve with "+solution, priority) { - def apply(sctx: SynthesisContext) = RuleSuccess(solution) + def apply(sctx: SynthesisContext) = RuleClosed(solution) } } } diff --git a/src/main/scala/leon/synthesis/SearchCostModel.scala b/src/main/scala/leon/synthesis/SearchCostModel.scala deleted file mode 100644 index a4b9a62d4..000000000 --- a/src/main/scala/leon/synthesis/SearchCostModel.scala +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -import synthesis.search._ - -case class SearchCostModel(cm: CostModel) extends AOCostModel[TaskRunRule, TaskTryRules, Solution] { - def taskCost(t: AOTask[Solution]) = t match { - case ttr: TaskRunRule => - cm.ruleAppCost(ttr.app) - case trr: TaskTryRules => - cm.problemCost(trr.p) - } - - def solutionCost(s: Solution) = cm.solutionCost(s) -} - diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala deleted file mode 100644 index 35f7f1ce5..000000000 --- a/src/main/scala/leon/synthesis/SimpleSearch.scala +++ /dev/null @@ -1,202 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -import leon.utils._ -import purescala.Definitions.FunDef -import synthesis.search._ - -class SimpleSearch(synth: Synthesizer, - problem: Problem, - costModel: CostModel) extends AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution](new AndOrGraph(TaskTryRules(problem), SearchCostModel(costModel))) with Interruptible { - - def this(synth: Synthesizer, problem: Problem) = { - this(synth, problem, synth.options.costModel) - } - - import synth.reporter._ - - val sctx = SynthesisContext.fromSynthesizer(synth) - - def expandAndTask(t: TaskRunRule): ExpandResult[TaskTryRules] = { - val prefix = "[%-20s] ".format(Option(t.rule).getOrElse("?")) - - info(prefix+"Got: "+t.problem) - t.app.apply(sctx) match { - case RuleSuccess(sol, isTrusted) => - info(prefix+"Solved"+(if(isTrusted) "" else " (untrusted)")+" with: "+sol) - - ExpandSuccess(sol, isTrusted) - case RuleDecomposed(sub) => - info(prefix+"Decomposed into:") - for(p <- sub) { - info(prefix+" - "+p) - } - - Expanded(sub.map(TaskTryRules(_))) - - case RuleApplicationImpossible => - info(prefix+"Failed") - - ExpandFailure() - } - } - - def expandOrTask(t: TaskTryRules): ExpandResult[TaskRunRule] = { - val apps = Rules.getInstantiations(sctx, t.p) - - if (apps.nonEmpty) { - Expanded(apps.map(TaskRunRule(_))) - } else { - ExpandFailure() - } - } - - case class SubProgram(p: Problem, fd: FunDef, acc: Set[FunDef]) - - def programAt(n: g.Tree): SubProgram = { - import purescala.TypeTrees._ - import purescala.Common._ - import purescala.TreeOps.replace - import purescala.Trees._ - import purescala.Definitions._ - - def programFrom(from: g.AndNode, sp: SubProgram): SubProgram = { - if (from.parent.parent eq null) { - sp - } else { - val at = from.parent.parent - val res = bestProgramForAnd(at, Map(from.parent -> sp)) - programFrom(at, res) - } - } - - def bestProgramForOr(on: g.OrTree): SubProgram = { - val problem: Problem = on.task.p - - val fd = problemToFunDef(problem) - - SubProgram(problem, fd, Set(fd)) - } - - def fundefToSol(p: Problem, fd: FunDef): Solution = { - Solution(BooleanLiteral(true), Set(), FunctionInvocation(fd.typed, p.as.map(Variable(_)))) - } - - def solToSubProgram(p: Problem, s: Solution): SubProgram = { - val fd = problemToFunDef(p) - fd.precondition = Some(s.pre) - fd.body = Some(s.term) - - SubProgram(p, fd, Set(fd)) - } - - def bestProgramForAnd(an: g.AndNode, subPrograms: Map[g.OrTree, SubProgram]): SubProgram = { - val subSubPrograms = an.subTasks.map(an.subProblems).map( ot => - subPrograms.getOrElse(ot, bestProgramForOr(ot)) - ) - - val allFd = subSubPrograms.flatMap(_.acc) - val subSolutions = subSubPrograms.map(ssp => fundefToSol(ssp.p, ssp.fd)) - - val sp = solToSubProgram(an.task.problem, an.task.composeSolution(subSolutions).get) - - sp.copy(acc = sp.acc ++ allFd) - } - - def problemToFunDef(p: Problem): FunDef = { - - val ret = if (p.xs.size == 1) { - p.xs.head.getType - } else { - TupleType(p.xs.map(_.getType)) - } - - val freshAs = p.as.map(_.freshen) - - val map = (p.as.map(Variable(_): Expr) zip freshAs.map(Variable(_): Expr)).toMap - - val res = Variable(FreshIdentifier("res").setType(ret)) - - val mapPost: Map[Expr, Expr] = if (p.xs.size > 1) { - p.xs.zipWithIndex.map{ case (id, i) => - Variable(id) -> TupleSelect(res, i+1) - }.toMap - } else { - Map(Variable(p.xs.head) -> res) - } - - val fd = new FunDef(FreshIdentifier("chimp", true), Nil, ret, freshAs.map(id => ValDef(id, id.getType)),DefType.MethodDef) - fd.precondition = Some(replace(map, p.pc)) - fd.postcondition = Some((res.id, replace(map++mapPost, p.phi))) - - fd - } - - n match { - case an: g.AndNode => - programFrom(an, bestProgramForAnd(an, Map.empty)) - - case on: g.OrNode => - if (on.parent ne null) { - programAt(on.parent) - } else { - bestProgramForOr(on) - } - } - } - - - def searchStep() { - val timer = synth.context.timers.synthesis.nextLeaf.start() - val nl = nextLeaf() - timer.stop() - - nl match { - case Some(l) => - l match { - case al: g.AndLeaf => - val sub = expandAndTask(al.task) - - val timer = synth.context.timers.synthesis.expand.start() - onExpansion(al, sub) - timer.stop() - case ol: g.OrLeaf => - val sub = expandOrTask(ol.task) - - val timer = synth.context.timers.synthesis.expand.start() - onExpansion(ol, sub) - timer.stop() - } - case None => - stop() - } - } - - sctx.context.interruptManager.registerForInterrupts(this) - - def interrupt() { - stop() - } - - def recoverInterrupt() { - shouldStop = false - } - - private var shouldStop = false - - override def stop() { - super.stop() - shouldStop = true - } - - def search(): Option[(Solution, Boolean)] = { - shouldStop = false - - while (!g.tree.isSolved && !shouldStop) { - searchStep() - } - g.tree.solution.map(s => (s, g.tree.isTrustworthy)) - } -} diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 65edb8ff8..94b80d36b 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -16,7 +16,7 @@ import leon.utils.Simplifiers // Defines a synthesis solution of the form: // ⟨ P | T ⟩ -class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { +class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrusted: Boolean = true) { override def toString = "⟨ "+pre+" | "+defs.mkString(" ")+" "+term+" ⟩" def guardedTerm = { @@ -60,8 +60,8 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { object Solution { def simplify(e: Expr) = simplifyLets(e) - def apply(pre: Expr, defs: Set[FunDef], term: Expr) = { - new Solution(simplify(pre), defs, simplify(term)) + def apply(pre: Expr, defs: Set[FunDef], term: Expr, isTrusted: Boolean = true) = { + new Solution(simplify(pre), defs, simplify(term), isTrusted) } def unapply(s: Solution): Option[(Expr, Set[FunDef], Expr)] = if (s eq null) None else Some((s.pre, s.defs, s.term)) @@ -78,4 +78,8 @@ object Solution { def simplest(t: TypeTree): Solution = { new Solution(BooleanLiteral(true), Set(), simplestValue(t)) } + + def failed(p: Problem): Solution = { + new Solution(BooleanLiteral(true), Set(), Error("Failed").setType(TupleType(p.xs.map(_.getType)))) + } } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index cf1321c95..c58e8fb3d 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -9,6 +9,9 @@ import purescala.Trees._ import purescala.Common._ import purescala.ScalaPrinter import purescala.Definitions.{Program, FunDef} +import leon.utils.ASCIIHelpers + +import graph._ object SynthesisPhase extends LeonPhase[Program, Program] { val name = "Synthesis" @@ -116,6 +119,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { options } + def run(ctx: LeonContext)(p: Program): Program = { val options = processOptions(ctx) @@ -132,32 +136,27 @@ object SynthesisPhase extends LeonPhase[Program, Program] { var functions = Set[FunDef]() - val results = chooses.map { ci => - val (sol, isComplete) = ci.synthesizer.synthesize() + chooses.foreach { ci => + val (search, solutions) = ci.synthesizer.validate(ci.synthesizer.synthesize()) val fd = ci.fd + if (ci.synthesizer.options.generateDerivationTrees) { + val dot = new DotGenerator(search.g) + dot.writeFile("derivation"+DotGenerator.nextId()+".dot") + } + + val (sol, _) = solutions.head + val expr = sol.toSimplifiedExpr(ctx, p) fd.body = fd.body.map(b => replace(Map(ci.source -> expr), b)) - functions += fd + } - ci -> expr - }.toMap - - if (options.inPlace) { - for (file <- ctx.files) { - new FileInterface(ctx.reporter).updateFile(file, results) - } - } else { - for (fd <- functions) { - val middle = " "+fd.id.name+" " - val remSize = (80-middle.length) - ctx.reporter.info("-"*math.floor(remSize/2).toInt+middle+"-"*math.ceil(remSize/2).toInt) - - ctx.reporter.info(ScalaPrinter(fd)) - ctx.reporter.info("") - } + for (fd <- functions) { + ctx.reporter.info(ASCIIHelpers.title(fd.id.name)) + ctx.reporter.info(ScalaPrinter(fd)) + ctx.reporter.info("") } p diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 5ac84fa23..5598f59b8 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -15,7 +15,7 @@ import solvers.z3._ import java.io.File -import synthesis.search._ +import synthesis.graph._ class Synthesizer(val context : LeonContext, val functionContext: FunDef, @@ -25,45 +25,49 @@ class Synthesizer(val context : LeonContext, val reporter = context.reporter - def synthesize(): (Solution, Boolean) = { - - val search = if (options.manualSearch) { - new ManualSearch(this, problem) - } else if (options.searchWorkers > 1) { - new ParallelSearch(this, problem, options.searchWorkers) - } else { - options.searchBound match { - case Some(b) => - new BoundedSearch(this, problem, b) + def getSearch(): Search = { + if (options.manualSearch) { + new ManualSearch(context, problem, options.costModel) + } else if (options.searchWorkers > 1) { + ??? + //new ParallelSearch(this, problem, options.searchWorkers) + } else { + new SimpleSearch(context, problem, options.costModel, options.searchBound) + } + } - case None => - new SimpleSearch(this, problem) - } - } + def synthesize(): (Search, Stream[Solution]) = { + val s = getSearch(); val t = context.timers.synthesis.search.start() - val res = search.search() + val sctx = SynthesisContext.fromSynthesizer(this) + val sols = s.search(sctx) val diff = t.stop() reporter.info("Finished in "+diff+"ms") - if (options.generateDerivationTrees) { - val converter = new AndOrGraphDotConverter(search.g, options.firstOnly) - converter.writeFile("derivation"+AndOrGraphDotConverterCounter.next()+".dot") - } + (s, sols) + } - res match { - case Some((solution, true)) => - (solution, true) - case Some((sol, false)) => - validateSolution(search, sol, 5000L) - case None => - (new AndOrGraphPartialSolution(search.g, (task: TaskRunRule) => Solution.choose(task.problem), true).getSolution, false) + def validate(results: (Search, Stream[Solution])): (Search, Stream[(Solution, Boolean)]) = { + val (s, sols) = results + + val result = sols.map { + case sol if sol.isTrusted => + (sol, true) + case sol => + validateSolution(s, sol, 5000L) } + + (s, if (result.isEmpty) { + List((new PartialSolution(s.g, true).getSolution, false)).toStream + } else { + result + }) } - def validateSolution(search: AndOrGraphSearch[TaskRunRule, TaskTryRules, Solution], sol: Solution, timeoutMs: Long): (Solution, Boolean) = { + def validateSolution(search: Search, sol: Solution, timeoutMs: Long): (Solution, Boolean) = { import verification.AnalysisPhase._ import verification.VerificationContext @@ -87,7 +91,7 @@ class Synthesizer(val context : LeonContext, reporter.warning("Solution was invalid:") reporter.warning(fds.map(ScalaPrinter(_)).mkString("\n\n")) reporter.warning(vcreport.summaryString) - (new AndOrGraphPartialSolution(search.g, (task: TaskRunRule) => Solution.choose(task.problem), false).getSolution, false) + (new PartialSolution(search.g, false).getSolution, false) } } @@ -119,3 +123,4 @@ class Synthesizer(val context : LeonContext, (npr, newDefs) } } + diff --git a/src/main/scala/leon/synthesis/Task.scala b/src/main/scala/leon/synthesis/Task.scala deleted file mode 100644 index c1f16cba4..000000000 --- a/src/main/scala/leon/synthesis/Task.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -/* -class Task(synth: Synthesizer, - val parent: Task, - val problem: Problem, - val rule: Rule) extends Ordered[Task] { - - def compare(that: Task) = { - val cproblem = -(this.problem.complexity compare that.problem.complexity) // problem DESC - - if (cproblem == 0) { - // On equal complexity, order tasks by rule priority - this.rule.priority-that.rule.priority - } else { - cproblem - } - } - - def isBetterSolutionThan(sol: Solution, osol: Option[Solution]): Boolean = { - osol match { - case Some(s) => s.complexity > sol.complexity - case None => true - } - } - - var solution: Option[Solution] = None - var solver: Option[RuleApplication] = None - - var alternatives = Traversable[RuleApplication]() - - var minComplexity: AbsSolComplexity = new FixedSolComplexity(0) - - class TaskStep(val subProblems: List[Problem]) { - var subSolutions = Map[Problem, Solution]() - var subSolvers = Map[Problem, Task]() - var failures = Set[Rule]() - } - - class RuleApplication( - val initProblems: List[Problem], - val allNextSteps: List[List[Solution] => List[Problem]], - val onSuccess: List[Solution] => (Solution, Boolean)) { - - var allSteps = List(new TaskStep(initProblems)) - var nextSteps = allNextSteps - def currentStep = allSteps.head - - def isSolvedFor(p: Problem) = allSteps.exists(_.subSolutions contains p) - - def partlySolvedBy(t: Task, s: Solution) { - if (currentStep.subProblems.contains(t.problem)) { - if (isBetterSolutionThan(s, currentStep.subSolutions.get(t.problem))) { - currentStep.subSolutions += t.problem -> s - currentStep.subSolvers += t.problem -> t - - if (currentStep.subSolutions.size == currentStep.subProblems.size) { - val solutions = currentStep.subProblems map currentStep.subSolutions - - if (!nextSteps.isEmpty) { - // Advance to the next step - val newProblems = nextSteps.head.apply(solutions) - nextSteps = nextSteps.tail - - synth.addProblems(Task.this, newProblems) - - allSteps = new TaskStep(newProblems) :: allSteps - } else { - onSuccess(solutions) match { - case (s, true) => - if (isBetterSolutionThan(s, solution)) { - solution = Some(s) - solver = Some(this) - - parent.partlySolvedBy(Task.this, s) - } - case _ => - // solution is there, but it is incomplete (precondition not strongest) - //parent.partlySolvedBy(this, Solution.choose(problem)) - } - } - } - } - } - } - - val minComplexity: AbsSolComplexity = { - val simplestSubSolutions = allNextSteps.foldLeft(initProblems.map(Solution.basic(_))){ - (sols, step) => step(sols).map(Solution.basic(_)) - } - val simplestSolution = onSuccess(simplestSubSolutions)._1 - new FixedSolComplexity(parent.minComplexity.value + simplestSolution.complexity.value) - } - } - - // Is this subproblem already fixed? - def isSolvedFor(problem: Problem): Boolean = parent.isSolvedFor(this.problem) || (alternatives.exists(_.isSolvedFor(problem))) - - def partlySolvedBy(t: Task, s: Solution) { - alternatives.foreach(_.partlySolvedBy(t, s)) - } - - def run(): List[Problem] = { - rule.applyOn(this) match { - case RuleSuccess(solution) => - // Solved - this.solution = Some(solution) - parent.partlySolvedBy(this, solution) - - Nil - - case RuleAlternatives(xs) if xs.isEmpty => - // Inapplicable - Nil - - case RuleAlternatives(steps) => - this.alternatives = steps.map( step => new RuleApplication(step.subProblems, step.interSteps, step.onSuccess) ) - - this.alternatives.flatMap(_.initProblems).toList - } - } - - override def toString = "Applying "+rule+" on "+problem -} - -class RootTask(synth: Synthesizer, problem: Problem) extends Task(synth, null, problem, null) { - var solverTask: Option[Task] = None - - override def run() = { - List(problem) - } - - override def isSolvedFor(problem: Problem) = solverTask.isDefined - - override def partlySolvedBy(t: Task, s: Solution) { - if (isBetterSolutionThan(s, solution)) { - solution = Some(s) - solverTask = Some(t) - } - } -} -*/ diff --git a/src/main/scala/leon/synthesis/TaskRunRule.scala b/src/main/scala/leon/synthesis/TaskRunRule.scala deleted file mode 100644 index d3b13f25b..000000000 --- a/src/main/scala/leon/synthesis/TaskRunRule.scala +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -import synthesis.search._ - -case class TaskRunRule(app: RuleInstantiation) extends AOAndTask[Solution] { - - val problem = app.problem - val rule = app.rule - - def composeSolution(sols: List[Solution]): Option[Solution] = { - app.onSuccess(sols) - } - - override def toString = rule.name -} diff --git a/src/main/scala/leon/synthesis/TaskTryRules.scala b/src/main/scala/leon/synthesis/TaskTryRules.scala deleted file mode 100644 index 4511fc097..000000000 --- a/src/main/scala/leon/synthesis/TaskTryRules.scala +++ /dev/null @@ -1,10 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -import synthesis.search._ - -case class TaskTryRules(p: Problem) extends AOOrTask[Solution] { - override def toString = p.toString -} diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala new file mode 100644 index 000000000..339deea13 --- /dev/null +++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala @@ -0,0 +1,141 @@ +package leon.synthesis.graph + +import java.io.{File, FileWriter, BufferedWriter} + +import leon.synthesis.Distribution + +class DotGenerator(g: Graph) { + import g.{Node, AndNode, OrNode, RootNode} + + private[this] var _nextID = 0 + def freshName(prefix: String) = { + _nextID += 1 + prefix+_nextID + } + + def writeFile(f: File): Unit = { + val out = new BufferedWriter(new FileWriter(f)) + out.write(asString) + out.close() + } + + def writeFile(path: String): Unit = writeFile(new File(path)) + + + def asString: String = { + val res = new StringBuffer() + + res append "digraph D {\n" + + // Print all nodes + val edges = collectEdges(g.root) + val nodes = edges.flatMap(e => Set(e._1, e._2)) + + var nodesToNames = Map[Node, String]() + + for (n <- nodes) { + val name = freshName("node") + + n match { + case ot: OrNode => + drawNode(res, name, ot) + case at: AndNode => + drawNode(res, name, at) + } + + nodesToNames += n -> name + } + + for ((f,t) <- edges) { + val label = f match { + case ot: OrNode => + "or" + case at: AndNode => + "" + } + + val style = if (f.selected contains t) { + ", style=\"bold\"" + } else { + "" + } + + res append " "+nodesToNames(f)+" -> "+nodesToNames(t) +" [label=\""+label+"\""+style+"]\n"; + } + + res append "}\n" + + res.toString + } + + def distrib(d: Distribution): String = { + if (d.firstNonZero == g.maxCost) { + ">max" + } else { + d.firstNonZero.toString + } + } + + def limit(o: Any, length: Int = 40): String = { + val str = o.toString + if (str.length > length) { + str.substring(0, length-3)+"..." + } else { + str + } + } + + def nodeDesc(n: Node): String = n match { + case an: AndNode => an.ri.toString + case on: OrNode => on.p.toString + } + + def drawNode(res: StringBuffer, name: String, n: Node) { + + def escapeHTML(str: String) = str.replaceAll("&", "&").replaceAll("<", "<").replaceAll(">", ">") + + val color = if (n.isSolved) { + "palegreen" + } else if (n.isClosed) { + "firebrick" + } else if(n.isExpanded) { + "grey80" + } else { + "white" + } + + + res append " "+name+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\">" + + //cost + n match { + case an: AndNode => + res append "<TR><TD BORDER=\"0\">"+escapeHTML(distrib(n.costDist)+" ("+distrib(an.selfCost))+")</TD></TR>" + case on: OrNode => + res append "<TR><TD BORDER=\"0\">"+escapeHTML(distrib(n.costDist))+"</TD></TR>" + } + + res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(nodeDesc(n)))+"</TD></TR>"; + + if (n.isSolved) { + res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.toString))+"</TD></TR>" + } + + res append "</TABLE>>, shape = \"none\" ];\n" + + } + + private def collectEdges(from: Node): Set[(Node, Node)] = { + from.descendents.flatMap { d => + Set(from -> d) ++ collectEdges(d) + }.toSet + } +} + +object DotGenerator { + private[this] var _nextID = 0 + def nextId() = { + _nextID += 1 + _nextID + } +} diff --git a/src/main/scala/leon/synthesis/graph/Graph.scala b/src/main/scala/leon/synthesis/graph/Graph.scala new file mode 100644 index 000000000..bea159f42 --- /dev/null +++ b/src/main/scala/leon/synthesis/graph/Graph.scala @@ -0,0 +1,207 @@ +package leon +package synthesis +package graph + +import leon.utils.StreamUtils.cartesianProduct + +sealed class Graph(problem: Problem, costModel: CostModel) { + type Cost = Int + + var maxCost = 100; + + val root = new RootNode(problem) + + sealed abstract class Node(parent: Option[Node]) { + var parents: List[Node] = parent.toList + var descendents: List[Node] = Nil + + // indicates whether this particular node has already been expanded + var isExpanded: Boolean = false + def expand(sctx: SynthesisContext) + + val p: Problem + + // costs + var costDist: Distribution + def onNewDist(desc: Node) + + var isSolved: Boolean = false + + def isClosed: Boolean = { + costDist.total == 0 + } + + def onSolved(desc: Node) + + // Solutions this terminal generates (!= None for terminals) + var solutions: Option[Stream[Solution]] = None + var selectedSolution = -1 + + // For non-terminals, selected childs for solution + var selected: List[Node] = Nil + + def composeSolutions(sols: List[Stream[Solution]]): Stream[Solution] + + // Generate solutions given selection+solutions + def generateSolutions(): Stream[Solution] = { + solutions.getOrElse { + composeSolutions(selected.map(n => n.generateSolutions())) + } + } + } + + class AndNode(parent: Option[Node], val ri: RuleInstantiation) extends Node(parent) { + val p = ri.problem + var selfCost = Distribution.point(maxCost, costModel.ruleAppCost(ri)) + var costDist: Distribution = Distribution.uniformFrom(maxCost, costModel.ruleAppCost(ri), 0.5) + + override def toString = "\u2227 "+ri; + + def expand(sctx: SynthesisContext): Unit = { + require(!isExpanded) + isExpanded = true + + import sctx.reporter.info + + val prefix = "[%-20s] ".format(Option(ri.rule).getOrElse("?")) + + info(prefix+ri.problem) + + ri.apply(sctx) match { + case RuleClosed(sols) => + solutions = Some(sols) + selectedSolution = 0; + + costDist = sols.foldLeft(Distribution.empty(maxCost)) { + (d, sol) => d or Distribution.point(maxCost, costModel.solutionCost(sol)) + } + + isSolved = sols.nonEmpty + + if (sols.isEmpty) { + info(prefix+"Failed") + } else { + val sol = sols.head + info(prefix+"Solved"+(if(sol.isTrusted) "" else " (untrusted)")+" with: "+sol+"...") + } + + parents.foreach{ p => + p.onNewDist(this) + if (isSolved) { + p.onSolved(this) + } + } + + case RuleExpanded(probs) => + info(prefix+"Decomposed into:") + for(p <- probs) { + info(prefix+" - "+p) + } + + descendents = probs.map(p => new OrNode(Some(this), p)) + + selected = descendents + + recomputeCost() + } + } + + def composeSolutions(solss: List[Stream[Solution]]): Stream[Solution] = { + cartesianProduct(solss).flatMap { + sols => ri.onSuccess(sols) + } + } + + def onNewDist(desc: Node) = { + recomputeCost() + } + + private def recomputeCost() = { + val newCostDist = descendents.foldLeft(selfCost){ + case (c, d) => c and d.costDist + } + + if (newCostDist != costDist) { + costDist = newCostDist + parents.foreach(_.onNewDist(this)) + } + } + + private var solveds = Set[Node]() + + def onSolved(desc: Node): Unit = { + // We store everything within solveds + solveds += desc + + // Everything is solved correctly + if (solveds.size == descendents.size) { + isSolved = true; + parents.foreach(_.onSolved(this)) + } + } + + } + + class OrNode(parent: Option[Node], val p: Problem) extends Node(parent) { + var costDist: Distribution = Distribution.uniformFrom(maxCost, costModel.problemCost(p), 0.5) + + override def toString = "\u2228 "+p; + + def expand(sctx: SynthesisContext): Unit = { + require(!isExpanded) + + val ris = Rules.getInstantiations(sctx, p) + + descendents = ris.map(ri => new AndNode(Some(this), ri)) + selected = List() + + recomputeCost() + + isExpanded = true + } + + def onSolved(desc: Node): Unit = { + isSolved = true + selected = List(desc) + parents.foreach(_.onSolved(this)) + } + + def composeSolutions(solss: List[Stream[Solution]]): Stream[Solution] = { + solss.toStream.flatten + } + + private def recomputeCost(): Unit = { + val newCostDist = descendents.foldLeft(Distribution.empty(maxCost)){ + case (c, d) => c or d.costDist + } + + if (costDist != newCostDist) { + costDist = newCostDist + parents.foreach(_.onNewDist(this)) + + } + } + + def onNewDist(desc: Node): Unit = { + recomputeCost() + } + } + + class RootNode(p: Problem) extends OrNode(None, p) + + // Returns closed/total + def getStats(from: Node = root): (Int, Int) = { + val isClosed = from.isClosed || from.isSolved + val self = (if (isClosed) 1 else 0, 1) + + if (!from.isExpanded) { + self + } else { + from.descendents.foldLeft(self) { + case ((c,t), d) => + val (sc, st) = getStats(d) + (c+sc, t+st) + } + } + } +} diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala new file mode 100644 index 000000000..030d14e44 --- /dev/null +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -0,0 +1,273 @@ +package leon +package synthesis +package graph + +import scala.annotation.tailrec + +import leon.utils.StreamUtils.cartesianProduct + +import scala.collection.mutable.ArrayBuffer +import leon.utils.Interruptible +import java.util.concurrent.atomic.AtomicBoolean + +abstract class Search(ctx: LeonContext, p: Problem, costModel: CostModel) extends Interruptible { + val g = new Graph(p, costModel); + + import g.{Node, AndNode, OrNode, RootNode} + + def findNodeToExpandFrom(n: Node): Option[Node] + + val interrupted = new AtomicBoolean(false); + + def doStep(n: Node, sctx: SynthesisContext) = { + n.expand(sctx) + } + + @tailrec + final def searchFrom(sctx: SynthesisContext, from: Node): Boolean = { + findNodeToExpandFrom(from) match { + case Some(n) => + doStep(n, sctx) + + if (from.isSolved) { + true + } else if (interrupted.get) { + false + } else { + searchFrom(sctx, from) + } + case None => + false + } + } + + def traversePathFrom(n: Node, path: List[Int]): Option[Node] = path match { + case Nil => + Some(n) + case x :: xs => + if (n.isExpanded && n.descendents.size > x) { + traversePathFrom(n.descendents(x), xs) + } else { + None + } + } + + def traversePath(path: List[Int]): Option[Node] = { + traversePathFrom(g.root, path) + } + + def search(sctx: SynthesisContext): Stream[Solution] = { + if (searchFrom(sctx, g.root)) { + g.root.generateSolutions() + } else { + Stream.empty + } + } + + def interrupt(): Unit = { + interrupted.set(true) + } + + def recoverInterrupt(): Unit = { + interrupted.set(false) + } + + ctx.interruptManager.registerForInterrupts(this) +} + +class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Option[Int]) extends Search(ctx, p, costModel) { + import g.{Node, AndNode, OrNode, RootNode} + + val expansionBuffer = ArrayBuffer[Node]() + + def findIn(n: Node) { + if (!n.isExpanded) { + expansionBuffer += n + } else if (!n.isClosed) { + n match { + case an: g.AndNode => + an.descendents.foreach(findIn) + + case on: g.OrNode => + if (on.descendents.nonEmpty) { + findIn(on.descendents.minBy(_.costDist)) + } + } + } + } + + var counter = 0; + def findNodeToExpandFrom(from: Node): Option[Node] = { + counter += 1 + if (!bound.isDefined || counter <= bound.get) { + if (expansionBuffer.isEmpty) { + findIn(from) + } + + if (expansionBuffer.nonEmpty) { + Some(expansionBuffer.remove(0)) + } else { + None + } + } else { + None + } + } +} + +class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) extends Search(ctx, problem, costModel) { + import g.{Node, AndNode, OrNode, RootNode} + + import ctx.reporter._ + + var cd = List[Int]() + var cmdQueue = List[String]() + + override def doStep(n: Node, sctx: SynthesisContext) = { + super.doStep(n, sctx); + + // Backtrack view to a point where node is neither closed nor solved + if (n.isClosed || n.isSolved) { + var from: Node = g.root + var newCd = List[Int]() + + while (!from.isSolved && !from.isClosed && newCd.size < cd.size) { + val cdElem = cd(newCd.size) + from = traversePathFrom(from, List(cdElem)).get + if (!from.isSolved && !from.isClosed) { + newCd = cdElem :: newCd + } + } + + cd = newCd.reverse + } + } + + def printGraph() { + def pathToString(path: List[Int]): String = { + val p = path.reverse.drop(cd.size) + if (p.isEmpty) { + "" + } else { + " "+p.mkString(" ") + } + } + + def title(str: String) = "\u001b[1m" + str + "\u001b[0m" + def failed(str: String) = "\u001b[31m" + str + "\u001b[0m" + def solved(str: String) = "\u001b[32m" + str + "\u001b[0m" + + def displayDist(d: Distribution): String = { + f"${d.firstNonZero}%3d" + } + + def displayNode(n: Node): String = n match { + case an: AndNode => + val app = an.ri + s"(${displayDist(n.costDist)}) $app" + case on: OrNode => + val p = on.p + s"(${displayDist(n.costDist)}) $p" + } + + def traversePathFrom(n: Node, prefix: List[Int]) { + val visible = (prefix endsWith cd.reverse) + + if (!n.isExpanded) { + if (visible) { + println(pathToString(prefix)+" \u2508 "+displayNode(n)) + } + } else if (n.isSolved) { + println(solved(pathToString(prefix)+" \u2508 "+displayNode(n))) + } else if (n.isClosed) { + println(failed(pathToString(prefix)+" \u2508 "+displayNode(n))) + } else { + if (visible) { + println(title(pathToString(prefix)+" \u2510 "+displayNode(n))) + } + } + + if (n.isExpanded && !n.isClosed && !n.isSolved) { + for ((sn, i) <- n.descendents.zipWithIndex) { + traversePathFrom(sn, i :: prefix) + } + } + } + + println("-"*80) + traversePathFrom(g.root, List()) + println("-"*80) + } + + var continue = true + + def findNodeToExpandFrom(from: Node): Option[Node] = { + if (!from.isExpanded) { + Some(from) + } else { + var res: Option[Node] = None + continue = true + + while(continue) { + printGraph() + + try { + print("Next action? (q to quit) "+cd.mkString(" ")+" $ ") + val line = if (cmdQueue.isEmpty) { + scala.io.StdIn.readLine() + } else { + val n = cmdQueue.head + println(n) + cmdQueue = cmdQueue.tail + n + } + if (line == "q") { + continue = false + res = None + } else if (line startsWith "cd") { + val parts = line.split("\\s+").toList + + parts match { + case List("cd") => + cd = List() + case List("cd", "..") => + if (cd.size > 0) { + cd = cd.dropRight(1) + } + case "cd" :: parts => + cd = cd ::: parts.map(_.toInt) + case _ => + } + + } else { + val parts = line.split("\\s+").toList + + val c = parts.head.toInt + cmdQueue = cmdQueue ::: parts.tail + + traversePath(cd ::: c :: Nil) match { + case Some(l) if !l.isExpanded => + res = Some(l) + cd = cd ::: c :: Nil + continue = false + case Some(n) => + cd = cd ::: c :: Nil + case None => + error("Invalid path") + } + } + } catch { + case e: java.lang.NumberFormatException => + + case e: java.io.IOException => + continue = false + + case e: Throwable => + error("Woops: "+e.getMessage()) + e.printStackTrace() + } + } + res + } + } +} diff --git a/src/main/scala/leon/synthesis/rules/AsChoose.scala b/src/main/scala/leon/synthesis/rules/AsChoose.scala index fc43972f0..6d71c5341 100644 --- a/src/main/scala/leon/synthesis/rules/AsChoose.scala +++ b/src/main/scala/leon/synthesis/rules/AsChoose.scala @@ -8,7 +8,7 @@ case object AsChoose extends Rule("As Choose") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { Some(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { def apply(sctx: SynthesisContext) = { - RuleSuccess(Solution.choose(p)) + RuleClosed(Solution.choose(p)) } }) } diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 6109038c7..fb5c5dc26 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -451,8 +451,8 @@ case object CEGIS extends Rule("CEGIS") { } List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplicationResult = { - var result: Option[RuleApplicationResult] = None + def apply(sctx: SynthesisContext): RuleApplication = { + var result: Option[RuleApplication] = None var ass = p.as.toSet var xss = p.xs.toSet @@ -487,11 +487,11 @@ case object CEGIS extends Rule("CEGIS") { baseExampleInputs = p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) +: baseExampleInputs case Some(false) => - return RuleApplicationImpossible + return RuleFailed() case None => sctx.reporter.warning("Solver could not solve path-condition") - return RuleApplicationImpossible // This is not necessary though, but probably wanted + return RuleFailed() // This is not necessary though, but probably wanted } } finally { solver.free() @@ -524,7 +524,7 @@ case object CEGIS extends Rule("CEGIS") { def allInputExamples() = baseExampleInputs.iterator ++ cachedInputIterator - def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplicationResult = { + def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplication = { for (prog <- programs) { val expr = ndProgram.determinize(prog) val res = Equals(Tuple(p.xs.map(Variable(_))), expr) @@ -535,9 +535,9 @@ case object CEGIS extends Rule("CEGIS") { try { solver3.check match { case Some(false) => - return RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = true) + return RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = true)) case None => - return RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = false) + return RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false)) case Some(true) => // invalid program, we skip } @@ -546,7 +546,7 @@ case object CEGIS extends Rule("CEGIS") { } } - RuleApplicationImpossible + RuleFailed() } // Keep track of collected cores to filter programs to test @@ -630,7 +630,7 @@ case object CEGIS extends Rule("CEGIS") { } else if (nPassing <= testUpTo) { // Immediate Test checkForPrograms(prunedPrograms) match { - case rs: RuleSuccess => + case rs: RuleClosed => result = Some(rs) case _ => } @@ -725,7 +725,7 @@ case object CEGIS extends Rule("CEGIS") { bssAssumptions case None => - return RuleApplicationImpossible + return RuleFailed() } solver1.pop() @@ -747,16 +747,16 @@ case object CEGIS extends Rule("CEGIS") { val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr))) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) case _ => if (useOptTimeout) { // Interpret timeout in CE search as "the candidate is valid" sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), expr), isTrusted = false)) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) } else { - return RuleApplicationImpossible + return RuleFailed() } } } @@ -767,7 +767,7 @@ case object CEGIS extends Rule("CEGIS") { solver1.check match { case Some(false) => // Unsat even without blockers (under which fcalls are then uninterpreted) - return RuleApplicationImpossible + return RuleFailed() case _ => } @@ -784,13 +784,13 @@ case object CEGIS extends Rule("CEGIS") { unrolings += 1 } while(unrolings < maxUnrolings && result.isEmpty && !interruptManager.isInterrupted()) - result.getOrElse(RuleApplicationImpossible) + result.getOrElse(RuleFailed()) } catch { case e: Throwable => sctx.reporter.warning("CEGIS crashed: "+e.getMessage) e.printStackTrace - RuleApplicationImpossible + RuleFailed() } finally { solver1.free() solver2.free() diff --git a/src/main/scala/leon/synthesis/rules/Ground.scala b/src/main/scala/leon/synthesis/rules/Ground.scala index 52e5e0c13..c2bc50f51 100644 --- a/src/main/scala/leon/synthesis/rules/Ground.scala +++ b/src/main/scala/leon/synthesis/rules/Ground.scala @@ -14,7 +14,7 @@ case object Ground extends Rule("Ground") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { if (p.as.isEmpty) { List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplicationResult = { + def apply(sctx: SynthesisContext): RuleApplication = { val solver = SimpleSolverAPI(new TimeoutSolverFactory(sctx.solverFactory, 10000L)) val tpe = TupleType(p.xs.map(_.getType)) @@ -22,12 +22,12 @@ case object Ground extends Rule("Ground") { val result = solver.solveSAT(p.phi) match { case (Some(true), model) => val sol = Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(model))).setType(tpe)) - RuleSuccess(sol) + RuleClosed(sol) case (Some(false), model) => val sol = Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe)) - RuleSuccess(sol) + RuleClosed(sol) case _ => - RuleApplicationImpossible + RuleFailed() } result diff --git a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala index 148085a5a..d9f15a2a3 100644 --- a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala +++ b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala @@ -27,9 +27,9 @@ case object OptimisticGround extends Rule("Optimistic Ground") { var i = 0; var maxTries = 3; - var result: Option[RuleApplicationResult] = None - var continue = true - var predicates: Seq[Expr] = Seq() + var result: Option[RuleApplication] = None + var continue = true + var predicates: Seq[Expr] = Seq() while (result.isEmpty && i < maxTries && continue) { val phi = And(p.pc +: p.phi +: predicates) @@ -47,7 +47,7 @@ case object OptimisticGround extends Rule("Optimistic Ground") { predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates case (Some(false), _) => - result = Some(RuleSuccess(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe)))) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(satModel))).setType(tpe)))) case _ => continue = false @@ -56,7 +56,7 @@ case object OptimisticGround extends Rule("Optimistic Ground") { case (Some(false), _) => if (predicates.isEmpty) { - result = Some(RuleSuccess(Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe)))) + result = Some(RuleClosed(Solution(BooleanLiteral(false), Set(), Error(p.phi+" is UNSAT!").setType(tpe)))) } else { continue = false result = None @@ -69,7 +69,7 @@ case object OptimisticGround extends Rule("Optimistic Ground") { i += 1 } - result.getOrElse(RuleApplicationImpossible) + result.getOrElse(RuleFailed()) } } List(res) diff --git a/src/main/scala/leon/synthesis/rules/Tegis.scala b/src/main/scala/leon/synthesis/rules/Tegis.scala index ca97ed859..15ac89ea9 100644 --- a/src/main/scala/leon/synthesis/rules/Tegis.scala +++ b/src/main/scala/leon/synthesis/rules/Tegis.scala @@ -32,7 +32,7 @@ case object TEGIS extends Rule("TEGIS") { var tests = p.getTests(sctx).map(_.ins).distinct if (tests.nonEmpty) { List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplicationResult = { + def apply(sctx: SynthesisContext): RuleApplication = { val evalParams = CodeGenParams(maxFunctionInvocations = 2000, checkContracts = true) //val evaluator = new CodeGenEvaluator(sctx.context, sctx.program, evalParams) @@ -143,54 +143,61 @@ case object TEGIS extends Rule("TEGIS") { var candidate: Option[Expr] = None var enumLimit = 10000; - var n = 1; - timers.generating.start() - allExprs.take(enumLimit).takeWhile(e => candidate.isEmpty).foreach { e => - val exprToTest = if (!isWrapped) { - Let(p.xs.head, e, p.phi) - } else { - letTuple(p.xs, e, p.phi) - } - sctx.reporter.debug("Got expression "+e) - timers.testing.start() - if (tests.forall{ case t => - val ts = System.currentTimeMillis - val res = evaluator.eval(exprToTest, p.as.zip(t).toMap) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => - sctx.reporter.debug("Test "+t+" passed!") - true - case _ => - sctx.reporter.debug("Test "+t+" failed on "+e) - failStat += t -> (failStat(t) + 1) - false - } - res - }) { - if (isWrapped) { - candidate = Some(e) + def findNext(): Option[Expr] = { + candidate = None + timers.generating.start() + allExprs.take(enumLimit).takeWhile(e => candidate.isEmpty).foreach { e => + val exprToTest = if (!isWrapped) { + Let(p.xs.head, e, p.phi) } else { - candidate = Some(Tuple(Seq(e))) + letTuple(p.xs, e, p.phi) } - } - timers.testing.stop() - if (n % 50 == 0) { - tests = tests.sortBy(t => -failStat(t)) + sctx.reporter.debug("Got expression "+e) + timers.testing.start() + if (tests.forall{ case t => + val ts = System.currentTimeMillis + val res = evaluator.eval(exprToTest, p.as.zip(t).toMap) match { + case EvaluationResults.Successful(BooleanLiteral(true)) => + sctx.reporter.debug("Test "+t+" passed!") + true + case _ => + sctx.reporter.debug("Test "+t+" failed on "+e) + failStat += t -> (failStat(t) + 1) + false + } + res + }) { + if (isWrapped) { + candidate = Some(e) + } else { + candidate = Some(Tuple(Seq(e))) + } + } + timers.testing.stop() + + if (n % 50 == 0) { + tests = tests.sortBy(t => -failStat(t)) + } + n += 1 } - n += 1 - } - timers.generating.stop() + timers.generating.stop() - //println("Found candidate "+n) - //println("Compiled: "+evaluator.unit.compiledN) + candidate + } - if (candidate.isDefined) { - RuleSuccess(Solution(BooleanLiteral(true), Set(), candidate.get), isTrusted = false) - } else { - RuleApplicationImpossible + def toStream(): Stream[Solution] = { + findNext() match { + case Some(e) => + Stream.cons(Solution(BooleanLiteral(true), Set(), e, isTrusted = false), toStream()) + case None => + Stream.empty + } } + + RuleClosed(toStream()) } }) } else { diff --git a/src/main/scala/leon/synthesis/search/AndOrGraph.scala b/src/main/scala/leon/synthesis/search/AndOrGraph.scala deleted file mode 100644 index ef0ef1af9..000000000 --- a/src/main/scala/leon/synthesis/search/AndOrGraph.scala +++ /dev/null @@ -1,383 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon.synthesis.search - -trait AOTask[S] { } - -trait AOAndTask[S] extends AOTask[S] { - def composeSolution(sols: List[S]): Option[S] -} - -trait AOOrTask[S] extends AOTask[S] { -} - -trait AOCostModel[AT <: AOAndTask[S], OT <: AOOrTask[S], S] { - def taskCost(at: AOTask[S]): Cost - def solutionCost(s: S): Cost -} - -class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val costModel: AOCostModel[AT, OT, S]) { - var tree: OrTree = RootNode - - object LeafOrdering extends Ordering[Leaf] { - def compare(a: Leaf, b: Leaf) = { - val diff = scala.math.Ordering.Iterable[Int].compare(a.minReachCost, b.minReachCost) - if (diff == 0) { - if (a == b) { - 0 - } else { - a.## - b.## - } - } else { - diff - } - } - } - - val leaves = collection.mutable.TreeSet()(LeafOrdering) - leaves += RootNode - - trait Tree { - val task : AOTask[S] - val parent: Node[_] - - def minCost: Cost - - var minReachCost = List[Int]() - - def updateMinReach(reverseParent: List[Int]); - def removeLeaves(); - - var isTrustworthy: Boolean = true - var solution: Option[S] = None - var isUnsolvable: Boolean = false - - def isSolved: Boolean = solution.isDefined - } - - abstract class AndTree extends Tree { - override val task: AT - } - - abstract class OrTree extends Tree { - override val task: OT - } - - - trait Leaf extends Tree { - val minCost = costModel.taskCost(task) - - var removedLeaf = false; - - def updateMinReach(reverseParent: List[Int]) { - if (!removedLeaf) { - leaves -= this - minReachCost = (minCost.value :: reverseParent).reverse - leaves += this - } - } - - def removeLeaves() { - removedLeaf = true - leaves -= this - } - } - - trait Node[T <: Tree] extends Tree { - def unsolvable(l: T) - def notifySolution(sub: T, sol: S) - - } - - class AndNode(val parent: OrNode, val subTasks: List[OT], val task: AT) extends AndTree with Node[OrTree] { - var subProblems = Map[OT, OrTree]() - var subSolutions = Map[OT, S]() - - var minCost = Cost.zero - - def updateMin() { - val old = minCost - minCost = solution match { - case Some(s) => - costModel.solutionCost(s) - case _ => - val subCosts = subProblems.values.map(_.minCost) - - subCosts.foldLeft(costModel.taskCost(task))(_ + _) - } - if (minCost != old) { - if (parent ne null) { - parent.updateMin() - } else { - // Reached the root, propagate minReach up - updateMinReach(Nil) - } - } else { - // Reached boundary of update, propagate minReach up - updateMinReach(minReachCost.reverse.tail) - } - } - - def updateMinReach(reverseParent: List[Int]) { - val rev = minCost.value :: reverseParent - - minReachCost = rev.reverse - - subProblems.values.foreach(_.updateMinReach(rev)) - } - - def removeLeaves() { - subProblems.values.foreach(_.removeLeaves()) - } - - def unsolvable(l: OrTree) { - isUnsolvable = true - - this.removeLeaves() - - parent.unsolvable(this) - } - - def expandLeaf(l: OrLeaf, succ: List[AT]) { - //println("[[2]] Expanding "+l.task+" to: ") - //for (t <- succ) { - // println(" - "+t) - //} - - //println("BEFORE: In leaves we have: ") - //for (i <- leaves.iterator) { - // println("-> "+i.minReachCost+" == "+i.task) - //} - - if (!l.removedLeaf) { - l.removeLeaves() - - val orNode = new OrNode(this, succ, l.task) - subProblems += l.task -> orNode - - updateMin() - - leaves ++= orNode.andLeaves.values - } - - //println("AFTER: In leaves we have: ") - //for (i <- leaves.iterator) { - // println("-> "+i.minReachCost+" == "+i.task) - //} - } - - def notifySolution(sub: OrTree, sol: S) { - subSolutions += sub.task -> sol - - sub match { - case l: Leaf => - l.removeLeaves() - case _ => - } - - if (subSolutions.size == subProblems.size) { - task.composeSolution(subTasks.map(subSolutions)) match { - case Some(sol) => - isTrustworthy = subProblems.forall(_._2.isTrustworthy) - solution = Some(sol) - updateMin() - - notifyParent(sol) - - case None => - if (solution.isEmpty) { - unsolvable(sub) - } - } - } else { - updateMin() - } - - } - - def notifyParent(sol: S) { - Option(parent).foreach(_.notifySolution(this, sol)) - } - } - - object RootNode extends OrLeaf(null, root) { - - minReachCost = List(minCost.value) - - override def expandWith(succ: List[AT]) { - this.removeLeaves() - - val orNode = new OrNode(null, succ, root) - tree = orNode - - leaves ++= orNode.andLeaves.values - } - } - - class AndLeaf(val parent: OrNode, val task: AT) extends AndTree with Leaf { - def expandWith(succ: List[OT]) { - parent.expandLeaf(this, succ) - } - - } - - - class OrNode(val parent: AndNode, val altTasks: List[AT], val task: OT) extends OrTree with Node[AndTree] { - val andLeaves = altTasks.map(t => t -> new AndLeaf(this, t)).toMap - var alternatives: Map[AT, AndTree] = andLeaves - var triedAlternatives = Map[AT, AndTree]() - var minAlternative: AndTree = _ - var minCost = costModel.taskCost(task) - - updateMin() - - def updateMin() { - if (!alternatives.isEmpty) { - minAlternative = alternatives.values.minBy(_.minCost) - val old = minCost - minCost = minAlternative.minCost - - //println("Updated minCost of "+task+" from "+old.value+" to "+minCost.value) - - if (minCost != old) { - if (parent ne null) { - parent.updateMin() - } else { - // reached root, propagate minReach up - updateMinReach(Nil) - } - } else { - updateMinReach(minReachCost.reverse.tail) - } - } else { - minAlternative = null - minCost = Cost.zero - } - } - - def updateMinReach(reverseParent: List[Int]) { - val rev = minCost.value :: reverseParent - - minReachCost = rev.reverse - - alternatives.values.foreach(_.updateMinReach(rev)) - } - - def removeLeaves() { - alternatives.values.foreach(_.removeLeaves()) - } - - def unsolvable(l: AndTree) { - if (alternatives contains l.task) { - triedAlternatives += l.task -> alternatives(l.task) - alternatives -= l.task - - l.removeLeaves() - - if (alternatives.isEmpty) { - isUnsolvable = true - if (parent ne null) { - parent.unsolvable(this) - } - } else { - updateMin() - } - } - } - - def expandLeaf(l: AndLeaf, succ: List[OT]) { - //println("[[1]] Expanding "+l.task+" to: ") - //for (t <- succ) { - // println(" - "+t) - //} - - //println("BEFORE: In leaves we have: ") - //for (i <- leaves.iterator) { - // println("-> "+i.minReachCost+" == "+i.task) - //} - - if (!l.removedLeaf) { - l.removeLeaves() - - val n = new AndNode(this, succ, l.task) - - val newLeaves = succ.map(t => t -> new OrLeaf(n, t)).toMap - n.subProblems = newLeaves - - alternatives += l.task -> n - - n.updateMin() - - updateMin() - - leaves ++= newLeaves.values - } - - - //println("AFTER: In leaves we have: ") - //for (i <- leaves.iterator) { - // println("-> "+i.minReachCost+" == "+i.task) - //} - } - - def notifySolution(sub: AndTree, sol: S) { - this.removeLeaves() - - solution match { - case Some(preSol) if (costModel.solutionCost(preSol) < costModel.solutionCost(sol)) => - isTrustworthy = sub.isTrustworthy - solution = Some(sol) - minAlternative = sub - - notifyParent(solution.get) - - case None => - isTrustworthy = sub.isTrustworthy - solution = Some(sol) - minAlternative = sub - - notifyParent(solution.get) - - case _ => - } - } - - def notifyParent(sol: S) { - if (parent ne null) { - parent.notifySolution(this, sol) - } - } - } - - class OrLeaf(val parent: AndNode, val task: OT) extends OrTree with Leaf { - def expandWith(succ: List[AT]) { - parent.expandLeaf(this, succ) - } - } - - def getStatus: (Int, Int) = { - var total: Int = 0 - var closed: Int = 0 - - def tr(t: Tree, isParentClosed: Boolean) { - val isClosed = isParentClosed || t.isSolved || t.isUnsolvable - if (isClosed) { - closed += 1 - } - total += 1 - - t match { - case an: AndNode => - an.subProblems.values.foreach(tr(_, isClosed)) - case on: OrNode => - (on.alternatives.values ++ on.triedAlternatives.values).foreach(tr(_, isClosed)) - case _ => - } - } - - tr(tree, false) - - (closed, total) - } -} - diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala b/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala deleted file mode 100644 index 5f180d3d2..000000000 --- a/src/main/scala/leon/synthesis/search/AndOrGraphDotConverter.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon.synthesis.search - -class AndOrGraphDotConverter[AT <: AOAndTask[S], - OT <: AOOrTask[S], - S](val g: AndOrGraph[AT, OT, S], firstOnly: Boolean) { - - - private[this] var _nextID = 0 - def freshName(prefix: String) = { - _nextID += 1 - prefix+_nextID - } - - override def toString: String = { - val res = new StringBuffer() - - res append "digraph D {\n" - - // Print all nodes - val (nodes, edges) = decomposeGraph - - var nodesToNames = Map[g.Tree, String]() - - for (n <- nodes) { - val name = freshName("node") - - n match { - case ot: g.OrTree => - drawNode(res, name, ot) - case at: g.AndTree => - drawNode(res, name, at) - } - - nodesToNames += n -> name - } - - for ((f,t, isMin) <- edges) { - val label = f match { - case ot: g.OrTree => - "or" - case at: g.AndTree => - "" - } - - res append " "+nodesToNames(f)+" -> "+nodesToNames(t) +" [label=\""+label+"\""+(if (isMin) ", style=bold" else "")+"]\n"; - } - - res append "}\n" - - res.toString - } - - def decomposeGraph: (Set[g.Tree], Set[(g.Tree, g.Tree, Boolean)]) = { - var nodes = Set[g.Tree]() - var edges = Set[(g.Tree, g.Tree, Boolean)]() - - def collect(n: g.Tree, wasMin: Boolean) { - nodes += n - - n match { - case an : g.AndNode => - for (sub <- an.subProblems.values) { - edges += ((n, sub, wasMin)) - collect(sub, wasMin) - } - case on : g.OrNode => - val alternatives:Traversable[g.AndTree] = if (firstOnly) { - Option(on.minAlternative) - } else { - on.alternatives.values - } - - for (sub <- alternatives) { - val isMin = sub == on.minAlternative - edges += ((n, sub, isMin)) - collect(sub, isMin) - } - case _ => - // ignore leaves - } - } - - collect(g.tree, false) - - (nodes, edges) - } - - - def drawNode(res: StringBuffer, name: String, t: g.Tree) { - - def escapeHTML(str: String) = str.replaceAll("<", "<").replaceAll(">", ">") - - val (color, style) = t match { - case l: g.Leaf => - (if (t.isSolved) "palegreen" else if (t.isUnsolvable) "firebrick" else "white" , "dashed") - case n: g.Node[_] => - (if (t.isSolved) "palegreen" else if (t.isUnsolvable) "firebrick" else "white", "") - } - - res append " "+name+" [ label = <<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\"><TR><TD BORDER=\"0\">self: "+g.costModel.taskCost(t.task).value+" | tree-min: "+t.minCost.value+"</TD></TR><TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(t.task.toString)+"</TD></TR>"; - - if (t.isSolved) { - res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(t.solution.get.toString)+"</TD></TR>" - } - - res append "</TABLE>>, shape = \"none\", style=\""+style+"\" ];\n" - - } - - /** Writes the graph to a file readable with GraphViz. */ - def writeFile(fname: String) { - import java.io.{BufferedWriter, FileWriter} - val out = new BufferedWriter(new FileWriter(fname)) - out.write(toString) - out.close() - } -} - -object AndOrGraphDotConverterCounter { - private var nextId = 0; - def next() = { - nextId += 1 - nextId - } -} - - diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala deleted file mode 100644 index 75124787e..000000000 --- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon.synthesis.search - -import leon.utils._ -import akka.actor._ -import scala.concurrent.duration._ -import scala.concurrent.Await -import akka.util.Timeout -import akka.pattern.ask -import akka.pattern.AskTimeoutException - -abstract class AndOrGraphParallelSearch[WC, - AT <: AOAndTask[S], - OT <: AOOrTask[S], - S](og: AndOrGraph[AT, OT, S], - nWorkers: Int) extends AndOrGraphSearch[AT, OT, S](og) { - - def initWorkerContext(w: ActorRef): WC - - val timeout = 600.seconds - - var system: ActorSystem = _ - - def search(): Option[(S, Boolean)] = { - system = ActorSystem("ParallelSearch") - - val master = system.actorOf(Props(new Master), name = "Master") - - val workers = for (i <- 1 to nWorkers) yield { - system.actorOf(Props(new Worker(master)), name = "Worker"+i) - } - - try { - Await.result(master.ask(Protocol.BeginSearch)(timeout), timeout) - } catch { - case e: AskTimeoutException => - } - - if (system ne null) { - system.shutdown - system = null - } - - g.tree.solution.map(s => (s, g.tree.isTrustworthy)) - } - - override def stop() { - super.stop() - - if(system ne null) { - system.shutdown - system = null - } - } - - - object Protocol { - case object BeginSearch - case object SearchDone - - case class WorkerNew(worker: ActorRef) - case class WorkerAndTaskDone(worker: ActorRef, res: ExpandResult[OT]) - case class WorkerOrTaskDone(worker: ActorRef, res: ExpandResult[AT]) - - case class RequestAndTask(task: AT) - case class RequestOrTask(task: OT) - case object NoTaskReady - } - - def getNextLeaves(idleWorkers: Map[ActorRef, Option[g.Leaf]], workingWorkers: Map[ActorRef, Option[g.Leaf]]): List[g.Leaf] = { - val processing = workingWorkers.values.flatten.toSet - - val ts = System.currentTimeMillis(); - - val str = nextLeaves() - .filterNot(processing) - .take(idleWorkers.size) - .toList - - str - } - - class Master extends Actor { - import Protocol._ - - var outer: ActorRef = _ - - var workers = Map[ActorRef, Option[g.Leaf]]() - - def sendWork() { - val (idleWorkers, workingWorkers) = workers.partition(_._2.isEmpty) - - assert(idleWorkers.size > 0) - - getNextLeaves(idleWorkers, workingWorkers) match { - case Nil => - if (workingWorkers.isEmpty) { - outer ! SearchDone - } else { - // No work yet, waiting for results from ongoing work - } - - case ls => - for ((w, leaf) <- idleWorkers.keySet zip ls) { - leaf match { - case al: g.AndLeaf => - workers += w -> Some(al) - w ! RequestAndTask(al.task) - case ol: g.OrLeaf => - workers += w -> Some(ol) - w ! RequestOrTask(ol.task) - } - } - } - } - - context.setReceiveTimeout(10.seconds) - - def receive = { - case BeginSearch => - outer = sender - - case WorkerNew(w) => - workers += w -> None - context.watch(w) - sendWork() - - case WorkerAndTaskDone(w, res) => - workers.get(w) match { - case Some(Some(l: g.AndLeaf)) => - onExpansion(l, res) - workers += w -> None - case _ => - } - sendWork() - - case WorkerOrTaskDone(w, res) => - workers.get(w) match { - case Some(Some(l: g.OrLeaf)) => - onExpansion(l, res) - workers += w -> None - case _ => - } - sendWork() - - case Terminated(w) => - if (workers contains w) { - workers -= w - } - - case ReceiveTimeout => - println("@ Worker status:") - for ((w, t) <- workers if t.isDefined) { - println("@ - "+w.toString+": "+t.get.task) - } - - - } - } - - class Worker(master: ActorRef) extends Actor { - import Protocol._ - - val ctx = initWorkerContext(self) - - def receive = { - case RequestAndTask(at) => - val res = expandAndTask(self, ctx)(at) - master ! WorkerAndTaskDone(self, res) - - case RequestOrTask(ot) => - val res = expandOrTask(self, ctx)(ot) - master ! WorkerOrTaskDone(self, res) - } - - override def preStart() = master ! WorkerNew(self) - } - - def expandAndTask(w: ActorRef, ctx: WC)(t: AT): ExpandResult[OT] - - def expandOrTask(w: ActorRef, ctx: WC)(t: OT): ExpandResult[AT] -} diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala deleted file mode 100644 index 9d13444ed..000000000 --- a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon.synthesis.search - -class AndOrGraphPartialSolution[AT <: AOAndTask[S], - OT <: AOOrTask[S], - S](val g: AndOrGraph[AT, OT, S], missing: AT => S, includeUntrusted: Boolean) { - - - def getSolution: S = { - solveOr(g.tree) - } - - def solveAnd(t: g.AndTree): S = { - if (t.isSolved && (includeUntrusted || t.isTrustworthy)) { - t.solution.get - } else { - t match { - case l: g.AndLeaf => - missing(t.task) - case n: g.AndNode => - n.task.composeSolution(n.subProblems.values.map(solveOr(_)).toList).getOrElse(missing(t.task)) - } - } - } - - def solveOr(t: g.OrTree): S = { - if (t.isSolved && (includeUntrusted || t.isTrustworthy)) { - t.solution.get - } else { - t match { - case l: g.OrLeaf => - missing(l.parent.task) - case n: g.OrNode => - solveAnd(n.minAlternative) - } - } - } -} diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala deleted file mode 100644 index 7453fe5c0..000000000 --- a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon.synthesis.search - -abstract class AndOrGraphSearch[AT <: AOAndTask[S], - OT <: AOOrTask[S], - S](val g: AndOrGraph[AT, OT, S]) { - - def nextLeaves(): Iterable[g.Leaf] = { - g.leaves - } - - def nextLeaf(): Option[g.Leaf] = { - nextLeaves().headOption - } - - abstract class ExpandResult[T <: AOTask[S]] - case class Expanded[T <: AOTask[S]](sub: List[T]) extends ExpandResult[T] - case class ExpandSuccess[T <: AOTask[S]](sol: S, isTrustworthy: Boolean) extends ExpandResult[T] - case class ExpandFailure[T <: AOTask[S]]() extends ExpandResult[T] - - def stop() { - - } - - def search(): Option[(S, Boolean)] - - def onExpansion(al: g.AndLeaf, res: ExpandResult[OT]) { - res match { - case Expanded(ls) => - al.expandWith(ls) - case r @ ExpandSuccess(sol, isTrustworthy) => - al.isTrustworthy = isTrustworthy - al.solution = Some(sol) - al.parent.notifySolution(al, sol) - case _ => - al.isUnsolvable = true - al.parent.unsolvable(al) - } - - if (g.tree.isSolved) { - stop() - } - } - - def onExpansion(ol: g.OrLeaf, res: ExpandResult[AT]) { - res match { - case Expanded(ls) => - ol.expandWith(ls) - case r @ ExpandSuccess(sol, isTrustworthy) => - ol.isTrustworthy = isTrustworthy - ol.solution = Some(sol) - ol.parent.notifySolution(ol, sol) - case _ => - ol.isUnsolvable = true - ol.parent.unsolvable(ol) - } - - if (g.tree.isSolved) { - stop() - } - } - - def traversePathFrom(n: g.Tree, path: List[Int]): Option[g.Tree] = { - n match { - case l: g.Leaf => - assert(path.isEmpty) - Some(l) - case an: g.AndNode => - path match { - case x :: xs => - traversePathFrom(an.subProblems(an.subTasks(x)), xs) - case Nil => - Some(an) - } - - case on: g.OrNode => - path match { - case x :: xs => - val t = on.altTasks(x) - if (on.triedAlternatives contains t) { - None - } else { - traversePathFrom(on.alternatives(t), xs) - } - - case Nil => - Some(on) - } - } - } - - def traversePath(path: List[Int]): Option[g.Tree] = { - traversePathFrom(g.tree, path) - } -} diff --git a/src/main/scala/leon/synthesis/search/Cost.scala b/src/main/scala/leon/synthesis/search/Cost.scala deleted file mode 100644 index 140ec6ba8..000000000 --- a/src/main/scala/leon/synthesis/search/Cost.scala +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon.synthesis.search - -trait Cost extends Ordered[Cost] { - def +(that: Cost): Cost = CostFixed(value + that.value) - - def value: Int - - def compare(that: Cost) = this.value - that.value -} - -case class CostFixed(value: Int) extends Cost - -object Cost { - val zero: Cost = new CostFixed(0) -} - diff --git a/src/main/scala/leon/utils/ASCIITable.scala b/src/main/scala/leon/utils/ASCIIHelpers.scala similarity index 89% rename from src/main/scala/leon/utils/ASCIITable.scala rename to src/main/scala/leon/utils/ASCIIHelpers.scala index 60a947033..ae2f3a418 100644 --- a/src/main/scala/leon/utils/ASCIITable.scala +++ b/src/main/scala/leon/utils/ASCIIHelpers.scala @@ -1,6 +1,6 @@ package leon.utils -object ASCIITables { +object ASCIIHelpers { case class Table(title: String, rows: Seq[TableRow] = Nil) { def +(r: TableRow): Table = this ++ Seq(r) def ++(rs: Iterable[TableRow]): Table = copy(rows = rows ++ rs) @@ -119,4 +119,18 @@ object ASCIITables { lazy val vString = v.toString } + def title(str: String, width: Int = 80): String = { + line(str, "=", width) + } + + def subTitle(str: String, width: Int = 80): String = { + line(str, "-", width) + } + + def line(str: String, sep: String, width: Int = 80): String = { + val middle = " "+str+" " + val remSize = (width-middle.length) + sep*math.floor(remSize/2).toInt+middle+sep*math.ceil(remSize/2).toInt + } + } diff --git a/src/main/scala/leon/utils/Simplifiers.scala b/src/main/scala/leon/utils/Simplifiers.scala index 3fef391d7..97ac93e8f 100644 --- a/src/main/scala/leon/utils/Simplifiers.scala +++ b/src/main/scala/leon/utils/Simplifiers.scala @@ -37,4 +37,25 @@ object Simplifiers { // Clean up ids/names (new ScopeSimplifier).transform(s) } + + def namePreservingBestEffort(ctx: LeonContext, p: Program)(e: Expr): Expr = { + val uninterpretedZ3 = SolverFactory(() => new UninterpretedZ3Solver(ctx, p)) + + val simplifiers = List[Expr => Expr]( + simplifyTautologies(uninterpretedZ3)(_), + decomposeIfs _, + rewriteTuples _, + evalGround(ctx, p), + normalizeExpression _ + ) + + val simple = { expr: Expr => + simplifiers.foldLeft(expr){ case (x, sim) => + sim(x) + } + } + + // Simplify first using stable simplifiers + fixpoint(simple, 5)(e) + } } diff --git a/src/main/scala/leon/utils/StreamUtils.scala b/src/main/scala/leon/utils/StreamUtils.scala new file mode 100644 index 000000000..a0806f5c5 --- /dev/null +++ b/src/main/scala/leon/utils/StreamUtils.scala @@ -0,0 +1,98 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon.utils + +object StreamUtils { + def cartesianProduct[T](streams : Seq[Stream[T]]) : Stream[List[T]] = { + val dimensions = streams.size + val vectorizedStreams = streams.map(new VectorizedStream(_)) + + if(dimensions == 0) + return Stream.cons(Nil, Stream.empty) + + if(streams.exists(_.isEmpty)) + return Stream.empty + + val indices = if(streams.forall(_.hasDefiniteSize)) { + val max = streams.map(_.size).max + diagCount(dimensions).take(max) + } else { + diagCount(dimensions) + } + + var allReached : Boolean = false + val bounds : Array[Int] = Array.fill(dimensions)(Int.MaxValue) + + indices.takeWhile(_ => !allReached).flatMap { indexList => + var d = 0 + var continue = true + var is = indexList + var ss = vectorizedStreams.toList + + if(indexList.sum >= bounds.max) { + allReached = true + } + + var tuple : List[T] = Nil + + while(continue && d < dimensions) { + var i = is.head + if(i > bounds(d)) { + continue = false + } else try { + // TODO can we speed up by caching the random access into + // the stream in an indexedSeq? After all, `i` increases + // slowly. + tuple = (ss.head)(i) :: tuple + is = is.tail + ss = ss.tail + d += 1 + } catch { + case e : IndexOutOfBoundsException => + bounds(d) = i - 1 + continue = false + } + } + if(continue) Some(tuple.reverse) else None + } + } + + private def diagCount(dim : Int) : Stream[List[Int]] = diag0(dim, 0) + private def diag0(dim : Int, nextSum : Int) : Stream[List[Int]] = summingTo(nextSum, dim).append(diag0(dim, nextSum + 1)) + + private def summingTo(sum : Int, n : Int) : Stream[List[Int]] = { + // assert(sum >= 0) + if(sum < 0) { + Stream.empty + } else if(n == 1) { + Stream.cons(sum :: Nil, Stream.empty) + } else { + (0 to sum).toStream.flatMap(fst => summingTo(sum - fst, n - 1).map(fst :: _)) + } + } + + private class VectorizedStream[T](initial : Stream[T]) { + private def mkException(i : Int) = new IndexOutOfBoundsException("Can't access VectorizedStream at : " + i) + private def streamHeadIndex : Int = indexed.size + private var stream : Stream[T] = initial + private var indexed : Vector[T] = Vector.empty + + def apply(index : Int) : T = { + if(index < streamHeadIndex) { + indexed(index) + } else { + val diff = index - streamHeadIndex // diff >= 0 + var i = 0 + while(i < diff) { + if(stream.isEmpty) throw mkException(index) + indexed = indexed :+ stream.head + stream = stream.tail + i += 1 + } + // The trick is *not* to read past the desired element. Leave it in the + // stream, or it will force the *following* one... + stream.headOption.getOrElse { throw mkException(index) } + } + } + } +} diff --git a/src/main/scala/leon/utils/Timer.scala b/src/main/scala/leon/utils/Timer.scala index 0bc4eae91..c896759f4 100644 --- a/src/main/scala/leon/utils/Timer.scala +++ b/src/main/scala/leon/utils/Timer.scala @@ -90,7 +90,7 @@ class TimerStorage extends Dynamic { } def outputTable(printer: String => Unit) = { - import utils.ASCIITables._ + import utils.ASCIIHelpers._ var table = Table("Timers") diff --git a/src/main/scala/leon/verification/VerificationReport.scala b/src/main/scala/leon/verification/VerificationReport.scala index 59d4ef4e6..bec5a07bf 100644 --- a/src/main/scala/leon/verification/VerificationReport.scala +++ b/src/main/scala/leon/verification/VerificationReport.scala @@ -21,7 +21,7 @@ class VerificationReport(val fvcs: Map[FunDef, List[VerificationCondition]]) { lazy val totalUnknown : Int = conditions.count(_.value == None) def summaryString : String = if(totalConditions >= 0) { - import utils.ASCIITables._ + import utils.ASCIIHelpers._ var t = Table("Verification Summary") -- GitLab