From c9358c44c08a2480a1ae6a95a55a7f7afe95a371 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <etienne.kneuss@epfl.ch> Date: Fri, 7 Nov 2014 15:23:28 +0100 Subject: [PATCH] Cost Distributions become Histograms --- .../scala/leon/refactor/RepairCostModel.scala | 3 +- src/main/scala/leon/refactor/Repairman.scala | 19 +- src/main/scala/leon/synthesis/CostModel.scala | 2 + .../scala/leon/synthesis/Distribution.scala | 179 ------------------ src/main/scala/leon/synthesis/Histogram.scala | 136 +++++++++++++ .../leon/synthesis/PartialSolution.scala | 2 +- src/main/scala/leon/synthesis/Problem.scala | 8 +- .../leon/synthesis/graph/DotGenerator.scala | 14 +- .../scala/leon/synthesis/graph/Graph.scala | 42 ++-- .../scala/leon/synthesis/graph/Search.scala | 11 +- .../scala/leon/synthesis/rules/Cegis.scala | 18 +- .../leon/synthesis/rules/GuidedCloser.scala | 8 +- .../synthesis/utils/ExpressionGrammar.scala | 6 +- .../scala/leon/utils/InterruptManager.scala | 2 +- 14 files changed, 208 insertions(+), 242 deletions(-) delete mode 100644 src/main/scala/leon/synthesis/Distribution.scala create mode 100644 src/main/scala/leon/synthesis/Histogram.scala diff --git a/src/main/scala/leon/refactor/RepairCostModel.scala b/src/main/scala/leon/refactor/RepairCostModel.scala index 40dc2ad43..3cdf9af8d 100644 --- a/src/main/scala/leon/refactor/RepairCostModel.scala +++ b/src/main/scala/leon/refactor/RepairCostModel.scala @@ -11,8 +11,9 @@ case class RepairCostModel(cm: CostModel) extends CostModel(cm.name) { override def ruleAppCost(app: RuleInstantiation): Cost = { app.rule match { case rules.GuidedDecomp => 0 + case rules.GuidedCloser => 0 case rules.CEGLESS => 0 - case _ => cm.ruleAppCost(app) + case _ => 10+cm.ruleAppCost(app) } } def solutionCost(s: Solution) = cm.solutionCost(s) diff --git a/src/main/scala/leon/refactor/Repairman.scala b/src/main/scala/leon/refactor/Repairman.scala index 1c1e17844..e77deaf5c 100644 --- a/src/main/scala/leon/refactor/Repairman.scala +++ b/src/main/scala/leon/refactor/Repairman.scala @@ -21,6 +21,7 @@ import verification._ import synthesis._ import synthesis.rules._ import synthesis.heuristics._ +import graph.DotGenerator class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { val reporter = ctx.reporter @@ -79,7 +80,7 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { // Synthesis from the ground up val p = Problem(fd.params.map(_.id).toList, pc, spec, List(out)) val ch = Choose(List(out), spec) - //fd.body = Some(ch) + fd.body = Some(ch) val soptions0 = SynthesisPhase.processOptions(ctx); @@ -100,28 +101,30 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { val expr = sol.toSimplifiedExpr(ctx, program) val (npr, fds) = synthesizer.solutionToProgram(sol) - solutions ::= (sol, expr, fds) if (!sol.isTrusted) { - getVerificationCounterExamples(fds.head, npr) match { case Some(ces) => testBank ++= ces reporter.info("Failed :(, but I learned: "+ces.mkString(" | ")) case None => - reporter.info("ZZUCCESS!") + solutions ::= (sol, expr, fds) + reporter.info(ASCIIHelpers.title("ZUCCESS!")) } } else { - reporter.info("ZZUCCESS!") + solutions ::= (sol, expr, fds) + reporter.info(ASCIIHelpers.title("ZUCCESS!")) } } + if (soptions.generateDerivationTrees) { + val dot = new DotGenerator(search.g) + dot.writeFile("derivation"+DotGenerator.nextId()+".dot") + } if (solutions.isEmpty) { - reporter.info("Trey aagggain") - repair() + reporter.info(ASCIIHelpers.title("FAILURZ!")) } else { - reporter.info(ASCIIHelpers.title("Solutions")) for (((sol, expr, fds), i) <- solutions.zipWithIndex) { reporter.info(ASCIIHelpers.subTitle("Solution "+(i+1)+":")) diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index b4cac6662..cba5671e9 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -3,6 +3,8 @@ package leon package synthesis +import graph._ + import purescala.Trees._ import purescala.TreeOps._ diff --git a/src/main/scala/leon/synthesis/Distribution.scala b/src/main/scala/leon/synthesis/Distribution.scala deleted file mode 100644 index 7c6a07040..000000000 --- a/src/main/scala/leon/synthesis/Distribution.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon.synthesis - -class Distribution(val span: Int, val values: Array[Long], val total: Long) extends Ordered[Distribution] { - def and(that: Distribution): Distribution = { val res = (this, that) match { - case (d1, d2) if d1.total == 0 => - d1 - - case (d1, d2) if d2.total == 0 => - d2 - - case (d1: PointDistribution, d2: PointDistribution) => - if (d1.at + d2.at >= span) { - Distribution.empty(span) - } else { - new PointDistribution(span, d1.at+d2.at) - } - - case (d: PointDistribution, o) => - val a = Array.fill(span)(0l) - - val base = d.at - var innerTotal = 0l; - var i = d.at; - while(i < span) { - val v = o.values(i-base) - a(i) = v - innerTotal += v - i += 1 - } - - if (innerTotal == 0) { - Distribution.empty(span) - } else { - new Distribution(span, a, total) - } - - case (o, d: PointDistribution) => - val a = Array.fill(span)(0l) - - val base = d.at - var innerTotal = 0l; - var i = d.at; - while(i < span) { - val v = o.values(i-base) - a(i) = v - innerTotal += v - i += 1 - } - - if (innerTotal == 0) { - Distribution.empty(span) - } else { - new Distribution(span, a, total) - } - - case (left, right) => - if (left == right) { - left - } else { - val a = Array.fill(span)(0l) - var innerTotal = 0l; - var i = 0; - while (i < span) { - var j = 0; - while (j < span) { - if (i+j < span) { - val lv = left.values(i) - val rv = right.values(j) - - a(i+j) += lv*rv - innerTotal += lv*rv - } - j += 1 - } - i += 1 - } - - if (innerTotal == 0) { - Distribution.empty(span) - } else { - new Distribution(span, a, left.total * right.total) - } - } - } - //println("And of "+this+" and "+that+" = "+res) - res - } - - def or(that: Distribution): Distribution = (this, that) match { - case (d1, d2) if d1.total == 0 => - d2 - - case (d1, d2) if d2.total == 0 => - d1 - - case (d1: PointDistribution, d2: PointDistribution) => - if (d1.at < d2.at) { - d1 - } else { - d2 - } - - case (d1, d2) => - if (d1.weightedSum < d2.weightedSum) { - //if (d1.firstNonZero < d2.firstNonZero) { - d1 - } else { - d2 - } - } - - lazy val firstNonZero: Int = { - if (total == 0) { - span - } else { - var i = 0; - var continue = true; - while (continue && i < span) { - if (values(i) != 0l) { - continue = false - } - i += 1 - } - i - } - } - - lazy val weightedSum: Double = { - var res = 0d; - var i = 0; - while (i < span) { - res += (1d*i*values(i))/total - i += 1 - } - res - } - - override def toString: String = { - "Tot:"+total+"(at "+firstNonZero+")" - } - - def compare(that: Distribution) = { - this.firstNonZero - that.firstNonZero - } -} - -object Distribution { - def point(span: Int, at: Int) = { - if (span <= at) { - empty(span) - } else { - new PointDistribution(span, at) - } - } - - def empty(span: Int) = new Distribution(span, Array[Long](), 0l) - def uniform(span: Int, v: Long, total: Int) = { - new Distribution(span, Array.fill(span)(v), total) - } - - def uniformFrom(span: Int, from: Int, ratio: Double) = { - var i = from - val a = Array.fill(span)(0l) - while(i < span) { - a(i) = 1 - i += 1 - } - - new Distribution(span, a, ((span-from).toDouble*(1/ratio)).toInt) - } -} - -class PointDistribution(span: Int, val at: Int) extends Distribution(span, new Array[Long](span).updated(at, 1l), 1l) { - override lazy val firstNonZero: Int = { - at - } -} diff --git a/src/main/scala/leon/synthesis/Histogram.scala b/src/main/scala/leon/synthesis/Histogram.scala new file mode 100644 index 000000000..10176e97b --- /dev/null +++ b/src/main/scala/leon/synthesis/Histogram.scala @@ -0,0 +1,136 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon.synthesis + +/** + * Histogram from 0 to `bound`, each value between 0 and 1 + * hist(c) = v means we have a `v` likelihood of finding a solution of cost `c` + */ +class Histogram(val bound: Int, val values: Array[Double]) extends Ordered[Histogram] { + /** + */ + def and(that: Histogram): Histogram = { + val a = Array.fill(bound)(0d) + var i = 0; + while(i < bound) { + var j = 0; + while(j <= i) { + + val v1 = (this.values(j) * that.values(i-j)) + val v2 = a(i) + + a(i) = v1+v2 - (v1*v2) + + j += 1 + } + i += 1 + } + + new Histogram(bound, a) + } + + /** + * hist1(c) || hist2(c) == hist1(c)+hist2(c) - hist1(c)*hist2(c) + */ + def or(that: Histogram): Histogram = { + val a = Array.fill(bound)(0d) + var i = 0; + while(i < bound) { + val v1 = this.values(i) + val v2 = that.values(i) + + a(i) = v1+v2 - (v1*v2) + i += 1 + } + + new Histogram(bound, a) + } + + lazy val maxInfo = { + var max = 0d; + var argMax = -1; + var i = 0; + while(i < bound) { + if ((argMax < 0) || values(i) > max) { + argMax = i; + max = values(i) + } + i += 1; + } + (max, argMax) + } + + def isImpossible = maxInfo._1 == 0 + + /** + * Should return v<0 if `this` < `that`, that is, `this` represents better + * solutions than `that`. + */ + def compare(that: Histogram) = { + val (m1, am1) = this.maxInfo + val (m2, am2) = that.maxInfo + + if (m1 == m2) { + am1 - am2 + } else { + if (m2 < m1) { + -1 + } else if (m2 == m1) { + 0 + } else { + +1 + } + } + } + + override def toString: String = { + var printed = 0 + val info = for (i <- 0 until bound if values(i) != 0 && printed < 5) yield { + f"$i%2d -> ${values(i)}%,3f" + } + val (m,am) = maxInfo + + "H("+m+"@"+am+": "+info.mkString(", ")+")" + } + +} + +object Histogram { + def clampV(v: Double): Double = { + if (v < 0) { + 0d + } else if (v > 1) { + 1d + } else { + v + } + } + + def point(bound: Int, at: Int, v: Int) = { + if (bound <= at) { + empty(bound) + } else { + new Histogram(bound, Array.fill(bound)(0d).updated(at, clampV(v))) + } + } + + def empty(bound: Int) = { + new Histogram(bound, Array.fill(bound)(0d)) + } + + def uniform(bound: Int, v: Double) = { + uniformFrom(bound, 0, v) + } + + def uniformFrom(bound: Int, from: Int, v: Double) = { + val vSafe = clampV(v) + var i = from + val a = Array.fill(bound)(0d) + while(i < bound) { + a(i) = vSafe + i += 1 + } + + new Histogram(bound, a) + } +} diff --git a/src/main/scala/leon/synthesis/PartialSolution.scala b/src/main/scala/leon/synthesis/PartialSolution.scala index 039342d3c..2e33fd105 100644 --- a/src/main/scala/leon/synthesis/PartialSolution.scala +++ b/src/main/scala/leon/synthesis/PartialSolution.scala @@ -37,7 +37,7 @@ class PartialSolution(g: Graph, includeUntrusted: Boolean) { if (descs.isEmpty) { completeProblem(on.p) } else { - getSolutionFor(descs.minBy(_.costDist)) + getSolutionFor(descs.minBy(_.histogram)) } } else { completeProblem(on.p) diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index cc21e87a9..a9d83d80e 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -20,12 +20,14 @@ case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifie val ev = new DefaultEvaluator(sctx.context, sctx.program) + val safePc = removeWitnesses(sctx.program)(pc) + def isValidExample(ex: Example): Boolean = { val (mapping, cond) = ex match { case io: InOutExample => - (Map((as zip io.ins) ++ (xs zip io.outs): _*), And(pc, phi)) + (Map((as zip io.ins) ++ (xs zip io.outs): _*), And(safePc, phi)) case i => - ((as zip i.ins).toMap, pc) + ((as zip i.ins).toMap, safePc) } ev.eval(cond, mapping) match { @@ -85,7 +87,7 @@ case class Problem(as: List[Identifier], pc: Expr, phi: Expr, xs: List[Identifie case FunctionInvocation(tfd, List(in, out, FiniteMap(inouts))) if tfd.id.name == "passes" => val infos = extractIds(Tuple(Seq(in, out))) val exs = inouts.map{ case (i, o) => - val test = Tuple(Seq(i, o)) + val test = Tuple(Seq(i, o)) val ids = variablesOf(test) evaluator.eval(test, ids.map { (i: Identifier) => i -> i.toVariable }.toMap) match { case EvaluationResults.Successful(res) => res diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala index 339deea13..3c74dd772 100644 --- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala +++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala @@ -2,7 +2,7 @@ package leon.synthesis.graph import java.io.{File, FileWriter, BufferedWriter} -import leon.synthesis.Distribution +import leon.synthesis.Histogram class DotGenerator(g: Graph) { import g.{Node, AndNode, OrNode, RootNode} @@ -68,11 +68,11 @@ class DotGenerator(g: Graph) { res.toString } - def distrib(d: Distribution): String = { - if (d.firstNonZero == g.maxCost) { - ">max" + def hist(h: Histogram): String = { + if (h.isImpossible) { + "-/-" } else { - d.firstNonZero.toString + h.maxInfo._1+"@"+h.maxInfo._2 } } @@ -110,9 +110,9 @@ class DotGenerator(g: Graph) { //cost n match { case an: AndNode => - res append "<TR><TD BORDER=\"0\">"+escapeHTML(distrib(n.costDist)+" ("+distrib(an.selfCost))+")</TD></TR>" + res append "<TR><TD BORDER=\"0\">"+escapeHTML(hist(n.histogram)+" ("+hist(an.selfHistogram))+")</TD></TR>" case on: OrNode => - res append "<TR><TD BORDER=\"0\">"+escapeHTML(distrib(n.costDist))+"</TD></TR>" + res append "<TR><TD BORDER=\"0\">"+escapeHTML(hist(n.histogram))+"</TD></TR>" } res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(nodeDesc(n)))+"</TD></TR>"; diff --git a/src/main/scala/leon/synthesis/graph/Graph.scala b/src/main/scala/leon/synthesis/graph/Graph.scala index bea159f42..4ba9f3751 100644 --- a/src/main/scala/leon/synthesis/graph/Graph.scala +++ b/src/main/scala/leon/synthesis/graph/Graph.scala @@ -22,13 +22,13 @@ sealed class Graph(problem: Problem, costModel: CostModel) { val p: Problem // costs - var costDist: Distribution - def onNewDist(desc: Node) + var histogram: Histogram + def updateHistogram(desc: Node) var isSolved: Boolean = false def isClosed: Boolean = { - costDist.total == 0 + histogram.maxInfo._1 == 0 } def onSolved(desc: Node) @@ -52,8 +52,8 @@ sealed class Graph(problem: Problem, costModel: CostModel) { class AndNode(parent: Option[Node], val ri: RuleInstantiation) extends Node(parent) { val p = ri.problem - var selfCost = Distribution.point(maxCost, costModel.ruleAppCost(ri)) - var costDist: Distribution = Distribution.uniformFrom(maxCost, costModel.ruleAppCost(ri), 0.5) + var selfHistogram = Histogram.point(maxCost, costModel.ruleAppCost(ri), 100) + var histogram = Histogram.uniformFrom(maxCost, costModel.ruleAppCost(ri), 50) override def toString = "\u2227 "+ri; @@ -72,8 +72,8 @@ sealed class Graph(problem: Problem, costModel: CostModel) { solutions = Some(sols) selectedSolution = 0; - costDist = sols.foldLeft(Distribution.empty(maxCost)) { - (d, sol) => d or Distribution.point(maxCost, costModel.solutionCost(sol)) + histogram = sols.foldLeft(Histogram.empty(maxCost)) { + (d, sol) => d or Histogram.point(maxCost, costModel.solutionCost(sol), 100) } isSolved = sols.nonEmpty @@ -86,7 +86,7 @@ sealed class Graph(problem: Problem, costModel: CostModel) { } parents.foreach{ p => - p.onNewDist(this) + p.updateHistogram(this) if (isSolved) { p.onSolved(this) } @@ -112,18 +112,18 @@ sealed class Graph(problem: Problem, costModel: CostModel) { } } - def onNewDist(desc: Node) = { + def updateHistogram(desc: Node) = { recomputeCost() } private def recomputeCost() = { - val newCostDist = descendents.foldLeft(selfCost){ - case (c, d) => c and d.costDist + val newHistogram = descendents.foldLeft(selfHistogram){ + case (c, d) => c and d.histogram } - if (newCostDist != costDist) { - costDist = newCostDist - parents.foreach(_.onNewDist(this)) + if (newHistogram != histogram) { + histogram = newHistogram + parents.foreach(_.updateHistogram(this)) } } @@ -143,7 +143,7 @@ sealed class Graph(problem: Problem, costModel: CostModel) { } class OrNode(parent: Option[Node], val p: Problem) extends Node(parent) { - var costDist: Distribution = Distribution.uniformFrom(maxCost, costModel.problemCost(p), 0.5) + var histogram = Histogram.uniformFrom(maxCost, costModel.problemCost(p), 50) override def toString = "\u2228 "+p; @@ -171,18 +171,18 @@ sealed class Graph(problem: Problem, costModel: CostModel) { } private def recomputeCost(): Unit = { - val newCostDist = descendents.foldLeft(Distribution.empty(maxCost)){ - case (c, d) => c or d.costDist + val newHistogram = descendents.foldLeft(Histogram.empty(maxCost)){ + case (c, d) => c or d.histogram } - if (costDist != newCostDist) { - costDist = newCostDist - parents.foreach(_.onNewDist(this)) + if (histogram != newHistogram) { + histogram = newHistogram + parents.foreach(_.updateHistogram(this)) } } - def onNewDist(desc: Node): Unit = { + def updateHistogram(desc: Node): Unit = { recomputeCost() } } diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index 030d14e44..fd639221d 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -90,7 +90,7 @@ class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Op case on: g.OrNode => if (on.descendents.nonEmpty) { - findIn(on.descendents.minBy(_.costDist)) + findIn(on.descendents.minBy(_.histogram)) } } } @@ -157,17 +157,18 @@ class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) ext def failed(str: String) = "\u001b[31m" + str + "\u001b[0m" def solved(str: String) = "\u001b[32m" + str + "\u001b[0m" - def displayDist(d: Distribution): String = { - f"${d.firstNonZero}%3d" + def displayHistogram(h: Histogram): String = { + val (max, maxarg) = h.maxInfo + f"$max%,2f@$maxarg%2d" } def displayNode(n: Node): String = n match { case an: AndNode => val app = an.ri - s"(${displayDist(n.costDist)}) $app" + s"(${displayHistogram(n.histogram)}) $app" case on: OrNode => val p = on.p - s"(${displayDist(n.costDist)}) $p" + s"(${displayHistogram(n.histogram)}) $p" } def traversePathFrom(n: Node, prefix: List[Int]) { diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index b5997165b..0b740adf7 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -49,9 +49,9 @@ case object CEGLESS extends CEGISLike("CEGLESS") { val inputs = p.as.map(_.toVariable) - val guidedGrammar = guides.map(SimilarTo(_, inputs.toSet)).foldLeft[ExpressionGrammar](Empty)(_ || _) + val guidedGrammar = guides.map(SimilarTo(_, inputs.toSet, Set(sctx.functionContext))).foldLeft[ExpressionGrammar](Empty)(_ || _) - guidedGrammar || OneOf(inputs) + guidedGrammar || OneOf(inputs) || SafeRecCalls(sctx.program, p.pc) } } @@ -170,12 +170,11 @@ abstract class CEGISLike(name: String) extends Rule(name) { res == BooleanLiteral(true) case EvaluationResults.RuntimeError(err) => - //sctx.reporter.error("Error testing CE: "+err) false case EvaluationResults.EvaluatorError(err) => sctx.reporter.error("Error testing CE: "+err) - true + false } } else { true @@ -600,15 +599,16 @@ abstract class CEGISLike(name: String) extends Rule(name) { // We further filter the set of working programs to remove those that fail on known examples if (useCEPruning && hasInputExamples() && ndProgram.canTest()) { - for (p <- prunedPrograms if !interruptManager.isInterrupted()) { + for (bs <- prunedPrograms if !interruptManager.isInterrupted()) { var valid = true val examples = allInputExamples() while(valid && examples.hasNext) { val e = examples.next() - if (!ndProgram.testForProgram(p)(e)) { + if (!ndProgram.testForProgram(bs)(e)) { failedTestsStats(e) += 1 - wrongPrograms += p - prunedPrograms -= p + wrongPrograms += bs + prunedPrograms -= bs + valid = false; } } @@ -622,7 +622,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { val nPassing = prunedPrograms.size sctx.reporter.debug("#Programs passing tests: "+nPassing) - if (nPassing == 0) { + if (nPassing == 0 || interruptManager.isInterrupted()) { skipCESearch = true; } else if (nPassing <= testUpTo) { // Immediate Test diff --git a/src/main/scala/leon/synthesis/rules/GuidedCloser.scala b/src/main/scala/leon/synthesis/rules/GuidedCloser.scala index 91b1494c7..f46e139a0 100644 --- a/src/main/scala/leon/synthesis/rules/GuidedCloser.scala +++ b/src/main/scala/leon/synthesis/rules/GuidedCloser.scala @@ -33,9 +33,7 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") { val vc = simp(And(p.pc, LetTuple(p.xs, wrappedE, Not(p.phi)))) - //println(vc) - - val solver = sctx.newSolver.setTimeout(1000L) + val solver = sctx.newSolver.setTimeout(2000L) solver.assertCnstr(vc) val osol = solver.check match { @@ -47,8 +45,8 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") { printer(vc) printer("== Unknown ==") } - None - //Some(Solution(BooleanLiteral(true), Set(), wrappedE, false)) + //None + Some(Solution(BooleanLiteral(true), Set(), wrappedE, false)) case _ => sctx.reporter.ifDebug { printer => diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index f20fa982b..5ed5990c9 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -112,7 +112,7 @@ object ExpressionGrammars { } } - case class SimilarTo(e: Expr, exclude: Set[Expr] = Set()) extends ExpressionGrammar { + case class SimilarTo(e: Expr, excludeExpr: Set[Expr] = Set(), excludeFCalls: Set[FunDef] = Set()) extends ExpressionGrammar { lazy val allSimilar = computeSimilar(e).groupBy(_._1).mapValues(_.map(_._2)) def computeProductions(t: TypeTree): Seq[Gen] = { @@ -121,7 +121,7 @@ object ExpressionGrammars { def computeSimilar(e : Expr) : Seq[(TypeTree, Gen)] = { - var seenSoFar = exclude; + var seenSoFar = excludeExpr; def gen(retType : TypeTree, tps : Seq[TypeTree], f : Seq[Expr] => Expr) : (TypeTree, Gen) = (bestRealType(retType), Generator[TypeTree, Expr](tps.map(bestRealType), f)) @@ -142,6 +142,8 @@ object ExpressionGrammars { val subs: Seq[(TypeTree, Gen)] = e match { case _: Terminal | _: Let | _: LetTuple | _: LetDef | _: MatchExpr => Seq() + case FunctionInvocation(TypedFunDef(fd, _), _) if excludeFCalls contains fd => + Seq() case UnaryOperator(sub, builder) => Seq( gen(tp, List(sub.getType), { case Seq(ex) => builder(ex) } ) ) ++ rec(sub) diff --git a/src/main/scala/leon/utils/InterruptManager.scala b/src/main/scala/leon/utils/InterruptManager.scala index 1aa19f9a4..abe0d927b 100644 --- a/src/main/scala/leon/utils/InterruptManager.scala +++ b/src/main/scala/leon/utils/InterruptManager.scala @@ -57,7 +57,7 @@ class InterruptManager(reporter: Reporter) { def handle(sig: Signal) { Signal.handle(sigINT, oldHandler) println - reporter.info("Aborting Leon...") + reporter.warning("Aborting Leon...") interrupt() -- GitLab