diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala index 9f539f1baf4ccf5a65cfc1630f9ab1ebe82831e4..88881e2ceebcc0957f77dc928e15db448bd72f43 100644 --- a/src/main/scala/leon/FunctionHoisting.scala +++ b/src/main/scala/leon/FunctionHoisting.scala @@ -24,7 +24,10 @@ object FunctionHoisting extends Pass { private def hoist(expr: Expr): (Expr, Set[FunDef]) = expr match { case l @ LetDef(fd, rest) => { val (e, s) = hoist(rest) - (e, s + fd) + val (e2, s2) = hoist(fd.getBody) + fd.body = Some(e2) + + (e, (s ++ s2) + fd) } case l @ Let(i,e,b) => { val (re, s1) = hoist(e) diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index abf197da2412b80be82163bd06324be452869a0f..1f311533dd83c9900f40c35f6ee4863d92547db6 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -12,13 +12,8 @@ object ImperativeCodeElimination extends Pass { def apply(pgm: Program): Program = { val allFuns = pgm.definedFunctions allFuns.foreach(fd => { - val (res, _, _) = toFunction(fd.getBody) - //val ((scope, fun), last) = fd.getBody match { - // case Block(stmts, stmt) => (toFunction(Block(stmts.init, stmts.last)), stmt) - // case _ => sys.error("not supported") - //} - //fd.body = Some(scope(replace(fun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, last))) - fd.body = Some(res) + val (res, scope, _) = toFunction(fd.getBody) + fd.body = Some(scope(res)) }) pgm } @@ -102,6 +97,14 @@ object ImperativeCodeElimination extends Pass { } //pure expression (that could still contain side effects as a subexpression) + case Let(id, e, b) => { + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + (bodyRes, (b: Expr) => Let(id, e, bodyScope(b)), bodyFun) + } + case LetDef(fd, b) => { + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + (bodyRes, (b: Expr) => LetDef(fd, bodyScope(b)), bodyFun) + } case n @ NAryOperator(args, recons) => { val (recArgs, scope, fun) = args.foldLeft((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((acc, arg) => { val (accArgs, scope, fun) = acc diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 7df4adbec24fa1522433d760d66cc573eb0e0195..0e201559186a8cc620726e6f543d2371c07e2409 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -80,7 +80,7 @@ object PrettyPrinter { } case LetDef(fd,e) => { sb.append("\n") - pp(fd, sb, lvl) + pp(fd, sb, lvl+1) sb.append("\n") sb.append("\n") ind(sb, lvl)