From 7f17d218ab92c1473754d79dfcb721c4c9dabd8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <a-mikmay@microsoft.com> Date: Tue, 8 Dec 2015 15:38:53 +0100 Subject: [PATCH] Added support for LetDef with mutually recursive functions. --- .../frontends/scalac/CodeExtraction.scala | 20 ++- src/main/scala/leon/purescala/ExprOps.scala | 67 +++++---- .../scala/leon/purescala/Expressions.scala | 6 +- .../scala/leon/purescala/Extractors.scala | 10 +- .../leon/purescala/FunctionClosure.scala | 8 +- .../leon/purescala/ScopeSimplifier.scala | 41 ++++-- src/main/scala/leon/purescala/TypeOps.scala | 47 +++--- src/main/scala/leon/synthesis/Solution.scala | 2 +- .../leon/termination/SelfCallsProcessor.scala | 2 +- .../transformations/StackSpacePhase.scala | 2 +- .../scala/leon/utils/UnitElimination.scala | 46 ++++-- .../scala/leon/xlang/EpsilonElimination.scala | 2 +- .../xlang/ImperativeCodeElimination.scala | 138 +++++++++++------- 13 files changed, 238 insertions(+), 153 deletions(-) diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index c50aae653..81d146863 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1104,18 +1104,25 @@ trait CodeExtraction extends ASTExtractors { val newDctx = dctx.copy(tparams = dctx.tparams ++ tparamsMap) + val restTree = rest match { + case Some(rst) => extractTree(rst) + case None => UnitLiteral() + } + rest = None + val oldCurrentFunDef = currentFunDef val funDefWithBody = extractFunBody(fd, params, b)(newDctx) currentFunDef = oldCurrentFunDef - - val restTree = rest match { - case Some(rst) => extractTree(rst) - case None => UnitLiteral() + + val (other_fds, block) = restTree match { + case LetDef(fds, block) => + (fds, block) + case _ => + (Nil, restTree) } - rest = None - LetDef(funDefWithBody, restTree) + LetDef(funDefWithBody +: other_fds, block) // FIXME case ExDefaultValueFunction @@ -1495,6 +1502,7 @@ trait CodeExtraction extends ASTExtractors { Implies(extractTree(lhs), extractTree(rhs)).setPos(current.pos) case c @ ExCall(rec, sym, tps, args) => + // The object on which it is called is null if the symbol sym is a valid function in the scope and not a method. val rrec = rec match { case t if (defsToDefs contains sym) && !isMethod(sym) => null diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index a7a20e2d4..0d096025a 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -309,9 +309,11 @@ object ExprOps { def preTransformWithBinders(f: (Expr, Set[Identifier]) => Expr, initBinders: Set[Identifier] = Set())(e: Expr) = { import xlang.Expressions.LetVar def rec(binders: Set[Identifier], e: Expr): Expr = (f(e, binders) match { - case LetDef(fd, bd) => - fd.fullBody = rec(binders ++ fd.paramIds, fd.fullBody) - LetDef(fd, rec(binders, bd)) + case LetDef(fds, bd) => + fds.foreach(fd => { + fd.fullBody = rec(binders ++ fd.paramIds, fd.fullBody) + }) + LetDef(fds, rec(binders, bd)) case Let(i, v, b) => Let(i, rec(binders + i, v), rec(binders + i, b)) case LetVar(i, v, b) => @@ -346,7 +348,7 @@ object ExprOps { e match { case Variable(i) => subvs + i case Old(i) => subvs + i - case LetDef(fd, _) => subvs -- fd.params.map(_.id) + case LetDef(fds, _) => subvs -- fds.flatMap(_.params.map(_.id)) case Let(i, _, _) => subvs - i case LetVar(i, _, _) => subvs - i case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders) @@ -377,7 +379,7 @@ object ExprOps { /** Returns functions in directly nested LetDefs */ def directlyNestedFunDefs(e: Expr): Set[FunDef] = { fold[Set[FunDef]]{ - case (LetDef(fd,_), Seq(fromFd, fromBd)) => fromBd + fd + case (LetDef(fds,_), Seq(fromFds, fromBd)) => fromBd ++ fds case (_, subs) => subs.flatten.toSet }(e) } @@ -514,7 +516,7 @@ object ExprOps { (expr, idSeqs) => idSeqs.foldLeft(expr match { case Lambda(args, _) => args.map(_.id) case Forall(args, _) => args.map(_.id) - case LetDef(fd, _) => fd.paramIds + case LetDef(fds, _) => fds.flatMap(_.paramIds) case Let(i, _, _) => Seq(i) case MatchExpr(_, cses) => cses.flatMap(_.pattern.binders) case Passes(_, _, cses) => cses.flatMap(_.pattern.binders) @@ -1239,23 +1241,23 @@ object ExprOps { def pre(e : Expr) = e match { - case LetDef(fd, expr) if fd.hasPrecondition => - val pre = fd.precondition.get - - solver.solveVALID(pre) match { - case Some(true) => - fd.precondition = None + case LetDef(fds, expr) => + for(fd <- fds if fd.hasPrecondition) { + val pre = fd.precondition.get - case Some(false) => solver.solveSAT(pre) match { - case (Some(false), _) => - fd.precondition = Some(BooleanLiteral(false).copiedFrom(e)) - case _ => + solver.solveVALID(pre) match { + case Some(true) => + fd.precondition = None + + case Some(false) => solver.solveSAT(pre) match { + case (Some(false), _) => + fd.precondition = Some(BooleanLiteral(false).copiedFrom(e)) + case _ => + } + case None => } - case None => - } - - e - + } + e case IfExpr(cond, thenn, elze) => try { solver.solveVALID(cond) match { @@ -1630,9 +1632,15 @@ object ExprOps { isHomo(v1, v2) && isHomo(e1, e2)(map + (id1 -> id2)) - case (LetDef(fd1, e1), LetDef(fd2, e2)) => - fdHomo(fd1, fd2) && - isHomo(e1, e2)(map + (fd1.id -> fd2.id)) + case (LetDef(fds1, e1), LetDef(fds2, e2)) => + fds1.size == fds2.size && + { + val zipped = fds1.zip(fds2) + zipped.forall( fds => + fdHomo(fds._1, fds._2) + ) && + isHomo(e1, e2)(map ++ zipped.map(fds => fds._1.id -> fds._2.id)) + } case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) => cs1.size == cs2.size && isHomo(s1, s2) && casesMatch(cs1,cs2) @@ -1819,7 +1827,8 @@ object ExprOps { */ def flattenFunctions(fdOuter: FunDef, ctx: LeonContext, p: Program): FunDef = { fdOuter.body match { - case Some(LetDef(fdInner, FunctionInvocation(tfdInner2, args))) if fdInner == tfdInner2.fd => + case Some(LetDef(fdsInner, FunctionInvocation(tfdInner2, args))) if fdsInner.size == 1 && fdsInner.head == tfdInner2.fd => + val fdInner = fdsInner.head val argsDef = fdOuter.paramIds val argsCall = args.collect { case Variable(id) => id } @@ -2106,12 +2115,12 @@ object ExprOps { import synthesis.Witnesses.Terminating val res1 = preMap({ - case LetDef(fd, b) => - val nfd = fd.duplicate() + case LetDef(lfds, b) => + val nfds = lfds.map(fd => fd -> fd.duplicate()) - fds += fd -> nfd + fds ++= nfds - Some(LetDef(nfd, b)) + Some(LetDef(nfds.map(_._2), b)) case FunctionInvocation(tfd, args) => if (fds contains tfd.fd) { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 0915a0243..7030dfb94 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -159,12 +159,12 @@ object Expressions { } } - /** $encodingof `def ... = ...; ...` (local function definition) + /** $encodingof multiple `def ... = ...; ...` (local function definition and possibly mutually recursive) * - * @param fd The function definition. + * @param fds The function definitions. * @param body The body of the expression after the function */ - case class LetDef(fd: FunDef, body: Expr) extends Expr { + case class LetDef(fds: Seq[FunDef], body: Expr) extends Expr { val getType = body.getType } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 3a3bd7c7b..e2581dd8c 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -72,11 +72,13 @@ object Extractors { Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head))) /* Binary operators */ - case LetDef(fd, body) => Some(( - Seq(fd.fullBody, body), + case LetDef(fds, rest) => Some(( + fds.map(_.fullBody) ++ Seq(rest), (es: Seq[Expr]) => { - fd.fullBody = es(0) - LetDef(fd, es(1)) + for((fd, i) <- fds.zipWithIndex) { + fd.fullBody = es(i) + } + LetDef(fds, es(fds.length)) } )) case Equals(t1, t2) => diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 65fd1de8a..0a62d8e03 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -26,13 +26,13 @@ object FunctionClosure extends TransformationPhase { private def close(fd: FunDef): Seq[FunDef] = { // Directly nested functions with their p.c. - val nestedWithPaths = { + val nestedWithPathsFull = { val funDefs = directlyNestedFunDefs(fd.fullBody) collectWithPC { - case LetDef(fd1, body) if funDefs(fd1) => fd1 + case LetDef(fd1, body) => fd1.filter(funDefs) }(fd.fullBody) - }.toMap - + } + val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap val nestedFuns = nestedWithPaths.keys.toSeq // Transitively called funcions from each function diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index 9eca53058..7bae40c5b 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -34,23 +34,32 @@ class ScopeSimplifier extends Transformer { val sb = rec(b, scope.register(i -> si)) Let(si, se, sb) - case LetDef(fd: FunDef, body: Expr) => - val newId = genId(fd.id, scope) - var newScope = scope.register(fd.id -> newId) - - val newArgs = for(ValDef(id, tpe) <- fd.params) yield { - val newArg = genId(id, newScope) - newScope = newScope.register(id -> newArg) - ValDef(newArg, tpe) + case LetDef(fds, body: Expr) => + var newScope: Scope = scope + // First register all functions + val fds_newIds = for(fd <- fds) yield { + val newId = genId(fd.id, scope) + newScope = newScope.register(fd.id -> newId) + (fd, newId) } - - val newFd = fd.duplicate(id = newId, params = newArgs) - - newScope = newScope.registerFunDef(fd -> newFd) - - newFd.fullBody = rec(fd.fullBody, newScope) - - LetDef(newFd, rec(body, newScope)) + + val fds_mapping = for((fd, newId) <- fds_newIds) yield { + val newArgs = for(ValDef(id, tpe) <- fd.params) yield { + val newArg = genId(id, newScope) + newScope = newScope.register(id -> newArg) + ValDef(newArg, tpe) + } + + val newFd = fd.duplicate(id = newId, params = newArgs) + + newScope = newScope.registerFunDef(fd -> newFd) + (newFd, fd) + } + + for((newFd, fd) <- fds_mapping) { + newFd.fullBody = rec(fd.fullBody, newScope) + } + LetDef(fds_mapping.map(_._1), rec(body, newScope)) case MatchExpr(scrut, cases) => val rs = rec(scrut, scope) diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 51bed3eaf..f15b3e5af 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -307,27 +307,38 @@ object TypeOps { val newId = freshId(id, tpeSub(id.getType)) Let(newId, srec(value), rec(idsMap + (id -> newId))(body)).copiedFrom(l) - case l @ LetDef(fd, bd) => - val id = fd.id.freshen - val tparams = fd.tparams map { p => - TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter]) + case l @ LetDef(fds, bd) => + val fds_mapping = for(fd <- fds) yield { + val id = fd.id.freshen + val tparams = fd.tparams map { p => + TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter]) + } + val returnType = tpeSub(fd.returnType) + val params = fd.params map (instantiateType(_, tps)) + val newFd = fd.duplicate(id, tparams, params, returnType) + val subCalls = preMap { + case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd => + Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi)) + case _ => + None + } _ + (fd, newFd, subCalls) + } + // We group the subcalls functions all in once + val subCalls = (((None:Option[Expr => Expr]) /: fds_mapping) { + case (None, (_, _, subCalls)) => Some(subCalls) + case (Some(fn), (_, _, subCalls)) => Some(fn andThen subCalls) + }).get + + // We apply all the functions mappings at once + val newFds = for((fd, newFd, _) <- fds_mapping) yield { + val fullBody = rec(idsMap ++ fd.paramIds.zip(newFd.paramIds))(subCalls(fd.fullBody)) + newFd.fullBody = fullBody + newFd } - val returnType = tpeSub(fd.returnType) - val params = fd.params map (instantiateType(_, tps)) - val newFd = fd.duplicate(id, tparams, params, returnType) - - val subCalls = preMap { - case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd => - Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi)) - case _ => - None - } _ - val fullBody = rec(idsMap ++ fd.paramIds.zip(newFd.paramIds))(subCalls(fd.fullBody)) - newFd.fullBody = fullBody - val newBd = srec(subCalls(bd)).copiedFrom(bd) - LetDef(newFd, newBd).copiedFrom(l) + LetDef(newFds, newBd).copiedFrom(l) case l @ Lambda(args, body) => val newArgs = args.map { arg => diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index e87a7883d..ea52b9bab 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -31,7 +31,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrust } def toExpr = { - defs.foldLeft(guardedTerm){ case (t, fd) => LetDef(fd, t) } + LetDef(defs.toList, guardedTerm) } // Projects a solution (ignore several output variables) diff --git a/src/main/scala/leon/termination/SelfCallsProcessor.scala b/src/main/scala/leon/termination/SelfCallsProcessor.scala index 157dbce24..320c230c2 100644 --- a/src/main/scala/leon/termination/SelfCallsProcessor.scala +++ b/src/main/scala/leon/termination/SelfCallsProcessor.scala @@ -30,7 +30,7 @@ class SelfCallsProcessor(val checker: TerminationChecker) extends Processor { def rec(e0: Expr): Boolean = e0 match { case Assert(pred: Expr, error: Option[String], body: Expr) => rec(pred) || rec(body) case Let(binder: Identifier, value: Expr, body: Expr) => rec(value) || rec(body) - case LetDef(fd: FunDef, body: Expr) => rec(body) // don't enter fd because we don't know if it will be called + case LetDef(fds, body: Expr) => rec(body) // don't enter fds because we don't know if it will be called case FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) => tfd.fd == f /* <-- success in proving non-termination */ || args.exists(arg => rec(arg)) || (tfd.fd.hasBody && (!seenFunDefs.contains(tfd.fd)) && { diff --git a/src/main/scala/leon/transformations/StackSpacePhase.scala b/src/main/scala/leon/transformations/StackSpacePhase.scala index f4edb7ee1..43fb3930e 100644 --- a/src/main/scala/leon/transformations/StackSpacePhase.scala +++ b/src/main/scala/leon/transformations/StackSpacePhase.scala @@ -141,7 +141,7 @@ class StackSpaceInstrumenter(p: Program, si: SerialInstrumenter) extends Instrum (1 + valTemp + bodyTemp, Math.max(valStack, bodyStack)) } - case LetDef(fd: FunDef, body: Expr) => { + case LetDef(fds, body: Expr) => { // The function definition does not take up stack space. Goes into the constant pool estimateTemporaries(body) } diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index 3d486b57f..f4f603393 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -93,25 +93,39 @@ object UnitElimination extends TransformationPhase { } } - case LetDef(fd, b) => - if(fd.returnType == UnitType) + case LetDef(fds, b) => + val nonUnits = fds.filter(fd => fd.returnType != UnitType) + if(nonUnits.isEmpty) { removeUnit(b) - else { - val (newFd, rest) = if(fd.params.exists(vd => vd.getType == UnitType)) { - val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType)) - fun2FreshFun += (fd -> freshFunDef) - freshFunDef.fullBody = removeUnit(fd.fullBody) - val restRec = removeUnit(b) - fun2FreshFun -= fd - (freshFunDef, restRec) - } else { - fun2FreshFun += (fd -> fd) - fd.body = fd.body.map(b => removeUnit(b)) - val restRec = removeUnit(b) + } else { + val fdtoFreshFd = for(fd <- nonUnits) yield { + val m = if(fd.params.exists(vd => vd.getType == UnitType)) { + val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType)) + fd -> freshFunDef + } else { + fd -> fd + } + fun2FreshFun += m + m + } + for((fd, freshFunDef) <- fdtoFreshFd) { + if(fd.params.exists(vd => vd.getType == UnitType)) { + freshFunDef.fullBody = removeUnit(fd.fullBody) + } else { + fd.body = fd.body.map(b => removeUnit(b)) + } + } + val rest = removeUnit(b) + val newFds = for((fd, freshFunDef) <- fdtoFreshFd) yield { fun2FreshFun -= fd - (fd, restRec) + if(fd.params.exists(vd => vd.getType == UnitType)) { + freshFunDef + } else { + fd + } } - LetDef(newFd, rest) + + LetDef(newFds, rest) } case ite@IfExpr(cond, tExpr, eExpr) => diff --git a/src/main/scala/leon/xlang/EpsilonElimination.scala b/src/main/scala/leon/xlang/EpsilonElimination.scala index a09eeb637..51b23be1b 100644 --- a/src/main/scala/leon/xlang/EpsilonElimination.scala +++ b/src/main/scala/leon/xlang/EpsilonElimination.scala @@ -30,7 +30,7 @@ object EpsilonElimination extends UnitPhase[Program] { }.toMap ++ Seq((epsilonVar, Variable(resId))) val postcondition = replace(eMap, pred) newFunDef.postcondition = Some(Lambda(Seq(ValDef(resId)), postcondition)) - LetDef(newFunDef, FunctionInvocation(newFunDef.typed, bSeq map Variable)) + LetDef(Seq(newFunDef), FunctionInvocation(newFunDef.typed, bSeq map Variable)) case (other, _) => other }, fd.paramIds.toSet)(fd.fullBody) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index e9802d500..600640c7b 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -28,9 +28,9 @@ object ImperativeCodeElimination extends UnitPhase[Program] { } } - /* varsInScope refers to variable declared in the same level scope. - Typically, when entering a nested function body, the scope should be - reset to empty */ + /** varsInScope refers to variable declared in the same level scope. + * Typically, when entering a nested function body, the scope should be + * reset to empty */ private case class State( parent: FunDef, varsInScope: Set[Identifier], @@ -39,12 +39,14 @@ object ImperativeCodeElimination extends UnitPhase[Program] { def withVar(i: Identifier) = copy(varsInScope = varsInScope + i) def withFunDef(fd: FunDef, nfd: FunDef, ids: List[Identifier]) = copy(funDefsMapping = funDefsMapping + (fd -> (nfd, ids))) + def withFunDefs(fdNfd: Seq[(FunDef, (FunDef, List[Identifier]))]) = + copy(funDefsMapping = funDefsMapping ++ fdNfd) } - //return a "scope" consisting of purely functional code that defines potentially needed - //new variables (val, not var) and a mapping for each modified variable (var, not val :) ) - //to their new name defined in the scope. The first returned valued is the value of the expression - //that should be introduced as such in the returned scope (the val already refers to the new names) + /** Returns a "scope" consisting of purely functional code that defines potentially needed + * new variables (val, not var) and a mapping for each modified variable (var, not val :) ) + * to their new name defined in the scope. The first returned valued is the value of the expression + * that should be introduced as such in the returned scope (the val already refers to the new names) */ private def toFunction(expr: Expr)(implicit state: State): (Expr, Expr => Expr, Map[Identifier, Identifier]) = { import state._ expr match { @@ -180,7 +182,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { val finalVars = modifiedVars.map(_.freshen) val finalScope = (body: Expr) => { val tupleId = FreshIdentifier("t", whileFunReturnType) - LetDef(whileFunDef, Let( + LetDef(Seq(whileFunDef), Let( tupleId, FunctionInvocation(whileFunDef.typed, modifiedVars.map(_.toVariable)).setPos(wh), finalVars.zipWithIndex.foldLeft(body) { (b, id) => @@ -262,58 +264,89 @@ object ImperativeCodeElimination extends UnitPhase[Program] { } - case LetDef(fd, b) => - - def fdWithoutSideEffects = { - fd.body.foreach { bd => - val (fdRes, fdScope, _) = toFunction(bd) - fd.body = Some(fdScope(fdRes)) + case LetDef(fds, b) => + def fdsWithoutSideEffects = { + for(fd <- fds) { + fd.body.foreach { bd => + val (fdRes, fdScope, _) = toFunction(bd) + fd.body = Some(fdScope(fdRes)) + } } val (bodyRes, bodyScope, bodyFun) = toFunction(b) - (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)).setPos(fd).copiedFrom(expr), bodyFun) + (bodyRes, (b2: Expr) => LetDef(fds, bodyScope(b2)).setPos(fds.head).copiedFrom(expr), bodyFun) } - - fd.body match { - case Some(bd) => { - + if(fds.forall(_.body.isEmpty)) fdsWithoutSideEffects + else { + val modified_vars: Seq[(FunDef, List[Identifier])] = for(fd <- fds; bd <- fd.body) yield { val modifiedVars: List[Identifier] = collect[Identifier]({ case Assignment(v, _) => Set(v) case _ => Set() })(bd).intersect(state.varsInScope).toList - - if(modifiedVars.isEmpty) fdWithoutSideEffects else { - - val freshNames: List[Identifier] = modifiedVars.map(id => id.freshen) - - val newParams: Seq[ValDef] = fd.params ++ freshNames.map(n => ValDef(n)) - val freshVarDecls: List[Identifier] = freshNames.map(id => id.freshen) - - val rewritingMap: Map[Identifier, Identifier] = - modifiedVars.zip(freshVarDecls).toMap - val freshBody = - preMap({ - case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e)) - case Variable(id) => rewritingMap.get(id).map(nid => Variable(nid)) - case _ => None - })(bd) - val wrappedBody = freshNames.zip(freshVarDecls).foldLeft(freshBody)((body, p) => { - LetVar(p._2, Variable(p._1), body) - }) - - val newReturnType = TupleType(fd.returnType :: modifiedVars.map(_.getType)) - - val newFd = new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd) - - val (fdRes, fdScope, fdFun) = + (fd, modifiedVars) + } + if(modified_vars.forall(_._2.isEmpty)) fdsWithoutSideEffects else { + val freshNames: Seq[(FunDef, Seq[Identifier])] = modified_vars.map(fdmv => (fdmv._1, fdmv._2.map(id => id.freshen))) + + val newParams: Seq[(FunDef, Seq[ValDef])] = freshNames.map(fdfn => (fdfn._1, fdfn._1.params ++ fdfn._2.map(n => ValDef(n)))) + + val freshVarDecls: Seq[(FunDef, List[Identifier])] = freshNames.map(id => (id._1, id._2.map(_.freshen).toList)) + + val rewritingMap: Map[Identifier, Identifier] = + (modified_vars.zip(freshVarDecls).map{ + case ((fd, md), (_, fv)) => (fd, md.zip(fv).toMap) + }).map(_._2).foldLeft(Map[Identifier, Identifier]())(_ ++ _) + + //TODO: + + val freshBody: Seq[Option[Expr]] = for(fd <- fds) yield { + fd.body.map(bd => + preMap({ + case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e)) + case Variable(id) => rewritingMap.get(id).map(nid => Variable(nid)) + case _ => None + })(bd)) + } + + val wrappedBody = freshBody.zip(freshNames).zip(freshVarDecls).map{ + case ((freshBodyOpt, (_, freshNames)), (_, freshVarDecls)) => + freshBodyOpt.map(freshBody => freshNames.zip(freshVarDecls).foldLeft(freshBody)((body, p) => { + LetVar(p._2, Variable(p._1), body) + }))} + + val newReturnType = for((fd, modifiedVars) <- modified_vars) + yield TupleType(fd.returnType :: modifiedVars.map(_.getType)) + + val newFds = for(((fd, newParams), newReturnType) <- newParams.zip(newReturnType)) + yield (fd, new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd)) + + val mappingToAdd: Seq[(FunDef, (FunDef, List[Identifier]))] = + for(((fd, newFd), (_, freshVarDecls)) <- newFds.zip(freshVarDecls)) yield (fd -> ((newFd, freshVarDecls.toList))) + + //Seq[Option[(fdRes, fdScope, fdFun)]] = + val fdsResScopeFun = for(wrappedBodyOpt <- wrappedBody) yield { + wrappedBodyOpt.map(wrappedBody => toFunction(wrappedBody)( State(state.parent, Set(), - state.funDefsMapping + (fd -> ((newFd, freshVarDecls)))) + state.funDefsMapping ++ mappingToAdd) ) - val newRes = Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable)) - val newBody = fdScope(newRes) - - newFd.body = Some(newBody) + ) + } + + val newRes= for((optFdsResScopeFun, (_, freshVarDecls)) <- fdsResScopeFun.zip(freshVarDecls)) yield { + for((fdRes, fdScope, fdFun) <- optFdsResScopeFun) yield { + Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable)) + } + } + val newbody = for((optFdsResScopeFun, newRes) <- fdsResScopeFun.zip(newRes)) yield { + for(newRes <- newRes; + (fdRes, fdScope, fdFun) <- optFdsResScopeFun) yield { + fdScope(newRes) + } + } + val fdForState = for(((((fd, newFd), optNewbody), (_, modifiedVars)), (_, freshNames)) + <- newFds.zip(newbody).zip(modified_vars).zip(freshNames)) yield { + newFd.body = optNewbody newFd.precondition = fd.precondition.map(prec => { replace(modifiedVars.zip(freshNames).map(p => (p._1.toVariable, p._2.toVariable)).toMap, prec) }) @@ -331,12 +364,11 @@ object ImperativeCodeElimination extends UnitPhase[Program] { postBody) Lambda(Seq(newRes), newBody).setPos(post) }) - - val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withFunDef(fd, newFd, modifiedVars)) - (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)).copiedFrom(expr), bodyFun) + (fd, (newFd, modifiedVars)) } + val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withFunDefs(fdForState)) + (bodyRes, (b2: Expr) => LetDef(newFds.map(_._1), bodyScope(b2)).copiedFrom(expr), bodyFun) } - case None => fdWithoutSideEffects } case c @ Choose(b) => -- GitLab