diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 86c6cf16fd46cad2b0fe8a4bd52c1398c119c785..450114a3b6c10c3cd08be62f598f1ccbc807a8b5 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -116,7 +116,7 @@ object Definitions { abstract class Import extends Definition { def subDefinitions = Nil - lazy val importedDefs = this match { + def importedDefs = this match { case PackageImport(pack) => { import DefOps._ // Ignore standalone modules, assume there are extra imports for them diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 179533af7a4c93565bba34f5ce3135f9f882ee00..81b2e92f86f823a3fc7e19a08ff9a4df84cfdb18 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -75,39 +75,48 @@ object MethodLifting extends TransformationPhase { }(e) } - val newUnits = program.units map { u => u.copy ( + val modsToMods = ( for { + u <- program.units + m <- u.modules + } yield (m, { + // We remove methods from class definitions and add corresponding functions + val newDefs = m.defs.flatMap { + case acd: AbstractClassDef if acd.methods.nonEmpty => + acd +: acd.methods.map(translateMethod(_)) + + case ccd: CaseClassDef if ccd.methods.nonEmpty => + ccd +: ccd.methods.map(translateMethod(_)) + + case fd: FunDef => + List(translateMethod(fd)) + + case d => + List(d) + } + + // finally, we clear methods from classes + m.defs.foreach { + case cd: ClassDef => + cd.clearMethods() + case _ => + } + ModuleDef(m.id, newDefs, m.isStandalone ) + })).toMap + + val newUnits = program.units map { u => u.copy( + imports = u.imports flatMap { case s@SingleImport(c : ClassDef) => // If a class is imported, also add the "methods" of this class - s :: ( c.methods map { md => SingleImport(mdToFds(md))}) + s :: ( c.methods map { md => SingleImport(mdToFds(md))}) + // If importing a ModuleDef, update to new ModuleDef + case SingleImport(m : ModuleDef) => List(SingleImport(modsToMods(m))) + case WildcardImport(m : ModuleDef) => List(WildcardImport(modsToMods(m))) case other => List(other) }, - - modules = u.modules map { m => - // We remove methods from class definitions and add corresponding functions - val newDefs = m.defs.flatMap { - case acd: AbstractClassDef if acd.methods.nonEmpty => - acd +: acd.methods.map(translateMethod(_)) - - case ccd: CaseClassDef if ccd.methods.nonEmpty => - ccd +: ccd.methods.map(translateMethod(_)) - - case fd: FunDef => - List(translateMethod(fd)) - - case d => - List(d) - } - - // finally, we clear methods from classes - m.defs.foreach { - case cd: ClassDef => - cd.clearMethods() - case _ => - } - - ModuleDef(m.id, newDefs, m.isStandalone ) - } + + modules = u.modules map modsToMods + )} Program(program.id, newUnits)