Skip to content
Snippets Groups Projects
Commit f9b714dd authored by Régis Blanc's avatar Régis Blanc
Browse files

Some sort of elimination of all unit values/types

parent 2ca9b5ed
Branches
Tags
No related merge requests found
......@@ -11,7 +11,7 @@ class Analysis(pgm : Program, val reporter: Reporter = Settings.reporter) {
Extensions.loadAll(reporter)
println("Analysis on program:\n" + pgm)
val passManager = new PassManager(Seq(ImperativeCodeElimination, FunctionClosure, FunctionHoisting, Simplificator))
val passManager = new PassManager(Seq(ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator))
val program = passManager.run(pgm)
val analysisExtensions: Seq[Analyser] = loadedAnalysisExtensions
......
package leon
import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.TypeTrees._
object UnitElimination extends Pass {
val description = "Remove all usage of the Unit type and value"
private var fun2FreshFun: Map[FunDef, FunDef] = Map()
private var id2FreshId: Map[Identifier, Identifier] = Map()
def apply(pgm: Program): Program = {
fun2FreshFun = Map()
val allFuns = pgm.definedFunctions
//first introduce new signatures without Unit parameters
allFuns.foreach(fd => {
if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) {
val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType))
fun2FreshFun += (fd -> freshFunDef)
} else {
fun2FreshFun += (fd -> fd) //this will make the next step simpler
}
})
//then apply recursively to the bodies
val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else {
val body = fd.getBody
val newFd = fun2FreshFun(fd)
newFd.body = Some(removeUnit(body))
Seq(newFd)
})
val Program(id, ObjectDef(objId, _, invariants)) = pgm
val allClasses = pgm.definedClasses
Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants))
}
//remove unit value as soon as possible, so expr should never be equal to a unit
private def removeUnit(expr: Expr): Expr = {
assert(expr.getType != UnitType)
expr match {
case fi@FunctionInvocation(fd, args) => {
val newArgs = args.filterNot(arg => arg.getType == UnitType)
FunctionInvocation(fun2FreshFun(fd), newArgs).setPosInfo(fi)
}
case t@Tuple(args) => {
val TupleType(tpes) = t.getType
val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip
Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes))
}
case ts@TupleSelect(t, index) => {
println("Select Tuple: " + ts)
val TupleType(tpes) = t.getType
val selectionType = tpes(index-1)
val (_, newIndex) = tpes.zipWithIndex.foldLeft((0,-1)){
case ((nbUnit, newIndex), (tpe, i)) =>
if(i == index-1) (nbUnit, index - nbUnit) else (if(tpe == UnitType) nbUnit + 1 else nbUnit, newIndex)
}
println("new index = " + newIndex)
TupleSelect(removeUnit(t), newIndex).setType(selectionType)
}
case Let(id, e, b) => {
if(id.getType == UnitType)
removeUnit(b)
else {
id.getType match {
case TupleType(tpes) if tpes.exists(_ == UnitType) => {
val newTupleType = TupleType(tpes.filterNot(_ == UnitType))
val freshId = FreshIdentifier(id.name).setType(newTupleType)
id2FreshId += (id -> freshId)
val newBody = removeUnit(b)
id2FreshId -= id
Let(freshId, removeUnit(e), newBody)
}
case _ => Let(id, removeUnit(e), removeUnit(b))
}
}
}
case LetDef(fd, b) => {
if(fd.returnType == UnitType)
removeUnit(b)
else {
val (newFd, rest) = if(fd.args.exists(vd => vd.tpe == UnitType)) {
val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType))
fun2FreshFun += (fd -> freshFunDef)
freshFunDef.body = Some(removeUnit(fd.getBody))
val restRec = removeUnit(b)
fun2FreshFun -= fd
(freshFunDef, restRec)
} else {
fd.body = Some(removeUnit(fd.getBody))
val restRec = removeUnit(b)
(fd, restRec)
}
LetDef(newFd, rest)
}
}
case ite@IfExpr(cond, tExpr, eExpr) => {
val thenRec = removeUnit(tExpr)
val elseRec = removeUnit(eExpr)
IfExpr(removeUnit(cond), thenRec, elseRec).setType(thenRec.getType)
}
case n @ NAryOperator(args, recons) => {
recons(args.map(removeUnit(_))).setType(n.getType)
}
case b @ BinaryOperator(a1, a2, recons) => {
recons(removeUnit(a1), removeUnit(a2)).setType(b.getType)
}
case u @ UnaryOperator(a, recons) => {
recons(removeUnit(a)).setType(u.getType)
}
case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v
case (t: Terminal) => t
case m @ MatchExpr(scrut, cses) => sys.error("not supported: " + expr)
case _ => sys.error("not supported: " + expr)
}
}
}
......@@ -40,12 +40,12 @@ object Trees {
setType(et)
}
case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr {
binders.foreach(_.markAsLetBinder)
val et = body.getType
if(et != Untyped)
setType(et)
}
//case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr {
// binders.foreach(_.markAsLetBinder)
// val et = body.getType
// if(et != Untyped)
// setType(et)
//}
case class LetDef(value: FunDef, body: Expr) extends Expr {
val et = body.getType
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment