From e45a3e185ca9d68b7af4a6ff335e3300722c9f15 Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Thu, 9 Jul 2015 16:48:51 +0200 Subject: [PATCH] Improve/fix/use consistently features related to deep type hierarchies --- .../scala/leon/codegen/CompilationUnit.scala | 10 ++++++++-- .../leon/frontends/scalac/CodeExtraction.scala | 4 ++-- .../scala/leon/purescala/Definitions.scala | 18 +++++++++++++----- .../scala/leon/purescala/MethodLifting.scala | 3 ++- .../leon/solvers/smtlib/SMTLIBSolver.scala | 4 ++-- .../leon/solvers/z3/AbstractZ3Solver.scala | 6 +----- .../synthesis/utils/ExpressionGrammar.scala | 2 +- .../leon/termination/ChainComparator.scala | 2 +- .../leon/termination/StructuralSize.scala | 12 +++--------- src/main/scala/leon/utils/TypingPhase.scala | 17 ++++++++++------- 10 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 85625077d..ef820f287 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -398,7 +398,10 @@ class CompilationUnit(val ctx: LeonContext, // First define all classes/ methods/ functions for (u <- program.units) { - for ( cls <- u.definedClassesOrdered ) { + for { + ch <- u.classHierarchies + cls <- ch + } { defineClass(cls) for (meth <- cls.methods) { defToModuleOrClass += meth -> cls @@ -419,7 +422,10 @@ class CompilationUnit(val ctx: LeonContext, // Compile everything for (u <- program.units) { - for (c <- u.definedClassesOrdered) { + for { + ch <- u.classHierarchies + c <- ch + } { c match { case acd: AbstractClassDef => compileAbstractClassDef(acd) diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index dae5018a0..4d4a0af16 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -500,13 +500,13 @@ trait CodeExtraction extends ASTExtractors { val acd = AbstractClassDef(id, tparams, parent).setPos(sym.pos) classesToClasses += sym -> acd - parent.foreach(_.classDef.registerChildren(acd)) + parent.foreach(_.classDef.registerChild(acd)) acd } else { val ccd = CaseClassDef(id, tparams, parent, sym.isModuleClass).setPos(sym.pos) classesToClasses += sym -> ccd - parent.foreach(_.classDef.registerChildren(ccd)) + parent.foreach(_.classDef.registerChild(ccd)) val fields = args.map { case (symbol, t) => val tpt = t.tpt diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 8b3e6ca9a..6841dccc7 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -136,7 +136,10 @@ object Definitions { definedClasses.filter(!_.hasParent) } - def definedClassesOrdered = classHierarchyRoots flatMap { root => root +: root.knownDescendents } + // Guarantees that a parent always appears before its children + def classHierarchies = classHierarchyRoots map { root => + root +: root.knownDescendents + } def singleCaseClasses = { definedClasses.collect { @@ -196,7 +199,7 @@ object Definitions { private var _children: List[ClassDef] = Nil - def registerChildren(chd: ClassDef) = { + def registerChild(chd: ClassDef) = { _children = (chd :: _children).sortBy(_.id.name) } @@ -239,7 +242,6 @@ object Definitions { lazy val definedFunctions : Seq[FunDef] = methods lazy val definedClasses = Seq(this) - lazy val classHierarchyRoots = if (this.hasParent) Seq(this) else Nil def typed(tps: Seq[TypeTree]): ClassType def typed: ClassType @@ -256,7 +258,10 @@ object Definitions { lazy val singleCaseClasses : Seq[CaseClassDef] = Nil - def typed(tps: Seq[TypeTree]) = AbstractClassType(this, tps) + def typed(tps: Seq[TypeTree]) = { + require(tps.length == tparams.length) + AbstractClassType(this, tps) + } def typed: AbstractClassType = typed(tparams.map(_.tp)) } @@ -289,7 +294,10 @@ object Definitions { lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this) - def typed(tps: Seq[TypeTree]): CaseClassType = CaseClassType(this, tps) + def typed(tps: Seq[TypeTree]): CaseClassType = { + require(tps.length == tparams.length) + CaseClassType(this, tps) + } def typed: CaseClassType = typed(tparams.map(_.tp)) } diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 7cf7aea4b..e87c332e9 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -83,7 +83,8 @@ object MethodLifting extends TransformationPhase { // Lift methods to the root class for { - c <- u.definedClassesOrdered + ch <- u.classHierarchies + c <- ch if c.parent.isDefined fd <- c.methods if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id)) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 5b14f2f4c..9b78a1fbb 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -124,9 +124,9 @@ abstract class SMTLIBSolver(val context: LeonContext, /* Helper functions */ protected def normalizeType(t: TypeTree): TypeTree = t match { - case ct: ClassType if ct.parent.isDefined => ct.parent.get + case ct: ClassType => ct.root case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType)) - case _ => t + case _ => t } protected def quantifiedTerm( diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index c6f0c342e..9d57eb723 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -174,11 +174,7 @@ trait AbstractZ3Solver } def rootType(ct: TypeTree): TypeTree = ct match { - case ct: ClassType => - ct.parent match { - case Some(p) => rootType(p) - case None => ct - } + case ct: ClassType => ct.root case t => t } diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index 2c539be22..8e6fec782 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -313,7 +313,7 @@ object ExpressionGrammars { def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = { val CaseClass(cct, args) = cc - val neighbors = cct.parent.map(_.knownCCDescendents).getOrElse(Seq()).filter(_ != cct) + val neighbors = cct.root.knownCCDescendents diff Seq(cct) for (scct <- neighbors if scct.fieldsTypes == cct.fieldsTypes) yield { gl -> Generator[L, Expr](Nil, { _ => CaseClass(scct, args) }) diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index 2ed2a08de..16d4cb460 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -16,7 +16,7 @@ trait ChainComparator { self : StructuralSize => private object ContainerType { def unapply(c: ClassType): Option[(CaseClassType, Seq[(Identifier, TypeTree)])] = c match { case cct @ CaseClassType(ccd, _) => - if (cct.fields.exists(arg => isSubtypeOf(arg.getType, cct.parent.getOrElse(c)))) None + if (cct.fields.exists(arg => isSubtypeOf(arg.getType, cct.root))) None else if (ccd.hasParent && ccd.parent.get.knownDescendents.size > 1) None else Some((cct, cct.fields.map(arg => arg.id -> arg.getType))) case _ => None diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index 329eaae06..6e0bed6c0 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -41,15 +41,9 @@ trait StructuralSize { def size(expr: Expr) : Expr = { def funDef(ct: ClassType, cases: ClassType => Seq[MatchCase]): FunDef = { // we want to reuse generic size functions for sub-types - val (argumentType, typeParams) = ct match { - case (cct : CaseClassType) if cct.parent.isDefined => - val classDef = cct.parent.get.classDef - val tparams = classDef.tparams.map(_.tp) - (classDef typed tparams, tparams) - case (ct : ClassType) => - val tparams = ct.classDef.tparams.map(_.tp) - (ct.classDef typed tparams, tparams) - } + val classDef = ct.root.classDef + val argumentType = classDef.typed + val typeParams = classDef.tparams.map(_.tp) sizeCache.get(argumentType) match { case Some(fd) => fd diff --git a/src/main/scala/leon/utils/TypingPhase.scala b/src/main/scala/leon/utils/TypingPhase.scala index ee7fd600d..ec5d8048d 100644 --- a/src/main/scala/leon/utils/TypingPhase.scala +++ b/src/main/scala/leon/utils/TypingPhase.scala @@ -34,9 +34,12 @@ object TypingPhase extends LeonPhase[Program, Program] { // Part (1) fd.precondition = { val argTypesPreconditions = fd.params.flatMap(arg => arg.getType match { - case cct : CaseClassType if cct.parent.isDefined => Seq(IsInstanceOf(cct, arg.id.toVariable)) - case (at : ArrayType) => Seq(GreaterEquals(ArrayLength(arg.id.toVariable), IntLiteral(0))) - case _ => Seq() + case cct: ClassType if cct.parent.isDefined => + Seq(IsInstanceOf(cct, arg.id.toVariable)) + case at: ArrayType => + Seq(GreaterEquals(ArrayLength(arg.id.toVariable), IntLiteral(0))) + case _ => + Seq() }) argTypesPreconditions match { case Nil => fd.precondition @@ -48,17 +51,17 @@ object TypingPhase extends LeonPhase[Program, Program] { } fd.postcondition = fd.returnType match { - case cct : CaseClassType if cct.parent.isDefined => { - val resId = FreshIdentifier("res", cct) + case ct: ClassType if ct.parent.isDefined => { + val resId = FreshIdentifier("res", ct) fd.postcondition match { case Some(p) => Some(Lambda(Seq(ValDef(resId)), and( application(p, Seq(Variable(resId))), - IsInstanceOf(cct, Variable(resId)) + IsInstanceOf(ct, Variable(resId)) ).setPos(p)).setPos(p)) case None => - Some(Lambda(Seq(ValDef(resId)), IsInstanceOf(cct, Variable(resId)))) + Some(Lambda(Seq(ValDef(resId)), IsInstanceOf(ct, Variable(resId)))) } } case _ => fd.postcondition -- GitLab