From 69f8416434e6c10cbbdcc277140dfacbbd82a18e Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Tue, 29 Sep 2015 17:12:23 +0200
Subject: [PATCH] EpsilonElimination should create FunDefs with all binders in
 scope as params

---
 src/main/scala/leon/purescala/ExprOps.scala   | 33 +++++++++++++++++++
 .../scala/leon/xlang/EpsilonElimination.scala | 32 +++++++++---------
 2 files changed, 49 insertions(+), 16 deletions(-)

diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index 54da326bd..6a994e454 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -306,8 +306,40 @@ object ExprOps {
     })(expr)
   }
 
+  def preTransformWithBinders(f: (Expr, Set[Identifier]) => Expr, initBinders: Set[Identifier] = Set())(e: Expr) = {
+    import xlang.Expressions.LetVar
+    def rec(binders: Set[Identifier], e: Expr): Expr = (f(e, binders) match {
+      case LetDef(fd, bd) =>
+        fd.fullBody = rec(binders ++ fd.paramIds, fd.fullBody)
+        LetDef(fd, rec(binders, bd))
+      case Let(i, v, b) =>
+        Let(i, rec(binders + i, v), rec(binders + i, b))
+      case LetVar(i, v, b) =>
+        LetVar(i, rec(binders + i, v), rec(binders + i, b))
+      case MatchExpr(scrut, cses) =>
+        MatchExpr(rec(binders, scrut), cses map { case MatchCase(pat, og, rhs) =>
+          val newBs = binders ++ pat.binders
+          MatchCase(pat, og map (rec(newBs, _)), rec(newBs, rhs))
+        })
+      case Passes(in, out, cses) =>
+        Passes(rec(binders, in), rec(binders, out), cses map { case MatchCase(pat, og, rhs) =>
+          val newBs = binders ++ pat.binders
+          MatchCase(pat, og map (rec(newBs, _)), rec(newBs, rhs))
+        })
+      case Lambda(args, bd) =>
+        Lambda(args, rec(binders ++ args.map(_.id), bd))
+      case Forall(args, bd) =>
+        Forall(args, rec(binders ++ args.map(_.id), bd))
+      case Operator(subs, builder) =>
+        builder(subs map (rec(binders, _)))
+    }).copiedFrom(e)
+
+    rec(initBinders, e)
+  }
+
   /** Returns the set of identifiers in an expression */
   def variablesOf(expr: Expr): Set[Identifier] = {
+    import leon.xlang.Expressions.LetVar
     fold[Set[Identifier]] {
       case (e, subs) =>
         val subvs = subs.flatten.toSet
@@ -315,6 +347,7 @@ object ExprOps {
           case Variable(i) => subvs + i
           case LetDef(fd, _) => subvs -- fd.params.map(_.id)
           case Let(i, _, _) => subvs - i
+          case LetVar(i, _, _) => subvs - i
           case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders)
           case Passes(_, _, cses) => subvs -- cses.flatMap(_.pattern.binders)
           case Lambda(args, _) => subvs -- args.map(_.id)
diff --git a/src/main/scala/leon/xlang/EpsilonElimination.scala b/src/main/scala/leon/xlang/EpsilonElimination.scala
index 377756889..9898bd337 100644
--- a/src/main/scala/leon/xlang/EpsilonElimination.scala
+++ b/src/main/scala/leon/xlang/EpsilonElimination.scala
@@ -16,24 +16,24 @@ object EpsilonElimination extends UnitPhase[Program] {
 
   def apply(ctx: LeonContext, pgm: Program) = {
 
-    for {
-      fd <- pgm.definedFunctions
-      body <- fd.body
-    } {
-      val newBody = postMap{
-        case eps@Epsilon(pred, tpe) =>
-          val freshName   = FreshIdentifier("epsilon")
-          val newFunDef   = new FunDef(freshName, Nil, tpe, Seq())
-          val epsilonVar  = EpsilonVariable(eps.getPos, tpe)
-          val resId       = FreshIdentifier("res", tpe)
-          val postcondition = replace(Map(epsilonVar -> Variable(resId)), pred)
+    for (fd <- pgm.definedFunctions) {
+      fd.fullBody = preTransformWithBinders({
+        case (eps@Epsilon(pred, tpe), binders) =>
+          val freshName = FreshIdentifier("epsilon")
+          val bSeq = binders.toSeq
+          val freshParams = bSeq.map { _.freshen }
+          val newFunDef = new FunDef(freshName, Nil, tpe, freshParams map (ValDef(_)))
+          val epsilonVar = EpsilonVariable(eps.getPos, tpe)
+          val resId = FreshIdentifier("res", tpe)
+          val eMap: Map[Expr, Expr] = bSeq.zip(freshParams).map {
+            case (from, to) => (Variable(from), Variable(to))
+          }.toMap ++ Seq((epsilonVar, Variable(resId)))
+          val postcondition = replace(eMap, pred)
           newFunDef.postcondition = Some(Lambda(Seq(ValDef(resId)), postcondition))
-          Some(LetDef(newFunDef, FunctionInvocation(newFunDef.typed, Seq())))
+          LetDef(newFunDef, FunctionInvocation(newFunDef.typed, bSeq map Variable))
 
-        case _ =>
-          None
-      }(body)
-      fd.body = Some(newBody)
+        case (other, _) => other
+      }, fd.paramIds.toSet)(fd.fullBody)
     }
   }
 
-- 
GitLab