diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index 1ec71eea823b2b459a702302feb65e25aac3f806..c60dd94b293f9098dd3639d4da4e91d5805a1672 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -9,6 +9,8 @@ object ImperativeCodeElimination extends Pass { val description = "Transform imperative constructs into purely functional code" + private var varInScope = Set[Identifier]() + def apply(pgm: Program): Program = { val allFuns = pgm.definedFunctions allFuns.foreach(fd => { @@ -24,7 +26,17 @@ object ImperativeCodeElimination extends Pass { //that should be introduced as such in the returned scope (the val already refers to the new names) private def toFunction(expr: Expr): (Expr, Expr => Expr, Map[Identifier, Identifier]) = { val res = expr match { + case LetVar(id, e, b) => { + val newId = FreshIdentifier(id.name).setType(id.getType) + val (rhsVal, rhsScope, rhsFun) = toFunction(e) + varInScope += id + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + varInScope -= id + val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, replaceNames(rhsFun + (id -> newId), bodyScope(body)))) + (bodyRes, scope, (rhsFun + (id -> newId)) ++ bodyFun) + } case Assignment(id, e) => { + assert(varInScope.contains(id)) val newId = FreshIdentifier(id.name).setType(id.getType) val (rhsVal, rhsScope, rhsFun) = toFunction(e) val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, body)) @@ -36,7 +48,7 @@ object ImperativeCodeElimination extends Pass { val (tRes, tScope, tFun) = toFunction(tExpr) val (eRes, eScope, eFun) = toFunction(eExpr) - val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSeq + val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varInScope).toSeq val resId = FreshIdentifier("res").setType(ite.getType) val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) val iteType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) @@ -72,12 +84,53 @@ object ImperativeCodeElimination extends Pass { (resId.toVariable, scope, cFun ++ modifiedVars.zip(freshIds).toMap) } + case m @ MatchExpr(scrut, cses) => { + val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure + val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3 + val (scrutRes, scrutScope, scrutFun) = toFunction(scrut) + + val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).intersect(varInScope).toSeq + val resId = FreshIdentifier("res").setType(m.getType) + val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val matchType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) + + val csesVals = csesRes.zip(csesFun).map{ + case (cRes, cFun) => (if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + }))).setType(matchType) + } + + val newRhs = csesVals.zip(csesScope).map{ + case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)).setType(matchType) + } + val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{ + case (SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs) + case (GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs) + }).setType(matchType) + + val scope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(matchType) + scrutScope( + Let(tupleId, matchExpr, + if(freshIds.isEmpty) + Let(resId, tupleId.toVariable, body) + else + Let(resId, TupleSelect(tupleId.toVariable, 1), + freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), + b))))) + }) + + (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) + } case wh@While(cond, body) => { val (condRes, condScope, condFun) = toFunction(cond) val (_, bodyScope, bodyFun) = toFunction(body) val condBodyFun = condFun ++ bodyFun - val modifiedVars: Seq[Identifier] = condBodyFun.keys.toSeq + val modifiedVars: Seq[Identifier] = condBodyFun.keys.toSet.intersect(varInScope).toSeq if(modifiedVars.isEmpty) (UnitLiteral, (b: Expr) => b, Map()) @@ -163,7 +216,7 @@ object ImperativeCodeElimination extends Pass { val (bindRes, bindScope, bindFun) = toFunction(e) val (bodyRes, bodyScope, bodyFun) = toFunction(b) (bodyRes, - (b2: Expr) => bindScope(Let(id, replaceNames(bindFun, bindRes), bodyScope(b2))), + (b2: Expr) => bindScope(Let(id, bindRes, replaceNames(bindFun, bodyScope(b2)))), bindFun ++ bodyFun) } case LetDef(fd, b) => { @@ -197,47 +250,6 @@ object ImperativeCodeElimination extends Pass { } case (t: Terminal) => (t, (body: Expr) => body, Map()) - case m @ MatchExpr(scrut, cses) => { - val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure - val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3 - val (scrutRes, scrutScope, scrutFun) = toFunction(scrut) - - val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).toSeq - val resId = FreshIdentifier("res").setType(m.getType) - val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) - val matchType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) - - val csesVals = csesRes.zip(csesFun).map{ - case (cRes, cFun) => (if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match { - case Some(newId) => newId.toVariable - case None => vId.toVariable - }))).setType(matchType) - } - - val newRhs = csesVals.zip(csesScope).map{ - case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)).setType(matchType) - } - val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{ - case (SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs) - case (GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs) - }).setType(matchType) - - val scope = ((body: Expr) => { - val tupleId = FreshIdentifier("t").setType(matchType) - scrutScope( - Let(tupleId, matchExpr, - if(freshIds.isEmpty) - Let(resId, tupleId.toVariable, body) - else - Let(resId, TupleSelect(tupleId.toVariable, 1), - freshIds.zipWithIndex.foldLeft(body)((b, id) => - Let(id._1, - TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), - b))))) - }) - - (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) - } case _ => sys.error("not supported: " + expr) } diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index d0753ec6580919ca8764859771ed406ce1faf0a4..cde76c0fa1b1b7f38c54aa1809628c1d0396db98 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -459,7 +459,6 @@ trait CodeExtraction extends Extractors { } handleRest = false val res = Let(newID, valTree, restTree) - println(res + " with type: " + res.getType + " and restree type: " + restTree.getType) res } case dd@ExFunctionDef(n, p, t, b) => { @@ -482,7 +481,18 @@ trait CodeExtraction extends Extractors { val newID = FreshIdentifier(vs.name.toString).setType(binderTpe) val valTree = rec(bdy) mutableVarSubsts += (vs -> (() => Variable(newID))) - Assignment(newID, valTree) + val restTree = rest match { + case Some(rst) => { + varSubsts(vs) = (() => Variable(newID)) + val res = rec(rst) + varSubsts.remove(vs) + res + } + case None => UnitLiteral + } + handleRest = false + val res = LetVar(newID, valTree, restTree) + res } case ExAssign(sym, rhs) => mutableVarSubsts.get(sym) match { diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index ec7236d50b3458edccc84fdd7551f0300f96cf02..9cc5b7b626a866430398cbd38a0faadb7b54cdee 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -78,6 +78,15 @@ object PrettyPrinter { sb.append(")") sb } + case LetVar(b,d,e) => { + sb.append("(letvar (" + b + " := "); + pp(d, sb, lvl) + sb.append(") in\n") + ind(sb, lvl+1) + pp(e, sb, lvl+1) + sb.append(")") + sb + } case LetDef(fd,e) => { sb.append("\n") pp(fd, sb, lvl+1) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index bb90f60c3e63267734ecbfe29d8e6ce15dcf90c9..09f91e3795562854af4fb9b72c98ef6f68afaea3 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -47,6 +47,13 @@ object Trees { if(et != Untyped) setType(et) } + //same as let, buf for mutable variable declaration + case class LetVar(binder: Identifier, value: Expr, body: Expr) extends Expr { + binder.markAsLetBinder + val et = body.getType + if(et != Untyped) + setType(et) + } //case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr { // binders.foreach(_.markAsLetBinder) @@ -502,6 +509,14 @@ object Trees { else l } + case l @ LetVar(i,e,b) => { + val re = rec(e) + val rb = rec(b) + if(re != e || rb != b) + LetVar(i, re, rb).setType(l.getType) + else + l + } case l @ LetDef(fd, b) => { //TODO, not sure, see comment for the next LetDef fd.body = fd.body.map(rec(_)) @@ -592,6 +607,15 @@ object Trees { l }) } + case l @ LetVar(i,e,b) => { + val re = rec(e) + val rb = rec(b) + applySubst(if(re != e || rb != b) { + LetVar(i,re,rb).setType(l.getType) + } else { + l + }) + } case l @ LetDef(fd,b) => { //TODO: Not sure: I actually need the replace to occurs even in the pre/post condition, hope this is correct fd.body = fd.body.map(rec(_)) @@ -723,6 +747,7 @@ object Trees { def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, compute: (Expr,A)=>A, expression: Expr) : A = { def rec(expr: Expr) : A = expr match { case l @ Let(_, e, b) => compute(l, combine(rec(e), rec(b))) + case l @ LetVar(_, e, b) => compute(l, combine(rec(e), rec(b))) case l @ LetDef(fd, b) => compute(l, combine(rec(fd.getBody), rec(b))) //TODO, still not sure about the semantic case n @ NAryOperator(args, _) => { if(args.size == 0) @@ -768,6 +793,7 @@ object Trees { case Block(_, _) => false case Assignment(_, _) => false case While(_, _) => false + case LetVar(_, _, _) => false case _ => true } def combine(b1: Boolean, b2: Boolean) = b1 && b2 @@ -775,6 +801,7 @@ object Trees { case Block(_, _) => false case Assignment(_, _) => false case While(_, _) => false + case LetVar(_, _, _) => false case _ => true } treeCatamorphism(convert, combine, compute, expr) @@ -859,6 +886,7 @@ object Trees { def allIdentifiers(expr: Expr) : Set[Identifier] = expr match { case l @ Let(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder + case l @ LetVar(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder case l @ LetDef(fd, b) => allIdentifiers(fd.getBody) ++ allIdentifiers(b) + fd.id case n @ NAryOperator(args, _) => (args map (Trees.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b)