-
Régis Blanc authoredRégis Blanc authored
FunctionClosure.scala 8.41 KiB
package leon
import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.TypeTrees._
object FunctionClosure extends Pass {
val description = "Closing function with its scoping variables"
private var pathConstraints: List[Expr] = Nil
private var enclosingLets: List[(Identifier, Expr)] = Nil
private var newFunDefs: Map[FunDef, FunDef] = Map()
private var topLevelFuns: Set[FunDef] = Set()
def apply(program: Program): Program = {
newFunDefs = Map()
val funDefs = program.definedFunctions
funDefs.foreach(fd => {
pathConstraints = fd.precondition.toList
fd.body = fd.body.map(b => functionClosure(b, fd.args.map(_.id).toSet, Map(), Map()))
})
val Program(id, ObjectDef(objId, defs, invariants)) = program
Program(id, ObjectDef(objId, defs ++ topLevelFuns, invariants))
}
private def functionClosure(expr: Expr, bindedVars: Set[Identifier], id2freshId: Map[Identifier, Identifier], fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = expr match {
case l @ LetDef(fd, rest) => {
val capturedVars: Set[Identifier] = bindedVars.diff(enclosingLets.map(_._1).toSet)
val capturedConstraints: Set[Expr] = pathConstraints.toSet
val freshIds: Map[Identifier, Identifier] = capturedVars.map(id => (id, FreshIdentifier(id.name).setType(id.getType))).toMap
val freshVars: Map[Expr, Expr] = freshIds.map(p => (p._1.toVariable, p._2.toVariable))
val extraVarDeclOldIds: Seq[Identifier] = capturedVars.toSeq
val extraVarDeclFreshIds: Seq[Identifier] = extraVarDeclOldIds.map(freshIds(_))
val extraVarDecls: Seq[VarDecl] = extraVarDeclFreshIds.map(id => VarDecl(id, id.getType))
val newVarDecls: Seq[VarDecl] = fd.args ++ extraVarDecls
val newBindedVars: Set[Identifier] = bindedVars ++ fd.args.map(_.id)
val newFunId = FreshIdentifier(fd.id.name)
val newFunDef = new FunDef(newFunId, fd.returnType, newVarDecls).setPosInfo(fd)
topLevelFuns += newFunDef
newFunDef.fromLoop = fd.fromLoop
newFunDef.parent = fd.parent
newFunDef.addAnnotation(fd.annotations.toSeq:_*)
def introduceLets(expr: Expr, fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = {
val (newExpr, _) = enclosingLets.foldLeft((expr, Map[Identifier, Identifier]()))((acc, p) => {
val newId = FreshIdentifier(p._1.name).setType(p._1.getType)
val newMap = acc._2 + (p._1 -> newId)
val newBody = functionClosure(acc._1, newBindedVars, freshIds ++ newMap, fd2FreshFd)
(Let(newId, p._2, newBody), newMap)
})
functionClosure(newExpr, newBindedVars, freshIds, fd2FreshFd)
}
val newPrecondition = simplifyLets(introduceLets(And((capturedConstraints ++ fd.precondition).toSeq), fd2FreshFd))
newFunDef.precondition = if(newPrecondition == BooleanLiteral(true)) None else Some(newPrecondition)
val freshPostcondition = fd.postcondition.map(post => introduceLets(post, fd2FreshFd))
newFunDef.postcondition = freshPostcondition
pathConstraints = fd.precondition.getOrElse(BooleanLiteral(true)) :: pathConstraints
val freshBody = fd.body.map(body => introduceLets(body, fd2FreshFd + (fd -> ((newFunDef, extraVarDeclFreshIds.map(_.toVariable))))))
newFunDef.body = freshBody
pathConstraints = pathConstraints.tail
val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd ->
((newFunDef, extraVarDeclOldIds.map(id => id2freshId.get(id).getOrElse(id).toVariable)))))
freshRest.setType(l.getType)
}
case l @ Let(i,e,b) => {
val re = functionClosure(e, bindedVars, id2freshId, fd2FreshFd)
//we need the enclosing lets to always refer to the original ids, because it might be expand later in a highly nested function
enclosingLets ::= (i, replace(id2freshId.map(p => (p._2.toVariable, p._1.toVariable)), re))
//pathConstraints :: Equals(i.toVariable, re)
val rb = functionClosure(b, bindedVars + i, id2freshId, fd2FreshFd)
enclosingLets = enclosingLets.tail
//pathConstraints = pathConstraints.tail
Let(i, re, rb).setType(l.getType)
}
case i @ IfExpr(cond,then,elze) => {
/*
when acumulating path constraints, take the condition without closing it first, so this
might not work well with nested fundef in if then else condition
*/
val rCond = functionClosure(cond, bindedVars, id2freshId, fd2FreshFd)
pathConstraints ::= cond//rCond
val rThen = functionClosure(then, bindedVars, id2freshId, fd2FreshFd)
pathConstraints = pathConstraints.tail
pathConstraints ::= Not(cond)//Not(rCond)
val rElze = functionClosure(elze, bindedVars, id2freshId, fd2FreshFd)
pathConstraints = pathConstraints.tail
IfExpr(rCond, rThen, rElze).setType(i.getType)
}
case fi @ FunctionInvocation(fd, args) => fd2FreshFd.get(fd) match {
case None => FunctionInvocation(fd, args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd))).setPosInfo(fi)
case Some((nfd, extraArgs)) =>
FunctionInvocation(nfd, args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd)) ++ extraArgs).setPosInfo(fi)
}
case n @ NAryOperator(args, recons) => {
val rargs = args.map(a => functionClosure(a, bindedVars, id2freshId, fd2FreshFd))
recons(rargs).setType(n.getType)
}
case b @ BinaryOperator(t1,t2,recons) => {
val r1 = functionClosure(t1, bindedVars, id2freshId, fd2FreshFd)
val r2 = functionClosure(t2, bindedVars, id2freshId, fd2FreshFd)
recons(r1,r2).setType(b.getType)
}
case u @ UnaryOperator(t,recons) => {
val r = functionClosure(t, bindedVars, id2freshId, fd2FreshFd)
recons(r).setType(u.getType)
}
case m @ MatchExpr(scrut,cses) => { //still needs to handle the new ids introduced by the patterns
val scrutRec = functionClosure(scrut, bindedVars, id2freshId, fd2FreshFd)
val csesRec = cses.map{
case SimpleCase(pat, rhs) => {
val binders = pat.binders
val cond = conditionForPattern(scrut, pat)
pathConstraints ::= cond
val rRhs = functionClosure(rhs, bindedVars ++ binders, id2freshId, fd2FreshFd)
pathConstraints = pathConstraints.tail
SimpleCase(pat, rRhs)
}
case GuardedCase(pat, guard, rhs) => {
val binders = pat.binders
val cond = conditionForPattern(scrut, pat)
pathConstraints ::= cond
val rRhs = functionClosure(rhs, bindedVars ++ binders, id2freshId, fd2FreshFd)
val rGuard = functionClosure(guard, bindedVars ++ binders, id2freshId, fd2FreshFd)
pathConstraints = pathConstraints.tail
GuardedCase(pat, rGuard, rRhs)
}
}
val tpe = csesRec.head.rhs.getType
MatchExpr(scrutRec, csesRec).setType(tpe).setPosInfo(m)
}
case v @ Variable(id) => id2freshId.get(id) match {
case None => v
case Some(nid) => Variable(nid)
}
case t if t.isInstanceOf[Terminal] => t
case unhandled => scala.sys.error("Non-terminal case should be handled in FunctionClosure: " + unhandled)
}
def freshIdInPat(pat: Pattern, id2freshId: Map[Identifier, Identifier]): Pattern = pat match {
case InstanceOfPattern(binder, classTypeDef) => InstanceOfPattern(binder.map(id2freshId(_)), classTypeDef)
case WildcardPattern(binder) => WildcardPattern(binder.map(id2freshId(_)))
case CaseClassPattern(binder, caseClassDef, subPatterns) => CaseClassPattern(binder.map(id2freshId(_)), caseClassDef, subPatterns.map(freshIdInPat(_, id2freshId)))
case TuplePattern(binder, subPatterns) => TuplePattern(binder.map(id2freshId(_)), subPatterns.map(freshIdInPat(_, id2freshId)))
}
//filter the list of constraints, only keeping those relevant to the set of variables
def filterConstraints(vars: Set[Identifier]): (List[Expr], Set[Identifier]) = {
var allVars = vars
var newVars: Set[Identifier] = Set()
var constraints = pathConstraints
var filteredConstraints: List[Expr] = Nil
do {
allVars ++= newVars
newVars = Set()
constraints = pathConstraints.filterNot(filteredConstraints.contains(_))
constraints.foreach(expr => {
val vs = variablesOf(expr)
if(!vs.intersect(allVars).isEmpty) {
filteredConstraints ::= expr
newVars ++= vs.diff(allVars)
}
})
} while(newVars != Set())
(filteredConstraints, allVars)
}
}