diff --git a/library/theories/Bag.scala b/library/theories/Bag.scala index 833b142f813c4e37e7213124b9fa9dcd90792d96..53089dcd200b379277feae1c18d7bc98bfbeb581 100644 --- a/library/theories/Bag.scala +++ b/library/theories/Bag.scala @@ -6,8 +6,6 @@ import leon.annotation._ @library sealed case class Bag[T](f: T => BigInt) { - require(forall((x: T) => f(x) >= 0)) - def get(x: T): BigInt = f(x) def add(elem: T): Bag[T] = Bag((x: T) => f(x) + (if (x == elem) 1 else 0)) def union(that: Bag[T]): Bag[T] = Bag((x: T) => f(x) + that.f(x)) diff --git a/src/main/scala/leon/purescala/DefinitionTransformer.scala b/src/main/scala/leon/purescala/DefinitionTransformer.scala index 89c5b613b426f8d6c9b6b2eeaa1273da2496ed40..6be2448adf41b1c2962ef458ca838d04c13b9bdb 100644 --- a/src/main/scala/leon/purescala/DefinitionTransformer.scala +++ b/src/main/scala/leon/purescala/DefinitionTransformer.scala @@ -97,7 +97,7 @@ class DefinitionTransformer( for (fd <- requiredFds) { val nfd = fdMap.toB(fd) - fd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap) + nfd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap) } } } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 61640bce6b4fd59de5f2366838dda5f75a947e8d..3855813de7e618bc272c6cc845cb24bb3983a5e3 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -823,11 +823,20 @@ object Expressions { } /** $encodingof `set + elem` */ case class SetAdd(set: Expr, elem: Expr) extends Expr { - val getType = set.getType + val getType = { + val base = set.getType match { + case SetType(base) => base + case _ => Untyped + } + checkParamTypes(Seq(elem.getType), Seq(base), SetType(base).unveilUntyped) + } } /** $encodingof `set.contains(element)` or `set(element)` */ case class ElementOfSet(element: Expr, set: Expr) extends Expr { - val getType = BooleanType + val getType = checkParamTypes(Seq(element.getType), Seq(set.getType match { + case SetType(base) => base + case _ => Untyped + }), BooleanType) } /** $encodingof `set.length` */ case class SetCardinality(set: Expr) extends Expr { @@ -835,7 +844,10 @@ object Expressions { } /** $encodingof `set.subsetOf(set2)` */ case class SubsetOf(set1: Expr, set2: Expr) extends Expr { - val getType = BooleanType + val getType = (set1.getType, set2.getType) match { + case (SetType(b1), SetType(b2)) if b1 == b2 => BooleanType + case _ => Untyped + } } /** $encodingof `set & set2` */ case class SetIntersection(set1: Expr, set2: Expr) extends Expr { @@ -857,11 +869,20 @@ object Expressions { } /** $encodingof `bag + elem` */ case class BagAdd(bag: Expr, elem: Expr) extends Expr { - val getType = bag.getType + val getType = { + val base = bag.getType match { + case BagType(base) => base + case _ => Untyped + } + checkParamTypes(Seq(base), Seq(elem.getType), BagType(base).unveilUntyped) + } } /** $encodingof `bag.get(element)` or `bag(element)` */ case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr { - val getType = IntegerType + val getType = checkParamTypes(Seq(element.getType), Seq(bag.getType match { + case BagType(base) => base + case _ => Untyped + }), IntegerType) } /** $encodingof `bag1 & bag2` */ case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr { diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index f1b80fd9de6df1996d33fdc7936669be6a997b73..38644fc4a896cb2f84eefac53eb6652dd041e8a1 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -321,16 +321,23 @@ class PrettyPrinter(opts: PrinterOptions, case RealTimes(l,r) => optP { p"$l * $r" } case RealDivision(l,r) => optP { p"$l / $r" } case fs @ FiniteSet(rs, _) => p"{${rs.toSeq}}" + case fs @ FiniteBag(rs, _) => p"{$rs}" case fm @ FiniteMap(rs, _, _) => p"{$rs}" case Not(ElementOfSet(e,s)) => p"$e \u2209 $s" case ElementOfSet(e,s) => p"$e \u2208 $s" case SubsetOf(l,r) => p"$l \u2286 $r" case Not(SubsetOf(l,r)) => p"$l \u2288 $r" + case SetAdd(s,e) => p"$s \u222A {$e}" case SetUnion(l,r) => p"$l \u222A $r" + case BagUnion(l,r) => p"$l \u222A $r" case MapUnion(l,r) => p"$l \u222A $r" case SetDifference(l,r) => p"$l \\ $r" + case BagDifference(l,r) => p"$l \\ $r" case SetIntersection(l,r) => p"$l \u2229 $r" + case BagIntersection(l,r) => p"$l \u2229 $r" case SetCardinality(s) => p"$s.size" + case BagAdd(b,e) => p"$b + $e" + case MultiplicityInBag(e, b) => p"$b($e)" case MapApply(m,k) => p"$m($k)" case MapIsDefinedAt(m,k) => p"$m.isDefinedAt($k)" case ArrayLength(a) => p"$a.length" @@ -464,6 +471,7 @@ class PrettyPrinter(opts: PrinterOptions, case StringType => p"String" case ArrayType(bt) => p"Array[$bt]" case SetType(bt) => p"Set[$bt]" + case BagType(bt) => p"Bag[$bt]" case MapType(ft,tt) => p"Map[$ft, $tt]" case TupleType(tpes) => p"($tpes)" case FunctionType(fts, tt) => p"($fts) => $tt" diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 03c50eb2e752ad0dd3827703eff4fc44542e5b2c..40dcbdab8188f71988d78e58ffb200224de57891 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -28,6 +28,7 @@ class ScalaPrinter(opts: PrinterOptions, case Choose(pred) => p"choose($pred)" case s @ FiniteSet(rss, t) => p"Set[$t](${rss.toSeq})" + case SetAdd(s,e) => optP { p"$s + $e" } case ElementOfSet(e,s) => p"$s.contains($e)" case SetUnion(l,r) => optP { p"$l ++ $r" } case SetDifference(l,r) => optP { p"$l -- $r" } @@ -35,6 +36,12 @@ class ScalaPrinter(opts: PrinterOptions, case SetCardinality(s) => p"$s.size" case SubsetOf(subset,superset) => p"$subset.subsetOf($superset)" + case b @ FiniteBag(els, t) => p"Bag[$t]($els)" + case BagAdd(s,e) => optP { p"$s + $e" } + case BagUnion(l,r) => optP { p"$l ++ $r" } + case BagDifference(l,r) => optP { p"$l -- $r" } + case BagIntersection(l,r) => optP { p"$l & $r" } + case MapUnion(l,r) => optP { p"$l ++ $r" } case m @ FiniteMap(els, k, v) => p"Map[$k,$v]($els)" diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index e2838104fd704110fbb06e6234ea1ae2599a20e7..1536395e92798360c70b5ad8fcebf3520520768c 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -143,6 +143,7 @@ object Types { case TupleType(ts) => Some((ts, Constructors.tupleTypeWrap _)) case ArrayType(t) => Some((Seq(t), ts => ArrayType(ts.head))) case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) + case BagType(t) => Some((Seq(t), ts => BagType(ts.head))) case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) /* n-ary operators */ diff --git a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala index 852e125ed88db4327b18cfabc329d31e1366a2f6..dfc48fffed88eadfc3cfb6587819ca2bdce16567 100644 --- a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala @@ -259,7 +259,9 @@ trait AbstractUnrollingSolver[T] var quantify = false def check[R](clauses: Seq[T])(block: Option[Boolean] => R) = - if (partialModels) solverCheckAssumptions(clauses)(block) else solverCheck(clauses)(block) + if (partialModels || templateGenerator.manager.quantifications.isEmpty) + solverCheckAssumptions(clauses)(block) + else solverCheck(clauses)(block) val timer = context.timers.solvers.check.start() check(encodedAssumptions.toSeq ++ unrollingBank.satisfactionAssumptions) { res => diff --git a/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala b/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala new file mode 100644 index 0000000000000000000000000000000000000000..0bd0ab35ecaef991f618c0938f8446051eabf00e --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala @@ -0,0 +1,43 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +import leon.annotation._ +import leon.collection._ +import leon.lang._ + +object MergeSort2 { + + def bag[T](list: List[T]): Bag[T] = list match { + case Nil() => Bag.empty[T] + case Cons(x, xs) => bag(xs) + x + } + + def isSorted(list: List[BigInt]): Boolean = list match { + case Cons(x1, tail @ Cons(x2, _)) => x1 <= x2 && isSorted(tail) + case _ => true + } + + def merge(l1: List[BigInt], l2: List[BigInt]): List[BigInt] = { + require(isSorted(l1) && isSorted(l2)) + + (l1, l2) match { + case (Cons(x, xs), Cons(y, ys)) => + if (x <= y) Cons(x, merge(xs, l2)) + else Cons(y, merge(l1, ys)) + case _ => l1 ++ l2 + } + } ensuring (res => isSorted(res) && bag(res) == bag(l1) ++ bag(l2)) + + def split(list: List[BigInt]): (List[BigInt], List[BigInt]) = (list match { + case Cons(x1, Cons(x2, xs)) => + val (s1, s2) = split(xs) + (Cons(x1, s1), Cons(x2, s2)) + case _ => (list, Nil[BigInt]()) + }) ensuring (res => bag(res._1) ++ bag(res._2) == bag(list)) + + def mergeSort(list: List[BigInt]): List[BigInt] = (list match { + case Cons(_, Cons(_, _)) => + val (s1, s2) = split(list) + merge(mergeSort(s1), mergeSort(s2)) + case _ => list + }) ensuring (res => isSorted(res) && bag(res) == bag(list)) +}