diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala
index 7f76beff85796ec47d8ff5a49e4613c752c9ac41..4ee1f89386468ab0582ad461da7a897c5340c4a2 100644
--- a/src/main/scala/leon/purescala/FunctionClosure.scala
+++ b/src/main/scala/leon/purescala/FunctionClosure.scala
@@ -3,178 +3,160 @@
 package leon
 package purescala
 
-import Common._
 import Definitions._
 import Expressions._
-import Extractors._
 import ExprOps._
 import Constructors._
+import TypeOps.instantiateType
+import leon.purescala.Common.Identifier
+import leon.purescala.Types.TypeParameter
+import utils.GraphOps._
 
 class FunctionClosure extends TransformationPhase {
 
-  val name = "Function Closure"
-  val description = "Closing function with its scoping variables"
+  override val name: String = "Function Closure"
+  override val description: String = "Closing function with its scoping variables"
+
+  private def close(fd: FunDef): Seq[FunDef] = { 
+
+    // Directly neste functions with their p.c.
+    val nestedWithPaths = {
+      val funDefs = directlyNestedFunDefs(fd.fullBody)
+      collectWithPC {
+        case LetDef(fd1, body) if funDefs(fd1) => fd1
+      }(fd.fullBody)
+    }.toMap
+    
+    val nestedFuns = nestedWithPaths.keys.toSeq
+
+    // Transitively called funcions from each function
+    val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure(
+      nestedFuns.map { f =>
+        val calls = functionCallsOf(f.fullBody) collect {
+          case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) =>
+            fd
+        }
+        f -> calls
+      }.toMap
+    )
+
+    def freeVars(fd: FunDef, pc: Expr): Set[Identifier] =
+      variablesOf(fd.fullBody) ++ variablesOf(pc) -- fd.paramIds
+
+    // All free variables one should include.
+    // Contains free vars of the function itself plus of all transitively called functions.
+    val transFree = nestedFuns.map { fd =>
+      fd -> (callGraph(fd) + fd).flatMap( (fd2:FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq
+    }.toMap
+
+    // Closed functions along with a map (old var -> new var).
+    val closed = nestedWithPaths.map {
+      case (inner, pc) => inner -> step(inner, fd, pc, transFree(inner))
+    }
 
-  // TODO: Rewrite this phase
-  /* I know, that's a lot of mutable variables */
-  private var pathConstraints: List[Expr] = Nil
-  private var enclosingLets: List[(Identifier, Expr)] = Nil
-  private var newFunDefs: Map[FunDef, FunDef] = Map()
-  private var topLevelFuns: Set[FunDef] = Set()
-  private var parent: FunDef = null //refers to the current toplevel parent
+    // Remove LetDefs
+    fd.fullBody = preMap({
+      case LetDef(fd, bd) =>
+        Some(bd)
+      case _ =>
+        None
+    }, applyRec = true)(fd.fullBody)
+
+    val dummySubst = FunSubst(
+      fd,
+      Map.empty.withDefault(id => id),
+      Map.empty.withDefault(id => id)
+    )
+
+    // Refresh function calls
+    (dummySubst +: closed.values.toSeq).foreach { case FunSubst(f, paramsMap, tparamsMap) =>
+      //println(f)
+      //paramsMap foreach { case (from, to) =>
+      //  println(from.uniqueName + " -> " + to.uniqueName)
+      //}
+      f.fullBody = preMap {
+        case FunctionInvocation(tfd, args) if closed contains tfd.fd =>
+          val FunSubst(newFd, newParams, newTParams) = closed(tfd.fd)
+
+          // New -> old map for function call
+          val mapReverse = newParams map { _.swap }
+          val extraArgs = newFd.paramIds.drop(args.size).map { id =>
+            paramsMap(mapReverse(id)).toVariable
+          }
+
+          // Similarly for type params
+          val tReverse = newTParams map { _.swap }
+          val tOrigExtraOrdered = newFd.tparams.map{_.tp}.drop(tfd.tps.length).map(tReverse)
+          val tFinalExtra: Seq[TypeParameter] = tOrigExtraOrdered.map( tp =>
+            tparamsMap(tp)
+          )
+
+          Some(FunctionInvocation(
+            newFd.typed(tfd.tps ++ tFinalExtra),
+            args ++ extraArgs
+          ))
+        case _ => None
+      }(f.fullBody)
+    }
 
-  def apply(ctx: LeonContext, program: Program): Program = {
+    val funs = closed.values.toSeq.map{ _.newFd }
 
-    val newUnits = program.units.map { u => u.copy(defs = u.defs map { 
-      case m: ModuleDef =>
-        pathConstraints = Nil
-        enclosingLets  = Nil
-        newFunDefs  = Map()
-        topLevelFuns = Set()
-        parent = null
-
-        val funDefs = m.definedFunctions
-        funDefs.foreach(fd => {
-          parent = fd
-          pathConstraints = fd.precondition.toList
-          fd.body = fd.body.map(b => functionClosure(b, fd.params.map(_.id).toSet, Map(), Map()))
-        })
-
-        ModuleDef(m.id, m.defs ++ topLevelFuns, m.isPackageObject )
-      case cd => cd
-    })}
-    Program(newUnits)
+    fd +: funs.flatMap(close)
   }
 
-  private def functionClosure(expr: Expr, bindedVars: Set[Identifier], id2freshId: Map[Identifier, Identifier], fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = expr match {
-    case l @ LetDef(fd, rest) => {
-      val capturedVars: Set[Identifier] = bindedVars.diff(enclosingLets.map(_._1).toSet)
-      val capturedConstraints: Set[Expr] = pathConstraints.toSet
-
-      val freshIds: Map[Identifier, Identifier] = capturedVars.map(id => (id, id.freshen)).toMap
-
-      val extraValDefOldIds: Seq[Identifier] = capturedVars.toSeq
-      val extraValDefFreshIds: Seq[Identifier] = extraValDefOldIds.map(freshIds(_))
-      val extraValDefs: Seq[ValDef] = extraValDefFreshIds.map(ValDef(_))
-      val newValDefs: Seq[ValDef] = fd.params ++ extraValDefs
-      val newBindedVars: Set[Identifier] = bindedVars ++ fd.params.map(_.id)
-      val newFunId = FreshIdentifier(fd.id.name, alwaysShowUniqueID = true) //since we hoist this at the top level, we need to make it a unique name
-
-      val newFunDef = new FunDef(newFunId, fd.tparams, fd.returnType, newValDefs).copiedFrom(fd)
-      topLevelFuns += newFunDef
-      newFunDef.copyContentFrom(fd) //TODO: this still has some dangerous side effects (?)
-
-      def introduceLets(expr: Expr, fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = {
-        val (newExpr, _) = enclosingLets.foldLeft((expr, Map[Identifier, Identifier]()))((acc, p) => {
-          val newId = p._1.freshen
-          val newMap = acc._2 + (p._1 -> newId)
-          val newBody = functionClosure(acc._1, newBindedVars, freshIds ++ newMap, fd2FreshFd)
-          (Let(newId, p._2, newBody), newMap)
-        })
-        functionClosure(newExpr, newBindedVars, freshIds, fd2FreshFd)
-      }
-
-      val newPrecondition = simplifyLets(introduceLets(and((capturedConstraints ++ fd.precondition).toSeq :_*), fd2FreshFd))
-      newFunDef.precondition = if(newPrecondition == BooleanLiteral(true)) None else Some(newPrecondition)
-
-      val freshPostcondition = fd.postcondition.map { case post @ Lambda(args, body) =>
-        Lambda(args, introduceLets(body, fd2FreshFd).setPos(body)).setPos(post)
-      }
-      newFunDef.postcondition = freshPostcondition
-      
-      pathConstraints = fd.precOrTrue :: pathConstraints
-      val freshBody = fd.body.map(body => introduceLets(body, fd2FreshFd + (fd -> ((newFunDef, extraValDefOldIds.map(_.toVariable))))))
-      newFunDef.body = freshBody
-      pathConstraints = pathConstraints.tail
-
-      val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd -> ((newFunDef, extraValDefOldIds.map(_.toVariable)))))
-      freshRest.copiedFrom(l)
-    }
-    case l @ Let(i,e,b) => {
-      val re = functionClosure(e, bindedVars, id2freshId, fd2FreshFd)
-      //we need the enclosing lets to always refer to the original ids, because it might be expand later in a highly nested function
-      enclosingLets ::= (i, replace(id2freshId.map(p => (p._2.toVariable, p._1.toVariable)), re)) 
-      //pathConstraints :: Equals(i.toVariable, re)
-      val rb = functionClosure(b, bindedVars + i, id2freshId, fd2FreshFd)
-      enclosingLets = enclosingLets.tail
-      //pathConstraints = pathConstraints.tail
-      Let(i, re, rb).copiedFrom(l)
-    }
-    case i @ IfExpr(cond,thenn,elze) => {
-      /*
-         when acumulating path constraints, take the condition without closing it first, so this
-         might not work well with nested fundef in if then else condition
-      */
-      val rCond = functionClosure(cond, bindedVars, id2freshId, fd2FreshFd)
-      pathConstraints ::= cond//rCond
-      val rThen = functionClosure(thenn, bindedVars, id2freshId, fd2FreshFd)
-      pathConstraints = pathConstraints.tail
-      pathConstraints ::= Not(cond)//Not(rCond)
-      val rElze = functionClosure(elze, bindedVars, id2freshId, fd2FreshFd)
-      pathConstraints = pathConstraints.tail
-      IfExpr(rCond, rThen, rElze).copiedFrom(i)
-    }
-    case fi @ FunctionInvocation(tfd, args) => fd2FreshFd.get(tfd.fd) match {
-      case None =>
-        FunctionInvocation(tfd,
-                           args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd))).copiedFrom(fi)
-      case Some((nfd, extraArgs)) => 
-        FunctionInvocation(nfd.typed(tfd.tps),
-                           args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd)) ++ 
-                           extraArgs.map(v => replace(id2freshId.map(p => (p._1.toVariable, p._2.toVariable)), v))).copiedFrom(fi)
-    }
-    case m @ MatchExpr(scrut,cses) => {
-      val scrutRec = functionClosure(scrut, bindedVars, id2freshId, fd2FreshFd)
-      val csesRec = cses.map{ cse =>
-        import cse._
-        val binders = pattern.binders
-        val cond = conditionForPattern(scrut, pattern)
-        pathConstraints ::= cond
-        val rRhs = functionClosure(rhs, bindedVars ++ binders, id2freshId, fd2FreshFd)
-        val rGuard = optGuard map { functionClosure(_, bindedVars ++ binders, id2freshId, fd2FreshFd) }
-        pathConstraints = pathConstraints.tail
-        MatchCase(pattern, rGuard, rRhs)
-      }
-      matchExpr(scrutRec, csesRec).copiedFrom(m)
-    }
-    case v @ Variable(id) => id2freshId.get(id) match {
-      case None => v
-      case Some(nid) => Variable(nid)
-    }
-    case n @ Operator(args, recons) => {
-      val rargs = args.map(a => functionClosure(a, bindedVars, id2freshId, fd2FreshFd))
-      recons(rargs).copiedFrom(n)
-    }
-    case unhandled => scala.sys.error("Non-terminal case should be handled in FunctionClosure: " + unhandled)
+  private case class FunSubst(
+    newFd: FunDef,
+    paramsMap: Map[Identifier, Identifier],
+    tparamsMap: Map[TypeParameter, TypeParameter]
+  )
+
+  private def step(inner: FunDef, outer: FunDef, pc: Expr, free: Seq[Identifier]): FunSubst = {
+
+    val tpFresh = outer.tparams map { _.freshen }
+    val tparamsMap = outer.tparams.zip(tpFresh map {_.tp}).toMap
+    
+    val freshVals = (inner.paramIds ++ free).map{_.freshen}.map(instantiateType(_, tparamsMap))
+    val freeMap   = (inner.paramIds ++ free).zip(freshVals).toMap
+
+    val newFd = new FunDef(
+      inner.id.freshen,
+      inner.tparams ++ tpFresh,
+      instantiateType(inner.returnType, tparamsMap),
+      freshVals.map(ValDef(_))
+    )
+    newFd.copyContentFrom(inner)
+    newFd.precondition = Some(and(pc, inner.precOrTrue))
+
+    val instBody = instantiateType(
+      newFd.fullBody,
+      tparamsMap,
+      freeMap
+    )
+
+    newFd.fullBody = preMap {
+      case FunctionInvocation(tfd, args) if tfd.fd == inner =>
+        Some(FunctionInvocation(
+          newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }),
+          args ++ freshVals.drop(args.length).map(Variable)
+        ))
+      case _ => None
+    }(instBody)
+
+    FunSubst(newFd, freeMap, tparamsMap.map{ case (from, to) => from.tp -> to})
   }
 
