diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index abe7d0ac0351bb9d2c81208e73daadde207fbfa9..eb010351db8a374196eef1582b3670b7cfe96d05 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -519,12 +519,6 @@ trait CodeExtraction extends ASTExtractors { acd } else { - val ccd = CaseClassDef(id, tparams, parent, sym.isModuleClass).setPos(sym.pos) - - parent.foreach(_.classDef.registerChildren(ccd)) - - - classesToClasses += sym -> ccd val fields = args.map { case (symbol, t) => val tpt = t.tpt @@ -532,7 +526,11 @@ trait CodeExtraction extends ASTExtractors { LeonValDef(FreshIdentifier(symbol.name.toString, tpe).setPos(t.pos)).setPos(t.pos) } - ccd.setFields(fields) + val ccd = CaseClassDef(id, tparams, fields, parent, sym.isModuleClass).setPos(sym.pos) + + parent.foreach(_.classDef.registerChildren(ccd)) + + classesToClasses += sym -> ccd // Validates type parameters parent match { diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index ae84099b75f113fa98e1874072aca6d52e4b333c..e7acc9a42f34f3b014ded101a86484d8747cac83 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -22,7 +22,7 @@ object DefOps { case other => pgm.units.find(_.containsDef(df)) } - private def pathFromRoot(df: Definition)(implicit pgm: Program): List[Definition] ={ + private def pathFromRoot(df: Definition)(implicit pgm: Program): List[Definition] = { def rec(from: Definition): List[Definition] = { from :: (if (from == df) { Nil @@ -234,22 +234,22 @@ object DefOps { searchWithin(names, p) case u: UnitDef => - val imports = u.imports.map { - case PackageImport(ls) => ImportPath(ls, true) - case SingleImport(d) => ImportPath(nameToParts(fullName(d)), false) - case WildcardImport(d) => ImportPath(nameToParts(fullName(d)), true) - }.toList - - val inModules = d.subDefinitions.filter(_.id.name == n).flatMap { sd => - searchWithin(ns, sd) - } + val imports = u.imports.map { + case PackageImport(ls) => ImportPath(ls, true) + case SingleImport(d) => ImportPath(nameToParts(fullName(d)), false) + case WildcardImport(d) => ImportPath(nameToParts(fullName(d)), true) + }.toList + + val inModules = d.subDefinitions.filter(_.id.name == n).flatMap { sd => + searchWithin(ns, sd) + } - val namesImported = resolveImports(imports, names) - val nameWithPackage = u.pack ++ names + val namesImported = resolveImports(imports, names) + val nameWithPackage = u.pack ++ names - val allNames = namesImported :+ nameWithPackage + val allNames = namesImported :+ nameWithPackage - allNames.foldLeft(inModules) { _ ++ searchRelative(_, ds) } + allNames.foldLeft(inModules) { _ ++ searchRelative(_, ds) } case d => if (n == d.id.name) { diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index a352e788f3e5e0c785798120f51207068dcb33ec..439e5445c9b7f086d380faa862c00908ce64aee0 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -209,11 +209,11 @@ object Definitions { } lazy val algebraicDataTypes : Map[AbstractClassDef, Seq[CaseClassDef]] = defs.collect { - case c@CaseClassDef(_, _, Some(p), _) => c + case c@CaseClassDef(_, _, _, Some(p), _) => c }.groupBy(_.parent.get.classDef) lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { - case c @ CaseClassDef(_, _, None, _) => c + case c @ CaseClassDef(_, _, _, None, _) => c } def duplicate = copy(defs = defs map { @@ -284,7 +284,6 @@ object Definitions { case cc : CaseClassDef => { val cc2 = cc.copy() cc.methods foreach { m => cc2.registerMethod(m.duplicate) } - cc2.setFields(cc.fields map { _.copy() }) cc2 } } @@ -310,21 +309,12 @@ object Definitions { /** Case classes/objects. */ case class CaseClassDef(id: Identifier, tparams: Seq[TypeParameterDef], + fields: Seq[ValDef], parent: Option[AbstractClassType], isCaseObject: Boolean) extends ClassDef { - private var _fields = Seq[ValDef]() - - def fields = _fields - - def setFields(fields: Seq[ValDef]) { - _fields = fields - } - - val isAbstract = false - def selectorID2Index(id: Identifier) : Int = { val index = fields.zipWithIndex.find(_._1.id == id).map(_._2) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index a5eb097d78f83570c948652f62550eb594aa599f..a1f9eb335f50c33442804fbfb4b681e9ffd369fc 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -487,7 +487,7 @@ class PrettyPrinter(opts: PrinterOptions, |}""" } - case ccd @ CaseClassDef(id, tparams, parent, isObj) => + case ccd @ CaseClassDef(id, tparams, fields, parent, isObj) => if (isObj) { p"case object $id" } else { @@ -499,7 +499,7 @@ class PrettyPrinter(opts: PrinterOptions, } if (!isObj) { - p"(${ccd.fields})" + p"($fields)" } parent.foreach{ par => diff --git a/src/main/scala/leon/purescala/RestoreMethods.scala b/src/main/scala/leon/purescala/RestoreMethods.scala index 6f4674b71107fd75e19a02b5e1399db66cd8790d..74802b9aff126faf145feb687e7d88a55acd2bab 100644 --- a/src/main/scala/leon/purescala/RestoreMethods.scala +++ b/src/main/scala/leon/purescala/RestoreMethods.scala @@ -6,10 +6,9 @@ package purescala import Definitions._ import Common._ import Expressions._ -import ExprOps.{replaceFromIDs,functionCallsOf} +import ExprOps.replaceFromIDs import DefOps.replaceFunDefs import Types._ -import utils.GraphOps._ object RestoreMethods extends TransformationPhase { @@ -19,9 +18,9 @@ object RestoreMethods extends TransformationPhase { // @TODO: This code probably needs fixing, but is mostly unused and completely untested. def apply(ctx: LeonContext, p: Program) = { - val classMethods = (p.definedFunctions.groupBy(_.origOwner).collect { + val classMethods = p.definedFunctions.groupBy(_.origOwner).collect { case (Some(cd: ClassDef), fds) => cd -> fds - }).toMap + } val fdToMd = for( (cd, fds) <- classMethods; fd <- fds) yield { val md = new FunDef(