From 014bf6ab2edfd9a7dd54dd5967db18653927c0ff Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Wed, 9 Jun 2010 20:59:05 +0000
Subject: [PATCH] support for some set operations.

---
 src/funcheck/CodeExtraction.scala | 25 +++++++++++++++++--
 src/funcheck/Extractors.scala     | 40 +++++++++++++++++++++++++++++++
 src/purescala/PrettyPrinter.scala |  5 ++++
 3 files changed, 68 insertions(+), 2 deletions(-)

diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala
index 7b7a536f5..857d31984 100644
--- a/src/funcheck/CodeExtraction.scala
+++ b/src/funcheck/CodeExtraction.scala
@@ -307,8 +307,8 @@ trait CodeExtraction extends Extractors {
     }
 
     def rec(tr: Tree): Expr = tr match {
-      case ExInt32Literal(v) => IntLiteral(v)
-      case ExBooleanLiteral(v) => BooleanLiteral(v)
+      case ExInt32Literal(v) => IntLiteral(v).setType(Int32Type)
+      case ExBooleanLiteral(v) => BooleanLiteral(v).setType(BooleanType)
       case ExIdentifier(sym,tpt) => varSubsts.get(sym) match {
         case Some(fun) => fun()
         case None => {
@@ -337,10 +337,31 @@ trait CodeExtraction extends Extractors {
       case ExTimes(l, r) => Times(rec(l), rec(r)).setType(Int32Type)
       case ExDiv(l, r) => Division(rec(l), rec(r)).setType(Int32Type)
       case ExEquals(l, r) => Equals(rec(l), rec(r)).setType(BooleanType)
+      case ExNotEquals(l, r) => Not(Equals(rec(l), rec(r)).setType(BooleanType)).setType(BooleanType)
       case ExGreaterThan(l, r) => GreaterThan(rec(l), rec(r)).setType(BooleanType)
       case ExGreaterEqThan(l, r) => GreaterEquals(rec(l), rec(r)).setType(BooleanType)
       case ExLessThan(l, r) => LessThan(rec(l), rec(r)).setType(BooleanType)
       case ExLessEqThan(l, r) => LessEquals(rec(l), rec(r)).setType(BooleanType)
+
+      case ExEmptySet(tt) => {
+        val underlying = scalaType2PureScala(unit, silent)(tt.tpe)
+        EmptySet(underlying).setType(SetType(underlying))          
+      }
+      case ExUnion(t1,t2) => {
+        val rl = rec(t1)
+        val rr = rec(t2)
+        SetUnion(rl, rr).setType(rl.getType) // this is not entirely correct: should be a setype of LUB of underlying types of left and right.
+      }
+      case ExIntersection(t1,t2) => {
+        val rl = rec(t1)
+        val rr = rec(t2)
+        SetIntersection(rl, rr).setType(rl.getType) // same as union
+      } 
+      case ExSetMinus(t1,t2) => {
+        val rl = rec(t1)
+        val rr = rec(t2)
+        SetDifference(rl, rr).setType(rl.getType) // same as union
+      } 
       case ExIfThenElse(t1,t2,t3) => {
         val r1 = rec(t1)
         val r2 = rec(t2)
diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala
index 7fba429eb..ae2bdf496 100644
--- a/src/funcheck/Extractors.scala
+++ b/src/funcheck/Extractors.scala
@@ -10,6 +10,8 @@ trait Extractors {
   import global._
   import global.definitions._
 
+  private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set")
+
   object StructuralExtractors {
     object ScalaPredef {
       /** Extracts method calls from scala.Predef. */
@@ -286,5 +288,43 @@ trait Extractors {
       def unapply(tree: Match): Option[(Tree,List[CaseDef])] =
         if(tree != null) Some((tree.selector, tree.cases)) else None
     }
+
+    object ExEmptySet {
+      def unapply(tree: TypeApply): Option[Tree] = tree match {
+        case TypeApply(
+          Select(
+            Select(
+              Select(
+                Select(Ident(s), collectionName),
+                immutableName),
+              setName),
+            emptyName),  theTypeTree :: Nil) if (
+            collectionName.toString == "collection" && immutableName.toString == "immutable" && setName.toString == "Set" && emptyName.toString == "empty"
+          ) => Some(theTypeTree)
+        case _ => None
+      }
+    }
+
+    object ExUnion {
+      def unapply(tree: Apply): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == nme.PLUSPLUS) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExIntersection {
+      def unapply(tree: Apply): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == encode("**")) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+  
+    object ExSetMinus {
+      def unapply(tree: Apply): Option[(Tree,Tree)] = tree match {
+        case Apply(Select(lhs, n), List(rhs)) if (n == encode("--")) => Some((lhs,rhs))
+        case _ => None
+      }
+    }
+    
   }
 }
diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala
index ccd9f6580..8a835a896 100644
--- a/src/purescala/PrettyPrinter.scala
+++ b/src/purescala/PrettyPrinter.scala
@@ -98,6 +98,10 @@ object PrettyPrinter {
     case GreaterThan(l,r) => ppBinary(sb, l, r, " > ", lvl)
     case LessEquals(l,r) => ppBinary(sb, l, r, " \u2264 ", lvl)      // \leq
     case GreaterEquals(l,r) => ppBinary(sb, l, r, " \u2265 ", lvl)   // \geq
+    case EmptySet(_) => sb.append("Ø")
+    case SetUnion(l,r) => ppBinary(sb, l, r, " U ", lvl)
+    case SetDifference(l,r) => ppBinary(sb, l, r, " \\ ", lvl)
+    case SetIntersection(l,r) => ppBinary(sb, l, r, " INT ", lvl)
     
     case IfExpr(c, t, e) => {
       var nsb = sb
@@ -162,6 +166,7 @@ object PrettyPrinter {
 
     case ResultVariable() => sb.append("#res")
 
+
     case _ => sb.append("Expr?")
   }
 
-- 
GitLab