-  def freshIdInPat(pat: Pattern, id2freshId: Map[Identifier, Identifier]): Pattern = pat match {
-    case InstanceOfPattern(binder, classTypeDef) => InstanceOfPattern(binder.map(id2freshId(_)), classTypeDef)
-    case WildcardPattern(binder) => WildcardPattern(binder.map(id2freshId(_)))
-    case CaseClassPattern(binder, caseClassDef, subPatterns) => CaseClassPattern(binder.map(id2freshId(_)), caseClassDef, subPatterns.map(freshIdInPat(_, id2freshId)))
-    case TuplePattern(binder, subPatterns) => TuplePattern(binder.map(id2freshId(_)), subPatterns.map(freshIdInPat(_, id2freshId)))
-    case UnapplyPattern(binder, fd, subPatterns) => UnapplyPattern(binder.map(id2freshId(_)), fd, subPatterns.map(freshIdInPat(_, id2freshId)))
-    case LiteralPattern(binder, lit) => LiteralPattern(binder.map(id2freshId(_)), lit)
+  override def apply(ctx: LeonContext, program: Program): Program = {
+    val newUnits = program.units.map { u => u.copy(defs = u.defs map {
+      case m: ModuleDef =>
+        ModuleDef(
+          m.id,
+          m.definedClasses ++ m.definedFunctions.flatMap(close),
+          m.isPackageObject
+        )
+      case cd =>
+        cd
+    })}
+    Program(newUnits)
   }
 
-  //filter the list of constraints, only keeping those relevant to the set of variables
-  def filterConstraints(vars: Set[Identifier]): (List[Expr], Set[Identifier]) = {
-    var allVars = vars
-    var newVars: Set[Identifier] = Set()
-    var constraints = pathConstraints
-    var filteredConstraints: List[Expr] = Nil
-    do {
-      allVars ++= newVars
-      newVars = Set()
-      constraints = pathConstraints.filterNot(filteredConstraints.contains(_))
-      constraints.foreach(expr => {
-        val vs = variablesOf(expr)
-        if(vs.intersect(allVars).nonEmpty) {
-          filteredConstraints ::= expr
-          newVars ++= vs.diff(allVars)
-        }
-      })
-    } while(newVars != Set())
-    (filteredConstraints, allVars)
-  }
 }