From a462d05be83eec8ebc9b47981f967a6a1f6cacd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Sat, 16 Apr 2016 21:39:03 +0200 Subject: [PATCH] function closure does not capture let bindings --- .../leon/purescala/FunctionClosure.scala | 34 ++-- src/main/scala/leon/purescala/Path.scala | 2 + .../unit/purescala/FunctionClosureSuite.scala | 164 ++++++++++++++++++ 3 files changed, 184 insertions(+), 16 deletions(-) diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 98d819eb2..1375cdf26 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -36,8 +36,6 @@ object FunctionClosure extends TransformationPhase { val nestedWithPaths = (for((fds, path) <- nestedWithPathsFull; fd <- fds) yield (fd, path)).toMap val nestedFuns = nestedWithPaths.keys.toSeq - //println(nestedWithPaths) - // Transitively called funcions from each function val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure( nestedFuns.map { f => @@ -56,14 +54,12 @@ object FunctionClosure extends TransformationPhase { //println("call graph: " + callGraph) def freeVars(fd: FunDef, pc: Path): Set[Identifier] = - variablesOf(fd.fullBody) ++ pc.variables -- fd.paramIds -- pc.bindings.map(_._1) - //def freeVars(fd: FunDef): Set[Identifier] = - // variablesOf(fd.fullBody) -- fd.paramIds + variablesOf(fd.fullBody) ++ pc.variables ++ pc.bindings.map(_._1) -- fd.paramIds // All free variables one should include. // Contains free vars of the function itself plus of all transitively called functions. // also contains free vars from PC if the PC is relevant to the fundef - val transFree = { + val transFreeWithBindings = { def step(current: Map[FunDef, Set[Identifier]]): Map[FunDef, Set[Identifier]] = { nestedFuns.map(fd => { val transFreeVars = (callGraph(fd) + fd).flatMap((fd2:FunDef) => current(fd2)) @@ -73,14 +69,12 @@ object FunctionClosure extends TransformationPhase { } utils.fixpoint(step, -1)(nestedFuns.map(fd => (fd, variablesOf(fd.fullBody) -- fd.paramIds)).toMap) - }.map(p => (p._1, p._2.toSeq)) - //println("free vars: " + transFree) + } + + val transFree: Map[FunDef, Seq[Identifier]] = + //transFreeWithBindings.map(p => (p._1, p._2 -- nestedWithPaths(p._1).bindings.map(_._1))).map(p => (p._1, p._2.toSeq)) + transFreeWithBindings.map(p => (p._1, p._2.toSeq)) - // All free variables one should include. - // Contains free vars of the function itself plus of all transitively called functions. - //val transFree = nestedFuns.map { fd => - // fd -> (callGraph(fd) + fd).flatMap( (fd2: FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq - //}.toMap // Closed functions along with a map (old var -> new var). val closed = nestedWithPaths.map { @@ -151,31 +145,38 @@ object FunctionClosure extends TransformationPhase { // Takes one inner function and closes it. private def closeFd(inner: FunDef, outer: FunDef, pc: Path, free: Seq[Identifier]): FunSubst = { + //println("inner: " + inner) + //println("pc: " + pc) + //println("free: " + free.map(_.uniqueName)) + + val reqPC = pc.filterByIds(free.toSet) val tpFresh = outer.tparams map { _.freshen } val tparamsMap = outer.tparams.zip(tpFresh map {_.tp}).toMap val freshVals = (inner.paramIds ++ free).map{_.freshen}.map(instantiateType(_, tparamsMap)) val freeMap = (inner.paramIds ++ free).zip(freshVals).toMap + val freshParams = (inner.paramIds ++ free).filterNot(v => reqPC.isBound(v)).map(v => freeMap(v)) val newFd = inner.duplicate( inner.id.freshen, inner.tparams ++ tpFresh, - freshVals.map(ValDef(_)), + freshParams.map(ValDef(_)), instantiateType(inner.returnType, tparamsMap) ) val instBody = instantiateType( - withPath(newFd.fullBody, pc.filterByIds(free.toSet)), + withPath(newFd.fullBody, reqPC), tparamsMap, freeMap ) newFd.fullBody = preMap { + case Let(id, v, r) if freeMap.isDefinedAt(id) => Some(Let(freeMap(id), v, r)) case fi @ FunctionInvocation(tfd, args) if tfd.fd == inner => Some(FunctionInvocation( newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }), - args ++ freshVals.drop(args.length).map(Variable) + args ++ freshParams.drop(args.length).map(Variable) ).setPos(fi)) case _ => None }(instBody) @@ -183,6 +184,7 @@ object FunctionClosure extends TransformationPhase { //HACK to make sure substitution happened even in nested fundef newFd.fullBody = replaceFromIDs(freeMap.map(p => (p._1, p._2.toVariable)), newFd.fullBody) + FunSubst(newFd, freeMap, tparamsMap.map{ case (from, to) => from.tp -> to}) } diff --git a/src/main/scala/leon/purescala/Path.scala b/src/main/scala/leon/purescala/Path.scala index 9b7e4786a..c7d667676 100644 --- a/src/main/scala/leon/purescala/Path.scala +++ b/src/main/scala/leon/purescala/Path.scala @@ -73,6 +73,8 @@ class Path private[purescala]( lazy val bindings: Seq[(Identifier, Expr)] = elements.collect { case Left(p) => p } lazy val conditions: Seq[Expr] = elements.collect { case Right(e) => e } + def isBound(id: Identifier): Boolean = bindings.exists(p => p._1 == id) + private def fold[T](base: T, combineLet: (Identifier, Expr, T) => T, combineCond: (Expr, T) => T) (elems: Seq[Either[(Identifier, Expr), Expr]]): T = elems.foldRight(base) { case (Left((id, e)), res) => combineLet(id, e, res) diff --git a/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala b/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala index 2019c0452..98132fd9a 100644 --- a/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala +++ b/src/test/scala/leon/unit/purescala/FunctionClosureSuite.scala @@ -131,6 +131,170 @@ class FunctionClosureSuite extends FunSuite with helpers.ExpressionsDSL { }) } + test("close captures enclosing require if needed") { + val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType) + nested.body = Some(Plus(x, y)) + + val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType) + fd.body = Some(Require(GreaterEquals(x, bi(0)), LetDef(Seq(nested), x))) + + val cfds = FunctionClosure.close(fd) + assert(cfds.size === 2) + + cfds.foreach(cfd => { + if(cfd.id.name == "f") { + assert(cfd.returnType === fd.returnType) + assert(cfd.params.size === fd.params.size) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "nested") { + assert(cfd.returnType === nested.returnType) + assert(cfd.params.size === 2) + assert(freeVars(cfd).isEmpty) + assert(cfd.precondition != None) + //next assert is assuming that the function closures always adds paramters at the end of the parameter list + cfd.precondition.foreach(pre => assert(pre == GreaterEquals(cfd.params.last.toVariable, bi(0)))) + } else { + fail("Unexpected fun def: " + cfd) + } + }) + } + + test("close captures transitive dependencies within path") { + val x2 = FreshIdentifier("x2", IntegerType).toVariable + val x3 = FreshIdentifier("x3", IntegerType).toVariable + + val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType) + nested.body = Some(Plus(y, x3)) + + + val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType) + fd.body = Some( + Let(x2.id, Plus(x, bi(1)), + Let(x3.id, Plus(x2, bi(1)), + LetDef(Seq(nested), x))) + ) + + val cfds = FunctionClosure.close(fd) + assert(cfds.size === 2) + + cfds.foreach(cfd => { + if(cfd.id.name == "f") { + assert(cfd.returnType === fd.returnType) + assert(cfd.params.size === fd.params.size) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "nested") { + assert(cfd.returnType === nested.returnType) + assert(cfd.params.size === 2) + assert(freeVars(cfd).isEmpty) + } else { + fail("Unexpected fun def: " + cfd) + } + }) + } + + test("close captures transitive dependencies within path but not too many") { + val x2 = FreshIdentifier("x2", IntegerType).toVariable + val x3 = FreshIdentifier("x3", IntegerType).toVariable + val x4 = FreshIdentifier("x4", IntegerType).toVariable + + val nested = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType) + nested.body = Some(Plus(y, x4)) + + + val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id), ValDef(z.id)), IntegerType) + fd.body = Some( + Let(x2.id, Plus(x, bi(1)), + Let(x3.id, Plus(z, bi(1)), + Let(x4.id, Plus(x2, bi(1)), + LetDef(Seq(nested), x)))) + ) + + val cfds = FunctionClosure.close(fd) + assert(cfds.size === 2) + + cfds.foreach(cfd => { + if(cfd.id.name == "f") { + assert(cfd.returnType === fd.returnType) + assert(cfd.params.size === fd.params.size) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "nested") { + assert(cfd.returnType === nested.returnType) + assert(cfd.params.size === 2) + assert(freeVars(cfd).isEmpty) + } else { + fail("Unexpected fun def: " + cfd) + } + }) + } + + test("close captures enclosing require of callee functions") { + val callee = new FunDef(FreshIdentifier("callee"), Seq(), Seq(), IntegerType) + callee.body = Some(x) + + val caller = new FunDef(FreshIdentifier("caller"), Seq(), Seq(), IntegerType) + caller.body = Some(FunctionInvocation(callee.typed, Seq())) + + val fd = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType) + fd.body = Some(Require(GreaterEquals(x, bi(0)), LetDef(Seq(callee, caller), x))) + + val cfds = FunctionClosure.close(fd) + assert(cfds.size === 3) + + cfds.foreach(cfd => { + if(cfd.id.name == "f") { + assert(cfd.returnType === fd.returnType) + assert(cfd.params.size === fd.params.size) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "callee") { + assert(cfd.returnType === callee.returnType) + assert(cfd.params.size === 1) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "caller") { + assert(cfd.returnType === caller.returnType) + assert(cfd.params.size === 1) + assert(freeVars(cfd).isEmpty) + assert(cfd.precondition != None) + //next assert is assuming that the function closures always adds paramters at the end of the parameter list + cfd.precondition.foreach(pre => assert(pre == GreaterEquals(cfd.params.last.toVariable, bi(0)))) + } else { + fail("Unexpected fun def: " + cfd) + } + }) + + + val deeplyNested2 = new FunDef(FreshIdentifier("deeplyNested"), Seq(), Seq(ValDef(z.id)), IntegerType) + deeplyNested2.body = Some(Require(GreaterEquals(x, bi(0)), z)) + + val nested2 = new FunDef(FreshIdentifier("nested"), Seq(), Seq(ValDef(y.id)), IntegerType) + nested2.body = Some(LetDef(Seq(deeplyNested2), FunctionInvocation(deeplyNested2.typed, Seq(y)))) + + val fd2 = new FunDef(FreshIdentifier("f"), Seq(), Seq(ValDef(x.id)), IntegerType) + fd2.body = Some(Require(GreaterEquals(x, bi(0)), + LetDef(Seq(nested2), FunctionInvocation(nested2.typed, Seq(x))))) + + val cfds2 = FunctionClosure.close(fd2) + assert(cfds2.size === 3) + + cfds2.foreach(cfd => { + if(cfd.id.name == "f") { + assert(cfd.returnType === fd2.returnType) + assert(cfd.params.size === fd2.params.size) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "nested") { + assert(cfd.returnType === nested2.returnType) + assert(cfd.params.size === 2) + assert(freeVars(cfd).isEmpty) + } else if(cfd.id.name == "deeplyNested") { + assert(cfd.returnType === deeplyNested2.returnType) + assert(cfd.params.size === 2) + assert(freeVars(cfd).isEmpty) + } else { + fail("Unexpected fun def: " + cfd) + } + }) + } + + private def freeVars(fd: FunDef): Set[Identifier] = variablesOf(fd.fullBody) -- fd.paramIds } -- GitLab