From e7e93f4bf1e260363e985767a1f3d9dd32e2b815 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Wed, 8 Jul 2015 11:11:00 +0200
Subject: [PATCH] Make CaseClassSelector consistent with other Expr's

---
 .../frontends/scalac/CodeExtraction.scala     |  2 +-
 .../scala/leon/purescala/Constructors.scala   | 12 +++++++--
 .../scala/leon/purescala/Definitions.scala    | 15 ++++++-----
 src/main/scala/leon/purescala/ExprOps.scala   | 10 +++----
 .../scala/leon/purescala/Expressions.scala    | 27 +------------------
 .../scala/leon/purescala/Extractors.scala     |  2 +-
 src/main/scala/leon/purescala/TypeOps.scala   |  2 +-
 .../scala/leon/synthesis/rules/ADTDual.scala  |  4 +--
 .../scala/leon/synthesis/rules/ADTSplit.scala |  2 +-
 .../leon/synthesis/rules/DetupleInput.scala   |  2 +-
 .../leon/termination/ChainComparator.scala    |  4 +--
 src/main/scala/leon/utils/InliningPhase.scala |  6 ++---
 .../leon/verification/InductionTactic.scala   |  2 +-
 13 files changed, 37 insertions(+), 53 deletions(-)

diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 3a94aa0fc..8970b698a 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -1453,7 +1453,7 @@ trait CodeExtraction extends ASTExtractors {
 
               val fieldID = cct.fields.find(_.id.name == name).get.id
 
-              CaseClassSelector(cct, rec, fieldID)
+              caseClassSelector(cct, rec, fieldID)
 
             //BigInt methods
             case (IsTyped(a1, IntegerType), "+", List(IsTyped(a2, IntegerType))) =>
diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala
index 6ffb833e3..05ed1e034 100644
--- a/src/main/scala/leon/purescala/Constructors.scala
+++ b/src/main/scala/leon/purescala/Constructors.scala
@@ -79,9 +79,17 @@ object Constructors {
       case None => sys.error(s"$actualType cannot be a subtype of $formalType!")
     }
 
-   
   }
