From 496c4fa729b1388a070e393d4807e97aeedb7cdb Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Thu, 14 Feb 2013 16:07:14 +0100
Subject: [PATCH] Implement support for "case object"

---
 .../scala/leon/plugin/CodeExtraction.scala    | 77 +++++++++++--------
 src/main/scala/leon/plugin/Extractors.scala   | 10 +++
 .../scala/leon/purescala/Definitions.scala    |  1 +
 .../scala/leon/purescala/PrettyPrinter.scala  |  6 +-
 .../scala/leon/purescala/ScalaPrinter.scala   |  6 +-
 .../purescala/valid/CaseObject1.scala         | 21 +++++
 6 files changed, 85 insertions(+), 36 deletions(-)
 create mode 100644 src/test/resources/regression/verification/purescala/valid/CaseObject1.scala

diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala
index ad3dc998a..06a6ac624 100644
--- a/src/main/scala/leon/plugin/CodeExtraction.scala
+++ b/src/main/scala/leon/plugin/CodeExtraction.scala
@@ -19,6 +19,7 @@ trait CodeExtraction extends Extractors {
   import global.definitions._
   import StructuralExtractors._
   import ExpressionExtractors._
+  import ExtractorHelpers._
 
 
   private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set")
@@ -156,7 +157,9 @@ trait CodeExtraction extends Extractors {
           if(p._1.isAbstractClass) {
             classesToClasses += (p._1 -> new AbstractClassDef(p._2))
           } else if(p._1.isCase) {
-            classesToClasses += (p._1 -> new CaseClassDef(p._2))
+            val ccd = new CaseClassDef(p._2)
+            ccd.isCaseObject = p._1.isModuleClass
+            classesToClasses += (p._1 -> ccd)
           }
       })
 
