From 27e7049bab6eb607fc78be3fb859287eb7fab4be Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Fri, 19 Jun 2015 17:21:09 +0200 Subject: [PATCH] Fix and generalize genericTransform, fix bug in liftLets --- src/main/scala/leon/purescala/ExprOps.scala | 28 ++++++++++++------- src/main/scala/leon/synthesis/CostModel.scala | 2 +- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index b3bd13f5d..9f1dfccc4 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -676,13 +676,21 @@ object ExprOps { type C = Seq[(Identifier, Expr)] - def lift(e: Expr, defs: C) = e match { - case Let(i, ex, b) => (b, (i, ex) +: defs) + def combiner(e: Expr, defs: Seq[C]): C = (e, defs) match { + case (Let(i, ex, b), Seq(inDef, inBody)) => + inDef ++ ((i, ex) +: inBody) + case _ => + defs.flatten + } + + def noLet(e: Expr, defs: C) = e match { + case Let(_, _, b) => (b, defs) case _ => (e, defs) } - val (bd, defs) = genericTransform[C](noTransformer, lift, _.flatten)(Seq())(e) - defs.foldRight(bd){ case ((id, e), body) => let(id, e, body) } + val (bd, defs) = genericTransform[C](noTransformer, noLet, combiner)(Seq())(e) + + defs.foldRight(bd){ case ((id, e), body) => Let(id, e, body) } } /** @@ -1060,7 +1068,7 @@ object ExprOps { def genericTransform[C](pre: (Expr, C) => (Expr, C), post: (Expr, C) => (Expr, C), - combiner: (Seq[C]) => C)(init: C)(expr: Expr) = { + combiner: (Expr, Seq[C]) => C)(init: C)(expr: Expr) = { def rec(eIn: Expr, cIn: C): (Expr, C) = { @@ -1068,26 +1076,26 @@ object ExprOps { val (newExpr, newC) = expr match { case t: Terminal => - (expr, cIn) + (expr, ctx) case UnaryOperator(e, builder) => val (e1, c) = rec(e, ctx) val newE = builder(e1).copiedFrom(expr) - (newE, combiner(Seq(c))) + (newE, combiner(newE, Seq(c))) case BinaryOperator(e1, e2, builder) => val (ne1, c1) = rec(e1, ctx) val (ne2, c2) = rec(e2, ctx) val newE = builder(ne1, ne2).copiedFrom(expr) - (newE, combiner(Seq(c1, c2))) + (newE, combiner(newE, Seq(c1, c2))) case NAryOperator(es, builder) => val (nes, cs) = es.map{ rec(_, ctx)}.unzip val newE = builder(nes).copiedFrom(expr) - (newE, combiner(cs)) + (newE, combiner(newE, cs)) case e => sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") @@ -1099,7 +1107,7 @@ object ExprOps { rec(expr, init) } - private def noCombiner(subCs: Seq[Unit]) = () + private def noCombiner(e: Expr, subCs: Seq[Unit]) = () private def noTransformer[C](e: Expr, c: C) = (e, c) def simpleTransform(pre: Expr => Expr, post: Expr => Expr)(expr: Expr) = { diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index 3c31731ef..6a76522b0 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -112,7 +112,7 @@ case object WeightedBranchesCostModel extends SizeBasedCostModel("WeightedBranch (e, bc) } - def combiner(cs: Seq[BC]) = { + def combiner(e: Expr, cs: Seq[BC]) = { cs.foldLeft(BC(0,0))((bc1, bc2) => BC(bc1.cost + bc2.cost, 0)) } -- GitLab