From 7f17d218ab92c1473754d79dfcb721c4c9dabd8a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <a-mikmay@microsoft.com>
Date: Tue, 8 Dec 2015 15:38:53 +0100
Subject: [PATCH] Added support for LetDef with mutually recursive functions.

---
 .../frontends/scalac/CodeExtraction.scala     |  20 ++-
 src/main/scala/leon/purescala/ExprOps.scala   |  67 +++++----
 .../scala/leon/purescala/Expressions.scala    |   6 +-
 .../scala/leon/purescala/Extractors.scala     |  10 +-
 .../leon/purescala/FunctionClosure.scala      |   8 +-
 .../leon/purescala/ScopeSimplifier.scala      |  41 ++++--
 src/main/scala/leon/purescala/TypeOps.scala   |  47 +++---
 src/main/scala/leon/synthesis/Solution.scala  |   2 +-
 .../leon/termination/SelfCallsProcessor.scala |   2 +-
 .../transformations/StackSpacePhase.scala     |   2 +-
 .../scala/leon/utils/UnitElimination.scala    |  46 ++++--
 .../scala/leon/xlang/EpsilonElimination.scala |   2 +-
 .../xlang/ImperativeCodeElimination.scala     | 138 +++++++++++-------
 13 files changed, 238 insertions(+), 153 deletions(-)

diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index c50aae653..81d146863 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -1104,18 +1104,25 @@ trait CodeExtraction extends ASTExtractors {
 
           val newDctx = dctx.copy(tparams = dctx.tparams ++ tparamsMap)
 
+          val restTree = rest match {
+            case Some(rst) => extractTree(rst)
+            case None => UnitLiteral()
+          }
+          rest = None
+          
           val oldCurrentFunDef = currentFunDef
 
           val funDefWithBody = extractFunBody(fd, params, b)(newDctx)
 
           currentFunDef = oldCurrentFunDef
-
-          val restTree = rest match {
-            case Some(rst) => extractTree(rst)
-            case None => UnitLiteral()
+          
+          val (other_fds, block) = restTree match {
+            case LetDef(fds, block) =>
+              (fds, block)
+            case _ =>
+              (Nil, restTree)
           }
-          rest = None
-          LetDef(funDefWithBody, restTree)
+          LetDef(funDefWithBody +: other_fds, block)
 
         // FIXME case ExDefaultValueFunction
 
@@ -1495,6 +1502,7 @@ trait CodeExtraction extends ASTExtractors {
           Implies(extractTree(lhs), extractTree(rhs)).setPos(current.pos)
 
         case c @ ExCall(rec, sym, tps, args) =>
+          // The object on which it is called is null if the symbol sym is a valid function in the scope and not a method.
           val rrec = rec match {
             case t if (defsToDefs contains sym) && !isMethod(sym) =>
               null
diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index a7a20e2d4..0d096025a 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -309,9 +309,11 @@ object ExprOps {
   def preTransformWithBinders(f: (Expr, Set[Identifier]) => Expr, initBinders: Set[Identifier] = Set())(e: Expr) = {
     import xlang.Expressions.LetVar
     def rec(binders: Set[Identifier], e: Expr): Expr = (f(e, binders) match {
-      case LetDef(fd, bd) =>
-        fd.fullBody = rec(binders ++ fd.paramIds, fd.fullBody)
-        LetDef(fd, rec(binders, bd))
+      case LetDef(fds, bd) =>
+        fds.foreach(fd => {
+          fd.fullBody = rec(binders ++ fd.paramIds, fd.fullBody)
+        })
+        LetDef(fds, rec(binders, bd))
       case Let(i, v, b) =>
         Let(i, rec(binders + i, v), rec(binders + i, b))
       case LetVar(i, v, b) =>
@@ -346,7 +348,7 @@ object ExprOps {
         e match {
           case Variable(i) => subvs + i
           case Old(i) => subvs + i
-          case LetDef(fd, _) => subvs -- fd.params.map(_.id)
+          case LetDef(fds, _) => subvs -- fds.flatMap(_.params.map(_.id))
           case Let(i, _, _) => subvs - i
           case LetVar(i, _, _) => subvs - i
           case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders)
@@ -377,7 +379,7 @@ object ExprOps {
   /** Returns functions in directly nested LetDefs */
   def directlyNestedFunDefs(e: Expr): Set[FunDef] = {
     fold[Set[FunDef]]{
-      case (LetDef(fd,_), Seq(fromFd, fromBd)) => fromBd + fd
+      case (LetDef(fds,_), Seq(fromFds, fromBd)) => fromBd ++ fds
       case (_, subs) => subs.flatten.toSet
     }(e)
   }
@@ -514,7 +516,7 @@ object ExprOps {
       (expr, idSeqs) => idSeqs.foldLeft(expr match {
         case Lambda(args, _) => args.map(_.id)
         case Forall(args, _) => args.map(_.id)
-        case LetDef(fd, _) => fd.paramIds
+        case LetDef(fds, _) => fds.flatMap(_.paramIds)
         case Let(i, _, _) => Seq(i)
         case MatchExpr(_, cses) => cses.flatMap(_.pattern.binders)
         case Passes(_, _, cses) => cses.flatMap(_.pattern.binders)
@@ -1239,23 +1241,23 @@ object ExprOps {
 
     def pre(e : Expr) = e match {
 
-      case LetDef(fd, expr) if fd.hasPrecondition =>
-       val pre = fd.precondition.get
-
-        solver.solveVALID(pre) match {
-          case Some(true)  =>
-            fd.precondition = None
+      case LetDef(fds, expr) =>
+       for(fd <- fds if fd.hasPrecondition) {
+          val pre = fd.precondition.get
 
-          case Some(false) => solver.solveSAT(pre) match {
-            case (Some(false), _) =>
-              fd.precondition = Some(BooleanLiteral(false).copiedFrom(e))
-            case _ =>
+          solver.solveVALID(pre) match {
+            case Some(true)  =>
+              fd.precondition = None
+  
+            case Some(false) => solver.solveSAT(pre) match {
+              case (Some(false), _) =>
+                fd.precondition = Some(BooleanLiteral(false).copiedFrom(e))
+              case _ =>
+            }
+            case None =>
           }
-          case None =>
-        }
-
-        e
-
+       }
+       e
       case IfExpr(cond, thenn, elze) =>
         try {
           solver.solveVALID(cond) match {
@@ -1630,9 +1632,15 @@ object ExprOps {
           isHomo(v1, v2) &&
           isHomo(e1, e2)(map + (id1 -> id2))
 
-        case (LetDef(fd1, e1), LetDef(fd2, e2)) =>
-          fdHomo(fd1, fd2) &&
-          isHomo(e1, e2)(map + (fd1.id -> fd2.id))
+        case (LetDef(fds1, e1), LetDef(fds2, e2)) =>
+          fds1.size == fds2.size &&
+          {
+            val zipped = fds1.zip(fds2)
+            zipped.forall( fds =>
+            fdHomo(fds._1, fds._2)
+            ) &&
+            isHomo(e1, e2)(map ++ zipped.map(fds => fds._1.id -> fds._2.id))
+          }
 
         case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) =>
           cs1.size == cs2.size && isHomo(s1, s2) && casesMatch(cs1,cs2)
@@ -1819,7 +1827,8 @@ object ExprOps {
     */
   def flattenFunctions(fdOuter: FunDef, ctx: LeonContext, p: Program): FunDef = {
     fdOuter.body match {
-      case Some(LetDef(fdInner, FunctionInvocation(tfdInner2, args))) if fdInner == tfdInner2.fd =>
+      case Some(LetDef(fdsInner, FunctionInvocation(tfdInner2, args))) if fdsInner.size == 1 && fdsInner.head == tfdInner2.fd =>
+        val fdInner = fdsInner.head
         val argsDef  = fdOuter.paramIds
         val argsCall = args.collect { case Variable(id) => id }
 
@@ -2106,12 +2115,12 @@ object ExprOps {
 
     import synthesis.Witnesses.Terminating
     val res1 = preMap({
-      case LetDef(fd, b) =>
-        val nfd = fd.duplicate()
+      case LetDef(lfds, b) =>
+        val nfds = lfds.map(fd => fd -> fd.duplicate())
 
-        fds += fd -> nfd
+        fds ++= nfds
 
-        Some(LetDef(nfd, b))
+        Some(LetDef(nfds.map(_._2), b))
 
       case FunctionInvocation(tfd, args) =>
         if (fds contains tfd.fd) {
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index 0915a0243..7030dfb94 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -159,12 +159,12 @@ object Expressions {
     }
   }
 
-  /** $encodingof `def ... = ...; ...` (local function definition)
+  /** $encodingof multiple `def ... = ...; ...` (local function definition and possibly mutually recursive)
     *
-    * @param fd The function definition.
+    * @param fds The function definitions.
     * @param body The body of the expression after the function
     */
-  case class LetDef(fd: FunDef, body: Expr) extends Expr {
+  case class LetDef(fds: Seq[FunDef], body: Expr) extends Expr {
     val getType = body.getType
   }
 
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index 3a3bd7c7b..e2581dd8c 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -72,11 +72,13 @@ object Extractors {
         Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head)))
 
       /* Binary operators */
-      case LetDef(fd, body) => Some((
-        Seq(fd.fullBody, body),
+      case LetDef(fds, rest) => Some((
+        fds.map(_.fullBody) ++ Seq(rest),
         (es: Seq[Expr]) => {
-          fd.fullBody = es(0)
-          LetDef(fd, es(1))
+          for((fd, i) <- fds.zipWithIndex) {
+            fd.fullBody = es(i)
+          }
+          LetDef(fds, es(fds.length))
         }
       ))
       case Equals(t1, t2) =>
diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala
index 65fd1de8a..0a62d8e03 100644
--- a/src/main/scala/leon/purescala/FunctionClosure.scala
+++ b/src/main/scala/leon/purescala/FunctionClosure.scala
@@ -26,13 +26,13 @@ object FunctionClosure extends TransformationPhase {
   private def close(fd: FunDef): Seq[FunDef] = { 
 
     // Directly nested functions with their p.c.
-    val nestedWithPaths = {
+    val nestedWithPathsFull = {
       val funDefs = directlyNestedFunDefs(fd.fullBody)
       collectWithPC {
-        case LetDef(fd1, body) if funDefs(fd1) => fd1
+        case LetDef(fd1, body) => fd1.filter(funDefs)
       }(fd.fullBody)
-    }.toMap
-    
+    }
+    val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap
     val nestedFuns = nestedWithPaths.keys.toSeq
 
     // Transitively called funcions from each function
diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala
index 9eca53058..7bae40c5b 100644
--- a/src/main/scala/leon/purescala/ScopeSimplifier.scala
+++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala
@@ -34,23 +34,32 @@ class ScopeSimplifier extends Transformer {
       val sb = rec(b, scope.register(i -> si))
       Let(si, se, sb)
 
-    case LetDef(fd: FunDef, body: Expr) =>
-      val newId    = genId(fd.id, scope)
-      var newScope = scope.register(fd.id -> newId)
-
-      val newArgs = for(ValDef(id, tpe) <- fd.params) yield {
-        val newArg = genId(id, newScope)
-        newScope = newScope.register(id -> newArg)
-        ValDef(newArg, tpe)
+    case LetDef(fds, body: Expr) =>
+      var newScope: Scope = scope
+      // First register all functions
+      val fds_newIds = for(fd <- fds) yield {
+        val newId    = genId(fd.id, scope)
+        newScope = newScope.register(fd.id -> newId)
+        (fd, newId)
       }
-
-      val newFd = fd.duplicate(id = newId, params = newArgs)
-
-      newScope = newScope.registerFunDef(fd -> newFd)
-
-      newFd.fullBody = rec(fd.fullBody, newScope)
-
-      LetDef(newFd, rec(body, newScope))
+      
+      val fds_mapping = for((fd, newId) <- fds_newIds) yield {
+        val newArgs = for(ValDef(id, tpe) <- fd.params) yield {
+          val newArg = genId(id, newScope)
+          newScope = newScope.register(id -> newArg)
+          ValDef(newArg, tpe)
+        }
+  
+        val newFd = fd.duplicate(id = newId, params = newArgs)
+  
+        newScope = newScope.registerFunDef(fd -> newFd)
+        (newFd, fd)
+      }
+      
+      for((newFd, fd) <- fds_mapping) {
+        newFd.fullBody = rec(fd.fullBody, newScope)
+      }
+      LetDef(fds_mapping.map(_._1), rec(body, newScope))
    
     case MatchExpr(scrut, cases) =>
       val rs = rec(scrut, scope)
diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala
index 51bed3eaf..f15b3e5af 100644
--- a/src/main/scala/leon/purescala/TypeOps.scala
+++ b/src/main/scala/leon/purescala/TypeOps.scala
@@ -307,27 +307,38 @@ object TypeOps {
             val newId = freshId(id, tpeSub(id.getType))
             Let(newId, srec(value), rec(idsMap + (id -> newId))(body)).copiedFrom(l)
 
-          case l @ LetDef(fd, bd) =>
-            val id = fd.id.freshen
-            val tparams = fd.tparams map { p => 
-              TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter])
+          case l @ LetDef(fds, bd) =>
+            val fds_mapping = for(fd <- fds) yield {
+              val id = fd.id.freshen
+              val tparams = fd.tparams map { p => 
+                TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter])
+              }
+              val returnType = tpeSub(fd.returnType)
+              val params = fd.params map (instantiateType(_, tps))
+              val newFd = fd.duplicate(id, tparams, params, returnType)
+              val subCalls = preMap {
+                case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd =>
+                  Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi))
+                case _ => 
+                  None
+              } _
+              (fd, newFd, subCalls)
+            }
+            // We group the subcalls functions all in once
+            val subCalls = (((None:Option[Expr => Expr]) /: fds_mapping) {
+              case (None, (_, _, subCalls)) => Some(subCalls)
+              case (Some(fn), (_, _, subCalls)) => Some(fn andThen subCalls)
+            }).get
+            
+            // We apply all the functions mappings at once
+            val newFds = for((fd, newFd, _) <- fds_mapping) yield {
+              val fullBody = rec(idsMap ++ fd.paramIds.zip(newFd.paramIds))(subCalls(fd.fullBody))
+              newFd.fullBody = fullBody
+              newFd
             }
-            val returnType = tpeSub(fd.returnType)
-            val params = fd.params map (instantiateType(_, tps))
-            val newFd = fd.duplicate(id, tparams, params, returnType)
-
-            val subCalls = preMap {
-              case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd =>
-                Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi))
-              case _ => 
-                None
-            } _
-            val fullBody = rec(idsMap ++ fd.paramIds.zip(newFd.paramIds))(subCalls(fd.fullBody))
-            newFd.fullBody = fullBody
-
             val newBd = srec(subCalls(bd)).copiedFrom(bd)
 
-            LetDef(newFd, newBd).copiedFrom(l)
+            LetDef(newFds, newBd).copiedFrom(l)
 
           case l @ Lambda(args, body) =>
             val newArgs = args.map { arg =>
diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala
index e87a7883d..ea52b9bab 100644
--- a/src/main/scala/leon/synthesis/Solution.scala
+++ b/src/main/scala/leon/synthesis/Solution.scala
@@ -31,7 +31,7 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrust
   }
 
   def toExpr = {
-    defs.foldLeft(guardedTerm){ case (t, fd) => LetDef(fd, t) }
+    LetDef(defs.toList, guardedTerm)
   }
 
   // Projects a solution (ignore several output variables)
diff --git a/src/main/scala/leon/termination/SelfCallsProcessor.scala b/src/main/scala/leon/termination/SelfCallsProcessor.scala
index 157dbce24..320c230c2 100644
--- a/src/main/scala/leon/termination/SelfCallsProcessor.scala
+++ b/src/main/scala/leon/termination/SelfCallsProcessor.scala
@@ -30,7 +30,7 @@ class SelfCallsProcessor(val checker: TerminationChecker) extends Processor {
     def rec(e0: Expr): Boolean = e0 match {
       case Assert(pred: Expr, error: Option[String], body: Expr) => rec(pred) || rec(body)
       case Let(binder: Identifier, value: Expr, body: Expr) => rec(value) || rec(body)
-      case LetDef(fd: FunDef, body: Expr) => rec(body) // don't enter fd because we don't know if it will be called
+      case LetDef(fds, body: Expr) => rec(body) // don't enter fds because we don't know if it will be called
       case FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) =>
         tfd.fd == f /* <-- success in proving non-termination */ ||
         args.exists(arg => rec(arg)) || (tfd.fd.hasBody && (!seenFunDefs.contains(tfd.fd)) && {
diff --git a/src/main/scala/leon/transformations/StackSpacePhase.scala b/src/main/scala/leon/transformations/StackSpacePhase.scala
index f4edb7ee1..43fb3930e 100644
--- a/src/main/scala/leon/transformations/StackSpacePhase.scala
+++ b/src/main/scala/leon/transformations/StackSpacePhase.scala
@@ -141,7 +141,7 @@ class StackSpaceInstrumenter(p: Program, si: SerialInstrumenter) extends Instrum
         (1 + valTemp + bodyTemp, Math.max(valStack, bodyStack))
       }
 
-      case LetDef(fd: FunDef, body: Expr) => {
+      case LetDef(fds, body: Expr) => {
       // The function definition does not take up stack space. Goes into the constant pool
         estimateTemporaries(body)
       }
diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala
index 3d486b57f..f4f603393 100644
--- a/src/main/scala/leon/utils/UnitElimination.scala
+++ b/src/main/scala/leon/utils/UnitElimination.scala
@@ -93,25 +93,39 @@ object UnitElimination extends TransformationPhase {
           }
         }
 
-      case LetDef(fd, b) =>
-        if(fd.returnType == UnitType) 
+      case LetDef(fds, b) =>
+        val nonUnits = fds.filter(fd => fd.returnType != UnitType)
+        if(nonUnits.isEmpty) {
           removeUnit(b)
-        else {
-          val (newFd, rest) = if(fd.params.exists(vd => vd.getType == UnitType)) {
-            val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType))
-            fun2FreshFun += (fd -> freshFunDef)
-            freshFunDef.fullBody = removeUnit(fd.fullBody)
-            val restRec = removeUnit(b)
-            fun2FreshFun -= fd
-            (freshFunDef, restRec)
-          } else {
-            fun2FreshFun += (fd -> fd)
-            fd.body = fd.body.map(b => removeUnit(b))
-            val restRec = removeUnit(b)
+        } else {
+          val fdtoFreshFd = for(fd <- nonUnits) yield {
+            val m = if(fd.params.exists(vd => vd.getType == UnitType)) {
+              val freshFunDef = fd.duplicate(params = fd.params.filterNot(vd => vd.getType == UnitType))
+              fd -> freshFunDef
+            } else {
+              fd -> fd
+            }
+            fun2FreshFun += m
+            m
+          }
+          for((fd, freshFunDef) <- fdtoFreshFd) {
+            if(fd.params.exists(vd => vd.getType == UnitType)) {
+              freshFunDef.fullBody = removeUnit(fd.fullBody)
+            } else {
+              fd.body = fd.body.map(b => removeUnit(b))
+            }
+          }
+          val rest = removeUnit(b)
+          val newFds = for((fd, freshFunDef) <- fdtoFreshFd) yield {
             fun2FreshFun -= fd
-            (fd, restRec)
+            if(fd.params.exists(vd => vd.getType == UnitType)) {
+              freshFunDef
+            } else {
+              fd
+            }
           }
-          LetDef(newFd, rest)
+          
+          LetDef(newFds, rest)
         }
 
       case ite@IfExpr(cond, tExpr, eExpr) =>
diff --git a/src/main/scala/leon/xlang/EpsilonElimination.scala b/src/main/scala/leon/xlang/EpsilonElimination.scala
index a09eeb637..51b23be1b 100644
--- a/src/main/scala/leon/xlang/EpsilonElimination.scala
+++ b/src/main/scala/leon/xlang/EpsilonElimination.scala
@@ -30,7 +30,7 @@ object EpsilonElimination extends UnitPhase[Program] {
           }.toMap ++ Seq((epsilonVar, Variable(resId)))
           val postcondition = replace(eMap, pred)
           newFunDef.postcondition = Some(Lambda(Seq(ValDef(resId)), postcondition))
-          LetDef(newFunDef, FunctionInvocation(newFunDef.typed, bSeq map Variable))
+          LetDef(Seq(newFunDef), FunctionInvocation(newFunDef.typed, bSeq map Variable))
 
         case (other, _) => other
       }, fd.paramIds.toSet)(fd.fullBody)
diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
index e9802d500..600640c7b 100644
--- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala
@@ -28,9 +28,9 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
     }
   }
 
-  /* varsInScope refers to variable declared in the same level scope.
-     Typically, when entering a nested function body, the scope should be
-     reset to empty */
+  /** varsInScope refers to variable declared in the same level scope.
+    * Typically, when entering a nested function body, the scope should be
+    * reset to empty */
   private case class State(
     parent: FunDef, 
     varsInScope: Set[Identifier],
@@ -39,12 +39,14 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
     def withVar(i: Identifier) = copy(varsInScope = varsInScope + i)
     def withFunDef(fd: FunDef, nfd: FunDef, ids: List[Identifier]) = 
       copy(funDefsMapping = funDefsMapping + (fd -> (nfd, ids)))
+    def withFunDefs(fdNfd: Seq[(FunDef, (FunDef, List[Identifier]))]) = 
+      copy(funDefsMapping = funDefsMapping ++ fdNfd)
   }
 
-  //return a "scope" consisting of purely functional code that defines potentially needed 
-  //new variables (val, not var) and a mapping for each modified variable (var, not val :) )
-  //to their new name defined in the scope. The first returned valued is the value of the expression
-  //that should be introduced as such in the returned scope (the val already refers to the new names)
+  /** Returns a "scope" consisting of purely functional code that defines potentially needed 
+    * new variables (val, not var) and a mapping for each modified variable (var, not val :) )
+    * to their new name defined in the scope. The first returned valued is the value of the expression
+    * that should be introduced as such in the returned scope (the val already refers to the new names) */
   private def toFunction(expr: Expr)(implicit state: State): (Expr, Expr => Expr, Map[Identifier, Identifier]) = {
     import state._
     expr match {
@@ -180,7 +182,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
           val finalVars = modifiedVars.map(_.freshen)
           val finalScope = (body: Expr) => {
             val tupleId = FreshIdentifier("t", whileFunReturnType)
-            LetDef(whileFunDef, Let(
+            LetDef(Seq(whileFunDef), Let(
               tupleId,
               FunctionInvocation(whileFunDef.typed, modifiedVars.map(_.toVariable)).setPos(wh),
               finalVars.zipWithIndex.foldLeft(body) { (b, id) =>
@@ -262,58 +264,89 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
         }
 
 
-      case LetDef(fd, b) =>
-
-        def fdWithoutSideEffects =  {
-          fd.body.foreach { bd =>
-            val (fdRes, fdScope, _) = toFunction(bd)
-            fd.body = Some(fdScope(fdRes))
+      case LetDef(fds, b) =>
+        def fdsWithoutSideEffects =  {
+          for(fd <- fds) {
+            fd.body.foreach { bd =>
+              val (fdRes, fdScope, _) = toFunction(bd)
+              fd.body = Some(fdScope(fdRes))
+            }
           }
           val (bodyRes, bodyScope, bodyFun) = toFunction(b)
-          (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)).setPos(fd).copiedFrom(expr), bodyFun)
+          (bodyRes, (b2: Expr) => LetDef(fds, bodyScope(b2)).setPos(fds.head).copiedFrom(expr), bodyFun)
         }
-
-        fd.body match {
-          case Some(bd) => {
-
+        if(fds.forall(_.body.isEmpty)) fdsWithoutSideEffects
+        else {
+          val modified_vars: Seq[(FunDef, List[Identifier])] = for(fd <- fds; bd <- fd.body) yield {
             val modifiedVars: List[Identifier] =
               collect[Identifier]({
                 case Assignment(v, _) => Set(v)
                 case _ => Set()
               })(bd).intersect(state.varsInScope).toList
-
-            if(modifiedVars.isEmpty) fdWithoutSideEffects else {
-
-              val freshNames: List[Identifier] = modifiedVars.map(id => id.freshen)
-
-              val newParams: Seq[ValDef] = fd.params ++ freshNames.map(n => ValDef(n))
-              val freshVarDecls: List[Identifier] = freshNames.map(id => id.freshen)
-
-              val rewritingMap: Map[Identifier, Identifier] =
-                modifiedVars.zip(freshVarDecls).toMap
-              val freshBody =
-                preMap({
-                  case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e))
-                  case Variable(id) => rewritingMap.get(id).map(nid => Variable(nid))
-                  case _ => None
-                })(bd)
-              val wrappedBody = freshNames.zip(freshVarDecls).foldLeft(freshBody)((body, p) => {
-                LetVar(p._2, Variable(p._1), body)
-              })
-
-              val newReturnType = TupleType(fd.returnType :: modifiedVars.map(_.getType))
-
-              val newFd = new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd)
-
-              val (fdRes, fdScope, fdFun) = 
+            (fd, modifiedVars)
+          }
+          if(modified_vars.forall(_._2.isEmpty)) fdsWithoutSideEffects else {
+            val freshNames: Seq[(FunDef, Seq[Identifier])] = modified_vars.map(fdmv => (fdmv._1, fdmv._2.map(id => id.freshen)))
+            
+            val newParams: Seq[(FunDef, Seq[ValDef])] = freshNames.map(fdfn => (fdfn._1, fdfn._1.params ++ fdfn._2.map(n => ValDef(n))))
+            
+            val freshVarDecls: Seq[(FunDef, List[Identifier])] = freshNames.map(id => (id._1, id._2.map(_.freshen).toList))
+            
+            val rewritingMap: Map[Identifier, Identifier] =
+                (modified_vars.zip(freshVarDecls).map{
+              case ((fd, md), (_, fv)) => (fd, md.zip(fv).toMap)
+            }).map(_._2).foldLeft(Map[Identifier, Identifier]())(_ ++ _)
+            
+            //TODO:
+            
+            val freshBody: Seq[Option[Expr]] = for(fd <- fds) yield {
+              fd.body.map(bd => 
+              preMap({
+                case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e))
+                case Variable(id) => rewritingMap.get(id).map(nid => Variable(nid))
+                case _ => None
+              })(bd))
+            }
+            
+            val wrappedBody = freshBody.zip(freshNames).zip(freshVarDecls).map{
+              case ((freshBodyOpt, (_, freshNames)), (_, freshVarDecls)) =>
+                freshBodyOpt.map(freshBody => freshNames.zip(freshVarDecls).foldLeft(freshBody)((body, p) => {
+              LetVar(p._2, Variable(p._1), body)
+            }))}
+
+            val newReturnType = for((fd, modifiedVars) <- modified_vars)
+              yield TupleType(fd.returnType :: modifiedVars.map(_.getType))
+
+            val newFds = for(((fd, newParams), newReturnType) <- newParams.zip(newReturnType))
+              yield (fd, new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd))
+              
+            val mappingToAdd: Seq[(FunDef, (FunDef, List[Identifier]))] =
+              for(((fd, newFd), (_, freshVarDecls)) <- newFds.zip(freshVarDecls)) yield (fd -> ((newFd, freshVarDecls.toList)))
+
+            //Seq[Option[(fdRes, fdScope, fdFun)]] = 
+            val fdsResScopeFun = for(wrappedBodyOpt <- wrappedBody) yield {
+              wrappedBodyOpt.map(wrappedBody => 
                 toFunction(wrappedBody)(
                   State(state.parent, Set(), 
-                        state.funDefsMapping + (fd -> ((newFd, freshVarDecls))))
+                        state.funDefsMapping ++ mappingToAdd)
                 )
-              val newRes = Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable))
-              val newBody = fdScope(newRes)
-
-              newFd.body = Some(newBody)
+              )
+            }
+            
+            val newRes= for((optFdsResScopeFun, (_, freshVarDecls)) <- fdsResScopeFun.zip(freshVarDecls)) yield {
+              for((fdRes, fdScope, fdFun) <- optFdsResScopeFun) yield {
+                Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable))
+              }
+            }
+            val newbody = for((optFdsResScopeFun, newRes) <- fdsResScopeFun.zip(newRes)) yield {
+              for(newRes <- newRes;
+                  (fdRes, fdScope, fdFun) <- optFdsResScopeFun) yield {
+                fdScope(newRes)
+              }
+            }
+            val fdForState = for(((((fd, newFd), optNewbody), (_, modifiedVars)), (_, freshNames))
+                <- newFds.zip(newbody).zip(modified_vars).zip(freshNames)) yield {
+              newFd.body = optNewbody
               newFd.precondition = fd.precondition.map(prec => {
                 replace(modifiedVars.zip(freshNames).map(p => (p._1.toVariable, p._2.toVariable)).toMap, prec)
               })
@@ -331,12 +364,11 @@ object ImperativeCodeElimination extends UnitPhase[Program] {
                   postBody)
                 Lambda(Seq(newRes), newBody).setPos(post)
               })
-
-              val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withFunDef(fd, newFd, modifiedVars))
-              (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)).copiedFrom(expr), bodyFun)
+              (fd, (newFd, modifiedVars))
             }
+            val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withFunDefs(fdForState))
+            (bodyRes, (b2: Expr) => LetDef(newFds.map(_._1), bodyScope(b2)).copiedFrom(expr), bodyFun)
           }
-          case None => fdWithoutSideEffects
         }
 
       case c @ Choose(b) =>
-- 
GitLab