From 78c4285cb830ad48f6bd420c0fe91764eec338b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Sun, 25 Mar 2012 03:37:53 +0000 Subject: [PATCH] imperative transformation working with if then else and assignment --- mytest/IfExpr1.scala | 2 +- src/main/scala/leon/Analysis.scala | 2 +- .../leon/ImperativeCodeElimination.scala | 77 ++++++++++--------- src/main/scala/leon/plugin/Extractors.scala | 4 +- .../scala/leon/purescala/PrettyPrinter.scala | 6 +- src/main/scala/leon/purescala/Trees.scala | 5 +- 6 files changed, 52 insertions(+), 44 deletions(-) diff --git a/mytest/IfExpr1.scala b/mytest/IfExpr1.scala index 3978b9745..4910ca339 100644 --- a/mytest/IfExpr1.scala +++ b/mytest/IfExpr1.scala @@ -8,6 +8,6 @@ object IfExpr1 { else b = a + b a - } + } ensuring(_ == 1) } diff --git a/src/main/scala/leon/Analysis.scala b/src/main/scala/leon/Analysis.scala index ed874898c..3ae042b12 100644 --- a/src/main/scala/leon/Analysis.scala +++ b/src/main/scala/leon/Analysis.scala @@ -11,7 +11,7 @@ class Analysis(pgm : Program, val reporter: Reporter = Settings.reporter) { Extensions.loadAll(reporter) println("Analysis on program:\n" + pgm) - val passManager = new PassManager(Seq(FunctionClosure, FunctionHoisting)) + val passManager = new PassManager(Seq(ImperativeCodeElimination, FunctionClosure, FunctionHoisting)) val program = passManager.run(pgm) val analysisExtensions: Seq[Analyser] = loadedAnalysisExtensions diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index 11f8df42e..dab564e46 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -13,7 +13,7 @@ object ImperativeCodeElimination extends Pass { val allFuns = pgm.definedFunctions allFuns.foreach(fd => { val ((scope, fun), last) = fd.getBody match { - case Block(stmts) => (toFunction(Block(stmts.init)), stmts.last) + case Block(stmts, stmt) => (toFunction(Block(stmts.init, stmts.last)), stmt) case _ => sys.error("not supported") } fd.body = Some(scope(replace(fun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, last))) @@ -27,41 +27,48 @@ object ImperativeCodeElimination extends Pass { private def toFunction(expr: Expr): (Expr => Expr, Map[Identifier, Identifier]) = { println("toFunction of: " + expr) val res = expr match { - case Assignment(id, e) => { - val newId = FreshIdentifier(id.name).setType(id.getType) - val scope = ((body: Expr) => Let(newId, e, body)) - (scope, Map(id -> newId)) - } - case IfExpr(cond, tExpr, eExpr) => { - val (tScope, tFun) = toFunction(tExpr) - val (eScope, eFun) = toFunction(eExpr) - val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSeq - val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) - val newTExpr = tScope(Tuple(modifiedVars.map(vId => tFun.get(vId) match { - case Some(newId) => newId.toVariable - case None => vId.toVariable - }))) - val newEExpr = eScope(Tuple(modifiedVars.map(vId => eFun.get(vId) match { - case Some(newId) => newId.toVariable - case None => vId.toVariable - }))) - val newIfExpr = IfExpr(cond, newTExpr, newEExpr).setType(newTExpr.getType) - val scope = ((body: Expr) => LetTuple(freshIds, newIfExpr, body)) - (scope, Map(modifiedVars.zip(freshIds):_*)) - } - case Block(exprs) => { - val (headScope, headFun) = toFunction(exprs.head) - exprs.tail.foldLeft((headScope, headFun))((acc, e) => { - val (accScope, accFun) = acc - val (rScope, rFun) = toFunction(e) - val scope = ((body: Expr) => - accScope(replace(accFun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, rScope(body)))) - (scope, accFun ++ rFun) - }) - } - case _ => sys.error("not supported: " + expr) + case Assignment(id, e) => { + val newId = FreshIdentifier(id.name).setType(id.getType) + val scope = ((body: Expr) => Let(newId, e, body)) + (scope, Map(id -> newId)) + } + case IfExpr(cond, tExpr, eExpr) => { + val (tScope, tFun) = toFunction(tExpr) + val (eScope, eFun) = toFunction(eExpr) + val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSeq + val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val tupleType = TupleType(freshIds.map(_.getType)) + val newTExpr = tScope(Tuple(modifiedVars.map(vId => tFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + })).setType(tupleType)) + val newEExpr = eScope(Tuple(modifiedVars.map(vId => eFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + })).setType(tupleType)) + val newIfExpr = IfExpr(cond, newTExpr, newEExpr).setType(newTExpr.getType) + val scope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(TupleType(freshIds.map(_.getType))) + Let(tupleId, newIfExpr, freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 1).setType(id._1.getType), + b))) + }) + (scope, Map(modifiedVars.zip(freshIds):_*)) + } + case Block(head::exprs, expr) => { + val (headScope, headFun) = toFunction(head) + (exprs:+expr).foldLeft((headScope, headFun))((acc, e) => { + val (accScope, accFun) = acc + val (rScope, rFun) = toFunction(e) + val scope = ((body: Expr) => + accScope(replace(accFun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, rScope(body)))) + (scope, accFun ++ rFun) + }) + } + case _ => sys.error("not supported: " + expr) } - val codeRepresentation = res._1(Block(res._2.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq)) + val codeRepresentation = res._1(Block(res._2.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq, UnitLiteral)) println("res of toFunction on: " + expr + " IS: " + codeRepresentation) res } diff --git a/src/main/scala/leon/plugin/Extractors.scala b/src/main/scala/leon/plugin/Extractors.scala index 02493c088..54a5fb7f0 100644 --- a/src/main/scala/leon/plugin/Extractors.scala +++ b/src/main/scala/leon/plugin/Extractors.scala @@ -185,8 +185,8 @@ trait Extractors { object ExWhile { def unapply(tree: LabelDef): Option[(Tree,Tree)] = tree match { case (label@LabelDef( - _, _, If(cond, Block(body, jump@Apply(_, _)), unit@Literal(_)))) - if label.symbol == jump.symbol && unit.symbol == null => Some((cond, body)) + _, _, If(cond, Block(body, jump@Apply(_, _)), unit@ExUnitLiteral()))) + if label.symbol == jump.symbol && unit.symbol == null => Some((cond, Block(body, unit))) case _ => None } } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 002a9ef79..0c8d88f14 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -90,9 +90,10 @@ object PrettyPrinter { case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) case StringLiteral(s) => sb.append("\"" + s + "\"") - case Block(exprs) => { + case UnitLiteral => sb.append("()") + case Block(exprs, last) => { sb.append("{\n") - exprs.foreach(e => { + (exprs :+ last).foreach(e => { ind(sb, lvl+1) pp(e, sb, lvl+1) sb.append("\n") @@ -102,7 +103,6 @@ object PrettyPrinter { sb } case Assignment(lhs, rhs) => ppBinary(sb, lhs.toVariable, rhs, " = ", lvl) - case Skip => sb.append("()") case While(cond, body) => { sb.append("while(") pp(cond, sb, lvl) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 309cc6873..3157e8eb7 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -258,6 +258,7 @@ object Trees { case class StringLiteral(value: String) extends Literal[String] case object UnitLiteral extends Literal[Unit] with FixedType { val fixedType = UnitType + val value = () } case class CaseClass(classDef: CaseClassDef, args: Seq[Expr]) extends Expr with FixedType { @@ -730,11 +731,11 @@ object Trees { case Block(exprs, last) => { val nexprs = (exprs :+ last).flatMap{ case Block(es2, el) => es2 :+ el - case Skip => Seq() + case UnitLiteral => Seq() case e2 => Seq(e2) } val fexpr = nexprs match { - case Seq() => Skip + case Seq() => UnitLiteral case Seq(e) => e case es => Block(es.init, es.last).setType(es.last.getType) } -- GitLab