From dad26d13408a43a028e0b67f1c0c9243c17b10e0 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Wed, 23 Sep 2015 14:22:53 +0200
Subject: [PATCH] Optimization for quantification in posts

---
 .../templates/QuantificationManager.scala     | 65 ++++++++++++++-----
 .../scala/leon/utils/IncrementalSet.scala     |  7 +-
 2 files changed, 53 insertions(+), 19 deletions(-)

diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
index 74ffe69e1..fde9dc746 100644
--- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala
+++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
@@ -89,26 +89,49 @@ object QuantificationTemplate {
 }
 
 class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) {
+  private lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true))
 
   private val quantifications = new IncrementalSeq[Quantification]
   private val instantiated    = new IncrementalSet[(T, Matcher[T])]
+  private val fInsts          = new IncrementalSet[Matcher[T]]
   private val known           = new IncrementalSet[T]
 
+  private def fInstantiated = fInsts.map(m => trueT -> m)
+
   private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = correspond(qm, m.caller, m.tpe)
   private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = qm.tpe match {
     case _: FunctionType => qm.tpe == tpe && (qm.caller == caller || !known(caller))
     case _ => qm.tpe == tpe
   }
 
+  private val uniformQuantifiers = scala.collection.mutable.Map.empty[TypeTree, Seq[T]]
+  private def uniformSubst(qs: Seq[(Identifier, T)]): Map[T, T] = {
+    qs.groupBy(_._1.getType).flatMap { case (tpe, qst) =>
+      val prev = uniformQuantifiers.get(tpe) match {
+        case Some(seq) => seq
+        case None => Seq.empty
+      }
+
+      if (prev.size >= qst.size) {
+        qst.map(_._2) zip prev.take(qst.size - 1)
+      } else {
+        val (handled, newQs) = qst.splitAt(prev.size)
+        val uQs = newQs.map(p => p._2 -> encoder.encodeId(p._1))
+        uniformQuantifiers(tpe) = prev ++ uQs.map(_._2)
+        (handled.map(_._2) zip prev) ++ uQs
+      }
+    }.toMap
+  }
+
   override protected def incrementals: List[IncrementalState] =
-    List(quantifications, instantiated, known) ++ super.incrementals
+    List(quantifications, instantiated, fInsts, known) ++ super.incrementals
 
   def assumptions: Seq[T] = quantifications.map(_.currentQ2Var).toSeq
 
-  def instantiations: Seq[(T, Matcher[T])] = instantiated.toSeq
+  def instantiations: Seq[(T, Matcher[T])] = instantiated.toSeq ++ fInstantiated
 
   def instantiations(caller: T, tpe: TypeTree): Seq[(T, Matcher[T])] =
