diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 1c3a6fd4876751e1ffc2534a83d52110288d3040..f457bb4b39c266b1f45e530495187b5ee92fad5d 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -83,9 +83,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Application(caller, args) => e(caller) match { - case Lambda(params, body) => + case l@Lambda(params, body) => val newArgs = args.map(e) - val mapping = (params.map(_.id) zip newArgs).toMap + val mapping = l.substitutions(newArgs) e(body)(rctx.withNewVars(mapping), gctx) case f => throw EvalError("Cannot apply non-lambda function " + f) @@ -142,10 +142,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } gctx.stepsLeft -= 1 - val evArgs = args.map(a => e(a)) + val evArgs = args map e // build a mapping for the function... - val frame = rctx.newVars((tfd.params.map(_.id) zip evArgs).toMap) + val frame = rctx.newVars(tfd.paramSubst(evArgs)) if(tfd.hasPrecondition) { e(tfd.precondition.get)(frame, gctx) match { diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index b6bb072f0c9b14f329b92d45aea9212d09320d02..d6388822cf58f7ac89bb73a069e0c3c3f3cf7225 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -58,7 +58,7 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex val evArgs = args.map(a => e(a)) // build a mapping for the function... - val frame = new TracingRecContext((tfd.params.map(_.id) zip evArgs).toMap, rctx.tracingFrames-1) + val frame = new TracingRecContext(tfd.paramSubst(evArgs), rctx.tracingFrames-1) if(tfd.hasPrecondition) { e(tfd.precondition.get)(frame, gctx) match { diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index 16f006cd4c7ef0eb7364e05c0eeb036002aa78f6..273128e0a3f975c203e01ae3ea1d73b5d2574044 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -6,7 +6,7 @@ package purescala import utils._ import Expressions.Variable import Types._ -import Definitions.{Program, Definition} +import Definitions.Program object Common { diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 4ecc8067659de12231d01085b0161966c3f8a7fc..eda9e8d3e65a9668885ec07927ea318c411784ee 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -454,6 +454,22 @@ object Definitions { def translated(e: Expr): Expr = instantiateType(e, typesMap, paramsMap) + def paramSubst(realArgs: Seq[Expr]) = { + require(realArgs.size == params.size) + (params map { _.id } zip realArgs).toMap + } + + def withParamSubst(realArgs: Seq[Expr], e: Expr) = { + replaceFromIDs(paramSubst(realArgs), e) + } + + def applied(realArgs: Seq[Expr]): FunctionInvocation = { + FunctionInvocation(this, realArgs) + } + + def applied: FunctionInvocation = + applied(params map { _.toVariable }) + /** * Params will return ValDefs instantiated with the correct types * For such a ValDef(id,tp) it may hold that (id.getType != tp) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index ad3216436a99594babc223f25b8b1bb4e539d2de..9c72b465c58908cd1bc3b20d910c0638a9fc0600 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -9,7 +9,6 @@ import Definitions._ import Expressions._ import Extractors._ import Constructors._ -import DefOps._ import utils.Simplifiers import solvers._ @@ -1862,8 +1861,8 @@ object ExprOps { Let(i, e, apply(b, args)) case LetTuple(is, es, b) => letTuple(is, es, apply(b, args)) - case Lambda(params, body) => - replaceFromIDs((params.map(_.id) zip args).toMap, body) + case l@Lambda(params, body) => + l.withSubstitutions(args, body) case _ => Application(expr, args) } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index e852d55c0e71f58bcedf0ee0742897b572969a68..03385aec2645a0945ebd549e7f22a50bca8beb4b 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -9,6 +9,7 @@ import TypeOps._ import Definitions._ import Extractors._ import Constructors._ +import ExprOps.replaceFromIDs /** AST definitions for Pure Scala. */ object Expressions { @@ -112,6 +113,13 @@ object Expressions { case class Lambda(args: Seq[ValDef], body: Expr) extends Expr { val getType = FunctionType(args.map(_.getType), body.getType).unveilUntyped + def substitutions(realArgs: Seq[Expr]) = { + require(realArgs.size == args.size) + (args map { _.id } zip realArgs).toMap + } + def withSubstitutions(realArgs: Seq[Expr], e: Expr) = { + replaceFromIDs(substitutions(realArgs), e) + } } case class Forall(args: Seq[ValDef], body: Expr) extends Expr { diff --git a/src/main/scala/leon/repair/RepairNDEvaluator.scala b/src/main/scala/leon/repair/RepairNDEvaluator.scala index 6c9337b8f05fb9e811a1ea6b3d5100f0a63cb641..d3e0df1746b1fb0cb12c95764362a58b2215bd5b 100644 --- a/src/main/scala/leon/repair/RepairNDEvaluator.scala +++ b/src/main/scala/leon/repair/RepairNDEvaluator.scala @@ -28,7 +28,7 @@ class RepairNDEvaluator(ctx: LeonContext, prog: Program, fd : FunDef, cond: Expr val evArgs = args.map(a => e(a)) // build a mapping for the function... - val frame = rctx.newVars((tfd.params.map(_.id) zip evArgs).toMap) + val frame = rctx.newVars(tfd.paramSubst(evArgs)) if(tfd.hasPrecondition) { e(tfd.precondition.get)(frame, gctx) match { diff --git a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala index bb52a5b7c1fcabf60904244b006ee94a20aa118d..52ac89493691c893d99e1964b5e269ab183b4e4a 100644 --- a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala +++ b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala @@ -60,7 +60,7 @@ class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends Recursive val evArgs = args.map(a => e(a)) // build a mapping for the function... - val frameBlamingCaller = rctx.newVars((tfd.params.map(_.id) zip evArgs).toMap) + val frameBlamingCaller = rctx.newVars(tfd.paramSubst(evArgs)) if(tfd.hasPrecondition) { e(tfd.precondition.get)(frameBlamingCaller, gctx) match { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala index d078f1fdb90fe25f813c85754243700422565954..93730b3c4f0c28c0e4c7623b9a999be33d0c4537 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4QuantifiedSolver.scala @@ -105,7 +105,7 @@ abstract class SMTLIBCVC4QuantifiedSolver(context: LeonContext, program: Program } { val term = implies( tfd.precondition getOrElse BooleanLiteral(true), - application(post, Seq(FunctionInvocation(tfd, tfd.params map { _.toVariable}))) + application(post, Seq(tfd.applied)) ) try { sendCommand(SMTAssert(quantifiedTerm(SMTForall, term))) diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index be8b27f3cfaceba92677e96d7182aa22b81235b6..8470a974459017ed90c3d6ac69c9e4ea703dc0d0 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -33,7 +33,7 @@ final case class Chain(relations: List[Relation]) { def rec(list: List[Relation], funDef: TypedFunDef, args: Seq[Expr]): Seq[(Seq[ValDef], Expr)] = list match { case Relation(_, _, fi @ FunctionInvocation(fitfd, nextArgs), _) :: xs => val tfd = TypedFunDef(fitfd.fd, fitfd.tps.map(funDef.translated)) - val subst = (tfd.params.map(_.id) zip args).toMap + val subst = tfd.paramSubst(args) val expr = replaceFromIDs(subst, hoistIte(expandLets(matchToIfThenElse(tfd.body.get)))) val mappedArgs = nextArgs.map(e => replaceFromIDs(subst, tfd.translated(e))) diff --git a/src/main/scala/leon/verification/DefaultTactic.scala b/src/main/scala/leon/verification/DefaultTactic.scala index 3026068a141841162064e46d6e9c3a462837bb93..2333c319170b5cd401d3dd177dd3104f78606a9b 100644 --- a/src/main/scala/leon/verification/DefaultTactic.scala +++ b/src/main/scala/leon/verification/DefaultTactic.scala @@ -31,7 +31,7 @@ class DefaultTactic(vctx: VerificationContext) extends Tactic(vctx) { calls.map { case ((fi @ FunctionInvocation(tfd, args), pre), path) => - val pre2 = replaceFromIDs((tfd.params.map(_.id) zip args).toMap, pre) + val pre2 = tfd.withParamSubst(args, pre) val vc = implies(and(precOrTrue(fd), path), pre2) val fiS = sizeLimit(fi.toString, 40) VC(vc, fd, VCKinds.Info(VCKinds.Precondition, s"call $fiS"), this).setPos(fi) diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index 4352adc27a6e877b9607009ed534d1ba93c7e773..529b9e5fc859d8cbb8f325bcc39d2b7b1e5cd618 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -72,13 +72,13 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { val subCases = selectors.map { sel => replace(Map(arg.toVariable -> sel), - implies(precOrTrue(fd), replace((tfd.params.map(_.toVariable) zip args).toMap, pre)) + implies(precOrTrue(fd), tfd.withParamSubst(args, pre)) ) } val vc = implies( andJoin(Seq(CaseClassInstanceOf(cct, arg.toVariable), precOrTrue(fd), path) ++ subCases), - replace((tfd.params.map(_.toVariable) zip args).toMap, pre) + tfd.withParamSubst(args, pre) ) // Crop the call to display it properly