From 0db9dc5bd509e30f96530997c973fafaa1fc7b8c Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Thu, 7 Apr 2016 15:59:49 +0200
Subject: [PATCH] Finished up Bags in leon

---
 library/theories/Bag.scala                    |  2 -
 .../purescala/DefinitionTransformer.scala     |  2 +-
 .../scala/leon/purescala/Expressions.scala    | 31 ++++++++++---
 .../scala/leon/purescala/PrettyPrinter.scala  |  8 ++++
 .../scala/leon/purescala/ScalaPrinter.scala   |  7 +++
 src/main/scala/leon/purescala/Types.scala     |  1 +
 .../solvers/unrolling/UnrollingSolver.scala   |  4 +-
 .../purescala/valid/MergeSort2.scala          | 43 +++++++++++++++++++
 8 files changed, 89 insertions(+), 9 deletions(-)
 create mode 100644 src/test/resources/regression/verification/purescala/valid/MergeSort2.scala

diff --git a/library/theories/Bag.scala b/library/theories/Bag.scala
index 833b142f8..53089dcd2 100644
--- a/library/theories/Bag.scala
+++ b/library/theories/Bag.scala
@@ -6,8 +6,6 @@ import leon.annotation._
 
 @library
 sealed case class Bag[T](f: T => BigInt) {
-  require(forall((x: T) => f(x) >= 0))
-
   def get(x: T): BigInt = f(x)
   def add(elem: T): Bag[T] = Bag((x: T) => f(x) + (if (x == elem) 1 else 0))
   def union(that: Bag[T]): Bag[T] = Bag((x: T) => f(x) + that.f(x))
diff --git a/src/main/scala/leon/purescala/DefinitionTransformer.scala b/src/main/scala/leon/purescala/DefinitionTransformer.scala
index 89c5b613b..6be2448ad 100644
--- a/src/main/scala/leon/purescala/DefinitionTransformer.scala
+++ b/src/main/scala/leon/purescala/DefinitionTransformer.scala
@@ -97,7 +97,7 @@ class DefinitionTransformer(
 
     for (fd <- requiredFds) {
       val nfd = fdMap.toB(fd)
-      fd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap)
+      nfd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap)
     }
   }
 }
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index 61640bce6..3855813de 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -823,11 +823,20 @@ object Expressions {
   }
   /** $encodingof `set + elem` */
   case class SetAdd(set: Expr, elem: Expr) extends Expr {
-    val getType = set.getType
+    val getType = {
+      val base = set.getType match {
+        case SetType(base) => base
+        case _ => Untyped
+      }
+      checkParamTypes(Seq(elem.getType), Seq(base), SetType(base).unveilUntyped)
+    }
   }
   /** $encodingof `set.contains(element)` or `set(element)` */
   case class ElementOfSet(element: Expr, set: Expr) extends Expr {
-    val getType = BooleanType
+    val getType = checkParamTypes(Seq(element.getType), Seq(set.getType match {
+      case SetType(base) => base
+      case _ => Untyped
+    }), BooleanType)
   }
   /** $encodingof `set.length` */
   case class SetCardinality(set: Expr) extends Expr {
@@ -835,7 +844,10 @@ object Expressions {
   }
   /** $encodingof `set.subsetOf(set2)` */
   case class SubsetOf(set1: Expr, set2: Expr) extends Expr {
-    val getType  = BooleanType
+    val getType = (set1.getType, set2.getType) match {
+      case (SetType(b1), SetType(b2)) if b1 == b2 => BooleanType
+      case _ => Untyped
+    }
   }
   /** $encodingof `set & set2` */
   case class SetIntersection(set1: Expr, set2: Expr) extends Expr {
@@ -857,11 +869,20 @@ object Expressions {
   }
   /** $encodingof `bag + elem` */
   case class BagAdd(bag: Expr, elem: Expr) extends Expr {
-    val getType = bag.getType
+    val getType = {
+      val base = bag.getType match {
+        case BagType(base) => base
+        case _ => Untyped
+      }
+      checkParamTypes(Seq(base), Seq(elem.getType), BagType(base).unveilUntyped)
+    }
   }
   /** $encodingof `bag.get(element)` or `bag(element)` */
   case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr {
-    val getType = IntegerType
+    val getType = checkParamTypes(Seq(element.getType), Seq(bag.getType match {
+      case BagType(base) => base
+      case _ => Untyped
+    }), IntegerType)
   }
   /** $encodingof `bag1 & bag2` */
   case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr {
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index f1b80fd9d..38644fc4a 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -321,16 +321,23 @@ class PrettyPrinter(opts: PrinterOptions,
       case RealTimes(l,r)            => optP { p"$l * $r" }
       case RealDivision(l,r)         => optP { p"$l / $r" }
       case fs @ FiniteSet(rs, _)     => p"{${rs.toSeq}}"
+      case fs @ FiniteBag(rs, _)     => p"{$rs}"
       case fm @ FiniteMap(rs, _, _)  => p"{$rs}"
       case Not(ElementOfSet(e,s))    => p"$e \u2209 $s"
       case ElementOfSet(e,s)         => p"$e \u2208 $s"
       case SubsetOf(l,r)             => p"$l \u2286 $r"
       case Not(SubsetOf(l,r))        => p"$l \u2288 $r"
+      case SetAdd(s,e)               => p"$s \u222A {$e}"
       case SetUnion(l,r)             => p"$l \u222A $r"
+      case BagUnion(l,r)             => p"$l \u222A $r"
       case MapUnion(l,r)             => p"$l \u222A $r"
       case SetDifference(l,r)        => p"$l \\ $r"
+      case BagDifference(l,r)        => p"$l \\ $r"
       case SetIntersection(l,r)      => p"$l \u2229 $r"
+      case BagIntersection(l,r)      => p"$l \u2229 $r"
       case SetCardinality(s)         => p"$s.size"
+      case BagAdd(b,e)               => p"$b + $e"
+      case MultiplicityInBag(e, b)   => p"$b($e)"
       case MapApply(m,k)             => p"$m($k)"
       case MapIsDefinedAt(m,k)       => p"$m.isDefinedAt($k)"
       case ArrayLength(a)            => p"$a.length"
@@ -464,6 +471,7 @@ class PrettyPrinter(opts: PrinterOptions,
       case StringType            => p"String"
       case ArrayType(bt)         => p"Array[$bt]"
       case SetType(bt)           => p"Set[$bt]"
+      case BagType(bt)           => p"Bag[$bt]"
       case MapType(ft,tt)        => p"Map[$ft, $tt]"
       case TupleType(tpes)       => p"($tpes)"
       case FunctionType(fts, tt) => p"($fts) => $tt"
diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala
index 03c50eb2e..40dcbdab8 100644
--- a/src/main/scala/leon/purescala/ScalaPrinter.scala
+++ b/src/main/scala/leon/purescala/ScalaPrinter.scala
@@ -28,6 +28,7 @@ class ScalaPrinter(opts: PrinterOptions,
       case Choose(pred)              => p"choose($pred)"
 
       case s @ FiniteSet(rss, t)     => p"Set[$t](${rss.toSeq})"
+      case SetAdd(s,e)               => optP { p"$s + $e" }
       case ElementOfSet(e,s)         => p"$s.contains($e)"
       case SetUnion(l,r)             => optP { p"$l ++ $r" }
       case SetDifference(l,r)        => optP { p"$l -- $r" }
@@ -35,6 +36,12 @@ class ScalaPrinter(opts: PrinterOptions,
       case SetCardinality(s)         => p"$s.size"
       case SubsetOf(subset,superset) => p"$subset.subsetOf($superset)"
 
+      case b @ FiniteBag(els, t)     => p"Bag[$t]($els)"
+      case BagAdd(s,e)               => optP { p"$s + $e" }
+      case BagUnion(l,r)             => optP { p"$l ++ $r" }
+      case BagDifference(l,r)        => optP { p"$l -- $r" }
+      case BagIntersection(l,r)      => optP { p"$l & $r" }
+
       case MapUnion(l,r)             => optP { p"$l ++ $r" }
       case m @ FiniteMap(els, k, v)  => p"Map[$k,$v]($els)"
 
diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala
index e2838104f..1536395e9 100644
--- a/src/main/scala/leon/purescala/Types.scala
+++ b/src/main/scala/leon/purescala/Types.scala
@@ -143,6 +143,7 @@ object Types {
       case TupleType(ts) => Some((ts, Constructors.tupleTypeWrap _))
       case ArrayType(t) => Some((Seq(t), ts => ArrayType(ts.head)))
       case SetType(t) => Some((Seq(t), ts => SetType(ts.head)))
+      case BagType(t) => Some((Seq(t), ts => BagType(ts.head)))
       case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1))))
       case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head)))
       /* n-ary operators */
diff --git a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala
index 852e125ed..dfc48fffe 100644
--- a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala
+++ b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala
@@ -259,7 +259,9 @@ trait AbstractUnrollingSolver[T]
       var quantify = false
 
       def check[R](clauses: Seq[T])(block: Option[Boolean] => R) =
-        if (partialModels) solverCheckAssumptions(clauses)(block) else solverCheck(clauses)(block)
+        if (partialModels || templateGenerator.manager.quantifications.isEmpty)
+          solverCheckAssumptions(clauses)(block)
+        else solverCheck(clauses)(block)
 
       val timer = context.timers.solvers.check.start()
       check(encodedAssumptions.toSeq ++ unrollingBank.satisfactionAssumptions) { res =>
diff --git a/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala b/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala
new file mode 100644
index 000000000..0bd0ab35e
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala
@@ -0,0 +1,43 @@
+/* Copyright 2009-2016 EPFL, Lausanne */
+
+import leon.annotation._
+import leon.collection._
+import leon.lang._
+
+object MergeSort2 {
+
+  def bag[T](list: List[T]): Bag[T] = list match {
+    case Nil() => Bag.empty[T]
+    case Cons(x, xs) => bag(xs) + x
+  }
+
+  def isSorted(list: List[BigInt]): Boolean = list match {
+    case Cons(x1, tail @ Cons(x2, _)) => x1 <= x2 && isSorted(tail)
+    case _ => true
+  }
+
+  def merge(l1: List[BigInt], l2: List[BigInt]): List[BigInt] = {
+    require(isSorted(l1) && isSorted(l2))
+
+    (l1, l2) match {
+      case (Cons(x, xs), Cons(y, ys)) =>
+        if (x <= y) Cons(x, merge(xs, l2))
+        else Cons(y, merge(l1, ys))
+      case _ => l1 ++ l2
+    }
+  } ensuring (res => isSorted(res) && bag(res) == bag(l1) ++ bag(l2))
+
+  def split(list: List[BigInt]): (List[BigInt], List[BigInt]) = (list match {
+    case Cons(x1, Cons(x2, xs)) =>
+      val (s1, s2) = split(xs)
+      (Cons(x1, s1), Cons(x2, s2))
+    case _ => (list, Nil[BigInt]())
+  }) ensuring (res => bag(res._1) ++ bag(res._2) == bag(list))
+
+  def mergeSort(list: List[BigInt]): List[BigInt] = (list match {
+    case Cons(_, Cons(_, _)) =>
+      val (s1, s2) = split(list)
+      merge(mergeSort(s1), mergeSort(s2))
+    case _ => list
+  }) ensuring (res => isSorted(res) && bag(res) == bag(list))
+}
-- 
GitLab