-    instantiated.toSeq.filter { case (b,m) => correspond(m, caller, tpe) }
+    instantiations.filter { case (b,m) => correspond(m, caller, tpe) }
 
   override def registerFree(ids: Seq[(TypeTree, T)]): Unit = {
     super.registerFree(ids)
@@ -133,7 +156,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     var currentQ2Var: T = qs._2
     private var slaves: Seq[(T, Map[T,T], Quantification)] = Nil
 
-    private def mappings(blocker: T, matcher: Matcher[T]): Set[(T, Map[T, T])] = {
+    private def mappings(blocker: T, matcher: Matcher[T])
+                        (implicit instantiated: Iterable[(T, Matcher[T])]): Set[(T, Map[T, T])] = {
 
       // Build a mapping from applications in the quantified statement to all potential concrete
       // applications previously encountered. Also make sure the current `app` is in the mapping
@@ -204,8 +228,9 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
       var instantiation = Instantiation.empty[T]
 
       for {
+        instantiated <- List(instantiated, fInstantiated)
         (blocker, matcher) <- instantiated
-        (enabler, subst) <- mappings(blocker, matcher)
+        (enabler, subst) <- mappings(blocker, matcher)(instantiated)
         (slaveEnabler, slaveSubst) = extractSlaveInfo(enabler, senabler, subst, ssubst)
       } instantiation ++= slave.instantiate(slaveEnabler, slaveSubst, Set(this))
 
@@ -214,7 +239,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
       instantiation
     }
 
-    def instantiate(blocker: T, matcher: Matcher[T]): Instantiation[T] = {
+    def instantiate(blocker: T, matcher: Matcher[T])(implicit instantiated: Iterable[(T, Matcher[T])]): Instantiation[T] = {
       var instantiation = Instantiation.empty[T]
 
       for ((enabler, subst) <- mappings(blocker, matcher)) {
@@ -268,7 +293,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     }
 
     val enabler =
-      if (constraints.isEmpty) encoder.encodeExpr(Map.empty)(BooleanLiteral(true))
+      if (constraints.isEmpty) trueT
       else if (constraints.size == 1) constraints.head
       else encoder.mkAnd(constraints : _*)
 
@@ -304,7 +329,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
       val newQ = encoder.encodeId(template.qs._1)
       val subst = substMap + (template.qs._2 -> newQ)
 
-      val substituter = encoder.substitute(substMap + (template.qs._2 -> newQ))
+      val substituter = encoder.substitute(subst)
       val quantification = new Quantification(template.qs._1 -> newQ,
         template.q2s, template.insts, template.guardVar,
         quantified,
@@ -360,8 +385,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
           instantiation ++= extendClauses(quantification, q)
       }
 
-      for ((b, m) <- instantiated) {
-        instantiation ++= quantification.instantiate(b, m)
+      for (instantiated <- List(instantiated, fInstantiated); (b, m) <- instantiated) {
+        instantiation ++= quantification.instantiate(b, m)(instantiated)
       }
 
       quantifications += quantification
@@ -370,17 +395,25 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
 
     instantiation = instantiation withClause {
       val newQs =
-        if (qs.isEmpty) encoder.encodeExpr(Map.empty)(BooleanLiteral(true))
+        if (qs.isEmpty) trueT
         else if (qs.size == 1) qs.head
         else encoder.mkAnd(qs : _*)
       encoder.mkImplies(substMap(template.start), encoder.mkEquals(substMap(template.qs._2), newQs))
     }
 
-    val quantifierSubst = substMap ++ template.quantifiers.map { case (id, idT) => idT -> encoder.encodeId(id) }
-    val substituter = encoder.substitute(quantifierSubst)
+    val quantifierSubst = uniformSubst(template.quantifiers)
+    val substituter = encoder.substitute(substMap ++ quantifierSubst)
+
+    for ((_, ms) <- template.matchers; m <- ms) {
+      val sm = m.substitute(substituter)
 
-    for ((b, ms) <- template.matchers; m <- ms) {
-      instantiation ++= instantiateMatcher(substMap(template.start), m.substitute(substituter))
+      if (!fInsts.exists(fm => correspond(sm, fm) && sm.args == fm.args)) {
+        for (q <- quantifications) {
+          instantiation ++= q.instantiate(trueT, sm)(fInstantiated)
+        }
+
+        fInsts += sm
+      }
     }
 
     instantiation
@@ -390,7 +423,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     val qInst = if (instantiated(blocker -> matcher)) Instantiation.empty[T] else {
       var instantiation = Instantiation.empty[T]
       for (q <- quantifications) {
-        instantiation ++= q.instantiate(blocker, matcher)
+        instantiation ++= q.instantiate(blocker, matcher)(instantiated)
       }
 
       instantiated += (blocker -> matcher)
diff --git a/src/main/scala/leon/utils/IncrementalSet.scala b/src/main/scala/leon/utils/IncrementalSet.scala
index b88dcf840..163da296c 100644
--- a/src/main/scala/leon/utils/IncrementalSet.scala
+++ b/src/main/scala/leon/utils/IncrementalSet.scala
@@ -4,7 +4,7 @@ package leon.utils
 
 import scala.collection.mutable.{Stack, Set => MSet}
 import scala.collection.mutable.Builder
-import scala.collection.{Iterable, IterableLike}
+import scala.collection.{Iterable, IterableLike, GenSet}
 
 class IncrementalSet[A] extends IncrementalState
                         with Iterable[A]
@@ -12,6 +12,7 @@ class IncrementalSet[A] extends IncrementalState
                         with Builder[A, IncrementalSet[A]] {
 
   private[this] val stack = new Stack[MSet[A]]()
+  override def repr = stack.flatten.toSet
 
   override def clear(): Unit = {
     stack.clear()
@@ -30,8 +31,8 @@ class IncrementalSet[A] extends IncrementalState
     stack.pop()
   }
 
-  def apply(elem: A) = toSet.contains(elem)
-  def contains(elem: A) = toSet.contains(elem)
+  def apply(elem: A) = repr.contains(elem)
+  def contains(elem: A) = repr.contains(elem)
 
   def iterator = stack.flatten.iterator
   def += (elem: A) = { stack.head += elem; this }
-- 
GitLab