diff --git a/library/lang/Bag.scala b/library/lang/Bag.scala new file mode 100644 index 0000000000000000000000000000000000000000..2825bd1fa4d5ca5136adb9dd0afc391b10ca5cbb --- /dev/null +++ b/library/lang/Bag.scala @@ -0,0 +1,37 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package leon.lang +import leon.annotation._ + +object Bag { + @library + def empty[T] = Bag[T]() + + @ignore + def apply[T](elems: (T, BigInt)*) = { + new Bag[T](scala.collection.immutable.Map[T, BigInt](elems : _*)) + } +} + +@ignore +case class Bag[T](theBag: scala.collection.immutable.Map[T, BigInt]) { + def +(a: T): Bag[T] = new Bag(theBag + (a -> (theBag.getOrElse(a, BigInt(0)) + 1))) + def ++(that: Bag[T]): Bag[T] = new Bag[T]((theBag.keys ++ that.theBag.keys).toSet.map { (k: T) => + k -> (theBag.getOrElse(k, BigInt(0)) + that.theBag.getOrElse(k, BigInt(0))) + }.toMap) + + def --(that: Bag[T]): Bag[T] = new Bag[T](theBag.flatMap { case (k,v) => + val res = v - that.get(k) + if (res <= 0) Nil else List(k -> res) + }) + + def &(that: Bag[T]): Bag[T] = new Bag[T](theBag.flatMap { case (k,v) => + val res = v min that.get(k) + if (res <= 0) Nil else List(k -> res) + }) + + def get(a: T): BigInt = theBag.getOrElse(a, BigInt(0)) + def apply(a: T): BigInt = get(a) + def isEmpty: Boolean = theBag.isEmpty +} + diff --git a/library/lang/Set.scala b/library/lang/Set.scala index d27fd18bbab181e148c1246dd104f322a69e7c4a..351af3f6abb256229cfe27d3a13ce7ad9c191e17 100644 --- a/library/lang/Set.scala +++ b/library/lang/Set.scala @@ -4,25 +4,25 @@ package leon.lang import leon.annotation._ object Set { - @library - def empty[T] = Set[T]() + @library + def empty[T] = Set[T]() - @ignore - def apply[T](elems: T*) = { - new Set[T](scala.collection.immutable.Set[T](elems : _*)) - } + @ignore + def apply[T](elems: T*) = { + new Set[T](scala.collection.immutable.Set[T](elems : _*)) + } } @ignore case class Set[T](val theSet: scala.collection.immutable.Set[T]) { - def +(a: T): Set[T] = new Set[T](theSet + a) - def ++(a: Set[T]): Set[T] = new Set[T](theSet ++ a.theSet) - def -(a: T): Set[T] = new Set[T](theSet - a) - def --(a: Set[T]): Set[T] = new Set[T](theSet -- a.theSet) - def size: BigInt = theSet.size - def contains(a: T): Boolean = theSet.contains(a) - def isEmpty: Boolean = theSet.isEmpty - def subsetOf(b: Set[T]): Boolean = theSet.subsetOf(b.theSet) - def &(a: Set[T]): Set[T] = new Set[T](theSet & a.theSet) + def +(a: T): Set[T] = new Set[T](theSet + a) + def ++(a: Set[T]): Set[T] = new Set[T](theSet ++ a.theSet) + def -(a: T): Set[T] = new Set[T](theSet - a) + def --(a: Set[T]): Set[T] = new Set[T](theSet -- a.theSet) + def size: BigInt = theSet.size + def contains(a: T): Boolean = theSet.contains(a) + def isEmpty: Boolean = theSet.isEmpty + def subsetOf(b: Set[T]): Boolean = theSet.subsetOf(b.theSet) + def &(a: Set[T]): Set[T] = new Set[T](theSet & a.theSet) } diff --git a/library/theories/Bag.scala b/library/theories/Bag.scala new file mode 100644 index 0000000000000000000000000000000000000000..833b142f813c4e37e7213124b9fa9dcd90792d96 --- /dev/null +++ b/library/theories/Bag.scala @@ -0,0 +1,26 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package leon.theories +import leon.lang.forall +import leon.annotation._ + +@library +sealed case class Bag[T](f: T => BigInt) { + require(forall((x: T) => f(x) >= 0)) + + def get(x: T): BigInt = f(x) + def add(elem: T): Bag[T] = Bag((x: T) => f(x) + (if (x == elem) 1 else 0)) + def union(that: Bag[T]): Bag[T] = Bag((x: T) => f(x) + that.f(x)) + def difference(that: Bag[T]): Bag[T] = Bag((x: T) => { + val res = f(x) - that.f(x) + if (res < 0) 0 else res + }) + + def intersect(that: Bag[T]): Bag[T] = Bag((x: T) => { + val r1 = f(x) + val r2 = that.f(x) + if (r1 > r2) r2 else r1 + }) + + def equals(that: Bag[T]): Boolean = forall((x: T) => f(x) == that.f(x)) +} diff --git a/library/theories/String.scala b/library/theories/String.scala new file mode 100644 index 0000000000000000000000000000000000000000..85e4d90c7e1eca5a72a2d7c74aa0c63437d6c872 --- /dev/null +++ b/library/theories/String.scala @@ -0,0 +1,32 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package leon.theories +import leon.annotation._ + +@library +sealed abstract class String { + def size: BigInt = (this match { + case StringCons(_, tail) => 1 + tail.size + case StringNil() => BigInt(0) + }) ensuring (_ >= 0) + + def concat(that: String): String = this match { + case StringCons(head, tail) => StringCons(head, tail concat that) + case StringNil() => that + } + + def take(i: BigInt): String = this match { + case StringCons(head, tail) if i > 0 => StringCons(head, tail take (i - 1)) + case _ => StringNil() + } + + def drop(i: BigInt): String = this match { + case StringCons(head, tail) if i > 0 => tail drop (i - 1) + case _ => this + } + + def slice(from: BigInt, to: BigInt): String = drop(from).take(to - from) +} + +case class StringCons(head: Char, tail: String) extends String +case class StringNil() extends String diff --git a/src/main/java/leon/codegen/runtime/Bag.java b/src/main/java/leon/codegen/runtime/Bag.java new file mode 100644 index 0000000000000000000000000000000000000000..8276935f780c3bb74b2d1a18213a5121d86dcd85 --- /dev/null +++ b/src/main/java/leon/codegen/runtime/Bag.java @@ -0,0 +1,117 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.HashMap; + +/** If someone wants to make it an efficient implementation of immutable + * sets, go ahead. */ +public final class Bag { + private final HashMap<Object, Integer> _underlying; + + protected final HashMap<Object, Integer> underlying() { + return _underlying; + } + + public Bag() { + _underlying = new HashMap<Object, Integer>(); + } + + private Bag(HashMap<Object, Integer> u) { + _underlying = u; + } + + // Uses mutation! Useful at building time. + public void add(Object e) { + add(e, 1); + } + + // Uses mutation! Useful at building time. + private void add(Object e, int count) { + _underlying.put(e, get(e) + count); + } + + public Iterator<java.util.Map.Entry<Object, Integer>> getElements() { + return _underlying.entrySet().iterator(); + } + + public int get(Object element) { + Integer r = _underlying.get(element); + if (r == null) return 0; + else return r.intValue(); + } + + public Bag plus(Object e) { + Bag n = new Bag(new HashMap<Object, Integer>(_underlying)); + n.add(e); + return n; + } + + public Bag union(Bag b) { + Bag n = new Bag(new HashMap<Object, Integer>(_underlying)); + for (java.util.Map.Entry<Object, Integer> entry : b.underlying().entrySet()) { + n.add(entry.getKey(), entry.getValue()); + } + return n; + } + + public Bag intersect(Bag b) { + Bag n = new Bag(); + for (java.util.Map.Entry<Object, Integer> entry : _underlying.entrySet()) { + int m = Math.min(entry.getValue(), b.get(entry.getKey())); + if (m > 0) n.add(entry.getKey(), m); + } + return n; + } + + public Bag difference(Bag b) { + Bag n = new Bag(); + for (java.util.Map.Entry<Object, Integer> entry : _underlying.entrySet()) { + int m = entry.getValue() - b.get(entry.getKey()); + if (m > 0) n.add(entry.getKey(), m); + } + return n; + } + + @Override + public boolean equals(Object that) { + if(that == this) return true; + if(!(that instanceof Bag)) return false; + + Bag other = (Bag) that; + for (Object key : _underlying.keySet()) { + if (get(key) != other.get(key)) return false; + } + + for (Object key : other.underlying().keySet()) { + if (get(key) != other.get(key)) return false; + } + + return true; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("Bag("); + boolean first = true; + for (java.util.Map.Entry<Object, Integer> entry : _underlying.entrySet()) { + if(!first) { + sb.append(", "); + first = false; + } + sb.append(entry.getKey().toString()); + sb.append(" -> "); + sb.append(entry.getValue().toString()); + } + sb.append(")"); + return sb.toString(); + } + + @Override + public int hashCode() { + return _underlying.hashCode(); + } +} diff --git a/src/main/java/leon/codegen/runtime/Set.java b/src/main/java/leon/codegen/runtime/Set.java index 42cadce43616eeafaaf7133d3ac7b55d229c56fb..c361ba9e662a5b84bbd84391c856ab9853d66860 100644 --- a/src/main/java/leon/codegen/runtime/Set.java +++ b/src/main/java/leon/codegen/runtime/Set.java @@ -23,6 +23,10 @@ public final class Set { _underlying = new HashSet<Object>(Arrays.asList(elements)); } + private Set(HashSet<Object> u) { + _underlying = u; + } + // Uses mutation! Useful at building time. public void add(Object e) { _underlying.add(e); @@ -32,14 +36,16 @@ public final class Set { return _underlying.iterator(); } - private Set(HashSet<Object> u) { - _underlying = u; - } - public boolean contains(Object element) { return _underlying.contains(element); } + public Set plus(Object e) { + Set n = new Set(new HashSet<Object>(_underlying)); + n.add(e); + return n; + } + public boolean subsetOf(Set s) { for(Object o : _underlying) { if(!s.underlying().contains(o)) { diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 960e57ff8179e4d3f96667bf0c73842754f91781..d615c4a7e0d2e58420a64097442458ceb006207c 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -72,6 +72,7 @@ trait CodeGeneration { private[codegen] val TupleClass = "leon/codegen/runtime/Tuple" private[codegen] val SetClass = "leon/codegen/runtime/Set" + private[codegen] val BagClass = "leon/codegen/runtime/Bag" private[codegen] val MapClass = "leon/codegen/runtime/Map" private[codegen] val BigIntClass = "leon/codegen/runtime/BigInt" private[codegen] val RealClass = "leon/codegen/runtime/Real" @@ -126,6 +127,9 @@ trait CodeGeneration { case _ : SetType => "L" + SetClass + ";" + case _ : BagType => + "L" + BagClass + ";" + case _ : MapType => "L" + MapClass + ";" @@ -544,6 +548,11 @@ trait CodeGeneration { ch << InvokeVirtual(SetClass, "add", s"(L$ObjectClass;)V") } + case SetAdd(s, e) => + mkExpr(s, ch) + mkBoxedExpr(e, ch) + ch << InvokeVirtual(SetClass, "plus", s"(L$ObjectClass;)L$SetClass;") + case ElementOfSet(e, s) => mkExpr(s, ch) mkBoxedExpr(e, ch) @@ -573,6 +582,41 @@ trait CodeGeneration { mkExpr(s2, ch) ch << InvokeVirtual(SetClass, "minus", s"(L$SetClass;)L$SetClass;") + // Bags + case FiniteBag(els, _) => + ch << DefaultNew(BagClass) + for((k,v) <- els) { + ch << DUP + mkBoxedExpr(k, ch) + mkExpr(v, ch) + ch << InvokeVirtual(BagClass, "add", s"(L$ObjectClass;I)V") + } + + case BagAdd(b, e) => + mkExpr(b, ch) + mkBoxedExpr(e, ch) + ch << InvokeVirtual(BagClass, "plus", s"(L$ObjectClass;)L$BagClass;") + + case MultiplicityInBag(e, b) => + mkExpr(b, ch) + mkBoxedExpr(e, ch) + ch << InvokeVirtual(BagClass, "get", s"(L$ObjectClass;)I") + + case BagIntersection(b1, b2) => + mkExpr(b1, ch) + mkExpr(b2, ch) + ch << InvokeVirtual(BagClass, "intersect", s"(L$BagClass;)L$BagClass;") + + case BagUnion(b1, b2) => + mkExpr(b1, ch) + mkExpr(b2, ch) + ch << InvokeVirtual(BagClass, "union", s"(L$BagClass;)L$BagClass;") + + case BagDifference(b1, b2) => + mkExpr(b1, ch) + mkExpr(b2, ch) + ch << InvokeVirtual(BagClass, "difference", s"(L$BagClass;)L$BagClass;") + // Maps case FiniteMap(ss, _, _) => ch << DefaultNew(MapClass) diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index e8bd40eb2302fd9ced5fc44d6c18848c4461b867..bde7a7649541ade3c1ea989b4e4a8936e52edb45 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -232,6 +232,13 @@ class CompilationUnit(val ctx: LeonContext, } s + case b @ FiniteBag(els, _) => + val b = new leon.codegen.runtime.Bag() + for ((k,v) <- els) { + b.add(valueToJVM(k), valueToJVM(v)) + } + b + case m @ FiniteMap(els, _, _) => val m = new leon.codegen.runtime.Map() for ((k,v) <- els) { @@ -362,6 +369,13 @@ class CompilationUnit(val ctx: LeonContext, case (set: runtime.Set, SetType(b)) => FiniteSet(set.getElements.asScala.map(jvmToValue(_, b)).toSet, b) + case (bag: runtime.Bag, BagType(b)) => + FiniteBag(bag.getElements.asScala.map { entry => + val k = jvmToValue(entry.getKey, b) + val v = jvmToValue(entry.getValue, IntegerType) + (k, v) + }.toMap, b) + case (map: runtime.Map, MapType(from, to)) => val pairs = map.getElements.asScala.map { entry => val k = jvmToValue(entry.getKey, from) diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 0b533e631dee5b0a41181262e0d8c8702b40ea62..0c5ad1ac7be71076ff013e43b7f1a6df9b92d846 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -106,6 +106,19 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { constructors += st -> cs cs }) + + case bt @ BagType(sub) => + constructors.getOrElse(bt, { + val cs = for (size <- List(0, 1, 2, 5)) yield { + val subs = (1 to size).flatMap(i => List(sub, IntegerType)).toList + Constructor[Expr, TypeTree](subs, bt, s => FiniteBag(s.grouped(2).map { + case List(k, i @ InfiniteIntegerLiteral(v)) => + k -> (if (v > 0) i else InfiniteIntegerLiteral(-v + 1)) + }.toMap, sub), bt.asString(ctx)+"@"+size) + } + constructors += bt -> cs + cs + }) case tt @ TupleType(parts) => constructors.getOrElse(tt, { @@ -117,7 +130,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case mt @ MapType(from, to) => constructors.getOrElse(mt, { val cs = for (size <- List(0, 1, 2, 5)) yield { - val subs = (1 to size).flatMap(i => List(from, to)).toList + val subs = (1 to size).flatMap(i => List(from, to)).toList Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toMap, from, to), mt.asString(ctx)+"@"+size) } constructors += mt -> cs diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 418ac42e62409d18d7ebb39276b9fe40e85b063b..cc1acc3bcc8c12883275ce10825d59989062a21d 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -186,6 +186,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (lv,rv) match { case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) + case (FiniteBag(el1, _),FiniteBag(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) case (FiniteLambda(m1, d1, _), FiniteLambda(m2, d2, _)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) case _ => BooleanLiteral(lv == rv) @@ -466,40 +467,40 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } + case SetAdd(s1, elem) => + (e(s1), e(elem)) match { + case (FiniteSet(els1, tpe), evElem) => FiniteSet(els1 + evElem, tpe) + case (le, re) => throw EvalError(typeErrorMsg(le, s1.getType)) + } + case SetUnion(s1,s2) => (e(s1), e(s2)) match { - case (f@FiniteSet(els1, _),FiniteSet(els2, _)) => - val SetType(tpe) = f.getType - FiniteSet(els1 ++ els2, tpe) - case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) + case (FiniteSet(els1, tpe), FiniteSet(els2, _)) => FiniteSet(els1 ++ els2, tpe) + case (le, re) => throw EvalError(typeErrorMsg(le, s1.getType)) } case SetIntersection(s1,s2) => (e(s1), e(s2)) match { - case (f @ FiniteSet(els1, _), FiniteSet(els2, _)) => - val newElems = els1 intersect els2 - val SetType(tpe) = f.getType - FiniteSet(newElems, tpe) + case (FiniteSet(els1, tpe), FiniteSet(els2, _)) => FiniteSet(els1 intersect els2, tpe) case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } case SetDifference(s1,s2) => (e(s1), e(s2)) match { - case (f @ FiniteSet(els1, _),FiniteSet(els2, _)) => - val SetType(tpe) = f.getType - val newElems = els1 -- els2 - FiniteSet(newElems, tpe) + case (FiniteSet(els1, tpe), FiniteSet(els2, _)) => FiniteSet(els1 -- els2, tpe) case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } case ElementOfSet(el,s) => (e(el), e(s)) match { - case (e, f @ FiniteSet(els, _)) => BooleanLiteral(els.contains(e)) + case (e, FiniteSet(els, _)) => BooleanLiteral(els.contains(e)) case (l,r) => throw EvalError(typeErrorMsg(r, SetType(l.getType))) } + case SubsetOf(s1,s2) => (e(s1), e(s2)) match { - case (f@FiniteSet(els1, _),FiniteSet(els2, _)) => BooleanLiteral(els1.subsetOf(els2)) + case (FiniteSet(els1, _),FiniteSet(els2, _)) => BooleanLiteral(els1.subsetOf(els2)) case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } + case SetCardinality(s) => val sr = e(s) sr match { @@ -510,6 +511,61 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case f @ FiniteSet(els, base) => FiniteSet(els.map(e), base) + case BagAdd(bag, elem) => (e(bag), e(elem)) match { + case (FiniteBag(els, tpe), evElem) => FiniteBag(els + (evElem -> (els.get(evElem) match { + case Some(InfiniteIntegerLiteral(i)) => InfiniteIntegerLiteral(i + 1) + case Some(i) => throw EvalError(typeErrorMsg(i, IntegerType)) + case None => InfiniteIntegerLiteral(0) + })), tpe) + + case (le, re) => throw EvalError(typeErrorMsg(le, bag.getType)) + } + + case MultiplicityInBag(elem, bag) => (e(elem), e(bag)) match { + case (evElem, FiniteBag(els, tpe)) => els.getOrElse(evElem, InfiniteIntegerLiteral(0)) + case (le, re) => throw EvalError(typeErrorMsg(re, bag.getType)) + } + + case BagIntersection(b1, b2) => (e(b1), e(b2)) match { + case (FiniteBag(els1, tpe), FiniteBag(els2, _)) => FiniteBag(els1.flatMap { case (k, v) => + val i = (v, els2.getOrElse(k, InfiniteIntegerLiteral(0))) match { + case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => i1 min i2 + case (le, re) => throw EvalError(typeErrorMsg(le, IntegerType)) + } + + if (i <= 0) None else Some(k -> InfiniteIntegerLiteral(i)) + }, tpe) + + case (le, re) => throw EvalError(typeErrorMsg(le, b1.getType)) + } + + case BagUnion(b1, b2) => (e(b1), e(b2)) match { + case (FiniteBag(els1, tpe), FiniteBag(els2, _)) => FiniteBag((els1.keys ++ els2.keys).toSet.map { (k: Expr) => + k -> ((els1.getOrElse(k, InfiniteIntegerLiteral(0)), els2.getOrElse(k, InfiniteIntegerLiteral(0))) match { + case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => InfiniteIntegerLiteral(i1 + i2) + case (le, re) => throw EvalError(typeErrorMsg(le, IntegerType)) + }) + }.toMap, tpe) + + case (le, re) => throw EvalError(typeErrorMsg(le, b1.getType)) + } + + case BagDifference(b1, b2) => (e(b1), e(b2)) match { + case (FiniteBag(els1, tpe), FiniteBag(els2, _)) => FiniteBag(els1.flatMap { case (k, v) => + val i = (v, els2.getOrElse(k, InfiniteIntegerLiteral(0))) match { + case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => i1 - i2 + case (le, re) => throw EvalError(typeErrorMsg(le, IntegerType)) + } + + if (i <= 0) None else Some(k -> InfiniteIntegerLiteral(i)) + }, tpe) + + case (le, re) => throw EvalError(typeErrorMsg(le, b1.getType)) + } + + case FiniteBag(els, base) => + FiniteBag(els.map{ case (k, v) => (e(k), e(v)) }, base) + case l @ Lambda(_, _) => val mapping = variablesOf(l).map(id => id -> e(Variable(id))).toMap val newLambda = replaceFromIDs(mapping, l).asInstanceOf[Lambda] diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 7890d7a229da487e8ae8152ae92c7b0b066eb12e..1d107703b1bfe5dff1658bab08b44d0ae84fa2e8 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -44,6 +44,7 @@ trait ASTExtractors { protected lazy val scalaSetSym = classFromName("scala.collection.immutable.Set") protected lazy val setSym = classFromName("leon.lang.Set") protected lazy val mapSym = classFromName("leon.lang.Map") + protected lazy val bagSym = classFromName("leon.lang.Bag") protected lazy val realSym = classFromName("leon.lang.Real") protected lazy val optionClassSym = classFromName("scala.Option") protected lazy val arraySym = classFromName("scala.Array") @@ -80,6 +81,10 @@ trait ASTExtractors { getResolvedTypeSym(sym) == setSym } + def isBagSym(sym: Symbol) : Boolean = { + getResolvedTypeSym(sym) == bagSym + } + def isRealSym(sym: Symbol) : Boolean = { getResolvedTypeSym(sym) == realSym } @@ -1038,6 +1043,16 @@ trait ASTExtractors { } } + object ExFiniteBag { + def unapply(tree: Apply): Option[(Tree, List[Tree])] = tree match { + case Apply(TypeApply(ExSelected("Bag", "apply"), Seq(tpt)), args) => + Some(tpt, args) + case Apply(TypeApply(ExSelected("leon", "lang", "Bag", "apply"), Seq(tpt)), args) => + Some(tpt, args) + case _ => None + } + } + object ExFiniteMap { def unapply(tree: Apply): Option[(Tree, Tree, List[Tree])] = tree match { case Apply(TypeApply(ExSelected("Map", "apply"), Seq(tptFrom, tptTo)), args) => diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 993bc4fd0c0f48b56c18eae3b2889f303306027c..0a80b3b53933a6bb6f5c88d8785ea8f3eb1cebaa 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1461,20 +1461,26 @@ trait CodeExtraction extends ASTExtractors { Forall(vds, exBody) case ExFiniteMap(tptFrom, tptTo, args) => - val singletons = args.collect { - case ExTuple(tpes, trees) if trees.size == 2 => - (extractTree(trees(0)), extractTree(trees(1))) - }.toMap - - if (singletons.size != args.size) { - outOfSubsetError(tr, "Some map elements could not be extracted as Tuple2") - } - - FiniteMap(singletons, extractType(tptFrom), extractType(tptTo)) + FiniteMap(args.map { + case ExTuple(tpes, Seq(key, value)) => + (extractTree(key), extractTree(value)) + case tree => + val ex = extractTree(tree) + (TupleSelect(ex, 1), TupleSelect(ex, 2)) + }.toMap, extractType(tptFrom), extractType(tptTo)) case ExFiniteSet(tpt, args) => FiniteSet(args.map(extractTree).toSet, extractType(tpt)) + case ExFiniteBag(tpt, args) => + FiniteBag(args.map { + case ExTuple(tpes, Seq(key, value)) => + (extractTree(key), extractTree(value)) + case tree => + val ex = extractTree(tree) + (TupleSelect(ex, 1), TupleSelect(ex, 2)) + }.toMap, extractType(tpt)) + case ExCaseClassConstruction(tpt, args) => extractType(tpt) match { case cct: CaseClassType => @@ -1669,6 +1675,7 @@ trait CodeExtraction extends ASTExtractors { val id = cct.classDef.fields.find(_.id.name == name.dropRight(2)).get.id FieldAssignment(rec, id, e1) + //String methods case (IsTyped(a1, StringType), "toString", List()) => a1 @@ -1686,6 +1693,8 @@ trait CodeExtraction extends ASTExtractors { SubString(a1, start, StringLength(a1)) case (IsTyped(a1, StringType), "substring", List(IsTyped(start, IntegerType | Int32Type), IsTyped(end, IntegerType | Int32Type))) => SubString(a1, start, end) + + //BigInt methods case (IsTyped(a1, IntegerType), "+", List(IsTyped(a2, IntegerType))) => Plus(a1, a2) @@ -1708,6 +1717,7 @@ trait CodeExtraction extends ASTExtractors { case (IsTyped(a1, IntegerType), "<=", List(IsTyped(a2, IntegerType))) => LessEquals(a1, a2) + //Real methods case (IsTyped(a1, RealType), "+", List(IsTyped(a2, RealType))) => RealPlus(a1, a2) @@ -1726,6 +1736,7 @@ trait CodeExtraction extends ASTExtractors { case (IsTyped(a1, RealType), "<=", List(IsTyped(a2, RealType))) => LessEquals(a1, a2) + // Int methods case (IsTyped(a1, Int32Type), "+", List(IsTyped(a2, Int32Type))) => BVPlus(a1, a2) @@ -1760,6 +1771,7 @@ trait CodeExtraction extends ASTExtractors { case (IsTyped(a1, Int32Type), "<=", List(IsTyped(a2, Int32Type))) => LessEquals(a1, a2) + // Boolean methods case (IsTyped(a1, BooleanType), "&&", List(IsTyped(a2, BooleanType))) => and(a1, a2) @@ -1767,6 +1779,7 @@ trait CodeExtraction extends ASTExtractors { case (IsTyped(a1, BooleanType), "||", List(IsTyped(a2, BooleanType))) => or(a1, a2) + // Set methods case (IsTyped(a1, SetType(b1)), "size", Nil) => SetCardinality(a1) @@ -1777,6 +1790,9 @@ trait CodeExtraction extends ASTExtractors { //case (IsTyped(a1, SetType(b1)), "max", Nil) => // SetMax(a1) + case (IsTyped(a1, SetType(b1)), "+", List(a2)) => + SetAdd(a1, a2) + case (IsTyped(a1, SetType(b1)), "++", List(IsTyped(a2, SetType(b2)))) if b1 == b2 => SetUnion(a1, a2) @@ -1795,6 +1811,27 @@ trait CodeExtraction extends ASTExtractors { case (IsTyped(a1, SetType(b1)), "isEmpty", List()) => Equals(a1, FiniteSet(Set(), b1)) + + // Bag methods + case (IsTyped(a1, BagType(b1)), "+", List(a2)) => + BagAdd(a1, a2) + + case (IsTyped(a1, BagType(b1)), "++", List(IsTyped(a2, BagType(b2)))) if b1 == b2 => + BagUnion(a1, a2) + + case (IsTyped(a1, BagType(b1)), "&", List(IsTyped(a2, BagType(b2)))) if b1 == b2 => + BagIntersection(a1, a2) + + case (IsTyped(a1, BagType(b1)), "--", List(IsTyped(a2, BagType(b2)))) if b1 == b2 => + BagDifference(a1, a2) + + case (IsTyped(a1, BagType(b1)), "get", List(a2)) => + MultiplicityInBag(a2, a1) + + case (IsTyped(a1, BagType(b1)), "isEmpty", List()) => + Equals(a1, FiniteBag(Map(), b1)) + + // Array methods case (IsTyped(a1, ArrayType(vt)), "apply", List(a2)) => ArraySelect(a1, a2) @@ -1926,6 +1963,9 @@ trait CodeExtraction extends ASTExtractors { case TypeRef(_, sym, btt :: Nil) if isSetSym(sym) => SetType(extractType(btt)) + case TypeRef(_, sym, btt :: Nil) if isBagSym(sym) => + BagType(extractType(btt)) + case TypeRef(_, sym, List(ftt,ttt)) if isMapSym(sym) => MapType(extractType(ftt), extractType(ttt)) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 4f7bc819650c5316ab4dc3ff3bf1cf1203cb9a1d..b917447942ad3996163f5e970e316d0e1aee4b67 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -856,6 +856,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { case BooleanType => BooleanLiteral(false) case UnitType => UnitLiteral() case SetType(baseType) => FiniteSet(Set(), baseType) + case BagType(baseType) => FiniteBag(Map(), baseType) case MapType(fromType, toType) => FiniteMap(Map(), fromType, toType) case TupleType(tpes) => Tuple(tpes.map(simplestValue)) case ArrayType(tpe) => EmptyArray(tpe) diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 79172eb7a538c298e5b5eb026eb3e8dfca397f3a..61640bce6b4fd59de5f2366838dda5f75a947e8d 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -821,6 +821,10 @@ object Expressions { case class FiniteSet(elements: Set[Expr], base: TypeTree) extends Expr { val getType = SetType(base).unveilUntyped } + /** $encodingof `set + elem` */ + case class SetAdd(set: Expr, elem: Expr) extends Expr { + val getType = set.getType + } /** $encodingof `set.contains(element)` or `set(element)` */ case class ElementOfSet(element: Expr, set: Expr) extends Expr { val getType = BooleanType @@ -833,7 +837,7 @@ object Expressions { case class SubsetOf(set1: Expr, set2: Expr) extends Expr { val getType = BooleanType } - /** $encodingof `set.intersect(set2)` */ + /** $encodingof `set & set2` */ case class SetIntersection(set1: Expr, set2: Expr) extends Expr { val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped } @@ -848,26 +852,18 @@ object Expressions { /* Bag operations */ /** $encodingof `Bag[base](elements)` */ - case class FiniteBag(elements: Map[Expr, Int], base: TypeTree) extends Expr { + case class FiniteBag(elements: Map[Expr, Expr], base: TypeTree) extends Expr { val getType = BagType(base).unveilUntyped } + /** $encodingof `bag + elem` */ + case class BagAdd(bag: Expr, elem: Expr) extends Expr { + val getType = bag.getType + } /** $encodingof `bag.get(element)` or `bag(element)` */ case class MultiplicityInBag(element: Expr, bag: Expr) extends Expr { val getType = IntegerType } - /** $encodingof `bag.length` */ - /* - case class BagCardinality(bag: Expr) extends Expr { - val getType = IntegerType - } - */ - /** $encodingof `bag1.subsetOf(bag2)` */ - /* - case class SubbagOf(bag1: Expr, bag2: Expr) extends Expr { - val getType = BooleanType - } - */ - /** $encodingof `bag1.intersect(bag2)` */ + /** $encodingof `bag1 & bag2` */ case class BagIntersection(bag1: Expr, bag2: Expr) extends Expr { val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } @@ -876,12 +872,10 @@ object Expressions { val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } /** $encodingof `bag1 -- bag2` */ - /* - case class SetDifference(bag1: Expr, bag2: Expr) extends Expr { - val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + case class BagDifference(bag1: Expr, bag2: Expr) extends Expr { + val getType = leastUpperBound(Seq(bag1, bag2).map(_.getType)).getOrElse(Untyped).unveilUntyped } - */ - + // TODO: Add checks for these expressions too diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 5c09c2d296bcbf7ddeb3518f2c42acdc6a695cb8..ea5b30430005fecbb8eb6c5d8082a415cb1e24e2 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -134,6 +134,8 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => RealDivision(es(0), es(1))) case StringConcat(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => StringConcat(es(0), es(1))) + case SetAdd(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => SetAdd(es(0), es(1))) case ElementOfSet(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => ElementOfSet(es(0), es(1))) case SubsetOf(t1, t2) => @@ -144,12 +146,16 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => SetUnion(es(0), es(1))) case SetDifference(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => SetDifference(es(0), es(1))) + case BagAdd(e1, e2) => + Some(Seq(e1, e2), (es: Seq[Expr]) => BagAdd(es(0), es(1))) case MultiplicityInBag(e1, e2) => Some(Seq(e1, e2), (es: Seq[Expr]) => MultiplicityInBag(es(0), es(1))) case BagIntersection(e1, e2) => Some(Seq(e1, e2), (es: Seq[Expr]) => BagIntersection(es(0), es(1))) case BagUnion(e1, e2) => Some(Seq(e1, e2), (es: Seq[Expr]) => BagUnion(es(0), es(1))) + case BagDifference(e1, e2) => + Some(Seq(e1, e2), (es: Seq[Expr]) => BagDifference(es(0), es(1))) case mg @ MapApply(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1))) case MapUnion(t1, t2) => @@ -180,13 +186,22 @@ object Extractors { case FiniteSet(els, base) => Some((els.toSeq, els => FiniteSet(els.toSet, base))) case FiniteBag(els, base) => - val seq = els.toSeq - Some((seq.map(_._1), els => FiniteBag((els zip seq.map(_._2)).toMap, base))) + val subArgs = els.flatMap { case (k, v) => Seq(k, v) }.toSeq + val builder = (as: Seq[Expr]) => { + def rec(kvs: Seq[Expr]): Map[Expr, Expr] = kvs match { + case Seq(k, v, t @ _*) => + Map(k -> v) ++ rec(t) + case Seq() => Map() + case _ => sys.error("odd number of key/value expressions") + } + FiniteBag(rec(as), base) + } + Some((subArgs, builder)) case FiniteMap(args, f, t) => { val subArgs = args.flatMap { case (k, v) => Seq(k, v) }.toSeq val builder = (as: Seq[Expr]) => { def rec(kvs: Seq[Expr]): Map[Expr, Expr] = kvs match { - case Seq(k, v, t@_*) => + case Seq(k, v, t @ _*) => Map(k -> v) ++ rec(t) case Seq() => Map() case _ => sys.error("odd number of key/value expressions") diff --git a/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala b/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala index 3ddfca8d9524f9052dc5e6e49266006e9ffa2e3b..97b5d4dc91a2efcf8c4c6446a7049f4ef28b42fa 100644 --- a/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/cvc4/CVC4UnrollingSolver.scala @@ -10,4 +10,4 @@ import unrolling._ import theories._ class CVC4UnrollingSolver(context: LeonContext, program: Program, underlying: Solver) - extends UnrollingSolver(context, program, underlying, theories = new NoEncoder) + extends UnrollingSolver(context, program, underlying, theories = new BagEncoder(context, program)) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala index 5de74510b465d76cfc152c476cff96321c4758e9..8838ae8ffc6443ba2961f51ecce795004dc98233 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala @@ -9,7 +9,7 @@ import solvers.theories._ import purescala.Definitions.Program class SMTLIBCVC4Solver(context: LeonContext, program: Program) - extends SMTLIBSolver(context, program, new NoEncoder) + extends SMTLIBSolver(context, program, new BagEncoder(context, program)) with SMTLIBCVC4Target { def targetName = "cvc4" diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 7418526c9338fce9c558d118fb947af503ee53ec..cd7eae5e8330bac85f5a554b624bf98576655895 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -197,9 +197,15 @@ trait SMTLIBTarget extends Interruptible { unsupported(r, "Solver returned a co-finite set which is not supported.") } require(r.keyTpe == base, s"Type error in solver model, expected $base, found ${r.keyTpe}") - FiniteSet(r.elems.keySet, base) + case BagType(base) => + if (r.default != InfiniteIntegerLiteral(0)) { + unsupported(r, "Solver returned an infinite bag which is not supported.") + } + require(r.keyTpe == base, s"Type error in solver model, expected $base, found ${r.keyTpe}") + FiniteBag(r.elems, base) + case RawArrayType(from, to) => r diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala index 6a58f9f1a294180182595bcb34d35d1a2ffe08a7..43c1643cf89bef425d30fa1c4e6e11eaa02cf3ed 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala @@ -16,7 +16,7 @@ import _root_.smtlib.theories.Core.{Equals => _, _} import theories._ class SMTLIBZ3Solver(context: LeonContext, program: Program) - extends SMTLIBSolver(context, program, new StringEncoder) + extends SMTLIBSolver(context, program, new StringEncoder(context, program)) with SMTLIBZ3Target { def getProgram: Program = program diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index 1be3f7ecf930990749565a1c7ca2181cd2a100dd..00b361323de76d5647a124eb65f449161390a15e 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -9,11 +9,12 @@ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ -import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} +import _root_.smtlib.common._ +import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Let => SMTLet, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} -import _root_.smtlib.theories.ArraysEx +import _root_.smtlib.theories._ trait SMTLIBZ3Target extends SMTLIBTarget { @@ -44,6 +45,8 @@ trait SMTLIBZ3Target extends SMTLIBTarget { case SetType(base) => super.declareSort(BooleanType) declareSetSort(base) + case BagType(base) => + declareSort(RawArrayType(base, IntegerType)) case _ => super.declareSort(t) } @@ -112,6 +115,9 @@ trait SMTLIBZ3Target extends SMTLIBTarget { SMTEquals(ArrayMap(SSymbol("implies"), toSMT(ss), toSMT(s)), allTrue) + case SetAdd(s, e) => + ArraysEx.Store(toSMT(s), toSMT(e), True()) + case ElementOfSet(e, s) => ArraysEx.Select(toSMT(s), toSMT(e)) @@ -128,6 +134,39 @@ trait SMTLIBZ3Target extends SMTLIBTarget { case SetIntersection(l, r) => ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) + case fb @ FiniteBag(elems, base) => + declareSort(fb.getType) + + toSMT(RawArrayValue(base, elems, InfiniteIntegerLiteral(0))) + + case BagAdd(b, e) => + val bid = FreshIdentifier("b", b.getType, true) + val eid = FreshIdentifier("e", e.getType, true) + val (bSym, eSym) = (id2sym(bid), id2sym(eid)) + SMTLet(VarBinding(bSym, toSMT(b)), Seq(VarBinding(eSym, toSMT(e))), ArraysEx.Store(bSym, eSym, + Ints.Add(ArraysEx.Select(bSym, eSym), Ints.NumeralLit(1)))) + + case MultiplicityInBag(e, b) => + ArraysEx.Select(toSMT(b), toSMT(e)) + + case BagUnion(b1, b2) => + val plus = SortedSymbol("+", List(IntegerType, IntegerType).map(declareSort), declareSort(IntegerType)) + ArrayMap(plus, toSMT(b1), toSMT(b2)) + + case BagIntersection(b1, b2) => + val abs = SortedSymbol("abs", List(IntegerType).map(declareSort), declareSort(IntegerType)) + val plus = SortedSymbol("+", List(IntegerType, IntegerType).map(declareSort), declareSort(IntegerType)) + val minus = SortedSymbol("-", List(IntegerType, IntegerType).map(declareSort), declareSort(IntegerType)) + val div = SortedSymbol("/", List(IntegerType, IntegerType).map(declareSort), declareSort(IntegerType)) + + val did = FreshIdentifier("d", b1.getType, true) + val dSym = id2sym(did) + + val all2 = ArrayConst(declareSort(IntegerType), Ints.NumeralLit(2)) + + SMTLet(VarBinding(dSym, ArrayMap(minus, toSMT(b1), toSMT(b2))), Seq(), + ArrayMap(div, ArrayMap(plus, dSym, ArrayMap(abs, dSym)), all2)) + case _ => super.toSMT(e) } @@ -159,8 +198,16 @@ trait SMTLIBZ3Target extends SMTLIBTarget { throw LeonFatalError("Unable to extract "+s) } + protected object SortedSymbol { + def apply(op: String, from: List[Sort], to: Sort) = { + def simpleSort(s: Sort): Boolean = s.subSorts.isEmpty && !s.id.isIndexed + assert(from.forall(simpleSort) && simpleSort(to), "SortedSymbol is only defined for simple sorts") + SList(SSymbol(op), SList(from.map(_.id.symbol): _*), to.id.symbol) + } + } + protected object ArrayMap { - def apply(op: SSymbol, arrs: Term*) = { + def apply(op: SExpr, arrs: Term*) = { FunctionApplication( QualifiedIdentifier(SMTIdentifier(SSymbol("map"), List(op))), arrs diff --git a/src/main/scala/leon/solvers/theories/BagEncoder.scala b/src/main/scala/leon/solvers/theories/BagEncoder.scala index 4ba7fcaa42307ae9c1b65bd3b81c4da03ab62b22..2ae7f2ac8a59b0cc2e96a2cec06b22d1983b79da 100644 --- a/src/main/scala/leon/solvers/theories/BagEncoder.scala +++ b/src/main/scala/leon/solvers/theories/BagEncoder.scala @@ -6,9 +6,113 @@ package theories import purescala.Common._ import purescala.Expressions._ +import purescala.Definitions._ import purescala.Types._ -class BagEncoder(val context: LeonContext) extends TheoryEncoder { - val encoder = new Encoder - val decoder = new Decoder +class BagEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { + val Bag = p.library.lookupUnique[CaseClassDef]("leon.theories.Bag") + + val Get = p.library.lookupUnique[FunDef]("leon.theories.Bag.get") + val Add = p.library.lookupUnique[FunDef]("leon.theories.Bag.add") + val Union = p.library.lookupUnique[FunDef]("leon.theories.Bag.union") + val Difference = p.library.lookupUnique[FunDef]("leon.theories.Bag.difference") + val Intersect = p.library.lookupUnique[FunDef]("leon.theories.Bag.intersect") + val BagEquals = p.library.lookupUnique[FunDef]("leon.theories.Bag.equals") + + val encoder = new Encoder { + override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + case FiniteBag(elems, tpe) => + val newTpe = transform(tpe) + val id = FreshIdentifier("x", newTpe, true) + CaseClass(Bag.typed(Seq(newTpe)), Seq(Lambda(Seq(ValDef(id)), + elems.foldRight[Expr](InfiniteIntegerLiteral(0).copiedFrom(e)) { case ((k, v), ite) => + IfExpr(Equals(Variable(id), transform(k)), transform(v), ite).copiedFrom(e) + }))).copiedFrom(e) + + case BagAdd(bag, elem) => + val BagType(base) = bag.getType + FunctionInvocation(Add.typed(Seq(transform(base))), Seq(transform(bag), transform(elem))).copiedFrom(e) + + case MultiplicityInBag(elem, bag) => + val BagType(base) = bag.getType + FunctionInvocation(Get.typed(Seq(transform(base))), Seq(transform(bag), transform(elem))).copiedFrom(e) + + case BagIntersection(b1, b2) => + val BagType(base) = b1.getType + FunctionInvocation(Intersect.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e) + + case BagUnion(b1, b2) => + val BagType(base) = b1.getType + FunctionInvocation(Union.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e) + + case BagDifference(b1, b2) => + val BagType(base) = b1.getType + FunctionInvocation(Difference.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e) + + case Equals(b1, b2) if b1.getType.isInstanceOf[BagType] => + val BagType(base) = b1.getType + FunctionInvocation(BagEquals.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e) + + case _ => super.transform(e) + } + + override def transform(tpe: TypeTree): TypeTree = tpe match { + case BagType(base) => Bag.typed(Seq(transform(base))).copiedFrom(tpe) + case _ => super.transform(tpe) + } + } + + val decoder = new Decoder { + override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { + case cc @ CaseClass(CaseClassType(Bag, Seq(tpe)), args) => + FiniteBag(args(0) match { + case FiniteLambda(mapping, dflt, tpe) => + if (dflt != InfiniteIntegerLiteral(0)) + throw new Unsupported(cc, "Bags can't have default value " + dflt.asString(ctx))(ctx) + mapping.map { case (ks, v) => transform(ks.head) -> transform(v) }.toMap + + case Lambda(Seq(ValDef(id)), body) => + def rec(expr: Expr): Map[Expr, Expr] = expr match { + case IfExpr(Equals(`id`, k), v, elze) => rec(elze) + (transform(k) -> transform(v)) + case InfiniteIntegerLiteral(v) if v == 0 => Map.empty + case _ => throw new Unsupported(expr, "Bags can't have default value " + expr.asString(ctx))(ctx) + } + + rec(body) + + case f => scala.sys.error("Unexpected function " + f.asString(ctx)) + }, transform(tpe)).copiedFrom(e) + + case FunctionInvocation(TypedFunDef(Add, Seq(_)), Seq(bag, elem)) => + BagAdd(transform(bag), transform(elem)).copiedFrom(e) + + case FunctionInvocation(TypedFunDef(Get, Seq(_)), Seq(bag, elem)) => + MultiplicityInBag(transform(elem), transform(bag)).copiedFrom(e) + + case FunctionInvocation(TypedFunDef(Intersect, Seq(_)), Seq(b1, b2)) => + BagIntersection(transform(b1), transform(b2)).copiedFrom(e) + + case FunctionInvocation(TypedFunDef(Union, Seq(_)), Seq(b1, b2)) => + BagUnion(transform(b1), transform(b2)).copiedFrom(e) + + case FunctionInvocation(TypedFunDef(Difference, Seq(_)), Seq(b1, b2)) => + BagDifference(transform(b1), transform(b2)).copiedFrom(e) + + case FunctionInvocation(TypedFunDef(BagEquals, Seq(_)), Seq(b1, b2)) => + Equals(transform(b1), transform(b2)).copiedFrom(e) + + case _ => super.transform(e) + } + + override def transform(tpe: TypeTree): TypeTree = tpe match { + case CaseClassType(Bag, Seq(base)) => BagType(transform(base)).copiedFrom(tpe) + case _ => super.transform(tpe) + } + + override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { + case CaseClassPattern(b, CaseClassType(Bag, Seq(tpe)), Seq(sub)) => + throw new Unsupported(pat, "Can't translate Bag case class pattern")(ctx) + case _ => super.transform(pat) + } + } } diff --git a/src/main/scala/leon/solvers/theories/StringEncoder.scala b/src/main/scala/leon/solvers/theories/StringEncoder.scala index 8f33513a64898c17a7a4b2dfda13a037b2d93f5d..5af13c1c0e52e29d3f79654a6cbe30f009630d5d 100644 --- a/src/main/scala/leon/solvers/theories/StringEncoder.scala +++ b/src/main/scala/leon/solvers/theories/StringEncoder.scala @@ -10,114 +10,18 @@ import purescala.Constructors._ import purescala.Types._ import purescala.Definitions._ import leon.utils.Bijection -import leon.purescala.DefOps import leon.purescala.TypeOps -import leon.purescala.Extractors.Operator -import leon.evaluators.EvaluationResults -object StringEcoSystem { - private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { - val id = FreshIdentifier(name, tpe) - f(id) - } - - private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { - withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) - } - - val StringList = new AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) - val StringListTyped = StringList.typed - val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => - val d = new CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) - d.setFields(Seq(ValDef(head), ValDef(tail))) - d - } - - StringList.registerChild(StringCons) - val StringConsTyped = StringCons.typed - val StringNil = new CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) - val StringNilTyped = StringNil.typed - StringList.registerChild(StringNil) - - val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => - val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) - fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => - MatchExpr(Variable(lengthArg), Seq( - MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), - MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, - Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) - )) - }) - fd - } - - val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => - val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) - fd.body = Some( - withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => - MatchExpr(Variable(x), Seq( - MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), - MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, - CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) - ))) - } - ) - fd - } - - val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => - val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) - fd.body = Some{ - withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => - withIdentifier("i", IntegerType){ i => - MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, - InfiniteIntegerLiteral(BigInt(0))), - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, - IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), - CaseClass(StringNilTyped, Seq()), - CaseClass(StringConsTyped, Seq(Variable(h), - FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) - )))) - } - } - } - fd - } - - val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => - val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) - fd.body = Some( - withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => - withIdentifier("i", IntegerType){ i => - MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, - InfiniteIntegerLiteral(BigInt(0))), - MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, - IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), - CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), - FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) - )))) - }} - ) - fd - } - - val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => - val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) - fd.body = Some( - FunctionInvocation(StringTake.typed, - Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), - Minus(Variable(to), Variable(from))))) - fd - } } - - val classDefs = Seq(StringList, StringCons, StringNil) - val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) -} +class StringEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { + val String = p.library.lookupUnique[ClassDef]("leon.theories.String").typed + val StringCons = p.library.lookupUnique[CaseClassDef]("leon.theories.StringCons").typed + val StringNil = p.library.lookupUnique[CaseClassDef]("leon.theories.StringNil").typed -class StringEncoder extends TheoryEncoder { - import StringEcoSystem._ + val Size = p.library.lookupUnique[FunDef]("leon.theories.String.size").typed + val Take = p.library.lookupUnique[FunDef]("leon.theories.String.take").typed + val Drop = p.library.lookupUnique[FunDef]("leon.theories.String.drop").typed + val Slice = p.library.lookupUnique[FunDef]("leon.theories.String.slice").typed + val Concat = p.library.lookupUnique[FunDef]("leon.theories.String.concat").typed private val stringBijection = new Bijection[String, Expr]() @@ -127,8 +31,8 @@ class StringEncoder extends TheoryEncoder { }) private def convertFromString(v: String): Expr = stringBijection.cachedB(v) { - v.toList.foldRight(CaseClass(StringNilTyped, Seq())){ - case (char, l) => CaseClass(StringConsTyped, Seq(CharLiteral(char), l)) + v.toList.foldRight(CaseClass(StringNil, Seq())){ + case (char, l) => CaseClass(StringCons, Seq(CharLiteral(char), l)) } } @@ -137,26 +41,26 @@ class StringEncoder extends TheoryEncoder { case StringLiteral(v) => convertFromString(v) case StringLength(a) => - FunctionInvocation(StringSize.typed, Seq(transform(a))).copiedFrom(e) + FunctionInvocation(Size, Seq(transform(a))).copiedFrom(e) case StringConcat(a, b) => - FunctionInvocation(StringListConcat.typed, Seq(transform(a), transform(b))).copiedFrom(e) + FunctionInvocation(Concat, Seq(transform(a), transform(b))).copiedFrom(e) case SubString(a, start, Plus(start2, length)) if start == start2 => - FunctionInvocation(StringTake.typed, Seq(FunctionInvocation(StringDrop.typed, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e) + FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e) case SubString(a, start, end) => - FunctionInvocation(StringSlice.typed, Seq(transform(a), transform(start), transform(end))).copiedFrom(e) + FunctionInvocation(Slice, Seq(transform(a), transform(start), transform(end))).copiedFrom(e) case _ => super.transform(e) } override def transform(tpe: TypeTree): TypeTree = tpe match { - case StringType => StringListTyped + case StringType => String case _ => super.transform(tpe) } override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { case LiteralPattern(binder, StringLiteral(s)) => val newBinder = binder map transform - val newPattern = s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { - case (elem, pattern) => CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) + val newPattern = s.foldRight(CaseClassPattern(None, StringNil, Seq())) { + case (elem, pattern) => CaseClassPattern(None, StringCons, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) } (newPattern.copy(binder = newBinder), (binder zip newBinder).filter(p => p._1 != p._2).toMap) case _ => super.transform(pat) @@ -165,35 +69,42 @@ class StringEncoder extends TheoryEncoder { val decoder = new Decoder { override def transform(e: Expr)(implicit binders: Map[Identifier, Identifier]): Expr = e match { - case cc @ CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + case cc @ CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, String)=> StringLiteral(convertToString(cc)).copiedFrom(cc) - case FunctionInvocation(StringSize, Seq(a)) => + case FunctionInvocation(Size, Seq(a)) => StringLength(transform(a)).copiedFrom(e) - case FunctionInvocation(StringListConcat, Seq(a, b)) => + case FunctionInvocation(Concat, Seq(a, b)) => StringConcat(transform(a), transform(b)).copiedFrom(e) - case FunctionInvocation(StringTake, Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + case FunctionInvocation(Slice, Seq(a, from, to)) => + SubString(transform(a), transform(from), transform(to)).copiedFrom(e) + case FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(a, start)), length)) => val rstart = transform(start) SubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e) + case FunctionInvocation(Take, Seq(a, length)) => + SubString(transform(a), InfiniteIntegerLiteral(0), transform(length)).copiedFrom(e) + case FunctionInvocation(Drop, Seq(a, count)) => + val ra = transform(a) + SubString(ra, transform(count), StringLength(ra)).copiedFrom(e) case _ => super.transform(e) } override def transform(tpe: TypeTree): TypeTree = tpe match { - case StringListTyped | StringConsTyped | StringNilTyped => StringType + case String | StringCons | StringNil => StringType case _ => super.transform(tpe) } override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { - case CaseClassPattern(b, StringNilTyped, Seq()) => + case CaseClassPattern(b, StringNil, Seq()) => val newBinder = b map transform (LiteralPattern(newBinder , StringLiteral("")), (b zip newBinder).filter(p => p._1 != p._2).toMap) - case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), sub)) => transform(sub) match { + case CaseClassPattern(b, StringCons, Seq(LiteralPattern(_, CharLiteral(elem)), sub)) => transform(sub) match { case (LiteralPattern(_, StringLiteral(s)), binders) => val newBinder = b map transform (LiteralPattern(newBinder, StringLiteral(elem + s)), (b zip newBinder).filter(p => p._1 != p._2).toMap ++ binders) case (e, binders) => - (LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)), binders) + throw new Unsupported(pat, "Failed to parse pattern back as string: " + e)(ctx) } case _ => super.transform(pat) diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index f8237a4885fdc459efe261bb4333afc36944b500..abf4c63d3d2c019a3a33e62fee88a5d663c4d1e8 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -103,7 +103,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) def asString(implicit ctx: LeonContext) = z3.toString } - val theoryEncoder = new StringEncoder + val theoryEncoder = new StringEncoder(context, program) >> new BagEncoder(context, program) val templateEncoder = new TemplateEncoder[Z3AST] { def encodeId(id: Identifier): Z3AST = { diff --git a/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala b/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala index 99c75199a7d96b64ce42ba9039dd648032e60888..a1ed39dfe442044be38be88dac287cfd34aca391 100644 --- a/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/z3/Z3UnrollingSolver.scala @@ -10,4 +10,4 @@ import unrolling._ import theories._ class Z3UnrollingSolver(context: LeonContext, program: Program, underlying: Solver) - extends UnrollingSolver(context, program, underlying, theories = new StringEncoder) + extends UnrollingSolver(context, program, underlying, new StringEncoder(context, program)) diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index 77b15659c3e2e7be97d36d9e5645af5daf9cfb58..daab94567d4ef6cb27f79d074ab58ea37d10a0b3 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -33,7 +33,7 @@ final case class Chain(relations: List[Relation]) { def rec(list: List[Relation], funDef: TypedFunDef, args: Seq[Expr]): Seq[(Seq[ValDef], Expr)] = list match { case Relation(_, _, fi @ FunctionInvocation(fitfd, nextArgs), _) :: xs => val tfd = TypedFunDef(fitfd.fd, fitfd.tps.map(funDef.translated)) - val subst = tfd.paramSubst(args) + val subst = funDef.paramSubst(args) val expr = replaceFromIDs(subst, hoistIte(expandLets(matchToIfThenElse(tfd.body.get)))) val mappedArgs = nextArgs.map(e => replaceFromIDs(subst, tfd.translated(e))) diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index e075c374e8d4796680c940f3cf6e4d3ec455df79..610472d92bcd8c43e80235c51b0fb89e9a5354bd 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -78,21 +78,23 @@ trait StructuralSize { }).foldLeft[Expr](InfiniteIntegerLiteral(0))(Plus) case IntegerType => FunctionInvocation(typedAbsBigIntFun, Seq(expr)) + case Int32Type => + FunctionInvocation(typedAbsIntFun, Seq(expr)) case _ => InfiniteIntegerLiteral(0) } } def lexicographicDecreasing(s1: Seq[Expr], s2: Seq[Expr], strict: Boolean, sizeOfOneExpr: Expr => Expr): Expr = { // Note: The Equal and GreaterThan ASTs work for both BigInt and Bitvector - + val sameSizeExprs = for ((arg1, arg2) <- s1 zip s2) yield Equals(sizeOfOneExpr(arg1), sizeOfOneExpr(arg2)) - + val greaterBecauseGreaterAtFirstDifferentPos = orJoin(for (firstDifferent <- 0 until scala.math.min(s1.length, s2.length)) yield and( andJoin(sameSizeExprs.take(firstDifferent)), GreaterThan(sizeOfOneExpr(s1(firstDifferent)), sizeOfOneExpr(s2(firstDifferent))) )) - + if (s1.length > s2.length || (s1.length == s2.length && !strict)) { or(andJoin(sameSizeExprs), greaterBecauseGreaterAtFirstDifferentPos) } else { diff --git a/src/main/scala/leon/utils/Library.scala b/src/main/scala/leon/utils/Library.scala index 18dc12938bd2dd55ff51e568e5af4bcd79cd245a..6bff50527120342bd5610c7e3ba9576a4bf2af95 100644 --- a/src/main/scala/leon/utils/Library.scala +++ b/src/main/scala/leon/utils/Library.scala @@ -7,6 +7,8 @@ import purescala.Definitions._ import purescala.Types._ import purescala.DefOps._ +import scala.reflect._ + case class Library(pgm: Program) { lazy val List = lookup("leon.collection.List").collectFirst { case acd : AbstractClassDef => acd } lazy val Cons = lookup("leon.collection.Cons").collectFirst { case ccd : CaseClassDef => ccd } @@ -28,6 +30,13 @@ case class Library(pgm: Program) { pgm.lookupAll(name) } + def lookupUnique[D <: Definition : ClassTag](name: String): D = { + val ct = classTag[D] + val all = pgm.lookupAll(name).filter(d => ct.runtimeClass.isInstance(d)) + assert(all.size == 1, "lookupUnique(\"name\") returned results " + all.map(_.id.uniqueName)) + all.head.asInstanceOf[D] + } + def optionType(tp: TypeTree) = AbstractClassType(Option.get, Seq(tp)) def someType(tp: TypeTree) = CaseClassType(Some.get, Seq(tp)) def noneType(tp: TypeTree) = CaseClassType(None.get, Seq(tp)) diff --git a/src/test/resources/regression/verification/purescala/valid/Nested14.scala b/src/test/resources/regression/verification/purescala/valid/Nested14.scala index 82f4fd1b965dc06c2458d237c84eb96d23ba8349..39e668aa208f1eeefd8257b359b3b5a5be51b39f 100644 --- a/src/test/resources/regression/verification/purescala/valid/Nested14.scala +++ b/src/test/resources/regression/verification/purescala/valid/Nested14.scala @@ -4,7 +4,11 @@ object Nested14 { def foo(i: Int): Int = { def rec1(j: Int): Int = { - def rec2(k: Int): Int = if(k == 0) 0 else rec1(j-1) + require(j >= 0) + def rec2(k: Int): Int = { + require(j > 0 || j == k) + if(k == 0) 0 else rec1(j-1) + } rec2(j) } rec1(3) diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index 8b0d8026e1540a170df577b3952140e4ef8fa910..f47f3c9b9e58f9eb44c601716d363cbeedc6b3d3 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -41,6 +41,7 @@ class SolversSuite extends LeonTestSuiteWithProgram { StringType, TypeParameter.fresh("T"), SetType(IntegerType), + BagType(IntegerType), MapType(IntegerType, IntegerType), FunctionType(Seq(IntegerType), IntegerType), TupleType(Seq(IntegerType, BooleanType, Int32Type)) @@ -54,6 +55,8 @@ class SolversSuite extends LeonTestSuiteWithProgram { Equals(v, simplestValue(v.getType)) case SetType(base) => Not(ElementOfSet(simplestValue(base), v)) + case BagType(base) => + Not(Equals(MultiplicityInBag(simplestValue(base), v), simplestValue(IntegerType))) case MapType(from, to) => Not(Equals(MapApply(v, simplestValue(from)), simplestValue(to))) case FunctionType(froms, to) =>