diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index b3bd13f5d380c8955d7068d02ee761aac6b30ed9..9f1dfccc44670e2ed5c2378edf2d29bc49cf88d8 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -676,13 +676,21 @@ object ExprOps { type C = Seq[(Identifier, Expr)] - def lift(e: Expr, defs: C) = e match { - case Let(i, ex, b) => (b, (i, ex) +: defs) + def combiner(e: Expr, defs: Seq[C]): C = (e, defs) match { + case (Let(i, ex, b), Seq(inDef, inBody)) => + inDef ++ ((i, ex) +: inBody) + case _ => + defs.flatten + } + + def noLet(e: Expr, defs: C) = e match { + case Let(_, _, b) => (b, defs) case _ => (e, defs) } - val (bd, defs) = genericTransform[C](noTransformer, lift, _.flatten)(Seq())(e) - defs.foldRight(bd){ case ((id, e), body) => let(id, e, body) } + val (bd, defs) = genericTransform[C](noTransformer, noLet, combiner)(Seq())(e) + + defs.foldRight(bd){ case ((id, e), body) => Let(id, e, body) } } /** @@ -1060,7 +1068,7 @@ object ExprOps { def genericTransform[C](pre: (Expr, C) => (Expr, C), post: (Expr, C) => (Expr, C), - combiner: (Seq[C]) => C)(init: C)(expr: Expr) = { + combiner: (Expr, Seq[C]) => C)(init: C)(expr: Expr) = { def rec(eIn: Expr, cIn: C): (Expr, C) = { @@ -1068,26 +1076,26 @@ object ExprOps { val (newExpr, newC) = expr match { case t: Terminal => - (expr, cIn) + (expr, ctx) case UnaryOperator(e, builder) => val (e1, c) = rec(e, ctx) val newE = builder(e1).copiedFrom(expr) - (newE, combiner(Seq(c))) + (newE, combiner(newE, Seq(c))) case BinaryOperator(e1, e2, builder) => val (ne1, c1) = rec(e1, ctx) val (ne2, c2) = rec(e2, ctx) val newE = builder(ne1, ne2).copiedFrom(expr) - (newE, combiner(Seq(c1, c2))) + (newE, combiner(newE, Seq(c1, c2))) case NAryOperator(es, builder) => val (nes, cs) = es.map{ rec(_, ctx)}.unzip val newE = builder(nes).copiedFrom(expr) - (newE, combiner(cs)) + (newE, combiner(newE, cs)) case e => sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") @@ -1099,7 +1107,7 @@ object ExprOps { rec(expr, init) } - private def noCombiner(subCs: Seq[Unit]) = () + private def noCombiner(e: Expr, subCs: Seq[Unit]) = () private def noTransformer[C](e: Expr, c: C) = (e, c) def simpleTransform(pre: Expr => Expr, post: Expr => Expr)(expr: Expr) = { diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index 3c31731efb9580bfa7867cae5fbf00312ef5850d..6a76522b09b3438a4c2706788b128fad36783919 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -112,7 +112,7 @@ case object WeightedBranchesCostModel extends SizeBasedCostModel("WeightedBranch (e, bc) } - def combiner(cs: Seq[BC]) = { + def combiner(e: Expr, cs: Seq[BC]) = { cs.foldLeft(BC(0,0))((bc1, bc2) => BC(bc1.cost + bc2.cost, 0)) }