From 4d6019a2cd03181ce6a9774c532b3d9331b32ee9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Tue, 5 Jun 2012 11:53:46 +0200
Subject: [PATCH] propagate elimination of imperative code in nested functions

---
 .../leon/ImperativeCodeElimination.scala      |  7 ++++++-
 testcases/regression/valid/IfExpr3.scala      | 19 +++++++++++++++++++
 testcases/regression/valid/IfExpr4.scala      | 18 ++++++++++++++++++
 testcases/regression/valid/NestedVar.scala    | 17 +++++++++++++++++
 4 files changed, 60 insertions(+), 1 deletion(-)
 create mode 100644 testcases/regression/valid/IfExpr3.scala
 create mode 100644 testcases/regression/valid/IfExpr4.scala
 create mode 100644 testcases/regression/valid/NestedVar.scala

diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala
index 3cc81b703..d3cd29907 100644
--- a/src/main/scala/leon/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/ImperativeCodeElimination.scala
@@ -225,8 +225,13 @@ object ImperativeCodeElimination extends Pass {
       }
       case LetDef(fd, b) => {
         //Recall that here the nested function should not access mutable variables from an outside scope
+        val newFd = if(!fd.hasImplementation) fd else {
+          val (fdRes, fdScope, fdFun) = toFunction(fd.getBody)
+          fd.body = Some(fdScope(fdRes))
+          fd
+        }
         val (bodyRes, bodyScope, bodyFun) = toFunction(b)
-        (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)), bodyFun)
+        (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)), bodyFun)
       }
       case n @ NAryOperator(Seq(), recons) => (n, (body: Expr) => body, Map())
       case n @ NAryOperator(args, recons) => {
diff --git a/testcases/regression/valid/IfExpr3.scala b/testcases/regression/valid/IfExpr3.scala
new file mode 100644
index 000000000..86d4e494a
--- /dev/null
+++ b/testcases/regression/valid/IfExpr3.scala
@@ -0,0 +1,19 @@
+object IfExpr1 {
+
+  def foo(a: Int): Int = {
+
+    if(a > 0) {
+      var a = 1
+      var b = 2
+      a = 3
+      a + b
+    } else {
+      5
+      //var a = 3
+      //var b = 1
+      //b = b + 1
+      //a + b
+    }
+  } ensuring(_ == 5)
+
+}
diff --git a/testcases/regression/valid/IfExpr4.scala b/testcases/regression/valid/IfExpr4.scala
new file mode 100644
index 000000000..94b99fde3
--- /dev/null
+++ b/testcases/regression/valid/IfExpr4.scala
@@ -0,0 +1,18 @@
+object IfExpr4 {
+
+  def foo(a: Int): Int = {
+
+    if(a > 0) {
+      var a = 1
+      var b = 2
+      a = 3
+      a + b
+    } else {
+      var a = 3
+      var b = 1
+      b = b + 1
+      a + b
+    }
+  } ensuring(_ == 5)
+
+}
diff --git a/testcases/regression/valid/NestedVar.scala b/testcases/regression/valid/NestedVar.scala
new file mode 100644
index 000000000..a26b7312b
--- /dev/null
+++ b/testcases/regression/valid/NestedVar.scala
@@ -0,0 +1,17 @@
+object NestedVar {
+
+  def foo(): Int = {
+    val a = 3
+    def rec(x: Int): Int = {
+      var b = 3
+      var c = 3
+      if(x > 0)
+        b = 2
+      else
+        c = 2
+      c+b
+    }
+    rec(a)
+  } ensuring(_ == 5)
+
+}
-- 
GitLab