From 01062258b53f5007eb1817c6951033f089a5c8b6 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <etienne.kneuss@epfl.ch> Date: Mon, 5 Jan 2015 18:10:33 +0100 Subject: [PATCH] Refactor rules, remove Heuristics/Rules --- .../scala/leon/purescala/Constructors.scala | 9 +- .../scala/leon/repair/RepairCostModel.scala | 13 -- src/main/scala/leon/repair/Repairman.scala | 6 +- .../leon/repair/rules/GuidedCloser.scala | 22 ++- .../leon/repair/rules/GuidedDecomp.scala | 12 +- src/main/scala/leon/synthesis/CostModel.scala | 18 +- .../scala/leon/synthesis/Heuristics.scala | 27 --- .../leon/synthesis/PartialSolution.scala | 2 +- src/main/scala/leon/synthesis/Problem.scala | 3 + src/main/scala/leon/synthesis/Rules.scala | 175 +++++++++--------- .../scala/leon/synthesis/SearchContext.scala | 32 ++++ src/main/scala/leon/synthesis/Solution.scala | 9 +- .../leon/synthesis/SynthesisContext.scala | 3 + .../leon/synthesis/SynthesisSettings.scala | 2 +- .../scala/leon/synthesis/Synthesizer.scala | 20 +- .../leon/synthesis/graph/DotGenerator.scala | 2 +- .../scala/leon/synthesis/graph/Graph.scala | 41 ++-- .../scala/leon/synthesis/graph/Search.scala | 14 +- .../scala/leon/synthesis/rules/ADTDual.scala | 6 +- .../{heuristics => rules}/ADTInduction.scala | 16 +- .../ADTLongInduction.scala | 16 +- .../scala/leon/synthesis/rules/ADTSplit.scala | 6 +- .../scala/leon/synthesis/rules/AsChoose.scala | 8 +- .../scala/leon/synthesis/rules/Assert.scala | 10 +- .../leon/synthesis/rules/BottomUpTegis.scala | 16 +- .../leon/synthesis/rules/CaseSplit.scala | 8 +- .../leon/synthesis/rules/CegisLike.scala | 15 +- .../leon/synthesis/rules/DetupleInput.scala | 17 +- .../leon/synthesis/rules/DetupleOutput.scala | 9 +- .../leon/synthesis/rules/Disunification.scala | 6 +- .../leon/synthesis/rules/EqualitySplit.scala | 6 +- .../synthesis/rules/EquivalentInputs.scala | 6 +- .../scala/leon/synthesis/rules/Ground.scala | 16 +- .../scala/leon/synthesis/rules/IfSplit.scala | 8 +- .../synthesis/rules/InequalitySplit.scala | 6 +- .../leon/synthesis/rules/InlineHoles.scala | 18 +- .../InnerCaseSplit.scala | 10 +- .../{heuristics => rules}/IntInduction.scala | 16 +- .../synthesis/rules/IntegerEquation.scala | 12 +- .../synthesis/rules/IntegerInequalities.scala | 17 +- .../scala/leon/synthesis/rules/OnePoint.scala | 6 +- .../synthesis/rules/OptimisticGround.scala | 14 +- .../OptimisticInjection.scala | 8 +- .../SelectiveInlining.scala | 8 +- .../leon/synthesis/rules/TegisLike.scala | 26 +-- .../synthesis/rules/UnconstrainedOutput.scala | 8 +- .../leon/synthesis/rules/Unification.scala | 11 +- .../leon/synthesis/rules/UnusedInput.scala | 4 +- .../test/synthesis/StablePrintingSuite.scala | 6 +- .../leon/test/synthesis/SynthesisSuite.scala | 46 ++--- 50 files changed, 400 insertions(+), 395 deletions(-) delete mode 100644 src/main/scala/leon/synthesis/Heuristics.scala create mode 100644 src/main/scala/leon/synthesis/SearchContext.scala rename src/main/scala/leon/synthesis/{heuristics => rules}/ADTInduction.scala (89%) rename src/main/scala/leon/synthesis/{heuristics => rules}/ADTLongInduction.scala (92%) rename src/main/scala/leon/synthesis/{heuristics => rules}/InnerCaseSplit.scala (76%) rename src/main/scala/leon/synthesis/{heuristics => rules}/IntInduction.scala (84%) rename src/main/scala/leon/synthesis/{heuristics => rules}/OptimisticInjection.scala (79%) rename src/main/scala/leon/synthesis/{heuristics => rules}/SelectiveInlining.scala (79%) diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 920b7cba6..73bc8c16c 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -23,7 +23,12 @@ object Constructors { case Nil => body case x :: Nil => - Let(x, tupleSelect(value, 1), body) + if (value.getType == x.getType) { + // This is for cases where we build it like: letTuple(List(x), tupleWrap(List(z))) + Let(x, value, body) + } else { + Let(x, tupleSelect(value, 1), body) + } case xs => LetTuple(xs, value, body) } @@ -79,8 +84,6 @@ object Constructors { unify(from1, from2) ++ unify(to1,to2) case (FunctionType(from1, to1), FunctionType(from2, to2)) => unifyMany(to1 :: from1, to2 :: from2) - case (TupleType(bases1), TupleType(bases2)) => - unifyMany(bases1, bases2) case (c1 : ClassType, c2 : ClassType) if isSubtypeOf(c1, c2) || isSubtypeOf(c2,c1) => unifyMany(c1.tps, c2.tps) case _ => throw new java.lang.IllegalArgumentException() diff --git a/src/main/scala/leon/repair/RepairCostModel.scala b/src/main/scala/leon/repair/RepairCostModel.scala index 8525db32f..25c19cc2a 100644 --- a/src/main/scala/leon/repair/RepairCostModel.scala +++ b/src/main/scala/leon/repair/RepairCostModel.scala @@ -28,17 +28,4 @@ case class RepairCostModel(cm: CostModel) extends WrappedCostModel(cm, "Repair(" case _ => h+1 }) } - - override def rulesFor(sctx: SynthesisContext, on: OrNode) = { - val rs = cm.rulesFor(sctx, on) - - on.parent match { - case None => - GuidedDecomp +: rs - case Some(an: AndNode) if an.ri.rule == GuidedDecomp => - GuidedCloser +: rs - case _ => - rs - } - } } diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 10a5c2819..e2f9bbe3c 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -19,7 +19,7 @@ import codegen._ import verification._ import synthesis._ import synthesis.rules._ -import synthesis.heuristics._ +import rules._ import synthesis.Witnesses._ import graph.DotGenerator import leon.utils.ASCIIHelpers.title @@ -129,8 +129,8 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout functionsToIgnore = soptions0.functionsToIgnore + fd, costModel = RepairCostModel(soptions0.costModel), rules = (soptions0.rules ++ Seq( - //GuidedDecomp, - //GuidedCloser, + GuidedDecomp, + GuidedCloser, CEGLESS //TEGLESS )) diff Seq(ADTInduction, TEGIS, IntegerInequalities, IntegerEquation) diff --git a/src/main/scala/leon/repair/rules/GuidedCloser.scala b/src/main/scala/leon/repair/rules/GuidedCloser.scala index 0b74f529a..e5b3dfce0 100644 --- a/src/main/scala/leon/repair/rules/GuidedCloser.scala +++ b/src/main/scala/leon/repair/rules/GuidedCloser.scala @@ -18,9 +18,17 @@ import purescala.Constructors._ import Witnesses._ import solvers._ +import graph._ case object GuidedCloser extends NormalizingRule("Guided Closer") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + hctx.parentNode match { + case Some(an: AndNode) if an.ri.rule == GuidedDecomp => + // We proceed as usual + case _ => + return Nil + } + val TopLevelAnds(clauses) = p.ws val guides = clauses.collect { @@ -31,11 +39,11 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") { // Tentative solution using e val wrappedE = if (p.xs.size == 1) Tuple(Seq(e)) else e - val simp = Simplifiers.bestEffort(sctx.context, sctx.program) _ + val simp = Simplifiers.bestEffort(hctx.context, hctx.program) _ val vc = simp(and(p.pc, letTuple(p.xs, wrappedE, not(p.phi)))) - val solver = sctx.newSolver.setTimeout(2000L) + val solver = hctx.sctx.newSolver.setTimeout(2000L) solver.assertCnstr(vc) val osol = solver.check match { @@ -43,7 +51,7 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") { Some(Solution(BooleanLiteral(true), Set(), wrappedE, true)) case None => - sctx.reporter.ifDebug { printer => + hctx.reporter.ifDebug { printer => printer(vc) printer("== Unknown ==") } @@ -52,7 +60,7 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") { None case _ => - sctx.reporter.ifDebug { printer => + hctx.reporter.ifDebug { printer => printer(vc) printer("== Invalid! ==") } @@ -61,9 +69,7 @@ case object GuidedCloser extends NormalizingRule("Guided Closer") { solver.free - osol.map { s => - RuleInstantiation.immediateSuccess(p, this, s) - } + osol.map { solve } } diff --git a/src/main/scala/leon/repair/rules/GuidedDecomp.scala b/src/main/scala/leon/repair/rules/GuidedDecomp.scala index 6f8f44a45..4297b616d 100644 --- a/src/main/scala/leon/repair/rules/GuidedDecomp.scala +++ b/src/main/scala/leon/repair/rules/GuidedDecomp.scala @@ -21,14 +21,18 @@ import Witnesses._ import solvers._ case object GuidedDecomp extends Rule("Guided Decomp") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + if (hctx.searchDepth > 0) { + return Nil; + } + val TopLevelAnds(clauses) = p.ws val guides = clauses.collect { case Guide(expr) => expr } - val simplify = Simplifiers.bestEffort(sctx.context, sctx.program)_ + val simplify = Simplifiers.bestEffort(hctx.context, hctx.program)_ val alts = guides.collect { case g @ IfExpr(c, thn, els) => @@ -42,7 +46,7 @@ case object GuidedDecomp extends Rule("Guided Decomp") { None } - Some(RuleInstantiation.immediateDecomp(p, this, List(sub1, sub2), onSuccess, "Guided If-Split on '"+c+"'")) + Some(decomp(List(sub1, sub2), onSuccess, s"Guided If-Split on '$c'")) case m @ MatchExpr(scrut0, _) => @@ -93,7 +97,7 @@ case object GuidedDecomp extends Rule("Guided Decomp") { )) } - Some(RuleInstantiation.immediateDecomp(p, this, subs.toList, onSuccess, "Guided Match-Split")) + Some(decomp(subs.toList, onSuccess, s"Guided Match-Split on '$scrut0'")) case e => None diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index c08caa410..fe8be1a7c 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -10,31 +10,23 @@ import purescala.TreeOps._ abstract class CostModel(val name: String) { def solution(s: Solution): Cost - def problem(p: Problem): Cost - def andNode(an: AndNode, subs: Option[Seq[Cost]]): Cost def impossible: Cost - def rulesFor(sctx: SynthesisContext, on: OrNode): Seq[Rule] = { - sctx.rules + def isImpossible(c: Cost): Boolean = { + c >= impossible } } -case class Cost(minSize: Int) extends Ordered[Cost] { - def isImpossible = minSize >= 200 - +case class Cost(val minSize: Int) extends AnyVal with Ordered[Cost] { def compare(that: Cost): Int = { this.minSize-that.minSize } def asString: String = { - if (isImpossible) { - "<!>" - } else { - f"$minSize%3d" - } + f"$minSize%3d" } } @@ -56,8 +48,6 @@ class WrappedCostModel(cm: CostModel, name: String) extends CostModel(name) { def andNode(an: AndNode, subs: Option[Seq[Cost]]): Cost = cm.andNode(an, subs) def impossible = cm.impossible - - override def rulesFor(sctx: SynthesisContext, on: OrNode) = cm.rulesFor(sctx, on) } class SizeBasedCostModel(name: String) extends CostModel(name) { diff --git a/src/main/scala/leon/synthesis/Heuristics.scala b/src/main/scala/leon/synthesis/Heuristics.scala deleted file mode 100644 index 378b0018e..000000000 --- a/src/main/scala/leon/synthesis/Heuristics.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis - -import purescala.Trees._ -import purescala.TypeTrees.TupleType - -import heuristics._ - -object Heuristics { - def all = List[Rule]( - IntInduction, - InnerCaseSplit, - //new OptimisticInjection(_), - //new SelectiveInlining(_), - //ADTLongInduction, - ADTInduction - ) -} - -trait Heuristic { - this: Rule => - - override def toString = "H: "+name - -} diff --git a/src/main/scala/leon/synthesis/PartialSolution.scala b/src/main/scala/leon/synthesis/PartialSolution.scala index ea2da95ed..2c517203f 100644 --- a/src/main/scala/leon/synthesis/PartialSolution.scala +++ b/src/main/scala/leon/synthesis/PartialSolution.scala @@ -33,7 +33,7 @@ class PartialSolution(g: Graph, includeUntrusted: Boolean) { } if (n.isExpanded) { - val descs = on.descendents.filter(_.isClosed) + val descs = on.descendents.filterNot(_.isDeadEnd) if (descs.isEmpty) { completeProblem(on.p) } else { diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index 24d95f577..042affe42 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -16,6 +16,9 @@ import Witnesses._ // ⟦ as ⟨ ws && pc | phi ⟩ xs ⟧ case class Problem(as: List[Identifier], ws: Expr, pc: Expr, phi: Expr, xs: List[Identifier]) { + def inType = tupleTypeWrap(as.map(_.getType)) + def outType = tupleTypeWrap(xs.map(_.getType)) + override def toString = { val pcws = and(ws, pc) "⟦ "+as.mkString(";")+", "+(if (pcws != BooleanLiteral(true)) pcws+" ≺ " else "")+" ⟨ "+phi+" ⟩ "+xs.mkString(";")+" ⟧ " diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 5e6db134a..f2df54b71 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -7,8 +7,25 @@ import purescala.Common._ import purescala.Trees._ import purescala.TypeTrees._ import purescala.TreeOps._ +import purescala.Constructors._ import rules._ +abstract class Rule(val name: String) extends RuleDSL { + def instantiateOn(implicit hctx: SearchContext, problem: Problem): Traversable[RuleInstantiation] + + val priority: RulePriority = RulePriorityDefault + + implicit val debugSection = leon.utils.DebugSectionSynthesis + + implicit val thisRule = this + + override def toString = name +} + +abstract class NormalizingRule(name: String) extends Rule(name) { + override val priority = RulePriorityNormalizing +} + object Rules { def all = List[Rule]( Unification.DecompTrivialClash, @@ -34,43 +51,65 @@ object Rules { ADTSplit, InlineHoles, IntegerEquation, - IntegerInequalities + IntegerInequalities, + IntInduction, + InnerCaseSplit, + //new OptimisticInjection(_), + //new SelectiveInlining(_), + //ADTLongInduction, + ADTInduction //AngelicHoles // @EK: Disabled now as it is explicit with withOracle { .. } ) } -abstract class SolutionBuilder(val arity: Int, val types: Seq[TypeTree]) { - def apply(sols: List[Solution]): Option[Solution] +abstract class RuleInstantiation(val description: String, + val onSuccess: SolutionBuilder = SolutionBuilderCloser()) + (implicit val problem: Problem, val rule: Rule) { + + def apply(hctx: SearchContext): RuleApplication - assert(types.size == arity) + override def toString = description } -class SolutionCombiner(arity: Int, types: Seq[TypeTree], f: List[Solution] => Option[Solution]) extends SolutionBuilder(arity, types) { - def apply(sols: List[Solution]) = { - assert(sols.size == arity) - f(sols) - } +/** + * Wrapper class for a function returning a recomposed solution from a list of + * subsolutions + * + * We also need to know the types of the expected sub-solutions to use them in + * cost-models before having real solutions. + */ +abstract class SolutionBuilder { + val types: Seq[TypeTree] + + def apply(sols: List[Solution]): Option[Solution] } -object SolutionBuilder { - val none = new SolutionBuilder(0, Seq()) { - def apply(sols: List[Solution]) = None +case class SolutionBuilderDecomp(val types: Seq[TypeTree], recomp: List[Solution] => Option[Solution]) extends SolutionBuilder { + def apply(sols: List[Solution]): Option[Solution] = { + assert(types.size == sols.size) + recomp(sols) } } -abstract class RuleInstantiation( - val problem: Problem, - val rule: Rule, - val onSuccess: SolutionBuilder, - val description: String, - val priority: RulePriority) { - - def apply(sctx: SynthesisContext): RuleApplication - - override def toString = description +/** + * Used by rules expected to close, no decomposition but maybe we already know + * the solution when instantiating + */ +case class SolutionBuilderCloser(val osol: Option[Solution] = None) extends SolutionBuilder { + val types = Nil + def apply(sols: List[Solution]) = { + assert(sols.isEmpty) + osol + } } +/** + * Results of applying rule instantiations + * + * Can either close meaning a stream of solutions are available (can be empty, + * if it failed) + */ sealed abstract class RuleApplication case class RuleClosed(solutions: Stream[Solution]) extends RuleApplication case class RuleExpanded(sub: List[Problem]) extends RuleApplication @@ -83,90 +122,46 @@ object RuleFailed { def apply(): RuleClosed = RuleClosed(Stream.empty) } +/** + * Rule priorities, which drive the instantiation order. + */ sealed abstract class RulePriority(val v: Int) extends Ordered[RulePriority] { def compare(that: RulePriority) = this.v - that.v } -case object RulePriorityDefault extends RulePriority(2) case object RulePriorityNormalizing extends RulePriority(0) case object RulePriorityHoles extends RulePriority(1) +case object RulePriorityDefault extends RulePriority(2) -object RuleInstantiation { - def immediateDecomp(problem: Problem, - rule: Rule, - sub: List[Problem], - onSuccess: List[Solution] => Option[Solution]): RuleInstantiation = { - - immediateDecomp(problem, rule, sub, onSuccess, rule.name, rule.priority) - } - - def immediateDecomp(problem: Problem, - rule: Rule, - sub: List[Problem], - onSuccess: List[Solution] => Option[Solution], - description: String): RuleInstantiation = { - immediateDecomp(problem, rule, sub, onSuccess, description, rule.priority) - } +/** + * Common utilities used by rules + */ +trait RuleDSL { + this: Rule => - def immediateDecomp(problem: Problem, - rule: Rule, - sub: List[Problem], - onSuccess: List[Solution] => Option[Solution], - description: String, - priority: RulePriority): RuleInstantiation = { - val subTypes = sub.map(p => TupleType(p.xs.map(_.getType))) + def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replaceFromIDs(Map(what), in) + def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replaceFromIDs(what, in) - new RuleInstantiation(problem, rule, new SolutionCombiner(sub.size, subTypes, onSuccess), description, priority) { - def apply(sctx: SynthesisContext) = RuleExpanded(sub) - } - } + val forward: List[Solution] => Option[Solution] = { ss => ss.headOption } - def immediateSuccess(problem: Problem, - rule: Rule, - solution: Solution): RuleInstantiation = { - immediateSuccess(problem, rule, solution, rule.priority) + def decomp(sub: List[Problem], onSuccess: List[Solution] => Option[Solution], description: String) + (implicit problem: Problem): RuleInstantiation = { - } + val subTypes = sub.map(_.outType) - def immediateSuccess(problem: Problem, - rule: Rule, - solution: Solution, - priority: RulePriority): RuleInstantiation = { - new RuleInstantiation(problem, rule, new SolutionCombiner(0, Seq(), ls => Some(solution)), "Solve with "+solution, priority) { - def apply(sctx: SynthesisContext) = RuleClosed(solution) + new RuleInstantiation(description, + SolutionBuilderDecomp(subTypes, onSuccess)) { + def apply(hctx: SearchContext) = RuleExpanded(sub) } } -} -abstract class Rule(val name: String) extends RuleHelpers { - def instantiateOn(sctx: SynthesisContext, problem: Problem): Traversable[RuleInstantiation] - - val priority: RulePriority = RulePriorityDefault - - implicit val debugSection = leon.utils.DebugSectionSynthesis - - override def toString = "R: "+name -} - -abstract class NormalizingRule(name: String) extends Rule(name) { - override val priority = RulePriorityNormalizing -} - -trait RuleHelpers { - def subst(what: Tuple2[Identifier, Expr], in: Expr): Expr = replaceFromIDs(Map(what), in) - def substAll(what: Map[Identifier, Expr], in: Expr): Expr = replaceFromIDs(what, in) - - val forward: List[Solution] => Option[Solution] = { ss => ss.headOption } - - def project(firstN: Int): List[Solution] => Option[Solution] = { - project(0 until firstN) - } + def solve(sol: Solution) + (implicit problem: Problem): RuleInstantiation = { + new RuleInstantiation(s"Solve: $sol", + SolutionBuilderCloser(Some(sol))) { + def apply(hctx: SearchContext) = RuleClosed(sol) + } - def project(ids: Seq[Int]): List[Solution] => Option[Solution] = { - case List(s) => - Some(s.project(ids)) - case _ => - None } } diff --git a/src/main/scala/leon/synthesis/SearchContext.scala b/src/main/scala/leon/synthesis/SearchContext.scala new file mode 100644 index 000000000..39f85eaee --- /dev/null +++ b/src/main/scala/leon/synthesis/SearchContext.scala @@ -0,0 +1,32 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package synthesis + +import graph._ + +/** + * This is context passed down rules, and include search-wise context, as well + * as current search location information + */ +case class SearchContext ( + sctx: SynthesisContext, + currentNode: Node, + search: Search +) { + val context = sctx.context + val reporter = sctx.reporter + val program = sctx.program + + + def searchDepth = { + def depthOf(n: Node): Int = n.parent match { + case Some(n2) => 1+depthOf(n2) + case None => 0 + } + + depthOf(currentNode) + } + + def parentNode: Option[Node] = currentNode.parent +} diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 4da0c9dc7..19f44bb33 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -44,7 +44,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrust term.getType match { case TupleType(ts) => val t = FreshIdentifier("t", true).setType(term.getType) - val newTerm = Let(t, term, Tuple(indices.map(i => TupleSelect(t.toVariable, i+1)))) + val newTerm = Let(t, term, tupleWrap(indices.map(i => TupleSelect(t.toVariable, i+1)))) Solution(pre, defs, newTerm) case _ => @@ -77,14 +77,15 @@ object Solution { // Generate the simplest, wrongest solution, used for complexity lowerbound def basic(p: Problem): Solution = { - new Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(id => simplestValue(id.getType)))) + simplest(p.outType) } def simplest(t: TypeTree): Solution = { new Solution(BooleanLiteral(true), Set(), simplestValue(t)) } - def failed(p: Problem): Solution = { - new Solution(BooleanLiteral(true), Set(), Error(TupleType(p.xs.map(_.getType)), "Failed")) + def UNSAT(implicit p: Problem): Solution = { + val tpe = tupleTypeWrap(p.xs.map(_.getType)) + Solution(BooleanLiteral(false), Set(), Error(tpe, p.phi+" is UNSAT!")) } } diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index c18153ccf..b6e74e2f2 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -13,6 +13,9 @@ import purescala.Common.Identifier import java.util.concurrent.atomic.AtomicBoolean +/** + * This is global information per entire search, contains necessary information + */ case class SynthesisContext( context: LeonContext, settings: SynthesisSettings, diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala index c6a1fb1b6..90713c20b 100644 --- a/src/main/scala/leon/synthesis/SynthesisSettings.scala +++ b/src/main/scala/leon/synthesis/SynthesisSettings.scala @@ -15,7 +15,7 @@ case class SynthesisSettings( firstOnly: Boolean = false, timeoutMs: Option[Long] = None, costModel: CostModel = CostModels.default, - rules: Seq[Rule] = Rules.all ++ Heuristics.all, + rules: Seq[Rule] = Rules.all, manualSearch: Boolean = false, searchBound: Option[Int] = None, selectedSolvers: Set[String] = Set("fairz3"), diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index cd98f703b..cca8d8177 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -4,12 +4,12 @@ package leon package synthesis import purescala.Common._ -import purescala.Definitions.{Program, FunDef, ModuleDef, DefType} +import purescala.Definitions.{Program, FunDef, ModuleDef, DefType, ValDef} import purescala.TreeOps._ import purescala.Trees._ import purescala.Constructors._ import purescala.ScalaPrinter - +import purescala.TypeTrees._ import solvers._ import solvers.combinators._ import solvers.z3._ @@ -99,17 +99,21 @@ class Synthesizer(val context : LeonContext, // Returns the new program and the new functions generated for this def solutionToProgram(sol: Solution): (Program, List[FunDef]) = { - import purescala.TypeTrees.TupleType - import purescala.Definitions.ValDef // Create new fundef for the body - val ret = TupleType(problem.xs.map(_.getType)) + val ret = tupleTypeWrap(problem.xs.map(_.getType)) val res = Variable(FreshIdentifier("res").setType(ret)) val mapPost: Map[Expr, Expr] = - problem.xs.zipWithIndex.map{ case (id, i) => - Variable(id) -> TupleSelect(res, i+1) - }.toMap + if (problem.xs.size > 1) { + problem.xs.zipWithIndex.map{ case (id, i) => + Variable(id) -> tupleSelect(res, i+1) + }.toMap + } else { + problem.xs.map{ case id => + Variable(id) -> res + }.toMap + } val fd = new FunDef(FreshIdentifier(functionContext.id.name+"_final", true), Nil, ret, problem.as.map(id => ValDef(id, id.getType)), DefType.MethodDef) fd.precondition = Some(and(problem.pc, sol.pre)) diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala index dbbb70d1e..4694044cc 100644 --- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala +++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala @@ -87,7 +87,7 @@ class DotGenerator(g: Graph) { val color = if (n.isSolved) { "palegreen" - } else if (n.isClosed) { + } else if (n.isDeadEnd) { "firebrick" } else if(n.isExpanded) { "grey80" diff --git a/src/main/scala/leon/synthesis/graph/Graph.scala b/src/main/scala/leon/synthesis/graph/Graph.scala index 9e384d922..76e161b1b 100644 --- a/src/main/scala/leon/synthesis/graph/Graph.scala +++ b/src/main/scala/leon/synthesis/graph/Graph.scala @@ -9,7 +9,7 @@ sealed class Graph(val cm: CostModel, problem: Problem) { // Returns closed/total def getStats(from: Node = root): (Int, Int) = { - val isClosed = from.isClosed || from.isSolved + val isClosed = from.isDeadEnd || from.isSolved val self = (if (isClosed) 1 else 0, 1) if (!from.isExpanded) { @@ -24,19 +24,17 @@ sealed class Graph(val cm: CostModel, problem: Problem) { } } -sealed abstract class Node(cm: CostModel, parent: Option[Node]) { - var parents: List[Node] = parent.toList +sealed abstract class Node(cm: CostModel, val parent: Option[Node]) { + var parents: List[Node] = parent.toList var descendents: List[Node] = Nil // indicates whether this particular node has already been expanded var isExpanded: Boolean = false - def expand(sctx: SynthesisContext) + def expand(hctx: SearchContext) val p: Problem var isSolved: Boolean = false - - def onSolved(desc: Node) // Solutions this terminal generates (!= None for terminals) @@ -46,8 +44,8 @@ sealed abstract class Node(cm: CostModel, parent: Option[Node]) { // Costs var cost: Cost = computeCost() - def isClosed: Boolean = { - cost.isImpossible + def isDeadEnd: Boolean = { + cm.isImpossible(cost) } // For non-terminals, selected childs for solution @@ -67,7 +65,11 @@ sealed abstract class Node(cm: CostModel, parent: Option[Node]) { cm.impossible case Some(sols) => - sols.map { sol => cm.solution(sol) } .min + if (sols.hasDefiniteSize) { + sols.map { sol => cm.solution(sol) } .min + } else { + cm.solution(sols.head) + } case None => val costs = if (isExpanded) { @@ -96,17 +98,17 @@ class AndNode(cm: CostModel, parent: Option[Node], val ri: RuleInstantiation) ex override def toString = "\u2227 "+ri; - def expand(sctx: SynthesisContext): Unit = { + def expand(hctx: SearchContext): Unit = { require(!isExpanded) isExpanded = true - import sctx.reporter.info + import hctx.sctx.reporter.info val prefix = "[%-20s] ".format(Option(ri.rule).getOrElse("?")) info(prefix+ri.problem) - ri.apply(sctx) match { + ri.apply(hctx) match { case RuleClosed(sols) => solutions = Some(sols) selectedSolution = 0; @@ -164,20 +166,19 @@ class AndNode(cm: CostModel, parent: Option[Node], val ri: RuleInstantiation) ex } -class OrNode(cm: CostModel, val parent: Option[Node], val p: Problem) extends Node(cm, parent) { +class OrNode(cm: CostModel, parent: Option[Node], val p: Problem) extends Node(cm, parent) { override def toString = "\u2228 "+p; - def getInstantiations(sctx: SynthesisContext): List[RuleInstantiation] = { - - val rules = cm.rulesFor(sctx, this) + def getInstantiations(hctx: SearchContext): List[RuleInstantiation] = { + val rules = hctx.sctx.rules val rulesPrio = rules.groupBy(_.priority).toSeq.sortBy(_._1) for ((_, rs) <- rulesPrio) { val results = rs.flatMap{ r => - sctx.context.timers.synthesis.instantiations.get(r.toString).timed { - r.instantiateOn(sctx, p) + hctx.context.timers.synthesis.instantiations.get(r.toString).timed { + r.instantiateOn(hctx, p) } }.toList @@ -188,10 +189,10 @@ class OrNode(cm: CostModel, val parent: Option[Node], val p: Problem) extends No Nil } - def expand(sctx: SynthesisContext): Unit = { + def expand(hctx: SearchContext): Unit = { require(!isExpanded) - val ris = getInstantiations(sctx) + val ris = getInstantiations(hctx) descendents = ris.map(ri => new AndNode(cm, Some(this), ri)) selected = List() diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index 4877c475c..b9bfa13cb 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -22,11 +22,11 @@ abstract class Search(ctx: LeonContext, p: Problem, costModel: CostModel) extend n match { case an: AndNode => ctx.timers.synthesis.applications.get(an.ri.toString).timed { - an.expand(sctx) + an.expand(SearchContext(sctx, an, this)) } case on: OrNode => - on.expand(sctx) + on.expand(SearchContext(sctx, on, this)) } } } @@ -89,7 +89,7 @@ class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Op def findIn(n: Node) { if (!n.isExpanded) { expansionBuffer += n - } else if (!n.isClosed) { + } else if (!n.isDeadEnd) { n match { case an: AndNode => an.descendents.foreach(findIn) @@ -184,14 +184,14 @@ class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) ext super.doStep(n, sctx); // Backtrack view to a point where node is neither closed nor solved - if (n.isClosed || n.isSolved) { + if (n.isDeadEnd || n.isSolved) { var from: Node = g.root var newCd = List[Int]() - while (!from.isSolved && !from.isClosed && newCd.size < cd.size) { + while (!from.isSolved && !from.isDeadEnd && newCd.size < cd.size) { val cdElem = cd(newCd.size) from = traversePathFrom(from, List(cdElem)).get - if (!from.isSolved && !from.isClosed) { + if (!from.isSolved && !from.isDeadEnd) { newCd = cdElem :: newCd } } @@ -233,7 +233,7 @@ class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) ext if (sn.isSolved) { println(solved(pathToString(sp)+" \u2508 "+displayNode(sn))) - } else if (sn.isClosed) { + } else if (sn.isDeadEnd) { println(failed(pathToString(sp)+" \u2508 "+displayNode(sn))) } else if (sn.isExpanded) { println(expanded(pathToString(sp)+" \u2508 "+displayNode(sn))) diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala index 94a4fc1f4..f99c8916e 100644 --- a/src/main/scala/leon/synthesis/rules/ADTDual.scala +++ b/src/main/scala/leon/synthesis/rules/ADTDual.scala @@ -10,7 +10,7 @@ import purescala.Extractors._ import purescala.Constructors._ case object ADTDual extends NormalizingRule("ADTDual") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val xs = p.xs.toSet val as = p.as.toSet @@ -28,9 +28,9 @@ case object ADTDual extends NormalizingRule("ADTDual") { if (!toRemove.isEmpty) { val sub = p.copy(phi = andJoin((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) - List(RuleInstantiation.immediateDecomp(p, this, List(sub), forward, "ADTDual")) + Some(decomp(List(sub), forward, "ADTDual")) } else { - Nil + None } } } diff --git a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala b/src/main/scala/leon/synthesis/rules/ADTInduction.scala similarity index 89% rename from src/main/scala/leon/synthesis/heuristics/ADTInduction.scala rename to src/main/scala/leon/synthesis/rules/ADTInduction.scala index 1e544d8e6..1f1091786 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/rules/ADTInduction.scala @@ -2,7 +2,7 @@ package leon package synthesis -package heuristics +package rules import solvers._ import purescala.Common._ @@ -13,17 +13,17 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ -case object ADTInduction extends Rule("ADT Induction") with Heuristic { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { +case object ADTInduction extends Rule("ADT Induction") { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val candidates = p.as.collect { - case IsTyped(origId, act: AbstractClassType) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, act) + case IsTyped(origId, act: AbstractClassType) if isInductiveOn(hctx.sctx.solverFactory)(p.pc, origId) => (origId, act) } val instances = for (candidate <- candidates) yield { val (origId, ct) = candidate val oas = p.as.filterNot(_ == origId) - val resType = TupleType(p.xs.map(_.getType)) + val resType = tupleTypeWrap(p.xs.map(_.getType)) val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType) val residualArgs = oas.map(id => FreshIdentifier(id.name, true).setType(id.getType)) @@ -86,7 +86,7 @@ case object ADTInduction extends Rule("ADT Induction") with Heuristic { globalPre ::= and(pre, sol.pre) val caze = CaseClassPattern(None, cct, ids.map(id => WildcardPattern(Some(id)))) - SimpleCase(caze, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => LetTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) + SimpleCase(caze, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => letTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) } // Might be overly picky with obviously true pre (a.is[Cons] OR a.is[Nil]) @@ -102,7 +102,7 @@ case object ADTInduction extends Rule("ADT Induction") with Heuristic { val outerPre = orJoin(globalPre) newFun.precondition = Some(funPre) - newFun.postcondition = Some((idPost, LetTuple(p.xs.toSeq, Variable(idPost), funPost))) + newFun.postcondition = Some((idPost, letTuple(p.xs.toSeq, Variable(idPost), funPost))) newFun.body = Some(matchExpr(Variable(inductOn), cases)) @@ -114,7 +114,7 @@ case object ADTInduction extends Rule("ADT Induction") with Heuristic { } } - Some(RuleInstantiation.immediateDecomp(p, this, subProblemsInfo.map(_._1).toList, onSuccess, "ADT Induction on '"+origId+"'")) + Some(decomp(subProblemsInfo.map(_._1).toList, onSuccess, s"ADT Induction on '$origId'")) } else { None } diff --git a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala b/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala similarity index 92% rename from src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala rename to src/main/scala/leon/synthesis/rules/ADTLongInduction.scala index 36e642990..1973dd244 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala +++ b/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala @@ -2,7 +2,7 @@ package leon package synthesis -package heuristics +package rules import solvers._ import purescala.Common._ @@ -13,10 +13,10 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ -case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { +case object ADTLongInduction extends Rule("ADT Long Induction") { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val candidates = p.as.collect { - case IsTyped(origId, act @ AbstractClassType(cd, tpe)) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, act) + case IsTyped(origId, act @ AbstractClassType(cd, tpe)) if isInductiveOn(hctx.sctx.solverFactory)(p.pc, origId) => (origId, act) } val instances = for (candidate <- candidates) yield { @@ -24,7 +24,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { val oas = p.as.filterNot(_ == origId) - val resType = TupleType(p.xs.map(_.getType)) + val resType = tupleTypeWrap(p.xs.map(_.getType)) val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType) val residualArgs = oas.map(id => FreshIdentifier(id.name, true).setType(id.getType)) @@ -135,7 +135,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { val cases = for ((sol, (problem, pat, calls, pc)) <- (sols zip subProblemsInfo)) yield { globalPre ::= and(pc, sol.pre) - SimpleCase(pat, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => LetTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) + SimpleCase(pat, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => letTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) } // Might be overly picky with obviously true pre (a.is[Cons] OR a.is[Nil]) @@ -151,7 +151,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { val outerPre = orJoin(globalPre) newFun.precondition = Some(funPre) - newFun.postcondition = Some((idPost, LetTuple(p.xs.toSeq, Variable(idPost), funPost))) + newFun.postcondition = Some((idPost, letTuple(p.xs.toSeq, Variable(idPost), funPost))) newFun.body = Some(matchExpr(Variable(inductOn), cases)) @@ -163,7 +163,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { } } - Some(RuleInstantiation.immediateDecomp(p, this, subProblemsInfo.map(_._1).toList, onSuccess, "ADT Long Induction on '"+origId+"'")) + Some(decomp(subProblemsInfo.map(_._1).toList, onSuccess, s"ADT Long Induction on '$origId'")) } else { None } diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index c4a259d76..482f40564 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -14,8 +14,8 @@ import purescala.Definitions._ import solvers._ case object ADTSplit extends Rule("ADT Split.") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(new TimeoutSolverFactory(sctx.solverFactory, 200L)) + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + val solver = SimpleSolverAPI(new TimeoutSolverFactory(hctx.sctx.solverFactory, 200L)) val candidates = p.as.collect { case IsTyped(id, act @ AbstractClassType(cd, tpes)) => @@ -80,7 +80,7 @@ case object ADTSplit extends Rule("ADT Split.") { Some(Solution(orJoin(globalPre), sols.flatMap(_.defs).toSet, matchExpr(Variable(id), cases), sols.forall(_.isTrusted))) } - Some(RuleInstantiation.immediateDecomp(p, this, subInfo.map(_._2).toList, onSuccess, "ADT Split on '"+id+"'")) + Some(decomp(subInfo.map(_._2).toList, onSuccess, s"ADT Split on '$id'")) case _ => None }}.flatten diff --git a/src/main/scala/leon/synthesis/rules/AsChoose.scala b/src/main/scala/leon/synthesis/rules/AsChoose.scala index 6d71c5341..aea600fb2 100644 --- a/src/main/scala/leon/synthesis/rules/AsChoose.scala +++ b/src/main/scala/leon/synthesis/rules/AsChoose.scala @@ -5,12 +5,8 @@ package synthesis package rules case object AsChoose extends Rule("As Choose") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - Some(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext) = { - RuleClosed(Solution.choose(p)) - } - }) + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + Some(solve(Solution.choose(p))) } } diff --git a/src/main/scala/leon/synthesis/rules/Assert.scala b/src/main/scala/leon/synthesis/rules/Assert.scala index eb290799e..1671cc450 100644 --- a/src/main/scala/leon/synthesis/rules/Assert.scala +++ b/src/main/scala/leon/synthesis/rules/Assert.scala @@ -10,7 +10,7 @@ import purescala.Extractors._ import purescala.Constructors._ case object Assert extends NormalizingRule("Assert") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { p.phi match { case TopLevelAnds(exprs) => val xsSet = p.xs.toSet @@ -19,20 +19,20 @@ case object Assert extends NormalizingRule("Assert") { if (!exprsA.isEmpty) { if (others.isEmpty) { - List(RuleInstantiation.immediateSuccess(p, this, Solution(andJoin(exprsA), Set(), Tuple(p.xs.map(id => simplestValue(id.getType)))))) + Some(solve(Solution(andJoin(exprsA), Set(), tupleWrap(p.xs.map(id => simplestValue(id.getType)))))) } else { val sub = p.copy(pc = andJoin(p.pc +: exprsA), phi = andJoin(others)) - List(RuleInstantiation.immediateDecomp(p, this, List(sub), { + Some(decomp(List(sub), { case (s @ Solution(pre, defs, term)) :: Nil => Some(Solution(andJoin(exprsA :+ pre), defs, term, s.isTrusted)) case _ => None }, "Assert "+andJoin(exprsA))) } } else { - Nil + None } case _ => - Nil + None } } } diff --git a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala index 33947304a..75846d3ff 100644 --- a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala +++ b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala @@ -41,16 +41,17 @@ abstract class BottomUpTEGISLike[T <% Typed](name: String) extends Rule(name) { def getRootLabel(tpe: TypeTree): T - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val ef = new ExamplesFinder(sctx.context, sctx.program) + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + val ef = new ExamplesFinder(hctx.context, hctx.program) var tests = ef.extractTests(p).collect { case io: InOutExample => (io.ins, io.outs) } if (tests.nonEmpty) { - List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplication = { + List(new RuleInstantiation(this.name) { + def apply(hctx: SearchContext): RuleApplication = { + val sctx = hctx.sctx val evalParams = CodeGenParams(maxFunctionInvocations = 2000, checkContracts = true) //val evaluator = new CodeGenEvaluator(sctx.context, sctx.program, evalParams) @@ -117,11 +118,8 @@ abstract class BottomUpTEGISLike[T <% Typed](name: String) extends Rule(name) { }) } - val (targetType, isWrapped, wrappedTests) = if (p.xs.size == 1) { - (p.xs.head.getType, false, tests.map{ case (i, o) => (i, o.head) }) - } else { - (TupleType(p.xs.map(_.getType)), true, tests.map{ case (i, o) => (i, Tuple(o)) }) - } + val targetType = tupleTypeWrap(p.xs.map(_.getType)) + val wrappedTests = tests.map { case (is, os) => (is, tupleWrap(os))} val enum = new BottomUpEnumerator[T, Expr, Expr]( grammar.getProductions, diff --git a/src/main/scala/leon/synthesis/rules/CaseSplit.scala b/src/main/scala/leon/synthesis/rules/CaseSplit.scala index 740ef8fd2..b09dbd703 100644 --- a/src/main/scala/leon/synthesis/rules/CaseSplit.scala +++ b/src/main/scala/leon/synthesis/rules/CaseSplit.scala @@ -10,16 +10,16 @@ import purescala.Extractors._ import purescala.Constructors._ case object CaseSplit extends Rule("Case-Split") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { p.phi match { case Or(os) => - List(split(os, p, "Split top-level Or")) + List(split(os, "Split top-level Or")) case _ => Nil } } - def split(alts: Seq[Expr], p: Problem, description: String): RuleInstantiation = { + def split(alts: Seq[Expr], description: String)(implicit p: Problem): RuleInstantiation = { val subs = alts.map(a => Problem(p.as, p.ws, p.pc, a, p.xs)).toList val onSuccess: List[Solution] => Option[Solution] = { @@ -37,7 +37,7 @@ case object CaseSplit extends Rule("Case-Split") { None } - RuleInstantiation.immediateDecomp(p, this, subs, onSuccess, description) + decomp(subs, onSuccess, description) } } diff --git a/src/main/scala/leon/synthesis/rules/CegisLike.scala b/src/main/scala/leon/synthesis/rules/CegisLike.scala index 2df78babb..8daacad5a 100644 --- a/src/main/scala/leon/synthesis/rules/CegisLike.scala +++ b/src/main/scala/leon/synthesis/rules/CegisLike.scala @@ -39,7 +39,9 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { def getParams(sctx: SynthesisContext, p: Problem): CegisParams - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + + val sctx = hctx.sctx // CEGIS Flags to actiave or de-activate features val useUninterpretedProbe = sctx.settings.cegisUseUninterpretedProbe @@ -262,7 +264,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { substAll(map.toMap, cClauses(c)) } - Tuple(p.xs.map(c => getCValue(c))) + tupleWrap(p.xs.map(c => getCValue(c))) } @@ -414,9 +416,10 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { def css : Set[Identifier] = mappings.values.map(_._1).toSet ++ guardedTerms.flatMap(_._2) } - List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplication = { + List(new RuleInstantiation(this.name) { + def apply(hctx: SearchContext): RuleApplication = { var result: Option[RuleApplication] = None + val sctx = hctx.sctx var ass = p.as.toSet var xss = p.xs.toSet @@ -509,7 +512,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplication = { for (prog <- programs) { val expr = ndProgram.determinize(prog) - val res = Equals(Tuple(p.xs.map(Variable(_))), expr) + val res = Equals(tupleWrap(p.xs.map(Variable(_))), expr) val solver3 = sctx.newSolver.setTimeout(cexSolverTo) solver3.assertCnstr(and(pc, res, not(p.phi))) @@ -547,7 +550,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { var didFilterAlready = false - val tpe = TupleType(p.xs.map(_.getType)) + val tpe = tupleTypeWrap(p.xs.map(_.getType)) try { do { diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index aea676d18..06c5b5d18 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -10,10 +10,11 @@ import purescala.Common._ import purescala.TypeTrees._ import purescala.TreeOps._ import purescala.Extractors._ +import purescala.Constructors._ case object DetupleInput extends NormalizingRule("Detuple In") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { def isDecomposable(id: Identifier) = id.getType match { case CaseClassType(t, _) if !t.isAbstract => true case TupleType(ts) => true @@ -33,7 +34,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { val map = (newIds.zipWithIndex).map{ case (nid, i) => nid -> TupleSelect(Variable(id), i+1) }.toMap - (newIds.toList, Tuple(newIds.map(Variable(_))), map) + (newIds.toList, tupleWrap(newIds.map(Variable(_))), map) case _ => sys.error("woot") } @@ -45,7 +46,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { var reverseMap = Map[Identifier, Expr]() - val (subAs, outerAs) = p.as.map { a => + val subAs = p.as.map { a => if (isDecomposable(a)) { val (newIds, expr, map) = decompose(a) @@ -55,11 +56,11 @@ case object DetupleInput extends NormalizingRule("Detuple In") { reverseMap ++= map - (newIds, expr) + newIds } else { - (List(a), Variable(a)) + List(a) } - }.unzip + } val newAs = subAs.flatten //sctx.reporter.warning("newOuts: " + newOuts.toString) @@ -76,9 +77,9 @@ case object DetupleInput extends NormalizingRule("Detuple In") { } - Some(RuleInstantiation.immediateDecomp(p, this, List(sub), onSuccess, this.name)) + Some(decomp(List(sub), onSuccess, s"Detuple ${reverseMap.keySet.mkString(", ")}")) } else { - Nil + None } } } diff --git a/src/main/scala/leon/synthesis/rules/DetupleOutput.scala b/src/main/scala/leon/synthesis/rules/DetupleOutput.scala index 5b5e7578e..66dfe0195 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleOutput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleOutput.scala @@ -14,7 +14,7 @@ import purescala.Constructors._ case object DetupleOutput extends Rule("Detuple Out") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { def isDecomposable(id: Identifier) = id.getType match { case CaseClassType(t, _) if !t.isAbstract => true case _ => false @@ -45,15 +45,14 @@ case object DetupleOutput extends Rule("Detuple Out") { val onSuccess: List[Solution] => Option[Solution] = { case List(sol) => - Some(Solution(sol.pre, sol.defs, letTuple(newOuts, sol.term, Tuple(outerOuts)), sol.isTrusted)) + Some(Solution(sol.pre, sol.defs, letTuple(newOuts, sol.term, tupleWrap(outerOuts)), sol.isTrusted)) case _ => None } - - Some(RuleInstantiation.immediateDecomp(p, this, List(sub), onSuccess, this.name)) + Some(decomp(List(sub), onSuccess, s"Detuple out ${p.xs.filter(isDecomposable).mkString(", ")}")) } else { - Nil + None } } } diff --git a/src/main/scala/leon/synthesis/rules/Disunification.scala b/src/main/scala/leon/synthesis/rules/Disunification.scala index ddd7eeca1..53c69e252 100644 --- a/src/main/scala/leon/synthesis/rules/Disunification.scala +++ b/src/main/scala/leon/synthesis/rules/Disunification.scala @@ -12,7 +12,7 @@ import purescala.Constructors._ object Disunification { case object Decomp extends Rule("Disunif. Decomp.") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = p.phi val (toRemove, toAdd) = exprs.collect { @@ -29,9 +29,9 @@ object Disunification { if (!toRemove.isEmpty) { val sub = p.copy(phi = orJoin((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) - List(RuleInstantiation.immediateDecomp(p, this, List(sub), forward, this.name)) + Some(decomp(List(sub), forward, this.name)) } else { - Nil + None } } } diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index 41a995da9..0a3be81d4 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -14,8 +14,8 @@ import purescala.Constructors._ import solvers._ case object EqualitySplit extends Rule("Eq. Split") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(sctx.fastSolverFactory) + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + val solver = SimpleSolverAPI(hctx.sctx.fastSolverFactory) val candidates = p.as.groupBy(_.getType).mapValues(_.combinations(2).filter { case List(a1, a2) => @@ -54,7 +54,7 @@ case object EqualitySplit extends Rule("Eq. Split") { None } - Some(RuleInstantiation.immediateDecomp(p, this, List(sub1, sub2), onSuccess, "Eq. Split on '"+a1+"' and '"+a2+"'")) + Some(decomp(List(sub1, sub2), onSuccess, s"Eq. Split on '$a1' and '$a2'")) case _ => None }) diff --git a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala index 6b936d66b..b5c5744f6 100644 --- a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala +++ b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala @@ -11,7 +11,7 @@ import purescala.Extractors._ import purescala.Constructors._ case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(clauses) = p.pc def discoverEquivalences(allClauses: Seq[Expr]): Seq[(Expr, Expr)] = { @@ -67,13 +67,13 @@ case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { } if (substs.nonEmpty) { - val simplifier = Simplifiers.bestEffort(sctx.context, sctx.program) _ + val simplifier = Simplifiers.bestEffort(hctx.context, hctx.program) _ val sub = p.copy(ws = replaceSeq(substs, p.ws), pc = simplifier(andJoin(replaceSeq(substs, p.pc) +: postsToInject)), phi = simplifier(replaceSeq(substs, p.phi))) - List(RuleInstantiation.immediateDecomp(p, this, List(sub), forward, this.name)) + List(decomp(List(sub), forward, "Equivalent Inputs")) } else { Nil } diff --git a/src/main/scala/leon/synthesis/rules/Ground.scala b/src/main/scala/leon/synthesis/rules/Ground.scala index 9de60847a..43d2ead62 100644 --- a/src/main/scala/leon/synthesis/rules/Ground.scala +++ b/src/main/scala/leon/synthesis/rules/Ground.scala @@ -9,23 +9,23 @@ import purescala.Trees._ import purescala.TypeTrees._ import purescala.TreeOps._ import purescala.Extractors._ +import purescala.Constructors._ case object Ground extends Rule("Ground") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { if (p.as.isEmpty) { - List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplication = { - val solver = SimpleSolverAPI(new TimeoutSolverFactory(sctx.solverFactory, 10000L)) + List(new RuleInstantiation(this.name) { + def apply(hctx: SearchContext): RuleApplication = { + val solver = SimpleSolverAPI(new TimeoutSolverFactory(hctx.sctx.solverFactory, 10000L)) - val tpe = TupleType(p.xs.map(_.getType)) + val tpe = tupleTypeWrap(p.xs.map(_.getType)) val result = solver.solveSAT(p.phi) match { case (Some(true), model) => - val sol = Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(model)))) + val sol = Solution(BooleanLiteral(true), Set(), tupleWrap(p.xs.map(valuateWithModel(model)))) RuleClosed(sol) case (Some(false), model) => - val sol = Solution(BooleanLiteral(false), Set(), Error(tpe, p.phi+" is UNSAT!")) - RuleClosed(sol) + RuleClosed(Solution.UNSAT(p)) case _ => RuleFailed() } diff --git a/src/main/scala/leon/synthesis/rules/IfSplit.scala b/src/main/scala/leon/synthesis/rules/IfSplit.scala index e9e34b80d..dff14901a 100644 --- a/src/main/scala/leon/synthesis/rules/IfSplit.scala +++ b/src/main/scala/leon/synthesis/rules/IfSplit.scala @@ -10,7 +10,7 @@ import purescala.Extractors._ import purescala.Constructors._ case object IfSplit extends Rule("If-Split") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val ifs = collect{ case i: IfExpr => Set(i) case _ => Set[IfExpr]() @@ -21,14 +21,14 @@ case object IfSplit extends Rule("If-Split") { ifs.flatMap { case i @ IfExpr(cond, _, _) => if ((variablesOf(cond) & xsSet).isEmpty) { - List(split(i, p, "Split If("+cond+")")) + List(split(i, s"If-Split on '$cond'")) } else { Nil } } } - def split(i: IfExpr, p: Problem, description: String): RuleInstantiation = { + def split(i: IfExpr, description: String)(implicit p: Problem): RuleInstantiation = { val subs = List( Problem(p.as, p.ws, and(p.pc, i.cond), replace(Map(i -> i.thenn), p.phi), p.xs), Problem(p.as, p.ws, and(p.pc, not(i.cond)), replace(Map(i -> i.elze), p.phi), p.xs) @@ -48,7 +48,7 @@ case object IfSplit extends Rule("If-Split") { None } - RuleInstantiation.immediateDecomp(p, this, subs, onSuccess, description) + decomp(subs, onSuccess, description) } } diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala index ac9bd86e8..54ce95e5e 100644 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -15,8 +15,8 @@ import purescala.Constructors._ import solvers._ case object InequalitySplit extends Rule("Ineq. Split.") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(sctx.fastSolverFactory) + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + val solver = SimpleSolverAPI(hctx.sctx.fastSolverFactory) val candidates = p.as.filter(_.getType == Int32Type).combinations(2).toList.filter { case List(a1, a2) => @@ -77,7 +77,7 @@ case object InequalitySplit extends Rule("Ineq. Split.") { None } - Some(RuleInstantiation.immediateDecomp(p, this, List(subLT, subEQ, subGT), onSuccess, "Ineq. Split on '"+a1+"' and '"+a2+"'")) + Some(decomp(List(subLT, subEQ, subGT), onSuccess, s"Ineq. Split on '$a1' and '$a2'")) case _ => None }) diff --git a/src/main/scala/leon/synthesis/rules/InlineHoles.scala b/src/main/scala/leon/synthesis/rules/InlineHoles.scala index 97b10158a..ee9379449 100644 --- a/src/main/scala/leon/synthesis/rules/InlineHoles.scala +++ b/src/main/scala/leon/synthesis/rules/InlineHoles.scala @@ -21,8 +21,9 @@ import purescala.Constructors._ case object InlineHoles extends Rule("Inline-Holes") { override val priority = RulePriorityHoles - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { // When true: withOracle gets converted into a big choose() on result. + val sctx = hctx.sctx val discreteHoles = sctx.settings.distreteHoles if (!discreteHoles) { @@ -120,7 +121,7 @@ case object InlineHoles extends Rule("Inline-Holes") { val newPhi = simplifyPaths(sfact)(and(pc, p.phi)) val newProblem1 = p.copy(phi = newPhi) - Some(RuleInstantiation.immediateDecomp(p, this, List(newProblem1), { + Some(decomp(List(newProblem1), { case List(s) if (s.pre != BooleanLiteral(false)) => Some(s) case _ => None }, "Avoid Holes")) @@ -133,11 +134,22 @@ case object InlineHoles extends Rule("Inline-Holes") { val (newXs, newPhiInlined) = inlineHoles(newPhi) val newProblem2 = p.copy(phi = newPhiInlined, xs = p.xs ::: newXs) - val rec = Some(RuleInstantiation.immediateDecomp(p, this, List(newProblem2), project(p.xs.size), "Inline Holes")) + val rec = Some(decomp(List(newProblem2), project(p.xs.size), "Inline Holes")) List(rec, avoid).flatten } else { Nil } } + + def project(firstN: Int): List[Solution] => Option[Solution] = { + project(0 until firstN) + } + + def project(ids: Seq[Int]): List[Solution] => Option[Solution] = { + case List(s) => + Some(s.project(ids)) + case _ => + None + } } diff --git a/src/main/scala/leon/synthesis/heuristics/InnerCaseSplit.scala b/src/main/scala/leon/synthesis/rules/InnerCaseSplit.scala similarity index 76% rename from src/main/scala/leon/synthesis/heuristics/InnerCaseSplit.scala rename to src/main/scala/leon/synthesis/rules/InnerCaseSplit.scala index ebfcbf4aa..637bdeeb9 100644 --- a/src/main/scala/leon/synthesis/heuristics/InnerCaseSplit.scala +++ b/src/main/scala/leon/synthesis/rules/InnerCaseSplit.scala @@ -2,15 +2,15 @@ package leon package synthesis -package heuristics +package rules import purescala.Trees._ import purescala.TreeOps._ import purescala.Extractors._ import purescala.Constructors._ -case object InnerCaseSplit extends Rule("Inner-Case-Split") with Heuristic { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { +case object InnerCaseSplit extends Rule("Inner-Case-Split"){ + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { p.phi match { case Or(_) => // Inapplicable in this case, normal case-split has precedence here. @@ -29,13 +29,13 @@ case object InnerCaseSplit extends Rule("Inner-Case-Split") with Heuristic { phi match { case Or(os) => - List(rules.CaseSplit.split(os, p, "Inner case-split")) + List(rules.CaseSplit.split(os, "Inner case-split")) case And(as) => val optapp = for ((a, i) <- as.zipWithIndex) yield { a match { case Or(os) => - Some(rules.CaseSplit.split(os.map(o => andJoin(as.updated(i, o))), p, "Inner case-split")) + Some(rules.CaseSplit.split(os.map(o => andJoin(as.updated(i, o))), "Inner case-split")) case _ => None diff --git a/src/main/scala/leon/synthesis/heuristics/IntInduction.scala b/src/main/scala/leon/synthesis/rules/IntInduction.scala similarity index 84% rename from src/main/scala/leon/synthesis/heuristics/IntInduction.scala rename to src/main/scala/leon/synthesis/rules/IntInduction.scala index 8c593b54a..ae0567bd3 100644 --- a/src/main/scala/leon/synthesis/heuristics/IntInduction.scala +++ b/src/main/scala/leon/synthesis/rules/IntInduction.scala @@ -2,7 +2,7 @@ package leon package synthesis -package heuristics +package rules import purescala.Common._ import purescala.Trees._ @@ -12,11 +12,11 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ -case object IntInduction extends Rule("Int Induction") with Heuristic { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { +case object IntInduction extends Rule("Int Induction") { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { p.as match { case List(IsTyped(origId, Int32Type)) => - val tpe = TupleType(p.xs.map(_.getType)) + val tpe = tupleTypeWrap(p.xs.map(_.getType)) val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType) @@ -50,14 +50,14 @@ case object IntInduction extends Rule("Int Induction") with Heuristic { val idPost = FreshIdentifier("res").setType(tpe) newFun.precondition = Some(preIn) - newFun.postcondition = Some((idPost, LetTuple(p.xs.toSeq, Variable(idPost), p.phi))) + newFun.postcondition = Some((idPost, letTuple(p.xs.toSeq, Variable(idPost), p.phi))) newFun.body = Some( IfExpr(Equals(Variable(inductOn), IntLiteral(0)), base.toExpr, IfExpr(GreaterThan(Variable(inductOn), IntLiteral(0)), - LetTuple(postXs, FunctionInvocation(newFun.typed, Seq(Minus(Variable(inductOn), IntLiteral(1)))), gt.toExpr) - , LetTuple(postXs, FunctionInvocation(newFun.typed, Seq(Plus(Variable(inductOn), IntLiteral(1)))), lt.toExpr))) + letTuple(postXs, FunctionInvocation(newFun.typed, Seq(Minus(Variable(inductOn), IntLiteral(1)))), gt.toExpr) + , letTuple(postXs, FunctionInvocation(newFun.typed, Seq(Plus(Variable(inductOn), IntLiteral(1)))), lt.toExpr))) ) @@ -68,7 +68,7 @@ case object IntInduction extends Rule("Int Induction") with Heuristic { None } - Some(RuleInstantiation.immediateDecomp(p, this, List(subBase, subGT, subLT), onSuccess, "Int Induction on '"+origId+"'")) + Some(decomp(List(subBase, subGT, subLT), onSuccess, s"Int Induction on '$origId'")) case _ => None } diff --git a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala index 257a48247..a2ebfb9f0 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerEquation.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerEquation.scala @@ -16,7 +16,7 @@ import LinearEquations.elimVariable import evaluators._ case object IntegerEquation extends Rule("Integer Equation") { - def instantiateOn(sctx: SynthesisContext, problem: Problem): Traversable[RuleInstantiation] = if(!problem.xs.exists(_.getType == Int32Type)) Nil else { + def instantiateOn(implicit hctx: SearchContext, problem: Problem): Traversable[RuleInstantiation] = if(!problem.xs.exists(_.getType == Int32Type)) Nil else { val TopLevelAnds(exprs) = problem.phi @@ -24,7 +24,7 @@ case object IntegerEquation extends Rule("Integer Equation") { var candidates: Seq[Expr] = eqs var allOthers: Seq[Expr] = others - val evaluator = new DefaultEvaluator(sctx.context, sctx.program) + val evaluator = new DefaultEvaluator(hctx.context, hctx.program) var vars: Set[Identifier] = Set() var eqxs: List[Identifier] = List() @@ -66,7 +66,7 @@ case object IntegerEquation extends Rule("Integer Equation") { None } - List(RuleInstantiation.immediateDecomp(problem, this, List(newProblem), onSuccess, this.name)) + List(decomp(List(newProblem), onSuccess, this.name)) } else { val (eqPre0, eqWitness, freshxs) = elimVariable(evaluator, eqas, normalizedEq) @@ -104,7 +104,7 @@ case object IntegerEquation extends Rule("Integer Equation") { val id2res: Map[Expr, Expr] = freshsubxs.zip(subproblemxs).map{case (id1, id2) => (Variable(id1), Variable(id2))}.toMap ++ neqxs.map(id => (Variable(id), eqSubstMap(Variable(id)))).toMap - Some(Solution(and(eqPre, freshPre), defs, simplifyArithmetic(simplifyLets(LetTuple(subproblemxs, freshTerm, replace(id2res, Tuple(problem.xs.map(Variable(_))))))), s.isTrusted)) + Some(Solution(and(eqPre, freshPre), defs, simplifyArithmetic(simplifyLets(letTuple(subproblemxs, freshTerm, replace(id2res, tupleWrap(problem.xs.map(Variable(_))))))), s.isTrusted)) } case _ => @@ -114,9 +114,9 @@ case object IntegerEquation extends Rule("Integer Equation") { if (subproblemxs.isEmpty) { // we directly solve - List(RuleInstantiation.immediateSuccess(problem, this, onSuccess(List(Solution(and(eqPre, problem.pc), Set(), Tuple(Seq())))).get)) + List(solve(onSuccess(List(Solution(and(eqPre, problem.pc), Set(), UnitLiteral()))).get)) } else { - List(RuleInstantiation.immediateDecomp(problem, this, List(newProblem), onSuccess, this.name)) + List(decomp(List(newProblem), onSuccess, this.name)) } } } diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala index 3e56228e4..20da4b5ac 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala @@ -11,12 +11,13 @@ import purescala.TreeOps._ import purescala.TreeNormalizations.linearArithmeticForm import purescala.TreeNormalizations.NonLinearExpressionException import purescala.TypeTrees._ +import purescala.Constructors._ import purescala.Definitions._ import LinearEquations.elimVariable import leon.synthesis.Algebra.lcm case object IntegerInequalities extends Rule("Integer Inequalities") { - def instantiateOn(sctx: SynthesisContext, problem: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, problem: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = problem.phi @@ -145,7 +146,7 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { val constraints: List[Expr] = for((ub, uc) <- upperBounds; (lb, lc) <- lowerBounds) yield LessEquals(ceilingDiv(lb, IntLiteral(lc)), floorDiv(ub, IntLiteral(uc))) val pre = And(exprNotUsed ++ constraints) - List(RuleInstantiation.immediateSuccess(problem, this, Solution(pre, Set(), Tuple(Seq(witness))))) + List(solve(Solution(pre, Set(), tupleWrap(Seq(witness))))) } else { val involvedVariables = (upperBounds++lowerBounds).foldLeft(Set[Identifier]())((acc, t) => { @@ -181,9 +182,9 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { case List(s @ Solution(pre, defs, term)) => { if(remainderIds.isEmpty) { Some(Solution(And(newPre, pre), defs, - LetTuple(subProblemxs, term, + letTuple(subProblemxs, term, Let(processedVar, witness, - Tuple(problem.xs.map(Variable(_))))), isTrusted=s.isTrusted)) + tupleWrap(problem.xs.map(Variable(_))))), isTrusted=s.isTrusted)) } else if(remainderIds.size > 1) { sys.error("TODO") } else { @@ -192,16 +193,16 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { val loopCounter = Variable(FreshIdentifier("i", true).setType(Int32Type)) val concretePre = replace(Map(Variable(k) -> loopCounter), pre) val concreteTerm = replace(Map(Variable(k) -> loopCounter), term) - val returnType = TupleType(problem.xs.map(_.getType)) + val returnType = tupleTypeWrap(problem.xs.map(_.getType)) val funDef = new FunDef(FreshIdentifier("rec", true), Nil, returnType, Seq(ValDef(loopCounter.id, Int32Type)),DefType.MethodDef) val funBody = expandAndSimplifyArithmetic(IfExpr( LessThan(loopCounter, IntLiteral(0)), Error(returnType, "No solution exists"), IfExpr( concretePre, - LetTuple(subProblemxs, concreteTerm, + letTuple(subProblemxs, concreteTerm, Let(processedVar, witness, - Tuple(problem.xs.map(Variable(_)))) + tupleWrap(problem.xs.map(Variable(_)))) ), FunctionInvocation(funDef.typed, Seq(Minus(loopCounter, IntLiteral(1)))) ) @@ -215,7 +216,7 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { None } - List(RuleInstantiation.immediateDecomp(problem, this, List(subProblem), onSuccess, this.name)) + List(decomp(List(subProblem), onSuccess, this.name)) } } } diff --git a/src/main/scala/leon/synthesis/rules/OnePoint.scala b/src/main/scala/leon/synthesis/rules/OnePoint.scala index 2d7c19cfe..54484206b 100644 --- a/src/main/scala/leon/synthesis/rules/OnePoint.scala +++ b/src/main/scala/leon/synthesis/rules/OnePoint.scala @@ -11,7 +11,7 @@ import purescala.Extractors._ import purescala.Constructors._ case object OnePoint extends NormalizingRule("One-point") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = p.phi def validOnePoint(x: Identifier, e: Expr) = { @@ -35,12 +35,12 @@ case object OnePoint extends NormalizingRule("One-point") { val onSuccess: List[Solution] => Option[Solution] = { case List(s @ Solution(pre, defs, term)) => - Some(Solution(pre, defs, letTuple(oxs, term, subst(x -> e, Tuple(p.xs.map(Variable(_))))), s.isTrusted)) + Some(Solution(pre, defs, letTuple(oxs, term, subst(x -> e, tupleWrap(p.xs.map(Variable(_))))), s.isTrusted)) case _ => None } - List(RuleInstantiation.immediateDecomp(p, this, List(newProblem), onSuccess, this.name)) + List(decomp(List(newProblem), onSuccess, s"One-point on $x = $e")) } else { Nil } diff --git a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala index caecd54f2..bc46805ae 100644 --- a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala +++ b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala @@ -13,17 +13,17 @@ import purescala.Constructors._ import solvers._ case object OptimisticGround extends Rule("Optimistic Ground") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { if (!p.as.isEmpty && !p.xs.isEmpty) { - val res = new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext) = { + val res = new RuleInstantiation(this.name) { + def apply(hctx: SearchContext) = { - val solver = SimpleSolverAPI(sctx.fastSolverFactory) // Optimistic ground is given a simple solver (uninterpreted) + val solver = SimpleSolverAPI(hctx.sctx.fastSolverFactory) // Optimistic ground is given a simple solver (uninterpreted) val xss = p.xs.toSet val ass = p.as.toSet - val tpe = TupleType(p.xs.map(_.getType)) + val tpe = tupleTypeWrap(p.xs.map(_.getType)) var i = 0; var maxTries = 3; @@ -48,7 +48,7 @@ case object OptimisticGround extends Rule("Optimistic Ground") { predicates = valuateWithModelIn(phi, ass, invalidModel) +: predicates case (Some(false), _) => - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), Tuple(p.xs.map(valuateWithModel(satModel)))))) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), tupleWrap(p.xs.map(valuateWithModel(satModel)))))) case _ => continue = false @@ -57,7 +57,7 @@ case object OptimisticGround extends Rule("Optimistic Ground") { case (Some(false), _) => if (predicates.isEmpty) { - result = Some(RuleClosed(Solution(BooleanLiteral(false), Set(), Error(tpe, p.phi+" is UNSAT!")))) + result = Some(RuleClosed(Solution.UNSAT(p))) } else { continue = false result = None diff --git a/src/main/scala/leon/synthesis/heuristics/OptimisticInjection.scala b/src/main/scala/leon/synthesis/rules/OptimisticInjection.scala similarity index 79% rename from src/main/scala/leon/synthesis/heuristics/OptimisticInjection.scala rename to src/main/scala/leon/synthesis/rules/OptimisticInjection.scala index 36c5c0956..d6ccc416f 100644 --- a/src/main/scala/leon/synthesis/heuristics/OptimisticInjection.scala +++ b/src/main/scala/leon/synthesis/rules/OptimisticInjection.scala @@ -2,7 +2,7 @@ package leon package synthesis -package heuristics +package rules import purescala.Common._ import purescala.Trees._ @@ -12,8 +12,8 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ -case object OptimisticInjection extends Rule("Opt. Injection") with Heuristic { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { +case object OptimisticInjection extends Rule("Opt. Injection") { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = p.phi val eqfuncalls = exprs.collect{ @@ -39,7 +39,7 @@ case object OptimisticInjection extends Rule("Opt. Injection") with Heuristic { val sub = p.copy(phi = andJoin(newExprs)) - Some(RuleInstantiation.immediateDecomp(p, this, List(sub), forward)) + Some(decomp(List(sub), forward, s"Injection ${candidates.keySet.map(_._1.id).mkString(", ")}")) } else { None } diff --git a/src/main/scala/leon/synthesis/heuristics/SelectiveInlining.scala b/src/main/scala/leon/synthesis/rules/SelectiveInlining.scala similarity index 79% rename from src/main/scala/leon/synthesis/heuristics/SelectiveInlining.scala rename to src/main/scala/leon/synthesis/rules/SelectiveInlining.scala index 6b524299e..e79d857d4 100644 --- a/src/main/scala/leon/synthesis/heuristics/SelectiveInlining.scala +++ b/src/main/scala/leon/synthesis/rules/SelectiveInlining.scala @@ -2,7 +2,7 @@ package leon package synthesis -package heuristics +package rules import purescala.Common._ import purescala.Trees._ @@ -12,8 +12,8 @@ import purescala.TreeOps._ import purescala.TypeTrees._ import purescala.Definitions._ -case object SelectiveInlining extends Rule("Sel. Inlining") with Heuristic { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { +case object SelectiveInlining extends Rule("Sel. Inlining") { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = p.phi val eqfuncalls = exprs.collect{ @@ -39,7 +39,7 @@ case object SelectiveInlining extends Rule("Sel. Inlining") with Heuristic { val sub = p.copy(phi = andJoin(newExprs)) - Some(RuleInstantiation.immediateDecomp(p, this, List(sub), forward)) + Some(decomp(List(sub), forward, s"Inlining ${candidates.keySet.map(_._1.id).mkString(", ")}")) } else { None } diff --git a/src/main/scala/leon/synthesis/rules/TegisLike.scala b/src/main/scala/leon/synthesis/rules/TegisLike.scala index 96a5b242e..51bbf8a6e 100644 --- a/src/main/scala/leon/synthesis/rules/TegisLike.scala +++ b/src/main/scala/leon/synthesis/rules/TegisLike.scala @@ -30,10 +30,11 @@ abstract class TEGISLike[T <% Typed](name: String) extends Rule(name) { def getParams(sctx: SynthesisContext, p: Problem): TegisParams - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - List(new RuleInstantiation(p, this, SolutionBuilder.none, this.name, this.priority) { - def apply(sctx: SynthesisContext): RuleApplication = { + List(new RuleInstantiation(this.name) { + def apply(hctx: SearchContext): RuleApplication = { + val sctx = hctx.sctx val params = getParams(sctx, p) val grammar = params.grammar @@ -49,11 +50,7 @@ abstract class TEGISLike[T <% Typed](name: String) extends Rule(name) { val enum = new MemoizedEnumerator[T, Expr](grammar.getProductions _) - val (targetType, isWrapped) = if (p.xs.size == 1) { - (p.xs.head.getType, false) - } else { - (TupleType(p.xs.map(_.getType)), true) - } + val targetType = tupleTypeWrap(p.xs.map(_.getType)) val timers = sctx.context.timers.synthesis.rules.tegis; @@ -68,11 +65,7 @@ abstract class TEGISLike[T <% Typed](name: String) extends Rule(name) { candidate = None timers.generating.start() allExprs.take(params.enumLimit).takeWhile(e => candidate.isEmpty).foreach { e => - val exprToTest = if (!isWrapped) { - Let(p.xs.head, e, p.phi) - } else { - letTuple(p.xs, e, p.phi) - } + val exprToTest = letTuple(p.xs, e, p.phi) sctx.reporter.debug("Got expression "+e) timers.testing.start() @@ -89,11 +82,8 @@ abstract class TEGISLike[T <% Typed](name: String) extends Rule(name) { } res }) { - if (isWrapped) { - candidate = Some(e) - } else { - candidate = Some(Tuple(Seq(e))) - } + + candidate = Some(tupleWrap(Seq(e))) } timers.testing.stop() diff --git a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala index becf22f7a..e41d3f64b 100644 --- a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala +++ b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala @@ -10,7 +10,7 @@ import purescala.Extractors._ import purescala.Constructors._ case object UnconstrainedOutput extends NormalizingRule("Unconstr.Output") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val unconstr = p.xs.toSet -- variablesOf(p.phi) if (!unconstr.isEmpty) { @@ -18,16 +18,16 @@ case object UnconstrainedOutput extends NormalizingRule("Unconstr.Output") { val onSuccess: List[Solution] => Option[Solution] = { case List(s) => - val term = letTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(id.getType) else Variable(id)))) + val term = letTuple(sub.xs, s.term, tupleWrap(p.xs.map(id => if (unconstr(id)) simplestValue(id.getType) else Variable(id)))) Some(Solution(s.pre, s.defs, term, s.isTrusted)) case _ => None } - List(RuleInstantiation.immediateDecomp(p, this, List(sub), onSuccess, this.name)) + Some(decomp(List(sub), onSuccess, s"Unconst. out ${p.xs.filter(unconstr).mkString(", ")}")) } else { - Nil + None } } } diff --git a/src/main/scala/leon/synthesis/rules/Unification.scala b/src/main/scala/leon/synthesis/rules/Unification.scala index 4a7283811..c0e865b57 100644 --- a/src/main/scala/leon/synthesis/rules/Unification.scala +++ b/src/main/scala/leon/synthesis/rules/Unification.scala @@ -12,7 +12,7 @@ import purescala.Constructors._ object Unification { case object DecompTrivialClash extends NormalizingRule("Unif Dec./Clash/Triv.") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = p.phi val (toRemove, toAdd) = exprs.collect { @@ -29,8 +29,7 @@ object Unification { if (!toRemove.isEmpty) { val sub = p.copy(phi = andJoin((exprs.toSet -- toRemove ++ toAdd.flatten).toSeq)) - - List(RuleInstantiation.immediateDecomp(p, this, List(sub), forward, this.name)) + List(decomp(List(sub), forward, this.name)) } else { Nil } @@ -40,7 +39,7 @@ object Unification { // This rule is probably useless; it never happens except in crafted // examples, and will be found by OptimisticGround anyway. case object OccursCheck extends NormalizingRule("Unif OccursCheck") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = p.phi val isImpossible = exprs.exists { @@ -53,9 +52,7 @@ object Unification { } if (isImpossible) { - val tpe = TupleType(p.xs.map(_.getType)) - - List(RuleInstantiation.immediateSuccess(p, this, Solution(BooleanLiteral(false), Set(), Error(tpe, p.phi+" is UNSAT!")))) + List(solve(Solution.UNSAT)) } else { Nil } diff --git a/src/main/scala/leon/synthesis/rules/UnusedInput.scala b/src/main/scala/leon/synthesis/rules/UnusedInput.scala index 8b2320fd7..b715078c6 100644 --- a/src/main/scala/leon/synthesis/rules/UnusedInput.scala +++ b/src/main/scala/leon/synthesis/rules/UnusedInput.scala @@ -9,13 +9,13 @@ import purescala.TreeOps._ import purescala.Extractors._ case object UnusedInput extends NormalizingRule("UnusedInput") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { val unused = p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.pc) if (!unused.isEmpty) { val sub = p.copy(as = p.as.filterNot(unused)) - List(RuleInstantiation.immediateDecomp(p, this, List(sub), forward, this.name)) + List(decomp(List(sub), forward, s"Unused inputs ${p.as.filter(unused).mkString(", ")}")) } else { Nil } diff --git a/src/test/scala/leon/test/synthesis/StablePrintingSuite.scala b/src/test/scala/leon/test/synthesis/StablePrintingSuite.scala index ad53849d5..ca2821c68 100644 --- a/src/test/scala/leon/test/synthesis/StablePrintingSuite.scala +++ b/src/test/scala/leon/test/synthesis/StablePrintingSuite.scala @@ -85,12 +85,14 @@ class StablePrintingSuite extends LeonTestSuite { if (j.rules.size < depth) { for ((ci, i) <- chooses.zipWithIndex if j.choosesToProcess(i) || j.choosesToProcess.isEmpty) { val sctx = SynthesisContext.fromSynthesizer(ci.synthesizer) + val search = ci.synthesizer.getSearch() + val hctx = SearchContext(sctx, search.g.root, search) val problem = ci.problem info(j.info("synthesis "+problem)) - val apps = sctx.rules flatMap { _.instantiateOn(sctx, problem)} + val apps = sctx.rules flatMap { _.instantiateOn(hctx, problem)} for (a <- apps) { - a.apply(sctx) match { + a.apply(hctx) match { case RuleClosed(sols) => case RuleExpanded(sub) => a.onSuccess(sub.map(Solution.choose(_))) match { diff --git a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala index e12cb6dc4..73f01af8a 100644 --- a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala @@ -9,6 +9,7 @@ import leon.purescala.TreeOps._ import leon.solvers.z3._ import leon.solvers.Solver import leon.synthesis._ +import leon.synthesis.graph._ import leon.synthesis.utils._ import leon.utils.PreprocessingPhase @@ -23,7 +24,7 @@ class SynthesisSuite extends LeonTestSuite { counter } - def forProgram(title: String, opts: SynthesisSettings = SynthesisSettings())(content: String)(block: (SynthesisContext, FunDef, Problem) => Unit) { + def forProgram(title: String, opts: SynthesisSettings = SynthesisSettings())(content: String)(block: (SearchContext, FunDef, Problem) => Unit) { test("Synthesizing %3d: [%s]".format(nextInt(), title)) { val ctx = testContext.copy(settings = Settings( @@ -45,15 +46,18 @@ class SynthesisSuite extends LeonTestSuite { program, ctx.reporter) - block(sctx, f, p) + val search = new SimpleSearch(ctx, p, opts.costModel, None) + val hctx = SearchContext(sctx, search.g.root, search) + + block(hctx, f, p) } } } case class Apply(desc: String, andThen: List[Apply] = Nil) - def synthesizeWith(sctx: SynthesisContext, p: Problem, ss: Apply): Solution = { - val apps = sctx.rules flatMap { _.instantiateOn(sctx, p)} + def synthesizeWith(hctx: SearchContext, p: Problem, ss: Apply): Solution = { + val apps = hctx.sctx.rules flatMap { _.instantiateOn(hctx, p)} def matchingDesc(app: RuleInstantiation, ss: Apply): Boolean = { import java.util.regex.Pattern; @@ -63,14 +67,14 @@ class SynthesisSuite extends LeonTestSuite { apps.filter(matchingDesc(_, ss)) match { case app :: Nil => - app.apply(sctx) match { + app.apply(hctx) match { case RuleClosed(sols) => assert(sols.nonEmpty) assert(ss.andThen.isEmpty) sols.head case RuleExpanded(sub) => - val subSols = (sub zip ss.andThen) map { case (p, ss) => synthesizeWith(sctx, p, ss) } + val subSols = (sub zip ss.andThen) map { case (p, ss) => synthesizeWith(hctx, p, ss) } app.onSuccess(subSols).get } @@ -84,10 +88,10 @@ class SynthesisSuite extends LeonTestSuite { def synthesize(title: String)(program: String)(strategies: PartialFunction[String, Apply]) { forProgram(title)(program) { - case (sctx, fd, p) => + case (hctx, fd, p) => strategies.lift.apply(fd.id.toString) match { case Some(ss) => - synthesizeWith(sctx, p, ss) + synthesizeWith(hctx, p, ss) case None => assert(false, "Function "+fd.id.toString+" not found") @@ -96,16 +100,16 @@ class SynthesisSuite extends LeonTestSuite { } - def assertAllAlternativesSucceed(sctx: SynthesisContext, rr: Traversable[RuleInstantiation]) { + def assertAllAlternativesSucceed(hctx: SearchContext, rr: Traversable[RuleInstantiation]) { assert(!rr.isEmpty) for (r <- rr) { - assertRuleSuccess(sctx, r) + assertRuleSuccess(hctx, r) } } - def assertRuleSuccess(sctx: SynthesisContext, rr: RuleInstantiation): Option[Solution] = { - rr.apply(sctx) match { + def assertRuleSuccess(hctx: SearchContext, rr: RuleInstantiation): Option[Solution] = { + rr.apply(hctx) match { case RuleClosed(sols) if sols.nonEmpty => sols.headOption case _ => @@ -135,8 +139,8 @@ object Injection { } """ ) { - case (sctx, fd, p) => - assertAllAlternativesSucceed(sctx, rules.Ground.instantiateOn(sctx, p)) + case (hctx, fd, p) => + assertAllAlternativesSucceed(hctx, rules.Ground.instantiateOn(hctx, p)) } forProgram("Cegis 1")( @@ -160,8 +164,8 @@ object Injection { } """ ) { - case (sctx, fd, p) => - assertAllAlternativesSucceed(sctx, rules.CEGIS.instantiateOn(sctx, p)) + case (hctx, fd, p) => + assertAllAlternativesSucceed(hctx, rules.CEGIS.instantiateOn(hctx, p)) } forProgram("Cegis 2")( @@ -185,17 +189,17 @@ object Injection { } """ ) { - case (sctx, fd, p) => - rules.CEGIS.instantiateOn(sctx, p).head.apply(sctx) match { + case (hctx, fd, p) => + rules.CEGIS.instantiateOn(hctx, p).head.apply(hctx) match { case RuleClosed(sols) if sols.nonEmpty => assert(false, "CEGIS should have failed, but found : %s".format(sols.head)) case _ => } - rules.ADTSplit.instantiateOn(sctx, p).head.apply(sctx) match { + rules.ADTSplit.instantiateOn(hctx, p).head.apply(hctx) match { case RuleExpanded(subs) => - for (sub <- subs; alt <- rules.CEGIS.instantiateOn(sctx, sub)) { - assertRuleSuccess(sctx, alt) + for (sub <- subs; alt <- rules.CEGIS.instantiateOn(hctx, sub)) { + assertRuleSuccess(hctx, alt) } case _ => assert(false, "Woot?") -- GitLab