From 0db9dc5bd509e30f96530997c973fafaa1fc7b8c Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Thu, 7 Apr 2016 15:59:49 +0200 Subject: [PATCH] Finished up Bags in leon --- library/theories/Bag.scala | 2 - .../purescala/DefinitionTransformer.scala | 2 +- .../scala/leon/purescala/Expressions.scala | 31 ++++++++++--- .../scala/leon/purescala/PrettyPrinter.scala | 8 ++++ .../scala/leon/purescala/ScalaPrinter.scala | 7 +++ src/main/scala/leon/purescala/Types.scala | 1 + .../solvers/unrolling/UnrollingSolver.scala | 4 +- .../purescala/valid/MergeSort2.scala | 43 +++++++++++++++++++ 8 files changed, 89 insertions(+), 9 deletions(-) create mode 100644 src/test/resources/regression/verification/purescala/valid/MergeSort2.scala diff --git a/library/theories/Bag.scala b/library/theories/Bag.scala index 833b142f8..53089dcd2 100644 --- a/library/theories/Bag.scala +++ b/library/theories/Bag.scala @@ -6,8 +6,6 @@ import leon.annotation._ @library sealed case class Bag[T](f: T => BigInt) { - require(forall((x: T) => f(x) >= 0)) - def get(x: T): BigInt = f(x) def add(elem: T): Bag[T] = Bag((x: T) => f(x) + (if (x == elem) 1 else 0)) def union(that: Bag[T]): Bag[T] = Bag((x: T) => f(x) + that.f(x)) diff --git a/src/main/scala/leon/purescala/DefinitionTransformer.scala b/src/main/scala/leon/purescala/DefinitionTransformer.scala index 89c5b613b..6be2448ad 100644 --- a/src/main/scala/leon/purescala/DefinitionTransformer.scala +++ b/src/main/scala/leon/purescala/DefinitionTransformer.scala @@ -97,7 +97,7 @@ class DefinitionTransformer( for (fd <- requiredFds) { val nfd = fdMap.toB(fd) - fd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap) + nfd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap) } } } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 61640bce6..3855813de 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -823,11 +823,20 @@ object Expressions { } /** $encodingof `set + elem` */ case class SetAdd(set: Expr, elem: Expr) extends Expr { - val getType = set.getType + val getType = { + val base = set.getType match { + case SetType(base) => base + case _ => Untyped + } + checkParamTypes(Seq(elem.getType), Seq(base), SetType(base).unveilUntyped) + } } /** $encodingof `set.contains(element)` or `set(element)` */ case class ElementOfSet(element: Expr, set: Expr) extends Expr { - val getType = BooleanType + val getType = checkParamTypes(Seq(element.getType), Seq(set.getType match { + case SetType(base) => base + case _ => Untyped + }), BooleanType) } /** $encodingof `set.length` */ case class SetCardinality(set: Expr) extends Expr { @@ -835,7 +844,10 @@ object Expressions { } /** $encodingof `set.subsetOf(set2)` */ case class SubsetOf(set1: Expr, set2: Expr) extends Expr { - val getType = BooleanType + val getType = (set1.getType, set2.getType) match { + case (SetType(b1), SetType(b2)) if b1 == b2 => BooleanType + case _ => Untyped + } } /** $encodingof `set & set2` */ case class SetIntersection(set1: Expr, set2: Expr) extends Expr { @@ -857,11 +869,20 @@ object Expressions { } /** $encodingof `bag + elem` */ case class BagAdd(bag: Expr, elem: Expr) extends Expr { - val getType = bag.getType + val getType = { + val base = bag.getType match { + case BagType(base) => base + case _ => Untyped + } + checkParamTypes(Seq(base), Seq(elem.getType), BagType(base).unveilUntyped) + } } /** $encodingof `bag.get(element)` or `bag(element)` */ case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr { - val getType = IntegerType + val getType = checkParamTypes(Seq(element.getType), Seq(bag.getType match { + case BagType(base) => base + case _ => Untyped + }), IntegerType) } /** $encodingof `bag1 & bag2` */ case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr { diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index f1b80fd9d..38644fc4a 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -321,16 +321,23 @@ class PrettyPrinter(opts: PrinterOptions, case RealTimes(l,r) => optP { p"$l * $r" } case RealDivision(l,r) => optP { p"$l / $r" } case fs @ FiniteSet(rs, _) => p"{${rs.toSeq}}" + case fs @ FiniteBag(rs, _) => p"{$rs}" case fm @ FiniteMap(rs, _, _) => p"{$rs}" case Not(ElementOfSet(e,s)) => p"$e \u2209 $s" case ElementOfSet(e,s) => p"$e \u2208 $s" case SubsetOf(l,r) => p"$l \u2286 $r" case Not(SubsetOf(l,r)) => p"$l \u2288 $r" + case SetAdd(s,e) => p"$s \u222A {$e}" case SetUnion(l,r) => p"$l \u222A $r" + case BagUnion(l,r) => p"$l \u222A $r" case MapUnion(l,r) => p"$l \u222A $r" case SetDifference(l,r) => p"$l \\ $r" + case BagDifference(l,r) => p"$l \\ $r" case SetIntersection(l,r) => p"$l \u2229 $r" + case BagIntersection(l,r) => p"$l \u2229 $r" case SetCardinality(s) => p"$s.size" + case BagAdd(b,e) => p"$b + $e" + case MultiplicityInBag(e, b) => p"$b($e)" case MapApply(m,k) => p"$m($k)" case MapIsDefinedAt(m,k) => p"$m.isDefinedAt($k)" case ArrayLength(a) => p"$a.length" @@ -464,6 +471,7 @@ class PrettyPrinter(opts: PrinterOptions, case StringType => p"String" case ArrayType(bt) => p"Array[$bt]" case SetType(bt) => p"Set[$bt]" + case BagType(bt) => p"Bag[$bt]" case MapType(ft,tt) => p"Map[$ft, $tt]" case TupleType(tpes) => p"($tpes)" case FunctionType(fts, tt) => p"($fts) => $tt" diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 03c50eb2e..40dcbdab8 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -28,6 +28,7 @@ class ScalaPrinter(opts: PrinterOptions, case Choose(pred) => p"choose($pred)" case s @ FiniteSet(rss, t) => p"Set[$t](${rss.toSeq})" + case SetAdd(s,e) => optP { p"$s + $e" } case ElementOfSet(e,s) => p"$s.contains($e)" case SetUnion(l,r) => optP { p"$l ++ $r" } case SetDifference(l,r) => optP { p"$l -- $r" } @@ -35,6 +36,12 @@ class ScalaPrinter(opts: PrinterOptions, case SetCardinality(s) => p"$s.size" case SubsetOf(subset,superset) => p"$subset.subsetOf($superset)" + case b @ FiniteBag(els, t) => p"Bag[$t]($els)" + case BagAdd(s,e) => optP { p"$s + $e" } + case BagUnion(l,r) => optP { p"$l ++ $r" } + case BagDifference(l,r) => optP { p"$l -- $r" } + case BagIntersection(l,r) => optP { p"$l & $r" } + case MapUnion(l,r) => optP { p"$l ++ $r" } case m @ FiniteMap(els, k, v) => p"Map[$k,$v]($els)" diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index e2838104f..1536395e9 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -143,6 +143,7 @@ object Types { case TupleType(ts) => Some((ts, Constructors.tupleTypeWrap _)) case ArrayType(t) => Some((Seq(t), ts => ArrayType(ts.head))) case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) + case BagType(t) => Some((Seq(t), ts => BagType(ts.head))) case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) /* n-ary operators */ diff --git a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala index 852e125ed..dfc48fffe 100644 --- a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala @@ -259,7 +259,9 @@ trait AbstractUnrollingSolver[T] var quantify = false def check[R](clauses: Seq[T])(block: Option[Boolean] => R) = - if (partialModels) solverCheckAssumptions(clauses)(block) else solverCheck(clauses)(block) + if (partialModels || templateGenerator.manager.quantifications.isEmpty) + solverCheckAssumptions(clauses)(block) + else solverCheck(clauses)(block) val timer = context.timers.solvers.check.start() check(encodedAssumptions.toSeq ++ unrollingBank.satisfactionAssumptions) { res => diff --git a/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala b/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala new file mode 100644 index 000000000..0bd0ab35e --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/MergeSort2.scala @@ -0,0 +1,43 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +import leon.annotation._ +import leon.collection._ +import leon.lang._ + +object MergeSort2 { + + def bag[T](list: List[T]): Bag[T] = list match { + case Nil() => Bag.empty[T] + case Cons(x, xs) => bag(xs) + x + } + + def isSorted(list: List[BigInt]): Boolean = list match { + case Cons(x1, tail @ Cons(x2, _)) => x1 <= x2 && isSorted(tail) + case _ => true + } + + def merge(l1: List[BigInt], l2: List[BigInt]): List[BigInt] = { + require(isSorted(l1) && isSorted(l2)) + + (l1, l2) match { + case (Cons(x, xs), Cons(y, ys)) => + if (x <= y) Cons(x, merge(xs, l2)) + else Cons(y, merge(l1, ys)) + case _ => l1 ++ l2 + } + } ensuring (res => isSorted(res) && bag(res) == bag(l1) ++ bag(l2)) + + def split(list: List[BigInt]): (List[BigInt], List[BigInt]) = (list match { + case Cons(x1, Cons(x2, xs)) => + val (s1, s2) = split(xs) + (Cons(x1, s1), Cons(x2, s2)) + case _ => (list, Nil[BigInt]()) + }) ensuring (res => bag(res._1) ++ bag(res._2) == bag(list)) + + def mergeSort(list: List[BigInt]): List[BigInt] = (list match { + case Cons(_, Cons(_, _)) => + val (s1, s2) = split(list) + merge(mergeSort(s1), mergeSort(s2)) + case _ => list + }) ensuring (res => isSorted(res) && bag(res) == bag(list)) +} -- GitLab