diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index f9b2b85fbca82ba289986fe6221cc229c30cc84f..b1aa75ea38e8f383a5b3a09ff8cec3c479235ad8 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -12,21 +12,19 @@ import leon.purescala.ExprOps._ import leon.purescala.TypeOps._ import leon.xlang.Expressions._ -object ImperativeCodeElimination extends TransformationPhase { +object ImperativeCodeElimination extends UnitPhase[Program] { val name = "Imperative Code Elimination" val description = "Transform imperative constructs into purely functional code" - def apply(ctx: LeonContext, pgm: Program): Program = { - val allFuns = pgm.definedFunctions + def apply(ctx: LeonContext, pgm: Program): Unit = { for { - fd <- allFuns + fd <- pgm.definedFunctions body <- fd.body } { val (res, scope, _) = toFunction(body)(State(fd, Set())) fd.body = Some(scope(res)) } - pgm } case class State(parent: FunDef, varsInScope: Set[Identifier]) { @@ -39,23 +37,22 @@ object ImperativeCodeElimination extends TransformationPhase { //that should be introduced as such in the returned scope (the val already refers to the new names) private def toFunction(expr: Expr)(implicit state: State): (Expr, Expr => Expr, Map[Identifier, Identifier]) = { import state._ - val res = expr match { - case LetVar(id, e, b) => { + expr match { + case LetVar(id, e, b) => val newId = id.freshen val (rhsVal, rhsScope, rhsFun) = toFunction(e) val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withVar(id)) val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, replaceNames(rhsFun + (id -> newId), bodyScope(body))).copiedFrom(expr)) (bodyRes, scope, (rhsFun + (id -> newId)) ++ bodyFun) - } - case Assignment(id, e) => { + + case Assignment(id, e) => assert(varsInScope.contains(id)) val newId = id.freshen val (rhsVal, rhsScope, rhsFun) = toFunction(e) val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, body).copiedFrom(expr)) (UnitLiteral(), scope, rhsFun + (id -> newId)) - } - case ite@IfExpr(cond, tExpr, eExpr) => { + case ite@IfExpr(cond, tExpr, eExpr) => val (cRes, cScope, cFun) = toFunction(cond) val (tRes, tScope, tFun) = toFunction(tExpr) val (eRes, eScope, eFun) = toFunction(eExpr) @@ -64,19 +61,11 @@ object ImperativeCodeElimination extends TransformationPhase { val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varsInScope).toSeq val resId = FreshIdentifier("res", iteRType) - val freshIds = modifiedVars.map( { _.freshen }) + val freshIds = modifiedVars.map( _.freshen ) val iteType = tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) - val thenVal = tupleWrap(tRes +: modifiedVars.map(vId => tFun.get(vId) match { - case Some(newId) => newId.toVariable - case None => vId.toVariable - })) - - val elseVal = tupleWrap(eRes +: modifiedVars.map(vId => eFun.get(vId) match { - case Some(newId) => newId.toVariable - case None => vId.toVariable - })) - + val thenVal = tupleWrap(tRes +: modifiedVars.map(vId => tFun.getOrElse(vId, vId).toVariable)) + val elseVal = tupleWrap(eRes +: modifiedVars.map(vId => eFun.getOrElse(vId, vId).toVariable)) val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).copiedFrom(ite) val scope = (body: Expr) => { @@ -91,9 +80,8 @@ object ImperativeCodeElimination extends TransformationPhase { } (resId.toVariable, scope, cFun ++ modifiedVars.zip(freshIds).toMap) - } - case m @ MatchExpr(scrut, cses) => { + case m @ MatchExpr(scrut, cses) => val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3 val (scrutRes, scrutScope, scrutFun) = toFunction(scrut) @@ -103,18 +91,16 @@ object ImperativeCodeElimination extends TransformationPhase { val freshIds = modifiedVars.map(id => FreshIdentifier(id.name, id.getType)) val matchType = tupleTypeWrap(resId.getType +: freshIds.map(_.getType)) - val csesVals = csesRes.zip(csesFun).map{ - case (cRes, cFun) => tupleWrap(cRes +: modifiedVars.map(vId => cFun.get(vId) match { - case Some(newId) => newId.toVariable - case None => vId.toVariable - })) + val csesVals = csesRes.zip(csesFun).map { + case (cRes, cFun) => tupleWrap(cRes +: modifiedVars.map(vId => cFun.getOrElse(vId, vId).toVariable)) } - val newRhs = csesVals.zip(csesScope).map{ + val newRhs = csesVals.zip(csesScope).map { case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)) } - val matchE = matchExpr(scrutRes, cses.zip(newRhs).map{ - case (mc @ MatchCase(pat, guard, _), newRhs) => MatchCase(pat, guard map { replaceNames(scrutFun, _)}, newRhs).setPos(mc) + val matchE = matchExpr(scrutRes, cses.zip(newRhs).map { + case (mc @ MatchCase(pat, guard, _), newRhs) => + MatchCase(pat, guard map { replaceNames(scrutFun, _) }, newRhs).setPos(mc) }).setPos(m) val scope = (body: Expr) => { @@ -131,8 +117,8 @@ object ImperativeCodeElimination extends TransformationPhase { } (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) - } - case wh@While(cond, body) => { + + case wh@While(cond, body) => val (condRes, condScope, condFun) = toFunction(cond) val (_, bodyScope, bodyFun) = toFunction(body) val condBodyFun = condFun ++ bodyFun @@ -173,13 +159,12 @@ object ImperativeCodeElimination extends TransformationPhase { val invariantPrecondition: Option[Expr] = wh.invariant.map(expr => replaceNames(modifiedVars2WhileFunVars, expr)) val invariantPostcondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2ResultVars, expr)) whileFunDef.precondition = invariantPrecondition - whileFunDef.postcondition = trivialPostcondition.map(expr => - Lambda(Seq(ValDef(resVar.id)), and(expr, invariantPostcondition match { - case Some(e) => e - case None => BooleanLiteral(true) - }).setPos(wh) - ).setPos(wh) - ) + whileFunDef.postcondition = trivialPostcondition.map( expr => + Lambda( + Seq(ValDef(resVar.id)), + and(expr, invariantPostcondition.getOrElse(BooleanLiteral(true))).setPos(wh) + ).setPos(wh) + ) val finalVars = modifiedVars.map(_.freshen) val finalScope = (body: Expr) => { @@ -195,10 +180,11 @@ object ImperativeCodeElimination extends TransformationPhase { (UnitLiteral(), finalScope, modifiedVars.zip(finalVars).toMap) } - } - case Block(Seq(), expr) => toFunction(expr) - case Block(exprs, expr) => { + case Block(Seq(), expr) => + toFunction(expr) + + case Block(exprs, expr) => val (scope, fun) = exprs.foldRight((body: Expr) => body, Map[Identifier, Identifier]())((e, acc) => { val (accScope, accFun) = acc val (_, rScope, rFun) = toFunction(e) @@ -207,58 +193,55 @@ object ImperativeCodeElimination extends TransformationPhase { }) val (lastRes, lastScope, lastFun) = toFunction(expr) val finalFun = fun ++ lastFun - (replaceNames(finalFun, lastRes), - (body: Expr) => scope(replaceNames(fun, lastScope(body))), - finalFun) - } + ( + replaceNames(finalFun, lastRes), + (body: Expr) => scope(replaceNames(fun, lastScope(body))), + finalFun + ) //pure expression (that could still contain side effects as a subexpression) (evaluation order is from left to right) - case Let(id, e, b) => { + case Let(id, e, b) => val (bindRes, bindScope, bindFun) = toFunction(e) val (bodyRes, bodyScope, bodyFun) = toFunction(b) - (bodyRes, - (b2: Expr) => bindScope(Let(id, bindRes, replaceNames(bindFun, bodyScope(b2))).copiedFrom(expr)), - bindFun ++ bodyFun) - } - case LetDef(fd, b) => { + ( + bodyRes, + (b2: Expr) => bindScope(Let(id, bindRes, replaceNames(bindFun, bodyScope(b2))).copiedFrom(expr)), + bindFun ++ bodyFun + ) + + case LetDef(fd, b) => //Recall that here the nested function should not access mutable variables from an outside scope - val newFd = fd.body match { - case Some(b) => - val (fdRes, fdScope, _) = toFunction(b) - fd.body = Some(fdScope(fdRes)) - fd - case None => - fd + fd.body.foreach { bd => + val (fdRes, fdScope, _) = toFunction(bd) + fd.body = Some(fdScope(fdRes)) } val (bodyRes, bodyScope, bodyFun) = toFunction(b) - (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)).copiedFrom(expr), bodyFun) - } - case c @ Choose(b) => { + (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)).copiedFrom(expr), bodyFun) + + case c @ Choose(b) => //Recall that Choose cannot mutate variables from the scope (c, (b2: Expr) => b2, Map()) - } case i @ FunctionInvocation(fd, args) => //function invocation can have side effects so we should keep them as local //val names. - val scope = (body: Expr) => { - Let(FreshIdentifier("tmp", fd.returnType), - i, - body) - } + val scope = + (body: Expr) => Let( + FreshIdentifier("tmp", fd.returnType), + i, + body + ) (i, scope, Map()) - case And(args) => { + case And(args) => val ifExpr = args.reduceRight((el, acc) => IfExpr(el, acc, BooleanLiteral(false))) toFunction(ifExpr) - } - case Or(args) => { + case Or(args) => val ifExpr = args.reduceRight((el, acc) => IfExpr(el, BooleanLiteral(true), acc)) toFunction(ifExpr) - } - case n @ Operator(args, recons) => { + case n @ Operator(args, recons) => val (recArgs, scope, fun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => { val (accArgs, accScope, accFun) = acc val (argVal, argScope, argFun) = toFunction(arg) @@ -266,15 +249,12 @@ object ImperativeCodeElimination extends TransformationPhase { (argVal +: accArgs, newScope, argFun ++ accFun) }) (recons(recArgs).copiedFrom(n), scope, fun) - } - case _ => sys.error("not supported: " + expr) + case _ => + sys.error("not supported: " + expr) } - //val codeRepresentation = res._2(Block(res._3.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq, res._1)) - //println("res of toFunction on: " + expr + " IS: " + codeRepresentation) - res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])] //need cast because it seems that res first map type is _ <: Identifier instead of Identifier } - def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replace(fun.map(ids => (ids._1.toVariable, ids._2.toVariable)), expr) + def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replaceFromIDs(fun mapValues Variable, expr) }