diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 7f76beff85796ec47d8ff5a49e4613c752c9ac41..4ee1f89386468ab0582ad461da7a897c5340c4a2 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -3,178 +3,160 @@ package leon package purescala -import Common._ import Definitions._ import Expressions._ -import Extractors._ import ExprOps._ import Constructors._ +import TypeOps.instantiateType +import leon.purescala.Common.Identifier +import leon.purescala.Types.TypeParameter +import utils.GraphOps._ class FunctionClosure extends TransformationPhase { - val name = "Function Closure" - val description = "Closing function with its scoping variables" + override val name: String = "Function Closure" + override val description: String = "Closing function with its scoping variables" + + private def close(fd: FunDef): Seq[FunDef] = { + + // Directly neste functions with their p.c. + val nestedWithPaths = { + val funDefs = directlyNestedFunDefs(fd.fullBody) + collectWithPC { + case LetDef(fd1, body) if funDefs(fd1) => fd1 + }(fd.fullBody) + }.toMap + + val nestedFuns = nestedWithPaths.keys.toSeq + + // Transitively called funcions from each function + val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure( + nestedFuns.map { f => + val calls = functionCallsOf(f.fullBody) collect { + case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) => + fd + } + f -> calls + }.toMap + ) + + def freeVars(fd: FunDef, pc: Expr): Set[Identifier] = + variablesOf(fd.fullBody) ++ variablesOf(pc) -- fd.paramIds + + // All free variables one should include. + // Contains free vars of the function itself plus of all transitively called functions. + val transFree = nestedFuns.map { fd => + fd -> (callGraph(fd) + fd).flatMap( (fd2:FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq + }.toMap + + // Closed functions along with a map (old var -> new var). + val closed = nestedWithPaths.map { + case (inner, pc) => inner -> step(inner, fd, pc, transFree(inner)) + } - // TODO: Rewrite this phase - /* I know, that's a lot of mutable variables */ - private var pathConstraints: List[Expr] = Nil - private var enclosingLets: List[(Identifier, Expr)] = Nil - private var newFunDefs: Map[FunDef, FunDef] = Map() - private var topLevelFuns: Set[FunDef] = Set() - private var parent: FunDef = null //refers to the current toplevel parent + // Remove LetDefs + fd.fullBody = preMap({ + case LetDef(fd, bd) => + Some(bd) + case _ => + None + }, applyRec = true)(fd.fullBody) + + val dummySubst = FunSubst( + fd, + Map.empty.withDefault(id => id), + Map.empty.withDefault(id => id) + ) + + // Refresh function calls + (dummySubst +: closed.values.toSeq).foreach { case FunSubst(f, paramsMap, tparamsMap) => + //println(f) + //paramsMap foreach { case (from, to) => + // println(from.uniqueName + " -> " + to.uniqueName) + //} + f.fullBody = preMap { + case FunctionInvocation(tfd, args) if closed contains tfd.fd => + val FunSubst(newFd, newParams, newTParams) = closed(tfd.fd) + + // New -> old map for function call + val mapReverse = newParams map { _.swap } + val extraArgs = newFd.paramIds.drop(args.size).map { id => + paramsMap(mapReverse(id)).toVariable + } + + // Similarly for type params + val tReverse = newTParams map { _.swap } + val tOrigExtraOrdered = newFd.tparams.map{_.tp}.drop(tfd.tps.length).map(tReverse) + val tFinalExtra: Seq[TypeParameter] = tOrigExtraOrdered.map( tp => + tparamsMap(tp) + ) + + Some(FunctionInvocation( + newFd.typed(tfd.tps ++ tFinalExtra), + args ++ extraArgs + )) + case _ => None + }(f.fullBody) + } - def apply(ctx: LeonContext, program: Program): Program = { + val funs = closed.values.toSeq.map{ _.newFd } - val newUnits = program.units.map { u => u.copy(defs = u.defs map { - case m: ModuleDef => - pathConstraints = Nil - enclosingLets = Nil - newFunDefs = Map() - topLevelFuns = Set() - parent = null - - val funDefs = m.definedFunctions - funDefs.foreach(fd => { - parent = fd - pathConstraints = fd.precondition.toList - fd.body = fd.body.map(b => functionClosure(b, fd.params.map(_.id).toSet, Map(), Map())) - }) - - ModuleDef(m.id, m.defs ++ topLevelFuns, m.isPackageObject ) - case cd => cd - })} - Program(newUnits) + fd +: funs.flatMap(close) } - private def functionClosure(expr: Expr, bindedVars: Set[Identifier], id2freshId: Map[Identifier, Identifier], fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = expr match { - case l @ LetDef(fd, rest) => { - val capturedVars: Set[Identifier] = bindedVars.diff(enclosingLets.map(_._1).toSet) - val capturedConstraints: Set[Expr] = pathConstraints.toSet - - val freshIds: Map[Identifier, Identifier] = capturedVars.map(id => (id, id.freshen)).toMap - - val extraValDefOldIds: Seq[Identifier] = capturedVars.toSeq - val extraValDefFreshIds: Seq[Identifier] = extraValDefOldIds.map(freshIds(_)) - val extraValDefs: Seq[ValDef] = extraValDefFreshIds.map(ValDef(_)) - val newValDefs: Seq[ValDef] = fd.params ++ extraValDefs - val newBindedVars: Set[Identifier] = bindedVars ++ fd.params.map(_.id) - val newFunId = FreshIdentifier(fd.id.name, alwaysShowUniqueID = true) //since we hoist this at the top level, we need to make it a unique name - - val newFunDef = new FunDef(newFunId, fd.tparams, fd.returnType, newValDefs).copiedFrom(fd) - topLevelFuns += newFunDef - newFunDef.copyContentFrom(fd) //TODO: this still has some dangerous side effects (?) - - def introduceLets(expr: Expr, fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = { - val (newExpr, _) = enclosingLets.foldLeft((expr, Map[Identifier, Identifier]()))((acc, p) => { - val newId = p._1.freshen - val newMap = acc._2 + (p._1 -> newId) - val newBody = functionClosure(acc._1, newBindedVars, freshIds ++ newMap, fd2FreshFd) - (Let(newId, p._2, newBody), newMap) - }) - functionClosure(newExpr, newBindedVars, freshIds, fd2FreshFd) - } - - val newPrecondition = simplifyLets(introduceLets(and((capturedConstraints ++ fd.precondition).toSeq :_*), fd2FreshFd)) - newFunDef.precondition = if(newPrecondition == BooleanLiteral(true)) None else Some(newPrecondition) - - val freshPostcondition = fd.postcondition.map { case post @ Lambda(args, body) => - Lambda(args, introduceLets(body, fd2FreshFd).setPos(body)).setPos(post) - } - newFunDef.postcondition = freshPostcondition - - pathConstraints = fd.precOrTrue :: pathConstraints - val freshBody = fd.body.map(body => introduceLets(body, fd2FreshFd + (fd -> ((newFunDef, extraValDefOldIds.map(_.toVariable)))))) - newFunDef.body = freshBody - pathConstraints = pathConstraints.tail - - val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd -> ((newFunDef, extraValDefOldIds.map(_.toVariable))))) - freshRest.copiedFrom(l) - } - case l @ Let(i,e,b) => { - val re = functionClosure(e, bindedVars, id2freshId, fd2FreshFd) - //we need the enclosing lets to always refer to the original ids, because it might be expand later in a highly nested function - enclosingLets ::= (i, replace(id2freshId.map(p => (p._2.toVariable, p._1.toVariable)), re)) - //pathConstraints :: Equals(i.toVariable, re) - val rb = functionClosure(b, bindedVars + i, id2freshId, fd2FreshFd) - enclosingLets = enclosingLets.tail - //pathConstraints = pathConstraints.tail - Let(i, re, rb).copiedFrom(l) - } - case i @ IfExpr(cond,thenn,elze) => { - /* - when acumulating path constraints, take the condition without closing it first, so this - might not work well with nested fundef in if then else condition - */ - val rCond = functionClosure(cond, bindedVars, id2freshId, fd2FreshFd) - pathConstraints ::= cond//rCond - val rThen = functionClosure(thenn, bindedVars, id2freshId, fd2FreshFd) - pathConstraints = pathConstraints.tail - pathConstraints ::= Not(cond)//Not(rCond) - val rElze = functionClosure(elze, bindedVars, id2freshId, fd2FreshFd) - pathConstraints = pathConstraints.tail - IfExpr(rCond, rThen, rElze).copiedFrom(i) - } - case fi @ FunctionInvocation(tfd, args) => fd2FreshFd.get(tfd.fd) match { - case None => - FunctionInvocation(tfd, - args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd))).copiedFrom(fi) - case Some((nfd, extraArgs)) => - FunctionInvocation(nfd.typed(tfd.tps), - args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd)) ++ - extraArgs.map(v => replace(id2freshId.map(p => (p._1.toVariable, p._2.toVariable)), v))).copiedFrom(fi) - } - case m @ MatchExpr(scrut,cses) => { - val scrutRec = functionClosure(scrut, bindedVars, id2freshId, fd2FreshFd) - val csesRec = cses.map{ cse => - import cse._ - val binders = pattern.binders - val cond = conditionForPattern(scrut, pattern) - pathConstraints ::= cond - val rRhs = functionClosure(rhs, bindedVars ++ binders, id2freshId, fd2FreshFd) - val rGuard = optGuard map { functionClosure(_, bindedVars ++ binders, id2freshId, fd2FreshFd) } - pathConstraints = pathConstraints.tail - MatchCase(pattern, rGuard, rRhs) - } - matchExpr(scrutRec, csesRec).copiedFrom(m) - } - case v @ Variable(id) => id2freshId.get(id) match { - case None => v - case Some(nid) => Variable(nid) - } - case n @ Operator(args, recons) => { - val rargs = args.map(a => functionClosure(a, bindedVars, id2freshId, fd2FreshFd)) - recons(rargs).copiedFrom(n) - } - case unhandled => scala.sys.error("Non-terminal case should be handled in FunctionClosure: " + unhandled) + private case class FunSubst( + newFd: FunDef, + paramsMap: Map[Identifier, Identifier], + tparamsMap: Map[TypeParameter, TypeParameter] + ) + + private def step(inner: FunDef, outer: FunDef, pc: Expr, free: Seq[Identifier]): FunSubst = { + + val tpFresh = outer.tparams map { _.freshen } + val tparamsMap = outer.tparams.zip(tpFresh map {_.tp}).toMap + + val freshVals = (inner.paramIds ++ free).map{_.freshen}.map(instantiateType(_, tparamsMap)) + val freeMap = (inner.paramIds ++ free).zip(freshVals).toMap + + val newFd = new FunDef( + inner.id.freshen, + inner.tparams ++ tpFresh, + instantiateType(inner.returnType, tparamsMap), + freshVals.map(ValDef(_)) + ) + newFd.copyContentFrom(inner) + newFd.precondition = Some(and(pc, inner.precOrTrue)) + + val instBody = instantiateType( + newFd.fullBody, + tparamsMap, + freeMap + ) + + newFd.fullBody = preMap { + case FunctionInvocation(tfd, args) if tfd.fd == inner => + Some(FunctionInvocation( + newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }), + args ++ freshVals.drop(args.length).map(Variable) + )) + case _ => None + }(instBody) + + FunSubst(newFd, freeMap, tparamsMap.map{ case (from, to) => from.tp -> to}) } - def freshIdInPat(pat: Pattern, id2freshId: Map[Identifier, Identifier]): Pattern = pat match { - case InstanceOfPattern(binder, classTypeDef) => InstanceOfPattern(binder.map(id2freshId(_)), classTypeDef) - case WildcardPattern(binder) => WildcardPattern(binder.map(id2freshId(_))) - case CaseClassPattern(binder, caseClassDef, subPatterns) => CaseClassPattern(binder.map(id2freshId(_)), caseClassDef, subPatterns.map(freshIdInPat(_, id2freshId))) - case TuplePattern(binder, subPatterns) => TuplePattern(binder.map(id2freshId(_)), subPatterns.map(freshIdInPat(_, id2freshId))) - case UnapplyPattern(binder, fd, subPatterns) => UnapplyPattern(binder.map(id2freshId(_)), fd, subPatterns.map(freshIdInPat(_, id2freshId))) - case LiteralPattern(binder, lit) => LiteralPattern(binder.map(id2freshId(_)), lit) + override def apply(ctx: LeonContext, program: Program): Program = { + val newUnits = program.units.map { u => u.copy(defs = u.defs map { + case m: ModuleDef => + ModuleDef( + m.id, + m.definedClasses ++ m.definedFunctions.flatMap(close), + m.isPackageObject + ) + case cd => + cd + })} + Program(newUnits) } - //filter the list of constraints, only keeping those relevant to the set of variables - def filterConstraints(vars: Set[Identifier]): (List[Expr], Set[Identifier]) = { - var allVars = vars - var newVars: Set[Identifier] = Set() - var constraints = pathConstraints - var filteredConstraints: List[Expr] = Nil - do { - allVars ++= newVars - newVars = Set() - constraints = pathConstraints.filterNot(filteredConstraints.contains(_)) - constraints.foreach(expr => { - val vs = variablesOf(expr) - if(vs.intersect(allVars).nonEmpty) { - filteredConstraints ::= expr - newVars ++= vs.diff(allVars) - } - }) - } while(newVars != Set()) - (filteredConstraints, allVars) - } }