From e45a3e185ca9d68b7af4a6ff335e3300722c9f15 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Thu, 9 Jul 2015 16:48:51 +0200
Subject: [PATCH] Improve/fix/use consistently features related to deep type
 hierarchies

---
 .../scala/leon/codegen/CompilationUnit.scala   | 10 ++++++++--
 .../leon/frontends/scalac/CodeExtraction.scala |  4 ++--
 .../scala/leon/purescala/Definitions.scala     | 18 +++++++++++++-----
 .../scala/leon/purescala/MethodLifting.scala   |  3 ++-
 .../leon/solvers/smtlib/SMTLIBSolver.scala     |  4 ++--
 .../leon/solvers/z3/AbstractZ3Solver.scala     |  6 +-----
 .../synthesis/utils/ExpressionGrammar.scala    |  2 +-
 .../leon/termination/ChainComparator.scala     |  2 +-
 .../leon/termination/StructuralSize.scala      | 12 +++---------
 src/main/scala/leon/utils/TypingPhase.scala    | 17 ++++++++++-------
 10 files changed, 43 insertions(+), 35 deletions(-)

diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala
index 85625077d..ef820f287 100644
--- a/src/main/scala/leon/codegen/CompilationUnit.scala
+++ b/src/main/scala/leon/codegen/CompilationUnit.scala
@@ -398,7 +398,10 @@ class CompilationUnit(val ctx: LeonContext,
     // First define all classes/ methods/ functions
     for (u <- program.units) {
 
-      for ( cls <- u.definedClassesOrdered ) {
+      for {
+        ch  <- u.classHierarchies
+        cls <- ch
+      } {
         defineClass(cls)
         for (meth <- cls.methods) {
           defToModuleOrClass += meth -> cls
@@ -419,7 +422,10 @@ class CompilationUnit(val ctx: LeonContext,
     // Compile everything
     for (u <- program.units) {
       
-      for (c <- u.definedClassesOrdered) {
+      for {
+        ch <- u.classHierarchies
+        c  <- ch
+      } {
         c match {
           case acd: AbstractClassDef =>
             compileAbstractClassDef(acd)
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index dae5018a0..4d4a0af16 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -500,13 +500,13 @@ trait CodeExtraction extends ASTExtractors {
         val acd = AbstractClassDef(id, tparams, parent).setPos(sym.pos)
 
         classesToClasses += sym -> acd
-        parent.foreach(_.classDef.registerChildren(acd))
+        parent.foreach(_.classDef.registerChild(acd))
 
         acd
       } else {
         val ccd = CaseClassDef(id, tparams, parent, sym.isModuleClass).setPos(sym.pos)
         classesToClasses += sym -> ccd
-        parent.foreach(_.classDef.registerChildren(ccd))
+        parent.foreach(_.classDef.registerChild(ccd))
 
         val fields = args.map { case (symbol, t) =>
           val tpt = t.tpt
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index 8b3e6ca9a..6841dccc7 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -136,7 +136,10 @@ object Definitions {
       definedClasses.filter(!_.hasParent)
     }
 
-    def definedClassesOrdered = classHierarchyRoots flatMap { root => root +: root.knownDescendents }
+    // Guarantees that a parent always appears before its children
+    def classHierarchies = classHierarchyRoots map { root =>
+      root +: root.knownDescendents
+    }
 
     def singleCaseClasses = {
       definedClasses.collect {
@@ -196,7 +199,7 @@ object Definitions {
 
     private var _children: List[ClassDef] = Nil
 
-    def registerChildren(chd: ClassDef) = {
+    def registerChild(chd: ClassDef) = {
       _children = (chd :: _children).sortBy(_.id.name)
     }
 
@@ -239,7 +242,6 @@ object Definitions {
 
     lazy val definedFunctions : Seq[FunDef] = methods
     lazy val definedClasses = Seq(this)
-    lazy val classHierarchyRoots = if (this.hasParent) Seq(this) else Nil
 
     def typed(tps: Seq[TypeTree]): ClassType
     def typed: ClassType
@@ -256,7 +258,10 @@ object Definitions {
     
     lazy val singleCaseClasses : Seq[CaseClassDef] = Nil
 
-    def typed(tps: Seq[TypeTree]) = AbstractClassType(this, tps)
+    def typed(tps: Seq[TypeTree]) = {
+      require(tps.length == tparams.length)
+      AbstractClassType(this, tps)
+    }
     def typed: AbstractClassType = typed(tparams.map(_.tp))
   }
 
@@ -289,7 +294,10 @@ object Definitions {
     
     lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this)
 
-    def typed(tps: Seq[TypeTree]): CaseClassType = CaseClassType(this, tps)
+    def typed(tps: Seq[TypeTree]): CaseClassType = {
+      require(tps.length == tparams.length)
+      CaseClassType(this, tps)
+    }
     def typed: CaseClassType = typed(tparams.map(_.tp))
   }
 
diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala
index 7cf7aea4b..e87c332e9 100644
--- a/src/main/scala/leon/purescala/MethodLifting.scala
+++ b/src/main/scala/leon/purescala/MethodLifting.scala
@@ -83,7 +83,8 @@ object MethodLifting extends TransformationPhase {
 
       // Lift methods to the root class
       for {
-        c <- u.definedClassesOrdered
+        ch <- u.classHierarchies
+        c  <- ch
         if c.parent.isDefined
         fd <- c.methods
         if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id))
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
index 5b14f2f4c..9b78a1fbb 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
@@ -124,9 +124,9 @@ abstract class SMTLIBSolver(val context: LeonContext,
   /* Helper functions */
 
   protected def normalizeType(t: TypeTree): TypeTree = t match {
-    case ct: ClassType if ct.parent.isDefined => ct.parent.get
+    case ct: ClassType => ct.root
     case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType))
-    case _ =>   t
+    case _ => t
   }
 
   protected def quantifiedTerm(
diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
index c6f0c342e..9d57eb723 100644
--- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
@@ -174,11 +174,7 @@ trait AbstractZ3Solver
   }
 
   def rootType(ct: TypeTree): TypeTree = ct match {
-    case ct: ClassType =>
-      ct.parent match {
-        case Some(p) => rootType(p)
-        case None => ct
-      }
+    case ct: ClassType => ct.root
     case t => t
   }
 
diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
index 2c539be22..8e6fec782 100644
--- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
+++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala
@@ -313,7 +313,7 @@ object ExpressionGrammars {
         def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = {
           val CaseClass(cct, args) = cc
 
-          val neighbors = cct.parent.map(_.knownCCDescendents).getOrElse(Seq()).filter(_ != cct)
+          val neighbors = cct.root.knownCCDescendents diff Seq(cct)
 
           for (scct <- neighbors if scct.fieldsTypes == cct.fieldsTypes) yield {
             gl -> Generator[L, Expr](Nil, { _ => CaseClass(scct, args) })
diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala
index 2ed2a08de..16d4cb460 100644
--- a/src/main/scala/leon/termination/ChainComparator.scala
+++ b/src/main/scala/leon/termination/ChainComparator.scala
@@ -16,7 +16,7 @@ trait ChainComparator { self : StructuralSize =>
   private object ContainerType {
     def unapply(c: ClassType): Option[(CaseClassType, Seq[(Identifier, TypeTree)])] = c match {
       case cct @ CaseClassType(ccd, _) =>
-        if (cct.fields.exists(arg => isSubtypeOf(arg.getType, cct.parent.getOrElse(c)))) None
+        if (cct.fields.exists(arg => isSubtypeOf(arg.getType, cct.root))) None
         else if (ccd.hasParent && ccd.parent.get.knownDescendents.size > 1) None
         else Some((cct, cct.fields.map(arg => arg.id -> arg.getType)))
       case _ => None
diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala
index 329eaae06..6e0bed6c0 100644
--- a/src/main/scala/leon/termination/StructuralSize.scala
+++ b/src/main/scala/leon/termination/StructuralSize.scala
@@ -41,15 +41,9 @@ trait StructuralSize {
   def size(expr: Expr) : Expr = {
     def funDef(ct: ClassType, cases: ClassType => Seq[MatchCase]): FunDef = {
       // we want to reuse generic size functions for sub-types
-      val (argumentType, typeParams) = ct match {
-        case (cct : CaseClassType) if cct.parent.isDefined =>
-          val classDef = cct.parent.get.classDef
-          val tparams = classDef.tparams.map(_.tp)
-          (classDef typed tparams, tparams)
-        case (ct : ClassType) =>
-          val tparams = ct.classDef.tparams.map(_.tp)
-          (ct.classDef typed tparams, tparams)
-      }
+      val classDef = ct.root.classDef
+      val argumentType = classDef.typed
+      val typeParams = classDef.tparams.map(_.tp)
 
       sizeCache.get(argumentType) match {
         case Some(fd) => fd
diff --git a/src/main/scala/leon/utils/TypingPhase.scala b/src/main/scala/leon/utils/TypingPhase.scala
index ee7fd600d..ec5d8048d 100644
--- a/src/main/scala/leon/utils/TypingPhase.scala
+++ b/src/main/scala/leon/utils/TypingPhase.scala
@@ -34,9 +34,12 @@ object TypingPhase extends LeonPhase[Program, Program] {
       // Part (1)
       fd.precondition = {
         val argTypesPreconditions = fd.params.flatMap(arg => arg.getType match {
-          case cct : CaseClassType if cct.parent.isDefined => Seq(IsInstanceOf(cct, arg.id.toVariable))
-          case (at : ArrayType) => Seq(GreaterEquals(ArrayLength(arg.id.toVariable), IntLiteral(0)))
-          case _ => Seq()
+          case cct: ClassType if cct.parent.isDefined =>
+            Seq(IsInstanceOf(cct, arg.id.toVariable))
+          case at: ArrayType =>
+            Seq(GreaterEquals(ArrayLength(arg.id.toVariable), IntLiteral(0)))
+          case _ =>
+            Seq()
         })
         argTypesPreconditions match {
           case Nil => fd.precondition
@@ -48,17 +51,17 @@ object TypingPhase extends LeonPhase[Program, Program] {
       }
 
       fd.postcondition = fd.returnType match {
-        case cct : CaseClassType if cct.parent.isDefined => {
-          val resId = FreshIdentifier("res", cct)
+        case ct: ClassType if ct.parent.isDefined => {
+          val resId = FreshIdentifier("res", ct)
           fd.postcondition match {
             case Some(p) =>
               Some(Lambda(Seq(ValDef(resId)), and(
                 application(p, Seq(Variable(resId))),
-                IsInstanceOf(cct, Variable(resId))
+                IsInstanceOf(ct, Variable(resId))
               ).setPos(p)).setPos(p))
 
             case None =>
-              Some(Lambda(Seq(ValDef(resId)), IsInstanceOf(cct, Variable(resId))))
+              Some(Lambda(Seq(ValDef(resId)), IsInstanceOf(ct, Variable(resId))))
           }
         }
         case _ => fd.postcondition
-- 
GitLab