diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index fefbde0456bd65ea01d7e08c13230d402d417b5e..6467226a0a788866891cad6ec183a89569bc5a66 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -403,7 +403,9 @@ object Definitions {
     def directlyNestedFuns = directlyNestedFunDefs(fullBody)
     def subDefinitions = params ++ tparams ++ directlyNestedFuns.toList
 
-    /* Duplication */
+    /** Duplication of this [[FunDef]].
+      * @note This will not replace recursive function calls
+      */
     def duplicate(
       id: Identifier = this.id.freshen,
       tparams: Seq[TypeParameterDef] = this.tparams,
diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala
index e86fef5af7074e20d516ef43c95b9ce40965867b..20dbd1eeced9c951221665b2a3136b2225681c86 100644
--- a/src/main/scala/leon/purescala/FunctionClosure.scala
+++ b/src/main/scala/leon/purescala/FunctionClosure.scala
@@ -17,9 +17,15 @@ object FunctionClosure extends TransformationPhase {
   override val name: String = "Function Closure"
   override val description: String = "Closing function with its scoping variables"
 
+  /** Takes a FunDef and returns a Seq of all internal FunDef's contained in fd in closed form
+    * (and fd itself, without inned FunDef's).
+    *
+    * The strategy is as follows: Remove one layer of nested FunDef's, then call
+    * close recursively on the new functions.
+    */
   private def close(fd: FunDef): Seq[FunDef] = { 
 
-    // Directly neste functions with their p.c.
+    // Directly nested functions with their p.c.
     val nestedWithPaths = {
       val funDefs = directlyNestedFunDefs(fd.fullBody)
       collectWithPC {
@@ -54,7 +60,7 @@ object FunctionClosure extends TransformationPhase {
       case (inner, pc) => inner -> step(inner, fd, pc, transFree(inner))
     }
 
-    // Remove LetDefs
+    // Remove LetDefs from fd
     fd.fullBody = preMap({
       case LetDef(fd, bd) =>
         Some(bd)
@@ -62,6 +68,7 @@ object FunctionClosure extends TransformationPhase {
         None
     }, applyRec = true)(fd.fullBody)
 
+    // A dummy substitution for fd, saying we should not change parameters
     val dummySubst = FunSubst(
       fd,
       Map.empty.withDefault(id => id),
@@ -69,47 +76,52 @@ object FunctionClosure extends TransformationPhase {
     )
 
     // Refresh function calls
-    (dummySubst +: closed.values.toSeq).foreach { case FunSubst(f, paramsMap, tparamsMap) =>
-      //println(f)
-      //paramsMap foreach { case (from, to) =>
-      //  println(from.uniqueName + " -> " + to.uniqueName)
-      //}
-      f.fullBody = preMap {
-        case fi@FunctionInvocation(tfd, args) if closed contains tfd.fd =>
-          val FunSubst(newFd, newParams, newTParams) = closed(tfd.fd)
-
-          // New -> old map for function call
-          val mapReverse = newParams map { _.swap }
-          val extraArgs = newFd.paramIds.drop(args.size).map { id =>
-            paramsMap(mapReverse(id)).toVariable
-          }
-
-          // Similarly for type params
-          val tReverse = newTParams map { _.swap }
-          val tOrigExtraOrdered = newFd.tparams.map{_.tp}.drop(tfd.tps.length).map(tReverse)
-          val tFinalExtra: Seq[TypeParameter] = tOrigExtraOrdered.map( tp =>
-            tparamsMap(tp)
-          )
-
-          Some(FunctionInvocation(
-            newFd.typed(tfd.tps ++ tFinalExtra),
-            args ++ extraArgs
-          ).copiedFrom(fi))
-        case _ => None
-      }(f.fullBody)
+    (dummySubst +: closed.values.toSeq).foreach {
+      case FunSubst(f, callerMap, callerTMap) =>
+        f.fullBody = preMap {
+          case fi@FunctionInvocation(tfd, args) if closed contains tfd.fd =>
+            val FunSubst(newCallee, calleeMap, calleeTMap) = closed(tfd.fd)
+
+            // This needs some explanation.
+            // Say we have caller and callee. First we find the param. substitutions of callee
+            // (say old -> calleeNew) and reverse them. So we have a mapping (calleeNew -> old).
+            // We also have the caller mapping, (old -> callerNew).
+            // So we pass the callee parameters through these two mappings to get the caller parameters.
+            val mapReverse = calleeMap map { _.swap }
+            val extraArgs = newCallee.paramIds.drop(args.size).map { id =>
+              callerMap(mapReverse(id)).toVariable
+            }
+
+            // Similarly for type params
+            val tReverse = calleeTMap map { _.swap }
+            val tOrigExtraOrdered = newCallee.tparams.map{_.tp}.drop(tfd.tps.length).map(tReverse)
+            val tFinalExtra: Seq[TypeParameter] = tOrigExtraOrdered.map( tp =>
+              callerTMap(tp)
+            )
+
+            Some(FunctionInvocation(
+              newCallee.typed(tfd.tps ++ tFinalExtra),
+              args ++ extraArgs
+            ).copiedFrom(fi))
+          case _ => None
+        }(f.fullBody)
     }
 
     val funs = closed.values.toSeq.map{ _.newFd }
 
+    // Recursively close new functions
     fd +: funs.flatMap(close)
   }
 
+  // Represents a substitution to a new function, along with parameter and type parameter
+  // mappings
   private case class FunSubst(
     newFd: FunDef,
     paramsMap: Map[Identifier, Identifier],
     tparamsMap: Map[TypeParameter, TypeParameter]
   )
 
+  // Takes one inner function and closes it. 
   private def step(inner: FunDef, outer: FunDef, pc: Expr, free: Seq[Identifier]): FunSubst = {
 
     val tpFresh = outer.tparams map { _.freshen }