Skip to content
Snippets Groups Projects
Commit 6c0b1684 authored by Régis Blanc's avatar Régis Blanc
Browse files

Template for Epsilon elimination pass

parent 1018f32e
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,7 @@ class Analysis(pgm : Program, val reporter: Reporter = Settings.reporter) {
Extensions.loadAll(reporter)
println("Analysis on program:\n" + pgm)
val passManager = new PassManager(Seq(ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator))
val passManager = new PassManager(Seq(EpsilonElimination, ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator))
val program = passManager.run(pgm)
val analysisExtensions: Seq[Analyser] = loadedAnalysisExtensions
......
package leon
import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.TypeTrees._
object EpsilonElimination extends Pass {
val description = "Remove all epsilons from the program"
private var fun2FreshFun: Map[FunDef, FunDef] = Map()
private var id2FreshId: Map[Identifier, Identifier] = Map()
def apply(pgm: Program): Program = {
fun2FreshFun = Map()
val allFuns = pgm.definedFunctions
//first introduce new signatures without Unit parameters
allFuns.foreach(fd => {
if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) {
val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd)
freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well..
freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well..
fun2FreshFun += (fd -> freshFunDef)
} else {
fun2FreshFun += (fd -> fd) //this will make the next step simpler
}
})
//then apply recursively to the bodies
val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else {
val body = fd.getBody
val newFd = fun2FreshFun(fd)
newFd.body = Some(removeUnit(body))
Seq(newFd)
})
val Program(id, ObjectDef(objId, _, invariants)) = pgm
val allClasses = pgm.definedClasses
Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants))
}
private def simplifyType(tpe: TypeTree): TypeTree = tpe match {
case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match {
case Seq() => UnitType
case Seq(tpe) => tpe
case tpes => TupleType(tpes)
}
case t => t
}
//remove unit value as soon as possible, so expr should never be equal to a unit
private def removeUnit(expr: Expr): Expr = {
assert(expr.getType != UnitType)
expr match {
case fi@FunctionInvocation(fd, args) => {
val newArgs = args.filterNot(arg => arg.getType == UnitType)
FunctionInvocation(fun2FreshFun(fd), newArgs).setPosInfo(fi)
}
case t@Tuple(args) => {
val TupleType(tpes) = t.getType
val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip
Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes))
}
case ts@TupleSelect(t, index) => {
val TupleType(tpes) = t.getType
val selectionType = tpes(index-1)
val (_, newIndex) = tpes.zipWithIndex.foldLeft((0,-1)){
case ((nbUnit, newIndex), (tpe, i)) =>
if(i == index-1) (nbUnit, index - nbUnit) else (if(tpe == UnitType) nbUnit + 1 else nbUnit, newIndex)
}
TupleSelect(removeUnit(t), newIndex).setType(selectionType)
}
case Let(id, e, b) => {
if(id.getType == UnitType)
removeUnit(b)
else {
id.getType match {
case TupleType(tpes) if tpes.exists(_ == UnitType) => {
val newTupleType = TupleType(tpes.filterNot(_ == UnitType))
val freshId = FreshIdentifier(id.name).setType(newTupleType)
id2FreshId += (id -> freshId)
val newBody = removeUnit(b)
id2FreshId -= id
Let(freshId, removeUnit(e), newBody)
}
case _ => Let(id, removeUnit(e), removeUnit(b))
}
}
}
case LetDef(fd, b) => {
if(fd.returnType == UnitType)
removeUnit(b)
else {
val (newFd, rest) = if(fd.args.exists(vd => vd.tpe == UnitType)) {
val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd)
freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well..
freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well..
fun2FreshFun += (fd -> freshFunDef)
freshFunDef.body = Some(removeUnit(fd.getBody))
val restRec = removeUnit(b)
fun2FreshFun -= fd
(freshFunDef, restRec)
} else {
fun2FreshFun += (fd -> fd)
fd.body = Some(removeUnit(fd.getBody))
val restRec = removeUnit(b)
fun2FreshFun -= fd
(fd, restRec)
}
LetDef(newFd, rest)
}
}
case ite@IfExpr(cond, tExpr, eExpr) => {
val thenRec = removeUnit(tExpr)
val elseRec = removeUnit(eExpr)
IfExpr(removeUnit(cond), thenRec, elseRec).setType(thenRec.getType)
}
case n @ NAryOperator(args, recons) => {
recons(args.map(removeUnit(_))).setType(n.getType)
}
case b @ BinaryOperator(a1, a2, recons) => {
recons(removeUnit(a1), removeUnit(a2)).setType(b.getType)
}
case u @ UnaryOperator(a, recons) => {
recons(removeUnit(a)).setType(u.getType)
}
case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v
case (t: Terminal) => t
case m @ MatchExpr(scrut, cses) => {
val scrutRec = removeUnit(scrut)
val csesRec = cses.map{
case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs))
case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs))
}
val tpe = csesRec.head.rhs.getType
MatchExpr(scrutRec, csesRec).setType(tpe)
}
case _ => sys.error("not supported: " + expr)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment