From a2280006dc7fc68af65b830e0445f36e302dff5a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Fri, 22 Apr 2016 10:25:35 +0200
Subject: [PATCH] Added  postFlatmap for all tree operations.

---
 .../scala/leon/purescala/GenTreeOps.scala     | 50 ++++++++++++++++++-
 1 file changed, 49 insertions(+), 1 deletion(-)

diff --git a/src/main/scala/leon/purescala/GenTreeOps.scala b/src/main/scala/leon/purescala/GenTreeOps.scala
index 08404ae15..8a8645da8 100644
--- a/src/main/scala/leon/purescala/GenTreeOps.scala
+++ b/src/main/scala/leon/purescala/GenTreeOps.scala
@@ -186,9 +186,57 @@ trait GenTreeOps[SubTree <: Tree]  {
     } else {
       f(newV) getOrElse newV
     }
-
   }
 
+  /** Post-transformation of the tree using flattening methods.
+    *
+    * Takes a partial function of replacements.
+    * Substitutes '''after''' recursing down the trees.
+    *
+    * Supports two modes :
+    *
+    *   - If applyRec is false (default), will only substitute once on each level.
+    *   e.g.
+    *   {{{
+    *     Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e
+    *   }}}
+    *   will yield:
+    *   {{{
+    *     Add(a, Minus(e, c))
+    *   }}}
+    *
+    *   - If applyRec is true, it will substitute multiple times on each level:
+    *   e.g.
+    *   {{{
+    *     Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f
+    *   }}}
+    *   will yield:
+    *   {{{
+    *     Add(a, f)
+    *   }}}
+    *
+    * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent)
+    */
+  def postFlatmap(f: SubTree => Option[Seq[SubTree]], applyRec: Boolean = false)(e: SubTree): Seq[SubTree] = {
+    val rec = postFlatmap(f, applyRec) _
+
+    val Deconstructor(es, builder) = e
+    val newEss = es.map(rec)
+    val newVs: Seq[SubTree] = leon.utils.SeqUtils.cartesianProduct(newEss).map { newEs =>
+      if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) {
+        builder(newEs).copiedFrom(e)
+      } else {
+        e
+      }
+    }
+
+    if (applyRec) {
+      // Apply f as long as it returns Some()
+      fixpoint { (e : Seq[SubTree]) => e.flatMap(newV => f(newV) getOrElse Seq(newV)) } (newVs)
+    } else {
+      newVs.flatMap((newV: SubTree) => f(newV) getOrElse Seq(newV))
+    }
+  }
 
   /** Applies functions and combines results in a generic way
     *
-- 
GitLab