diff --git a/src/main/scala/leon/refactor/RepairCostModel.scala b/src/main/scala/leon/refactor/RepairCostModel.scala index 3cdf9af8d5d90b697e4593dbf469d1575027bdfc..4cb56277ac14b6e7fd1e0d4c37f3232523c385ad 100644 --- a/src/main/scala/leon/refactor/RepairCostModel.scala +++ b/src/main/scala/leon/refactor/RepairCostModel.scala @@ -4,20 +4,33 @@ package leon package refactor import synthesis._ +import purescala.Definitions._ import purescala.Trees._ +import purescala.DefOps._ import purescala.TreeOps._ +import purescala.Extractors._ -case class RepairCostModel(cm: CostModel) extends CostModel(cm.name) { - override def ruleAppCost(app: RuleInstantiation): Cost = { - app.rule match { - case rules.GuidedDecomp => 0 - case rules.GuidedCloser => 0 - case rules.CEGLESS => 0 - case _ => 10+cm.ruleAppCost(app) - } +case class RepairCostModel(cm: CostModel) extends WrappedCostModel(cm, "Repair("+cm.name+")") { + import graph._ + + override def andNode(an: AndNode, subs: Option[Seq[Cost]]) = { + val h = cm.andNode(an, subs).minSize + + Cost(an.ri.rule match { + case rules.GuidedDecomp => h/3 + case rules.GuidedCloser => h/3 + case rules.CEGLESS => h/2 + case _ => h + }) } - def solutionCost(s: Solution) = cm.solutionCost(s) - def problemCost(p: Problem) = cm.problemCost(p) -} + def costOfGuide(p: Problem): Int = { + val TopLevelAnds(clauses) = p.pc + val guides = clauses.collect { + case FunctionInvocation(TypedFunDef(fd, _), Seq(expr)) if fullName(fd) == "leon.lang.synthesis.guide" => expr + } + + guides.map(formulaSize(_)).sum + } +} diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index cba5671e94a0024d02440adcb9a405488e8a4a67..d714bbf81a80c153631946c0bf9758fa4dae8093 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -9,54 +9,86 @@ import purescala.Trees._ import purescala.TreeOps._ abstract class CostModel(val name: String) { - type Cost = Int + def solution(s: Solution): Cost - def solutionCost(s: Solution): Cost - def problemCost(p: Problem): Cost + def problem(p: Problem): Cost - def ruleAppCost(app: RuleInstantiation): Cost = { - val subSols = app.onSuccess.types.map {t => Solution.simplest(t) }.toList - val simpleSol = app.onSuccess(subSols) + def andNode(an: AndNode, subs: Option[Seq[Cost]]): Cost - simpleSol match { - case Some(sol) => - solutionCost(sol) - case None => - problemCost(app.problem) - } - } + def impossible: Cost } -case class ScaledCostModel(cm: CostModel, scale: Int) extends CostModel(cm.name+"/"+scale) { - def solutionCost(s: Solution): Cost = Math.max(cm.solutionCost(s)/scale, 1) - def problemCost(p: Problem): Cost = Math.max(cm.problemCost(p)/scale, 1) - override def ruleAppCost(app: RuleInstantiation): Cost = Math.max(cm.ruleAppCost(app)/scale, 1) +case class Cost(minSize: Int) extends Ordered[Cost] { + def isImpossible = minSize >= 100 + + def compare(that: Cost): Int = { + this.minSize-that.minSize + } + + def asString: String = { + if (isImpossible) { + "<!>" + } else { + f"$minSize%3d" + } + } } -object CostModel { - def default: CostModel = ScaledCostModel(WeightedBranchesCostModel, 5) +object CostModels { + def default: CostModel = WeightedBranchesCostModel def all: Set[CostModel] = Set( - ScaledCostModel(NaiveCostModel, 5), - ScaledCostModel(WeightedBranchesCostModel, 5) + NaiveCostModel, + WeightedBranchesCostModel ) } -case object NaiveCostModel extends CostModel("Naive") { - def solutionCost(s: Solution): Cost = { - val chooses = collectChooses(s.toExpr) - val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c))) +class WrappedCostModel(cm: CostModel, name: String) extends CostModel(name) { + + def solution(s: Solution): Cost = cm.solution(s) + + def problem(p: Problem): Cost = cm.problem(p) + + def andNode(an: AndNode, subs: Option[Seq[Cost]]): Cost = cm.andNode(an, subs) + + def impossible = cm.impossible +} + +class SizeBasedCostModel(name: String) extends CostModel(name) { + def solution(s: Solution) = { + Cost(formulaSize(s.toExpr)/10) + } - (formulaSize(s.toExpr) + chooseCost)/5+1 + def problem(p: Problem) = { + Cost(1) } - def problemCost(p: Problem): Cost = { - 1 + def andNode(an: AndNode, subs: Option[Seq[Cost]]) = { + + subs match { + case Some(subs) if subs.isEmpty => + impossible + + case osubs => + val app = an.ri + + val subSols = app.onSuccess.types.map {t => Solution.simplest(t) }.toList + val selfCost = app.onSuccess(subSols) match { + case Some(sol) => + solution(sol).minSize - subSols.size + case None => + 1 + } + Cost(osubs.toList.flatten.foldLeft(selfCost)(_ + _.minSize)) + } } + def impossible = Cost(100) } -case object WeightedBranchesCostModel extends CostModel("WeightedBranches") { +case object NaiveCostModel extends SizeBasedCostModel("Naive") + +case object WeightedBranchesCostModel extends SizeBasedCostModel("WeightedBranches") { def branchesCost(e: Expr): Int = { case class BC(cost: Int, nesting: Int) @@ -93,15 +125,8 @@ case object WeightedBranchesCostModel extends CostModel("WeightedBranches") { bc.cost } - def solutionCost(s: Solution): Cost = { - val chooses = collectChooses(s.toExpr) - val chooseCost = chooses.foldLeft(0)((i, c) => i + problemCost(Problem.fromChoose(c))) - - formulaSize(s.toExpr) + branchesCost(s.toExpr) + chooseCost - } - - def problemCost(p: Problem): Cost = { - p.xs.size + override def solution(s: Solution) = { + Cost(formulaSize(s.toExpr) + branchesCost(s.toExpr)) } } diff --git a/src/main/scala/leon/synthesis/Histogram.scala b/src/main/scala/leon/synthesis/Histogram.scala index e4235d281998d63cb463f1c4c04acb1acc66a813..2614704ea824066da4d6caae5171b724b7b10a42 100644 --- a/src/main/scala/leon/synthesis/Histogram.scala +++ b/src/main/scala/leon/synthesis/Histogram.scala @@ -26,7 +26,12 @@ class Histogram(val bound: Int, val values: Array[Double]) extends Ordered[Histo i += 1 } - new Histogram(bound, a) + val res = new Histogram(bound, a) + println("==== && ====") + println("this:"+ this) + println("that:"+ that) + println(" ==> "+res) + res } /** @@ -43,10 +48,15 @@ class Histogram(val bound: Int, val values: Array[Double]) extends Ordered[Histo i += 1 } - new Histogram(bound, a) + val res = new Histogram(bound, a) + println("==== || ====") + println("this:"+ this) + println("that:"+ that) + println(" ==> "+res) + res } - lazy val maxInfo = { + lazy val mode = { var max = 0d; var argMax = -1; var i = 0; @@ -60,15 +70,41 @@ class Histogram(val bound: Int, val values: Array[Double]) extends Ordered[Histo (max, argMax) } - def isImpossible = maxInfo._1 == 0 + lazy val firstNonZero = { + var i = 0; + var mini = -1 + while(i < bound && mini < 0) { + if (values(i) > 0) { + mini = i; + } + i += 1; + } + if (mini >= 0) { + (values(mini), mini) + } else { + (0d, bound) + } + } - /** - * Should return v<0 if `this` < `that`, that is, `this` represents better - * solutions than `that`. - */ - def compare(that: Histogram) = { - val (m1, am1) = this.maxInfo - val (m2, am2) = that.maxInfo + lazy val moment = { + var i = 0; + var moment = 0d; + var allV = 0d; + while(i < bound) { + val v = values(i) + moment += v*i + allV += v + i += 1 + } + + moment/allV + } + + def isImpossible = mode._1 == 0 + + def compareByMode(that: Histogram) = { + val (m1, am1) = this.mode + val (m2, am2) = that.mode if (m1 == m2) { am1 - am2 @@ -83,16 +119,80 @@ class Histogram(val bound: Int, val values: Array[Double]) extends Ordered[Histo } } + def rescaled(by: Double): Histogram = { + val a = new Array[Double](bound) + + var i = 0; + while(i < bound) { + val v = values(i) + + val nv = 1-Math.pow(1-v, by); + + a(i) = nv + + i += 1 + } + + new Histogram(bound, a) + } + + def compareByFirstNonZero(that: Histogram) = { + this.firstNonZero._2 - that.firstNonZero._2 + } + + def compareByMoment(that: Histogram) = { + this.moment - that.moment + } + + /** + * Should return v<0 if `this` < `that`, that is, `this` represents better + * solutions than `that`. + */ + def compare(that: Histogram) = { + compareByFirstNonZero(that) + } + override def toString: String = { - var printed = 0 - val info = for (i <- 0 until bound if values(i) != 0 && printed < 5) yield { - f"$i%2d -> ${values(i)}%,3f" + var lastv = -1d; + var fromi = -1; + val entries = new scala.collection.mutable.ArrayBuffer[((Int, Int), Double)]() + + + for (i <- 0 until bound) { + val v = values(i) + if (lastv < 0) { + lastv = v + fromi = i; + } + + if (lastv != v) { + entries += (fromi, i-1) -> lastv + lastv = v + fromi = i + } + } + entries += (fromi, bound-1) -> lastv + + val info = for (((from, to), v) <- entries) yield { + val k = if (from == to) { + s"$from" + } else { + s"$from..$to" + } + + f"$k -> $v%1.3f" } - val (m,am) = maxInfo - "H("+m+"@"+am+": "+info.mkString(", ")+")" + s"H($summary: ${info.mkString(", ")})" } + + def summary: String = { + //val (m, am) = maxInfo + val (m, am) = firstNonZero + + f"$m%1.4f->$am%-2d ($moment%1.3f)" + } } object Histogram { diff --git a/src/main/scala/leon/synthesis/PartialSolution.scala b/src/main/scala/leon/synthesis/PartialSolution.scala index 2e33fd105d11c9ed4ef56ee758be47efb4da42f3..ea2da95ed1eba104f8b74fc2fd7ef6b2543db7bb 100644 --- a/src/main/scala/leon/synthesis/PartialSolution.scala +++ b/src/main/scala/leon/synthesis/PartialSolution.scala @@ -20,9 +20,9 @@ class PartialSolution(g: Graph, includeUntrusted: Boolean) { getSolutionFor(g.root) } - def getSolutionFor(n: g.Node): Solution = { + def getSolutionFor(n: Node): Solution = { n match { - case on: g.OrNode => + case on: OrNode => if (on.isSolved) { val sols = on.generateSolutions() sols.find(includeSolution) match { @@ -37,12 +37,12 @@ class PartialSolution(g: Graph, includeUntrusted: Boolean) { if (descs.isEmpty) { completeProblem(on.p) } else { - getSolutionFor(descs.minBy(_.histogram)) + getSolutionFor(descs.minBy(_.cost)) } } else { completeProblem(on.p) } - case an: g.AndNode => + case an: AndNode => if (an.isSolved) { val sols = an.generateSolutions() sols.find(includeSolution) match { diff --git a/src/main/scala/leon/synthesis/SynthesisOptions.scala b/src/main/scala/leon/synthesis/SynthesisOptions.scala index a92550ac9cb0401d7c4e6dcf15bd33babb87364c..15b634eb5200a390c7bc0a085d8f6a97fada84ee 100644 --- a/src/main/scala/leon/synthesis/SynthesisOptions.scala +++ b/src/main/scala/leon/synthesis/SynthesisOptions.scala @@ -3,6 +3,8 @@ package leon package synthesis +import scala.language.existentials + case class SynthesisOptions( inPlace: Boolean = false, allSeeing: Boolean = false, @@ -11,7 +13,7 @@ case class SynthesisOptions( searchWorkers: Int = 1, firstOnly: Boolean = false, timeoutMs: Option[Long] = None, - costModel: CostModel = CostModel.default, + costModel: CostModel = CostModels.default, rules: Seq[Rule] = Rules.all ++ Heuristics.all, manualSearch: Boolean = false, searchBound: Option[Int] = None, diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index c58e8fb3ddcbba23259bbce5655845e53c7b5eee..d7ff62406ef077d197969afa6078f89a3d4e7421 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -54,7 +54,7 @@ object SynthesisPhase extends LeonPhase[Program, Program] { options = options.copy(filterFuns = Some(fs.toSet)) case LeonValueOption("costmodel", cm) => - CostModel.all.find(_.name.toLowerCase == cm.toLowerCase) match { + CostModels.all.find(_.name.toLowerCase == cm.toLowerCase) match { case Some(model) => options = options.copy(costModel = model) case None => @@ -62,8 +62,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] { var errorMsg = "Unknown cost model: " + cm + "\n" + "Defined cost models: \n" - for (cm <- CostModel.all.toSeq.sortBy(_.name)) { - errorMsg += " - " + cm.name + (if(cm == CostModel.default) " (default)" else "") + "\n" + for (cm <- CostModels.all.toSeq.sortBy(_.name)) { + errorMsg += " - " + cm.name + (if(cm == CostModels.default) " (default)" else "") + "\n" } ctx.reporter.fatalError(errorMsg) diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala index 3c74dd772f70fd785b7d84a116c7cc62ff114acd..dbbb70d1ef6d4326c1d3be266d3c182e97c2aa52 100644 --- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala +++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala @@ -5,7 +5,6 @@ import java.io.{File, FileWriter, BufferedWriter} import leon.synthesis.Histogram class DotGenerator(g: Graph) { - import g.{Node, AndNode, OrNode, RootNode} private[this] var _nextID = 0 def freshName(prefix: String) = { @@ -68,14 +67,6 @@ class DotGenerator(g: Graph) { res.toString } - def hist(h: Histogram): String = { - if (h.isImpossible) { - "-/-" - } else { - h.maxInfo._1+"@"+h.maxInfo._2 - } - } - def limit(o: Any, length: Int = 40): String = { val str = o.toString if (str.length > length) { @@ -110,9 +101,9 @@ class DotGenerator(g: Graph) { //cost n match { case an: AndNode => - res append "<TR><TD BORDER=\"0\">"+escapeHTML(hist(n.histogram)+" ("+hist(an.selfHistogram))+")</TD></TR>" + res append "<TR><TD BORDER=\"0\">"+escapeHTML(n.cost.asString)+" ("+escapeHTML(g.cm.andNode(an, None).asString)+")</TD></TR>" case on: OrNode => - res append "<TR><TD BORDER=\"0\">"+escapeHTML(hist(n.histogram))+"</TD></TR>" + res append "<TR><TD BORDER=\"0\">"+escapeHTML(n.cost.asString)+"</TD></TR>" } res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(nodeDesc(n)))+"</TD></TR>"; diff --git a/src/main/scala/leon/synthesis/graph/Graph.scala b/src/main/scala/leon/synthesis/graph/Graph.scala index d27983f57df380862b977b0babb032d6585bd190..7e07bca199f6095adca187583483e200765baae8 100644 --- a/src/main/scala/leon/synthesis/graph/Graph.scala +++ b/src/main/scala/leon/synthesis/graph/Graph.scala @@ -4,204 +4,192 @@ package graph import leon.utils.StreamUtils.cartesianProduct -sealed class Graph(problem: Problem, costModel: CostModel) { - type Cost = Int +sealed class Graph(val cm: CostModel, problem: Problem) { + val root = new RootNode(cm, problem) - var maxCost = 100; + // Returns closed/total + def getStats(from: Node = root): (Int, Int) = { + val isClosed = from.isClosed || from.isSolved + val self = (if (isClosed) 1 else 0, 1) - val root = new RootNode(problem) + if (!from.isExpanded) { + self + } else { + from.descendents.foldLeft(self) { + case ((c,t), d) => + val (sc, st) = getStats(d) + (c+sc, t+st) + } + } + } +} - sealed abstract class Node(parent: Option[Node]) { - var parents: List[Node] = parent.toList - var descendents: List[Node] = Nil +sealed abstract class Node(cm: CostModel, parent: Option[Node]) { + var parents: List[Node] = parent.toList + var descendents: List[Node] = Nil - // indicates whether this particular node has already been expanded - var isExpanded: Boolean = false - def expand(sctx: SynthesisContext) + // indicates whether this particular node has already been expanded + var isExpanded: Boolean = false + def expand(sctx: SynthesisContext) - val p: Problem + val p: Problem - // costs - var histogram: Histogram - def updateHistogram(desc: Node) + var isSolved: Boolean = false - var isSolved: Boolean = false - def isClosed: Boolean = { - histogram.maxInfo._1 == 0 - } + def onSolved(desc: Node) - def onSolved(desc: Node) + // Solutions this terminal generates (!= None for terminals) + var solutions: Option[Stream[Solution]] = None + var selectedSolution = -1 - // Solutions this terminal generates (!= None for terminals) - var solutions: Option[Stream[Solution]] = None - var selectedSolution = -1 + // Costs + var cost: Cost = computeCost() - // For non-terminals, selected childs for solution - var selected: List[Node] = Nil + def isClosed: Boolean = { + cost.isImpossible + } - def composeSolutions(sols: List[Stream[Solution]]): Stream[Solution] + // For non-terminals, selected childs for solution + var selected: List[Node] = Nil - // Generate solutions given selection+solutions - def generateSolutions(): Stream[Solution] = { - solutions.getOrElse { - composeSolutions(selected.map(n => n.generateSolutions())) - } + def composeSolutions(sols: List[Stream[Solution]]): Stream[Solution] + + // Generate solutions given selection+solutions + def generateSolutions(): Stream[Solution] = { + solutions.getOrElse { + composeSolutions(selected.map(n => n.generateSolutions())) } } - class AndNode(parent: Option[Node], val ri: RuleInstantiation) extends Node(parent) { - val p = ri.problem - var selfHistogram = Histogram.point(maxCost, costModel.ruleAppCost(ri), 1d) - var histogram = Histogram.uniformFrom(maxCost, costModel.ruleAppCost(ri), 0.5d) + def computeCost(): Cost = solutions match { + case Some(sols) if sols.isEmpty => + cm.impossible - override def toString = "\u2227 "+ri; + case Some(sols) => + sols.map { sol => cm.solution(sol) } .min - def expand(sctx: SynthesisContext): Unit = { - require(!isExpanded) - isExpanded = true + case None => + val costs = if (isExpanded) { + Some(descendents.map { _.cost }) + } else { + None + } - import sctx.reporter.info + this match { + case an: AndNode => + cm.andNode(an, costs) - val prefix = "[%-20s] ".format(Option(ri.rule).getOrElse("?")) + case on: OrNode => + costs.map(_.min).getOrElse(cm.problem(on.p)) + } + } - info(prefix+ri.problem) + def updateCost(): Unit = { + cost = computeCost() + parents.foreach(_.updateCost()) + } +} - ri.apply(sctx) match { - case RuleClosed(sols) => - solutions = Some(sols) - selectedSolution = 0; +class AndNode(cm: CostModel, parent: Option[Node], val ri: RuleInstantiation) extends Node(cm, parent) { + val p = ri.problem - histogram = sols.foldLeft(Histogram.empty(maxCost)) { - (d, sol) => d or Histogram.point(maxCost, costModel.solutionCost(sol), 1d) - } + override def toString = "\u2227 "+ri; - isSolved = sols.nonEmpty + def expand(sctx: SynthesisContext): Unit = { + require(!isExpanded) + isExpanded = true - if (sols.isEmpty) { - info(prefix+"Failed") - } else { - val sol = sols.head - info(prefix+"Solved"+(if(sol.isTrusted) "" else " (untrusted)")+" with: "+sol+"...") - } + import sctx.reporter.info - parents.foreach{ p => - p.updateHistogram(this) - if (isSolved) { - p.onSolved(this) - } - } + val prefix = "[%-20s] ".format(Option(ri.rule).getOrElse("?")) - case RuleExpanded(probs) => - info(prefix+"Decomposed into:") - for(p <- probs) { - info(prefix+" - "+p) - } + info(prefix+ri.problem) - descendents = probs.map(p => new OrNode(Some(this), p)) + ri.apply(sctx) match { + case RuleClosed(sols) => + solutions = Some(sols) + selectedSolution = 0; - selected = descendents + updateCost() - recomputeCost() - } - } + isSolved = sols.nonEmpty - def composeSolutions(solss: List[Stream[Solution]]): Stream[Solution] = { - cartesianProduct(solss).flatMap { - sols => ri.onSuccess(sols) - } - } + if (sols.isEmpty) { + info(prefix+"Failed") + } else { + val sol = sols.head + info(prefix+"Solved"+(if(sol.isTrusted) "" else " (untrusted)")+" with: "+sol+"...") + } - def updateHistogram(desc: Node) = { - recomputeCost() - } + parents.foreach{ p => + p.updateCost() + if (isSolved) { + p.onSolved(this) + } + } - private def recomputeCost() = { - val newHistogram = descendents.foldLeft(selfHistogram){ - case (c, d) => c and d.histogram - } + case RuleExpanded(probs) => + info(prefix+"Decomposed into:") + for(p <- probs) { + info(prefix+" - "+p) + } - if (newHistogram != histogram) { - histogram = newHistogram - parents.foreach(_.updateHistogram(this)) - } - } - - private var solveds = Set[Node]() + descendents = probs.map(p => new OrNode(cm, Some(this), p)) - def onSolved(desc: Node): Unit = { - // We store everything within solveds - solveds += desc + selected = descendents - // Everything is solved correctly - if (solveds.size == descendents.size) { - isSolved = true; - parents.foreach(_.onSolved(this)) - } + updateCost() } - } - class OrNode(parent: Option[Node], val p: Problem) extends Node(parent) { - var histogram = Histogram.uniformFrom(maxCost, costModel.problemCost(p), 0.5d) - - override def toString = "\u2228 "+p; + def composeSolutions(solss: List[Stream[Solution]]): Stream[Solution] = { + cartesianProduct(solss).flatMap { + sols => ri.onSuccess(sols) + } + } - def expand(sctx: SynthesisContext): Unit = { - require(!isExpanded) + private var solveds = Set[Node]() - val ris = Rules.getInstantiations(sctx, p) + def onSolved(desc: Node): Unit = { + // We store everything within solveds + solveds += desc - descendents = ris.map(ri => new AndNode(Some(this), ri)) - selected = List() + // Everything is solved correctly + if (solveds.size == descendents.size) { + isSolved = true; + parents.foreach(_.onSolved(this)) + } + } - recomputeCost() +} - isExpanded = true - } +class OrNode(cm: CostModel, parent: Option[Node], val p: Problem) extends Node(cm, parent) { - def onSolved(desc: Node): Unit = { - isSolved = true - selected = List(desc) - parents.foreach(_.onSolved(this)) - } + override def toString = "\u2228 "+p; - def composeSolutions(solss: List[Stream[Solution]]): Stream[Solution] = { - solss.toStream.flatten - } + def expand(sctx: SynthesisContext): Unit = { + require(!isExpanded) - private def recomputeCost(): Unit = { - val newHistogram = descendents.foldLeft(Histogram.empty(maxCost)){ - case (c, d) => c or d.histogram - } + val ris = Rules.getInstantiations(sctx, p) - if (histogram != newHistogram) { - histogram = newHistogram - parents.foreach(_.updateHistogram(this)) + descendents = ris.map(ri => new AndNode(cm, Some(this), ri)) + selected = List() - } - } + updateCost() - def updateHistogram(desc: Node): Unit = { - recomputeCost() - } + isExpanded = true } - class RootNode(p: Problem) extends OrNode(None, p) - - // Returns closed/total - def getStats(from: Node = root): (Int, Int) = { - val isClosed = from.isClosed || from.isSolved - val self = (if (isClosed) 1 else 0, 1) + def onSolved(desc: Node): Unit = { + isSolved = true + selected = List(desc) + parents.foreach(_.onSolved(this)) + } - if (!from.isExpanded) { - self - } else { - from.descendents.foldLeft(self) { - case ((c,t), d) => - val (sc, st) = getStats(d) - (c+sc, t+st) - } - } + def composeSolutions(solss: List[Stream[Solution]]): Stream[Solution] = { + solss.toStream.flatten } } + +class RootNode(cm: CostModel, p: Problem) extends OrNode(cm, None, p) diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index fd639221d28c41d287327031a44da2a11b48e393..03614a19b5d1688647f068e7b5de9de352d54a54 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -11,9 +11,7 @@ import leon.utils.Interruptible import java.util.concurrent.atomic.AtomicBoolean abstract class Search(ctx: LeonContext, p: Problem, costModel: CostModel) extends Interruptible { - val g = new Graph(p, costModel); - - import g.{Node, AndNode, OrNode, RootNode} + val g = new Graph(costModel, p); def findNodeToExpandFrom(n: Node): Option[Node] @@ -76,8 +74,6 @@ abstract class Search(ctx: LeonContext, p: Problem, costModel: CostModel) extend } class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Option[Int]) extends Search(ctx, p, costModel) { - import g.{Node, AndNode, OrNode, RootNode} - val expansionBuffer = ArrayBuffer[Node]() def findIn(n: Node) { @@ -85,12 +81,12 @@ class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Op expansionBuffer += n } else if (!n.isClosed) { n match { - case an: g.AndNode => + case an: AndNode => an.descendents.foreach(findIn) - case on: g.OrNode => + case on: OrNode => if (on.descendents.nonEmpty) { - findIn(on.descendents.minBy(_.histogram)) + findIn(on.descendents.minBy(_.cost)) } } } @@ -116,12 +112,61 @@ class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Op } class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) extends Search(ctx, problem, costModel) { - import g.{Node, AndNode, OrNode, RootNode} - import ctx.reporter._ + abstract class Command + case class Cd(path: List[Int]) extends Command + case object Parent extends Command + case object Quit extends Command + case object Noop extends Command + + // Manual search state: var cd = List[Int]() - var cmdQueue = List[String]() + var cmdQueue = List[Command]() + + def getNextCommand(): Command = cmdQueue match { + case c :: cs => + cmdQueue = cs + c + + case Nil => + print("Next action? (q to quit) "+cd.mkString(" ")+" $ ") + val line = scala.io.StdIn.readLine().trim + val parts = line.split("\\s+").toList + + cmdQueue = parseCommands(parts) + getNextCommand() + } + + def parseCommands(tokens: List[String]): List[Command] = tokens match { + case "cd" :: ".." :: ts => + Parent :: parseCommands(ts) + + case "cd" :: ts => + val path = ts.takeWhile { t => t.forall(_.isDigit) } + + if (path.isEmpty) { + parseCommands(ts) + } else { + Cd(path.map(_.toInt)) :: parseCommands(ts.drop(path.size)) + } + + case "q" :: ts => + Quit :: Nil + + case Nil => + Nil + + case ts => + val path = ts.takeWhile { t => t.forall(_.isDigit) } + + if (path.isEmpty) { + error("Unknown command "+ts.head) + parseCommands(ts.tail) + } else { + Cd(path.map(_.toInt)) :: parseCommands(ts.drop(path.size)) + } + } override def doStep(n: Node, sctx: SynthesisContext) = { super.doStep(n, sctx); @@ -143,61 +188,55 @@ class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) ext } } - def printGraph() { - def pathToString(path: List[Int]): String = { - val p = path.reverse.drop(cd.size) - if (p.isEmpty) { - "" - } else { - " "+p.mkString(" ") - } - } - def title(str: String) = "\u001b[1m" + str + "\u001b[0m" - def failed(str: String) = "\u001b[31m" + str + "\u001b[0m" - def solved(str: String) = "\u001b[32m" + str + "\u001b[0m" - - def displayHistogram(h: Histogram): String = { - val (max, maxarg) = h.maxInfo - f"$max%,2f@$maxarg%2d" - } + def printGraph() { + def title(str: String) = "\u001b[1m" + str + "\u001b[0m" + def failed(str: String) = "\u001b[31m" + str + "\u001b[0m" + def solved(str: String) = "\u001b[32m" + str + "\u001b[0m" + def expanded(str: String) = "\u001b[33m" + str + "\u001b[0m" def displayNode(n: Node): String = n match { case an: AndNode => val app = an.ri - s"(${displayHistogram(n.histogram)}) $app" + s"(${n.cost.asString}) ${indent(app)}" case on: OrNode => val p = on.p - s"(${displayHistogram(n.histogram)}) $p" + s"(${n.cost.asString}) ${indent(p)}" } - def traversePathFrom(n: Node, prefix: List[Int]) { - val visible = (prefix endsWith cd.reverse) + def indent(a: Any): String = { + a.toString.replaceAll("\n", "\n"+(" "*12)) + } - if (!n.isExpanded) { - if (visible) { - println(pathToString(prefix)+" \u2508 "+displayNode(n)) - } - } else if (n.isSolved) { - println(solved(pathToString(prefix)+" \u2508 "+displayNode(n))) - } else if (n.isClosed) { - println(failed(pathToString(prefix)+" \u2508 "+displayNode(n))) - } else { - if (visible) { - println(title(pathToString(prefix)+" \u2510 "+displayNode(n))) - } - } + def pathToString(cd: List[Int]): String = { + cd.map(i => f"$i%2d").mkString(" ") + } + + def displayPath(n: Node, cd: List[Int]) { + if (cd.isEmpty) { + println(title(pathToString(cd)+" \u2510 "+displayNode(n))) - if (n.isExpanded && !n.isClosed && !n.isSolved) { for ((sn, i) <- n.descendents.zipWithIndex) { - traversePathFrom(sn, i :: prefix) + val sp = cd ::: List(i) + + if (sn.isSolved) { + println(solved(pathToString(sp)+" \u2508 "+displayNode(sn))) + } else if (sn.isClosed) { + println(failed(pathToString(sp)+" \u2508 "+displayNode(sn))) + } else if (sn.isExpanded) { + println(expanded(pathToString(sp)+" \u2508 "+displayNode(sn))) + } else { + println(pathToString(sp)+" \u2508 "+displayNode(sn)) + } } + } else { + displayPath(n.descendents(cd.head), cd.tail) } } - println("-"*80) - traversePathFrom(g.root, List()) - println("-"*80) + println("-"*120) + displayPath(g.root, cd) + println("-"*120) } var continue = true @@ -206,56 +245,48 @@ class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) ext if (!from.isExpanded) { Some(from) } else { - var res: Option[Node] = None - continue = true + var res: Option[Option[Node]] = None - while(continue) { + while(res.isEmpty) { printGraph() try { - print("Next action? (q to quit) "+cd.mkString(" ")+" $ ") - val line = if (cmdQueue.isEmpty) { - scala.io.StdIn.readLine() - } else { - val n = cmdQueue.head - println(n) - cmdQueue = cmdQueue.tail - n - } - if (line == "q") { - continue = false - res = None - } else if (line startsWith "cd") { - val parts = line.split("\\s+").toList - - parts match { - case List("cd") => - cd = List() - case List("cd", "..") => - if (cd.size > 0) { - cd = cd.dropRight(1) + getNextCommand() match { + case Quit => + continue = false + res = Some(None) + case Parent => + if (cd.nonEmpty) { + cd = cd.dropRight(1) + } else { + error("Already at root node") + } + + case Cd(path) => + var currentNode = from + var currentPath = cd ++ path + cd = Nil + while (currentPath.nonEmpty && currentNode.isExpanded && res.isEmpty) { + traversePathFrom(currentNode, List(currentPath.head)) match { + case Some(n) => + cd = cd ::: List(currentPath.head) + currentNode = n + currentPath = currentPath.tail + + case None => + error("Unknown path: "+path) + res = Some(None) + return None } - case "cd" :: parts => - cd = cd ::: parts.map(_.toInt) - case _ => - } + } - } else { - val parts = line.split("\\s+").toList - - val c = parts.head.toInt - cmdQueue = cmdQueue ::: parts.tail - - traversePath(cd ::: c :: Nil) match { - case Some(l) if !l.isExpanded => - res = Some(l) - cd = cd ::: c :: Nil - continue = false - case Some(n) => - cd = cd ::: c :: Nil - case None => - error("Invalid path") - } + if (currentPath.nonEmpty) { + cmdQueue = Cd(currentPath) :: cmdQueue + } + + if (!currentNode.isExpanded) { + res = Some(Some(currentNode)) + } } } catch { case e: java.lang.NumberFormatException => @@ -268,7 +299,8 @@ class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) ext e.printStackTrace() } } - res + + res.get } } }