diff --git a/src/main/scala/leon/purescala/GenTreeOps.scala b/src/main/scala/leon/purescala/GenTreeOps.scala index 08404ae1534f2bf62e6ef4d82355d83f9fe7d1ca..8a8645da8045d4f96b43f0d42144b654f3c4c4a0 100644 --- a/src/main/scala/leon/purescala/GenTreeOps.scala +++ b/src/main/scala/leon/purescala/GenTreeOps.scala @@ -186,9 +186,57 @@ trait GenTreeOps[SubTree <: Tree] { } else { f(newV) getOrElse newV } - } + /** Post-transformation of the tree using flattening methods. + * + * Takes a partial function of replacements. + * Substitutes '''after''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e + * }}} + * will yield: + * {{{ + * Add(a, Minus(e, c)) + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) + */ + def postFlatmap(f: SubTree => Option[Seq[SubTree]], applyRec: Boolean = false)(e: SubTree): Seq[SubTree] = { + val rec = postFlatmap(f, applyRec) _ + + val Deconstructor(es, builder) = e + val newEss = es.map(rec) + val newVs: Seq[SubTree] = leon.utils.SeqUtils.cartesianProduct(newEss).map { newEs => + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + } + + if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { (e : Seq[SubTree]) => e.flatMap(newV => f(newV) getOrElse Seq(newV)) } (newVs) + } else { + newVs.flatMap((newV: SubTree) => f(newV) getOrElse Seq(newV)) + } + } /** Applies functions and combines results in a generic way *