-
Régis Blanc authoredRégis Blanc authored
EpsilonElimination.scala 5.59 KiB
package leon
import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.TypeTrees._
object EpsilonElimination extends Pass {
val description = "Remove all epsilons from the program"
private var fun2FreshFun: Map[FunDef, FunDef] = Map()
private var id2FreshId: Map[Identifier, Identifier] = Map()
def apply(pgm: Program): Program = {
fun2FreshFun = Map()
val allFuns = pgm.definedFunctions
//first introduce new signatures without Unit parameters
allFuns.foreach(fd => {
if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) {
val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd)
freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well..
freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well..
fun2FreshFun += (fd -> freshFunDef)
} else {
fun2FreshFun += (fd -> fd) //this will make the next step simpler
}
})
//then apply recursively to the bodies
val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else {
val body = fd.getBody
val newFd = fun2FreshFun(fd)
newFd.body = Some(removeUnit(body))
Seq(newFd)
})
val Program(id, ObjectDef(objId, _, invariants)) = pgm
val allClasses = pgm.definedClasses
Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants))
}
private def simplifyType(tpe: TypeTree): TypeTree = tpe match {
case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match {
case Seq() => UnitType
case Seq(tpe) => tpe
case tpes => TupleType(tpes)
}
case t => t
}
//remove unit value as soon as possible, so expr should never be equal to a unit
private def removeUnit(expr: Expr): Expr = {
assert(expr.getType != UnitType)
expr match {
case fi@FunctionInvocation(fd, args) => {
val newArgs = args.filterNot(arg => arg.getType == UnitType)
FunctionInvocation(fun2FreshFun(fd), newArgs).setPosInfo(fi)
}
case t@Tuple(args) => {
val TupleType(tpes) = t.getType
val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip
Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes))
}
case ts@TupleSelect(t, index) => {
val TupleType(tpes) = t.getType
val selectionType = tpes(index-1)
val (_, newIndex) = tpes.zipWithIndex.foldLeft((0,-1)){
case ((nbUnit, newIndex), (tpe, i)) =>
if(i == index-1) (nbUnit, index - nbUnit) else (if(tpe == UnitType) nbUnit + 1 else nbUnit, newIndex)
}
TupleSelect(removeUnit(t), newIndex).setType(selectionType)
}
case Let(id, e, b) => {
if(id.getType == UnitType)
removeUnit(b)
else {
id.getType match {
case TupleType(tpes) if tpes.exists(_ == UnitType) => {
val newTupleType = TupleType(tpes.filterNot(_ == UnitType))
val freshId = FreshIdentifier(id.name).setType(newTupleType)
id2FreshId += (id -> freshId)
val newBody = removeUnit(b)
id2FreshId -= id
Let(freshId, removeUnit(e), newBody)
}
case _ => Let(id, removeUnit(e), removeUnit(b))
}
}
}
case LetDef(fd, b) => {
if(fd.returnType == UnitType)
removeUnit(b)
else {
val (newFd, rest) = if(fd.args.exists(vd => vd.tpe == UnitType)) {
val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd)
freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well..
freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well..
fun2FreshFun += (fd -> freshFunDef)
freshFunDef.body = Some(removeUnit(fd.getBody))
val restRec = removeUnit(b)
fun2FreshFun -= fd
(freshFunDef, restRec)
} else {
fun2FreshFun += (fd -> fd)
fd.body = Some(removeUnit(fd.getBody))
val restRec = removeUnit(b)
fun2FreshFun -= fd
(fd, restRec)
}
LetDef(newFd, rest)
}
}
case ite@IfExpr(cond, tExpr, eExpr) => {
val thenRec = removeUnit(tExpr)
val elseRec = removeUnit(eExpr)
IfExpr(removeUnit(cond), thenRec, elseRec).setType(thenRec.getType)
}
case n @ NAryOperator(args, recons) => {
recons(args.map(removeUnit(_))).setType(n.getType)
}
case b @ BinaryOperator(a1, a2, recons) => {
recons(removeUnit(a1), removeUnit(a2)).setType(b.getType)
}
case u @ UnaryOperator(a, recons) => {
recons(removeUnit(a)).setType(u.getType)
}
case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v
case (t: Terminal) => t
case m @ MatchExpr(scrut, cses) => {
val scrutRec = removeUnit(scrut)
val csesRec = cses.map{
case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs))
case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs))
}
val tpe = csesRec.head.rhs.getType
MatchExpr(scrutRec, csesRec).setType(tpe)
}
case _ => sys.error("not supported: " + expr)
}
}
}