Skip to content
Snippets Groups Projects
Commit dc159e1d authored by Régis Blanc's avatar Régis Blanc Committed by Nicolas Voirol
Browse files

function closure only capture required variables

parent a18fc92f
No related branches found
No related tags found
No related merge requests found
...@@ -36,6 +36,8 @@ object FunctionClosure extends TransformationPhase { ...@@ -36,6 +36,8 @@ object FunctionClosure extends TransformationPhase {
val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap
val nestedFuns = nestedWithPaths.keys.toSeq val nestedFuns = nestedWithPaths.keys.toSeq
//println(nestedWithPaths)
// Transitively called funcions from each function // Transitively called funcions from each function
val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure( val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure(
nestedFuns.map { f => nestedFuns.map { f =>
...@@ -55,34 +57,30 @@ object FunctionClosure extends TransformationPhase { ...@@ -55,34 +57,30 @@ object FunctionClosure extends TransformationPhase {
def freeVars(fd: FunDef, pc: Path): Set[Identifier] = def freeVars(fd: FunDef, pc: Path): Set[Identifier] =
variablesOf(fd.fullBody) ++ pc.variables -- fd.paramIds -- pc.bindings.map(_._1) variablesOf(fd.fullBody) ++ pc.variables -- fd.paramIds -- pc.bindings.map(_._1)
//def freeVars(fd: FunDef): Set[Identifier] =
// variablesOf(fd.fullBody) -- fd.paramIds
// All free variables one should include. // All free variables one should include.
// Contains free vars of the function itself plus of all transitively called functions. // Contains free vars of the function itself plus of all transitively called functions.
// also contains free vars from PC if the PC is relevant to the fundef // also contains free vars from PC if the PC is relevant to the fundef
/*
val transFree = { val transFree = {
def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = { def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = {
nestedFuns.map(fd => { nestedFuns.map(fd => {
val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => freeVars(fd2)) val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => current(fd2))
val reqPaths = Seq(nestedWithPaths(fd)).filter(pathExpr => exists{ val reqPath = nestedWithPaths(fd).filterByIds(transFreeVars)
case _ => true //TODO: for now we take all PCs, need to refine (fd, transFreeVars ++ freeVars(fd, reqPath))
//case Variable(id) => transFreeVars.contains(id)
//case _ => false
}(pathExpr))
(fd, transFreeVars ++ reqPaths.flatMap(p => variablesOf(p)) -- fd.paramIds)
}).toMap }).toMap
} }
utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, freeVars(fd))).toMap) utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, variablesOf(fd.fullBody) -- fd.paramIds)).toMap)
}.map(p => (p._1, p._2.toSeq)) }.map(p => (p._1, p._2.toSeq))
*/
//println("free vars: " + transFree) //println("free vars: " + transFree)
// All free variables one should include. // All free variables one should include.
// Contains free vars of the function itself plus of all transitively called functions. // Contains free vars of the function itself plus of all transitively called functions.
val transFree = nestedFuns.map { fd => //val transFree = nestedFuns.map { fd =>
fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq // fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq
}.toMap //}.toMap
// Closed functions along with a map (old var -> new var). // Closed functions along with a map (old var -> new var).
val closed = nestedWithPaths.map { val closed = nestedWithPaths.map {
...@@ -168,7 +166,7 @@ object FunctionClosure extends TransformationPhase { ...@@ -168,7 +166,7 @@ object FunctionClosure extends TransformationPhase {
) )
val instBody = instantiateType( val instBody = instantiateType(
withPath(newFd.fullBody, pc), withPath(newFd.fullBody, pc.filterByIds(free.toSet)),
tparamsMap, tparamsMap,
freeMap freeMap
) )
......
...@@ -53,6 +53,19 @@ class Path private[purescala]( ...@@ -53,6 +53,19 @@ class Path private[purescala](
new Path(outers :+ Right(not(fold[Expr](BooleanLiteral(true), let, Constructors.and(_, _))(rest)))) new Path(outers :+ Right(not(fold[Expr](BooleanLiteral(true), let, Constructors.and(_, _))(rest))))
} }
def filterByIds(ids: Set[Identifier]): Path = {
def containsIds(ids: Set[Identifier])(e: Expr): Boolean = exists{
case Variable(id) => ids.contains(id)
case _ => false
}(e)
val newElements = elements.filter{
case Left((id, e)) => ids.contains(id) || containsIds(ids)(e)
case Right(e) => containsIds(ids)(e)
}
new Path(newElements)
}
lazy val variables: Set[Identifier] = fold[Set[Identifier]](Set.empty, lazy val variables: Set[Identifier] = fold[Set[Identifier]](Set.empty,
(id, e, res) => res - id ++ variablesOf(e), (e, res) => res ++ variablesOf(e) (id, e, res) => res - id ++ variablesOf(e), (e, res) => res ++ variablesOf(e)
)(elements) )(elements)
......
...@@ -25,4 +25,80 @@ class FunctionClosureSuite extends FunSuite with helpers.ExpressionsDSL { ...@@ -25,4 +25,80 @@ class FunctionClosureSuite extends FunSuite with helpers.ExpressionsDSL {
assert(fd1.body === cfd1.head.body) assert(fd1.body === cfd1.head.body)
} }
test("close does not capture param from parent if not needed") {
val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
nested.body = Some(y)
val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
fd.body = Some(LetDef(Seq(nested), x))
val cfds = FunctionClosure.close(fd)
assert(cfds.size === 2)
cfds.foreach(cfd => {
if(cfd.id.name == "f") {
assert(cfd.returnType === fd.returnType)
assert(cfd.params.size === fd.params.size)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "nested") {
assert(cfd.returnType === nested.returnType)
assert(cfd.params.size === nested.params.size)
assert(freeVars(cfd).isEmpty)
} else {
fail("Unexpected fun def: " + cfd)
}
})
val nested2 = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
nested2.body = Some(y)
val fd2 = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
fd2.body = Some(Let(z.id, Plus(x, bi(1)), LetDef(Seq(nested2), x)))
val cfds2 = FunctionClosure.close(fd2)
assert(cfds2.size === 2)
cfds2.foreach(cfd => {
if(cfd.id.name == "f") {
assert(cfd.returnType === fd2.returnType)
assert(cfd.params.size === fd2.params.size)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "nested") {
assert(cfd.returnType === nested2.returnType)
assert(cfd.params.size === nested2.params.size)
assert(freeVars(cfd).isEmpty)
} else {
fail("Unexpected fun def: " + cfd)
}
})
}
test("close does not capture enclosing require if not needed") {
val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
nested.body = Some(y)
val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
fd.body = Some(Require(GreaterEquals(x, bi(0)), Let(z.id, Plus(x, bi(1)), LetDef(Seq(nested), x))))
val cfds = FunctionClosure.close(fd)
assert(cfds.size === 2)
cfds.foreach(cfd => {
if(cfd.id.name == "f") {
assert(cfd.returnType === fd.returnType)
assert(cfd.params.size === fd.params.size)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "nested") {
assert(cfd.returnType === nested.returnType)
assert(cfd.params.size === nested.params.size)
assert(freeVars(cfd).isEmpty)
} else {
fail("Unexpected fun def: " + cfd)
}
})
}
private def freeVars(fd: FunDef): Set[Identifier] = variablesOf(fd.fullBody) -- fd.paramIds
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment