From dc159e1df9d5b5e2117cc8bd4ac2b4b73f9e650a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Thu, 14 Apr 2016 22:54:37 +0200 Subject: [PATCH] function closure only capture required variables --- .../leon/purescala/FunctionClosure.scala | 26 +++---- src/main/scala/leon/purescala/Path.scala | 13 ++++ .../unit/purescala/FunctionClosureSuite.scala | 76 +++++++++++++++++++ 3 files changed, 101 insertions(+), 14 deletions(-) diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index c67e222d8..e69050bf6 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -36,6 +36,8 @@ object FunctionClosure extends TransformationPhase { val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap val nestedFuns = nestedWithPaths.keys.toSeq + //println(nestedWithPaths) + // Transitively called funcions from each function val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure( nestedFuns.map { f => @@ -55,34 +57,30 @@ object FunctionClosure extends TransformationPhase { def freeVars(fd: FunDef, pc: Path): Set[Identifier] = variablesOf(fd.fullBody) ++ pc.variables -- fd.paramIds -- pc.bindings.map(_._1) + //def freeVars(fd: FunDef): Set[Identifier] = + // variablesOf(fd.fullBody) -- fd.paramIds // All free variables one should include. // Contains free vars of the function itself plus of all transitively called functions. // also contains free vars from PC if the PC is relevant to the fundef - /* val transFree = { def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = { nestedFuns.map(fd => { - val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => freeVars(fd2)) - val reqPaths = Seq(nestedWithPaths(fd)).filter(pathExpr => exists{ - case _ => true //TODO: for now we take all PCs, need to refine - //case Variable(id) => transFreeVars.contains(id) - //case _ => false - }(pathExpr)) - (fd, transFreeVars ++ reqPaths.flatMap(p => variablesOf(p)) -- fd.paramIds) + val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => current(fd2)) + val reqPath = nestedWithPaths(fd).filterByIds(transFreeVars) + (fd, transFreeVars ++ freeVars(fd, reqPath)) }).toMap } - utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, freeVars(fd))).toMap) + utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, variablesOf(fd.fullBody) -- fd.paramIds)).toMap) }.map(p => (p._1, p._2.toSeq)) - */ //println("free vars: " + transFree) // All free variables one should include. // Contains free vars of the function itself plus of all transitively called functions. - val transFree = nestedFuns.map { fd => - fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq - }.toMap + //val transFree = nestedFuns.map { fd => + // fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq + //}.toMap // Closed functions along with a map (old var -> new var). val closed = nestedWithPaths.map { @@ -168,7 +166,7 @@ object FunctionClosure extends TransformationPhase { ) val instBody = instantiateType( - withPath(newFd.fullBody, pc), + withPath(newFd.fullBody, pc.filterByIds(free.toSet)), tparamsMap, freeMap ) diff --git a/src/main/scala/leon/purescala/Path.scala b/src/main/scala/leon/purescala/Path.scala index 1ce4e13eb..9b7e4786a 100644 --- a/src/main/scala/leon/purescala/Path.scala +++ b/src/main/scala/leon/purescala/Path.scala @@ -53,6 +53,19 @@ class Path private[purescala]( new Path(outers :+ Right(not(fold[Expr](BooleanLiteral(true), let, Constructors.and(_, _))(rest)))) } + def filterByIds(ids: Set[Identifier]): Path = { + def containsIds(ids: Set[Identifier])(e: Expr): Boolean = exists{ + case Variable(id) => ids.contains(id) + case _ => false + }(e) + + val newElements = elements.filter{ + case Left((id, e)) => ids.contains(id) || containsIds(ids)(e) + case Right(e) => containsIds(ids)(e) + } + new Path(newElements) + } + lazy val variables: Set[Identifier] = fold[Set[Identifier]](Set.empty, (id, e, res) => res - id ++ variablesOf(e), (e, res) => res ++ variablesOf(e) )(elements) diff --git a/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala b/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala index c7274e560..faef38195 100644 --- a/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala +++ b/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala @@ -25,4 +25,80 @@ class FunctionClosureSuite extends FunSuite with helpers.ExpressionsDSL { assert(fd1.body === cfd1.head.body) } + test("close does not capture param from parent if not needed") { + val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType) + nested.body = Some(y) + + val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType) + fd.body = Some(LetDef(Seq(nested), x)) + + val cfds = FunctionClosure.close(fd) + assert(cfds.size === 2) + + cfds.foreach(cfd => { + if(cfd.id.name == "f") { + assert(cfd.returnType === fd.returnType) + assert(cfd.params.size === fd.params.size) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "nested") { + assert(cfd.returnType === nested.returnType) + assert(cfd.params.size === nested.params.size) + assert(freeVars(cfd).isEmpty) + } else { + fail("Unexpected fun def: " + cfd) + } + }) + + + val nested2 = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType) + nested2.body = Some(y) + + val fd2 = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType) + fd2.body = Some(Let(z.id, Plus(x, bi(1)), LetDef(Seq(nested2), x))) + + val cfds2 = FunctionClosure.close(fd2) + assert(cfds2.size === 2) + + cfds2.foreach(cfd => { + if(cfd.id.name == "f") { + assert(cfd.returnType === fd2.returnType) + assert(cfd.params.size === fd2.params.size) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "nested") { + assert(cfd.returnType === nested2.returnType) + assert(cfd.params.size === nested2.params.size) + assert(freeVars(cfd).isEmpty) + } else { + fail("Unexpected fun def: " + cfd) + } + }) + } + + test("close does not capture enclosing require if not needed") { + val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType) + nested.body = Some(y) + + val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType) + fd.body = Some(Require(GreaterEquals(x, bi(0)), Let(z.id, Plus(x, bi(1)), LetDef(Seq(nested), x)))) + + val cfds = FunctionClosure.close(fd) + assert(cfds.size === 2) + + cfds.foreach(cfd => { + if(cfd.id.name == "f") { + assert(cfd.returnType === fd.returnType) + assert(cfd.params.size === fd.params.size) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "nested") { + assert(cfd.returnType === nested.returnType) + assert(cfd.params.size === nested.params.size) + assert(freeVars(cfd).isEmpty) + } else { + fail("Unexpected fun def: " + cfd) + } + }) + } + + private def freeVars(fd: FunDef): Set[Identifier] = variablesOf(fd.fullBody) -- fd.paramIds + } -- GitLab