Skip to content
Snippets Groups Projects
Commit caab26d7 authored by Manos Koukoutos's avatar Manos Koukoutos
Browse files

New FunctionClosure

parent e2276093
No related branches found
No related tags found
No related merge requests found
...@@ -3,178 +3,160 @@ ...@@ -3,178 +3,160 @@
package leon package leon
package purescala package purescala
import Common._
import Definitions._ import Definitions._
import Expressions._ import Expressions._
import Extractors._
import ExprOps._ import ExprOps._
import Constructors._ import Constructors._
import TypeOps.instantiateType
import leon.purescala.Common.Identifier
import leon.purescala.Types.TypeParameter
import utils.GraphOps._
class FunctionClosure extends TransformationPhase { class FunctionClosure extends TransformationPhase {
val name = "Function Closure" override val name: String = "Function Closure"
val description = "Closing function with its scoping variables" override val description: String = "Closing function with its scoping variables"
private def close(fd: FunDef): Seq[FunDef] = {
// Directly neste functions with their p.c.
val nestedWithPaths = {
val funDefs = directlyNestedFunDefs(fd.fullBody)
collectWithPC {
case LetDef(fd1, body) if funDefs(fd1) => fd1
}(fd.fullBody)
}.toMap
val nestedFuns = nestedWithPaths.keys.toSeq
// Transitively called funcions from each function
val callGraph: Map[FunDef, Set[FunDef]] = transitiveClosure(
nestedFuns.map { f =>
val calls = functionCallsOf(f.fullBody) collect {
case FunctionInvocation(TypedFunDef(fd, _), _) if nestedFuns.contains(fd) =>
fd
}
f -> calls
}.toMap
)
def freeVars(fd: FunDef, pc: Expr): Set[Identifier] =
variablesOf(fd.fullBody) ++ variablesOf(pc) -- fd.paramIds
// All free variables one should include.
// Contains free vars of the function itself plus of all transitively called functions.
val transFree = nestedFuns.map { fd =>
fd -> (callGraph(fd) + fd).flatMap( (fd2:FunDef) => freeVars(fd2, nestedWithPaths(fd2)) ).toSeq
}.toMap
// Closed functions along with a map (old var -> new var).
val closed = nestedWithPaths.map {
case (inner, pc) => inner -> step(inner, fd, pc, transFree(inner))
}
// TODO: Rewrite this phase // Remove LetDefs
/* I know, that's a lot of mutable variables */ fd.fullBody = preMap({
private var pathConstraints: List[Expr] = Nil case LetDef(fd, bd) =>
private var enclosingLets: List[(Identifier, Expr)] = Nil Some(bd)
private var newFunDefs: Map[FunDef, FunDef] = Map() case _ =>
private var topLevelFuns: Set[FunDef] = Set() None
private var parent: FunDef = null //refers to the current toplevel parent }, applyRec = true)(fd.fullBody)
val dummySubst = FunSubst(
fd,
Map.empty.withDefault(id => id),
Map.empty.withDefault(id => id)
)
// Refresh function calls
(dummySubst +: closed.values.toSeq).foreach { case FunSubst(f, paramsMap, tparamsMap) =>
//println(f)
//paramsMap foreach { case (from, to) =>
// println(from.uniqueName + " -> " + to.uniqueName)
//}
f.fullBody = preMap {
case FunctionInvocation(tfd, args) if closed contains tfd.fd =>
val FunSubst(newFd, newParams, newTParams) = closed(tfd.fd)
// New -> old map for function call
val mapReverse = newParams map { _.swap }
val extraArgs = newFd.paramIds.drop(args.size).map { id =>
paramsMap(mapReverse(id)).toVariable
}
// Similarly for type params
val tReverse = newTParams map { _.swap }
val tOrigExtraOrdered = newFd.tparams.map{_.tp}.drop(tfd.tps.length).map(tReverse)
val tFinalExtra: Seq[TypeParameter] = tOrigExtraOrdered.map( tp =>
tparamsMap(tp)
)
Some(FunctionInvocation(
newFd.typed(tfd.tps ++ tFinalExtra),
args ++ extraArgs
))
case _ => None
}(f.fullBody)
}
def apply(ctx: LeonContext, program: Program): Program = { val funs = closed.values.toSeq.map{ _.newFd }
val newUnits = program.units.map { u => u.copy(defs = u.defs map { fd +: funs.flatMap(close)
case m: ModuleDef =>
pathConstraints = Nil
enclosingLets = Nil
newFunDefs = Map()
topLevelFuns = Set()
parent = null
val funDefs = m.definedFunctions
funDefs.foreach(fd => {
parent = fd
pathConstraints = fd.precondition.toList
fd.body = fd.body.map(b => functionClosure(b, fd.params.map(_.id).toSet, Map(), Map()))
})
ModuleDef(m.id, m.defs ++ topLevelFuns, m.isPackageObject )
case cd => cd
})}
Program(newUnits)
} }
private def functionClosure(expr: Expr, bindedVars: Set[Identifier], id2freshId: Map[Identifier, Identifier], fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = expr match { private case class FunSubst(
case l @ LetDef(fd, rest) => { newFd: FunDef,
val capturedVars: Set[Identifier] = bindedVars.diff(enclosingLets.map(_._1).toSet) paramsMap: Map[Identifier, Identifier],
val capturedConstraints: Set[Expr] = pathConstraints.toSet tparamsMap: Map[TypeParameter, TypeParameter]
)
val freshIds: Map[Identifier, Identifier] = capturedVars.map(id => (id, id.freshen)).toMap
private def step(inner: FunDef, outer: FunDef, pc: Expr, free: Seq[Identifier]): FunSubst = {
val extraValDefOldIds: Seq[Identifier] = capturedVars.toSeq
val extraValDefFreshIds: Seq[Identifier] = extraValDefOldIds.map(freshIds(_)) val tpFresh = outer.tparams map { _.freshen }
val extraValDefs: Seq[ValDef] = extraValDefFreshIds.map(ValDef(_)) val tparamsMap = outer.tparams.zip(tpFresh map {_.tp}).toMap
val newValDefs: Seq[ValDef] = fd.params ++ extraValDefs
val newBindedVars: Set[Identifier] = bindedVars ++ fd.params.map(_.id) val freshVals = (inner.paramIds ++ free).map{_.freshen}.map(instantiateType(_, tparamsMap))
val newFunId = FreshIdentifier(fd.id.name, alwaysShowUniqueID = true) //since we hoist this at the top level, we need to make it a unique name val freeMap = (inner.paramIds ++ free).zip(freshVals).toMap
val newFunDef = new FunDef(newFunId, fd.tparams, fd.returnType, newValDefs).copiedFrom(fd) val newFd = new FunDef(
topLevelFuns += newFunDef inner.id.freshen,
newFunDef.copyContentFrom(fd) //TODO: this still has some dangerous side effects (?) inner.tparams ++ tpFresh,
instantiateType(inner.returnType, tparamsMap),
def introduceLets(expr: Expr, fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = { freshVals.map(ValDef(_))
val (newExpr, _) = enclosingLets.foldLeft((expr, Map[Identifier, Identifier]()))((acc, p) => { )
val newId = p._1.freshen newFd.copyContentFrom(inner)
val newMap = acc._2 + (p._1 -> newId) newFd.precondition = Some(and(pc, inner.precOrTrue))
val newBody = functionClosure(acc._1, newBindedVars, freshIds ++ newMap, fd2FreshFd)
(Let(newId, p._2, newBody), newMap) val instBody = instantiateType(
}) newFd.fullBody,
functionClosure(newExpr, newBindedVars, freshIds, fd2FreshFd) tparamsMap,
} freeMap
)
val newPrecondition = simplifyLets(introduceLets(and((capturedConstraints ++ fd.precondition).toSeq :_*), fd2FreshFd))
newFunDef.precondition = if(newPrecondition == BooleanLiteral(true)) None else Some(newPrecondition) newFd.fullBody = preMap {
case FunctionInvocation(tfd, args) if tfd.fd == inner =>
val freshPostcondition = fd.postcondition.map { case post @ Lambda(args, body) => Some(FunctionInvocation(
Lambda(args, introduceLets(body, fd2FreshFd).setPos(body)).setPos(post) newFd.typed(tfd.tps ++ tpFresh.map{ _.tp }),
} args ++ freshVals.drop(args.length).map(Variable)
newFunDef.postcondition = freshPostcondition ))
case _ => None
pathConstraints = fd.precOrTrue :: pathConstraints }(instBody)
val freshBody = fd.body.map(body => introduceLets(body, fd2FreshFd + (fd -> ((newFunDef, extraValDefOldIds.map(_.toVariable))))))
newFunDef.body = freshBody FunSubst(newFd, freeMap, tparamsMap.map{ case (from, to) => from.tp -> to})
pathConstraints = pathConstraints.tail
val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd -> ((newFunDef, extraValDefOldIds.map(_.toVariable)))))
freshRest.copiedFrom(l)
}
case l @ Let(i,e,b) => {
val re = functionClosure(e, bindedVars, id2freshId, fd2FreshFd)
//we need the enclosing lets to always refer to the original ids, because it might be expand later in a highly nested function
enclosingLets ::= (i, replace(id2freshId.map(p => (p._2.toVariable, p._1.toVariable)), re))
//pathConstraints :: Equals(i.toVariable, re)
val rb = functionClosure(b, bindedVars + i, id2freshId, fd2FreshFd)
enclosingLets = enclosingLets.tail
//pathConstraints = pathConstraints.tail
Let(i, re, rb).copiedFrom(l)
}
case i @ IfExpr(cond,thenn,elze) => {
/*
when acumulating path constraints, take the condition without closing it first, so this
might not work well with nested fundef in if then else condition
*/
val rCond = functionClosure(cond, bindedVars, id2freshId, fd2FreshFd)
pathConstraints ::= cond//rCond
val rThen = functionClosure(thenn, bindedVars, id2freshId, fd2FreshFd)
pathConstraints = pathConstraints.tail
pathConstraints ::= Not(cond)//Not(rCond)
val rElze = functionClosure(elze, bindedVars, id2freshId, fd2FreshFd)
pathConstraints = pathConstraints.tail
IfExpr(rCond, rThen, rElze).copiedFrom(i)
}
case fi @ FunctionInvocation(tfd, args) => fd2FreshFd.get(tfd.fd) match {
case None =>
FunctionInvocation(tfd,
args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd))).copiedFrom(fi)
case Some((nfd, extraArgs)) =>
FunctionInvocation(nfd.typed(tfd.tps),
args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd)) ++
extraArgs.map(v => replace(id2freshId.map(p => (p._1.toVariable, p._2.toVariable)), v))).copiedFrom(fi)
}
case m @ MatchExpr(scrut,cses) => {
val scrutRec = functionClosure(scrut, bindedVars, id2freshId, fd2FreshFd)
val csesRec = cses.map{ cse =>
import cse._
val binders = pattern.binders
val cond = conditionForPattern(scrut, pattern)
pathConstraints ::= cond
val rRhs = functionClosure(rhs, bindedVars ++ binders, id2freshId, fd2FreshFd)
val rGuard = optGuard map { functionClosure(_, bindedVars ++ binders, id2freshId, fd2FreshFd) }
pathConstraints = pathConstraints.tail
MatchCase(pattern, rGuard, rRhs)
}
matchExpr(scrutRec, csesRec).copiedFrom(m)
}
case v @ Variable(id) => id2freshId.get(id) match {
case None => v
case Some(nid) => Variable(nid)
}
case n @ Operator(args, recons) => {
val rargs = args.map(a => functionClosure(a, bindedVars, id2freshId, fd2FreshFd))
recons(rargs).copiedFrom(n)
}
case unhandled => scala.sys.error("Non-terminal case should be handled in FunctionClosure: " + unhandled)
} }
def freshIdInPat(pat: Pattern, id2freshId: Map[Identifier, Identifier]): Pattern = pat match { override def apply(ctx: LeonContext, program: Program): Program = {
case InstanceOfPattern(binder, classTypeDef) => InstanceOfPattern(binder.map(id2freshId(_)), classTypeDef) val newUnits = program.units.map { u => u.copy(defs = u.defs map {
case WildcardPattern(binder) => WildcardPattern(binder.map(id2freshId(_))) case m: ModuleDef =>
case CaseClassPattern(binder, caseClassDef, subPatterns) => CaseClassPattern(binder.map(id2freshId(_)), caseClassDef, subPatterns.map(freshIdInPat(_, id2freshId))) ModuleDef(
case TuplePattern(binder, subPatterns) => TuplePattern(binder.map(id2freshId(_)), subPatterns.map(freshIdInPat(_, id2freshId))) m.id,
case UnapplyPattern(binder, fd, subPatterns) => UnapplyPattern(binder.map(id2freshId(_)), fd, subPatterns.map(freshIdInPat(_, id2freshId))) m.definedClasses ++ m.definedFunctions.flatMap(close),
case LiteralPattern(binder, lit) => LiteralPattern(binder.map(id2freshId(_)), lit) m.isPackageObject
)
case cd =>
cd
})}
Program(newUnits)
} }
//filter the list of constraints, only keeping those relevant to the set of variables
def filterConstraints(vars: Set[Identifier]): (List[Expr], Set[Identifier]) = {
var allVars = vars
var newVars: Set[Identifier] = Set()
var constraints = pathConstraints
var filteredConstraints: List[Expr] = Nil
do {
allVars ++= newVars
newVars = Set()
constraints = pathConstraints.filterNot(filteredConstraints.contains(_))
constraints.foreach(expr => {
val vs = variablesOf(expr)
if(vs.intersect(allVars).nonEmpty) {
filteredConstraints ::= expr
newVars ++= vs.diff(allVars)
}
})
} while(newVars != Set())
(filteredConstraints, allVars)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment