diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala
index 4bca1518def1b35d63ab6ff5d3bfeb67bd084efa..0fa481c6887d1b87cae7ab89e2a4603d6e5b3bcc 100644
--- a/src/main/scala/leon/purescala/TreeOps.scala
+++ b/src/main/scala/leon/purescala/TreeOps.scala
@@ -1949,17 +1949,39 @@ object TreeOps {
    * the necessary information as arguments, no need to close them.
    */
   def liftClosures(e: Expr): (Set[FunDef], Expr) = {
-    var fds: Set[FunDef] = Set()
+    var fds: Map[FunDef, FunDef] = Map()
 
-    val res = postMap{
+    val res1 = preMap({
+      case LetDef(fd, b) =>
+        val nfd = new FunDef(fd.id.freshen, fd.tparams, fd.returnType, fd.params, fd.defType)
+        nfd.copyContentFrom(fd)
+        nfd.copiedFrom(fd)
+
+        fds += fd -> nfd
+
+        Some(LetDef(nfd, b))
+
+      case fi @ FunctionInvocation(tfd, args) =>
+        if (fds contains tfd.fd) {
+          Some(FunctionInvocation(fds(tfd.fd).typed(tfd.tps), args))
+        } else {
+          None
+        }
+
+      case _ =>
+        None
+    })(e)
+
+    // we now remove LetDefs
+    val res2 = preMap({
       case LetDef(fd, b) =>
-        fds += fd
         Some(b)
       case _ =>
         None
-    }(e)
+    }, applyRec = true)(res1)
+
 
-    (fds, res)
+    (fds.values.toSet, res2)
   }
   
   def preTraversalWithParent(f: (Expr, Option[Tree]) => Unit, initParent: Option[Tree] = None)(e: Expr): Unit = {