-
Etienne Kneuss authoredEtienne Kneuss authored
UnitElimination.scala 6.01 KiB
package leon
import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.Extractors._
import purescala.TypeTrees._
object UnitElimination extends TransformationPhase {
val name = "Unit Elimination"
val description = "Remove all usage of the Unit type and value"
private var fun2FreshFun: Map[FunDef, FunDef] = Map()
private var id2FreshId: Map[Identifier, Identifier] = Map()
def apply(ctx: LeonContext, pgm: Program): Program = {
fun2FreshFun = Map()
val allFuns = pgm.definedFunctions
//first introduce new signatures without Unit parameters
allFuns.foreach(fd => {
if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) {
val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd)
freshFunDef.fromLoop = fd.fromLoop
freshFunDef.parent = fd.parent
freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well..
freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well..
freshFunDef.addAnnotation(fd.annotations.toSeq:_*)
fun2FreshFun += (fd -> freshFunDef)
} else {
fun2FreshFun += (fd -> fd) //this will make the next step simpler
}
})
//then apply recursively to the bodies
val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else {
val newBody = fd.body.map(body => removeUnit(body))
val newFd = fun2FreshFun(fd)
newFd.body = newBody
Seq(newFd)
})
val Program(id, ObjectDef(objId, _, invariants)) = pgm
val allClasses = pgm.definedClasses
Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants))
}
private def simplifyType(tpe: TypeTree): TypeTree = tpe match {
case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match {
case Seq() => UnitType
case Seq(tpe) => tpe
case tpes => TupleType(tpes)
}
case t => t
}
//remove unit value as soon as possible, so expr should never be equal to a unit
private def removeUnit(expr: Expr): Expr = {
assert(expr.getType != UnitType)
expr match {
case fi@FunctionInvocation(fd, args) => {
val newArgs = args.filterNot(arg => arg.getType == UnitType)
FunctionInvocation(fun2FreshFun(fd), newArgs).setPosInfo(fi)
}
case t@Tuple(args) => {
val TupleType(tpes) = t.getType
val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip
Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes))
}