Skip to content
Snippets Groups Projects
Commit 9d7510a8 authored by Régis Blanc's avatar Régis Blanc Committed by Ravi
Browse files

function closure does not capture let bindings

parent 62c98020
No related branches found
No related tags found
No related merge requests found
...@@ -36,8 +36,6 @@ object FunctionClosure extends TransformationPhase { ...@@ -36,8 +36,6 @@ object FunctionClosure extends TransformationPhase {
val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap
val nestedFuns = nestedWithPaths.keys.toSeq val nestedFuns = nestedWithPaths.keys.toSeq
//println(nestedWithPaths)
// Transitively called funcions from each function // Transitively called funcions from each function
val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure( val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure(
nestedFuns.map { f => nestedFuns.map { f =>
...@@ -56,14 +54,12 @@ object FunctionClosure extends TransformationPhase { ...@@ -56,14 +54,12 @@ object FunctionClosure extends TransformationPhase {
//println("call graph: " + callGraph) //println("call graph: " + callGraph)
def freeVars(fd: FunDef, pc: Path): Set[Identifier] = def freeVars(fd: FunDef, pc: Path): Set[Identifier] =
variablesOf(fd.fullBody) ++ pc.variables -- fd.paramIds -- pc.bindings.map(_._1) variablesOf(fd.fullBody) ++ pc.variables ++ pc.bindings.map(_._1) -- fd.paramIds
//def freeVars(fd: FunDef): Set[Identifier] =
// variablesOf(fd.fullBody) -- fd.paramIds
// All free variables one should include. // All free variables one should include.
// Contains free vars of the function itself plus of all transitively called functions. // Contains free vars of the function itself plus of all transitively called functions.
// also contains free vars from PC if the PC is relevant to the fundef // also contains free vars from PC if the PC is relevant to the fundef
val transFree = { val transFreeWithBindings = {
def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = { def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = {
nestedFuns.map(fd => { nestedFuns.map(fd => {
val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => current(fd2)) val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => current(fd2))
...@@ -73,14 +69,12 @@ object FunctionClosure extends TransformationPhase { ...@@ -73,14 +69,12 @@ object FunctionClosure extends TransformationPhase {
} }
utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, variablesOf(fd.fullBody) -- fd.paramIds)).toMap) utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, variablesOf(fd.fullBody) -- fd.paramIds)).toMap)
}.map(p => (p._1, p._2.toSeq)) }
//println("free vars: " + transFree)
val transFree: Map[FunDef, Seq[Identifier]] =
//transFreeWithBindings.map(p => (p._1, p._2 -- nestedWithPaths(p._1).bindings.map(_._1))).map(p => (p._1, p._2.toSeq))
transFreeWithBindings.map(p => (p._1, p._2.toSeq))
// All free variables one should include.
// Contains free vars of the function itself plus of all transitively called functions.
//val transFree = nestedFuns.map { fd =>
// fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq
//}.toMap
// Closed functions along with a map (old var -> new var). // Closed functions along with a map (old var -> new var).
val closed = nestedWithPaths.map { val closed = nestedWithPaths.map {
...@@ -151,31 +145,38 @@ object FunctionClosure extends TransformationPhase { ...@@ -151,31 +145,38 @@ object FunctionClosure extends TransformationPhase {
// Takes one inner function and closes it. // Takes one inner function and closes it.
private def closeFd(inner: FunDef, outer: FunDef, pc: Path, free: Seq[Identifier]): FunSubst = { private def closeFd(inner: FunDef, outer: FunDef, pc: Path, free: Seq[Identifier]): FunSubst = {
//println("inner: " + inner)
//println("pc: " + pc)
//println("free: " + free.map(_.uniqueName))
val reqPC = pc.filterByIds(free.toSet)
val tpFresh = outer.tparams map { _.freshen } val tpFresh = outer.tparams map { _.freshen }
val tparamsMap = outer.tparams.zip(tpFresh map {_.tp}).toMap val tparamsMap = outer.tparams.zip(tpFresh map {_.tp}).toMap
val freshVals = (inner.paramIds ++ free).map{_.freshen}.map(instantiateType(_, tparamsMap)) val freshVals = (inner.paramIds ++ free).map{_.freshen}.map(instantiateType(_, tparamsMap))
val freeMap = (inner.paramIds ++ free).zip(freshVals).toMap val freeMap = (inner.paramIds ++ free).zip(freshVals).toMap
val freshParams = (inner.paramIds ++ free).filterNot(v => reqPC.isBound(v)).map(v => freeMap(v))
val newFd = inner.duplicate( val newFd = inner.duplicate(
inner.id.freshen, inner.id.freshen,
inner.tparams ++ tpFresh, inner.tparams ++ tpFresh,
freshVals.map(ValDef(_)), freshParams.map(ValDef(_)),
instantiateType(inner.returnType, tparamsMap) instantiateType(inner.returnType, tparamsMap)
) )
val instBody = instantiateType( val instBody = instantiateType(
withPath(newFd.fullBody, pc.filterByIds(free.toSet)), withPath(newFd.fullBody, reqPC),
tparamsMap, tparamsMap,
freeMap freeMap
) )
newFd.fullBody = preMap { newFd.fullBody = preMap {
case Let(id, v, r) if freeMap.isDefinedAt(id) => Some(Let(freeMap(id), v, r))
case fi @ FunctionInvocation(tfd, args) if tfd.fd == inner => case fi @ FunctionInvocation(tfd, args) if tfd.fd == inner =>
Some(FunctionInvocation( Some(FunctionInvocation(
newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }), newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }),
args ++ freshVals.drop(args.length).map(Variable) args ++ freshParams.drop(args.length).map(Variable)
).setPos(fi)) ).setPos(fi))
case _ => None case _ => None
}(instBody) }(instBody)
...@@ -183,6 +184,7 @@ object FunctionClosure extends TransformationPhase { ...@@ -183,6 +184,7 @@ object FunctionClosure extends TransformationPhase {
//HACK to make sure substitution happened even in nested fundef //HACK to make sure substitution happened even in nested fundef
newFd.fullBody = replaceFromIDs(freeMap.map(p => (p._1, p._2.toVariable)), newFd.fullBody) newFd.fullBody = replaceFromIDs(freeMap.map(p => (p._1, p._2.toVariable)), newFd.fullBody)
FunSubst(newFd, freeMap, tparamsMap.map{ case (from, to) => from.tp -> to}) FunSubst(newFd, freeMap, tparamsMap.map{ case (from, to) => from.tp -> to})
} }
......
...@@ -73,6 +73,8 @@ class Path private[purescala]( ...@@ -73,6 +73,8 @@ class Path private[purescala](
lazy val bindings: Seq[(Identifier, Expr)] = elements.collect { case Left(p) => p } lazy val bindings: Seq[(Identifier, Expr)] = elements.collect { case Left(p) => p }
lazy val conditions: Seq[Expr] = elements.collect { case Right(e) => e } lazy val conditions: Seq[Expr] = elements.collect { case Right(e) => e }
def isBound(id: Identifier): Boolean = bindings.exists(p => p._1 == id)
private def fold[T](base: T, combineLet: (Identifier, Expr, T) => T, combineCond: (Expr, T) => T) private def fold[T](base: T, combineLet: (Identifier, Expr, T) => T, combineCond: (Expr, T) => T)
(elems: Seq[Either[(Identifier, Expr), Expr]]): T = elems.foldRight(base) { (elems: Seq[Either[(Identifier, Expr), Expr]]): T = elems.foldRight(base) {
case (Left((id, e)), res) => combineLet(id, e, res) case (Left((id, e)), res) => combineLet(id, e, res)
......
...@@ -131,6 +131,170 @@ class FunctionClosureSuite extends FunSuite with helpers.ExpressionsDSL { ...@@ -131,6 +131,170 @@ class FunctionClosureSuite extends FunSuite with helpers.ExpressionsDSL {
}) })
} }
test("close captures enclosing require if needed") {
val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
nested.body = Some(Plus(x, y))
val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
fd.body = Some(Require(GreaterEquals(x, bi(0)), LetDef(Seq(nested), x)))
val cfds = FunctionClosure.close(fd)
assert(cfds.size === 2)
cfds.foreach(cfd => {
if(cfd.id.name == "f") {
assert(cfd.returnType === fd.returnType)
assert(cfd.params.size === fd.params.size)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "nested") {
assert(cfd.returnType === nested.returnType)
assert(cfd.params.size === 2)
assert(freeVars(cfd).isEmpty)
assert(cfd.precondition != None)
//next assert is assuming that the function closures always adds paramters at the end of the parameter list
cfd.precondition.foreach(pre => assert(pre == GreaterEquals(cfd.params.last.toVariable, bi(0))))
} else {
fail("Unexpected fun def: " + cfd)
}
})
}
test("close captures transitive dependencies within path") {
val x2 = FreshIdentifier("x2", IntegerType).toVariable
val x3 = FreshIdentifier("x3", IntegerType).toVariable
val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
nested.body = Some(Plus(y, x3))
val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
fd.body = Some(
Let(x2.id, Plus(x, bi(1)),
Let(x3.id, Plus(x2, bi(1)),
LetDef(Seq(nested), x)))
)
val cfds = FunctionClosure.close(fd)
assert(cfds.size === 2)
cfds.foreach(cfd => {
if(cfd.id.name == "f") {
assert(cfd.returnType === fd.returnType)
assert(cfd.params.size === fd.params.size)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "nested") {
assert(cfd.returnType === nested.returnType)
assert(cfd.params.size === 2)
assert(freeVars(cfd).isEmpty)
} else {
fail("Unexpected fun def: " + cfd)
}
})
}
test("close captures transitive dependencies within path but not too many") {
val x2 = FreshIdentifier("x2", IntegerType).toVariable
val x3 = FreshIdentifier("x3", IntegerType).toVariable
val x4 = FreshIdentifier("x4", IntegerType).toVariable
val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
nested.body = Some(Plus(y, x4))
val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id), ValDef(z.id)), IntegerType)
fd.body = Some(
Let(x2.id, Plus(x, bi(1)),
Let(x3.id, Plus(z, bi(1)),
Let(x4.id, Plus(x2, bi(1)),
LetDef(Seq(nested), x))))
)
val cfds = FunctionClosure.close(fd)
assert(cfds.size === 2)
cfds.foreach(cfd => {
if(cfd.id.name == "f") {
assert(cfd.returnType === fd.returnType)
assert(cfd.params.size === fd.params.size)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "nested") {
assert(cfd.returnType === nested.returnType)
assert(cfd.params.size === 2)
assert(freeVars(cfd).isEmpty)
} else {
fail("Unexpected fun def: " + cfd)
}
})
}
test("close captures enclosing require of callee functions") {
val callee = new FunDef(FreshIdentifier("callee"), Seq(), Seq(), IntegerType)
callee.body = Some(x)
val caller = new FunDef(FreshIdentifier("caller"), Seq(), Seq(), IntegerType)
caller.body = Some(FunctionInvocation(callee.typed, Seq()))
val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
fd.body = Some(Require(GreaterEquals(x, bi(0)), LetDef(Seq(callee, caller), x)))
val cfds = FunctionClosure.close(fd)
assert(cfds.size === 3)
cfds.foreach(cfd => {
if(cfd.id.name == "f") {
assert(cfd.returnType === fd.returnType)
assert(cfd.params.size === fd.params.size)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "callee") {
assert(cfd.returnType === callee.returnType)
assert(cfd.params.size === 1)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "caller") {
assert(cfd.returnType === caller.returnType)
assert(cfd.params.size === 1)
assert(freeVars(cfd).isEmpty)
assert(cfd.precondition != None)
//next assert is assuming that the function closures always adds paramters at the end of the parameter list
cfd.precondition.foreach(pre => assert(pre == GreaterEquals(cfd.params.last.toVariable, bi(0))))
} else {
fail("Unexpected fun def: " + cfd)
}
})
val deeplyNested2 = new FunDef(FreshIdentifier("deeplyNested"), Seq(), Seq(ValDef(z.id)), IntegerType)
deeplyNested2.body = Some(Require(GreaterEquals(x, bi(0)), z))
val nested2 = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType)
nested2.body = Some(LetDef(Seq(deeplyNested2), FunctionInvocation(deeplyNested2.typed, Seq(y))))
val fd2 = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType)
fd2.body = Some(Require(GreaterEquals(x, bi(0)),
LetDef(Seq(nested2), FunctionInvocation(nested2.typed, Seq(x)))))
val cfds2 = FunctionClosure.close(fd2)
assert(cfds2.size === 3)
cfds2.foreach(cfd => {
if(cfd.id.name == "f") {
assert(cfd.returnType === fd2.returnType)
assert(cfd.params.size === fd2.params.size)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "nested") {
assert(cfd.returnType === nested2.returnType)
assert(cfd.params.size === 2)
assert(freeVars(cfd).isEmpty)
} else if(cfd.id.name == "deeplyNested") {
assert(cfd.returnType === deeplyNested2.returnType)
assert(cfd.params.size === 2)
assert(freeVars(cfd).isEmpty)
} else {
fail("Unexpected fun def: " + cfd)
}
})
}
private def freeVars(fd: FunDef): Set[Identifier] = variablesOf(fd.fullBody) -- fd.paramIds private def freeVars(fd: FunDef): Set[Identifier] = variablesOf(fd.fullBody) -- fd.paramIds
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment