From f237a077c246206ffe66432ec89c47a4e07c50be Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Tue, 15 Nov 2016 09:43:46 +0100
Subject: [PATCH] More simplifications for quantifiers

---
 src/main/scala/inox/ast/Constructors.scala |  6 +--
 src/main/scala/inox/ast/ExprOps.scala      | 14 +++++
 src/main/scala/inox/ast/Expressions.scala  |  2 +-
 src/main/scala/inox/ast/SymbolOps.scala    | 63 ++++++++++++++++++++--
 4 files changed, 77 insertions(+), 8 deletions(-)

diff --git a/src/main/scala/inox/ast/Constructors.scala b/src/main/scala/inox/ast/Constructors.scala
index d9f55e26e..293c70394 100644
--- a/src/main/scala/inox/ast/Constructors.scala
+++ b/src/main/scala/inox/ast/Constructors.scala
@@ -121,10 +121,10 @@ trait Constructors {
     }
 
     var stop = false
-    val simpler = for(e <- flat if !stop && e != BooleanLiteral(true)) yield {
-      if(e == BooleanLiteral(false)) stop = true
+    val simpler = (for (e <- flat if !stop && e != BooleanLiteral(true)) yield {
+      if (e == BooleanLiteral(false)) stop = true
       e
-    }
+    }).distinct
 
     simpler match {
       case Seq()  => BooleanLiteral(true)
diff --git a/src/main/scala/inox/ast/ExprOps.scala b/src/main/scala/inox/ast/ExprOps.scala
index a6c202ba0..65262b2f9 100644
--- a/src/main/scala/inox/ast/ExprOps.scala
+++ b/src/main/scala/inox/ast/ExprOps.scala
@@ -66,6 +66,20 @@ trait ExprOps extends GenTreeOps {
     }(expr)
   }
 
+  /** Freshens all local variables */
+  def freshenLocals(expr: Expr): Expr = {
+    def rec(expr: Expr, bindings: Map[Variable, Variable]): Expr = expr match {
+      case v: Variable => bindings(v)
+      case _ =>
+        val (vs, es, tps, recons) = deconstructor.deconstruct(expr)
+        val newVs = vs.map(_.freshen)
+        val newBindings = bindings ++ (vs zip newVs)
+        recons(newVs, es map (rec(_, newBindings)), tps)
+    }
+
+    rec(expr, variablesOf(expr).map(v => v -> v).toMap)
+  }
+
   /** Returns true if the expression contains a function call */
   def containsFunctionCalls(expr: Expr): Boolean = {
     exists{
diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala
index 55c8aad2a..5f1fd53fa 100644
--- a/src/main/scala/inox/ast/Expressions.scala
+++ b/src/main/scala/inox/ast/Expressions.scala
@@ -130,7 +130,7 @@ trait Expressions { self: Trees =>
 
     def inlined(implicit s: Symbols): Expr = {
       val tfd = this.tfd
-      tfd.withParamSubst(args, tfd.fullBody)
+      exprOps.freshenLocals(tfd.withParamSubst(args, tfd.fullBody))
     }
   }
 
diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala
index 0099afd17..3ceabcb3f 100644
--- a/src/main/scala/inox/ast/SymbolOps.scala
+++ b/src/main/scala/inox/ast/SymbolOps.scala
@@ -246,10 +246,65 @@ trait SymbolOps { self: TypeOps =>
       fixpoint(inline)(e)
     }
 
-    def inlineQuantifications(e: Expr): Expr = postMap {
-      case Forall(args1, Forall(args2, body)) => Some(Forall(args1 ++ args2, body))
-      case _ => None
-    } (e)
+    def inlineQuantifications(e: Expr): Expr = {
+      def liftForalls(args: Seq[ValDef], es: Seq[Expr], recons: Seq[Expr] => Expr): Forall = {
+        val (allArgs, allBodies) = es.map {
+          case f: Forall =>
+            val Forall(args, body) = freshenLocals(f)
+            (args, body)
+          case e =>
+            (Seq[ValDef](), e)
+        }.unzip
+
+        Forall(args ++ allArgs.flatten, recons(allBodies))
+      }
+      
+      postMap {
+        case Forall(args1, Forall(args2, body)) =>
+          Some(Forall(args1 ++ args2, body))
+
+        case Forall(args, And(es)) =>
+          Some(liftForalls(args, es, andJoin))
+
+        case Forall(args, Or(es)) =>
+          Some(liftForalls(args, es, orJoin))
+
+        case Forall(args, Implies(e1, e2)) =>
+          Some(liftForalls(args, Seq(e1, e2), es => implies(es(0), es(1))))
+
+        case And(es) => Some(andJoin(SeqUtils.groupWhile(es)(_.isInstanceOf[Forall]).map {
+          case Seq(e) => e
+          case foralls =>
+            val pairs = foralls.collect { case Forall(args, body) => (args, body) }
+            val (allArgs, allBodies) = pairs.foldLeft((Seq[ValDef](), Seq[Expr]())) {
+              case ((allArgs, bodies), (args, body)) =>
+                val available = allArgs.groupBy(_.tpe).mapValues(_.sortBy(_.id.uniqueName))
+                val (_, map) = args.foldLeft((available, Map[ValDef, ValDef]())) {
+                  case ((available, map), vd) => available.get(vd.tpe) match {
+                    case Some(x +: xs) =>
+                      val newAvailable = if (xs.isEmpty) {
+                        available - vd.tpe
+                      } else {
+                        available + (vd.tpe -> xs)
+                      }
+                      (newAvailable, map + (vd -> x))
+                    case _ =>
+                      (available, map + (vd -> vd))
+                  }
+                }
+
+                val newBody = replaceFromSymbols(map.mapValues(_.toVariable), body)
+                val newArgs = allArgs ++ map.map(_._2).filterNot(allArgs contains _)
+                val newBodies = bodies :+ newBody
+                (newArgs, newBodies)
+            }
+
+            Forall(allArgs, andJoin(allBodies))
+        }))
+
+        case _ => None
+      } (e)
+    }
 
     /* Weaker variant of disjunctive normal form */
     def normalizeClauses(e: Expr): Expr = e match {
-- 
GitLab