diff --git a/src/main/scala/leon/EpsilonElimination.scala b/src/main/scala/leon/EpsilonElimination.scala index b8071a721080a67427f3faae255ff37fc4c76585..585888f4c771ea6354df64f8792d0becd90dea52 100644 --- a/src/main/scala/leon/EpsilonElimination.scala +++ b/src/main/scala/leon/EpsilonElimination.scala @@ -9,136 +9,25 @@ object EpsilonElimination extends Pass { val description = "Remove all epsilons from the program" - private var fun2FreshFun: Map[FunDef, FunDef] = Map() - private var id2FreshId: Map[Identifier, Identifier] = Map() - def apply(pgm: Program): Program = { - fun2FreshFun = Map() - val allFuns = pgm.definedFunctions - - //first introduce new signatures without Unit parameters - allFuns.foreach(fd => { - if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) { - val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd) - freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. - freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. - fun2FreshFun += (fd -> freshFunDef) - } else { - fun2FreshFun += (fd -> fd) //this will make the next step simpler - } - }) - - //then apply recursively to the bodies - val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else { - val body = fd.getBody - val newFd = fun2FreshFun(fd) - newFd.body = Some(removeUnit(body)) - Seq(newFd) - }) - - val Program(id, ObjectDef(objId, _, invariants)) = pgm - val allClasses = pgm.definedClasses - Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants)) - } - private def simplifyType(tpe: TypeTree): TypeTree = tpe match { - case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match { - case Seq() => UnitType - case Seq(tpe) => tpe - case tpes => TupleType(tpes) - } - case t => t - } - - //remove unit value as soon as possible, so expr should never be equal to a unit - private def removeUnit(expr: Expr): Expr = { - assert(expr.getType != UnitType) - expr match { - case fi@FunctionInvocation(fd, args) => { - val newArgs = args.filterNot(arg => arg.getType == UnitType) - FunctionInvocation(fun2FreshFun(fd), newArgs).setPosInfo(fi) - } - case t@Tuple(args) => { - val TupleType(tpes) = t.getType - val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip - Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes)) - } - case ts@TupleSelect(t, index) => { - val TupleType(tpes) = t.getType - val selectionType = tpes(index-1) - val (_, newIndex) = tpes.zipWithIndex.foldLeft((0,-1)){ - case ((nbUnit, newIndex), (tpe, i)) => - if(i == index-1) (nbUnit, index - nbUnit) else (if(tpe == UnitType) nbUnit + 1 else nbUnit, newIndex) - } - TupleSelect(removeUnit(t), newIndex).setType(selectionType) - } - case Let(id, e, b) => { - if(id.getType == UnitType) - removeUnit(b) - else { - id.getType match { - case TupleType(tpes) if tpes.exists(_ == UnitType) => { - val newTupleType = TupleType(tpes.filterNot(_ == UnitType)) - val freshId = FreshIdentifier(id.name).setType(newTupleType) - id2FreshId += (id -> freshId) - val newBody = removeUnit(b) - id2FreshId -= id - Let(freshId, removeUnit(e), newBody) - } - case _ => Let(id, removeUnit(e), removeUnit(b)) - } - } - } - case LetDef(fd, b) => { - if(fd.returnType == UnitType) - removeUnit(b) - else { - val (newFd, rest) = if(fd.args.exists(vd => vd.tpe == UnitType)) { - val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd) - freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. - freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. - fun2FreshFun += (fd -> freshFunDef) - freshFunDef.body = Some(removeUnit(fd.getBody)) - val restRec = removeUnit(b) - fun2FreshFun -= fd - (freshFunDef, restRec) - } else { - fun2FreshFun += (fd -> fd) - fd.body = Some(removeUnit(fd.getBody)) - val restRec = removeUnit(b) - fun2FreshFun -= fd - (fd, restRec) - } - LetDef(newFd, rest) - } - } - case ite@IfExpr(cond, tExpr, eExpr) => { - val thenRec = removeUnit(tExpr) - val elseRec = removeUnit(eExpr) - IfExpr(removeUnit(cond), thenRec, elseRec).setType(thenRec.getType) - } - case n @ NAryOperator(args, recons) => { - recons(args.map(removeUnit(_))).setType(n.getType) - } - case b @ BinaryOperator(a1, a2, recons) => { - recons(removeUnit(a1), removeUnit(a2)).setType(b.getType) - } - case u @ UnaryOperator(a, recons) => { - recons(removeUnit(a)).setType(u.getType) - } - case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v - case (t: Terminal) => t - case m @ MatchExpr(scrut, cses) => { - val scrutRec = removeUnit(scrut) - val csesRec = cses.map{ - case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs)) + val allFuns = pgm.definedFunctions + allFuns.foreach(fd => fd.body.map(body => { + val newBody = searchAndReplaceDFS{ + case eps@Epsilon(pred) => { + val freshName = FreshIdentifier("epsilon") + val newFunDef = new FunDef(freshName, eps.getType, Seq()) + val epsilonVar = EpsilonVariable(eps.posIntInfo) + val resultVar = ResultVariable().setType(eps.getType) + val postcondition = replace(Map(epsilonVar -> resultVar), pred) + newFunDef.postcondition = Some(postcondition) + Some(LetDef(newFunDef, FunctionInvocation(newFunDef, Seq()))) } - val tpe = csesRec.head.rhs.getType - MatchExpr(scrutRec, csesRec).setType(tpe) - } - case _ => sys.error("not supported: " + expr) - } + case _ => None + }(body) + fd.body = Some(newBody) + })) + pgm } } diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala index 47c0be1391b1f799a1a45ccfb9cc9b30351678c8..ba1c793d376439450cb4565b764014f7c73aa4ae 100644 --- a/src/main/scala/leon/FunctionClosure.scala +++ b/src/main/scala/leon/FunctionClosure.scala @@ -21,7 +21,7 @@ object FunctionClosure extends Pass { funDefs.foreach(fd => { enclosingPreconditions = fd.precondition.toList pathConstraints = fd.precondition.toList - fd.body = Some(functionClosure(fd.getBody, fd.args.map(_.id).toSet)) + fd.body = fd.body.map(b => functionClosure(b, fd.args.map(_.id).toSet)) }) program } @@ -32,11 +32,14 @@ object FunctionClosure extends Pass { val id = fd.id val rt = fd.returnType val varDecl = fd.args - val funBody = fd.getBody val precondition = fd.precondition val postcondition = fd.postcondition - val bodyVars: Set[Identifier] = variablesOf(funBody) ++ variablesOf(precondition.getOrElse(BooleanLiteral(true))) + val bodyVars: Set[Identifier] = (fd.body match { + case Some(body) => variablesOf(body) + case None => Set() + }) ++ variablesOf(precondition.getOrElse(BooleanLiteral(true))) ++ variablesOf(postcondition.getOrElse(BooleanLiteral(true))) + val capturedVars = bodyVars.intersect(bindedVars)// this should be the variable used that are in the scope val (constraints, allCapturedVars) = filterConstraints(capturedVars) //all relevant path constraints val capturedVarsWithConstraints = allCapturedVars.toSeq @@ -61,13 +64,13 @@ object FunctionClosure extends Pass { case FunctionInvocation(fd, args) if fd.id == id => Some(FunctionInvocation(newFunDef, args ++ extraVarDecls.map(_.id.toVariable))) case _ => None } - val freshBody = replace(freshVarsExpr, funBody) + val freshBody = fd.body.map(b => replace(freshVarsExpr, b)) val oldPathConstraints = pathConstraints pathConstraints = (precondition.getOrElse(BooleanLiteral(true)) :: pathConstraints).map(e => replace(freshVarsExpr, e)) - val recBody = functionClosure(freshBody, bindedVars ++ newVarDecls.map(_.id)) + val recBody = freshBody.map(b => functionClosure(b, bindedVars ++ newVarDecls.map(_.id))) pathConstraints = oldPathConstraints - val recBody2 = searchAndReplaceDFS(substFunInvocInDef)(recBody) - newFunDef.body = Some(recBody2) + val recBody2 = recBody.map(b => searchAndReplaceDFS(substFunInvocInDef)(b)) + newFunDef.body = recBody2 def substFunInvocInRest(expr: Expr): Option[Expr] = expr match { case FunctionInvocation(fd, args) if fd.id == id => Some(FunctionInvocation(newFunDef, args ++ capturedVarsWithConstraints.map(_.toVariable))) diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala index ec02d43360c0180925252ab6ab8f32590dac038b..799cea51aa28945996b6307d27de4fde7eec1867 100644 --- a/src/main/scala/leon/FunctionHoisting.scala +++ b/src/main/scala/leon/FunctionHoisting.scala @@ -12,11 +12,11 @@ object FunctionHoisting extends Pass { def apply(program: Program): Program = { val funDefs = program.definedFunctions var topLevelFuns: Set[FunDef] = Set() - funDefs.foreach(fd => { - val (newBody, additionalTopLevelFun) = hoist(fd.getBody) + funDefs.foreach(fd => fd.body.map(body => { + val (newBody, additionalTopLevelFun) = hoist(body) fd.body = Some(newBody) topLevelFuns ++= additionalTopLevelFun - }) + })) val Program(id, ObjectDef(objId, defs, invariants)) = program Program(id, ObjectDef(objId, defs ++ topLevelFuns, invariants)) } @@ -24,9 +24,14 @@ object FunctionHoisting extends Pass { private def hoist(expr: Expr): (Expr, Set[FunDef]) = expr match { case l @ LetDef(fd, rest) => { val (e, s) = hoist(rest) - val (e2, s2) = hoist(fd.getBody) - fd.body = Some(e2) - + val s2 = fd.body match { + case Some(body) => { + val (e2, s2) = hoist(body) + fd.body = Some(e2) + s2 + } + case None => Set() + } (e, (s ++ s2) + fd) } case l @ Let(i,e,b) => { diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index c60dd94b293f9098dd3639d4da4e91d5805a1672..cb14cc02a9b4d116a2fd17774c3f2e6dadd65135 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -13,10 +13,10 @@ object ImperativeCodeElimination extends Pass { def apply(pgm: Program): Program = { val allFuns = pgm.definedFunctions - allFuns.foreach(fd => { - val (res, scope, _) = toFunction(fd.getBody) + allFuns.foreach(fd => fd.body.map(body => { + val (res, scope, _) = toFunction(body) fd.body = Some(scope(res)) - }) + })) pgm } diff --git a/src/main/scala/leon/Simplificator.scala b/src/main/scala/leon/Simplificator.scala index e8a7e2e005b68ce54a6482cb4a06f942d54ba58f..43dcc8c01dc11811ed148bca3191d516e4af5029 100644 --- a/src/main/scala/leon/Simplificator.scala +++ b/src/main/scala/leon/Simplificator.scala @@ -12,9 +12,9 @@ object Simplificator extends Pass { def apply(pgm: Program): Program = { val allFuns = pgm.definedFunctions - allFuns.foreach(fd => { - fd.body = Some(simplifyLets(fd.getBody)) - }) + allFuns.foreach(fd => fd.body.map(body => { + fd.body = Some(simplifyLets(body)) + })) pgm } diff --git a/src/main/scala/leon/UnitElimination.scala b/src/main/scala/leon/UnitElimination.scala index 005544d9be2e8fff0e86f106464c29358383d772..d7c9781386cd9630d0f2372900678c318c2e7155 100644 --- a/src/main/scala/leon/UnitElimination.scala +++ b/src/main/scala/leon/UnitElimination.scala @@ -30,9 +30,9 @@ object UnitElimination extends Pass { //then apply recursively to the bodies val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else { - val body = fd.getBody + val newBody = fd.body.map(body => removeUnit(body)) val newFd = fun2FreshFun(fd) - newFd.body = Some(removeUnit(body)) + newFd.body = newBody Seq(newFd) }) @@ -98,13 +98,13 @@ object UnitElimination extends Pass { freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. fun2FreshFun += (fd -> freshFunDef) - freshFunDef.body = Some(removeUnit(fd.getBody)) + freshFunDef.body = fd.body.map(b => removeUnit(b)) val restRec = removeUnit(b) fun2FreshFun -= fd (freshFunDef, restRec) } else { fun2FreshFun += (fd -> fd) - fd.body = Some(removeUnit(fd.getBody)) + fd.body = fd.body.map(b => removeUnit(b)) val restRec = removeUnit(b) fun2FreshFun -= fd (fd, restRec) diff --git a/src/main/scala/leon/Utils.scala b/src/main/scala/leon/Utils.scala index 2a1a8b6b6d53c164b24c6a15b063c07292646b69..ae60002b693a0b74657fa8640a7eef56a641703f 100644 --- a/src/main/scala/leon/Utils.scala +++ b/src/main/scala/leon/Utils.scala @@ -10,7 +10,7 @@ object Utils { implicit def any2IsValid(x: Boolean) : IsValid = new IsValid(x) - def epsilon[A](pred: (A) => Boolean): A + def epsilon[A](pred: (A) => Boolean): A = throw new RuntimeException("Implementation not supported") object InvariantFunction { def invariant(x: Boolean): Unit = ()