From dc159e1df9d5b5e2117cc8bd4ac2b4b73f9e650a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Thu, 14 Apr 2016 22:54:37 +0200
Subject: [PATCH] function closure only capture required variables

---
 .../leon/purescala/FunctionClosure.scala      | 26 +++----
 src/main/scala/leon/purescala/Path.scala      | 13 ++++
 .../unit/purescala/FunctionClosureSuite.scala | 76 +++++++++++++++++++
 3 files changed, 101 insertions(+), 14 deletions(-)

diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala
index c67e222d8..e69050bf6 100644
--- a/src/main/scala/leon/purescala/FunctionClosure.scala
+++ b/src/main/scala/leon/purescala/FunctionClosure.scala
@@ -36,6 +36,8 @@ object FunctionClosure extends TransformationPhase {
     val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap
     val nestedFuns = nestedWithPaths.keys.toSeq
 
+    //println(nestedWithPaths)
+
     // Transitively called funcions from each function
     val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure(
       nestedFuns.map { f =>
@@ -55,34 +57,30 @@ object FunctionClosure extends TransformationPhase {
 
     def freeVars(fd: FunDef, pc: Path): Set[Identifier] =
       variablesOf(fd.fullBody) ++ pc.variables -- fd.paramIds -- pc.bindings.map(_._1)
+    //def freeVars(fd: FunDef): Set[Identifier] =
+    //  variablesOf(fd.fullBody) -- fd.paramIds
 
     // All free variables one should include.
     // Contains free vars of the function itself plus of all transitively called functions.
     // also contains free vars from PC if the PC is relevant to the fundef
-    /*
     val transFree = {
       def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = {
         nestedFuns.map(fd => {
-          val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => freeVars(fd2))
-          val reqPaths = Seq(nestedWithPaths(fd)).filter(pathExpr => exists{
-            case _ => true //TODO: for now we take all PCs, need to refine
-            //case Variable(id) => transFreeVars.contains(id)
-            //case _ => false
-          }(pathExpr))
-          (fd, transFreeVars ++ reqPaths.flatMap(p => variablesOf(p)) -- fd.paramIds)
+          val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => current(fd2))
+          val reqPath = nestedWithPaths(fd).filterByIds(transFreeVars)
+          (fd, transFreeVars ++ freeVars(fd, reqPath))
         }).toMap
       }
 
-      utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, freeVars(fd))).toMap)
+      utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, variablesOf(fd.fullBody) -- fd.paramIds)).toMap)
     }.map(p => (p._1, p._2.toSeq))
-    */
     //println("free vars: " + transFree)
 
     // All free variables one should include.
     // Contains free vars of the function itself plus of all transitively called functions.
-    val transFree = nestedFuns.map { fd =>
-      fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq
-    }.toMap
+    //val transFree = nestedFuns.map { fd =>
+    //  fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq
+    //}.toMap
 
     // Closed functions along with a map (old var -> new var).
     val closed = nestedWithPaths.map {
@@ -168,7 +166,7 @@ object FunctionClosure extends TransformationPhase {
     )
 
     val instBody = instantiateType(
-      withPath(newFd.fullBody, pc),
+      withPath(newFd.fullBody, pc.filterByIds(free.toSet)),
       tparamsMap,
       freeMap
     )
diff --git a/src/main/scala/leon/purescala/Path.scala b/src/main/scala/leon/purescala/Path.scala
index 1ce4e13eb..9b7e4786a 100644
--- a/src/main/scala/leon/purescala/Path.scala
+++ b/src/main/scala/leon/purescala/Path.scala
@@ -53,6 +53,19 @@ class Path private[purescala](
     new Path(outers :+ Right(not(fold[Expr](BooleanLiteral(true), let, Constructors.and(_, _))(rest))))
   }
 
+  def filterByIds(ids: Set[Identifier]): Path = {
+    def containsIds(ids: Set[Identifier])(e: Expr): Boolean = exists{
+      case Variable(id) => ids.contains(id)
+      case _ => false
+    }(e)
+    
+    val newElements = elements.filter{
+      case Left((id, e)) => ids.contains(id) || containsIds(ids)(e)
+      case Right(e) => containsIds(ids)(e)
+    }
+    new Path(newElements)
+  }
+
   lazy val variables: Set[Identifier] = fold[Set[Identifier]](Set.empty,
     (id, e, res) => res - id ++ variablesOf(e), (e, res) => res ++ variablesOf(e)
   )(elements)
diff --git a/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala b/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala
index c7274e560..faef38195 100644
--- a/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala
+++ b/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala
@@ -25,4 +25,80 @@ class FunctionClosureSuite extends FunSuite with helpers.ExpressionsDSL {
     assert(fd1.body === cfd1.head.body)
   }
 
+  test("close does not capture param from parent if not needed") {
+    val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
+    nested.body = Some(y)
+
+    val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
+    fd.body = Some(LetDef(Seq(nested), x))
+
+    val cfds = FunctionClosure.close(fd)
+    assert(cfds.size === 2)
+
+    cfds.foreach(cfd => {
+      if(cfd.id.name == "f") {
+        assert(cfd.returnType === fd.returnType)
+        assert(cfd.params.size === fd.params.size)
+        assert(freeVars(cfd).isEmpty)
+      } else if(cfd.id.name == "nested") {
+        assert(cfd.returnType === nested.returnType)
+        assert(cfd.params.size === nested.params.size)
+        assert(freeVars(cfd).isEmpty)
+      } else {
+        fail("Unexpected fun def: " + cfd)
+      }
+    })
+
+
+    val nested2 = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
+    nested2.body = Some(y)
+
+    val fd2 = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
+    fd2.body = Some(Let(z.id, Plus(x, bi(1)), LetDef(Seq(nested2), x)))
+
+    val cfds2 = FunctionClosure.close(fd2)
+    assert(cfds2.size === 2)
+
+    cfds2.foreach(cfd => {
+      if(cfd.id.name == "f") {
+        assert(cfd.returnType === fd2.returnType)
+        assert(cfd.params.size === fd2.params.size)
+        assert(freeVars(cfd).isEmpty)
+      } else if(cfd.id.name == "nested") {
+        assert(cfd.returnType === nested2.returnType)
+        assert(cfd.params.size === nested2.params.size)
+        assert(freeVars(cfd).isEmpty)
+      } else {
+        fail("Unexpected fun def: " + cfd)
+      }
+    })
+  }
+
+  test("close does not capture enclosing require if not needed") {
+    val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
+    nested.body = Some(y)
+
+    val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
+    fd.body = Some(Require(GreaterEquals(x, bi(0)), Let(z.id, Plus(x, bi(1)), LetDef(Seq(nested), x))))
+
+    val cfds = FunctionClosure.close(fd)
+    assert(cfds.size === 2)
+
+    cfds.foreach(cfd => {
+      if(cfd.id.name == "f") {
+        assert(cfd.returnType === fd.returnType)
+        assert(cfd.params.size === fd.params.size)
+        assert(freeVars(cfd).isEmpty)
+      } else if(cfd.id.name == "nested") {
+        assert(cfd.returnType === nested.returnType)
+        assert(cfd.params.size === nested.params.size)
+        assert(freeVars(cfd).isEmpty)
+      } else {
+        fail("Unexpected fun def: " + cfd)
+      }
+    })
+  }
+
+  private def freeVars(fd: FunDef): Set[Identifier] = variablesOf(fd.fullBody) -- fd.paramIds
+
 }
-- 
GitLab