diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 4bca1518def1b35d63ab6ff5d3bfeb67bd084efa..0fa481c6887d1b87cae7ab89e2a4603d6e5b3bcc 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1949,17 +1949,39 @@ object TreeOps { * the necessary information as arguments, no need to close them. */ def liftClosures(e: Expr): (Set[FunDef], Expr) = { - var fds: Set[FunDef] = Set() + var fds: Map[FunDef, FunDef] = Map() - val res = postMap{ + val res1 = preMap({ + case LetDef(fd, b) => + val nfd = new FunDef(fd.id.freshen, fd.tparams, fd.returnType, fd.params, fd.defType) + nfd.copyContentFrom(fd) + nfd.copiedFrom(fd) + + fds += fd -> nfd + + Some(LetDef(nfd, b)) + + case fi @ FunctionInvocation(tfd, args) => + if (fds contains tfd.fd) { + Some(FunctionInvocation(fds(tfd.fd).typed(tfd.tps), args)) + } else { + None + } + + case _ => + None + })(e) + + // we now remove LetDefs + val res2 = preMap({ case LetDef(fd, b) => - fds += fd Some(b) case _ => None - }(e) + }, applyRec = true)(res1) + - (fds, res) + (fds.values.toSet, res2) } def preTraversalWithParent(f: (Expr, Option[Tree]) => Unit, initParent: Option[Tree] = None)(e: Expr): Unit = {