From 5d8d8ea1f0800c351c87117353a495208f91bc39 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Thu, 27 Aug 2015 12:03:58 +0200
Subject: [PATCH] Handle case-class field implementing method

---
 .../leon/frontends/scalac/ASTExtractors.scala | 19 +++++-
 .../frontends/scalac/CodeExtraction.scala     | 54 +++++++----------
 .../scala/leon/purescala/MethodLifting.scala  | 60 ++++++++++++-------
 .../purescala/invalid/FieldInheritance.scala  | 22 +++++++
 .../purescala/valid/FieldInheritance.scala    | 25 ++++++++
 5 files changed, 127 insertions(+), 53 deletions(-)
 create mode 100644 src/test/resources/regression/verification/purescala/invalid/FieldInheritance.scala
 create mode 100644 src/test/resources/regression/verification/purescala/valid/FieldInheritance.scala

diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
index d8910012f..066955731 100644
--- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
+++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
@@ -350,7 +350,24 @@ trait ASTExtractors {
             case _ => false
           }.get.asInstanceOf[DefDef]
 
-          val args = constructor.vparamss.flatten.map(vd => ( vd.symbol, vd))
+          val valDefs = constructor.vparamss.flatten
+
+          //impl.children foreach println
+
+          val symbols = impl.children.collect {
+            case df: DefDef if df.symbol.isStable && df.symbol.isAccessor &&
+                df.symbol.isParamAccessor =>
+              df.symbol
+          }
+
+          //if (symbols.size != valDefs.size) {
+          //  println(" >>>>> " + cd.name)
+          //  symbols foreach println
+          //  valDefs foreach println
+          //}
+
+          val args = symbols zip valDefs
+
           Some((name.toString, cd.symbol, args, impl))
         }
         case _ => None
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 62627f8c8..9f9db555e 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -518,10 +518,10 @@ trait CodeExtraction extends ASTExtractors {
         classesToClasses += sym -> ccd
         parent.foreach(_.classDef.registerChild(ccd))
 
-        val fields = args.map { case (symbol, t) =>
-          val tpt = t.tpt
-          val tpe = leonType(tpt.tpe)(defCtx, sym.pos)
-          LeonValDef(FreshIdentifier(symbol.name.toString, tpe).setPos(t.pos)).setPos(t.pos)
+        val fields = args.map { case (fsym, t) =>
+          val tpe = leonType(t.tpt.tpe)(defCtx, fsym.pos)
+          val id = overridenOrFresh(fsym, Some(ccd), tpe)
+          LeonValDef(id.setPos(t.pos), Some(tpe)).setPos(t.pos)
         }
 
         ccd.setFields(fields)
@@ -607,6 +607,20 @@ trait CodeExtraction extends ASTExtractors {
       cd
     }
 
+    // Returns the parent's method Identifier if sym overrides a symbol, otherwise a fresh Identifier
+    private def overridenOrFresh(sym: Symbol, within: Option[LeonClassDef], tpe: LeonType = Untyped) = {
+      val name = sym.name.toString
+      if (sym.overrideChain.length > 1) {
+        (for {
+          cd <- within
+          p <- cd.parent
+          m <- p.classDef.methods.find(_.id.name == name)
+        } yield m.id).getOrElse(FreshIdentifier(name, tpe))
+      } else {
+        FreshIdentifier(name, tpe)
+      }
+    }
+
     private var defsToDefs = Map[Symbol, FunDef]()
 
     private def defineFunDef(sym: Symbol, within: Option[LeonClassDef] = None)(implicit dctx: DefContext): FunDef = {
@@ -626,19 +640,7 @@ trait CodeExtraction extends ASTExtractors {
 
       val returnType = leonType(sym.info.finalResultType)(nctx, sym.pos)
 
-      val name = sym.name.toString
-
-      val id = {
-        if (sym.overrideChain.length > 1) {
-          (for {
-            cd <- within
-            p <- cd.parent
-            m <- p.classDef.methods.find(_.id.name == name)
-          } yield m.id).getOrElse(FreshIdentifier(name))
-        } else {
-          FreshIdentifier(name)
-        }
-      }
+      val id = overridenOrFresh(sym, within)
 
       val fd = new FunDef(id.setPos(sym.pos), tparamsDef, returnType, newParams)
 
@@ -661,19 +663,7 @@ trait CodeExtraction extends ASTExtractors {
 
       val returnType = leonType(sym.info.finalResultType)(nctx, sym.pos)
 
-      val name = sym.name.toString
-
-      val id =
-        if (sym.overrideChain.length == 1) {
-          FreshIdentifier(name)
-        } else {
-          ( for {
-            cd <- within
-            p <- cd.parent
-            m <- p.classDef.methods.find(_.id.name == name)
-          } yield m.id).getOrElse(FreshIdentifier(name))
-        }
-
+      val id = overridenOrFresh(sym, within)
       val fd = new FunDef(id.setPos(sym.pos), Seq(), returnType, Seq())
 
       fd.setPos(sym.pos)
@@ -861,7 +851,7 @@ trait CodeExtraction extends ASTExtractors {
         // case Obj =>
         extractType(s) match {
           case ct: CaseClassType =>
-            assert(ct.classDef.fields.size == 0)
+            assert(ct.classDef.fields.isEmpty)
             (CaseClassPattern(binder, ct, Seq()).setPos(p.pos), dctx)
           case _ =>
             outOfSubsetError(s, "Invalid type "+s.tpe+" for .isInstanceOf")
@@ -1494,7 +1484,7 @@ trait CodeExtraction extends ASTExtractors {
 
           //println(s"symbol $sym with id ${sym.id}")
           //println(s"isMethod($sym) == ${isMethod(sym)}")
-          
+
           (rrec, sym.name.decoded, rargs) match {
             case (null, _, args) =>
               val fd = getFunDef(sym, c.pos)
diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala
index 5f41bc811..38319eb22 100644
--- a/src/main/scala/leon/purescala/MethodLifting.scala
+++ b/src/main/scala/leon/purescala/MethodLifting.scala
@@ -24,25 +24,45 @@ object MethodLifting extends TransformationPhase {
   // A Seq of MatchCases is returned, along with a boolean that signifies if the matching is complete.
   private def makeCases(cd: ClassDef, fdId: Identifier, breakDown: Expr => Expr): (Seq[MatchCase], Boolean) = cd match {
     case ccd: CaseClassDef =>
-      ccd.methods.find( _.id == fdId) match {
-        case None =>
-          (List(), false)
-        case Some(m) =>
-          val ct = ccd.typed
-          val binder = FreshIdentifier(ccd.id.name.toLowerCase, ct, true)
-          val fBinders = ct.fields.map{ f => f.id -> f.id.freshen }.toMap
-          def subst(e: Expr): Expr = e match {
-            case CaseClassSelector(`ct`, This(`ct`), i) =>
-              Variable(fBinders(i)).setPos(e)
-            case This(`ct`) =>
-              Variable(binder).setPos(e)
-            case e =>
-              e
-          }
-          val newE = simplePreTransform(subst)(breakDown(m.fullBody))
-          val subPatts = ct.fields map (f => WildcardPattern(Some(fBinders(f.id))))
-          val cse = SimpleCase(CaseClassPattern(Some(binder), ct, subPatts), newE).setPos(newE)
-          (List(cse), true)
+
+      // Common for both cases
+      val ct = ccd.typed
+      val binder = FreshIdentifier(ccd.id.name.toLowerCase, ct, true)
+      val fBinders = ct.fields.map{ f => f.id -> f.id.freshen }.toMap
+      def subst(e: Expr): Expr = e match {
+        case CaseClassSelector(`ct`, This(`ct`), i) =>
+          Variable(fBinders(i)).setPos(e)
+        case This(`ct`) =>
+          Variable(binder).setPos(e)
+        case e =>
+          e
+      }
+
+      ccd.methods.find( _.id == fdId).map { m =>
+
+        // Ancestor's method is a method in the case class
+        val subPatts = ct.fields map (f => WildcardPattern(Some(fBinders(f.id))))
+        val patt = CaseClassPattern(Some(binder), ct, subPatts)
+        val newE = simplePreTransform(subst)(breakDown(m.fullBody))
+        val cse = SimpleCase(patt, newE).setPos(newE)
+        (List(cse), true)
+
+      } orElse ccd.fields.find( _.id == fdId).map { f =>
+
+        // Ancestor's method is a case class argument in the case class
+        val subPatts = ct.fields map (fld =>
+          if (fld.id == f.id)
+            WildcardPattern(Some(fBinders(f.id)))
+          else
+            WildcardPattern(None)
+        )
+        val patt = CaseClassPattern(Some(binder), ct, subPatts)
+        val newE = breakDown(Variable(fBinders(f.id)))
+        val cse = SimpleCase(patt, newE).setPos(newE)
+        (List(cse), true)
+
+      } getOrElse {
+        (List(), false)
       }
     case acd: AbstractClassDef =>
       val (r, c) = acd.knownChildren.map(makeCases(_, fdId, breakDown)).unzip
@@ -139,7 +159,7 @@ object MethodLifting extends TransformationPhase {
         nfd.setPos(fd)
         nfd.addFlag(IsMethod(cd))
 
-        if (cd.knownDescendants.forall( _.methods.forall(_.id != fd.id))) {
+        if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) {
           val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap
           // Don't need to compose methods
           nfd.fullBody = postMap {
diff --git a/src/test/resources/regression/verification/purescala/invalid/FieldInheritance.scala b/src/test/resources/regression/verification/purescala/invalid/FieldInheritance.scala
new file mode 100644
index 000000000..7ef9d9522
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/invalid/FieldInheritance.scala
@@ -0,0 +1,22 @@
+import leon.lang._
+
+object FieldInheritance {
+
+  abstract class Foo[B] {
+    val thisIsIt: BigInt = 1
+    val y: BigInt
+    val weird: B
+  }
+
+  case class Bar[X](override val thisIsIt: BigInt, weird: X) extends Foo[X] {
+    val y = thisIsIt
+  }
+  
+  case class Baz[X](weird: X) extends Foo[X] {
+    val y = thisIsIt + 1
+  }
+
+
+  def foo[A](f: Foo[A]) = (f.thisIsIt == 1).holds
+
+}
diff --git a/src/test/resources/regression/verification/purescala/valid/FieldInheritance.scala b/src/test/resources/regression/verification/purescala/valid/FieldInheritance.scala
new file mode 100644
index 000000000..9b99fc84c
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/FieldInheritance.scala
@@ -0,0 +1,25 @@
+import leon.lang._
+
+object FieldInheritance {
+
+  abstract class Foo[B] {
+    val thisIsIt: BigInt = 1
+    val y: BigInt
+    val weird: B
+  }
+
+  case class Bar[X](override val thisIsIt: BigInt, weird: X) extends Foo[X] {
+    val y = thisIsIt
+  }
+  
+  case class Baz[X](weird: X) extends Foo[X] {
+    val y = thisIsIt + 1
+  }
+
+
+  def foo[A](f: Foo[A]) = { f match {
+    case Bar(t, _) => f.thisIsIt == t
+    case _ => true
+  }}.holds
+
+}
-- 
GitLab