@@ -367,51 +370,48 @@ trait CodeExtraction extends Extractors {
    * errors. */
   private def scala2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): Expr = {
     def rewriteCaseDef(cd: CaseDef): MatchCase = {
-      def pat2pat(p: Tree): Pattern = p match {
-        case Ident(nme.WILDCARD) => WildcardPattern(None)
-        case b @ Bind(name, Ident(nme.WILDCARD)) => {
-          val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(b.symbol.tpe))
-          varSubsts(b.symbol) = (() => Variable(newID))
-          WildcardPattern(Some(newID))
-        }
-        case b @ Bind(name, Typed(Ident(nme.WILDCARD), tpe)) => {
+
+      def pat2pat(p: Tree, binder: Option[Identifier] = None): Pattern = p match {
+        case b @ Bind(name, Typed(pat, tpe)) =>
           val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(tpe.tpe))
           varSubsts(b.symbol) = (() => Variable(newID))
-          WildcardPattern(Some(newID))
-        }
-        case a @ Apply(fn, args) if fn.isType && a.tpe.typeSymbol.isCase && classesToClasses.keySet.contains(a.tpe.typeSymbol) => {
-          val cd = classesToClasses(a.tpe.typeSymbol).asInstanceOf[CaseClassDef]
-          assert(args.size == cd.fields.size)
-          CaseClassPattern(None, cd, args.map(pat2pat(_)))
-        }
-        case b @ Bind(name, a @ Apply(fn, args)) if fn.isType && a.tpe.typeSymbol.isCase && classesToClasses.keySet.contains(a.tpe.typeSymbol) => {
+          pat2pat(pat, Some(newID))
+
+        case b @ Bind(name, pat) =>
           val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(b.symbol.tpe))
           varSubsts(b.symbol) = (() => Variable(newID))
+          pat2pat(pat, Some(newID))
+
+        case Ident(nme.WILDCARD) =>
+          WildcardPattern(binder)
+
+        case s @ Select(This(_), b) if s.tpe.typeSymbol.isCase &&
+                                       classesToClasses.keySet.contains(s.tpe.typeSymbol) =>
+          // case Obj =>
+          val cd = classesToClasses(s.tpe.typeSymbol).asInstanceOf[CaseClassDef]
+          assert(cd.fields.size == 0)
+          CaseClassPattern(binder, cd, Seq())
+
+        case a @ Apply(fn, args) if fn.isType &&
+                                    a.tpe.typeSymbol.isCase &&
+                                    classesToClasses.keySet.contains(a.tpe.typeSymbol) =>
+
           val cd = classesToClasses(a.tpe.typeSymbol).asInstanceOf[CaseClassDef]
           assert(args.size == cd.fields.size)
-          CaseClassPattern(Some(newID), cd, args.map(pat2pat(_)))
-        }
-        case a@Apply(fn, args) => {
-          val pst = scalaType2PureScala(unit, silent)(a.tpe)
-          pst match {
-            case TupleType(argsTpes) => TuplePattern(None, args.map(pat2pat))
-            case _ => throw ImpureCodeEncounteredException(p)
-          }
-        }
-        case b @ Bind(name, a @ Apply(fn, args)) => {
-          val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(b.symbol.tpe))
-          varSubsts(b.symbol) = (() => Variable(newID))
+          CaseClassPattern(binder, cd, args.map(pat2pat(_)))
+
+        case a @ Apply(fn, args) =>
           val pst = scalaType2PureScala(unit, silent)(a.tpe)
           pst match {
-            case TupleType(argsTpes) => TuplePattern(Some(newID), args.map(pat2pat))
+            case TupleType(argsTpes) => TuplePattern(binder, args.map(pat2pat(_)))
             case _ => throw ImpureCodeEncounteredException(p)
           }
-        }
-        case _ => {
-          if(!silent)
+
+        case _ =>
+          if (!silent) {
             unit.error(p.pos, "Unsupported pattern.")
+          }
           throw ImpureCodeEncounteredException(p)
-        }
       }
 
       if(cd.guard == EmptyTree) {
@@ -508,6 +508,14 @@ trait CodeExtraction extends Extractors {
       }
 
       val e2: Option[Expr] = nextExpr match {
+        case ExCaseObject(sym) =>
+          classesToClasses.get(sym) match {
+            case Some(ccd: CaseClassDef) =>
+              Some(CaseClass(ccd, Seq()))
+            case _ =>
+              None
+          }
+
         case ExParameterlessMethodCall(t,n) => {
           val selector = rec(t)
           val selType = selector.getType
@@ -726,6 +734,7 @@ trait CodeExtraction extends Extractors {
             val cct = cctype.asInstanceOf[CaseClassType]
             CaseClass(cct.classDef, nargs).setType(cct)
           }
+
           case ExAnd(l, r) => And(rec(l), rec(r)).setType(BooleanType)
           case ExOr(l, r) => Or(rec(l), rec(r)).setType(BooleanType)
           case ExNot(e) => Not(rec(e)).setType(BooleanType)
diff --git a/src/main/scala/leon/plugin/Extractors.scala b/src/main/scala/leon/plugin/Extractors.scala
index b6b15a162..c333df79f 100644
--- a/src/main/scala/leon/plugin/Extractors.scala
+++ b/src/main/scala/leon/plugin/Extractors.scala
@@ -123,6 +123,16 @@ trait Extractors {
       }
     }
 
+    object ExCaseObject {
+      def unapply(s: Select): Option[Symbol] = {
+        if (s.tpe.typeSymbol.isModuleClass) {
+          Some(s.tpe.typeSymbol)
+        } else {
+          None
+        }
+      }
+    }
+
     object ExCaseClassSyntheticJunk {
       def unapply(cd: ClassDef): Boolean = cd match {
         case ClassDef(_, _, _, _) if (cd.symbol.isSynthetic && cd.symbol.isFinal) => true
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index 6e3c0b926..fe9ae283a 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -253,6 +253,7 @@ object Definitions {
   class CaseClassDef(val id: Identifier, prnt: Option[AbstractClassDef] = None) extends ClassTypeDef with ExtractorTypeDef {
     private var parent_ = prnt
     var fields: VarDecls = Nil
+    var isCaseObject = false
     val isAbstract = false
 
     def setParent(newParent: AbstractClassDef) = {
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index 8a244a5e8..fd8fa2ed0 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -117,7 +117,11 @@ object PrettyPrinter {
     case CaseClass(cd, args) => {
       var nsb = sb
       nsb.append(cd.id)
-      nsb = ppNary(nsb, args, "(", ", ", ")", lvl)
+      if (cd.isCaseObject) {
+        nsb = ppNary(nsb, args, "", "", "", lvl)
+      } else {
+        nsb = ppNary(nsb, args, "(", ", ", ")", lvl)
+      }
       nsb
     }
     case CaseClassInstanceOf(cd, e) => {
diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala
index d9a82aa8c..9e2579f6c 100644
--- a/src/main/scala/leon/purescala/ScalaPrinter.scala
+++ b/src/main/scala/leon/purescala/ScalaPrinter.scala
@@ -124,7 +124,11 @@ object ScalaPrinter {
 
     case CaseClass(cd, args) => {
       sb.append(cd.id)
-      ppNary(sb, args, "(", ", ", ")", lvl)
+      if (cd.isCaseObject) {
+        ppNary(sb, args, "", "", "", lvl)
+      } else {
+        ppNary(sb, args, "(", ", ", ")", lvl)
+      }
     }
     case CaseClassInstanceOf(cd, e) => {
       pp(e, sb, lvl)
diff --git a/src/test/resources/regression/verification/purescala/valid/CaseObject1.scala b/src/test/resources/regression/verification/purescala/valid/CaseObject1.scala
new file mode 100644
index 000000000..90dcdf064
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/CaseObject1.scala
@@ -0,0 +1,21 @@
+object CaseObject1 {
+
+  abstract sealed class A
+  case class B(size: Int) extends A
+  case object C extends A
+
+  def foo(): A = {
+    C
+  }
+
+  def foo1(a: A): A = a match {
+    case C => a
+    case B(s) => a 
+  }
+
+  def foo2(a: A): A = a match {
+    case c @ C => c
+    case B(s) => a
+  }
+
+}
-- 
GitLab