diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index 6f4a2b8c4a9e2b929f69a0fcfc911413679811f6..1e189924c04349de5b63e6efefab85b8e591507d 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -371,6 +371,10 @@ trait CodeExtraction extends Extractors { val rr = rec(t2) SetDifference(rl, rr).setType(rl.getType) // same as union } + case ExSetCard(t) => { + val rt = rec(t) + SetCardinality(rt) + } case ExIfThenElse(t1,t2,t3) => { val r1 = rec(t1) val r2 = rec(t2) diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala index 6236796ef36455aafb272811f9414901d7e9134d..2d108abdcd56d32de01d99163d2e31f5fb41bb72 100644 --- a/src/funcheck/Extractors.scala +++ b/src/funcheck/Extractors.scala @@ -345,6 +345,12 @@ trait Extractors { case _ => None } } - + + object ExSetCard { + def unapply(tree: Select): Option[Tree] = tree match { + case Select(t, n) if (n.toString == "size") => Some(t) + case _ => None + } + } } } diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala index c7529cfdf824e5ea8a1df981c91e30f86d66061d..ad018307c3472b37b4193eefa9399643e657b159 100644 --- a/src/purescala/PrettyPrinter.scala +++ b/src/purescala/PrettyPrinter.scala @@ -33,12 +33,11 @@ object PrettyPrinter { // EXPRESSIONS // all expressions are printed in-line - private def ppUnary(sb: StringBuffer, expr: Expr, op: String, lvl: Int): StringBuffer = { + private def ppUnary(sb: StringBuffer, expr: Expr, op1: String, op2: String, lvl: Int): StringBuffer = { var nsb: StringBuffer = sb - nsb.append(op) - nsb.append("(") + nsb.append(op1) nsb = pp(expr, nsb, lvl) - nsb.append(")") + nsb.append(op2) nsb } @@ -71,9 +70,9 @@ object PrettyPrinter { case And(exprs) => ppNary(sb, exprs, "(", " \u2227 ", ")", lvl) // \land case Or(exprs) => ppNary(sb, exprs, "(", " \u2228 ", ")", lvl) // \lor case Not(Equals(l, r)) => ppBinary(sb, l, r, " \u2260 ", lvl) // \neq - case Not(expr) => ppUnary(sb, expr, "\u00AC", lvl) // \neg + case Not(expr) => ppUnary(sb, expr, "\u00AC(", ")", lvl) // \neg case Implies(l,r) => ppBinary(sb, l, r, "==>", lvl) - case UMinus(expr) => ppUnary(sb, expr, "-", lvl) + case UMinus(expr) => ppUnary(sb, expr, "-(", ")", lvl) case Equals(l,r) => ppBinary(sb, l, r, " == ", lvl) case IntLiteral(v) => sb.append(v) case BooleanLiteral(v) => sb.append(v) @@ -104,6 +103,7 @@ object PrettyPrinter { case SetUnion(l,r) => ppBinary(sb, l, r, " \u222A ", lvl) // \cup case SetDifference(l,r) => ppBinary(sb, l, r, " \\ ", lvl) case SetIntersection(l,r) => ppBinary(sb, l, r, " \u2229 ", lvl) // \cap + case SetCardinality(t) => ppUnary(sb, t, "|", "|", lvl) case IfExpr(c, t, e) => { var nsb = sb diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 1306ed75e7a83a2fb2c25f16076a22cd09691580..39072861072544603ccb130433ead8a2d877ecaa 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -134,7 +134,9 @@ object Trees { case class ElementOfSet(element: Expr, set: Expr) extends Expr case class IsEmptySet(set: Expr) extends Expr case class SetEquals(set1: Expr, set2: Expr) extends Expr - case class SetCardinality(set: Expr) extends Expr + case class SetCardinality(set: Expr) extends Expr with FixedType { + val fixedType = Int32Type + } case class SubsetOf(set1: Expr, set2: Expr) extends Expr case class SetIntersection(set1: Expr, set2: Expr) extends Expr case class SetUnion(set1: Expr, set2: Expr) extends Expr diff --git a/testcases/BinarySearchTree.scala b/testcases/BinarySearchTree.scala index 28c5b09194fa805eabf009308491e3e8f47fe9eb..e0f61f5146769332aa530d303eec8a311cfa1b20 100644 --- a/testcases/BinarySearchTree.scala +++ b/testcases/BinarySearchTree.scala @@ -21,9 +21,9 @@ object BinarySearchTree { case Node(l, v, r) if v > value => contains(l, value) } - def contents(tree: Tree) : Set[Int] = tree match { + def contents(tree: Tree) : Set[Int] = (tree match { case Leaf() => Set.empty[Int] case Node(l, v, r) => contents(l) ++ Set(v) ++ contents(r) - } + }) ensuring(res => res == Set.empty[Int] || true) //res.min <= res.max) } diff --git a/testcases/SetOperations.scala b/testcases/SetOperations.scala index 0cd18de4c67ecdc46a648bb3891890a8a65fd7f3..00f4474ec10917ae77f3e9593c6b95bd1cdc82ce 100644 --- a/testcases/SetOperations.scala +++ b/testcases/SetOperations.scala @@ -5,7 +5,7 @@ import scala.collection.immutable.Set object SetOperations { def add(a: Set[Int], b: Int) : Set[Int] = { - // require(a.size >= 0) - a + b - } // ensuring((x:Set[Int]) => x.size == a.size + 1) + require(a.size >= 0) + a ++ Set(b) + } ensuring((x:Set[Int]) => x.size == a.size + 1) }