diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index fefbde0456bd65ea01d7e08c13230d402d417b5e..6467226a0a788866891cad6ec183a89569bc5a66 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -403,7 +403,9 @@ object Definitions { def directlyNestedFuns = directlyNestedFunDefs(fullBody) def subDefinitions = params ++ tparams ++ directlyNestedFuns.toList - /* Duplication */ + /** Duplication of this [[FunDef]]. + * @note This will not replace recursive function calls + */ def duplicate( id: Identifier = this.id.freshen, tparams: Seq[TypeParameterDef] = this.tparams, diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index e86fef5af7074e20d516ef43c95b9ce40965867b..20dbd1eeced9c951221665b2a3136b2225681c86 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -17,9 +17,15 @@ object FunctionClosure extends TransformationPhase { override val name: String = "Function Closure" override val description: String = "Closing function with its scoping variables" + /** Takes a FunDef and returns a Seq of all internal FunDef's contained in fd in closed form + * (and fd itself, without inned FunDef's). + * + * The strategy is as follows: Remove one layer of nested FunDef's, then call + * close recursively on the new functions. + */ private def close(fd: FunDef): Seq[FunDef] = { - // Directly neste functions with their p.c. + // Directly nested functions with their p.c. val nestedWithPaths = { val funDefs = directlyNestedFunDefs(fd.fullBody) collectWithPC { @@ -54,7 +60,7 @@ object FunctionClosure extends TransformationPhase { case (inner, pc) => inner -> step(inner, fd, pc, transFree(inner)) } - // Remove LetDefs + // Remove LetDefs from fd fd.fullBody = preMap({ case LetDef(fd, bd) => Some(bd) @@ -62,6 +68,7 @@ object FunctionClosure extends TransformationPhase { None }, applyRec = true)(fd.fullBody) + // A dummy substitution for fd, saying we should not change parameters val dummySubst = FunSubst( fd, Map.empty.withDefault(id => id), @@ -69,47 +76,52 @@ object FunctionClosure extends TransformationPhase { ) // Refresh function calls - (dummySubst +: closed.values.toSeq).foreach { case FunSubst(f, paramsMap, tparamsMap) => - //println(f) - //paramsMap foreach { case (from, to) => - // println(from.uniqueName + " -> " + to.uniqueName) - //} - f.fullBody = preMap { - case fi@FunctionInvocation(tfd, args) if closed contains tfd.fd => - val FunSubst(newFd, newParams, newTParams) = closed(tfd.fd) - - // New -> old map for function call - val mapReverse = newParams map { _.swap } - val extraArgs = newFd.paramIds.drop(args.size).map { id => - paramsMap(mapReverse(id)).toVariable - } - - // Similarly for type params - val tReverse = newTParams map { _.swap } - val tOrigExtraOrdered = newFd.tparams.map{_.tp}.drop(tfd.tps.length).map(tReverse) - val tFinalExtra: Seq[TypeParameter] = tOrigExtraOrdered.map( tp => - tparamsMap(tp) - ) - - Some(FunctionInvocation( - newFd.typed(tfd.tps ++ tFinalExtra), - args ++ extraArgs - ).copiedFrom(fi)) - case _ => None - }(f.fullBody) + (dummySubst +: closed.values.toSeq).foreach { + case FunSubst(f, callerMap, callerTMap) => + f.fullBody = preMap { + case fi@FunctionInvocation(tfd, args) if closed contains tfd.fd => + val FunSubst(newCallee, calleeMap, calleeTMap) = closed(tfd.fd) + + // This needs some explanation. + // Say we have caller and callee. First we find the param. substitutions of callee + // (say old -> calleeNew) and reverse them. So we have a mapping (calleeNew -> old). + // We also have the caller mapping, (old -> callerNew). + // So we pass the callee parameters through these two mappings to get the caller parameters. + val mapReverse = calleeMap map { _.swap } + val extraArgs = newCallee.paramIds.drop(args.size).map { id => + callerMap(mapReverse(id)).toVariable + } + + // Similarly for type params + val tReverse = calleeTMap map { _.swap } + val tOrigExtraOrdered = newCallee.tparams.map{_.tp}.drop(tfd.tps.length).map(tReverse) + val tFinalExtra: Seq[TypeParameter] = tOrigExtraOrdered.map( tp => + callerTMap(tp) + ) + + Some(FunctionInvocation( + newCallee.typed(tfd.tps ++ tFinalExtra), + args ++ extraArgs + ).copiedFrom(fi)) + case _ => None + }(f.fullBody) } val funs = closed.values.toSeq.map{ _.newFd } + // Recursively close new functions fd +: funs.flatMap(close) } + // Represents a substitution to a new function, along with parameter and type parameter + // mappings private case class FunSubst( newFd: FunDef, paramsMap: Map[Identifier, Identifier], tparamsMap: Map[TypeParameter, TypeParameter] ) + // Takes one inner function and closes it. private def step(inner: FunDef, outer: FunDef, pc: Expr, free: Seq[Identifier]): FunSubst = { val tpFresh = outer.tparams map { _.freshen }