diff --git a/src/main/scala/leon/purescala/Path.scala b/src/main/scala/leon/purescala/Path.scala index 0a78f0b6dd7355485f60b9acd88f67cca43f68d0..77eab7b8399e92b3da309eaef8aa781e42697c4e 100644 --- a/src/main/scala/leon/purescala/Path.scala +++ b/src/main/scala/leon/purescala/Path.scala @@ -11,6 +11,7 @@ import ExprOps._ import Types._ object Path { + final type Element = Either[(Identifier, Expr), Expr] def empty: Path = new Path(Seq.empty) def apply(p: Expr): Path = p match { case Let(i, e, b) => new Path(Seq(Left(i -> e))) merge apply(b) @@ -26,18 +27,30 @@ object Path { * in the context of the provided let-bindings. * * This encoding enables let-bindings over types for which equality is - * not defined, whereas previous encoding of let-bindings into equals + * not defined, whereas an encoding of let-bindings with equalities * could introduce non-sensical equations. */ class Path private[purescala]( - private[purescala] val elements: Seq[Either[(Identifier, Expr), Expr]]) extends Printable { + private[purescala] val elements: Seq[Path.Element]) extends Printable { + import Path.Element + /** Add a binding to this [[Path]] */ - def withBinding(p: (Identifier, Expr)) = new Path(elements :+ Left(p)) - def withBindings(ps: Iterable[(Identifier, Expr)]) = new Path(elements ++ ps.map(Left(_))) + def withBinding(p: (Identifier, Expr)) = { + def exprOf(e: Element) = e match { case Right(e) => e; case Left((_, e)) => e } + val (before, after) = elements span (el => !variablesOf(exprOf(el)).contains(p._1)) + new Path(before ++ Seq(Left(p)) ++ after) + } + def withBindings(ps: Iterable[(Identifier, Expr)]) = { + ps.foldLeft(this)( _ withBinding _ ) + } + /** Add a condition to this [[Path]] */ - def withCond(e: Expr) = new Path(elements :+ Right(e)) - def withConds(es: Iterable[Expr]) = new Path(elements ++ es.map(Right(_))) + def withCond(e: Expr) = { + if (e == BooleanLiteral(true)) this + else new Path(elements :+ Right(e)) + } + def withConds(es: Iterable[Expr]) = new Path(elements ++ es.filterNot( _ == BooleanLiteral(true)).map(Right(_))) /** Remove bound variables from this [[Path]] * @param ids the bound variables to remove @@ -131,7 +144,7 @@ class Path private[purescala]( * for proposition expressions. */ private def fold[T](base: T, combineLet: (Identifier, Expr, T) => T, combineCond: (Expr, T) => T) - (elems: Seq[Either[(Identifier, Expr), Expr]]): T = elems.foldRight(base) { + (elems: Seq[Element]): T = elems.foldRight(base) { case (Left((id, e)), res) => combineLet(id, e, res) case (Right(e), res) => combineCond(e, res) }