diff --git a/src/main/scala/leon/Analysis.scala b/src/main/scala/leon/Analysis.scala index d7efcaecf8557e31fab4b61ab04fe17384236718..5de2c671af169d9bd9671a7f4c340cbd6548df3e 100644 --- a/src/main/scala/leon/Analysis.scala +++ b/src/main/scala/leon/Analysis.scala @@ -11,7 +11,7 @@ class Analysis(pgm : Program, val reporter: Reporter = Settings.reporter) { Extensions.loadAll(reporter) println("Analysis on program:\n" + pgm) - val passManager = new PassManager(Seq(ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator)) + val passManager = new PassManager(Seq(EpsilonElimination, ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator)) val program = passManager.run(pgm) val analysisExtensions: Seq[Analyser] = loadedAnalysisExtensions diff --git a/src/main/scala/leon/EpsilonElimination.scala b/src/main/scala/leon/EpsilonElimination.scala new file mode 100644 index 0000000000000000000000000000000000000000..b8071a721080a67427f3faae255ff37fc4c76585 --- /dev/null +++ b/src/main/scala/leon/EpsilonElimination.scala @@ -0,0 +1,144 @@ +package leon + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TypeTrees._ + +object EpsilonElimination extends Pass { + + val description = "Remove all epsilons from the program" + + private var fun2FreshFun: Map[FunDef, FunDef] = Map() + private var id2FreshId: Map[Identifier, Identifier] = Map() + + def apply(pgm: Program): Program = { + fun2FreshFun = Map() + val allFuns = pgm.definedFunctions + + //first introduce new signatures without Unit parameters + allFuns.foreach(fd => { + if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) { + val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd) + freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. + freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. + fun2FreshFun += (fd -> freshFunDef) + } else { + fun2FreshFun += (fd -> fd) //this will make the next step simpler + } + }) + + //then apply recursively to the bodies + val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else { + val body = fd.getBody + val newFd = fun2FreshFun(fd) + newFd.body = Some(removeUnit(body)) + Seq(newFd) + }) + + val Program(id, ObjectDef(objId, _, invariants)) = pgm + val allClasses = pgm.definedClasses + Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants)) + } + + private def simplifyType(tpe: TypeTree): TypeTree = tpe match { + case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match { + case Seq() => UnitType + case Seq(tpe) => tpe + case tpes => TupleType(tpes) + } + case t => t + } + + //remove unit value as soon as possible, so expr should never be equal to a unit + private def removeUnit(expr: Expr): Expr = { + assert(expr.getType != UnitType) + expr match { + case fi@FunctionInvocation(fd, args) => { + val newArgs = args.filterNot(arg => arg.getType == UnitType) + FunctionInvocation(fun2FreshFun(fd), newArgs).setPosInfo(fi) + } + case t@Tuple(args) => { + val TupleType(tpes) = t.getType + val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip + Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes)) + } + case ts@TupleSelect(t, index) => { + val TupleType(tpes) = t.getType + val selectionType = tpes(index-1) + val (_, newIndex) = tpes.zipWithIndex.foldLeft((0,-1)){ + case ((nbUnit, newIndex), (tpe, i)) => + if(i == index-1) (nbUnit, index - nbUnit) else (if(tpe == UnitType) nbUnit + 1 else nbUnit, newIndex) + } + TupleSelect(removeUnit(t), newIndex).setType(selectionType) + } + case Let(id, e, b) => { + if(id.getType == UnitType) + removeUnit(b) + else { + id.getType match { + case TupleType(tpes) if tpes.exists(_ == UnitType) => { + val newTupleType = TupleType(tpes.filterNot(_ == UnitType)) + val freshId = FreshIdentifier(id.name).setType(newTupleType) + id2FreshId += (id -> freshId) + val newBody = removeUnit(b) + id2FreshId -= id + Let(freshId, removeUnit(e), newBody) + } + case _ => Let(id, removeUnit(e), removeUnit(b)) + } + } + } + case LetDef(fd, b) => { + if(fd.returnType == UnitType) + removeUnit(b) + else { + val (newFd, rest) = if(fd.args.exists(vd => vd.tpe == UnitType)) { + val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd) + freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. + freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. + fun2FreshFun += (fd -> freshFunDef) + freshFunDef.body = Some(removeUnit(fd.getBody)) + val restRec = removeUnit(b) + fun2FreshFun -= fd + (freshFunDef, restRec) + } else { + fun2FreshFun += (fd -> fd) + fd.body = Some(removeUnit(fd.getBody)) + val restRec = removeUnit(b) + fun2FreshFun -= fd + (fd, restRec) + } + LetDef(newFd, rest) + } + } + case ite@IfExpr(cond, tExpr, eExpr) => { + val thenRec = removeUnit(tExpr) + val elseRec = removeUnit(eExpr) + IfExpr(removeUnit(cond), thenRec, elseRec).setType(thenRec.getType) + } + case n @ NAryOperator(args, recons) => { + recons(args.map(removeUnit(_))).setType(n.getType) + } + case b @ BinaryOperator(a1, a2, recons) => { + recons(removeUnit(a1), removeUnit(a2)).setType(b.getType) + } + case u @ UnaryOperator(a, recons) => { + recons(removeUnit(a)).setType(u.getType) + } + case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v + case (t: Terminal) => t + case m @ MatchExpr(scrut, cses) => { + val scrutRec = removeUnit(scrut) + val csesRec = cses.map{ + case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs)) + } + val tpe = csesRec.head.rhs.getType + MatchExpr(scrutRec, csesRec).setType(tpe) + } + case _ => sys.error("not supported: " + expr) + } + } + +}