-  
+
+  def caseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier): Expr = {
+    caseClass match {
+      case CaseClass(ct, fields) if ct.classDef == classType.classDef =>
+        fields(ct.classDef.selectorID2Index(selector))
+      case _ =>
+        CaseClassSelector(classType, caseClass, selector)
+    }
+  }
+
   private def filterCases(scrutType : TypeTree, resType: Option[TypeTree], cases: Seq[MatchCase]): Seq[MatchCase] = {
     val casesFiltered = scrutType match {
       case c: CaseClassType =>
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index 2d37cce58..8b3e6ca9a 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -274,16 +274,17 @@ object Definitions {
       _fields = fields
     }
 
-
     val isAbstract = false
 
-
     def selectorID2Index(id: Identifier) : Int = {
-      val index = fields.zipWithIndex.find(_._1.id == id).map(_._2)
-
-      index.getOrElse {
-        scala.sys.error("Could not find '"+id+"' ("+id.uniqueName+") within "+fields.map(_.id.uniqueName).mkString(", "))
-      }
+      val index = fields.indexWhere(_.id == id)
+
+      if (index < 0) {
+        scala.sys.error(
+          "Could not find '"+id+"' ("+id.uniqueName+") within "+
+          fields.map(_.id.uniqueName).mkString(", ")
+        )
+      } else index
     }
     
     lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this)
diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index 73e4aaed9..727e0d94b 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -392,7 +392,7 @@ object ExprOps {
         Some(letTuple(ids, v, tupleSelect(b, ts, true)))
 
       case CaseClassSelector(cct, cc: CaseClass, id) =>
-        Some(CaseClassSelector(cct, cc, id))
+        Some(caseClassSelector(cct, cc, id))
 
       case IfExpr(c, thenn, elze) if (thenn == elze) && isDeterministic(e) =>
         Some(thenn)
@@ -644,7 +644,7 @@ object ExprOps {
 
       case CaseClassPattern(_, cct, subps) =>
         val subExprs = (subps zip cct.fields) map {
-          case (p, f) => p.binder.map(_.toVariable).getOrElse(CaseClassSelector(cct, in, f.id))
+          case (p, f) => p.binder.map(_.toVariable).getOrElse(caseClassSelector(cct, in, f.id))
         }
 
         // Special case to get rid of Cons(a,b) match { case Cons(c,d) => .. }
@@ -705,7 +705,7 @@ object ExprOps {
         case CaseClassPattern(ob, cct, subps) =>
           assert(cct.fields.size == subps.size)
           val pairs = cct.fields.map(_.id).toList zip subps.toList
-          val subTests = pairs.map(p => rec(CaseClassSelector(cct, in, p._1), p._2))
+          val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2))
           val together = and(bind(ob, in) +: subTests :_*)
           and(IsInstanceOf(cct, in), together)
 
@@ -727,7 +727,7 @@ object ExprOps {
     case CaseClassPattern(b, ccd, subps) =>
       assert(ccd.fields.size == subps.size)
       val pairs = ccd.fields.map(_.id).toList zip subps.toList
-      val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2))
+      val subMaps = pairs.map(p => mapForPattern(caseClassSelector(ccd, in, p._1), p._2))
       val together = subMaps.flatten.toMap
       b match {
         case Some(id) => Map(id -> in) ++ together
@@ -1283,7 +1283,7 @@ object ExprOps {
             val v = Variable(on)
 
             recSelectors.map{ s =>
-              and(isType, expr, not(replace(Map(v -> CaseClassSelector(cct, v, s)), expr)))
+              and(isType, expr, not(replace(Map(v -> caseClassSelector(cct, v, s)), expr)))
             }
           }
       }.flatten
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index b49c61216..77fcdb40c 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -270,34 +270,9 @@ object Expressions {
     val getType = BooleanType
   }
 
-  object CaseClassSelector {
-    def apply(classType: CaseClassType, caseClass: Expr, selector: Identifier): Expr = {
-      caseClass match {
-        case CaseClass(ct, fields) =>
-          if (ct.classDef == classType.classDef) {
-            fields(ct.classDef.selectorID2Index(selector))
-          } else {
-            new CaseClassSelector(classType, caseClass, selector)
-          }
-        case _ => new CaseClassSelector(classType, caseClass, selector)
-      }
-    }
-
-    def unapply(ccs: CaseClassSelector): Option[(CaseClassType, Expr, Identifier)] = {
-      Some((ccs.classType, ccs.caseClass, ccs.selector))
-    }
-  }
-
-  class CaseClassSelector(val classType: CaseClassType, val caseClass: Expr, val selector: Identifier) extends Expr {
+  case class CaseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier) extends Expr {
     val selectorIndex = classType.classDef.selectorID2Index(selector)
     val getType = classType.fieldsTypes(selectorIndex)
-
-    override def equals(that: Any): Boolean = (that != null) && (that match {
-      case t: CaseClassSelector => (t.classType, t.caseClass, t.selector) == (classType, caseClass, selector)
-      case _ => false
-    })
-
-    override def hashCode: Int = (classType, caseClass, selector).hashCode + 9
   }
 
   /* Arithmetic */
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index d74812fd6..d17b58a96 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -28,7 +28,7 @@ object Extractors {
       case SetCardinality(t) =>
         Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head)))
       case CaseClassSelector(cd, e, sel) =>
-        Some((Seq(e), (es: Seq[Expr]) => CaseClassSelector(cd, es.head, sel)))
+        Some((Seq(e), (es: Seq[Expr]) => caseClassSelector(cd, es.head, sel)))
       case IsInstanceOf(cd, e) =>
         Some((Seq(e), (es: Seq[Expr]) => IsInstanceOf(cd, es.head)))
       case TupleSelect(t, i) =>
diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala
index 49529c035..4b77380bf 100644
--- a/src/main/scala/leon/purescala/TypeOps.scala
+++ b/src/main/scala/leon/purescala/TypeOps.scala
@@ -270,7 +270,7 @@ object TypeOps {
             CaseClass(tpeSub(ct).asInstanceOf[CaseClassType], args.map(srec)).copiedFrom(cc)
 
           case cc @ CaseClassSelector(ct, e, sel) =>
-            CaseClassSelector(tpeSub(ct).asInstanceOf[CaseClassType], srec(e), sel).copiedFrom(cc)
+            caseClassSelector(tpeSub(ct).asInstanceOf[CaseClassType], srec(e), sel).copiedFrom(cc)
 
           case cc @ IsInstanceOf(ct, e) =>
             IsInstanceOf(tpeSub(ct).asInstanceOf[ClassType], srec(e)).copiedFrom(cc)
diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala
index 3f68419ed..ddf37760e 100644
--- a/src/main/scala/leon/synthesis/rules/ADTDual.scala
+++ b/src/main/scala/leon/synthesis/rules/ADTDual.scala
@@ -19,10 +19,10 @@ case object ADTDual extends NormalizingRule("ADTDual") {
 
     val (toRemove, toAdd) = exprs.collect {
       case eq @ Equals(cc @ CaseClass(ct, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty =>
-        (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } )
+        (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } )
 
       case eq @ Equals(e, cc @ CaseClass(ct, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty =>
-        (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } )
+        (eq, IsInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } )
     }.unzip
 
     if (toRemove.nonEmpty) {
diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala
index 38bc0455a..111757777 100644
--- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala
+++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala
@@ -74,7 +74,7 @@ case object ADTSplit extends Rule("ADT Split.") {
             val cases = for ((sol, (cct, problem, pattern)) <- sols zip subInfo) yield {
               if (sol.pre != BooleanLiteral(true)) {
                 val substs = (for ((field,arg) <- cct.fields zip problem.as ) yield {
-                  (arg, CaseClassSelector(cct, id.toVariable, field.id))
+                  (arg, caseClassSelector(cct, id.toVariable, field.id))
                 }).toMap
                 globalPre ::= and(IsInstanceOf(cct, Variable(id)), replaceFromIDs(substs, sol.pre))
               } else {
diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala
index ae5f731b9..ed25fc4b3 100644
--- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala
+++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala
@@ -23,7 +23,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") {
       case cct @ CaseClassType(ccd, _) if !ccd.isAbstract =>
         val newIds = cct.fields.map{ vd => FreshIdentifier(vd.id.name, vd.getType, true) }
 
-        val map = (ccd.fields zip newIds).map{ case (vd, nid) => nid -> CaseClassSelector(cct, Variable(id), vd.id) }.toMap
+        val map = (ccd.fields zip newIds).map{ case (vd, nid) => nid -> caseClassSelector(cct, Variable(id), vd.id) }.toMap
 
         (newIds.toList, CaseClass(cct, newIds.map(Variable)), map)
 
diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala
index bbee27665..2ed2a08de 100644
--- a/src/main/scala/leon/termination/ChainComparator.scala
+++ b/src/main/scala/leon/termination/ChainComparator.scala
@@ -34,7 +34,7 @@ trait ChainComparator { self : StructuralSize =>
     def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match  {
       case ContainerType(cct, fields) =>
         powerSetToFunSet(fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) =>
-          rec(fieldTpe).map(recons => (e: Expr) => recons(CaseClassSelector(cct, e, fieldId)))
+          rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId)))
         })
       case TupleType(tpes) =>
         powerSetToFunSet((0 until tpes.length).flatMap { case index =>
@@ -50,7 +50,7 @@ trait ChainComparator { self : StructuralSize =>
     def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match {
       case ContainerType(cct, fields) =>
         fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) =>
-          rec(fieldTpe).map(recons => (e: Expr) => recons(CaseClassSelector(cct, e, fieldId)))
+          rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId)))
         }.toSet
       case TupleType(tpes) =>
         (0 until tpes.length).flatMap { case index =>
diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala
index 62e62a9e2..dea267601 100644
--- a/src/main/scala/leon/utils/InliningPhase.scala
+++ b/src/main/scala/leon/utils/InliningPhase.scala
@@ -5,10 +5,10 @@ package leon.utils
 import leon._
 import purescala.Definitions._
 import purescala.Expressions._
-import purescala.Types._
 import purescala.TypeOps._
 import purescala.ExprOps._
 import purescala.DefOps._
+import purescala.Constructors.caseClassSelector
 
 object InliningPhase extends TransformationPhase {
 
@@ -27,14 +27,14 @@ object InliningPhase extends TransformationPhase {
 
     def simplifyImplicitClass(e: Expr) = e match {
       case CaseClassSelector(cct, cc: CaseClass, id) =>
-        Some(CaseClassSelector(cct, cc, id))
+        Some(caseClassSelector(cct, cc, id))
 
       case _ =>
         None
     }
 
     def simplify(e: Expr) = {
-      fixpoint(postMap(simplifyImplicitClass _))(e)
+      fixpoint(postMap(simplifyImplicitClass))(e)
     }
 
     for (fd <- p.definedFunctions) {
diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala
index c780c647f..7bbd332c6 100644
--- a/src/main/scala/leon/verification/InductionTactic.scala
+++ b/src/main/scala/leon/verification/InductionTactic.scala
@@ -23,7 +23,7 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) {
   private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr): Seq[Expr] = {
     val childrenOfSameType = cct.fields.filter(_.getType == parentType)
     for (field <- childrenOfSameType) yield {
-      CaseClassSelector(cct, expr, field.id)
+      caseClassSelector(cct, expr, field.id)
     }
   }
 
-- 
GitLab