diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index c14ade4c7cf6c77e4b2dc91b75b2ed0b582d12a0..3a94aa0fcd45bbb412a04fa69cecc5c232bf6d67 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -396,7 +396,7 @@ trait CodeExtraction extends ASTExtractors { case (sym, cl) => sym.fullName.toString == "leon.lang.synthesis.Oracle" } match { case Some((_, cd)) => - classDefToClassType(cd, List(tpe)) + cd.typed(List(tpe)) case None => outOfSubsetError(pos, "Could not find class Oracle") } @@ -1700,7 +1700,7 @@ trait CodeExtraction extends ASTExtractors { case tt: ThisType => val cd = getClassDef(tt.sym, pos) - classDefToClassType(cd, cd.tparams.map(_.tp)) // Typed using own's type parameters + cd.typed // Typed using own's type parameters case SingleType(_, sym) => getClassType(sym.moduleClass, Nil) @@ -1737,7 +1737,7 @@ trait CodeExtraction extends ASTExtractors { private def getClassType(sym: Symbol, tps: List[LeonType])(implicit dctx: DefContext) = { if (seenClasses contains sym) { - classDefToClassType(getClassDef(sym, NoPosition), tps) + getClassDef(sym, NoPosition).typed(tps) } else { outOfSubsetError(NoPosition, "Unknown class "+sym.fullName) } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index ab8ab0cf9d754eea80fa9c0af724bcd7084bf76f..2d37cce58ed9023e86d6a46fb9b2fd5161198e95 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -79,9 +79,9 @@ object Definitions { lazy val callGraph = new CallGraph(this) - def caseClassDef(name: String) = definedClasses.collect { + def caseClassDef(name: String) = definedClasses.collectFirst { case ccd: CaseClassDef if ccd.id.name == name => ccd - }.headOption.getOrElse(throw LeonFatalError("Unknown case class '"+name+"'")) + }.getOrElse(throw LeonFatalError("Unknown case class '"+name+"'")) def lookupAll(name: String) = DefOps.searchWithin(name, this) def lookup(name: String) = lookupAll(name).headOption @@ -241,6 +241,8 @@ object Definitions { lazy val definedClasses = Seq(this) lazy val classHierarchyRoots = if (this.hasParent) Seq(this) else Nil + def typed(tps: Seq[TypeTree]): ClassType + def typed: ClassType } /** Abstract classes. */ @@ -253,6 +255,9 @@ object Definitions { val isCaseObject = false lazy val singleCaseClasses : Seq[CaseClassDef] = Nil + + def typed(tps: Seq[TypeTree]) = AbstractClassType(this, tps) + def typed: AbstractClassType = typed(tparams.map(_.tp)) } /** Case classes/objects. */ @@ -282,6 +287,9 @@ object Definitions { } lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this) + + def typed(tps: Seq[TypeTree]): CaseClassType = CaseClassType(this, tps) + def typed: CaseClassType = typed(tparams.map(_.tp)) } // A class that represents flags that annotate a FunDef with different attributes diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 84d7f515608f06e289756282e4d14990fd8ab482..fb25f9d52e098636b9b18dd78f3ecc068cb11dc7 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -938,7 +938,7 @@ object ExprOps { val orderedChildren = nonRecChildren.sortBy(_.fields.size) - simplestValue(classDefToClassType(orderedChildren.head, tpe)) + simplestValue(orderedChildren.head.typed(tpe)) case cct: CaseClassType => CaseClass(cct, cct.fieldsTypes.map(t => simplestValue(t))) diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index cb124cc5bb0099ff656f6ec6461fb05881927493..1ab85aed884aad3bf72b5cca65d9b0f8d532e726 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -28,7 +28,7 @@ object MethodLifting extends TransformationPhase { case None => (List(), false) case Some(m) => - val ct = classDefToClassType(ccd).asInstanceOf[CaseClassType] + val ct = ccd.typed val binder = FreshIdentifier(ccd.id.name.toLowerCase, ct, true) val fBinders = ct.fields.map{ f => f.id -> f.id.freshen }.toMap def subst(e: Expr): Expr = e match { @@ -57,7 +57,7 @@ object MethodLifting extends TransformationPhase { } else { // We have something to add val m = acd.methods.find( m => m.id == fdId ).get - val at = classDefToClassType(acd).asInstanceOf[AbstractClassType] + val at = acd.typed val binder = FreshIdentifier(acd.id.name.toLowerCase, at, true) def subst(e: Expr): Expr = e match { case This(`at`) => @@ -107,8 +107,8 @@ object MethodLifting extends TransformationPhase { Some(and( prec, IsInstanceOf( - classDefToClassType(c,root.tparams.map{ _.tp }), - This(classDefToClassType(root)) + c.typed(root.tparams.map{ _.tp }), + This(root.typed) ) )) )) @@ -124,7 +124,7 @@ object MethodLifting extends TransformationPhase { val tparamsMap = cd.tparams.zip(ctParams map { _.tp }).toMap val id = fd.id.freshen - val recType = classDefToClassType(cd, ctParams.map(_.tp)) + val recType = cd.typed(ctParams.map(_.tp)) val retType = instantiateType(fd.returnType, tparamsMap) val fdParams = fd.params map { vd => val newId = FreshIdentifier(vd.id.name, instantiateType(vd.id.getType, tparamsMap)) diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index ed9f730220c32b1a2e4ceb9cba8432aa5551e7f9..54e633e08bfa893178d984d9d1ab12925452a630 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -88,17 +88,14 @@ object Types { if (tmap.isEmpty) { classDef.fields } else { - // !! WARNING !! - // vd.id.getType will NOT match vd.tpe, but we kind of need this for selectorID2Index... - // See with Etienne about changing this! - // @mk Fixed this + // This is the only case where ValDef overrides the type of its Identifier classDef.fields.map(vd => ValDef(vd.id, Some(instantiateType(vd.getType, tmap)))) } } - def knownDescendents = classDef.knownDescendents.map(classDefToClassType(_, tps)) + def knownDescendents = classDef.knownDescendents.map( _.typed(tps) ) - def knownCCDescendents = classDef.knownCCDescendents.map(CaseClassType(_, tps)) + def knownCCDescendents = classDef.knownCCDescendents.map( _.typed(tps) ) lazy val fieldsTypes = fields.map(_.getType) @@ -113,17 +110,7 @@ object Types { } case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType - case class CaseClassType(override val classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType - - def classDefToClassType(cd: ClassDef, tps: Seq[TypeTree]): ClassType = cd match { - case a: AbstractClassDef => AbstractClassType(a, tps) - case c: CaseClassDef => CaseClassType(c, tps) - } - - // Using definition types - def classDefToClassType(cd: ClassDef): ClassType = { - classDefToClassType(cd, cd.tparams.map(_.tp)) - } + case class CaseClassType(classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType object NAryType { def unapply(t: TypeTree): Option[(Seq[TypeTree], Seq[TypeTree] => TypeTree)] = t match { diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index a6b08d5986b1cc9dbac5736e273697c6df300298..a9bd1c5c4d2dc2dbb9ee1ce8a8d035f6f3905c2f 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -45,10 +45,10 @@ trait StructuralSize { case (cct : CaseClassType) if cct.parent.isDefined => val classDef = cct.parent.get.classDef val tparams = classDef.tparams.map(_.tp) - (classDefToClassType(classDef, tparams), tparams) + (classDef typed tparams, tparams) case (ct : ClassType) => val tparams = ct.classDef.tparams.map(_.tp) - (classDefToClassType(ct.classDef, tparams), tparams) + (ct.classDef typed tparams, tparams) } sizeCache.get(argumentType) match { diff --git a/src/test/scala/leon/test/purescala/DataGenSuite.scala b/src/test/scala/leon/test/purescala/DataGenSuite.scala index 2ba0bf0178f5589d1b4c8e07004a46232ca94633..6784078c4fc3f10c6e53c5bdcec23be63c214b58 100644 --- a/src/test/scala/leon/test/purescala/DataGenSuite.scala +++ b/src/test/scala/leon/test/purescala/DataGenSuite.scala @@ -67,7 +67,7 @@ class DataGenSuite extends LeonTestSuite { // Make sure we target our own lists val module = prog.units.flatMap{_.modules}.find(_.id.name == "Program").get - val listType : TypeTree = classDefToClassType(module.classHierarchyRoots.head) + val listType : TypeTree = module.classHierarchyRoots.head.typed val sizeDef : FunDef = module.definedFunctions.find(_.id.name == "size").get val sortedDef : FunDef = module.definedFunctions.find(_.id.name == "isSorted").get val contentDef : FunDef = module.definedFunctions.find(_.id.name == "content").get