diff --git a/mytest/IfExpr1.scala b/mytest/IfExpr1.scala index 3978b974552545f1db46e0676b758251af741c4e..4910ca339d6c94ac83eddcd56d854a0a7e6b045b 100644 --- a/mytest/IfExpr1.scala +++ b/mytest/IfExpr1.scala @@ -8,6 +8,6 @@ object IfExpr1 { else b = a + b a - } + } ensuring(_ == 1) } diff --git a/src/main/scala/leon/Analysis.scala b/src/main/scala/leon/Analysis.scala index ed874898c05507c57b3e880ca8df1a7941255e27..3ae042b129b87bf26e8164fd766f2ab71d600ff6 100644 --- a/src/main/scala/leon/Analysis.scala +++ b/src/main/scala/leon/Analysis.scala @@ -11,7 +11,7 @@ class Analysis(pgm : Program, val reporter: Reporter = Settings.reporter) { Extensions.loadAll(reporter) println("Analysis on program:\n" + pgm) - val passManager = new PassManager(Seq(FunctionClosure, FunctionHoisting)) + val passManager = new PassManager(Seq(ImperativeCodeElimination, FunctionClosure, FunctionHoisting)) val program = passManager.run(pgm) val analysisExtensions: Seq[Analyser] = loadedAnalysisExtensions diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index 11f8df42efa056fc436e6902812d92513321350f..dab564e46fe68e422ce279e5168d29b2fdb25a2f 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -13,7 +13,7 @@ object ImperativeCodeElimination extends Pass { val allFuns = pgm.definedFunctions allFuns.foreach(fd => { val ((scope, fun), last) = fd.getBody match { - case Block(stmts) => (toFunction(Block(stmts.init)), stmts.last) + case Block(stmts, stmt) => (toFunction(Block(stmts.init, stmts.last)), stmt) case _ => sys.error("not supported") } fd.body = Some(scope(replace(fun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, last))) @@ -27,41 +27,48 @@ object ImperativeCodeElimination extends Pass { private def toFunction(expr: Expr): (Expr => Expr, Map[Identifier, Identifier]) = { println("toFunction of: " + expr) val res = expr match { - case Assignment(id, e) => { - val newId = FreshIdentifier(id.name).setType(id.getType) - val scope = ((body: Expr) => Let(newId, e, body)) - (scope, Map(id -> newId)) - } - case IfExpr(cond, tExpr, eExpr) => { - val (tScope, tFun) = toFunction(tExpr) - val (eScope, eFun) = toFunction(eExpr) - val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSeq - val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) - val newTExpr = tScope(Tuple(modifiedVars.map(vId => tFun.get(vId) match { - case Some(newId) => newId.toVariable - case None => vId.toVariable - }))) - val newEExpr = eScope(Tuple(modifiedVars.map(vId => eFun.get(vId) match { - case Some(newId) => newId.toVariable - case None => vId.toVariable - }))) - val newIfExpr = IfExpr(cond, newTExpr, newEExpr).setType(newTExpr.getType) - val scope = ((body: Expr) => LetTuple(freshIds, newIfExpr, body)) - (scope, Map(modifiedVars.zip(freshIds):_*)) - } - case Block(exprs) => { - val (headScope, headFun) = toFunction(exprs.head) - exprs.tail.foldLeft((headScope, headFun))((acc, e) => { - val (accScope, accFun) = acc - val (rScope, rFun) = toFunction(e) - val scope = ((body: Expr) => - accScope(replace(accFun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, rScope(body)))) - (scope, accFun ++ rFun) - }) - } - case _ => sys.error("not supported: " + expr) + case Assignment(id, e) => { + val newId = FreshIdentifier(id.name).setType(id.getType) + val scope = ((body: Expr) => Let(newId, e, body)) + (scope, Map(id -> newId)) + } + case IfExpr(cond, tExpr, eExpr) => { + val (tScope, tFun) = toFunction(tExpr) + val (eScope, eFun) = toFunction(eExpr) + val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSeq + val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val tupleType = TupleType(freshIds.map(_.getType)) + val newTExpr = tScope(Tuple(modifiedVars.map(vId => tFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + })).setType(tupleType)) + val newEExpr = eScope(Tuple(modifiedVars.map(vId => eFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + })).setType(tupleType)) + val newIfExpr = IfExpr(cond, newTExpr, newEExpr).setType(newTExpr.getType) + val scope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(TupleType(freshIds.map(_.getType))) + Let(tupleId, newIfExpr, freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 1).setType(id._1.getType), + b))) + }) + (scope, Map(modifiedVars.zip(freshIds):_*)) + } + case Block(head::exprs, expr) => { + val (headScope, headFun) = toFunction(head) + (exprs:+expr).foldLeft((headScope, headFun))((acc, e) => { + val (accScope, accFun) = acc + val (rScope, rFun) = toFunction(e) + val scope = ((body: Expr) => + accScope(replace(accFun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, rScope(body)))) + (scope, accFun ++ rFun) + }) + } + case _ => sys.error("not supported: " + expr) } - val codeRepresentation = res._1(Block(res._2.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq)) + val codeRepresentation = res._1(Block(res._2.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq, UnitLiteral)) println("res of toFunction on: " + expr + " IS: " + codeRepresentation) res } diff --git a/src/main/scala/leon/plugin/Extractors.scala b/src/main/scala/leon/plugin/Extractors.scala index 02493c088d9f5d0407242e0a9b3133fb866a65ea..54a5fb7f02c4ddd0d56ee371deddd0700894c660 100644 --- a/src/main/scala/leon/plugin/Extractors.scala +++ b/src/main/scala/leon/plugin/Extractors.scala @@ -185,8 +185,8 @@ trait Extractors { object ExWhile { def unapply(tree: LabelDef): Option[(Tree,Tree)] = tree match { case (label@LabelDef( - _, _, If(cond, Block(body, jump@Apply(_, _)), unit@Literal(_)))) - if label.symbol == jump.symbol && unit.symbol == null => Some((cond, body)) + _, _, If(cond, Block(body, jump@Apply(_, _)), unit@ExUnitLiteral()))) + if label.symbol == jump.symbol && unit.symbol == null => Some((cond, Block(body, unit))) case _ => None } } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 002a9ef79c1c99647982fe39fa85d9caa5c392fb..0c8d88f14c1e51cdafc80ab0c8ae1c8a9f02570c 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -90,9 +90,10 @@ object PrettyPrinter { case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) case StringLiteral(s) => sb.append("\"" + s + "\"") - case Block(exprs) => { + case UnitLiteral => sb.append("()") + case Block(exprs, last) => { sb.append("{\n") - exprs.foreach(e => { + (exprs :+ last).foreach(e => { ind(sb, lvl+1) pp(e, sb, lvl+1) sb.append("\n") @@ -102,7 +103,6 @@ object PrettyPrinter { sb } case Assignment(lhs, rhs) => ppBinary(sb, lhs.toVariable, rhs, " = ", lvl) - case Skip => sb.append("()") case While(cond, body) => { sb.append("while(") pp(cond, sb, lvl) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 309cc68732e865f11dcb3f1d4a38be18a99de28b..3157e8eb7f24c9dbe17984cfb987d6409f0b83c0 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -258,6 +258,7 @@ object Trees { case class StringLiteral(value: String) extends Literal[String] case object UnitLiteral extends Literal[Unit] with FixedType { val fixedType = UnitType + val value = () } case class CaseClass(classDef: CaseClassDef, args: Seq[Expr]) extends Expr with FixedType { @@ -730,11 +731,11 @@ object Trees { case Block(exprs, last) => { val nexprs = (exprs :+ last).flatMap{ case Block(es2, el) => es2 :+ el - case Skip => Seq() + case UnitLiteral => Seq() case e2 => Seq(e2) } val fexpr = nexprs match { - case Seq() => Skip + case Seq() => UnitLiteral case Seq(e) => e case es => Block(es.init, es.last).setType(es.last.getType) }