From ebbed9d09ed767cf3309d9f52b271a8ec66f81b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Mon, 26 Mar 2012 15:46:25 +0200 Subject: [PATCH] function with more than one level of nested function now work --- src/main/scala/leon/FunctionHoisting.scala | 5 ++++- .../scala/leon/ImperativeCodeElimination.scala | 17 ++++++++++------- .../scala/leon/purescala/PrettyPrinter.scala | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala index 9f539f1ba..88881e2ce 100644 --- a/src/main/scala/leon/FunctionHoisting.scala +++ b/src/main/scala/leon/FunctionHoisting.scala @@ -24,7 +24,10 @@ object FunctionHoisting extends Pass { private def hoist(expr: Expr): (Expr, Set[FunDef]) = expr match { case l @ LetDef(fd, rest) => { val (e, s) = hoist(rest) - (e, s + fd) + val (e2, s2) = hoist(fd.getBody) + fd.body = Some(e2) + + (e, (s ++ s2) + fd) } case l @ Let(i,e,b) => { val (re, s1) = hoist(e) diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index abf197da2..1f311533d 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -12,13 +12,8 @@ object ImperativeCodeElimination extends Pass { def apply(pgm: Program): Program = { val allFuns = pgm.definedFunctions allFuns.foreach(fd => { - val (res, _, _) = toFunction(fd.getBody) - //val ((scope, fun), last) = fd.getBody match { - // case Block(stmts, stmt) => (toFunction(Block(stmts.init, stmts.last)), stmt) - // case _ => sys.error("not supported") - //} - //fd.body = Some(scope(replace(fun.map{case (i1, i2) => (i1.toVariable, i2.toVariable)}, last))) - fd.body = Some(res) + val (res, scope, _) = toFunction(fd.getBody) + fd.body = Some(scope(res)) }) pgm } @@ -102,6 +97,14 @@ object ImperativeCodeElimination extends Pass { } //pure expression (that could still contain side effects as a subexpression) + case Let(id, e, b) => { + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + (bodyRes, (b: Expr) => Let(id, e, bodyScope(b)), bodyFun) + } + case LetDef(fd, b) => { + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + (bodyRes, (b: Expr) => LetDef(fd, bodyScope(b)), bodyFun) + } case n @ NAryOperator(args, recons) => { val (recArgs, scope, fun) = args.foldLeft((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((acc, arg) => { val (accArgs, scope, fun) = acc diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 7df4adbec..0e2015591 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -80,7 +80,7 @@ object PrettyPrinter { } case LetDef(fd,e) => { sb.append("\n") - pp(fd, sb, lvl) + pp(fd, sb, lvl+1) sb.append("\n") sb.append("\n") ind(sb, lvl) -- GitLab