From 3ea0f667a7aef560ef9ac8e131193611d53330a2 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Tue, 23 Feb 2016 12:05:43 +0100
Subject: [PATCH] Fixed some data-type unrolling issues

---
 .../solvers/combinators/UnrollingSolver.scala |  4 ++++
 .../solvers/templates/DatatypeManager.scala   |  2 +-
 .../solvers/templates/LambdaManager.scala     |  6 ++---
 .../templates/QuantificationManager.scala     | 12 ++++++----
 .../solvers/templates/UnrollingBank.scala     |  2 --
 .../verification/math/RationalProps.scala     | 23 +++++++------------
 6 files changed, 24 insertions(+), 25 deletions(-)

diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
index 8c57de44e..b7fefa6b2 100644
--- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
+++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
@@ -268,6 +268,10 @@ trait AbstractUnrollingSolver[T]
   private def getTotalModel: Model = {
     val wrapped = solverGetModel
 
+    templateGenerator.manager.quantifications.map { q =>
+      q.holds
+    }
+
     val typeInsts = templateGenerator.manager.typeInstantiations
     val partialInsts = templateGenerator.manager.partialInstantiations
 
diff --git a/src/main/scala/leon/solvers/templates/DatatypeManager.scala b/src/main/scala/leon/solvers/templates/DatatypeManager.scala
index 365fe85ff..dcfa67e83 100644
--- a/src/main/scala/leon/solvers/templates/DatatypeManager.scala
+++ b/src/main/scala/leon/solvers/templates/DatatypeManager.scala
@@ -200,7 +200,7 @@ class DatatypeManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(en
       andJoin(inv ++ subtype :+ induct)
 
     case TupleType(tpes) =>
-      andJoin(tpes.zipWithIndex.map(p => typeUnroller(TupleSelect(expr, p._2))))
+      andJoin(tpes.zipWithIndex.map(p => typeUnroller(TupleSelect(expr, p._2 + 1))))
 
     case FunctionType(_, _) =>
       FreshFunction(expr)
diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala
index 647c75a69..541888c83 100644
--- a/src/main/scala/leon/solvers/templates/LambdaManager.scala
+++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala
@@ -211,7 +211,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco
 
   def registerFunction(b: T, tpe: FunctionType, f: T): Seq[T] = {
     val ft = bestRealType(tpe).asInstanceOf[FunctionType]
-    val bs = fixpoint((bs: Set[T]) => bs.flatMap(blockerParents))(Set(b))
+    val bs = fixpoint((bs: Set[T]) => bs ++ bs.flatMap(blockerParents))(Set(b))
 
     val (known, neqClauses) = if ((bs intersect typeEnablers).nonEmpty) {
       maybeFree += ft -> (maybeFree(ft) + (b -> f))
@@ -262,10 +262,10 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco
         val nextB  = encoder.encodeId(FreshIdentifier("b_or", BooleanType, true))
         freeBlockers += tpe -> (freeBlockers(tpe) + (nextB -> caller))
 
-        val clause = encoder.mkEquals(firstB, encoder.mkOr(
+        val clause = encoder.mkEquals(firstB, encoder.mkAnd(blocker, encoder.mkOr(
           knownFree(tpe).map(idT => encoder.mkEquals(caller, idT)).toSeq ++
           maybeFree(tpe).map { case (b, idT) => encoder.mkAnd(b, encoder.mkEquals(caller, idT)) } :+
-          nextB : _*))
+          nextB : _*)))
         (firstB, Seq(clause))
       }
 
diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
index b53382f6b..cd894b5f2 100644
--- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala
+++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala
@@ -132,7 +132,8 @@ object QuantificationTemplate {
 }
 
 class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) {
-  private val quantifications = new IncrementalSeq[MatcherQuantification]
+  private[solvers] val quantifications = new IncrementalSeq[MatcherQuantification]
+
   private val instCtx         = new InstantiationContext
 
   private val ignoredMatchers = new IncrementalSeq[(Int, Set[T], Matcher[T])]
@@ -344,7 +345,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     }
   }
 
-  private trait MatcherQuantification {
+  private[solvers] trait MatcherQuantification {
+    val holds: T
     val pathVar: (Identifier, T)
     val quantifiers: Seq[(Identifier, T)]
     val matchers: Set[Matcher[T]]
@@ -472,10 +474,9 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
         Instantiation.empty[T]
       } else {
         handledSubsts(this) += enablers -> subst
-        val allEnablers = fixpoint((enablers: Set[T]) => enablers.flatMap(blockerParents))(enablers)
 
         var instantiation = Instantiation.empty[T]
-        val (enabler, optEnabler) = freshBlocker(allEnablers)
+        val (enabler, optEnabler) = freshBlocker(enablers)
         if (optEnabler.isDefined) {
           instantiation = instantiation withClause encoder.mkEquals(enabler, optEnabler.get)
         }
@@ -539,6 +540,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification {
 
     var currentQ2Var: T = qs._2
+    val holds = qs._2
 
     protected def instanceSubst(enabler: T): Map[T, T] = {
       val nextQ2Var = encoder.encodeId(q2s._1)
@@ -590,6 +592,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage
     val applications: Map[T, Set[App[T]]],
     val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification {
 
+    val holds = start
+
     protected def instanceSubst(enabler: T): Map[T, T] = {
       Map(guardVar -> start, blocker -> enabler)
     }
diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala
index e66de514a..ba6ed4382 100644
--- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala
+++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala
@@ -294,7 +294,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat
           // we need to define this defBlocker and link it to definition
           val defBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType))
           defBlockers += info -> defBlocker
-          manager.implies(id, defBlocker)
 
           val template = templateGenerator.mkTemplate(tfd)
           //reporter.debug(template)
@@ -348,7 +347,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat
         case None =>
           val lambdaBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType))
           lambdaBlockers += info -> lambdaBlocker
-          manager.implies(b, lambdaBlocker)
 
           val (newExprs, callBlocks, appBlocks) = template.instantiate(lambdaBlocker, args)
           val blockExprs = freshAppBlocks(appBlocks.keys)
diff --git a/testcases/verification/math/RationalProps.scala b/testcases/verification/math/RationalProps.scala
index aec07246e..0b13ff1aa 100644
--- a/testcases/verification/math/RationalProps.scala
+++ b/testcases/verification/math/RationalProps.scala
@@ -7,63 +7,57 @@ import scala.language.postfixOps
 object RationalProps {
 
   def squarePos(r: Rational): Rational = {
-    require(r.isRational)
     r * r
   } ensuring(_ >= Rational(0))
 
   def additionIsCommutative(p: Rational, q: Rational): Boolean = {
-    require(p.isRational && q.isRational)
     p + q == q + p
   } holds
 
   def multiplicationIsCommutative(p: Rational, q: Rational): Boolean = {
-    require(p.isRational && q.isRational)
     p * q == q * p
   } holds
 
   def lessThanIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = {
-    require(p.isRational && q.isRational && r.isRational && p < q && q < r)
+    require(p < q && q < r)
     p < r
   } holds
 
   def lessEqualsIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = {
-    require(p.isRational && q.isRational && r.isRational && p <= q && q <= r)
+    require(p <= q && q <= r)
     p <= r
   } holds
 
   def greaterThanIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = {
-    require(p.isRational && q.isRational && r.isRational && p > q && q > r)
+    require(p > q && q > r)
     p > r
   } holds
 
   def greaterEqualsIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = {
-    require(p.isRational && q.isRational && r.isRational && p >= q && q >= r)
+    require(p >= q && q >= r)
     p >= r
   } holds
 
   def distributionMult(p: Rational, q: Rational, r: Rational): Boolean = {
-    require(p.isRational && q.isRational && r.isRational)
     (p*(q + r)) ~ (p*q + p*r)
   } holds
 
   def reciprocalIsCorrect(p: Rational): Boolean = {
-    require(p.isRational && p.nonZero)
+    require(p.nonZero)
     (p * p.reciprocal) ~ Rational(1)
   } holds
 
   def additiveInverseIsCorrect(p: Rational): Boolean = {
-    require(p.isRational)
     (p + (-p)) ~ Rational(0)
   } holds
 
   //should not hold because q could be 0
   def divByZero(p: Rational, q: Rational): Boolean = {
-    require(p.isRational && q.isRational)
     ((p / q) * q) ~ p
   } holds
 
   def divByNonZero(p: Rational, q: Rational): Boolean = {
-    require(p.isRational && q.isRational && q.nonZero)
+    require(q.nonZero)
     ((p / q) * q) ~ p
   } holds
   
@@ -73,17 +67,16 @@ object RationalProps {
    */
 
   def equivalentIsReflexive(p: Rational): Boolean = {
-    require(p.isRational)
     p ~ p
   } holds
 
   def equivalentIsSymmetric(p: Rational, q: Rational): Boolean = {
-    require(p.isRational && q.isRational && p ~ q)
+    require(p ~ q)
     q ~ p
   } holds
 
   def equivalentIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = {
-    require(p.isRational && q.isRational && r.isRational && p ~ q && q ~ r)
+    require(p ~ q && q ~ r)
     p ~ r
   } holds
 }
-- 
GitLab