From fd2e80df8e5c2a4cf5d107930a8b91a2bcdeb5da Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Thu, 14 Apr 2016 16:34:32 +0200 Subject: [PATCH] Added let bindings in leon paths --- .../scala/leon/codegen/CodeGeneration.scala | 2 +- .../scala/leon/codegen/runtime/Monitor.scala | 18 ++- .../leon/evaluators/RecursiveEvaluator.scala | 12 +- .../leon/evaluators/StreamEvaluator.scala | 8 +- src/main/scala/leon/grammars/Grammars.scala | 2 +- .../leon/grammars/SafeRecursiveCalls.scala | 3 +- .../leon/laziness/ClosurePreAsserter.scala | 2 +- .../scala/leon/purescala/Definitions.scala | 5 +- src/main/scala/leon/purescala/ExprOps.scala | 148 ++++++++++++------ .../leon/purescala/FunctionClosure.scala | 32 ++-- src/main/scala/leon/purescala/Path.scala | 108 +++++++++++++ .../scala/leon/purescala/Quantification.scala | 6 +- .../leon/purescala/SimplifierWithPaths.scala | 29 ++-- .../leon/purescala/TransformerWithPC.scala | 63 ++++---- src/main/scala/leon/repair/Repairman.scala | 3 +- src/main/scala/leon/repair/rules/Focus.scala | 35 +++-- .../leon/solvers/sygus/SygusSolver.scala | 2 +- .../solvers/unrolling/TemplateGenerator.scala | 17 +- .../solvers/unrolling/TemplateManager.scala | 6 +- .../scala/leon/solvers/z3/FairZ3Solver.scala | 2 +- .../scala/leon/synthesis/ExamplesFinder.scala | 6 +- src/main/scala/leon/synthesis/Problem.scala | 21 +-- src/main/scala/leon/synthesis/Solution.scala | 2 +- .../scala/leon/synthesis/SourceInfo.scala | 13 +- src/main/scala/leon/synthesis/Witnesses.scala | 2 +- .../scala/leon/synthesis/rules/ADTSplit.scala | 21 ++- .../scala/leon/synthesis/rules/Assert.scala | 2 +- .../leon/synthesis/rules/CEGISLike.scala | 22 +-- .../leon/synthesis/rules/DetupleInput.scala | 2 +- .../synthesis/rules/EquivalentInputs.scala | 121 ++++++++------ .../rules/GenericTypeEqualitySplit.scala | 13 +- .../scala/leon/synthesis/rules/IfSplit.scala | 4 +- .../synthesis/rules/IndependentSplit.scala | 4 +- .../synthesis/rules/InequalitySplit.scala | 21 ++- .../leon/synthesis/rules/InputSplit.scala | 12 +- .../synthesis/rules/IntroduceRecCalls.scala | 25 +-- .../synthesis/rules/OptimisticGround.scala | 4 +- .../leon/synthesis/rules/StringRender.scala | 2 +- .../leon/synthesis/rules/UnusedInput.scala | 2 +- .../synthesis/rules/unused/ADTInduction.scala | 16 +- .../rules/unused/ADTLongInduction.scala | 27 ++-- .../synthesis/rules/unused/IntInduction.scala | 8 +- .../rules/unused/IntegerEquation.scala | 8 +- .../synthesis/rules/unused/TEGISLike.scala | 6 +- .../scala/leon/synthesis/utils/Helpers.scala | 17 +- .../scala/leon/termination/ChainBuilder.scala | 29 ++-- .../leon/termination/ChainComparator.scala | 11 +- .../leon/termination/ChainProcessor.scala | 6 +- .../leon/termination/LoopProcessor.scala | 2 +- .../leon/termination/RecursionProcessor.scala | 4 +- .../leon/termination/RelationBuilder.scala | 17 +- .../leon/termination/RelationProcessor.scala | 2 +- .../scala/leon/termination/Strengthener.scala | 27 ++-- .../SerialInstrumentationPhase.scala | 6 +- .../leon/verification/DefaultTactic.scala | 2 +- .../leon/verification/InductionTactic.scala | 9 +- .../integration/grammars/SimilarToSuite.scala | 1 + 57 files changed, 597 insertions(+), 403 deletions(-) create mode 100644 src/main/scala/leon/purescala/Path.scala diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index d7a3b12c1..d197872cf 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -94,7 +94,7 @@ trait CodeGeneration { scala.reflect.NameTransformer.encode(id.uniqueName).replaceAll("\\.", "\\$") } - def defToJVMName(d: Definition): String = "Leon$CodeGen$" + idToSafeJVMName(d.id) + def defToJVMName(d: Definition): String = "Leon$CodeGen$Def$" + idToSafeJVMName(d.id) /** Retrieve the name of the underlying lazy field from a lazy field accessor method */ private[codegen] def underlyingField(lazyAccessor : String) = lazyAccessor + "$underlying" diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala index 6ed6b2d67..07b5e665a 100644 --- a/src/main/scala/leon/codegen/runtime/Monitor.scala +++ b/src/main/scala/leon/codegen/runtime/Monitor.scala @@ -127,11 +127,16 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id } val inputsMap = (newAs zip inputs).map { - case (id, v) => Equals(Variable(id), unit.jvmToValue(v, id.getType)) + case (id, v) => id -> unit.jvmToValue(v, id.getType) } - val expr = instantiateType(and(p.pc, p.phi), tpMap, (p.as zip newAs).toMap ++ (p.xs zip newXs)) - solver.assertCnstr(andJoin(expr +: inputsMap)) + val instTpe: Expr => Expr = { + val idMap = (p.as zip newAs).toMap ++ (p.xs zip newXs) + instantiateType(_: Expr, tpMap, idMap) + } + + val expr = p.pc map instTpe withBindings inputsMap and instTpe(p.phi) + solver.assertCnstr(expr) try { solver.check match { @@ -220,7 +225,12 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id val domainMap = quantifierDomains.groupBy(_._1).mapValues(_.map(_._2).flatten) andJoin(domainMap.toSeq.map { case (id, dom) => - orJoin(dom.toSeq.map { case (path, value) => and(path, Equals(Variable(id), value)) }) + orJoin(dom.toSeq.map { case (path, value) => + // @nv: Note that we know id.getType is first-order since quantifiers can only + // range over basic types. This means equality is guaranteed well-defined + // between `id` and `value` + path and Equals(Variable(id), value) + }) }) }) diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 17b44ea45..4d70063fb 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -627,7 +627,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val domainMap = quantifierDomains.groupBy(_._1).mapValues(_.map(_._2).flatten) andJoin(domainMap.toSeq.map { case (id, dom) => - orJoin(dom.toSeq.map { case (path, value) => and(path, Equals(Variable(id), value)) }) + orJoin(dom.toSeq.map { case (path, value) => + // @nv: Equality with variable is ok, see [[leon.codegen.runtime.Monitor]] + path and Equals(Variable(id), value) + }) }) }) @@ -735,12 +738,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val solver = solverf.getNewSolver() try { - val eqs = p.as.map { - case id => - Equals(Variable(id), rctx.mappings(id)) - } - - val cnstr = andJoin(eqs ::: p.pc :: p.phi :: Nil) + val cnstr = p.pc withBindings p.as.map(id => id -> rctx.mappings(id)) and p.phi solver.assertCnstr(cnstr) solver.check match { diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala index d36a4b739..c211ba706 100644 --- a/src/main/scala/leon/evaluators/StreamEvaluator.scala +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -184,7 +184,11 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) val domainMap = quantifierDomains.groupBy(_._1).mapValues(_.map(_._2).flatten) andJoin(domainMap.toSeq.map { case (id, dom) => - orJoin(dom.toSeq.map { case (path, value) => and(path, Equals(Variable(id), value)) }) + orJoin(dom.toSeq.map { case (path, value) => + // @nv: Note that equality is allowed here because of first-order quantifiers. + // See [[leon.codegen.runtime.Monitor]] for more details. + path and Equals(Variable(id), value) + }) }) }) @@ -238,7 +242,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) case id => Equals(Variable(id), rctx.mappings(id)) } - val cnstr = andJoin(eqs ::: p.pc :: p.phi :: Nil) + val cnstr = p.pc withBindings p.as.map(id => id -> rctx.mappings(id)) and p.phi solver.assertCnstr(cnstr) def getSolution = try { diff --git a/src/main/scala/leon/grammars/Grammars.scala b/src/main/scala/leon/grammars/Grammars.scala index 56d63a438..cc9e1279c 100644 --- a/src/main/scala/leon/grammars/Grammars.scala +++ b/src/main/scala/leon/grammars/Grammars.scala @@ -27,7 +27,7 @@ object Grammars { def default(sctx: SynthesisContext, p: Problem, extraHints: Seq[Expr] = Seq()): ExpressionGrammar = { val TopLevelAnds(ws) = p.ws val hints = ws.collect{ case Hint(e) if formulaSize(e) >= 4 => e } - default(sctx.program, p.as.map(_.toVariable) ++ hints ++ extraHints, sctx.functionContext, sctx.settings.functionsToIgnore) + default(sctx.program, p.allAs.map(_.toVariable) ++ hints ++ extraHints, sctx.functionContext, sctx.settings.functionsToIgnore) } def similarTo(e: Expr, base: ExpressionGrammar) = { diff --git a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala index 679f6f764..fed42d36c 100644 --- a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala +++ b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala @@ -3,6 +3,7 @@ package leon package grammars +import purescala.Path import purescala.Types._ import purescala.Definitions._ import purescala.ExprOps._ @@ -14,7 +15,7 @@ import synthesis.utils.Helpers._ * @param ws An expression that contains the known set [[synthesis.Witnesses.Terminating]] expressions * @param pc The path condition for the generated [[Expr]] by this grammar */ -case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends SimpleExpressionGrammar { +case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Path) extends SimpleExpressionGrammar { def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { val calls = terminatingCalls(prog,ws, pc, Some(t), true) diff --git a/src/main/scala/leon/laziness/ClosurePreAsserter.scala b/src/main/scala/leon/laziness/ClosurePreAsserter.scala index 18398850c..48c51556c 100644 --- a/src/main/scala/leon/laziness/ClosurePreAsserter.scala +++ b/src/main/scala/leon/laziness/ClosurePreAsserter.scala @@ -67,7 +67,7 @@ class ClosurePreAsserter(p: Program) { args :+ st else args val pre2 = replaceFromIDs((target.params.map(_.id) zip nargs).toMap, pre) - val vc = Implies(And(precOrTrue(fd), path), pre2) + val vc = path withCond precOrTrue(fd) implies pre2 // create a function for each vc val lemmaid = FreshIdentifier(ccd.id.name + fd.id.name + "Lem", Untyped, true) val params = variablesOf(vc).toSeq.map(v => ValDef(v)) diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 3f00b8e69..718a50e15 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -454,7 +454,10 @@ object Definitions { def precondition = preconditionOf(fullBody) def precondition_=(oe: Option[Expr]) = { - fullBody = withPrecondition(fullBody, oe) + fullBody = withPrecondition(fullBody, oe) + } + def precondition_=(p: Path) = { + fullBody = withPath(fullBody, p) } def precOrTrue = precondition getOrElse BooleanLiteral(true) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 9a2b46c0f..d9e312788 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -342,6 +342,18 @@ object ExprOps extends GenTreeOps[Expr] { variablesOf(e).isEmpty && isDeterministic(e) } + /** Returns '''true''' if the formula is simple, + * which means that it requires no special encoding for an + * unrolling solver. See implementation for what this means exactly. + */ + def isSimple(e: Expr): Boolean = !exists { + case (_: Choose) | (_: Hole) | + (_: Assert) | (_: Ensuring) | + (_: Forall) | (_: Lambda) | (_: FiniteLambda) | + (_: FunctionInvocation) | (_: Application) => true + case _ => false + } (e) + /** Returns a function which can simplify all ground expressions which appear in a program context. */ def evalGround(ctx: LeonContext, program: Program): Expr => Expr = { @@ -604,16 +616,16 @@ object ExprOps extends GenTreeOps[Expr] { * * @see [[purescala.Expressions.Pattern]] */ - def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false): Expr = { - def bind(ob: Option[Identifier], to: Expr): Expr = { + def conditionForPattern(in: Expr, pattern: Pattern, includeBinders: Boolean = false): Path = { + def bind(ob: Option[Identifier], to: Expr): Path = { if (!includeBinders) { - BooleanLiteral(true) + Path.empty } else { - ob.map(id => Equals(Variable(id), to)).getOrElse(BooleanLiteral(true)) + ob.map(id => Path.empty withBinding (id -> to)).getOrElse(Path.empty) } } - def rec(in: Expr, pattern: Pattern): Expr = { + def rec(in: Expr, pattern: Pattern): Path = { pattern match { case WildcardPattern(ob) => bind(ob, in) @@ -622,31 +634,32 @@ object ExprOps extends GenTreeOps[Expr] { if (ct.parent.isEmpty) { bind(ob, in) } else { - and(IsInstanceOf(in, ct), bind(ob, in)) + Path(IsInstanceOf(in, ct)) merge bind(ob, in) } case CaseClassPattern(ob, cct, subps) => assert(cct.classDef.fields.size == subps.size) val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2)) - val together = and(bind(ob, in) +: subTests :_*) - and(IsInstanceOf(in, cct), together) + val together = subTests.foldLeft(bind(ob, in))(_ merge _) + Path(IsInstanceOf(in, cct)) merge together case TuplePattern(ob, subps) => val TupleType(tpes) = in.getType assert(tpes.size == subps.size) val subTests = subps.zipWithIndex.map{case (p, i) => rec(tupleSelect(in, i+1, subps.size), p)} - and(bind(ob, in) +: subTests: _*) + subTests.foldLeft(bind(ob, in))(_ merge _) case up @ UnapplyPattern(ob, fd, subps) => def someCase(e: Expr) = { // In the case where unapply returns a Some, it is enough that the subpatterns match - andJoin(unwrapTuple(e, subps.size) zip subps map { case (ex, p) => rec(ex, p).setPos(p) }).setPos(e) + val subTests = unwrapTuple(e, subps.size) zip subps map { case (ex, p) => rec(ex, p) } + subTests.foldLeft(Path.empty)(_ merge _).toClause } - and(up.patternMatch(in, BooleanLiteral(false), someCase).setPos(in), bind(ob, in)) + Path(up.patternMatch(in, BooleanLiteral(false), someCase).setPos(in)) merge bind(ob, in) - case LiteralPattern(ob,lit) => - and(Equals(in,lit), bind(ob,in)) + case LiteralPattern(ob, lit) => + Path(Equals(in, lit)) merge bind(ob, in) } } @@ -697,15 +710,15 @@ object ExprOps extends GenTreeOps[Expr] { case m @ MatchExpr(scrut, cases) => // println("Rewriting the following PM: " + e) - val condsAndRhs = for(cse <- cases) yield { + val condsAndRhs = for (cse <- cases) yield { val map = mapForPattern(scrut, cse.pattern) val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false) val realCond = cse.optGuard match { - case Some(g) => and(patCond, replaceFromIDs(map, g)) + case Some(g) => patCond withCond replaceFromIDs(map, g) case None => patCond } val newRhs = replaceFromIDs(map, cse.rhs) - (realCond, newRhs) + (realCond.toClause, newRhs) } val bigIte = condsAndRhs.foldRight[Expr](Error(m.getType, "Match is non-exhaustive").copiedFrom(m))((p1, ex) => { @@ -735,26 +748,26 @@ object ExprOps extends GenTreeOps[Expr] { * @see [[purescala.ExprOps#conditionForPattern conditionForPattern]] * @see [[purescala.ExprOps#mapForPattern mapForPattern]] */ - def matchExprCaseConditions(m: MatchExpr, pathCond: List[Expr]) : Seq[List[Expr]] = { + def matchExprCaseConditions(m: MatchExpr, path: Path) : Seq[Path] = { val MatchExpr(scrut, cases) = m - var pcSoFar = pathCond - for (c <- cases) yield { + var pcSoFar = path + for (c <- cases) yield { val g = c.optGuard getOrElse BooleanLiteral(true) val cond = conditionForPattern(scrut, c.pattern, includeBinders = true) - val localCond = pcSoFar :+ cond :+ g + val localCond = pcSoFar merge (cond withCond g) // These contain no binders defined in this MatchCase val condSafe = conditionForPattern(scrut, c.pattern) - val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern),g) - pcSoFar ::= not(and(condSafe, gSafe)) + val gSafe = replaceFromIDs(mapForPattern(scrut, c.pattern), g) + pcSoFar = pcSoFar merge (condSafe withCond gSafe).negate localCond } } /** Condition to pass this match case, expressed w.r.t scrut only */ - def matchCaseCondition(scrut: Expr, c: MatchCase): Expr = { + def matchCaseCondition(scrut: Expr, c: MatchCase): Path = { val patternC = conditionForPattern(scrut, c.pattern, includeBinders = false) @@ -762,7 +775,7 @@ object ExprOps extends GenTreeOps[Expr] { case Some(g) => // guard might refer to binders val map = mapForPattern(scrut, c.pattern) - and(patternC, replaceFromIDs(map, g)) + patternC withCond replaceFromIDs(map, g) case None => patternC @@ -773,7 +786,7 @@ object ExprOps extends GenTreeOps[Expr] { * * Each case holds the conditions on other previous cases as negative. */ - def passesPathConditions(p : Passes, pathCond: List[Expr]) : Seq[List[Expr]] = { + def passesPathConditions(p: Passes, pathCond: Path) : Seq[Path] = { matchExprCaseConditions(MatchExpr(p.in, p.cases), pathCond) } @@ -990,25 +1003,23 @@ object ExprOps extends GenTreeOps[Expr] { } object CollectorWithPaths { - def apply[T](p: PartialFunction[Expr,T]): CollectorWithPaths[(T, Expr)] = new CollectorWithPaths[(T, Expr)] { - def collect(e: Expr, path: Seq[Expr]): Option[(T, Expr)] = if (!p.isDefinedAt(e)) None else { - Some(p(e) -> and(path: _*)) + def apply[T](p: PartialFunction[Expr,T]): CollectorWithPaths[(T, Path)] = new CollectorWithPaths[(T, Path)] { + def collect(e: Expr, path: Path): Option[(T, Path)] = if (!p.isDefinedAt(e)) None else { + Some(p(e) -> path) } } } trait CollectorWithPaths[T] extends TransformerWithPC with Traverser[Seq[T]] { - type C = Seq[Expr] - protected val initC : C = Nil - def register(e: Expr, path: C) = path :+ e + protected val initPath: Seq[Expr] = Nil private var results: Seq[T] = Nil - def collect(e: Expr, path: Seq[Expr]): Option[T] + def collect(e: Expr, path: Path): Option[T] - def walk(e: Expr, path: Seq[Expr]): Option[Expr] = None + def walk(e: Expr, path: Path): Option[Expr] = None - override def rec(e: Expr, path: Seq[Expr]) = { + override def rec(e: Expr, path: Path) = { collect(e, path).foreach { results :+= _ } walk(e, path) match { case Some(r) => r @@ -1018,18 +1029,18 @@ object ExprOps extends GenTreeOps[Expr] { def traverse(funDef: FunDef): Seq[T] = traverse(funDef.fullBody) - def traverse(e: Expr): Seq[T] = traverse(e, initC) + def traverse(e: Expr): Seq[T] = traverse(e, initPath) def traverse(e: Expr, init: Expr): Seq[T] = traverse(e, Seq(init)) def traverse(e: Expr, init: Seq[Expr]): Seq[T] = { results = Nil - rec(e, init) + rec(e, Path(init)) results } } - def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Expr)] = { + def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Path)] = { CollectorWithPaths(f).traverse(expr) } @@ -1193,7 +1204,7 @@ object ExprOps extends GenTreeOps[Expr] { * foo(Nil, b) and * foo(t, b) => foo(Cons(h,t), b) */ - def isInductiveOn(sf: SolverFactory[Solver])(expr: Expr, on: Identifier): Boolean = on match { + def isInductiveOn(sf: SolverFactory[Solver])(path: Path, on: Identifier): Boolean = on match { case IsTyped(origId, AbstractClassType(cd, tps)) => val toCheck = cd.knownDescendants.collect { @@ -1211,8 +1222,8 @@ object ExprOps extends GenTreeOps[Expr] { } else { val v = Variable(on) - recSelectors.map{ s => - and(isType, expr, not(replace(Map(v -> caseClassSelector(cct, v, s)), expr))) + recSelectors.map { s => + and(path and isType, not(replace(Map(v -> caseClassSelector(cct, v, s)), path.toClause))) } } }.flatten @@ -1779,6 +1790,33 @@ object ExprOps extends GenTreeOps[Expr] { * ================= */ + /** Returns whether a particular [[Expressions.Expr]] contains specification + * constructs, namely [[Expressions.Require]] and [[Expressions.Ensuring]]. + */ + def hasSpec(e: Expr): Boolean = exists { + case Require(_, _) => true + case Ensuring(_, _) => true + case _ => false + } (e) + + /** Merges the given [[Path]] into the provided [[Expressions.Expr]]. + * + * This method expects to run on a [[Definitions.FunDef.fullBody]] and merges into + * existing pre- and postconditions. + * + * @param expr The current body + * @param path The path that should be wrapped around the given body + * @see [[Expressions.Ensuring]] + * @see [[Expressions.Require]] + */ + def withPath(expr: Expr, path: Path): Expr = expr match { + case Let(i, e, b) => Let(i, e, withPath(b, path)) + case Require(pre, b) => path specs (b, pre) + case Ensuring(Require(pre, b), post) => path specs (b, pre, post) + case Ensuring(b, post) => path specs (b, post = post) + case b => path specs b + } + /** Replaces the precondition of an existing [[Expressions.Expr]] with a new one. * * If no precondition is provided, removes any existing precondition. @@ -1793,9 +1831,11 @@ object ExprOps extends GenTreeOps[Expr] { case (Some(newPre), Require(pre, b)) => req(newPre, b) case (Some(newPre), Ensuring(Require(pre, b), p)) => Ensuring(req(newPre, b), p) case (Some(newPre), Ensuring(b, p)) => Ensuring(req(newPre, b), p) + case (Some(newPre), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPrecondition(b, pred)) case (Some(newPre), b) => req(newPre, b) case (None, Require(pre, b)) => b case (None, Ensuring(Require(pre, b), p)) => Ensuring(b, p) + case (None, Let(i, e, b)) if hasSpec(b) => Let(i, e, withPrecondition(b, pred)) case (None, b) => b } @@ -1809,11 +1849,13 @@ object ExprOps extends GenTreeOps[Expr] { * @see [[Expressions.Ensuring]] * @see [[Expressions.Require]] */ - def withPostcondition(expr: Expr, oie: Option[Expr]) = (oie, expr) match { - case (Some(npost), Ensuring(b, post)) => ensur(b, npost) - case (Some(npost), b) => ensur(b, npost) - case (None, Ensuring(b, p)) => b - case (None, b) => b + def withPostcondition(expr: Expr, oie: Option[Expr]): Expr = (oie, expr) match { + case (Some(npost), Ensuring(b, post)) => ensur(b, npost) + case (Some(npost), Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie)) + case (Some(npost), b) => ensur(b, npost) + case (None, Ensuring(b, p)) => b + case (None, Let(i, e, b)) if hasSpec(b) => Let(i, e, withPostcondition(b, oie)) + case (None, b) => b } /** Adds a body to a specification @@ -1824,7 +1866,8 @@ object ExprOps extends GenTreeOps[Expr] { * @see [[Expressions.Ensuring]] * @see [[Expressions.Require]] */ - def withBody(expr: Expr, body: Option[Expr]) = expr match { + def withBody(expr: Expr, body: Option[Expr]): Expr = expr match { + case Let(i, e, b) if hasSpec(b) => Let(i, e, withBody(b, body)) case Require(pre, _) => Require(pre, body.getOrElse(NoTree(expr.getType))) case Ensuring(Require(pre, _), post) => Ensuring(Require(pre, body.getOrElse(NoTree(expr.getType))), post) case Ensuring(_, post) => Ensuring(body.getOrElse(NoTree(expr.getType)), post) @@ -1840,7 +1883,8 @@ object ExprOps extends GenTreeOps[Expr] { * @see [[Expressions.Ensuring]] * @see [[Expressions.Require]] */ - def withoutSpec(expr: Expr) = expr match { + def withoutSpec(expr: Expr): Option[Expr] = expr match { + case Let(i, e, b) => withoutSpec(b).map(Let(i, e, _)) case Require(pre, b) => Option(b).filterNot(_.isInstanceOf[NoTree]) case Ensuring(Require(pre, b), post) => Option(b).filterNot(_.isInstanceOf[NoTree]) case Ensuring(b, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) @@ -1848,14 +1892,16 @@ object ExprOps extends GenTreeOps[Expr] { } /** Returns the precondition of an expression wrapped in Option */ - def preconditionOf(expr: Expr) = expr match { + def preconditionOf(expr: Expr): Option[Expr] = expr match { + case Let(i, e, b) => preconditionOf(b).map(Let(i, e, _)) case Require(pre, _) => Some(pre) case Ensuring(Require(pre, _), _) => Some(pre) case b => None } /** Returns the postcondition of an expression wrapped in Option */ - def postconditionOf(expr: Expr) = expr match { + def postconditionOf(expr: Expr): Option[Expr] = expr match { + case Let(i, e, b) => postconditionOf(b).map(Let(i, e, _)) case Ensuring(_, post) => Some(post) case _ => None } @@ -2047,7 +2093,7 @@ object ExprOps extends GenTreeOps[Expr] { val conds = collectWithPC { case m @ MatchExpr(scrut, cases) => - (m, orJoin(cases map (matchCaseCondition(scrut, _)))) + (m, orJoin(cases map (matchCaseCondition(scrut, _).toClause))) case e @ Error(_, _) => (e, BooleanLiteral(false)) @@ -2067,7 +2113,7 @@ object ExprOps extends GenTreeOps[Expr] { conds map { case ((e, cond), path) => - (e, implies(path, cond)) + (e, path implies cond) } } diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index bfce31656..ba3353e32 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -32,6 +32,7 @@ object FunctionClosure extends TransformationPhase { case LetDef(fd1, body) => fd1.filter(funDefs) }(fd.fullBody) } + val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap val nestedFuns = nestedWithPaths.keys.toSeq @@ -39,24 +40,26 @@ object FunctionClosure extends TransformationPhase { val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure( nestedFuns.map { f => val calls = functionCallsOf(f.fullBody) collect { - case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) => - fd + case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) => fd } - val pcCalls = functionCallsOf(nestedWithPaths(f)) collect { - case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) => - fd + + val pcCalls = functionCallsOf(nestedWithPaths(f).fullClause) collect { + case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) => fd } + f -> (calls ++ pcCalls) }.toMap ) //println("nested funs: " + nestedFuns) //println("call graph: " + callGraph) - def freeVars(fd: FunDef): Set[Identifier] = variablesOf(fd.fullBody) -- fd.paramIds + def freeVars(fd: FunDef, pc: Path): Set[Identifier] = + variablesOf(fd.fullBody) ++ pc.variables -- fd.paramIds -- pc.bindings.map(_._1) // All free variables one should include. // Contains free vars of the function itself plus of all transitively called functions. // also contains free vars from PC if the PC is relevant to the fundef + /* val transFree = { def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = { nestedFuns.map(fd => { @@ -69,10 +72,18 @@ object FunctionClosure extends TransformationPhase { (fd, transFreeVars ++ reqPaths.flatMap(p => variablesOf(p)) -- fd.paramIds) }).toMap } + utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, freeVars(fd))).toMap) }.map(p => (p._1, p._2.toSeq)) + */ //println("free vars: " + transFree) + // All free variables one should include. + // Contains free vars of the function itself plus of all transitively called functions. + val transFree = nestedFuns.map { fd => + fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq + }.toMap + // Closed functions along with a map (old var -> new var). val closed = nestedWithPaths.map { case (inner, pc) => inner -> closeFd(inner, fd, pc, transFree(inner)) @@ -97,7 +108,7 @@ object FunctionClosure extends TransformationPhase { (dummySubst +: closed.values.toSeq).foreach { case FunSubst(f, callerMap, callerTMap) => f.fullBody = preMap { - case fi@FunctionInvocation(tfd, args) if closed contains tfd.fd => + case fi @ FunctionInvocation(tfd, args) if closed contains tfd.fd => val FunSubst(newCallee, calleeMap, calleeTMap) = closed(tfd.fd) // This needs some explanation. @@ -141,7 +152,7 @@ object FunctionClosure extends TransformationPhase { ) // Takes one inner function and closes it. - private def closeFd(inner: FunDef, outer: FunDef, pc: Expr, free: Seq[Identifier]): FunSubst = { + private def closeFd(inner: FunDef, outer: FunDef, pc: Path, free: Seq[Identifier]): FunSubst = { val tpFresh = outer.tparams map { _.freshen } val tparamsMap = outer.tparams.zip(tpFresh map {_.tp}).toMap @@ -155,16 +166,15 @@ object FunctionClosure extends TransformationPhase { freshVals.map(ValDef(_)), instantiateType(inner.returnType, tparamsMap) ) - newFd.precondition = Some(and(pc, inner.precOrTrue)) val instBody = instantiateType( - newFd.fullBody, + withPath(newFd.fullBody, pc), tparamsMap, freeMap ) newFd.fullBody = preMap { - case fi@FunctionInvocation(tfd, args) if tfd.fd == inner => + case fi @ FunctionInvocation(tfd, args) if tfd.fd == inner => Some(FunctionInvocation( newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }), args ++ freshVals.drop(args.length).map(Variable) diff --git a/src/main/scala/leon/purescala/Path.scala b/src/main/scala/leon/purescala/Path.scala new file mode 100644 index 000000000..6bc93410d --- /dev/null +++ b/src/main/scala/leon/purescala/Path.scala @@ -0,0 +1,108 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Constructors._ +import Extractors._ +import ExprOps._ +import Types._ + +object Path { + def empty: Path = new Path(Seq.empty) + def apply(p: Expr): Path = p match { + case Let(i, e, b) => new Path(Seq(Left(i -> e))) merge apply(b) + case _ => new Path(Seq(Right(p))) + } + def apply(path: Seq[Expr]): Path = new Path(path.map(Right(_))) +} + +/** Encodes path conditions */ +class Path private[purescala]( + private[purescala] val elements: Seq[Either[(Identifier, Expr), Expr]]) { + + def withBinding(p: (Identifier, Expr)) = new Path(elements :+ Left(p)) + def withBindings(ps: Iterable[(Identifier, Expr)]) = new Path(elements ++ ps.map(Left(_))) + def withCond(e: Expr) = new Path(elements :+ Right(e)) + def withConds(es: Iterable[Expr]) = new Path(elements ++ es.map(Right(_))) + + def --(ids: Set[Identifier]) = new Path(elements.filterNot(_.left.exists(p => ids(p._1)))) + + def merge(that: Path) = new Path(elements ++ that.elements) + + def map(f: Expr => Expr) = new Path(elements.map(_.left.map { case (id, e) => id -> f(e) }.right.map(f))) + def partition(p: Expr => Boolean): (Path, Seq[Expr]) = { + val (passed, failed) = elements.partition { + case Right(e) => p(e) + case Left(_) => true + } + + (new Path(passed), failed.flatMap(_.right.toOption)) + } + + def isEmpty = elements.filter { + case Right(BooleanLiteral(true)) => false + case _ => true + }.isEmpty + + def negate: Path = { + val (outers, rest) = elements.span(_.isLeft) + new Path(outers :+ Right(not(fold[Expr](BooleanLiteral(true), let, Constructors.and(_, _))(rest)))) + } + + lazy val variables: Set[Identifier] = fold[Set[Identifier]](Set.empty, + (id, e, res) => res - id ++ variablesOf(e), (e, res) => res ++ variablesOf(e) + )(elements) + + lazy val bindings: Seq[(Identifier, Expr)] = elements.collect { case Left(p) => p } + lazy val conditions: Seq[Expr] = elements.collect { case Right(e) => e } + + private def fold[T](base: T, combineLet: (Identifier, Expr, T) => T, combineCond: (Expr, T) => T) + (elems: Seq[Either[(Identifier, Expr), Expr]]): T = elems.foldRight(base) { + case (Left((id, e)), res) => combineLet(id, e, res) + case (Right(e), res) => combineCond(e, res) + } + + private def distributiveClause(base: Expr, combine: (Expr, Expr) => Expr): Expr = { + val (outers, rest) = elements.span(_.isLeft) + val inner = fold[Expr](base, let, combine)(rest) + fold[Expr](inner, let, (_,_) => scala.sys.error("Should never happen!"))(outers) + } + + def and(base: Expr) = distributiveClause(base, Constructors.and(_, _)) + def implies(base: Expr) = distributiveClause(base, Constructors.implies(_, _)) + def specs(body: Expr, pre: Expr = BooleanLiteral(true), post: Expr = NoTree(BooleanType)) = { + val (outers, rest) = elements.span(_.isLeft) + val cond = fold[Expr](BooleanLiteral(true), let, Constructors.and(_, _))(rest) + + def wrap(e: Expr) = fold[Expr](e, let, (_, res) => res)(rest) + + val req = Require(Constructors.and(cond, wrap(pre)), wrap(body)) + val full = if (post != NoTree(BooleanType)) Ensuring(req, wrap(post)) else req + + fold[Expr](full, let, (_, _) => scala.sys.error("Should never happen!"))(outers) + } + + lazy val toClause: Expr = and(BooleanLiteral(true)) + lazy val fullClause: Expr = fold[Expr](BooleanLiteral(true), Let(_, _, _), And(_, _))(elements) + + lazy val toPath: Expr = andJoin(elements.map { + case Left((id, e)) => Equals(id.toVariable, e) + case Right(e) => e + }) + + override def equals(that: Any): Boolean = that match { + case p: Path => elements == p.elements + case _ => false + } + + override def hashCode: Int = elements.hashCode + + override def toString = asString(LeonContext.printNames) + def asString(implicit ctx: LeonContext): String = fullClause.asString + def asString(pgm: Program)(implicit ctx: LeonContext): String = fullClause.asString(pgm) +} + diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala index c815f4ca5..5e10ace6b 100644 --- a/src/main/scala/leon/purescala/Quantification.scala +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -54,7 +54,7 @@ object Quantification { res.filter(ms => ms.forall(m => reverseMap(m) subsetOf ms)) } - def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Expr, Expr, Seq[Expr])]] = { + def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Path, Expr, Seq[Expr])]] = { object QMatcher { def unapply(e: Expr): Option[(Expr, Seq[Expr])] = e match { case QuantificationMatcher(expr, args) => @@ -71,8 +71,8 @@ object Quantification { val matchers = allMatchers.map { case ((caller, args), path) => (path, caller, args) }.toSet extractQuorums(matchers, quantified, - (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case QMatcher(e, a) => (p._1, e, a) }.toSet, - (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case Variable(id) if quantified(id) => id }.toSet) + (p: (Path, Expr, Seq[Expr])) => p._3.collect { case QMatcher(e, a) => (p._1, e, a) }.toSet, + (p: (Path, Expr, Seq[Expr])) => p._3.collect { case Variable(id) if quantified(id) => id }.toSet) } object Domains { diff --git a/src/main/scala/leon/purescala/SimplifierWithPaths.scala b/src/main/scala/leon/purescala/SimplifierWithPaths.scala index 40f100fcb..b6cf9e657 100644 --- a/src/main/scala/leon/purescala/SimplifierWithPaths.scala +++ b/src/main/scala/leon/purescala/SimplifierWithPaths.scala @@ -10,15 +10,12 @@ import Extractors._ import Constructors._ import solvers._ -class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Expr] = Nil) extends TransformerWithPC { - type C = List[Expr] +class SimplifierWithPaths(sf: SolverFactory[Solver], override val initPath: List[Expr] = Nil) extends TransformerWithPC { val solver = SimpleSolverAPI(sf) - protected def register(e: Expr, c: C) = e :: c - - def impliedBy(e : Expr, path : Seq[Expr]) : Boolean = try { - solver.solveVALID(implies(andJoin(path), e)) match { + def impliedBy(e: Expr, path: Path) : Boolean = try { + solver.solveVALID(path implies e) match { case Some(true) => true case _ => false } @@ -26,8 +23,8 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex case _ : Exception => false } - def contradictedBy(e : Expr, path : Seq[Expr]) : Boolean = try { - solver.solveVALID(implies(andJoin(path), Not(e))) match { + def contradictedBy(e: Expr, path: Path) : Boolean = try { + solver.solveVALID(path implies not(e)) match { case Some(true) => true case _ => false } @@ -35,7 +32,7 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex case _ : Exception => false } - def valid(e : Expr) : Boolean = try { + def valid(e: Expr) : Boolean = try { solver.solveVALID(e) match { case Some(true) => true case _ => false @@ -44,7 +41,7 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex case _ : Exception => false } - def sat(e : Expr) : Boolean = try { + def sat(e: Expr) : Boolean = try { solver.solveSAT(e) match { case (Some(false),_) => false case _ => true @@ -53,7 +50,7 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex case _ : Exception => true } - protected override def rec(e: Expr, path: C) = e match { + protected override def rec(e: Expr, path: Path) = e match { case Require(pre, body) if impliedBy(pre, path) => body @@ -86,7 +83,7 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex case Implies(lhs, rhs) if contradictedBy(lhs, path) => BooleanLiteral(true).copiedFrom(e) - case me@MatchExpr(scrut, cases) => + case me @ MatchExpr(scrut, cases) => val rs = rec(scrut, path) var stillPossible = true @@ -94,9 +91,9 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex val conds = matchExprCaseConditions(me, path) val newCases = cases.zip(conds).flatMap { case (cs, cond) => - if (stillPossible && sat(and(cond: _*))) { + if (stillPossible && sat(cond.toClause)) { - if (valid(and(cond: _*))) { + if (valid(cond.toClause)) { stillPossible = false } @@ -107,7 +104,8 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex // @mk: This is quite a dirty hack. We just know matchCasePathConditions // returns the current guard as the last element. // We don't include it in the path condition when we recurse into itself. - val condWithoutGuard = cond.dropRight(1) + // @nv: baaaaaaaad!!! + val condWithoutGuard = new Path(cond.elements.dropRight(1)) val newGuard = rec(g, condWithoutGuard) if (valid(newGuard)) SimpleCase(p, rec(rhs,cond)) @@ -118,6 +116,7 @@ class SimplifierWithPaths(sf: SolverFactory[Solver], override val initC: List[Ex Seq() } } + newCases match { case List() => Error(e.getType, "Unreachable code").copiedFrom(e) diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/leon/purescala/TransformerWithPC.scala index 9411bf68d..53e675684 100644 --- a/src/main/scala/leon/purescala/TransformerWithPC.scala +++ b/src/main/scala/leon/purescala/TransformerWithPC.scala @@ -3,47 +3,42 @@ package leon package purescala +import Common._ import Expressions._ -import ExprOps._ -import Extractors._ import Constructors._ +import Extractors._ +import ExprOps._ +import Types._ /** Traverses/ transforms expressions with path condition awareness. * - * Path condition representation is left generic (type [[C]]) + * As lets cannot be encoded as Equals due to types for which equality + * is not well-founded, path conditions reconstruct lets around the + * final condition one wishes to verify through [[Path.getClause]]. */ abstract class TransformerWithPC extends Transformer { - /** The type of the path condition */ - type C - /** The initial path condition */ - protected val initC: C - - /** Register a new expression to a path condition */ - protected def register(cond: Expr, path: C): C + protected val initPath: Seq[Expr] - protected def rec(e: Expr, path: C): Expr = e match { + protected def rec(e: Expr, path: Path): Expr = e match { case Let(i, v, b) => val se = rec(v, path) - val sb = rec(b, register(Equals(Variable(i), se), path)) + val sb = rec(b, path withBinding (i -> se)) Let(i, se, sb).copiedFrom(e) - case Ensuring(req@Require(pre, body), lam@Lambda(Seq(arg), post)) => + case Ensuring(req @ Require(pre, body), lam @ Lambda(Seq(arg), post)) => val spre = rec(pre, path) - val sbody = rec(body, register(spre, path)) - val spost = rec(post, register( - and(spre, Equals(arg.toVariable, sbody)), - path - )) + val sbody = rec(body, path withCond spre) + val spost = rec(post, path withCond spre withBinding (arg.id -> sbody)) Ensuring( Require(spre, sbody).copiedFrom(req), Lambda(Seq(arg), spost).copiedFrom(lam) ).copiedFrom(e) - case Ensuring(body, lam@Lambda(Seq(arg), post)) => + case Ensuring(body, lam @ Lambda(Seq(arg), post)) => val sbody = rec(body, path) - val spost = rec(post, register(Equals(arg.toVariable, sbody), path)) + val spost = rec(post, path withBinding (arg.id -> sbody)) Ensuring( sbody, Lambda(Seq(arg), spost).copiedFrom(lam) @@ -51,7 +46,7 @@ abstract class TransformerWithPC extends Transformer { case Require(pre, body) => val sp = rec(pre, path) - val sb = rec(body, register(sp, path)) + val sb = rec(body, path withCond pre) Require(sp, sb).copiedFrom(e) //@mk: TODO Discuss if we should include asserted predicates in the pc @@ -60,8 +55,8 @@ abstract class TransformerWithPC extends Transformer { // val sb = rec(body, register(sp, path)) // Assert(sp, err, sb).copiedFrom(e) - case p:Passes => - applyAsMatches(p,rec(_,path)) + case p: Passes => + applyAsMatches(p, rec(_,path)) case MatchExpr(scrut, cases) => val rs = rec(scrut, path) @@ -69,29 +64,27 @@ abstract class TransformerWithPC extends Transformer { var soFar = path MatchExpr(rs, cases.map { c => - val patternExprPos = conditionForPattern(rs, c.pattern, includeBinders = true) - val patternExprNeg = conditionForPattern(rs, c.pattern, includeBinders = false) + val patternPathPos = conditionForPattern(rs, c.pattern, includeBinders = true) + val patternPathNeg = conditionForPattern(rs, c.pattern, includeBinders = false) val map = mapForPattern(rs, c.pattern) val guardOrTrue = c.optGuard.getOrElse(BooleanLiteral(true)) val guardMapped = replaceFromIDs(map, guardOrTrue) - val subPath = register(and(patternExprPos, guardOrTrue), soFar) - soFar = register(not(and(patternExprNeg, guardMapped)), soFar) - - MatchCase(c.pattern, c.optGuard, rec(c.rhs,subPath)).copiedFrom(c) + val subPath = soFar merge (patternPathPos withCond guardOrTrue) + soFar = soFar merge (patternPathNeg withCond guardMapped).negate + MatchCase(c.pattern, c.optGuard, rec(c.rhs, subPath)).copiedFrom(c) }).copiedFrom(e) case IfExpr(cond, thenn, elze) => val rc = rec(cond, path) - - IfExpr(rc, rec(thenn, register(rc, path)), rec(elze, register(Not(rc), path))).copiedFrom(e) + IfExpr(rc, rec(thenn, path withCond rc), rec(elze, path withCond Not(rc))).copiedFrom(e) case And(es) => var soFar = path andJoin(for(e <- es) yield { val se = rec(e, soFar) - soFar = register(se, soFar) + soFar = soFar withCond se se }).copiedFrom(e) @@ -99,13 +92,13 @@ abstract class TransformerWithPC extends Transformer { var soFar = path orJoin(for(e <- es) yield { val se = rec(e, soFar) - soFar = register(Not(se), soFar) + soFar = soFar withCond Not(se) se }).copiedFrom(e) case i @ Implies(lhs, rhs) => val rc = rec(lhs, path) - Implies(rc, rec(rhs, register(rc, path))).copiedFrom(i) + Implies(rc, rec(rhs, path withCond rc)).copiedFrom(i) case o @ Operator(es, builder) => builder(es.map(rec(_, path))).copiedFrom(o) @@ -115,7 +108,7 @@ abstract class TransformerWithPC extends Transformer { } def transform(e: Expr): Expr = { - rec(e, initC) + rec(e, Path(initPath)) } } diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 4b6d0ce83..1c49aad05 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -4,6 +4,7 @@ package leon package repair import leon.datagen.GrammarDataGen +import purescala.Path import purescala.Definitions._ import purescala.Expressions._ import purescala.Extractors._ @@ -183,7 +184,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou val guide = Guide(origBody) val pre = fd.precOrTrue - val prob = Problem.fromSpec(fd.postOrTrue, andJoin(Seq(pre, guide, term)), eb, Some(fd)) + val prob = Problem.fromSpec(fd.postOrTrue, Path(Seq(pre, guide, term)), eb, Some(fd)) val ci = SourceInfo(fd, origBody, prob) diff --git a/src/main/scala/leon/repair/rules/Focus.scala b/src/main/scala/leon/repair/rules/Focus.scala index 7ce3fd261..512fb0afd 100644 --- a/src/main/scala/leon/repair/rules/Focus.scala +++ b/src/main/scala/leon/repair/rules/Focus.scala @@ -7,6 +7,7 @@ package rules import synthesis._ import leon.evaluators._ +import purescala.Path import purescala.Expressions._ import purescala.Common._ import purescala.Types._ @@ -112,17 +113,17 @@ case object Focus extends PreprocessingRule("Focus") { // Try to focus on branches forAllTests(c, Map(), evaluator) match { case Some(true) => - val np = Problem(p.as, ws(thn), and(p.pc, c), p.phi, p.xs, p.qeb.filterIns(c)) + val np = Problem(p.as, ws(thn), p.pc withCond c, p.phi, p.xs, p.qeb.filterIns(c)) Some(decomp(List(np), termWrap(IfExpr(c, _, els), c), s"Focus on if-then")(p)) case Some(false) => - val np = Problem(p.as, ws(els), and(p.pc, not(c)), p.phi, p.xs, p.qeb.filterIns(not(c))) + val np = Problem(p.as, ws(els), p.pc withCond not(c), p.phi, p.xs, p.qeb.filterIns(not(c))) Some(decomp(List(np), termWrap(IfExpr(c, thn, _), not(c)), s"Focus on if-else")(p)) case None => // We split - val sub1 = p.copy(ws = ws(thn), pc = and(c, replace(Map(g -> thn), p.pc)), eb = p.qeb.filterIns(c)) - val sub2 = p.copy(ws = ws(els), pc = and(Not(c), replace(Map(g -> els), p.pc)), eb = p.qeb.filterIns(Not(c))) + val sub1 = p.copy(ws = ws(thn), pc = p.pc map (replace(Map(g -> thn), _)) withCond c , eb = p.qeb.filterIns(c)) + val sub2 = p.copy(ws = ws(els), pc = p.pc map (replace(Map(g -> thn), _)) withCond Not(c), eb = p.qeb.filterIns(Not(c))) val onSuccess: List[Solution] => Option[Solution] = { case List(s1, s2) => @@ -136,21 +137,21 @@ case object Focus extends PreprocessingRule("Focus") { } case MatchExpr(scrut, cases) => - var pcSoFar: Seq[Expr] = Nil + var pcSoFar = Path.empty // Generate subproblems for each match-case that fails at least one test. var casesInfos = for (c <- cases) yield { val map = mapForPattern(scrut, c.pattern) val thisCond = matchCaseCondition(scrut, c) - val cond = andJoin(pcSoFar :+ thisCond) - pcSoFar = pcSoFar :+ not(thisCond) + val cond = pcSoFar merge thisCond + pcSoFar = pcSoFar merge thisCond.negate - val subP = if (existsFailing(cond, map, evaluator)) { + val subP = if (existsFailing(cond.toClause, map, evaluator)) { val vars = map.toSeq.map(_._1) // Filter tests by the path-condition - val eb2 = p.qeb.filterIns(cond) + val eb2 = p.qeb.filterIns(cond.toClause) // Augment test with the additional variables and their valuations val ebF: (Seq[Expr] => List[Seq[Expr]]) = { (e: Seq[Expr]) => @@ -167,9 +168,9 @@ case object Focus extends PreprocessingRule("Focus") { eb2.eb } - val newPc = andJoin(cond +: vars.map { id => equality(id.toVariable, map(id)) }) + val newPc = Path.empty withBindings vars.map(id => id -> map(id)).toSeq merge cond - Some(Problem(p.as ++ vars, ws(c.rhs), and(p.pc, newPc), p.phi, p.xs, eb3)) + Some(Problem(p.as, ws(c.rhs), p.pc merge newPc, p.phi, p.xs, eb3)) } else { None } @@ -179,14 +180,14 @@ case object Focus extends PreprocessingRule("Focus") { // Check if the match might be missing a case? (we check if one test // goes to no defined cases) - val elsePc = andJoin(pcSoFar) + val elsePc = pcSoFar - if (existsFailing(elsePc, Map(), evaluator)) { + if (existsFailing(elsePc.toClause, Map(), evaluator)) { val newCase = MatchCase(WildcardPattern(None), None, NoTree(scrut.getType)) - val eb = p.qeb.filterIns(elsePc) + val eb = p.qeb.filterIns(elsePc.toClause) - val newProblem = Problem(p.as, andJoin(wss), and(p.pc, elsePc), p.phi, p.xs, eb) + val newProblem = Problem(p.as, andJoin(wss), p.pc merge elsePc, p.phi, p.xs, eb) casesInfos :+= (newCase -> (Some(newProblem), elsePc)) } @@ -210,7 +211,7 @@ case object Focus extends PreprocessingRule("Focus") { if(s.pre == BooleanLiteral(true)) { BooleanLiteral(true) } else { - and(p.pc, s.pre) + p.pc and s.pre } } @@ -241,7 +242,7 @@ case object Focus extends PreprocessingRule("Focus") { }.toList } - val np = Problem(p.as :+ id, ws(body), and(p.pc, equality(id.toVariable, value)), p.phi, p.xs, p.eb.mapIns(ebF)) + val np = Problem(p.as, ws(body), p.pc withBinding (id -> value), p.phi, p.xs, p.eb.mapIns(ebF)) Some(decomp(List(np), termWrap(Let(id, value, _)), s"Focus on let-body")(p)) diff --git a/src/main/scala/leon/solvers/sygus/SygusSolver.scala b/src/main/scala/leon/solvers/sygus/SygusSolver.scala index 816872492..f8f2c52d1 100644 --- a/src/main/scala/leon/solvers/sygus/SygusSolver.scala +++ b/src/main/scala/leon/solvers/sygus/SygusSolver.scala @@ -59,7 +59,7 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p val synthPhi = replaceFromIDs(xToFdCall, p.phi) - val constraint = implies(p.pc, synthPhi) + val constraint = p.pc implies synthPhi emit(FunctionApplication(constraintId, Seq(toSMT(constraint)(bindings)))) diff --git a/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala index 83129798a..6d29458e2 100644 --- a/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala @@ -248,15 +248,6 @@ class TemplateGenerator[T](val theories: TheoryEncoder, var lambdas = Seq[LambdaTemplate[T]]() @inline def registerLambda(lambda: LambdaTemplate[T]) : Unit = lambdas :+= lambda - def requireDecomposition(e: Expr) = { - exists{ - case (_: Choose) | (_: Forall) | (_: Lambda) | (_: FiniteLambda) => true - case (_: Assert) | (_: Ensuring) => true - case (_: FunctionInvocation) | (_: Application) => true - case _ => false - }(e) - } - def rec(pathVar: Identifier, expr: Expr): Expr = { expr match { case a @ Assert(cond, err, body) => @@ -300,14 +291,14 @@ class TemplateGenerator[T](val theories: TheoryEncoder, case p : Passes => sys.error("'Passes's should have been eliminated before generating templates.") case i @ Implies(lhs, rhs) => - if (requireDecomposition(i)) { + if (!isSimple(i)) { rec(pathVar, Or(Not(lhs), rhs)) } else { implies(rec(pathVar, lhs), rec(pathVar, rhs)) } case a @ And(parts) => - val partitions = groupWhile(parts)(!requireDecomposition(_)) + val partitions = groupWhile(parts)(isSimple) partitions.map(andJoin) match { case Seq(e) => e case seq => @@ -336,7 +327,7 @@ class TemplateGenerator[T](val theories: TheoryEncoder, } case o @ Or(parts) => - val partitions = groupWhile(parts)(!requireDecomposition(_)) + val partitions = groupWhile(parts)(isSimple) partitions.map(orJoin) match { case Seq(e) => e case seq => @@ -365,7 +356,7 @@ class TemplateGenerator[T](val theories: TheoryEncoder, } case i @ IfExpr(cond, thenn, elze) => { - if(!requireDecomposition(i)) { + if(isSimple(i)) { i } else { val newBool1 : Identifier = FreshIdentifier("b", BooleanType, true) diff --git a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala index b5c00e403..37f1a031f 100644 --- a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala @@ -168,10 +168,10 @@ object Template { } (expr) val withPaths = CollectorWithPaths { case FreshFunction(f) => f }.traverse(e) - functions ++= withPaths.map { case (f, TopLevelAnds(paths)) => + functions ++= withPaths.map { case (f, path) => val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] - val path = andJoin(paths.map(clean)) - (encodeExpr(and(Variable(b), path)), tpe, encodeExpr(f)) + val cleanPath = path.map(clean) + (encodeExpr(and(Variable(b), cleanPath.toPath)), tpe, encodeExpr(f)) } val cleanExpr = clean(e) diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index abf4c63d3..815f79750 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -173,7 +173,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) try { super.assertCnstr(expression) } catch { - case _: Unsupported => + case u: Unsupported => addError() } } diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 628a9201a..0ffdbc769 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -78,7 +78,7 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { /** Extract examples from the passes found in expression */ def extractFromProblem(p: Problem): ExamplesBank = { - val testClusters = extractTestsOf(and(p.pc, p.phi)) + val testClusters = extractTestsOf(p.pc and p.phi) // Finally, we keep complete tests covering all as++xs val allIds = (p.as ++ p.xs).toSet @@ -102,9 +102,9 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { if(this.keepAbstractExamples) return true // TODO: Abstract interpretation here ? val (mapping, cond) = ex match { case io: InOutExample => - (Map((p.as zip io.ins) ++ (p.xs zip io.outs): _*), And(p.pc, p.phi)) + (Map((p.as zip io.ins) ++ (p.xs zip io.outs): _*), p.pc and p.phi) case i => - ((p.as zip i.ins).toMap, p.pc) + ((p.as zip i.ins).toMap, p.pc.toClause) } evaluator.eval(cond, mapping) match { diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index f60b5fc5a..e3c818354 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -3,6 +3,7 @@ package leon package synthesis +import purescala.Path import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ @@ -21,18 +22,20 @@ import Witnesses._ * @param phi The formula on `as` and `xs` to satisfy * @param xs The list of output identifiers for which we want to compute a function */ -case class Problem(as: List[Identifier], ws: Expr, pc: Expr, phi: Expr, xs: List[Identifier], eb: ExamplesBank = ExamplesBank.empty) extends Printable { +case class Problem(as: List[Identifier], ws: Expr, pc: Path, phi: Expr, xs: List[Identifier], eb: ExamplesBank = ExamplesBank.empty) extends Printable { def inType = tupleTypeWrap(as.map(_.getType)) def outType = tupleTypeWrap(xs.map(_.getType)) + def allAs = as ++ pc.bindings.map(_._1) + def asString(implicit ctx: LeonContext): String = { - val pcws = and(ws, pc) + val pcws = pc withCond ws val ebInfo = "/"+eb.valids.size+","+eb.invalids.size+"/" s"""|⟦ ${if (as.nonEmpty) as.map(_.asString).mkString(", ") else "()"} - | ${pcws.asString} ≺ + | ${pcws.toClause.asString} ≺ | ⟨ ${phi.asString} ⟩ | ${if (xs.nonEmpty) xs.map(_.asString).mkString(", ") else "()"} |⟧ $ebInfo""".stripMargin @@ -51,7 +54,7 @@ case class Problem(as: List[Identifier], ws: Expr, pc: Expr, phi: Expr, xs: List object Problem { def fromSpec( spec: Expr, - pc: Expr = BooleanLiteral(true), + pc: Path = Path.empty, eb: ExamplesBank = ExamplesBank.empty, fd: Option[FunDef] = None ): Problem = { @@ -61,7 +64,7 @@ object Problem { }.toList val phi = application(simplifyLets(spec), xs map { _.toVariable}) - val as = (variablesOf(And(pc, phi)) -- xs).toList.sortBy(_.name) + val as = (variablesOf(phi) ++ pc.variables -- xs).toList.sortBy(_.name) val sortedAs = fd match { case None => as @@ -70,14 +73,12 @@ object Problem { as.sortBy(a => argsIndex(a)) } - val TopLevelAnds(clauses) = pc - - val (pcs, wss) = clauses.partition { - case w : Witness => false + val (pcs, wss) = pc.partition { + case w: Witness => false case _ => true } - Problem(sortedAs, andJoin(wss), andJoin(pcs), phi, xs, eb) + Problem(sortedAs, andJoin(wss), pcs, phi, xs, eb) } } diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index bddbcae12..571d42e21 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -77,7 +77,7 @@ object Solution { } def chooseComplete(p: Problem): Solution = { - new Solution(BooleanLiteral(true), Set(), Choose(Lambda(p.xs.map(ValDef(_)), and(p.pc, p.phi)))) + new Solution(BooleanLiteral(true), Set(), Choose(Lambda(p.xs.map(ValDef(_)), p.pc and p.phi))) } // Generate the simplest, wrongest solution, used for complexity lowerbound diff --git a/src/main/scala/leon/synthesis/SourceInfo.scala b/src/main/scala/leon/synthesis/SourceInfo.scala index 0e9a22fd3..0a34afb69 100644 --- a/src/main/scala/leon/synthesis/SourceInfo.scala +++ b/src/main/scala/leon/synthesis/SourceInfo.scala @@ -3,6 +3,7 @@ package leon package synthesis +import purescala.Path import purescala.Definitions._ import purescala.Constructors._ import purescala.Expressions._ @@ -13,9 +14,9 @@ case class SourceInfo(fd: FunDef, source: Expr, problem: Problem) object SourceInfo { - class ChooseCollectorWithPaths extends CollectorWithPaths[(Choose,Expr)] { - def collect(e: Expr, path: Seq[Expr]) = e match { - case c: Choose => Some(c -> and(path: _*)) + class ChooseCollectorWithPaths extends CollectorWithPaths[(Choose,Path)] { + def collect(e: Expr, path: Path) = e match { + case c: Choose => Some(c -> path) case _ => None } } @@ -55,15 +56,15 @@ object SourceInfo { val functionEb = eFinder.extractFromFunDef(fd, partition = false) for ((ch, path) <- new ChooseCollectorWithPaths().traverse(fd)) yield { - val outerEb = if (path == BooleanLiteral(true)) { + val outerEb = if (path.isEmpty) { functionEb } else { ExamplesBank.empty } - val p = Problem.fromSpec(ch.pred, and(path, term), outerEb, Some(fd)) + val p = Problem.fromSpec(ch.pred, path withCond term, outerEb, Some(fd)) - val pcEb = eFinder.generateForPC(p.as, path, ctx, 20) + val pcEb = eFinder.generateForPC(p.as, path.toClause, ctx, 20) val chooseEb = eFinder.extractFromProblem(p) val eb = (outerEb union chooseEb) union pcEb diff --git a/src/main/scala/leon/synthesis/Witnesses.scala b/src/main/scala/leon/synthesis/Witnesses.scala index 884eee9f0..b2359b1d4 100644 --- a/src/main/scala/leon/synthesis/Witnesses.scala +++ b/src/main/scala/leon/synthesis/Witnesses.scala @@ -15,7 +15,7 @@ object Witnesses { override def isSimpleExpr = true } - case class Guide(e : Expr) extends Witness { + case class Guide(e: Expr) extends Witness { def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] = Some((Seq(e), (es: Seq[Expr]) => Guide(es.head))) override def printWith(implicit pctx: PrinterContext): Unit = { diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 93e1488ae..71f564859 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -22,17 +22,14 @@ case object ADTSplit extends Rule("ADT Split.") { // don't want to split on two variables for which only one split // alternative is viable. This should be much less expensive than making // calls to a solver for each pair. - var facts = Map[Identifier, CaseClassType]() - - def addFacts(e: Expr): Unit = e match { - case Equals(Variable(a), CaseClass(cct, _)) => facts += a -> cct - case IsInstanceOf(Variable(a), cct: CaseClassType) => facts += a -> cct - case _ => - } - - val TopLevelAnds(as) = and(p.pc, p.phi) - for (e <- as) { - addFacts(e) + val facts: Map[Identifier, CaseClassType] = { + val TopLevelAnds(as) = andJoin(p.pc.conditions :+ p.phi) + val instChecks: Seq[(Identifier, CaseClassType)] = as.collect { + case IsInstanceOf(Variable(a), cct: CaseClassType) => a -> cct + case Equals(Variable(a), CaseClass(cct, _)) => a -> cct + } + val boundCcs = p.pc.bindings.collect { case (id, CaseClass(cct, _)) => id -> cct } + instChecks.toMap ++ boundCcs } val candidates = p.as.collect { @@ -74,7 +71,7 @@ case object ADTSplit extends Rule("ADT Split.") { val whole = CaseClass(cct, args.map(Variable)) val subPhi = subst(id -> whole, p.phi) - val subPC = subst(id -> whole, p.pc) + val subPC = p.pc map (subst(id -> whole, _)) val subWS = subst(id -> whole, p.ws) val eb2 = p.qeb.mapIns { inInfo => diff --git a/src/main/scala/leon/synthesis/rules/Assert.scala b/src/main/scala/leon/synthesis/rules/Assert.scala index 0acdc209e..d7dd317d7 100644 --- a/src/main/scala/leon/synthesis/rules/Assert.scala +++ b/src/main/scala/leon/synthesis/rules/Assert.scala @@ -29,7 +29,7 @@ case object Assert extends NormalizingRule("Assert") { Some(solve(Solution(pre = andJoin(exprsA), defs = Set(), term = simplestOut))) } } else { - val sub = p.copy(pc = andJoin(p.pc +: exprsA), phi = andJoin(others), eb = p.qeb.filterIns(andJoin(exprsA))) + val sub = p.copy(pc = p.pc withConds exprsA, phi = andJoin(others), eb = p.qeb.filterIns(andJoin(exprsA))) Some(decomp(List(sub), { case (s @ Solution(pre, defs, term)) :: Nil => Some(Solution(pre=andJoin(exprsA :+ pre), defs, term, s.isTrusted)) diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index 74bf7133d..933ae752b 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -330,7 +330,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { */ private def outerExprToInnerExpr(e: Expr): Expr = outerToInner.transform(e)(Map.empty) - private val innerPc = outerExprToInnerExpr(p.pc) + private val innerPc = p.pc map outerExprToInnerExpr private val innerPhi = outerExprToInnerExpr(p.phi) // The program with the c-tree functions @@ -453,11 +453,11 @@ abstract class CEGISLike(name: String) extends Rule(name) { timers.testForProgram.start() val res = ex match { case InExample(ins) => - evaluator.eval(cnstr, p.as.zip(ins).toMap) + evaluator.eval(cnstr, p.as.zip(ins).toMap ++ p.pc.bindings) case InOutExample(ins, outs) => val eq = equality(innerSol, tupleWrap(outs)) - evaluator.eval(eq, p.as.zip(ins).toMap) + evaluator.eval(eq, p.as.zip(ins).toMap ++ p.pc.bindings) } timers.testForProgram.stop() @@ -521,7 +521,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { //println(innerProgram) cTreeFd.fullBody = innerSol - val cnstr = and(innerPc, letTuple(p.xs, innerSol, Not(innerPhi))) + val cnstr = innerPc and letTuple(p.xs, innerSol, Not(innerPhi)) val eval = new DefaultEvaluator(hctx, innerProgram) @@ -625,7 +625,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { //println("-"*80) //println(programCTree.asString) - val toFind = and(innerPc, cnstr) + val toFind = innerPc and cnstr //println(" --- Constraints ---") //println(" - "+toFind.asString) try { @@ -699,8 +699,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { try { solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable)))) - solver.assertCnstr(innerPc) - solver.assertCnstr(Not(cnstr)) + solver.assertCnstr(innerPc and not(cnstr)) //println("*"*80) //println(Not(cnstr)) @@ -749,7 +748,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { val solverf = hctx.solverFactory val solver = solverf.getNewSolver().setTimeout(exSolverTo) - solver.assertCnstr(p.pc) + solver.assertCnstr(p.pc.toClause) try { solver.check match { @@ -791,15 +790,16 @@ abstract class CEGISLike(name: String) extends Rule(name) { case FunctionInvocation(tfd, _) if tfd.fd == hctx.functionContext => true case Choose(_) => true case _ => false - }(p.pc) + }(p.pc.toClause) + if (complicated) { Iterator() } else { if (useVanuatoo) { - new VanuatooDataGen(hctx, hctx.program).generateFor(p.as, p.pc, nTests, 3000).map(InExample) + new VanuatooDataGen(hctx, hctx.program).generateFor(p.as, p.pc.toClause, nTests, 3000).map(InExample) } else { val evaluator = new DualEvaluator(hctx, hctx.program, CodeGenParams.default) - new GrammarDataGen(evaluator, ValueGrammar).generateFor(p.as, p.pc, nTests, 1000).map(InExample) + new GrammarDataGen(evaluator, ValueGrammar).generateFor(p.as, p.pc.toClause, nTests, 1000).map(InExample) } } } diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index 771899a7c..57e85d7b0 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -69,7 +69,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { val (newIds, expr, tMap) = decompose(a) subProblem = subst(a -> expr, subProblem) - subPc = subst(a -> expr, subPc) + subPc = subPc map (subst(a -> expr, _)) subWs = subst(a -> expr, subWs) revMap += expr -> Variable(a) hints +:= Hint(expr) diff --git a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala index 2ee67ad7f..90c4ecbb3 100644 --- a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala +++ b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala @@ -5,6 +5,8 @@ package synthesis package rules import leon.utils._ +import purescala.Path +import purescala.Common._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Extractors._ @@ -13,80 +15,95 @@ import purescala.Types.CaseClassType case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(clauses) = p.pc - def discoverEquivalences(allClauses: Seq[Expr]): Seq[(Expr, Expr)] = { - val instanceOfs = allClauses.collect { - case ccio: IsInstanceOf => ccio - } + val simplifier = Simplifiers.bestEffort(hctx, hctx.program) _ - val clauses = allClauses.filterNot(instanceOfs.toSet) + var subst = Map.empty[Identifier, Expr] + var reverseSubst = Map.empty[Identifier, Expr] - val ccSubsts = for (IsInstanceOf(s, cct: CaseClassType) <- instanceOfs) yield { + var obsolete = Set.empty[Identifier] + var free = Set.empty[Identifier] - val fieldsVals = (for (f <- cct.classDef.fields) yield { - val id = f.id + def discoverEquivalences(p: Path): Path = { - clauses.collectFirst { - case Equals(e, CaseClassSelector(`cct`, `s`, `id`)) => e - case Equals(e, CaseClassSelector(`cct`, AsInstanceOf(`s`, `cct`), `id`)) => e - } + val vars = p.variables + val clauses = p.conditions + + val instanceOfs = clauses.collect { case IsInstanceOf(Variable(id), cct) if vars(id) => id -> cct }.toSet - }).flatten + val equivalences = (for ((sid, cct: CaseClassType) <- instanceOfs) yield { + val fieldVals = for (f <- cct.classDef.fields) yield { + val fid = f.id + + p.bindings.collectFirst { + case (id, CaseClassSelector(`cct`, Variable(`sid`), `fid`)) => id + case (id, CaseClassSelector(`cct`, AsInstanceOf(Variable(`sid`), `cct`), `fid`)) => id + } + } - if (fieldsVals.size == cct.fields.size) { - Some((s, CaseClass(cct, fieldsVals))) + if (fieldVals.forall(_.isDefined)) { + Some(sid -> CaseClass(cct, fieldVals.map(_.get.toVariable))) + } else if (fieldVals.exists(_.isDefined)) { + Some(sid -> CaseClass(cct, (cct.fields zip fieldVals).map { + case (_, Some(id)) => Variable(id) + case (vid, None) => Variable(vid.id.freshen) + })) } else { None } - } - - // Direct equivalences: - val directEqs = allClauses.collect { - case Equals(v1 @ Variable(a1), v2 @ Variable(a2)) if a1 != a2 => (v2, v1) - } - - ccSubsts.flatten ++ directEqs - } + }).flatten + val unbound = equivalences.flatMap(_._2.args.collect { case Variable(id) => id }) + obsolete ++= equivalences.map(_._1) + free ++= unbound - // We could discover one equivalence, which could allow us to discover - // other equivalences: We do a fixpoint with limit 5. - val substs = fixpoint({ (substs: Set[(Expr, Expr)]) => - val newClauses = substs.map{ case(e,v) => Equals(v, e) } // clauses are directed: foo = obj.f - substs ++ discoverEquivalences(clauses ++ newClauses) - }, 5)(Set()).toSeq + def replace(e: Expr) = simplifier(replaceFromIDs(equivalences.toMap, e)) + subst = subst.mapValues(replace) ++ equivalences + val reverse = equivalences.toMap.flatMap { case (id, CaseClass(cct, fields)) => + (cct.classDef.fields zip fields).map { case (vid, Variable(fieldId)) => + fieldId -> CaseClassSelector(cct, AsInstanceOf(Variable(id), cct), vid.id) + } + } - // We are replacing foo(a) with b. We inject postcondition(foo)(a, b). - val postsToInject = substs.collect { - case (FunctionInvocation(tfd, args), e) if tfd.hasPostcondition => - val Some(post) = tfd.postcondition + reverseSubst ++= reverse.mapValues(replaceFromIDs(reverseSubst, _)) - application(replaceFromIDs((tfd.params.map(_.id) zip args).toMap, post), Seq(e)) + (p -- unbound) map replace } - if (substs.nonEmpty) { - val simplifier = Simplifiers.bestEffort(hctx, hctx.program) _ - - val removedAs = substs.collect { case (Variable(from), _) => from }.toSet + // We could discover one equivalence, which could allow us to discover + // other equivalences: We do a fixpoint with limit 5. + val simplifiedPath = fixpoint({ (path: Path) => discoverEquivalences(path) }, 5)(p.pc) + + if (subst.nonEmpty) { + // XXX: must take place in this order!! obsolete & free is typically non-empty + val newAs = (p.as ++ free).distinct.filterNot(obsolete) + + val newBank = p.eb.map { ex => + val mapping = (p.as zip ex.ins).toMap + val newIns = newAs.map(a => mapping.getOrElse(a, replaceFromIDs(mapping, reverseSubst(a)))) + List(ex match { + case ioe @ InOutExample(ins, outs) => ioe.copy(ins = newIns) + case ie @ InExample(ins) => ie.copy(ins = newIns) + }) + } val sub = p.copy( - as = p.as filterNot removedAs, - ws = replaceSeq(substs, p.ws), - pc = simplifier(andJoin(replaceSeq(substs, p.pc) +: postsToInject)), - phi = simplifier(replaceSeq(substs, p.phi)), - eb = p.qeb.removeIns(removedAs) + as = newAs, + ws = replaceFromIDs(subst, p.ws), + pc = simplifiedPath, + phi = simplifier(replaceFromIDs(subst, p.phi)), + eb = newBank ) - val subst = replace( - substs.map{_.swap}.filter{ case (x,y) => formulaSize(x) > formulaSize(y) }.toMap, - _:Expr - ) - - val substString = substs.map { case (f, t) => f.asString+" -> "+t.asString } + val onSuccess = { + val reverse = subst.map(_.swap).mapValues(_.toVariable) + forwardMap(replace(reverse, _)) + } + + val substString = subst.map { case (f, t) => f.asString+" -> "+t.asString } - List(decomp(List(sub), forwardMap(subst), "Equivalent Inputs ("+substString.mkString(", ")+")")) + List(decomp(List(sub), onSuccess, "Equivalent Inputs ("+substString.mkString(", ")+")")) } else { Nil } diff --git a/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala b/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala index 8dce44d68..b9be6da6d 100644 --- a/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/GenericTypeEqualitySplit.scala @@ -26,11 +26,12 @@ case object GenericTypeEqualitySplit extends Rule("Eq. Split") { case _ => Set() } - val TopLevelAnds(as) = and(p.pc, p.phi) - - val facts = as.flatMap(getFacts).toSet + val facts: Set[Set[Identifier]] = { + val TopLevelAnds(as) = andJoin(p.pc.conditions :+ p.phi) + as.toSet.flatMap(getFacts) + } - val candidates = p.as.combinations(2).collect { + val candidates = p.allAs.combinations(2).collect { case List(IsTyped(a1, TypeParameter(t1)), IsTyped(a2, TypeParameter(t2))) if t1 == t2 && !facts(Set(a1, a2)) => (a1, a2) @@ -42,12 +43,12 @@ case object GenericTypeEqualitySplit extends Rule("Eq. Split") { val v2 = Variable(a2) val subProblems = List( p.copy(as = p.as.diff(Seq(a1)), - pc = subst(a1 -> v2, p.pc), + pc = p.pc map (subst(a1 -> v2, _)), ws = subst(a1 -> v2, p.ws), phi = subst(a1 -> v2, p.phi), eb = p.qeb.filterIns(Equals(v1, v2)).removeIns(Set(a1))), - p.copy(pc = and(p.pc, not(Equals(v1, v2))), + p.copy(pc = p.pc withCond not(Equals(v1, v2)), eb = p.qeb.filterIns(not(Equals(v1, v2)))) ) diff --git a/src/main/scala/leon/synthesis/rules/IfSplit.scala b/src/main/scala/leon/synthesis/rules/IfSplit.scala index 9b97898b9..3d0c0d0ba 100644 --- a/src/main/scala/leon/synthesis/rules/IfSplit.scala +++ b/src/main/scala/leon/synthesis/rules/IfSplit.scala @@ -13,8 +13,8 @@ case object IfSplit extends Rule("If-Split") { def split(i: IfExpr, description: String): RuleInstantiation = { val subs = List( - Problem(p.as, p.ws, and(p.pc, i.cond), replace(Map(i -> i.thenn), p.phi), p.xs, p.qeb.filterIns(i.cond)), - Problem(p.as, p.ws, and(p.pc, not(i.cond)), replace(Map(i -> i.elze), p.phi), p.xs, p.qeb.filterIns(not(i.cond))) + Problem(p.as, p.ws, p.pc withCond i.cond, replace(Map(i -> i.thenn), p.phi), p.xs, p.qeb.filterIns(i.cond)), + Problem(p.as, p.ws, p.pc withCond not(i.cond), replace(Map(i -> i.elze), p.phi), p.xs, p.qeb.filterIns(not(i.cond))) ) val onSuccess: List[Solution] => Option[Solution] = { diff --git a/src/main/scala/leon/synthesis/rules/IndependentSplit.scala b/src/main/scala/leon/synthesis/rules/IndependentSplit.scala index fa8eceb3a..41d4c24b2 100644 --- a/src/main/scala/leon/synthesis/rules/IndependentSplit.scala +++ b/src/main/scala/leon/synthesis/rules/IndependentSplit.scala @@ -68,12 +68,12 @@ case object IndependentSplit extends NormalizingRule("IndependentSplit") { /**** Phase 2 ****/ - val TopLevelAnds(clauses) = and(newP.pc, newP.phi) + val TopLevelAnds(clauses) = andJoin(newP.pc.conditions :+ newP.phi) var independentClasses = Set[Set[Identifier]]() // We group connect variables together - for(c <- clauses) { + for (c <- clauses) { val vs = variablesOf(c) var newClasses = Set[Set[Identifier]]() diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala index 54445a5d4..e4e35b226 100644 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -28,8 +28,6 @@ case object InequalitySplit extends Rule("Ineq. Split.") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(as) = and(p.pc, p.phi) - def getFacts(e: Expr): Set[Fact] = e match { case LessThan(a, b) => Set(LT(b,a), EQ(a,b)) case LessEquals(a, b) => Set(LT(b,a)) @@ -44,23 +42,26 @@ case object InequalitySplit extends Rule("Ineq. Split.") { case _ => Set() } - val facts = as flatMap getFacts + val facts: Set[Fact] = { + val TopLevelAnds(as) = andJoin(p.pc.conditions :+ p.phi) + as.toSet flatMap getFacts + } val candidates = - (p.as.map(_.toVariable).filter(_.getType == Int32Type) :+ IntLiteral(0)).combinations(2).toList ++ - (p.as.map(_.toVariable).filter(_.getType == IntegerType) :+ InfiniteIntegerLiteral(0)).combinations(2).toList + (p.allAs.map(_.toVariable).filter(_.getType == Int32Type) :+ IntLiteral(0)).combinations(2).toList ++ + (p.allAs.map(_.toVariable).filter(_.getType == IntegerType) :+ InfiniteIntegerLiteral(0)).combinations(2).toList candidates.flatMap { case List(v1, v2) => val lt = if (!facts.contains(LT(v1, v2))) { val pc = LessThan(v1, v2) - Some(pc, p.copy(pc = and(p.pc, pc), eb = p.qeb.filterIns(pc))) + Some(pc, p.copy(pc = p.pc withCond pc, eb = p.qeb.filterIns(pc))) } else None val gt = if (!facts.contains(LT(v2, v1))) { val pc = GreaterThan(v1, v2) - Some(pc, p.copy(pc = and(p.pc, pc), eb = p.qeb.filterIns(pc))) + Some(pc, p.copy(pc = p.pc withCond pc, eb = p.qeb.filterIns(pc))) } else None val eq = if (!facts.contains(EQ(v1, v2)) && !facts.contains(EQ(v2,v1))) { @@ -72,7 +73,7 @@ case object InequalitySplit extends Rule("Ineq. Split.") { } val newP = p.copy( as = p.as.diff(Seq(a1)), - pc = subst(a1 -> v2, p.pc), + pc = p.pc map (subst(a1 -> v2, _)), ws = subst(a1 -> v2, p.ws), phi = subst(a1 -> v2, p.phi), eb = p.qeb.filterIns(Equals(v1, v2)).removeIns(Set(a1)) @@ -86,9 +87,7 @@ case object InequalitySplit extends Rule("Ineq. Split.") { else { val onSuccess: List[Solution] => Option[Solution] = { sols => - val pre = orJoin(pcs.zip(sols).map { case (pc, sol) => - and(pc, sol.pre) - }) + val pre = orJoin(pcs.zip(sols).map { case (pc, sol) => and(pc, sol.pre) }) val term = pcs.zip(sols) match { case Seq((pc1, s1), (_, s2)) => diff --git a/src/main/scala/leon/synthesis/rules/InputSplit.scala b/src/main/scala/leon/synthesis/rules/InputSplit.scala index 293c9767d..a14078ec8 100644 --- a/src/main/scala/leon/synthesis/rules/InputSplit.scala +++ b/src/main/scala/leon/synthesis/rules/InputSplit.scala @@ -4,6 +4,7 @@ package leon package synthesis package rules +import purescala.Path import purescala.Expressions._ import purescala.ExprOps._ import purescala.Constructors._ @@ -11,16 +12,23 @@ import purescala.Types._ case object InputSplit extends Rule("In. Split") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - p.as.filter(_.getType == BooleanType).flatMap { a => + p.allAs.filter(_.getType == BooleanType).flatMap { a => def getProblem(v: Boolean): Problem = { def replaceA(e: Expr) = replaceFromIDs(Map(a -> BooleanLiteral(v)), e) val tests = QualifiedExamplesBank(p.as, p.xs, p.qeb.filterIns(m => m(a) == BooleanLiteral(v))) + + val newPc: Path = { + val withoutA = p.pc -- Set(a) map replaceA + withoutA withConds (p.pc.bindings.find(_._1 == a).map { case (id, res) => + if (v) res else not(res) + }) + } p.copy( as = p.as.filterNot(_ == a), ws = replaceA(p.ws), - pc = replaceA(p.pc), + pc = newPc, phi = replaceA(p.phi), eb = tests.removeIns(Set(a)) ) diff --git a/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala b/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala index aca95e883..4463e6818 100644 --- a/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala +++ b/src/main/scala/leon/synthesis/rules/IntroduceRecCalls.scala @@ -5,6 +5,7 @@ package synthesis package rules import evaluators.DefaultEvaluator +import purescala.Path import purescala.Definitions.Program import purescala.Extractors.TopLevelAnds import purescala.Expressions._ @@ -25,8 +26,7 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { } def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val TopLevelAnds(pcs) = p.pc - val existingCalls = pcs.collect { case Equals(_, fi: FunctionInvocation) => fi }.toSet + val existingCalls = p.pc.bindings.collect { case (_, fi: FunctionInvocation) => fi }.toSet val calls = terminatingCalls(hctx.program, p.ws, p.pc, None, false) .map(_._1).distinct.filterNot(existingCalls) @@ -35,22 +35,23 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { val specifyCalls = hctx.findOptionOrDefault(SynthesisPhase.optSpecifyRecCalls) - val (recs, posts) = calls.map { newCall => + val recs = calls.map { newCall => val rec = FreshIdentifier("rec", newCall.getType, alwaysShowUniqueID = true) // Assume the postcondition of recursive call - val post = if (specifyCalls) { - Equals(rec.toVariable, newCall) + val (bound, path) = if (specifyCalls) { + (true, Path.empty withBinding (rec -> newCall)) } else { - application( + (false, Path(application( newCall.tfd.withParamSubst(newCall.args, newCall.tfd.postOrTrue), Seq(rec.toVariable) - ) + ))) } - (rec, post) - }.unzip - val onSuccess = forwardMap(letTuple(recs, tupleWrap(calls), _)) + (rec, bound, path) + } + + val onSuccess = forwardMap(letTuple(recs.map(_._1), tupleWrap(calls), _)) List(new RuleInstantiation(s"Introduce recursive calls ${calls mkString ", "}", SolutionBuilderDecomp(List(p.outType), onSuccess)) { @@ -82,8 +83,8 @@ case object IntroduceRecCalls extends NormalizingRule("Introduce rec. calls") { val TopLevelAnds(ws) = p.ws try { val newProblem = p.copy( - as = p.as ++ recs, - pc = andJoin(p.pc +: posts), + as = p.as ++ recs.collect { case (r, false, _) => r }, + pc = recs.map(_._3).foldLeft(p.pc)(_ merge _), ws = andJoin(ws ++ newWs), eb = p.eb.map(mapExample) ) diff --git a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala index 6a688b120..ce05a9b3c 100644 --- a/src/main/scala/leon/synthesis/rules/OptimisticGround.scala +++ b/src/main/scala/leon/synthesis/rules/OptimisticGround.scala @@ -31,8 +31,8 @@ case object OptimisticGround extends Rule("Optimistic Ground") { var predicates: Seq[Expr] = Seq() while (result.isEmpty && i < maxTries && continue) { - val phi = andJoin(p.pc +: p.phi +: predicates) - val notPhi = andJoin(p.pc +: not(p.phi) +: predicates) + val phi = p.pc and andJoin(p.phi +: predicates) + val notPhi = p.pc and andJoin(not(p.phi) +: predicates) //println("SOLVING " + phi + " ...") solver.solveSAT(phi) match { case (Some(true), satModel) => diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index b86894867..57b56f24a 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -229,7 +229,7 @@ case object StringRender extends Rule("StringRender") { val term = ExprOps.simplifyString(ExprOps.replaceFromIDs(initMap ++ assignment.mapValues(StringLiteral), singleTemplate)) val (finalTerm, finalDefs) = makeFunctionsUnique(term, fds.toSet) - Solution(pre=p.pc, defs=finalDefs, term=finalTerm) + Solution(BooleanLiteral(true), finalDefs, finalTerm) }) } } diff --git a/src/main/scala/leon/synthesis/rules/UnusedInput.scala b/src/main/scala/leon/synthesis/rules/UnusedInput.scala index a72dfb633..81e03a064 100644 --- a/src/main/scala/leon/synthesis/rules/UnusedInput.scala +++ b/src/main/scala/leon/synthesis/rules/UnusedInput.scala @@ -9,7 +9,7 @@ import purescala.TypeOps._ case object UnusedInput extends NormalizingRule("UnusedInput") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val unused = (p.as.toSet -- variablesOf(p.phi) -- variablesOf(p.pc) -- variablesOf(p.ws)).filter { a => + val unused = (p.as.toSet -- variablesOf(p.phi) -- p.pc.variables -- variablesOf(p.ws)).filter { a => !isParametricType(a.getType) } diff --git a/src/main/scala/leon/synthesis/rules/unused/ADTInduction.scala b/src/main/scala/leon/synthesis/rules/unused/ADTInduction.scala index ab2720fb0..0c1aa15e4 100644 --- a/src/main/scala/leon/synthesis/rules/unused/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/rules/unused/ADTInduction.scala @@ -4,6 +4,7 @@ package leon package synthesis package rules.unused +import purescala.Path import purescala.Common._ import purescala.Expressions._ import purescala.Extractors._ @@ -56,7 +57,7 @@ case object ADTInduction extends Rule("ADT Induction") { // Transformation of conditions, variables and axioms to use the inner variables of the inductive function. val innerPhi = substAll(substMap, p.phi) - val innerPC = substAll(substMap, p.pc) + val innerPC = p.pc map (substAll(substMap, _)) val innerWS = substAll(substMap, p.ws) val subProblemsInfo = for (cct <- ct.knownCCDescendants) yield { @@ -80,12 +81,12 @@ case object ADTInduction extends Rule("ADT Induction") { }).flatten val subPhi = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), innerPhi) - val subPC = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), innerPC) + val subPC = innerPC map (substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), _)) val subWS = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable))), innerWS) val subPre = IsInstanceOf(Variable(origId), cct) - val subProblem = Problem(inputs ::: residualArgs, subWS, andJoin(subPC :: postFs), subPhi, p.xs) + val subProblem = Problem(inputs ::: residualArgs, subWS, subPC withConds postFs, subPhi, p.xs) (subProblem, subPre, cct, newIds, recCalls) } @@ -111,17 +112,18 @@ case object ADTInduction extends Rule("ADT Induction") { // might only have to enforce it on solutions of base cases. None } else { - val funPre = substAll(substMap, and(p.pc, orJoin(globalPre))) + val outerPre = orJoin(globalPre) + val funPre = p.pc withCond outerPre map (substAll(substMap, _)) val funPost = substAll(substMap, p.phi) val idPost = FreshIdentifier("res", resType) - newFun.precondition = Some(funPre) + newFun.precondition = funPre newFun.postcondition = Some(Lambda(Seq(ValDef(idPost)), letTuple(p.xs.toSeq, Variable(idPost), funPost))) newFun.body = Some(matchExpr(Variable(inductOn), cases)) - Some(Solution(orJoin(globalPre), - sols.flatMap(_.defs).toSet+newFun, + Some(Solution(outerPre, + sols.flatMap(_.defs).toSet + newFun, FunctionInvocation(newFun.typed, Variable(origId) :: oas.map(Variable)), sols.forall(_.isTrusted) )) diff --git a/src/main/scala/leon/synthesis/rules/unused/ADTLongInduction.scala b/src/main/scala/leon/synthesis/rules/unused/ADTLongInduction.scala index 4ffd1a8c2..e105ef9a7 100644 --- a/src/main/scala/leon/synthesis/rules/unused/ADTLongInduction.scala +++ b/src/main/scala/leon/synthesis/rules/unused/ADTLongInduction.scala @@ -4,6 +4,7 @@ package leon package synthesis package rules.unused +import purescala.Path import purescala.Common._ import purescala.Expressions._ import purescala.Extractors._ @@ -44,10 +45,10 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { case class InductCase(ids: List[Identifier], calls: List[Identifier], pattern: Pattern, - outerPC: Expr, + outerPC: Path, trMap: Map[Identifier, Expr]) - val init = InductCase(inductOn :: residualArgs, List(), WildcardPattern(Some(inductOn)), BooleanLiteral(true), Map(inductOn -> Variable(inductOn))) + val init = InductCase(inductOn :: residualArgs, List(), WildcardPattern(Some(inductOn)), Path.empty, Map(inductOn -> Variable(inductOn))) def isRec(id: Identifier) = id.getType == origId.getType @@ -82,7 +83,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { val newMap = trMap.mapValues(v => substAll(Map(id -> CaseClass(cct, subIds.map(Variable))), v)) - InductCase(newIds, newCalls, newPattern, and(pc, IsInstanceOf(Variable(id), cct)), newMap) + InductCase(newIds, newCalls, newPattern, pc withCond IsInstanceOf(Variable(id), cct), newMap) } }).flatten } else { @@ -93,7 +94,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { val cases = unroll(init).flatMap(unroll) val innerPhi = substAll(substMap, p.phi) - val innerPC = substAll(substMap, p.pc) + val innerPC = p.pc map (substAll(substMap, _)) val innerWS = substAll(substMap, p.ws) val subProblemsInfo = for (c <- cases) yield { @@ -103,7 +104,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { var recCalls = Map[List[Identifier], List[Expr]]() - val subPC = substAll(trMap, innerPC) + val subPC = innerPC map (substAll(trMap, _)) val subWS = substAll(trMap, innerWS) val subPhi = substAll(trMap, innerPhi) @@ -119,7 +120,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { recCalls += postXs -> (Variable(cid) +: residualArgs.map(id => Variable(id))) } - val subProblem = Problem(c.ids ::: postXss, subWS, andJoin(subPC :: postFs), subPhi, p.xs) + val subProblem = Problem(c.ids ::: postXss, subWS, subPC withConds postFs, subPhi, p.xs) //println(subProblem) //println(recCalls) (subProblem, pat, recCalls, pc) @@ -127,12 +128,12 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { val onSuccess: List[Solution] => Option[Solution] = { case sols => - var globalPre = List[Expr]() + var globalPre = List.empty[Path] val newFun = new FunDef(FreshIdentifier("rec", alwaysShowUniqueID = true), Nil, ValDef(inductOn) +: residualArgDefs, resType) val cases = for ((sol, (problem, pat, calls, pc)) <- sols zip subProblemsInfo) yield { - globalPre ::= and(pc, sol.pre) + globalPre ::= (pc withCond sol.pre) SimpleCase(pat, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => letTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) } @@ -144,18 +145,18 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { // might only have to enforce it on solutions of base cases. None } else { - val funPre = substAll(substMap, and(p.pc, orJoin(globalPre))) + val outerPre = orJoin(globalPre.map(_.toClause)) + val funPre = p.pc withCond outerPre map (substAll(substMap, _)) val funPost = substAll(substMap, p.phi) val idPost = FreshIdentifier("res", resType) - val outerPre = orJoin(globalPre) - newFun.precondition = Some(funPre) + newFun.precondition = funPre newFun.postcondition = Some(Lambda(Seq(ValDef(idPost)), letTuple(p.xs.toSeq, Variable(idPost), funPost))) newFun.body = Some(matchExpr(Variable(inductOn), cases)) - Some(Solution(orJoin(globalPre), - sols.flatMap(_.defs).toSet+newFun, + Some(Solution(outerPre, + sols.flatMap(_.defs).toSet + newFun, FunctionInvocation(newFun.typed, Variable(origId) :: oas.map(Variable)), sols.forall(_.isTrusted) )) diff --git a/src/main/scala/leon/synthesis/rules/unused/IntInduction.scala b/src/main/scala/leon/synthesis/rules/unused/IntInduction.scala index 6ef209fcb..67d5d4948 100644 --- a/src/main/scala/leon/synthesis/rules/unused/IntInduction.scala +++ b/src/main/scala/leon/synthesis/rules/unused/IntInduction.scala @@ -24,14 +24,14 @@ case object IntInduction extends Rule("Int Induction") { val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable) val newPhi = subst(origId -> Variable(inductOn), p.phi) - val newPc = subst(origId -> Variable(inductOn), p.pc) + val newPc = p.pc map (subst(origId -> Variable(inductOn), _)) val newWs = subst(origId -> Variable(inductOn), p.ws) val postCondGT = substAll(postXsMap + (origId -> Minus(Variable(inductOn), InfiniteIntegerLiteral(1))), p.phi) val postCondLT = substAll(postXsMap + (origId -> Plus(Variable(inductOn), InfiniteIntegerLiteral(1))), p.phi) - val subBase = Problem(List(), subst(origId -> InfiniteIntegerLiteral(0), p.ws), subst(origId -> InfiniteIntegerLiteral(0), p.pc), subst(origId -> InfiniteIntegerLiteral(0), p.phi), p.xs) - val subGT = Problem(inductOn :: postXs, newWs, and(GreaterThan(Variable(inductOn), InfiniteIntegerLiteral(0)), postCondGT, newPc), newPhi, p.xs) - val subLT = Problem(inductOn :: postXs, newWs, and(LessThan(Variable(inductOn), InfiniteIntegerLiteral(0)), postCondLT, newPc), newPhi, p.xs) + val subBase = Problem(List(), subst(origId -> InfiniteIntegerLiteral(0), p.ws), p.pc map (subst(origId -> InfiniteIntegerLiteral(0), _)), subst(origId -> InfiniteIntegerLiteral(0), p.phi), p.xs) + val subGT = Problem(inductOn :: postXs, newWs, newPc withCond and(GreaterThan(Variable(inductOn), InfiniteIntegerLiteral(0)), postCondGT), newPhi, p.xs) + val subLT = Problem(inductOn :: postXs, newWs, newPc withCond and(LessThan(Variable(inductOn), InfiniteIntegerLiteral(0)), postCondLT), newPhi, p.xs) val onSuccess: List[Solution] => Option[Solution] = { case List(base, gt, lt) => diff --git a/src/main/scala/leon/synthesis/rules/unused/IntegerEquation.scala b/src/main/scala/leon/synthesis/rules/unused/IntegerEquation.scala index fdb59b00d..a80963625 100644 --- a/src/main/scala/leon/synthesis/rules/unused/IntegerEquation.scala +++ b/src/main/scala/leon/synthesis/rules/unused/IntegerEquation.scala @@ -56,7 +56,7 @@ case object IntegerEquation extends Rule("Integer Equation") { if(normalizedEq.size == 1) { val eqPre = Equals(normalizedEq.head, IntLiteral(0)) - val newProblem = Problem(problem.as, problem.ws, and(eqPre, problem.pc), andJoin(allOthers), problem.xs) + val newProblem = Problem(problem.as, problem.ws, problem.pc withCond eqPre, andJoin(allOthers), problem.xs) val onSuccess: List[Solution] => Option[Solution] = { case List(s @ Solution(pre, defs, term)) => @@ -80,7 +80,7 @@ case object IntegerEquation extends Rule("Integer Equation") { var freshInputVariables: List[Identifier] = Nil var equivalenceConstraints: Map[Expr, Expr] = Map() val freshFormula = simplePreTransform({ - case d@Division(_, _) => { + case d @ Division(_, _) => { assert(variablesOf(d).intersect(problem.xs.toSet).isEmpty) val newVar = FreshIdentifier("d", Int32Type, true) freshInputVariables ::= newVar @@ -93,7 +93,7 @@ case object IntegerEquation extends Rule("Integer Equation") { val ys: List[Identifier] = problem.xs.filterNot(neqxs.contains(_)) val subproblemxs: List[Identifier] = freshxs ++ ys - val newProblem = Problem(problem.as ++ freshInputVariables, problem.ws, and(eqPre, problem.pc), freshFormula, subproblemxs) + val newProblem = Problem(problem.as ++ freshInputVariables, problem.ws, problem.pc withCond eqPre, freshFormula, subproblemxs) val onSuccess: List[Solution] => Option[Solution] = { case List(s @ Solution(pre, defs, term)) => { @@ -113,7 +113,7 @@ case object IntegerEquation extends Rule("Integer Equation") { if (subproblemxs.isEmpty) { // we directly solve - List(solve(onSuccess(List(Solution(and(eqPre, problem.pc), Set(), UnitLiteral()))).get)) + List(solve(onSuccess(List(Solution((problem.pc withCond eqPre).toClause, Set(), UnitLiteral()))).get)) } else { List(decomp(List(newProblem), onSuccess, this.name)) } diff --git a/src/main/scala/leon/synthesis/rules/unused/TEGISLike.scala b/src/main/scala/leon/synthesis/rules/unused/TEGISLike.scala index 16eb701f9..ee49610d0 100644 --- a/src/main/scala/leon/synthesis/rules/unused/TEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/unused/TEGISLike.scala @@ -36,15 +36,15 @@ abstract class TEGISLike(name: String) extends Rule(name) { val params = getParams(hctx, p) val grammar = params.grammar - val nTests = if (p.pc == BooleanLiteral(true)) 50 else 20 + val nTests = if (p.pc.isEmpty) 50 else 20 val useVanuatoo = hctx.settings.cegisUseVanuatoo val inputGenerator: Iterator[Seq[Expr]] = if (useVanuatoo) { - new VanuatooDataGen(hctx, hctx.program).generateFor(p.as, p.pc, nTests, 3000) + new VanuatooDataGen(hctx, hctx.program).generateFor(p.as, p.pc.toClause, nTests, 3000) } else { val evaluator = new DualEvaluator(hctx, hctx.program, CodeGenParams.default) - new GrammarDataGen(evaluator, ValueGrammar).generateFor(p.as, p.pc, nTests, 1000) + new GrammarDataGen(evaluator, ValueGrammar).generateFor(p.as, p.pc.toClause, nTests, 1000) } val gi = new GrowableIterable[Seq[Expr]](p.eb.examples.map(_.ins).distinct, inputGenerator) diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index 64a8f4d9d..0a76050a1 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -4,6 +4,7 @@ package leon package synthesis package utils +import purescala.Path import purescala.Definitions._ import purescala.Types._ import purescala.Extractors._ @@ -46,10 +47,9 @@ object Helpers { * @return A list of pairs (safe function call, holes), * where holes stand for the rest of the arguments of the function. */ - def terminatingCalls(prog: Program, ws: Expr, pc: Expr, tpe: Option[TypeTree], introduceHoles: Boolean): List[(FunctionInvocation, Option[Set[Identifier]])] = { + def terminatingCalls(prog: Program, ws: Expr, pc: Path, tpe: Option[TypeTree], introduceHoles: Boolean): List[(FunctionInvocation, Option[Set[Identifier]])] = { val TopLevelAnds(wss) = ws - val TopLevelAnds(clauses) = pc val gs: List[Terminating] = wss.toList.collect { case t : Terminating => t @@ -60,12 +60,14 @@ object Helpers { case (r: Variable) if leastUpperBound(r.getType, v.getType).isDefined => Some(r -> v) case _ => None } - + val z = InfiniteIntegerLiteral(0) val one = InfiniteIntegerLiteral(1) - val knownSmallers = clauses.collect { - case Equals(v: Variable, s@CaseClassSelector(cct, r, _)) => subExprsOf(s, v) - case Equals(s@CaseClassSelector(cct, r, _), v: Variable) => subExprsOf(s, v) + val knownSmallers = (pc.bindings.flatMap { + // @nv: used to check both Equals(id, selector) and Equals(selector, id) + case (id, s @ CaseClassSelector(cct, r, _)) => subExprsOf(s, id.toVariable) + case _ => None + } ++ pc.conditions.flatMap { case GreaterThan(v: Variable, `z`) => Some(v -> Minus(v, one)) case LessThan(`z`, v: Variable) => @@ -74,7 +76,8 @@ object Helpers { Some(v -> Plus(v, one)) case GreaterThan(`z`, v: Variable) => Some(v -> Plus(v, one)) - }.flatten.groupBy(_._1).mapValues(v => v.map(_._2)) + case _ => None + }).groupBy(_._1).mapValues(v => v.map(_._2)) def argsSmaller(e: Expr, tpe: TypeTree): Seq[Expr] = e match { case CaseClass(cct, args) => diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index daab94567..a4b90940c 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -3,11 +3,12 @@ package leon package termination -import leon.purescala.Definitions._ -import leon.purescala.Expressions._ -import leon.purescala.ExprOps._ -import leon.purescala.Constructors._ -import leon.purescala.Common._ +import purescala.Path +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Constructors._ +import purescala.Common._ import scala.collection.mutable.{Map => MutableMap} @@ -48,13 +49,13 @@ final case class Chain(relations: List[Relation]) { lazy val finalParams : Seq[ValDef] = inlining.last._1 - def loop(initialArgs: Seq[Identifier] = Seq.empty, finalArgs: Seq[Identifier] = Seq.empty): Seq[Expr] = { - def rec(relations: List[Relation], funDef: TypedFunDef, subst: Map[Identifier, Identifier]): Seq[Expr] = { + def loop(initialArgs: Seq[Identifier] = Seq.empty, finalArgs: Seq[Identifier] = Seq.empty): Path = { + def rec(relations: List[Relation], funDef: TypedFunDef, subst: Map[Identifier, Identifier]): Path = { val Relation(_, path, FunctionInvocation(fitfd, args), _) = relations.head val tfd = TypedFunDef(fitfd.fd, fitfd.tps.map(funDef.translated)) val translate : Expr => Expr = { - val free : Set[Identifier] = path.flatMap(variablesOf).toSet -- funDef.fd.params.map(_.id) + val free : Set[Identifier] = path.variables -- funDef.fd.params.map(_.id) val freeMapping : Map[Identifier,Identifier] = free.map(id => id -> { FreshIdentifier(id.name, funDef.translated(id.getType), true).copiedFrom(id) }).toMap @@ -65,16 +66,14 @@ final case class Chain(relations: List[Relation]) { lazy val newArgs = args.map(translate) - path.map(translate) ++ (relations.tail match { + path.map(translate) merge (relations.tail match { case Nil => - (finalArgs zip newArgs).map { case (finalArg, newArg) => Equals(finalArg.toVariable, newArg) } + Path.empty withBindings (finalArgs zip newArgs) case xs => val params = tfd.params.map(_.id) val freshParams = tfd.params.map(arg => FreshIdentifier(arg.id.name, arg.getType, true)) - val bindings = (freshParams.map(_.toVariable) zip newArgs).map(p => Equals(p._1, p._2)) - bindings ++ rec(xs, tfd, (params zip freshParams).toMap) + Path.empty withBindings (freshParams zip newArgs) merge rec(xs, tfd, (params zip freshParams).toMap) }) - } rec(relations, funDef.typed, (funDef.params.map(_.id) zip initialArgs).toMap) @@ -132,8 +131,8 @@ trait ChainBuilder extends RelationBuilder { self: Strengthener with RelationCom val constraints = relations.map(relation => relationConstraints.getOrElse(relation, { val Relation(funDef, path, FunctionInvocation(_, args), _) = relation val args0 = funDef.params.map(_.toVariable) - val constraint = if (solver.definitiveALL(implies(andJoin(path), self.softDecreasing(args0, args)))) { - if (solver.definitiveALL(implies(andJoin(path), self.sizeDecreasing(args0, args)))) { + val constraint = if (solver.definitiveALL(path implies self.softDecreasing(args0, args))) { + if (solver.definitiveALL(path implies self.sizeDecreasing(args0, args))) { StrongDecreasing } else { WeakDecreasing diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index a2c18e171..0a4df39fd 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -3,6 +3,7 @@ package leon package termination +import purescala.Path import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ @@ -62,9 +63,9 @@ trait ChainComparator { self : StructuralSize => rec(tpe) } - def structuralDecreasing(e1: Expr, e2s: Seq[(Seq[Expr], Expr)]) : Seq[Expr] = flatTypesPowerset(e1.getType).toSeq.map { + def structuralDecreasing(e1: Expr, e2s: Seq[(Path, Expr)]): Seq[Expr] = flatTypesPowerset(e1.getType).toSeq.map { recons => andJoin(e2s.map { case (path, e2) => - implies(andJoin(path), GreaterThan(self.size(recons(e1)), self.size(recons(e2)))) + path implies GreaterThan(self.size(recons(e1)), self.size(recons(e2))) }) } @@ -192,19 +193,19 @@ trait ChainComparator { self : StructuralSize => } } - def numericConverging(e1: Expr, e2s: Seq[(Seq[Expr], Expr)], cluster: Set[Chain]) : Seq[Expr] = flatType(e1.getType).toSeq.flatMap { + def numericConverging(e1: Expr, e2s: Seq[(Path, Expr)], cluster: Set[Chain]) : Seq[Expr] = flatType(e1.getType).toSeq.flatMap { recons => recons(e1) match { case e if e.getType == IntegerType => val endpoint = numericEndpoint(e, cluster) val uppers = if (endpoint == UpperBoundEndpoint || endpoint == AnyEndpoint) { - Some(andJoin(e2s map { case (path, e2) => implies(andJoin(path), GreaterThan(e, recons(e2))) })) + Some(andJoin(e2s map { case (path, e2) => path implies GreaterThan(e, recons(e2)) })) } else { None } val lowers = if (endpoint == LowerBoundEndpoint || endpoint == AnyEndpoint) { - Some(andJoin(e2s map { case (path, e2) => implies(andJoin(path), LessThan(e, recons(e2))) })) + Some(andJoin(e2s map { case (path, e2) => path implies LessThan(e, recons(e2)) })) } else { None } diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala index ef6fde2e2..1fcf89fbc 100644 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -9,8 +9,8 @@ import purescala.Definitions._ import purescala.Constructors._ class ChainProcessor( - val checker: TerminationChecker, - val modules: ChainBuilder with ChainComparator with Strengthener with StructuralSize + val checker: TerminationChecker, + val modules: ChainBuilder with ChainComparator with Strengthener with StructuralSize ) extends Processor with Solvable { val name: String = "Chain Processor" @@ -59,7 +59,7 @@ class ChainProcessor( Some(problem.funDefs map Cleared) else { val maybeReentrant = chains.flatMap(c1 => chains.flatMap(c2 => c1 compose c2)).exists { - chain => maybeSAT(andJoin(chain.loop())) + chain => maybeSAT(chain.loop().toClause) } if (!maybeReentrant) diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala index e71c52821..643c62b3a 100644 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -34,7 +34,7 @@ class LoopProcessor(val checker: TerminationChecker, val modules: ChainBuilder w val srcTuple = tupleWrap(chain.funDef.params.map(_.toVariable)) val resTuple = tupleWrap(freshParams.map(_.toVariable)) - definitiveSATwithModel(andJoin(path :+ Equals(srcTuple, resTuple))) match { + definitiveSATwithModel(path and equality(srcTuple, resTuple)) match { case Some(model) => val args = chain.funDef.params.map(arg => model(arg.id)) val res = if (chain.relations.exists(_.inLambda)) MaybeBroken(chain.funDef, args) else Broken(chain.funDef, args) diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala index d8d31e2a3..622723e2b 100644 --- a/src/main/scala/leon/termination/RecursionProcessor.scala +++ b/src/main/scala/leon/termination/RecursionProcessor.scala @@ -34,8 +34,8 @@ class RecursionProcessor(val checker: TerminationChecker, val rb: RelationBuilde recursive.forall({ case Relation(_, path, FunctionInvocation(_, args), _) => args(index) match { // handle case class deconstruction in match expression! - case Variable(id) => path.reverse.exists { - case Equals(Variable(vid), ccs) if vid == id => isSubtreeOf(ccs, arg.id) + case Variable(id) => path.bindings.exists { + case (vid, ccs) if vid == id => isSubtreeOf(ccs, arg.id) case _ => false } case expr => isSubtreeOf(expr, arg.id) diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala index 127b64ec0..5feef247e 100644 --- a/src/main/scala/leon/termination/RelationBuilder.scala +++ b/src/main/scala/leon/termination/RelationBuilder.scala @@ -3,13 +3,14 @@ package leon package termination +import purescala._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Definitions._ import scala.collection.mutable.{Map => MutableMap} -final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocation, inLambda: Boolean) { +final case class Relation(funDef: FunDef, path: Path, call: FunctionInvocation, inLambda: Boolean) { override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.tfd.id + call.args.mkString("(",",",")") + "," + inLambda + ")" } @@ -32,7 +33,7 @@ trait RelationBuilder { self: Strengthener => val collector = new CollectorWithPaths[Relation] { var inLambda: Boolean = false - override def rec(e: Expr, path: Seq[Expr]): Expr = e match { + override def rec(e: Expr, path: Path): Expr = e match { case l : Lambda => val old = inLambda inLambda = true @@ -43,21 +44,17 @@ trait RelationBuilder { self: Strengthener => super.rec(e, path) } - def collect(e: Expr, path: Seq[Expr]): Option[Relation] = e match { + def collect(e: Expr, path: Path): Option[Relation] = e match { case fi @ FunctionInvocation(f, args) if checker.functions(f.fd) => - val flatPath = path flatMap { - case And(es) => es - case expr => Seq(expr) - } - Some(Relation(funDef, flatPath, fi, inLambda)) + Some(Relation(funDef, path, fi, inLambda)) case _ => None } - override def walk(e: Expr, path: Seq[Expr]) = e match { + override def walk(e: Expr, path: Path) = e match { case FunctionInvocation(tfd, args) => val funDef = tfd.fd Some(FunctionInvocation(tfd, (funDef.params.map(_.id) zip args) map { case (id, arg) => - rec(arg, register(self.applicationConstraint(funDef, id, arg, args), path)) + rec(arg, path withCond self.applicationConstraint(funDef, id, arg, args)) })) case _ => None } diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala index 1044e1be2..c27a9a6e1 100644 --- a/src/main/scala/leon/termination/RelationProcessor.scala +++ b/src/main/scala/leon/termination/RelationProcessor.scala @@ -27,7 +27,7 @@ class RelationProcessor( funDef -> modules.getRelations(funDef).collect({ case Relation(_, path, FunctionInvocation(tfd, args), _) if problem.funSet(tfd.fd) => val args0 = funDef.params.map(_.toVariable) - def constraint(expr: Expr) = implies(andJoin(path.toSeq), expr) + def constraint(expr: Expr) = path implies expr val greaterThan = modules.sizeDecreasing(args0, args) val greaterEquals = modules.softDecreasing(args0, args) (tfd.fd, (constraint(greaterThan), constraint(greaterEquals))) diff --git a/src/main/scala/leon/termination/Strengthener.scala b/src/main/scala/leon/termination/Strengthener.scala index 75a466001..cc268d943 100644 --- a/src/main/scala/leon/termination/Strengthener.scala +++ b/src/main/scala/leon/termination/Strengthener.scala @@ -3,6 +3,7 @@ package leon package termination +import purescala.Path import purescala.Expressions._ import purescala.Types._ import purescala.ExprOps._ @@ -86,9 +87,9 @@ trait Strengthener { self : RelationComparator => for (funDef <- sortedFunDefs if !strengthenedApp(funDef) && funDef.hasBody && checker.terminates(funDef).isGuaranteed) { - val appCollector = new CollectorWithPaths[(Identifier,Expr,Seq[Expr])] { - def collect(e: Expr, path: Seq[Expr]): Option[(Identifier, Expr, Seq[Expr])] = e match { - case Application(Variable(id), args) => Some((id, andJoin(path), args)) + val appCollector = new CollectorWithPaths[(Identifier,Path,Seq[Expr])] { + def collect(e: Expr, path: Path): Option[(Identifier, Path, Seq[Expr])] = e match { + case Application(Variable(id), args) => Some((id, path, args)) case _ => None } } @@ -98,8 +99,8 @@ trait Strengthener { self : RelationComparator => val funDefArgs = funDef.params.map(_.toVariable) val allFormulas = for ((id, path, appArgs) <- applications) yield { - val soft = Implies(path, self.softDecreasing(funDefArgs, appArgs)) - val hard = Implies(path, self.sizeDecreasing(funDefArgs, appArgs)) + val soft = path implies self.softDecreasing(funDefArgs, appArgs) + val hard = path implies self.sizeDecreasing(funDefArgs, appArgs) id -> ((soft, hard)) } @@ -119,10 +120,10 @@ trait Strengthener { self : RelationComparator => val funDefHOArgs = funDef.params.map(_.id).filter(_.getType.isInstanceOf[FunctionType]).toSet - val fiCollector = new CollectorWithPaths[(Expr, Seq[Expr], Seq[(Identifier,(FunDef, Identifier))])] { - def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[(Identifier,(FunDef, Identifier))])] = e match { + val fiCollector = new CollectorWithPaths[(Path, Seq[Expr], Seq[(Identifier,(FunDef, Identifier))])] { + def collect(e: Expr, path: Path): Option[(Path, Seq[Expr], Seq[(Identifier,(FunDef, Identifier))])] = e match { case FunctionInvocation(tfd, args) if (funDefHOArgs intersect args.collect({ case Variable(id) => id }).toSet).nonEmpty => - Some((andJoin(path), args, (args zip tfd.fd.params).collect { + Some((path, args, (args zip tfd.fd.params).collect { case (Variable(id), vd) if funDefHOArgs(id) => id -> ((tfd.fd, vd.id)) })) case _ => None @@ -130,22 +131,22 @@ trait Strengthener { self : RelationComparator => } val invocations = fiCollector.traverse(funDef) - val id2invocations : Seq[(Identifier, ((FunDef, Identifier), Expr, Seq[Expr]))] = + val id2invocations : Seq[(Identifier, ((FunDef, Identifier), Path, Seq[Expr]))] = for { p <- invocations c <- p._3 } yield c._1 -> (c._2, p._1, p._2) - val invocationMap: Map[Identifier, Seq[((FunDef, Identifier), Expr, Seq[Expr])]] = + val invocationMap: Map[Identifier, Seq[((FunDef, Identifier), Path, Seq[Expr])]] = id2invocations.groupBy(_._1).mapValues(_.map(_._2)) - def constraint(id: Identifier, passings: Seq[((FunDef, Identifier), Expr, Seq[Expr])]): SizeConstraint = { + def constraint(id: Identifier, passings: Seq[((FunDef, Identifier), Path, Seq[Expr])]): SizeConstraint = { if (constraints.get(id) == Some(NoConstraint)) NoConstraint else if (passings.exists(p => appConstraint.get(p._1) == Some(NoConstraint))) NoConstraint else passings.foldLeft[SizeConstraint](constraints.getOrElse(id, StrongDecreasing)) { case (constraint, (key, path, args)) => - lazy val strongFormula = Implies(path, self.sizeDecreasing(funDefArgs, args)) - lazy val weakFormula = Implies(path, self.softDecreasing(funDefArgs, args)) + lazy val strongFormula = path implies self.sizeDecreasing(funDefArgs, args) + lazy val weakFormula = path implies self.softDecreasing(funDefArgs, args) (constraint, appConstraint.get(key)) match { case (_, Some(NoConstraint)) => scala.sys.error("Whaaaat!?!? This shouldn't happen...") diff --git a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala index 957074b48..ada138679 100644 --- a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala +++ b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala @@ -308,11 +308,11 @@ class ExprInstrumenter(funMap: Map[FunDef, FunDef], serialInst: SerialInstrument val map = mapForPattern(scrut, cse.pattern) val patCond = conditionForPattern(scrut, cse.pattern, includeBinders = false) val realCond = cse.optGuard match { - case Some(g) => And(patCond, replaceFromIDs(map, g)) + case Some(g) => patCond withCond replaceFromIDs(map, g) case None => patCond } val newRhs = replaceFromIDs(map, cse.rhs) - (realCond, newRhs) + (realCond.toClause, newRhs) } val bigIte = condsAndRhs.foldRight[Expr]( Error(me.getType, "Match is non-exhaustive").copiedFrom(me))((p1, ex) => { @@ -496,4 +496,4 @@ abstract class Instrumenter(program: Program, si: SerialInstrumenter) { */ def instrumentMatchCase(me: MatchExpr, mc: MatchCase, caseExprCost: Expr, scrutineeCost: Expr): Expr -} \ No newline at end of file +} diff --git a/src/main/scala/leon/verification/DefaultTactic.scala b/src/main/scala/leon/verification/DefaultTactic.scala index 9905d96ad..4d03e20e3 100644 --- a/src/main/scala/leon/verification/DefaultTactic.scala +++ b/src/main/scala/leon/verification/DefaultTactic.scala @@ -31,7 +31,7 @@ class DefaultTactic(vctx: VerificationContext) extends Tactic(vctx) { calls.map { case (fi @ FunctionInvocation(tfd, args), path) => val pre = tfd.withParamSubst(args, tfd.precondition.get) - val vc = implies(path, pre) + val vc = path implies pre val fiS = sizeLimit(fi.asString, 40) VC(vc, fd, VCKinds.Info(VCKinds.Precondition, s"call $fiS")).setPos(fi) } diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index d1ae3b0f9..a81081bc7 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -65,7 +65,7 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { }(body) for { - ((fi@FunctionInvocation(tfd, args), pre), path) <- calls + ((fi @ FunctionInvocation(tfd, args), pre), path) <- calls cct <- parentType.knownCCDescendants } yield { val selectors = selectorsOfParentType(parentType, cct, arg.toVariable) @@ -76,10 +76,9 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { ) } - val vc = implies( - andJoin(Seq(IsInstanceOf(arg.toVariable, cct), fd.precOrTrue, path) ++ subCases), - tfd.withParamSubst(args, pre) - ) + val vc = path + .withConds(Seq(IsInstanceOf(arg.toVariable, cct), fd.precOrTrue) ++ subCases) + .implies(tfd.withParamSubst(args, pre)) // Crop the call to display it properly val fiS = sizeLimit(fi.asString, 25) diff --git a/src/test/scala/leon/integration/grammars/SimilarToSuite.scala b/src/test/scala/leon/integration/grammars/SimilarToSuite.scala index 7f4f625be..358c6c252 100644 --- a/src/test/scala/leon/integration/grammars/SimilarToSuite.scala +++ b/src/test/scala/leon/integration/grammars/SimilarToSuite.scala @@ -5,6 +5,7 @@ package leon.integration.grammars import leon._ import leon.test._ import leon.test.helpers._ +import leon.purescala.Path import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.purescala.Constructors._ -- GitLab