From f85c484b679125fc468e301c362c4e2269c2a543 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Mon, 7 May 2012 15:35:01 +0200
Subject: [PATCH] merge closure and hoisting. fix wrong id when accessing let
 value defined with more than 1 level higher

---
 run-tests.sh                              |  6 +++---
 src/main/scala/leon/FunctionClosure.scala | 17 ++++++++++++-----
 src/main/scala/leon/Main.scala            |  2 +-
 testcases/regression/valid/Nested9.scala  | 23 +++++++++++++++++++++++
 4 files changed, 39 insertions(+), 9 deletions(-)
 create mode 100644 testcases/regression/valid/Nested9.scala

diff --git a/run-tests.sh b/run-tests.sh
index cedaea2d3..951862c3e 100755
--- a/run-tests.sh
+++ b/run-tests.sh
@@ -6,7 +6,7 @@ failedtests=""
 
 for f in testcases/regression/valid/*.scala; do
   echo -n "Running $f, expecting VALID, got: "
-  res=`./leon --timeout=5 --oneline "$f"`
+  res=`./leon --timeout=10 --oneline "$f"`
   echo $res | tr [a-z] [A-Z]
   if [ $res = valid ]; then
     nbsuccess=$((nbsuccess + 1))
@@ -17,7 +17,7 @@ done
 
 for f in testcases/regression/invalid/*.scala; do
   echo -n "Running $f, expecting INVALID, got: "
-  res=`./leon --timeout=5 --oneline "$f"`
+  res=`./leon --timeout=10 --oneline "$f"`
   echo $res | tr [a-z] [A-Z]
   if [ $res = invalid ]; then
     nbsuccess=$((nbsuccess + 1))
@@ -28,7 +28,7 @@ done
 
 for f in testcases/regression/error/*.scala; do
   echo -n "Running $f, expecting ERROR, got: "
-  res=`./leon --timeout=5 --oneline "$f"`
+  res=`./leon --timeout=10 --oneline "$f"`
   echo $res | tr [a-z] [A-Z]
   if [ $res = error ]; then
     nbsuccess=$((nbsuccess + 1))
diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala
index 01281be29..40f0b8716 100644
--- a/src/main/scala/leon/FunctionClosure.scala
+++ b/src/main/scala/leon/FunctionClosure.scala
@@ -12,6 +12,8 @@ object FunctionClosure extends Pass {
   private var pathConstraints: List[Expr] = Nil
   private var enclosingLets: List[(Identifier, Expr)] = Nil
   private var newFunDefs: Map[FunDef, FunDef] = Map()
+  private var originalsIds: Map[Identifier, Identifier] = Map()
+  private var topLevelFuns: Set[FunDef] = Set()
 
   def apply(program: Program): Program = {
     newFunDefs = Map()
@@ -20,7 +22,8 @@ object FunctionClosure extends Pass {
       pathConstraints = fd.precondition.toList
       fd.body = fd.body.map(b => functionClosure(b, fd.args.map(_.id).toSet, Map(), Map()))
     })
-    program
+    val Program(id, ObjectDef(objId, defs, invariants)) = program
+    Program(id, ObjectDef(objId, defs ++ topLevelFuns, invariants))
   }
 
   private def functionClosure(expr: Expr, bindedVars: Set[Identifier], id2freshId: Map[Identifier, Identifier], fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = expr match {
@@ -29,6 +32,7 @@ object FunctionClosure extends Pass {
       val capturedConstraints: Set[Expr] = pathConstraints.toSet
 
       val freshIds: Map[Identifier, Identifier] = capturedVars.map(id => (id, FreshIdentifier(id.name).setType(id.getType))).toMap
+      freshIds.foreach(p => originalsIds += (p._2 -> p._1))
       val freshVars: Map[Expr, Expr] = freshIds.map(p => (p._1.toVariable, p._2.toVariable))
       
       val extraVarDeclOldIds: Seq[Identifier] = capturedVars.toSeq
@@ -39,6 +43,7 @@ object FunctionClosure extends Pass {
       val newFunId = FreshIdentifier(fd.id.name)
 
       val newFunDef = new FunDef(newFunId, fd.returnType, newVarDecls).setPosInfo(fd)
+      topLevelFuns += newFunDef
       newFunDef.fromLoop = fd.fromLoop
       newFunDef.parent = fd.parent
       newFunDef.addAnnotation(fd.annotations.toSeq:_*)
@@ -57,11 +62,11 @@ object FunctionClosure extends Pass {
 
       val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd -> 
                         ((newFunDef, extraVarDeclOldIds.map(id => id2freshId.get(id).getOrElse(id).toVariable)))))
-      LetDef(newFunDef, freshRest).setType(l.getType)
+      freshRest.setType(l.getType)
     }
     case l @ Let(i,e,b) => {
       val re = functionClosure(e, bindedVars, id2freshId, fd2FreshFd)
-      enclosingLets ::= (i, re)
+      enclosingLets ::= (i, replace(originalsIds.map(p => (p._1.toVariable, p._2.toVariable)), re))
       //pathConstraints :: Equals(i.toVariable, re)
       val rb = functionClosure(b, bindedVars + i, id2freshId, fd2FreshFd)
       enclosingLets = enclosingLets.tail
@@ -106,8 +111,10 @@ object FunctionClosure extends Pass {
     }
     case v @ Variable(id) => id2freshId.get(id) match {
       case None => replace(
-                    id2freshId.map(p => (p._1.toVariable, p._2.toVariable)).toMap, 
-                    enclosingLets.foldLeft(v: Expr){ case (expr, (id, value)) => replace(Map(id.toVariable -> value), expr) })
+                     id2freshId.map(p => (p._1.toVariable, p._2.toVariable)),
+                     enclosingLets.foldLeft(v: Expr){ 
+                       case (expr, (id, value)) => replace(Map(id.toVariable -> value), expr) 
+                     })
       case Some(nid) => Variable(nid)
     }
     case t if t.isInstanceOf[Terminal] => t
diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala
index b532c2e7c..9af1b41e7 100644
--- a/src/main/scala/leon/Main.scala
+++ b/src/main/scala/leon/Main.scala
@@ -32,7 +32,7 @@ object Main {
 
   private def defaultAction(program: Program, reporter: Reporter) : Unit = {
     Logger.debug("Default action on program: " + program, 3, "main")
-    val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, /*UnitElimination,*/ FunctionClosure, FunctionHoisting, Simplificator))
+    val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, /*UnitElimination,*/ FunctionClosure, /*FunctionHoisting,*/ Simplificator))
     val program2 = passManager.run(program)
     assert(program2.isPure)
     val analysis = new Analysis(program2, reporter)
diff --git a/testcases/regression/valid/Nested9.scala b/testcases/regression/valid/Nested9.scala
new file mode 100644
index 000000000..3344a2c79
--- /dev/null
+++ b/testcases/regression/valid/Nested9.scala
@@ -0,0 +1,23 @@
+object Nested4 {
+
+  def foo(a: Int, a2: Int): Int = {
+    require(a >= 0 && a <= 50)
+    val b = a + 2
+    val c = a + b
+    if(a2 > a) {
+      def rec1(d: Int): Int = {
+        require(d >= 0 && d <= 50)
+        val e = d + b + c + a2
+        def rec2(f: Int): Int = {
+          require(f >= c)
+          f + e
+        } ensuring(_ > 0)
+        rec2(c+1)
+      } ensuring(_ > 0)
+      rec1(2)
+    } else {
+      5
+    }
+  } ensuring(_ > 0)
+
+}
-- 
GitLab