From ebbed9d09ed767cf3309d9f52b271a8ec66f81b3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Mon, 26 Mar 2012 15:46:25 +0200
Subject: [PATCH] function with more than one level of nested function now work

---
 src/main/scala/leon/FunctionHoisting.scala      |  5 ++++-
 .../scala/leon/ImperativeCodeElimination.scala  | 17 ++++++++++-------
 .../scala/leon/purescala/PrettyPrinter.scala    |  2 +-
 3 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala
index 9f539f1ba..88881e2ce 100644
--- a/src/main/scala/leon/FunctionHoisting.scala
+++ b/src/main/scala/leon/FunctionHoisting.scala
@@ -24,7 +24,10 @@ object FunctionHoisting extends Pass {
   private def hoist(expr: Expr): (Expr, Set[FunDef]) = expr match {
     case l @ LetDef(fd, rest) => {
       val (e, s) = hoist(rest)
-      (e, s + fd)
+      val (e2, s2) = hoist(fd.getBody)
+      fd.body = Some(e2)
+
+      (e, (s ++ s2) + fd)
     }
     case l @ Let(i,e,b) => {
       val (re, s1) = hoist(e)
diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala
index abf197da2..1f311533d 100644
--- a/src/main/scala/leon/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/ImperativeCodeElimination.scala
@@ -12,13 +12,8 @@ object ImperativeCodeElimination extends Pass {
   def apply(pgm: Program): Program = {
     val allFuns = pgm.definedFunctions
     allFuns.foreach(fd => {
-      val (res, _, _) = toFunction(fd.getBody)
-      //val ((scope, fun), last) = fd.getBody match {
-      //  case Block(stmts, stmt) => (toFunction(Block(stmts.init, stmts.last)), stmt)
-      //  case _ => sys.error("not supported")
-      //}
-      //fd.body = Some(scope(replace(fun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, last)))
-      fd.body = Some(res)
+      val (res, scope, _) = toFunction(fd.getBody)
+      fd.body = Some(scope(res))
     })
     pgm
   }
@@ -102,6 +97,14 @@ object ImperativeCodeElimination extends Pass {
       }
 
       //pure expression (that could still contain side effects as a subexpression)
+      case Let(id, e, b) => {
+        val (bodyRes, bodyScope, bodyFun) = toFunction(b)
+        (bodyRes, (b: Expr) => Let(id, e, bodyScope(b)), bodyFun)
+      }
+      case LetDef(fd, b) => {
+        val (bodyRes, bodyScope, bodyFun) = toFunction(b)
+        (bodyRes, (b: Expr) => LetDef(fd, bodyScope(b)), bodyFun)
+      }
       case n @ NAryOperator(args, recons) => {
         val (recArgs, scope, fun) = args.foldLeft((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((acc, arg) => {
           val (accArgs, scope, fun) = acc
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index 7df4adbec..0e2015591 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -80,7 +80,7 @@ object PrettyPrinter {
     }
     case LetDef(fd,e) => {
       sb.append("\n")
-      pp(fd, sb, lvl)
+      pp(fd, sb, lvl+1)
       sb.append("\n")
       sb.append("\n")
       ind(sb, lvl)
-- 
GitLab