From a3216570a1279d2d1df736ee8a66d6cb27bdc92c Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Sun, 1 Mar 2015 19:52:22 +0100 Subject: [PATCH] Implement Ensuring, Choose as HOFs --- .../scala/leon/codegen/CodeGeneration.scala | 17 +- .../leon/evaluators/RecursiveEvaluator.scala | 28 +-- .../leon/evaluators/TracingEvaluator.scala | 19 +- .../leon/frontends/scalac/ASTExtractors.scala | 13 +- .../frontends/scalac/CodeExtraction.scala | 34 +-- src/main/scala/leon/purescala/CallGraph.scala | 5 +- .../CompleteAbstractDefinitions.scala | 25 +- .../scala/leon/purescala/Constructors.scala | 22 ++ src/main/scala/leon/purescala/DefOps.scala | 4 +- .../scala/leon/purescala/Definitions.scala | 16 +- .../scala/leon/purescala/Extractors.scala | 68 ++---- .../leon/purescala/FunctionClosure.scala | 6 +- .../scala/leon/purescala/MethodLifting.scala | 7 +- .../scala/leon/purescala/PrettyPrinter.scala | 14 +- .../scala/leon/purescala/ScalaPrinter.scala | 2 +- .../leon/purescala/ScopeSimplifier.scala | 10 +- src/main/scala/leon/purescala/TreeOps.scala | 218 ++++-------------- src/main/scala/leon/purescala/Trees.scala | 32 ++- .../scala/leon/purescala/TypeTreeOps.scala | 10 +- .../scala/leon/repair/RepairNDEvaluator.scala | 16 +- .../leon/repair/RepairTrackingEvaluator.scala | 28 +-- src/main/scala/leon/repair/Repairman.scala | 28 +-- .../solvers/templates/TemplateGenerator.scala | 23 +- .../scala/leon/synthesis/ConvertHoles.scala | 20 +- .../leon/synthesis/ConvertWithOracles.scala | 14 +- .../scala/leon/synthesis/ExamplesFinder.scala | 3 +- src/main/scala/leon/synthesis/Problem.scala | 13 +- src/main/scala/leon/synthesis/Solution.scala | 4 +- .../scala/leon/synthesis/Synthesizer.scala | 2 +- .../leon/synthesis/rules/ADTInduction.scala | 2 +- .../synthesis/rules/ADTLongInduction.scala | 2 +- .../scala/leon/synthesis/rules/Assert.scala | 1 - .../synthesis/rules/EquivalentInputs.scala | 4 +- .../leon/synthesis/rules/IntInduction.scala | 2 +- .../leon/termination/RelationBuilder.scala | 2 +- .../SimpleTerminationChecker.scala | 12 +- .../scala/leon/termination/Strengthener.scala | 11 +- .../leon/termination/StructuralSize.scala | 4 +- src/main/scala/leon/utils/TypingPhase.scala | 13 +- .../scala/leon/utils/UnitElimination.scala | 12 +- .../leon/verification/DefaultTactic.scala | 7 +- .../leon/verification/InductionTactic.scala | 13 +- .../leon/xlang/ArrayTransformation.scala | 26 +-- .../scala/leon/xlang/EpsilonElimination.scala | 10 +- .../xlang/ImperativeCodeElimination.scala | 4 +- .../verification/purescala/valid/Acc.scala | 2 +- .../test/solvers/UnrollingSolverTests.scala | 2 +- 47 files changed, 306 insertions(+), 524 deletions(-) diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index d3fd0df76..95d1e0d95 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -180,11 +180,10 @@ trait CodeGeneration { body } - val bodyWithPost = if(funDef.hasPostcondition && params.checkContracts) { - val Some((id, post)) = funDef.postcondition - Let(id, bodyWithPre, IfExpr(post, Variable(id), Error(id.getType, "Postcondition failed")) ) - } else { - bodyWithPre + val bodyWithPost = funDef.postcondition match { + case Some(post) if params.checkContracts => + Ensuring(bodyWithPre, post).toAssert + case _ => bodyWithPre } if (params.recordInvocations) { @@ -216,8 +215,8 @@ trait CodeGeneration { case Assert(cond, oerr, body) => mkExpr(IfExpr(Not(cond), Error(body.getType, oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch) - case Ensuring(body, id, post) => - mkExpr(Let(id, body, Assert(post, Some("Ensuring failed"), Variable(id))), ch) + case en@Ensuring(_, _) => + mkExpr(en.toAssert, ch) case Let(i,d,b) => mkExpr(d, ch) @@ -785,10 +784,10 @@ trait CodeGeneration { ch << InvokeSpecial(ErrorClass, constructorName, "(Ljava/lang/String;)V") ch << ATHROW - case Choose(_, _, Some(e)) => + case Choose(_, Some(e)) => mkExpr(e, ch) - case choose @ Choose(_, _, None) => + case choose @ Choose(_, None) => val prob = synthesis.Problem.fromChoose(choose) val id = runtime.ChooseEntryPoint.register(prob, this); diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 546548244..7bac62443 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -109,16 +109,16 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Assert(cond, oerr, body) => e(IfExpr(Not(cond), Error(expr.getType, oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) - case en@Ensuring(body, id, post) => + case en@Ensuring(body, post) => if ( exists{ case Hole(_,_) => true case Gives(_,_) => true case _ => false }(en)) e(convertHoles(en, ctx, true)) - else - e(Let(id, body, Assert(post, Some("Ensuring failed"), Variable(id)))) - + else + e(en.toAssert) + case Error(tpe, desc) => throw RuntimeError("Error reached in evaluation: " + desc) @@ -168,14 +168,14 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) val callResult = e(body)(frame, gctx) - if(tfd.hasPostcondition) { - val (id, post) = tfd.postcondition.get - - e(post)(frame.withNewVar(id, callResult), gctx) match { - case BooleanLiteral(true) => - case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } + tfd.postcondition match { + case Some(post) => + e(application(post, Seq(callResult)))(frame, gctx) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") + case other => throw EvalError(typeErrorMsg(other, BooleanType)) + } + case None => } callResult @@ -500,10 +500,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case p : Passes => e(p.asConstraint) - case choose @ Choose(_, _, Some(impl)) => + case choose @ Choose(_, Some(impl)) => e(impl) - case choose @ Choose(_, _, None) => + case choose @ Choose(_, None) => import purescala.TreeOps.simplestValue implicit val debugSection = utils.DebugSectionSynthesis diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index 8321a9c5c..80757d88a 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -78,16 +78,15 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) val callResult = e(body)(frame, gctx) - if(tfd.hasPostcondition) { - val (id, post) = tfd.postcondition.get - - gctx.values ::= id.toVariable.setPos(id) -> callResult - - e(post)(frame.withNewVar(id, callResult), gctx) match { - case BooleanLiteral(true) => - case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } + tfd.postcondition match { + case Some(post) => + + e(Application(post, Seq(callResult)))(frame, gctx) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") + case other => throw EvalError(typeErrorMsg(other, BooleanType)) + } + case None => } (callResult, callResult) diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index b37fd59a6..840dd6fa4 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -136,10 +136,9 @@ trait ASTExtractors { object ExEnsuredExpression { /** Extracts the 'ensuring' contract from an expression. */ - def unapply(tree: Apply): Option[(Tree,ValDef,Tree)] = tree match { - case Apply(Select(Apply(TypeApply(ExSelected("scala", "Predef", "Ensuring"), _ :: Nil), body :: Nil), ExNamed("ensuring")), - (Function((vd @ ValDef(_, _, _, EmptyTree)) :: Nil, contractBody)) :: Nil) - => Some((body, vd, contractBody)) + def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { + case Apply(Select(Apply(TypeApply(ExSelected("scala", "Predef", "Ensuring"), _ :: Nil), body :: Nil), ExNamed("ensuring")), contract :: Nil) + => Some((body, contract)) case _ => None } } @@ -455,11 +454,11 @@ trait ASTExtractors { } object ExChooseExpression { - def unapply(tree: Apply) : Option[(List[(Tree, Symbol)], Tree)] = tree match { + def unapply(tree: Apply) : Option[Tree] = tree match { case a @ Apply( TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "choose"), types), - Function(vds, predicateBody) :: Nil) => - Some(((types zip vds.map(_.symbol)).toList, predicateBody)) + predicate :: Nil) => + Some(predicate) case _ => None } } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index ffa2524f3..8166e69f8 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -855,7 +855,7 @@ trait CodeExtraction extends ASTExtractors { } } - funDef.postcondition.foreach { case (id, e) => + funDef.postcondition.foreach { e => if(containsLetDef(e)) { reporter.warning(e.getPos, "Function postcondition should not contain nested function definition, ignoring.") funDef.postcondition = None @@ -967,9 +967,8 @@ trait CodeExtraction extends ASTExtractors { var rest = tmpRest val res = current match { - case ExEnsuredExpression(body, resVd, contract) => - val resId = FreshIdentifier(resVd.symbol.name.toString, extractType(current)).setPos(resVd.pos).setOwner(currentFunDef) - val post = extractTree(contract)(dctx.withNewVar(resVd.symbol -> (() => Variable(resId)))) + case ExEnsuredExpression(body, contract) => + val post = extractTree(contract) val b = try { extractTree(body) @@ -978,11 +977,11 @@ trait CodeExtraction extends ASTExtractors { NoTree(toPureScalaType(current.tpe)(dctx, current.pos)) } - Ensuring(b, resId, post) + Ensuring(b, post) case t @ ExHoldsExpression(body) => val resId = FreshIdentifier("holds", BooleanType).setPos(current.pos).setOwner(currentFunDef) - val post = Variable(resId).setPos(current.pos) + val post = Lambda(Seq(LeonValDef(resId)), Variable(resId).setPos(current.pos)) val b = try { extractTree(body) @@ -991,7 +990,7 @@ trait CodeExtraction extends ASTExtractors { NoTree(toPureScalaType(current.tpe)(dctx, current.pos)) } - Ensuring(b, resId, post) + Ensuring(b, post) case ExAssertExpression(contract, oerr) => val const = extractTree(contract) @@ -1276,22 +1275,9 @@ trait CodeExtraction extends ASTExtractors { WithOracle(newOracles, cBody) - case chs @ ExChooseExpression(args, body) => - val vars = args map { case (tpt, sym) => - val aTpe = extractType(tpt) - val newID = FreshIdentifier(sym.name.toString, aTpe).setOwner(currentFunDef) - owners += (newID -> None) - newID - } - - val newVars = (args zip vars).map { - case ((_, sym), id) => - sym -> (() => Variable(id)) - } - - val cBody = extractTree(body)(dctx.withNewVars(newVars)) - - Choose(vars, cBody) + case chs @ ExChooseExpression(body) => + val cBody = extractTree(body) + Choose(cBody) case l @ ExLambdaExpression(args, body) => val vds = args map { vd => @@ -1537,7 +1523,7 @@ trait CodeExtraction extends ASTExtractors { MethodInvocation(rec, cd, fd.typed(newTps), args) case (IsTyped(rec, ft: FunctionType), _, args) => - Application(rec, args) + application(rec, args) case (IsTyped(rec, cct: CaseClassType), name, Nil) if cct.fields.exists(_.id.name == name) => diff --git a/src/main/scala/leon/purescala/CallGraph.scala b/src/main/scala/leon/purescala/CallGraph.scala index 0318f8311..a6ffa8980 100644 --- a/src/main/scala/leon/purescala/CallGraph.scala +++ b/src/main/scala/leon/purescala/CallGraph.scala @@ -60,10 +60,7 @@ class CallGraph(p: Program) { } private def scanForCalls(fd: FunDef) { - val allExprs: Iterable[Expr] = - fd.precondition ++ fd.body ++ fd.postcondition.map(_._2) - - for (e <- allExprs; (from, to) <- collect(collectCalls(fd)(_))(e)) { + for( (from, to) <- collect(collectCalls(fd)(_))(fd.fullBody) ) { _calls += (from -> to) _callees += (from -> (_callees.getOrElse(from, Set()) + to)) _callers += (to -> (_callers.getOrElse(to, Set()) + from)) diff --git a/src/main/scala/leon/purescala/CompleteAbstractDefinitions.scala b/src/main/scala/leon/purescala/CompleteAbstractDefinitions.scala index ceaf8ad37..5d67a2286 100644 --- a/src/main/scala/leon/purescala/CompleteAbstractDefinitions.scala +++ b/src/main/scala/leon/purescala/CompleteAbstractDefinitions.scala @@ -16,27 +16,12 @@ object CompleteAbstractDefinitions extends TransformationPhase { val description = "Inject fake choose-like body in abstract definitions" def apply(ctx: LeonContext, program: Program): Program = { - // First we create the appropriate functions from methods: - var mdToFds = Map[FunDef, FunDef]() - - for (u <- program.units; m <- u.modules ) { - // We remove methods from class definitions and add corresponding functions - m.defs.foreach { - case fd: FunDef if fd.body.isEmpty => - val id = FreshIdentifier("res", fd.returnType) - - if (fd.hasPostcondition) { - val (pid, post) = fd.postcondition.get - - fd.body = Some(Choose(List(id), replaceFromIDs(Map(pid -> Variable(id)), post))) - } else { - fd.body = Some(Choose(List(id), BooleanLiteral(true))) - } - - case d => - } + for (u <- program.units; m <- u.modules; fd <- m.definedFunctions; if fd.body.isEmpty) { + val post = fd.postcondition getOrElse ( + Lambda(Seq(ValDef(FreshIdentifier("res", fd.returnType))), BooleanLiteral(true)) + ) + fd.body = Some(Choose(post)) } - // Translation is in-place program } diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index e757612b1..34e634e12 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -7,6 +7,7 @@ import utils._ object Constructors { import Trees._ + import TreeOps._ import Definitions._ import TypeTreeOps._ import Common._ @@ -26,6 +27,12 @@ object Constructors { def tupleSelect(t: Expr, index: Int, originalSize: Int): Expr = tupleSelect(t, index, originalSize > 1) + def let(id: Identifier, e: Expr, bd: Expr) = { + if (variablesOf(bd) contains id) + Let(id, e, bd) + else bd + } + def letTuple(binders: Seq[Identifier], value: Expr, body: Expr) = binders match { case Nil => body @@ -232,4 +239,19 @@ object Constructors { Lambda(args, body) } } + + def application(fn: Expr, realArgs: Seq[Expr]) = fn match { + case Lambda(formalArgs, body) => + val (inline, notInline) = formalArgs.map{_.id}.zip(realArgs).partition { + case (form, _) => count{ + case Variable(`form`) => 1 + case _ => 0 + }(body) <= 1 + } + val newBody = replaceFromIDs(inline.toMap, body) + val (ids, es) = notInline.unzip + letTuple(ids, tupleWrap(es), newBody) + case _ => Application(fn, realArgs) + } + } diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index c8ea48d59..ce8785768 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -273,9 +273,7 @@ object DefOps { */ def applyOnFunDef(operation : Expr => Expr)(funDef : FunDef): FunDef = { val newFunDef = funDef.duplicate - newFunDef.body = funDef.body map operation - newFunDef.precondition = funDef.precondition map operation - newFunDef.postcondition = funDef.postcondition map { case (id, ex) => (id, operation(ex))} + newFunDef.fullBody = operation(funDef.fullBody) newFunDef } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 2f96aa02c..ef3505476 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -379,7 +379,7 @@ object Definitions { } def postcondition = postconditionOf(fullBody) - def postcondition_=(op: Option[(Identifier, Expr)]) = { + def postcondition_=(op: Option[Expr]) = { fullBody = withPostcondition(fullBody, op) } @@ -493,9 +493,9 @@ object Definitions { lazy val returnType: TypeTree = translated(fd.returnType) private var trCache = Map[Expr, Expr]() - private var postCache = Map[(Identifier, Expr), (Identifier, Expr)]() + private var postCache = Map[Expr, Expr]() - def body = fd.body.map { b => + def body = fd.body.map { b => trCache.getOrElse(b, { val res = translated(b) trCache += b -> res @@ -503,7 +503,7 @@ object Definitions { }) } - def precondition = fd.precondition.map { pre => + def precondition = fd.precondition.map { pre => trCache.getOrElse(pre, { val res = translated(pre) trCache += pre -> res @@ -512,11 +512,11 @@ object Definitions { } def postcondition = fd.postcondition.map { - case (id, post) if typesMap.nonEmpty => - postCache.getOrElse((id, post), { + case post if typesMap.nonEmpty => + postCache.getOrElse(post, { val nId = FreshIdentifier(id.name, translated(id.getType)).copiedFrom(id) - val res = nId -> instantiateType(post, typesMap, paramsMap + (id -> nId)) - postCache += ((id,post) -> res) + val res = instantiateType(post, typesMap, paramsMap) + postCache += (post -> res) res }) diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index ebd13edb7..ce0535337 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -42,6 +42,12 @@ object Extractors { object BinaryOperator { def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { + case LetDef(fd, body) => Some((fd.fullBody, body, + (fdBd, body) => { + fd.fullBody = fdBd + LetDef(fd, body) + } + )) case Equals(t1,t2) => Some((t1,t2,Equals.apply)) case Implies(t1,t2) => Some((t1,t2, implies)) case Plus(t1,t2) => Some((t1,t2,Plus)) @@ -81,7 +87,7 @@ object Extractors { case ArraySelect(t1, t2) => Some((t1, t2, ArraySelect)) case Let(binders, e, body) => Some((e, body, (e: Expr, b: Expr) => Let(binders, e, b))) case Require(pre, body) => Some((pre, body, Require)) - case Ensuring(body, id, post) => Some((body, post, (b: Expr, p: Expr) => Ensuring(b, id, p))) + case Ensuring(body, post) => Some((body, post, (b: Expr, p: Expr) => Ensuring(b, p))) case Assert(const, oerr, body) => Some((const, body, (c: Expr, b: Expr) => Assert(c, oerr, b))) case (ex: BinaryExtractable) => ex.extract case _ => None @@ -94,6 +100,10 @@ object Extractors { object NAryOperator { def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match { + case Choose(pred, impl) => Some((Seq(pred) ++ impl.toSeq, { + case Seq(pred) => Choose(pred, None) + case Seq(pred, impl) => Choose(pred, Some(impl)) + })) case fi @ FunctionInvocation(fd, args) => Some((args, (as => FunctionInvocation(fd, as).setPos(fi)))) case mi @ MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, (as => MethodInvocation(as.head, cd, tfd, as.tail).setPos(mi)))) case fa @ Application(caller, args) => Some((caller +: args), (as => Application(as.head, as.tail).setPos(fa))) @@ -161,60 +171,6 @@ object Extractors { Passes(in, out, newcases) }} )) - case LetDef(fd, body) => - fd.body match { - case Some(b) => - (fd.precondition, fd.postcondition) match { - case (None, None) => - Some((Seq(b, body), (as: Seq[Expr]) => { - fd.body = Some(as(0)) - LetDef(fd, as(1)) - })) - case (Some(pre), None) => - Some((Seq(b, body, pre), (as: Seq[Expr]) => { - fd.body = Some(as(0)) - fd.precondition = Some(as(2)) - LetDef(fd, as(1)) - })) - case (None, Some((pid, post))) => - Some((Seq(b, body, post), (as: Seq[Expr]) => { - fd.body = Some(as(0)) - fd.postcondition = Some((pid, as(2))) - LetDef(fd, as(1)) - })) - case (Some(pre), Some((pid, post))) => - Some((Seq(b, body, pre, post), (as: Seq[Expr]) => { - fd.body = Some(as(0)) - fd.precondition = Some(as(2)) - fd.postcondition = Some((pid, as(3))) - LetDef(fd, as(1)) - })) - } - - case None => //case no body, we still need to handle remaining cases - (fd.precondition, fd.postcondition) match { - case (None, None) => - Some((Seq(body), (as: Seq[Expr]) => { - LetDef(fd, as(0)) - })) - case (Some(pre), None) => - Some((Seq(body, pre), (as: Seq[Expr]) => { - fd.precondition = Some(as(1)) - LetDef(fd, as(0)) - })) - case (None, Some((pid, post))) => - Some((Seq(body, post), (as: Seq[Expr]) => { - fd.postcondition = Some((pid, as(1))) - LetDef(fd, as(0)) - })) - case (Some(pre), Some((pid, post))) => - Some((Seq(body, pre, post), (as: Seq[Expr]) => { - fd.precondition = Some(as(1)) - fd.postcondition = Some((pid, as(2))) - LetDef(fd, as(0)) - })) - } - } case (ex: NAryExtractable) => ex.extract case _ => None } @@ -259,6 +215,7 @@ object Extractors { object TopLevelOrs { // expr1 AND (expr2 AND (expr3 AND ..)) => List(expr1, expr2, expr3) def unapply(e: Expr): Option[Seq[Expr]] = e match { + case Let(i, e, TopLevelOrs(bs)) => Some(bs map (let(i,e,_))) case Or(exprs) => Some(exprs.flatMap(unapply(_)).flatten) case e => @@ -267,6 +224,7 @@ object Extractors { } object TopLevelAnds { // expr1 AND (expr2 AND (expr3 AND ..)) => List(expr1, expr2, expr3) def unapply(e: Expr): Option[Seq[Expr]] = e match { + case Let(i, e, TopLevelAnds(bs)) => Some(bs map (let(i,e,_))) case And(exprs) => Some(exprs.flatMap(unapply(_)).flatten) case e => diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 10a30c416..10ee5d41e 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -77,21 +77,17 @@ object FunctionClosure extends TransformationPhase { functionClosure(newExpr, newBindedVars, freshIds, fd2FreshFd) } - val newPrecondition = simplifyLets(introduceLets(and((capturedConstraints ++ fd.precondition).toSeq :_*), fd2FreshFd)) newFunDef.precondition = if(newPrecondition == BooleanLiteral(true)) None else Some(newPrecondition) - val freshPostcondition = fd.postcondition.map{ case (id, post) => (id, introduceLets(post, fd2FreshFd)) } + val freshPostcondition = fd.postcondition.map{ post => introduceLets(post, fd2FreshFd) } newFunDef.postcondition = freshPostcondition pathConstraints = fd.precondition.getOrElse(BooleanLiteral(true)) :: pathConstraints - //val freshBody = fd.body.map(body => introduceLets(body, fd2FreshFd + (fd -> ((newFunDef, extraValDefFreshIds.map(_.toVariable)))))) val freshBody = fd.body.map(body => introduceLets(body, fd2FreshFd + (fd -> ((newFunDef, extraValDefOldIds.map(_.toVariable)))))) newFunDef.body = freshBody pathConstraints = pathConstraints.tail - //val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd -> - // ((newFunDef, extraValDefOldIds.map(id => id2freshId.get(id).getOrElse(id).toVariable))))) val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd -> ((newFunDef, extraValDefOldIds.map(_.toVariable))))) freshRest.copiedFrom(l) } diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 4c82c31d7..1ef75b902 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -54,12 +54,7 @@ object MethodLifting extends TransformationPhase { (fd, None) } - nfd.precondition = nfd.precondition.map(removeMethodCalls(rec)) - nfd.body = nfd.body.map(removeMethodCalls(rec)) - nfd.postcondition = nfd.postcondition.map { - case (id, post) => (id, removeMethodCalls(rec)(post)) - } - + nfd.fullBody = removeMethodCalls(rec)(nfd.fullBody) nfd } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 6e0195fc0..80da11205 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -204,11 +204,11 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe p"""|assert($const) |$body""" - case Ensuring(body, id, post) => + case Ensuring(body, post) => p"""|{ | $body |} ensuring { - | (${typed(id)}) => $post + | $post |}""" case Gives(s, tests) => @@ -281,12 +281,12 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case Tuple(exprs) => p"($exprs)" case TupleSelect(t, i) => p"${t}._$i" case NoTree(tpe) => p"???($tpe)" - case Choose(vars, pred, oimpl) => + case Choose(pred, oimpl) => oimpl match { case Some(e) => - p"$e /* choose: $vars => $pred */" + p"$e /* choose: $pred */" case None => - p"choose(($vars) => $pred)" + p"choose($pred)" } case e @ Error(tpe, err) => p"""error[$tpe]("$err")""" case CaseClassInstanceOf(cct, e) => @@ -635,9 +635,9 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe p"""| |}""" - fd.postcondition.foreach { case (id, post) => + fd.postcondition.foreach { post => p"""| ensuring { - | (${typed(id)}) => $post + | $post |}""" } diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index ab9ff171f..cf58ee34d 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -26,7 +26,7 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex tree match { case Not(Equals(l, r)) => p"$l != $r" case Implies(l,r) => pp(or(not(l), r)) - case Choose(vars, pred, None) => p"choose((${typed(vars)}) => $pred)" + case Choose(pred, None) => p"choose($pred)" case s @ FiniteSet(rss) => { val rs = rss.toSeq s.getType match { diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index e319619e6..525c5f26d 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -51,15 +51,7 @@ class ScopeSimplifier extends Transformer { newScope = newScope.registerFunDef(fd -> newFd) - newFd.body = fd.body.map(b => rec(b, newScope)) - newFd.precondition = fd.precondition.map(pre => rec(pre, newScope)) - - newFd.postcondition = fd.postcondition.map { - case (id, post) => - val nid = genId(id, newScope) - val postScope = newScope.register(id -> nid) - (nid, rec(post, postScope)) - } + newFd.fullBody = rec(fd.fullBody, newScope) LetDef(newFd, rec(body, newScope)) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index cb02eed02..5cc4796c2 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -359,9 +359,8 @@ object TreeOps { e match { case Variable(i) => subvs + i - case LetDef(fd,_) => subvs -- fd.params.map(_.id) -- fd.postcondition.map(_._1) + case LetDef(fd,_) => subvs -- fd.params.map(_.id) case Let(i,_,_) => subvs - i - case Choose(is,_,_) => subvs -- is case MatchLike(_, cses, _) => subvs -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) case Passes(_, _ , cses) => subvs -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) case Lambda(args, body) => subvs -- args.map(_.id) @@ -1222,10 +1221,11 @@ object TreeOps { } def traverse(funDef: FunDef): Seq[T] = { + // @mk FIXME: This seems overly compicated val precondition = funDef.precondition.map(e => matchToIfThenElse(e)).toSeq val precTs = funDef.precondition.map(e => traverse(e)).toSeq.flatten val bodyTs = funDef.body.map(e => traverse(e, precondition)).toSeq.flatten - val postTs = funDef.postcondition.map(p => traverse(p._2)).toSeq.flatten + val postTs = funDef.postcondition.map(p => traverse(p)).toSeq.flatten precTs ++ bodyTs ++ postTs } @@ -1290,7 +1290,7 @@ object TreeOps { def isDeterministic(e: Expr): Boolean = { preTraversal{ - case Choose(_, _, None) => return false + case Choose(_, None) => return false case Hole(_, _) => return false case Gives(_,_) => return false case _ => @@ -1447,31 +1447,11 @@ object TreeOps { } def fdHomo(fd1: FunDef, fd2: FunDef)(implicit map: Map[Identifier, Identifier]) = { - if (fd1.params.size == fd2.params.size && - fd1.precondition.size == fd2.precondition.size && - fd1.body.size == fd2.body.size && - fd1.postcondition.size == fd2.postcondition.size) { - - val newMap = map + - (fd1.id -> fd2.id) ++ - (fd1.params zip fd2.params).map{ case (vd1, vd2) => (vd1.id, vd2.id) } - - val preMatch = (fd1.precondition zip fd2.precondition).forall { - case (e1, e2) => isHomo(e1, e2)(newMap) - } - - val postMatch = (fd1.postcondition zip fd2.postcondition).forall { - case ((id1, e1), (id2, e2)) => isHomo(e1, e2)(newMap + (id1 -> id2)) - } - - val bodyMatch = (fd1.body zip fd2.body).forall { - case (e1, e2) => isHomo(e1, e2)(newMap) - } - - preMatch && postMatch && bodyMatch - } else { - false - + (fd1.params.size == fd2.params.size) && { + val newMap = map + + (fd1.id -> fd2.id) ++ + (fd1.params zip fd2.params).map{ case (vd1, vd2) => (vd1.id, vd2.id) } + isHomo(fd1.fullBody, fd2.fullBody)(newMap) } } @@ -1535,8 +1515,8 @@ object TreeOps { case (Variable(i1), Variable(i2)) => idHomo(i1, i2) - case (Choose(ids1, e1, _), Choose(ids2, e2, _)) => - isHomo(e1, e2)(map ++ (ids1 zip ids2)) + case (Choose(e1, _), Choose(e2, _)) => + isHomo(e1, e2) case (Let(id1, v1, e1), Let(id2, v2, e2)) => isHomo(v1, v2) && @@ -1762,15 +1742,19 @@ object TreeOps { Some(and(oe, simplePreTransform(pre)(ie))) } - def mergePost(outer: Option[(Identifier, Expr)], inner: Option[(Identifier, Expr)]): Option[(Identifier, Expr)] = (outer, inner) match { - case (None, Some((iid, ie))) => - Some((iid, simplePreTransform(pre)(ie))) + def mergePost(outer: Option[Expr], inner: Option[Expr]): Option[Expr] = (outer, inner) match { + case (None, Some(ie)) => + Some(simplePreTransform(pre)(ie)) case (Some(oe), None) => Some(oe) case (None, None) => None - case (Some((oid, oe)), Some((iid, ie))) => - Some((oid, and(oe, replaceFromIDs(Map(iid -> Variable(oid)), simplePreTransform(pre)(ie))))) + case (Some(oe), Some(ie)) => + val res = FreshIdentifier("res", fdOuter.returnType, true) + Some(Lambda(Seq(ValDef(res)), and( + application(oe, Seq(Variable(res))), + application(simplePreTransform(pre)(ie), Seq(Variable(res))) + ))) } val newFd = fdOuter.duplicate @@ -1779,7 +1763,7 @@ object TreeOps { newFd.body = fdInner.body.map(b => simplePreTransform(pre)(b)) newFd.precondition = mergePre(fdOuter.precondition, fdInner.precondition).map(simp) - newFd.postcondition = mergePost(fdOuter.postcondition, fdInner.postcondition).map{ case (id, ex) => id -> simp(ex) } + newFd.postcondition = mergePost(fdOuter.postcondition, fdInner.postcondition).map(simp) newFd } else { @@ -1810,46 +1794,45 @@ object TreeOps { */ def withPrecondition(expr: Expr, pred: Option[Expr]): Expr = (pred, expr) match { - case (Some(newPre), Require(pre, b)) => Require(newPre, b) - case (Some(newPre), Ensuring(Require(pre, b), i, p)) => Ensuring(Require(newPre, b), i, p) - case (Some(newPre), Ensuring(b, i, p)) => Ensuring(Require(newPre, b), i, p) - case (Some(newPre), b) => Require(newPre, b) - case (None, Require(pre, b)) => b - case (None, Ensuring(Require(pre, b), i, p)) => Ensuring(b, i, p) - case (None, Ensuring(b, i, p)) => Ensuring(b, i, p) - case (None, b) => b + case (Some(newPre), Require(pre, b)) => Require(newPre, b) + case (Some(newPre), Ensuring(Require(pre, b), p)) => Ensuring(Require(newPre, b), p) + case (Some(newPre), Ensuring(b, p)) => Ensuring(Require(newPre, b), p) + case (Some(newPre), b) => Require(newPre, b) + case (None, Require(pre, b)) => b + case (None, Ensuring(Require(pre, b), p)) => Ensuring(b, p) + case (None, b) => b } - def withPostcondition(expr: Expr, oie: Option[(Identifier, Expr)]) = (oie, expr) match { - case (Some((nid, npost)), Ensuring(b, id, post)) => Ensuring(b, nid, npost) - case (Some((nid, npost)), b) => Ensuring(b, nid, npost) - case (None, Ensuring(b, i, p)) => b - case (None, b) => b + def withPostcondition(expr: Expr, oie: Option[Expr]) = (oie, expr) match { + case (Some(npost), Ensuring(b, post)) => Ensuring(b, npost) + case (Some(npost), b) => Ensuring(b, npost) + case (None, Ensuring(b, p)) => b + case (None, b) => b } def withBody(expr: Expr, body: Option[Expr]) = expr match { - case Require(pre, _) => Require(pre, body.getOrElse(NoTree(expr.getType))) - case Ensuring(Require(pre, _), i, post) => Ensuring(Require(pre, body.getOrElse(NoTree(expr.getType))), i, post) - case Ensuring(_, i, post) => Ensuring(body.getOrElse(NoTree(expr.getType)), i, post) - case _ => body.getOrElse(NoTree(expr.getType)) + case Require(pre, _) => Require(pre, body.getOrElse(NoTree(expr.getType))) + case Ensuring(Require(pre, _), post) => Ensuring(Require(pre, body.getOrElse(NoTree(expr.getType))), post) + case Ensuring(_, post) => Ensuring(body.getOrElse(NoTree(expr.getType)), post) + case _ => body.getOrElse(NoTree(expr.getType)) } def withoutSpec(expr: Expr) = expr match { - case Require(pre, b) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case Ensuring(Require(pre, b), i, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case Ensuring(b, i, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) - case b => Option(b).filterNot(_.isInstanceOf[NoTree]) + case Require(pre, b) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case Ensuring(Require(pre, b), post) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case Ensuring(b, post) => Option(b).filterNot(_.isInstanceOf[NoTree]) + case b => Option(b).filterNot(_.isInstanceOf[NoTree]) } def preconditionOf(expr: Expr) = expr match { - case Require(pre, _) => Some(pre) - case Ensuring(Require(pre, _), _, _) => Some(pre) - case b => None + case Require(pre, _) => Some(pre) + case Ensuring(Require(pre, _), _) => Some(pre) + case b => None } def postconditionOf(expr: Expr) = expr match { - case Ensuring(_, i, post) => Some((i, post)) - case _ => None + case Ensuring(_, post) => Some(post) + case _ => None } def breakDownSpecs(e : Expr) = (preconditionOf(e), withoutSpec(e), postconditionOf(e)) @@ -1914,7 +1897,7 @@ object TreeOps { case Application(caller, args) => val newArgs = args.map(rec(_, true)) val newCaller = rec(caller, false) - extract(Application(newCaller, newArgs), build) + extract(application(newCaller, newArgs), build) case FunctionInvocation(fd, args) => val newArgs = args.map(rec(_, true)) extract(FunctionInvocation(fd, newArgs), build) @@ -2013,113 +1996,6 @@ object TreeOps { @deprecated("Use exists instead", "Leon 0.2.1") def contains(e: Expr, matcher: Expr => Boolean): Boolean = exists(matcher)(e) - /** - * Eliminates tuples of arity 0 and 1. - * Used to simplify synthesis solutions - * - * Only rewrites local fundefs. - */ - @deprecated("Use purescala.Constructors.tuple* and purescala.Extractors.Unwrap* " + - "to avoid creation of tuples of size 0 and 1", "Leon 3.0.0" - ) - def rewriteTuples(expr: Expr) : Expr = { - def mapType(tt : TypeTree) : Option[TypeTree] = tt match { - case TupleType(ts) => ts.size match { - case 0 => Some(UnitType) - case 1 => Some(ts(0)) - case _ => - val tss = ts.map(mapType) - if(tss.exists(_.isDefined)) { - Some(TupleType((tss zip ts).map(p => p._1.getOrElse(p._2)))) - } else { - None - } - } - case SetType(t) => mapType(t).map(SetType(_)) - case MultisetType(t) => mapType(t).map(MultisetType(_)) - case ArrayType(t) => mapType(t).map(ArrayType(_)) - case MapType(f,t) => - val (f2,t2) = (mapType(f),mapType(t)) - if(f2.isDefined || t2.isDefined) { - Some(MapType(f2.getOrElse(f), t2.getOrElse(t))) - } else { - None - } - case ft : FunctionType => None // FIXME - - case a : AbstractClassType => None - case cct : CaseClassType => - // This is really just one big assertion. We don't rewrite class defs. - val fieldTypes = cct.fields.map(_.getType) - if(fieldTypes.exists(t => t match { - case TupleType(ts) if ts.size <= 1 => true - case _ => false - })) { - scala.sys.error("Cannot rewrite case class def that contains degenerate tuple types.") - } else { - None - } - case Untyped | BooleanType | Int32Type | IntegerType | UnitType | TypeParameter(_) => None - } - - var idMap = Map[Identifier, Identifier]() - var funDefMap = Map.empty[FunDef,FunDef] - - def fd2fd(funDef : FunDef) : FunDef = funDefMap.get(funDef) match { - case Some(fd) => fd - case None => - if(funDef.params.map(vd => mapType(vd.getType)).exists(_.isDefined)) { - scala.sys.error("Cannot rewrite function def that takes degenerate tuple arguments,") - } - val newFD = mapType(funDef.returnType) match { - case None => funDef - case Some(rt) => - val fd = new FunDef(FreshIdentifier(funDef.id.name, alwaysShowUniqueID = true), funDef.tparams, rt, funDef.params, funDef.defType) - // These will be taken care of in the recursive traversal. - fd.body = funDef.body - fd.precondition = funDef.precondition - funDef.postcondition match { - case Some((id, post)) => - val freshId = FreshIdentifier(id.name, rt, true) - idMap += id -> freshId - fd.postcondition = Some((freshId, post)) - case None => - fd.postcondition = None - } - fd - } - funDefMap = funDefMap.updated(funDef, newFD) - newFD - } - - import synthesis.Witnesses.Terminating - - def pre(e : Expr) : Expr = e match { - case Tuple(Seq()) => println("Tuple0!"); UnitLiteral() - case Variable(id) if idMap contains id => Variable(idMap(id)) - - case Error(tpe, err) => Error(mapType(tpe).getOrElse(e.getType), err).copiedFrom(e) - case Tuple(Seq(s)) => println("Tuple1!"); pre(s) - - case LetTuple(bs, v, bdy) if bs.size == 1 => - Let(bs(0), v, bdy) - - case l @ LetDef(fd, bdy) => - LetDef(fd2fd(fd), bdy) - - case FunctionInvocation(tfd, args) => - FunctionInvocation(fd2fd(tfd.fd).typed(tfd.tps), args) - - case Terminating(tfd, args) => - Terminating(fd2fd(tfd.fd).typed(tfd.tps), args) - - case _ => e - } - - simplePreTransform(pre)(expr) - } - - /* * Transforms complicated Ifs into multiple nested if blocks * It will decompose every OR clauses, and it will group AND clauses checking diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 561e9c903..5773f1716 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -41,22 +41,40 @@ object Trees { val getType = body.getType } - case class Ensuring(body: Expr, id: Identifier, pred: Expr) extends Expr { - val getType = body.getType + case class Ensuring(body: Expr, pred: Expr) extends Expr { + val getType = pred.getType match { + case FunctionType(Seq(bodyType), BooleanType) if bodyType == body.getType => bodyType + case _ => Untyped + } + def toAssert: Expr = { + val res = FreshIdentifier("res", getType, true) + Let(res, body, Assert(application(pred, Seq(Variable(res))), Some("Postcondition failed @" + this.getPos), Variable(res))) + } } case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr { val getType = body.getType } - case class Choose(vars: List[Identifier], pred: Expr, var impl: Option[Expr] = None) extends Expr with NAryExtractable { - require(!vars.isEmpty) + case class Choose(pred: Expr, private var impl_ : Option[Expr] = None) extends Expr { + val getType = pred.getType match { + case FunctionType(from, to) if from.nonEmpty => // @mk why nonEmpty? + tupleTypeWrap(from) + case _ => + Untyped + } - val getType = tupleTypeWrap(vars.map(_.getType)) + require(impl_ forall { imp => isSubtypeOf(imp.getType, this.getType)}) - def extract = { - Some((Seq(pred)++impl, (es: Seq[Expr]) => Choose(vars, es.head, es.tail.headOption).setPos(this))) + def impl_= (newImpl: Option[Expr]) = { + require( + newImpl forall {imp => isSubtypeOf(imp.getType,this.getType)}, + newImpl.get +":" + newImpl.get.getType + " vs " + this + ":" + this.getType + ) + impl_ = newImpl } + + def impl = impl_ } /* Like vals */ diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala index b61136cf0..451291d7d 100644 --- a/src/main/scala/leon/purescala/TypeTreeOps.scala +++ b/src/main/scala/leon/purescala/TypeTreeOps.scala @@ -279,9 +279,8 @@ object TypeTreeOps { val newId = freshId(id, tpeSub(id.getType)) Let(newId, srec(value), rec(idsMap + (id -> newId))(body)).copiedFrom(l) - case c @ Choose(xs, pred, oimpl) => - val newXs = xs.map(id => freshId(id, tpeSub(id.getType))) - Choose(newXs, rec(idsMap ++ (xs zip newXs))(pred), oimpl.map(srec)).copiedFrom(c) + case c @ Choose(pred, oimpl) => + Choose(rec(idsMap)(pred), oimpl.map(srec)).copiedFrom(c) case l @ Lambda(args, body) => val newArgs = args.map { arg => @@ -310,9 +309,8 @@ object TypeTreeOps { sys.error(s"Tried to substitute $tpar with $other within GenericValue $g") } - case ens @ Ensuring(body, id, pred) => - val newId = freshId(id, tpeSub(id.getType)) - Ensuring(srec(body), newId, rec(idsMap + (id -> newId))(pred)).copiedFrom(ens) + case ens @ Ensuring(body, pred) => + Ensuring(srec(body), rec(idsMap)(pred)).copiedFrom(ens) case s @ FiniteSet(elements) if elements.isEmpty => val SetType(tp) = s.getType diff --git a/src/main/scala/leon/repair/RepairNDEvaluator.scala b/src/main/scala/leon/repair/RepairNDEvaluator.scala index 4dd78d134..a2ed84fd7 100644 --- a/src/main/scala/leon/repair/RepairNDEvaluator.scala +++ b/src/main/scala/leon/repair/RepairNDEvaluator.scala @@ -47,14 +47,14 @@ class RepairNDEvaluator(ctx: LeonContext, prog: Program, fd : FunDef, cond: Expr def treat(subst : Expr => Expr) = { val callResult = e(subst(body))(frame, gctx) - if(tfd.hasPostcondition) { - val (id, post) = tfd.postcondition.get - - e(subst(post))(frame.withNewVar(id, callResult), gctx) match { - case BooleanLiteral(true) => - case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } + tfd.postcondition match { + case Some(post) => + e(subst(Application(post, Seq(callResult))))(frame, gctx) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") + case other => throw EvalError(typeErrorMsg(other, BooleanType)) + } + case None => } callResult diff --git a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala index 791b7ec01..22c4cb271 100644 --- a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala +++ b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala @@ -94,20 +94,20 @@ class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends Recursive val callResult = e(body)(frameBlamingCallee, gctx) - if(tfd.hasPostcondition) { - val (id, post) = tfd.postcondition.get - - e(post)(frameBlamingCallee.withNewVar(id, callResult), gctx) match { - case BooleanLiteral(true) => - case BooleanLiteral(false) => - // Callee's fault - registerFailed(tfd.fd, evArgs) - throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") - case other => - // Callee's fault - registerFailed(tfd.fd, evArgs) - throw EvalError(typeErrorMsg(other, BooleanType)) - } + tfd.postcondition match { + case Some(post) => + e(Application(post, Seq(callResult)))(frameBlamingCallee, gctx) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => + // Callee's fault + registerFailed(tfd.fd, evArgs) + throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") + case other => + // Callee's fault + registerFailed(tfd.fd, evArgs) + throw EvalError(typeErrorMsg(other, BooleanType)) + } + case None => } callResult diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 45dab972e..76656acdf 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -14,14 +14,14 @@ import purescala.Extractors.unwrapTuple import purescala.ScalaPrinter import evaluators._ import solvers._ -import utils._ import solvers.z3._ +import utils._ import codegen._ import verification._ import synthesis._ import synthesis.rules._ -import rules._ import synthesis.Witnesses._ +import rules._ import graph.DotGenerator import leon.utils.ASCIIHelpers.title @@ -225,13 +225,11 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout val args = fd.params.map(_.id) val argsWrapped = tupleWrap(args.map(_.toVariable)) - val out = fd.postcondition.map(_._1).getOrElse(FreshIdentifier("res", fd.returnType, true)) - - val spec = fd.postcondition.map(_._2).getOrElse(BooleanLiteral(true)) + val spec = fd.postcondition.getOrElse(Lambda(Seq(ValDef(FreshIdentifier("res", fd.returnType, true))), BooleanLiteral(true))) val body = fd.body.get - val choose = Choose(List(out), spec) + val choose = Choose(spec) val evaluator = new DefaultEvaluator(ctx, program) @@ -260,23 +258,19 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout soFar } - def focus(expr: Expr, env: Map[Identifier, Expr])(implicit spec: Expr, out: Identifier): (Expr, Expr) = { - val choose = Choose(List(out), spec) + def focus(expr: Expr, env: Map[Identifier, Expr])(implicit spec: Expr): (Expr, Expr) = { + val choose = Choose(spec) def testCondition(cond: Expr, inExpr: Expr => Expr) = forAllTests( - spec, - env + (out -> inExpr(not(cond))), + application(spec, Seq(inExpr(not(cond)))), + env, new RepairNDEvaluator(ctx,program,fd,cond) ) def condAsSpec(cond: Expr, inExpr: Expr => Expr) = { val newOut = FreshIdentifier("cond", BooleanType, true) - val newSpec = Let( - out, - inExpr(Variable(newOut)), - spec - ) - val (b, r) = focus(cond, env)(newSpec, newOut) + val newSpec = Lambda(Seq(ValDef(newOut)), application(spec, Seq(inExpr(Variable(newOut))))) + val (b, r) = focus(cond, env)(newSpec) (inExpr(b), r) } @@ -378,7 +372,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout } } - focus(body, Map())(spec, out) + focus(body, Map())(spec) } private def getVerificationCounterExamples(fd: FunDef, prog: Program): VerificationResult = { diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index 741c750f3..1015b4a92 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -85,8 +85,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { // Now the postcondition. val (condVars, exprVars, guardedExprs, lambdas) = tfd.postcondition match { - case Some((id, post)) => - val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post)) + case Some(post) => + val newPost : Expr = application(matchToIfThenElse(post), Seq(invocation)) val postHolds : Expr = if(tfd.hasPrecondition) { @@ -119,7 +119,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { private def appliedEquals(invocation: Expr, body: Expr): Expr = body match { case Lambda(args, lambdaBody) => - appliedEquals(Application(invocation, args.map(_.toVariable)), lambdaBody) + appliedEquals(application(invocation, args.map(_.toVariable)), lambdaBody) case _ => Equals(invocation, body) } @@ -190,8 +190,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { storeGuarded(pathVar, rec(pathVar, cond)) rec(pathVar, body) - case e @ Ensuring(body, id, post) => - rec(pathVar, Let(id, body, Assert(post, None, Variable(id)))) + case e @ Ensuring(_, _) => + rec(pathVar, e.toAssert) case l @ Let(i, e : Lambda, b) => val re = rec(pathVar, e) // guaranteed variable! @@ -265,20 +265,13 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { } } - case c @ Choose(ids, cond, Some(impl)) => + case c @ Choose(cond, Some(impl)) => rec(pathVar, impl) - case c @ Choose(ids, cond, None) => + case c @ Choose(cond, None) => val cid = FreshIdentifier("choose", c.getType, true) storeExpr(cid) - - val m: Map[Expr, Expr] = if (ids.size == 1) { - Map(Variable(ids.head) -> Variable(cid)) - } else { - ids.zipWithIndex.map{ case (id, i) => Variable(id) -> tupleSelect(Variable(cid), i+1, ids.size) }.toMap - } - - storeGuarded(pathVar, replace(m, cond)) + storeGuarded(pathVar, application(cond, Seq(Variable(cid)))) Variable(cid) case l @ Lambda(args, body) => diff --git a/src/main/scala/leon/synthesis/ConvertHoles.scala b/src/main/scala/leon/synthesis/ConvertHoles.scala index d0118c318..7b2450fac 100644 --- a/src/main/scala/leon/synthesis/ConvertHoles.scala +++ b/src/main/scala/leon/synthesis/ConvertHoles.scala @@ -41,7 +41,7 @@ object ConvertHoles extends LeonPhase[Program, Program] { val (pre, body, post) = breakDownSpecs(e) // Ensure that holes are not found in pre and/or post conditions - pre.foreach { + (pre ++ post).foreach { preTraversal{ case h : Hole => ctx.reporter.error("Holes are not supported in preconditions. @"+ h.getPos) @@ -49,14 +49,6 @@ object ConvertHoles extends LeonPhase[Program, Program] { } } - post.foreach { case (id, post) => - preTraversal{ - case h : Hole => - ctx.reporter.error("Holes are not supported in postconditions. @"+ h.getPos) - case _ => - }(post) - } - body match { case Some(body) => var holes = List[Identifier]() @@ -75,15 +67,15 @@ object ConvertHoles extends LeonPhase[Program, Program] { }(body) val asChoose = if (holes.nonEmpty) { - val cids = holes.map(_.freshen) + val cids: List[Identifier] = holes.map(_.freshen) val pred = post match { - case Some((id, post)) => - replaceFromIDs((holes zip cids.map(_.toVariable)).toMap, Let(id, withoutHoles, post)) + case Some(post) => + replaceFromIDs((holes zip cids.map(_.toVariable)).toMap, post) case None => - BooleanLiteral(true) + Lambda(cids.map(ValDef(_)), BooleanLiteral(true)) } - letTuple(holes, Choose(cids, pred), withoutHoles) + letTuple(holes, Choose(pred), withoutHoles) } else withoutHoles diff --git a/src/main/scala/leon/synthesis/ConvertWithOracles.scala b/src/main/scala/leon/synthesis/ConvertWithOracles.scala index f8072b71a..8b130c7ad 100644 --- a/src/main/scala/leon/synthesis/ConvertWithOracles.scala +++ b/src/main/scala/leon/synthesis/ConvertWithOracles.scala @@ -44,17 +44,17 @@ object ConvertWithOracle extends LeonPhase[Program, Program] { val body = preMap { case wo @ WithOracle(os, b) => withoutSpec(b) match { - case Some(body) => + case Some(pred) => val chooseOs = os.map(_.freshen) val pred = postconditionOf(b) match { - case Some((id, post)) => - replaceFromIDs((os zip chooseOs.map(_.toVariable)).toMap, Let(id, body, post)) + case Some(post) => + post // FIXME we need to freshen variables case None => - BooleanLiteral(true) + Lambda(chooseOs.map(ValDef(_)), BooleanLiteral(true)) } - Some(letTuple(os, Choose(chooseOs, pred), b)) + Some(letTuple(os, Choose(pred), b)) case None => None } @@ -73,12 +73,12 @@ object ConvertWithOracle extends LeonPhase[Program, Program] { } } - fd.postcondition.foreach { case (id, post) => + fd.postcondition.foreach { preTraversal{ case _: WithOracle => ctx.reporter.error("WithOracle expressions are not supported in postconditions. (function "+fd.id.asString(ctx)+")") case _ => - }(post) + } } }) diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 679b6c569..91f21e4ee 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -23,7 +23,8 @@ class ExamplesFinder(ctx: LeonContext, program: Program) { val reporter = ctx.reporter def extractTests(fd: FunDef): (Seq[Example], Seq[Example]) = fd.postcondition match { - case Some((id, post)) => + case Some(Lambda(Seq(ValDef(id, _)), post)) => + // @mk FIXME: make this more general val tests = extractTestsOf(post) val insIds = fd.params.map(_.id).toSet diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index 042affe42..71c12af51 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -6,7 +6,7 @@ package synthesis import leon.purescala.Trees._ import leon.purescala.Definitions._ import leon.purescala.TreeOps._ -import leon.purescala.TypeTrees.TypeTree +import leon.purescala.TypeTrees._ import leon.purescala.Common._ import leon.purescala.Constructors._ import leon.purescala.Extractors._ @@ -28,13 +28,18 @@ case class Problem(as: List[Identifier], ws: Expr, pc: Expr, phi: Expr, xs: List object Problem { def fromChoose(ch: Choose, pc: Expr = BooleanLiteral(true)): Problem = { - val xs = ch.vars - val phi = simplifyLets(ch.pred) - val as = (variablesOf(And(pc, phi))--xs).toList + val xs = { + val tps = ch.pred.getType.asInstanceOf[FunctionType].from + tps map (FreshIdentifier("x", _, true)) + }.toList + + val phi = application(simplifyLets(ch.pred), xs map { _.toVariable}) + val as = (variablesOf(And(pc, phi)) -- xs).toList // FIXME do we need this at all? val TopLevelAnds(clauses) = pc + // @mk FIXME: Is this needed? val (pcs, wss) = clauses.partition { case w : Witness => false case _ => true diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 5bf161e7b..bc93ce0c7 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -68,11 +68,11 @@ object Solution { def unapply(s: Solution): Option[(Expr, Set[FunDef], Expr)] = if (s eq null) None else Some((s.pre, s.defs, s.term)) def choose(p: Problem): Solution = { - new Solution(BooleanLiteral(true), Set(), Choose(p.xs, p.phi)) + new Solution(BooleanLiteral(true), Set(), Choose(Lambda(p.xs.map(ValDef(_)), p.phi))) } def chooseComplete(p: Problem): Solution = { - new Solution(BooleanLiteral(true), Set(), Choose(p.xs, and(p.pc, p.phi))) + new Solution(BooleanLiteral(true), Set(), Choose(Lambda(p.xs.map(ValDef(_)), and(p.pc, p.phi)))) } // Generate the simplest, wrongest solution, used for complexity lowerbound diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index c44e70059..6898a685f 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -111,7 +111,7 @@ class Synthesizer(val context : LeonContext, val fd = new FunDef(FreshIdentifier(ci.fd.id.name+"_final", alwaysShowUniqueID = true), Nil, ret, problem.as.map(ValDef(_)), DefType.MethodDef) fd.precondition = Some(and(problem.pc, sol.pre)) - fd.postcondition = Some((res.id, replace(mapPost, problem.phi))) + fd.postcondition = Some(Lambda(Seq(ValDef(res.id)), replace(mapPost, problem.phi))) fd.body = Some(sol.term) val newDefs = fd +: sol.defs.toList diff --git a/src/main/scala/leon/synthesis/rules/ADTInduction.scala b/src/main/scala/leon/synthesis/rules/ADTInduction.scala index a84797242..5dd0afe26 100644 --- a/src/main/scala/leon/synthesis/rules/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/rules/ADTInduction.scala @@ -102,7 +102,7 @@ case object ADTInduction extends Rule("ADT Induction") { val outerPre = orJoin(globalPre) newFun.precondition = Some(funPre) - newFun.postcondition = Some((idPost, letTuple(p.xs.toSeq, Variable(idPost), funPost))) + newFun.postcondition = Some(Lambda(Seq(ValDef(idPost)), letTuple(p.xs.toSeq, Variable(idPost), funPost))) newFun.body = Some(matchExpr(Variable(inductOn), cases)) diff --git a/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala b/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala index 7fe22426f..6dcb79d5e 100644 --- a/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala +++ b/src/main/scala/leon/synthesis/rules/ADTLongInduction.scala @@ -151,7 +151,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") { val outerPre = orJoin(globalPre) newFun.precondition = Some(funPre) - newFun.postcondition = Some((idPost, letTuple(p.xs.toSeq, Variable(idPost), funPost))) + newFun.postcondition = Some(Lambda(Seq(ValDef(idPost)), letTuple(p.xs.toSeq, Variable(idPost), funPost))) newFun.body = Some(matchExpr(Variable(inductOn), cases)) diff --git a/src/main/scala/leon/synthesis/rules/Assert.scala b/src/main/scala/leon/synthesis/rules/Assert.scala index 1671cc450..3e11b4d65 100644 --- a/src/main/scala/leon/synthesis/rules/Assert.scala +++ b/src/main/scala/leon/synthesis/rules/Assert.scala @@ -14,7 +14,6 @@ case object Assert extends NormalizingRule("Assert") { p.phi match { case TopLevelAnds(exprs) => val xsSet = p.xs.toSet - val (exprsA, others) = exprs.partition(e => (variablesOf(e) & xsSet).isEmpty) if (!exprsA.isEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala index f4da2f71a..bb43d716b 100644 --- a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala +++ b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala @@ -66,9 +66,9 @@ case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { // We are replacing foo(a) with b. We inject postcondition(foo)(a, b). val postsToInject = substs.collect { case (FunctionInvocation(tfd, args), e) if tfd.hasPostcondition => - val Some((id, post)) = tfd.postcondition + val Some(post) = tfd.postcondition - replaceFromIDs((tfd.params.map(_.id) zip args).toMap + (id -> e), post) + application(replaceFromIDs((tfd.params.map(_.id) zip args).toMap, post), Seq(e)) } if (substs.nonEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/IntInduction.scala b/src/main/scala/leon/synthesis/rules/IntInduction.scala index 3f139099b..5a611320f 100644 --- a/src/main/scala/leon/synthesis/rules/IntInduction.scala +++ b/src/main/scala/leon/synthesis/rules/IntInduction.scala @@ -50,7 +50,7 @@ case object IntInduction extends Rule("Int Induction") { val idPost = FreshIdentifier("res", tpe) newFun.precondition = Some(preIn) - newFun.postcondition = Some((idPost, letTuple(p.xs.toSeq, Variable(idPost), p.phi))) + newFun.postcondition = Some(Lambda(Seq(ValDef(idPost)), letTuple(p.xs.toSeq, Variable(idPost), p.phi))) newFun.body = Some( IfExpr(Equals(Variable(inductOn), InfiniteIntegerLiteral(0)), diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala index 801db775a..3f62a4a55 100644 --- a/src/main/scala/leon/termination/RelationBuilder.scala +++ b/src/main/scala/leon/termination/RelationBuilder.scala @@ -21,7 +21,7 @@ trait RelationBuilder { self: TerminationChecker with Strengthener => protected def funDefRelationSignature(fd: FunDef): RelationSignature = { val strengthenedCallees = self.program.callGraph.callees(fd).map(fd => fd -> strengthened(fd)) - (fd, fd.precondition, fd.body, fd.postcondition map {_._2}, self.terminates(fd).isGuaranteed, strengthenedCallees) + (fd, fd.precondition, fd.body, fd.postcondition, self.terminates(fd).isGuaranteed, strengthenedCallees) } private val relationCache : MutableMap[FunDef, (Set[Relation], RelationSignature)] = MutableMap.empty diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala index 47bf1f198..caeb10e36 100644 --- a/src/main/scala/leon/termination/SimpleTerminationChecker.scala +++ b/src/main/scala/leon/termination/SimpleTerminationChecker.scala @@ -87,15 +87,11 @@ class SimpleTerminationChecker(context: LeonContext, program: Program) extends T // Now we apply a simple recipe: we check that in each (self) // call, at least one argument is of an ADT type and decreases. // Yes, it's that restrictive. - val callsOfInterest = { (oe: Option[Expr]) => - oe.map { e => - functionCallsOf( - simplifyLets( - matchToIfThenElse(e))).filter(_.tfd.fd == funDef) - } getOrElse Set.empty[FunctionInvocation] - } + val callsOfInterest = { (e: Expr) => + functionCallsOf(simplifyLets(matchToIfThenElse(e))).filter(_.tfd.fd == funDef) + } - val callsToAnalyze = callsOfInterest(funDef.body) ++ callsOfInterest(funDef.precondition) ++ callsOfInterest(funDef.postcondition map { _._2 }) + val callsToAnalyze = callsOfInterest(funDef.fullBody) val funDefArgsIDs = funDef.params.map(_.id).toSet diff --git a/src/main/scala/leon/termination/Strengthener.scala b/src/main/scala/leon/termination/Strengthener.scala index b45d0f874..d0f097d63 100644 --- a/src/main/scala/leon/termination/Strengthener.scala +++ b/src/main/scala/leon/termination/Strengthener.scala @@ -22,19 +22,20 @@ trait Strengthener { self : TerminationChecker with RelationComparator with Rela for (funDef <- sortedCallees if !strengthenedPost(funDef) && funDef.hasBody && self.terminates(funDef).isGuaranteed) { def strengthen(cmp: (Expr, Expr) => Expr): Boolean = { val old = funDef.postcondition - val (res, postcondition) = { - val (res, post) = old.getOrElse(FreshIdentifier("res", funDef.returnType) -> BooleanLiteral(true)) + val postcondition = { + val res = FreshIdentifier("res", funDef.returnType, true) + val post = old.map{application(_, Seq(Variable(res)))}.getOrElse(BooleanLiteral(true)) val args = funDef.params.map(_.toVariable) val sizePost = cmp(tupleWrap(funDef.params.map(_.toVariable)), res.toVariable) - (res, and(post, sizePost)) + Lambda(Seq(ValDef(res)), and(post, sizePost)) } - funDef.postcondition = Some(res -> postcondition) + funDef.postcondition = Some(postcondition) val prec = matchToIfThenElse(funDef.precondition.getOrElse(BooleanLiteral(true))) val body = matchToIfThenElse(funDef.body.get) val post = matchToIfThenElse(postcondition) - val formula = implies(prec, Let(res, body, post)) + val formula = implies(prec, application(post, Seq(body))) if (!solver.definitiveALL(formula)) { funDef.postcondition = old diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index 69c22f0a2..b255c58b6 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -39,10 +39,10 @@ trait StructuralSize { val body = simplifyLets(matchToIfThenElse(matchExpr(argument.toVariable, cases(argumentType)))) val postId = FreshIdentifier("res", IntegerType) - val postcondition = GreaterThan(Variable(postId), InfiniteIntegerLiteral(0)) + val postcondition = Lambda(Seq(ValDef(postId)), GreaterThan(Variable(postId), InfiniteIntegerLiteral(0))) fd.body = Some(body) - fd.postcondition = Some(postId, postcondition) + fd.postcondition = Some(postcondition) fd } } diff --git a/src/main/scala/leon/utils/TypingPhase.scala b/src/main/scala/leon/utils/TypingPhase.scala index b889f0fe4..1b9b86f26 100644 --- a/src/main/scala/leon/utils/TypingPhase.scala +++ b/src/main/scala/leon/utils/TypingPhase.scala @@ -49,15 +49,16 @@ object TypingPhase extends LeonPhase[Program, Program] { fd.postcondition = fd.returnType match { case cct : CaseClassType if cct.parent.isDefined => { - + val resId = FreshIdentifier("res", cct) fd.postcondition match { - case Some((id, p)) => - Some((id, and(CaseClassInstanceOf(cct, Variable(id)).setPos(p), p).setPos(p))) + case Some(p) => + Some(Lambda(Seq(ValDef(resId)), and( + application(p, Seq(Variable(resId))), + CaseClassInstanceOf(cct, Variable(resId)) + ).setPos(p)).setPos(p)) case None => - val resId = FreshIdentifier("res", cct) - - Some((resId, CaseClassInstanceOf(cct, Variable(resId)))) + Some(Lambda(Seq(ValDef(resId)), CaseClassInstanceOf(cct, Variable(resId)))) } } case _ => fd.postcondition diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index 2cca31c74..c85374400 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -23,13 +23,10 @@ object UnitElimination extends TransformationPhase { val newUnits = pgm.units map { u => u.copy(modules = u.modules.map { m => fun2FreshFun = Map() val allFuns = m.definedFunctions - //first introduce new signatures without Unit parameters allFuns.foreach(fd => { if(fd.returnType != UnitType && fd.params.exists(vd => vd.getType == UnitType)) { val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.tparams, fd.returnType, fd.params.filterNot(vd => vd.getType == UnitType), fd.defType).setPos(fd) - freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. - freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. freshFunDef.addAnnotation(fd.annotations.toSeq:_*) fun2FreshFun += (fd -> freshFunDef) } else { @@ -38,12 +35,11 @@ object UnitElimination extends TransformationPhase { }) //then apply recursively to the bodies - val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else { - val newBody = fd.body.map(body => removeUnit(body)) + val newFuns = allFuns.collect{ case fd if fd.returnType != UnitType => val newFd = fun2FreshFun(fd) - newFd.body = newBody - Seq(newFd) - }) + newFd.fullBody = removeUnit(fd.fullBody) + newFd + } ModuleDef(m.id, m.definedClasses ++ newFuns, m.isStandalone ) })} diff --git a/src/main/scala/leon/verification/DefaultTactic.scala b/src/main/scala/leon/verification/DefaultTactic.scala index f6767cf2c..665cf21f0 100644 --- a/src/main/scala/leon/verification/DefaultTactic.scala +++ b/src/main/scala/leon/verification/DefaultTactic.scala @@ -18,9 +18,8 @@ class DefaultTactic(vctx: VerificationContext) extends Tactic(vctx) { def generatePostconditions(fd: FunDef): Seq[VerificationCondition] = { (fd.postcondition, fd.body) match { - case (Some((id, post)), Some(body)) => - val res = id.freshen - val vc = implies(precOrTrue(fd), Let(res, body, replace(Map(id.toVariable -> res.toVariable), post))) + case (Some(post), Some(body)) => + val vc = implies(precOrTrue(fd), application(post, Seq(body))) Seq(new VerificationCondition(vc, fd, VCPostcondition, this).setPos(post)) case _ => @@ -70,7 +69,7 @@ class DefaultTactic(vctx: VerificationContext) extends Tactic(vctx) { (a, kind, cond) case a @ Assert(cond, None, _) => (a, VCAssert, cond) // Only triggered for inner ensurings, general postconditions are handled by generatePostconditions - case a @ Ensuring(body, id, post) => (a, VCAssert, Let(id, body, post)) + case a @ Ensuring(body, post) => (a, VCAssert, Application(post, Seq(body))) }(body) calls.map { diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index 7137f6215..9d7a7d890 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -31,21 +31,20 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { override def generatePostconditions(fd: FunDef): Seq[VerificationCondition] = { (fd.body, firstAbsClassDef(fd.params), fd.postcondition) match { - case (Some(b), Some((parentType, arg)), Some((id, p))) => - val post = p - val body = b - + case (Some(body), Some((parentType, arg)), Some(post)) => for (cct <- parentType.knownCCDescendents) yield { val selectors = selectorsOfParentType(parentType, cct, arg.toVariable) val subCases = selectors.map { sel => - val res = id.freshen replace(Map(arg.toVariable -> sel), - implies(precOrTrue(fd), Let(res, body, replace(Map(id.toVariable -> res.toVariable), post))) + implies(precOrTrue(fd), application(post, Seq(body))) ) } - val vc = implies(and(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd)), implies(andJoin(subCases), Let(id, body, post))) + val vc = implies( + and(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd)), + implies(andJoin(subCases), application(post, Seq(body))) + ) new VerificationCondition(vc, fd, VCPostcondition, this).setPos(fd) } diff --git a/src/main/scala/leon/xlang/ArrayTransformation.scala b/src/main/scala/leon/xlang/ArrayTransformation.scala index 513052d1e..32d64e95f 100644 --- a/src/main/scala/leon/xlang/ArrayTransformation.scala +++ b/src/main/scala/leon/xlang/ArrayTransformation.scala @@ -17,23 +17,14 @@ object ArrayTransformation extends TransformationPhase { val name = "Array Transformation" val description = "Add bound checking for array access and remove array update with side effect" - private var id2FreshId = Map[Identifier, Identifier]() - def apply(ctx: LeonContext, pgm: Program): Program = { - - id2FreshId = Map() - val allFuns = pgm.definedFunctions - allFuns.foreach(fd => { - id2FreshId = Map() - fd.precondition = fd.precondition.map(transform) - fd.body = fd.body.map(transform) - fd.postcondition = fd.postcondition.map { case (id, post) => (id, transform(post)) } + pgm.definedFunctions.foreach(fd => { + fd.fullBody = transform(fd.fullBody)(Map()) }) pgm } - - def transform(expr: Expr): Expr = (expr match { + def transform(expr: Expr)(implicit env: Map[Identifier, Identifier]): Expr = (expr match { case up@ArrayUpdate(a, i, v) => { val ra = transform(a) val ri = transform(i) @@ -45,15 +36,14 @@ object ArrayTransformation extends TransformationPhase { v.getType match { case ArrayType(_) => { val freshIdentifier = FreshIdentifier("t", i.getType) - id2FreshId += (i -> freshIdentifier) - LetVar(freshIdentifier, transform(v), transform(b)) + val newEnv = env + (i -> freshIdentifier) + LetVar(freshIdentifier, transform(v)(newEnv), transform(b)(newEnv)) } case _ => Let(i, transform(v), transform(b)) } } case v@Variable(i) => { - val freshId = id2FreshId.get(i).getOrElse(i) - Variable(freshId) + Variable(env.getOrElse(i, i)) } case LetVar(id, e, b) => { @@ -82,9 +72,7 @@ object ArrayTransformation extends TransformationPhase { matchExpr(scrutRec, csesRec).setPos(m) } case LetDef(fd, b) => { - fd.precondition = fd.precondition.map(transform) - fd.body = fd.body.map(transform) - fd.postcondition = fd.postcondition.map { case (id, post) => (id, transform(post)) } + fd.fullBody = transform(fd.fullBody) val rb = transform(b) LetDef(fd, rb) } diff --git a/src/main/scala/leon/xlang/EpsilonElimination.scala b/src/main/scala/leon/xlang/EpsilonElimination.scala index fec576c08..e90c53707 100644 --- a/src/main/scala/leon/xlang/EpsilonElimination.scala +++ b/src/main/scala/leon/xlang/EpsilonElimination.scala @@ -22,12 +22,12 @@ object EpsilonElimination extends TransformationPhase { allFuns.foreach(fd => fd.body.map(body => { val newBody = postMap{ case eps@Epsilon(pred, tpe) => - val freshName = FreshIdentifier("epsilon") - val newFunDef = new FunDef(freshName, Nil, tpe, Seq(), DefType.MethodDef) - val epsilonVar = EpsilonVariable(eps.getPos, tpe) - val resId = FreshIdentifier("res", tpe) + val freshName = FreshIdentifier("epsilon") + val newFunDef = new FunDef(freshName, Nil, tpe, Seq(), DefType.MethodDef) + val epsilonVar = EpsilonVariable(eps.getPos, tpe) + val resId = FreshIdentifier("res", tpe) val postcondition = replace(Map(epsilonVar -> Variable(resId)), pred) - newFunDef.postcondition = Some((resId, postcondition)) + newFunDef.postcondition = Some(Lambda(Seq(ValDef(resId)), postcondition)) Some(LetDef(newFunDef, FunctionInvocation(newFunDef.typed, Seq()))) case _ => diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 67faab2e3..e2c3af95f 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -180,7 +180,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val invariantPostcondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2ResultVars, expr)) whileFunDef.precondition = invariantPrecondition whileFunDef.postcondition = trivialPostcondition.map(expr => - (resVar.id, and(expr, invariantPostcondition match { + Lambda(Seq(ValDef(resVar.id)), and(expr, invariantPostcondition match { case Some(e) => e case None => BooleanLiteral(true) }))) @@ -237,7 +237,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val (bodyRes, bodyScope, bodyFun) = toFunction(b) (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)).copiedFrom(expr), bodyFun) } - case c @ Choose(ids, b, _) => { + case c @ Choose(b, _) => { //Recall that Choose cannot mutate variables from the scope (c, (b2: Expr) => b2, Map()) } diff --git a/src/test/resources/regression/verification/purescala/valid/Acc.scala b/src/test/resources/regression/verification/purescala/valid/Acc.scala index 681f75059..6dce9c294 100644 --- a/src/test/resources/regression/verification/purescala/valid/Acc.scala +++ b/src/test/resources/regression/verification/purescala/valid/Acc.scala @@ -7,7 +7,7 @@ object Acc { def putAside(x: BigInt, a: Acc): Acc = { require (x > 0 && notRed(a) && a.checking >= x) - Acc(a.checking - x, a.savings + x) + Acc(a.checking - x, a.savings + x) } ensuring { r => notRed(r) && sameTotal(a, r) } diff --git a/src/test/scala/leon/test/solvers/UnrollingSolverTests.scala b/src/test/scala/leon/test/solvers/UnrollingSolverTests.scala index d451a4cdc..8fb31f280 100644 --- a/src/test/scala/leon/test/solvers/UnrollingSolverTests.scala +++ b/src/test/scala/leon/test/solvers/UnrollingSolverTests.scala @@ -19,7 +19,7 @@ class UnrollingSolverTests extends LeonTestSuite { Plus(Variable(fx), FunctionInvocation(fDef.typed, Seq(Minus(Variable(fx), InfiniteIntegerLiteral(1))))), InfiniteIntegerLiteral(1) )) - fDef.postcondition = Some(fres -> GreaterThan(Variable(fres), InfiniteIntegerLiteral(0))) + fDef.postcondition = Some(Lambda(Seq(ValDef(fres)), GreaterThan(Variable(fres), InfiniteIntegerLiteral(0)))) private val program = Program( FreshIdentifier("Minimal"), -- GitLab