diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 3ab73e549b25eb7ff3be969fe8fa98ca0a4b6255..179533af7a4c93565bba34f5ce3135f9f882ee00 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -75,31 +75,40 @@ object MethodLifting extends TransformationPhase { }(e) } - val newUnits = program.units map { u => u.copy (modules = u.modules map { m => - // We remove methods from class definitions and add corresponding functions - val newDefs = m.defs.flatMap { - case acd: AbstractClassDef if acd.methods.nonEmpty => - acd +: acd.methods.map(translateMethod(_)) - - case ccd: CaseClassDef if ccd.methods.nonEmpty => - ccd +: ccd.methods.map(translateMethod(_)) - - case fd: FunDef => - List(translateMethod(fd)) - - case d => - List(d) + val newUnits = program.units map { u => u.copy ( + imports = u.imports flatMap { + case s@SingleImport(c : ClassDef) => + // If a class is imported, also add the "methods" of this class + s :: ( c.methods map { md => SingleImport(mdToFds(md))}) + case other => List(other) + }, + + modules = u.modules map { m => + // We remove methods from class definitions and add corresponding functions + val newDefs = m.defs.flatMap { + case acd: AbstractClassDef if acd.methods.nonEmpty => + acd +: acd.methods.map(translateMethod(_)) + + case ccd: CaseClassDef if ccd.methods.nonEmpty => + ccd +: ccd.methods.map(translateMethod(_)) + + case fd: FunDef => + List(translateMethod(fd)) + + case d => + List(d) + } + + // finally, we clear methods from classes + m.defs.foreach { + case cd: ClassDef => + cd.clearMethods() + case _ => + } + + ModuleDef(m.id, newDefs, m.isStandalone ) } - - // finally, we clear methods from classes - m.defs.foreach { - case cd: ClassDef => - cd.clearMethods() - case _ => - } - - ModuleDef(m.id, newDefs, m.isStandalone ) - })} + )} Program(program.id, newUnits) } diff --git a/src/main/scala/leon/purescala/RestoreMethods.scala b/src/main/scala/leon/purescala/RestoreMethods.scala index 1e94a5cc21176a0e05aea2b2023572033237db98..a6989a13144b61bf427137722392bde65c91a32b 100644 --- a/src/main/scala/leon/purescala/RestoreMethods.scala +++ b/src/main/scala/leon/purescala/RestoreMethods.scala @@ -76,6 +76,7 @@ object RestoreMethods extends TransformationPhase { * Renew that function map by applying subsituteMethods on its values to obtain correct functions */ val fd2MdFinal = fd2Md.mapValues(substituteMethods) + val oldFuns = fd2MdFinal.map{ _._1 }.toSet // We need a special type of transitive closure, detecting only trans. calls on the same argument def transCallsOnSameArg(fd : FunDef) : Set[FunDef] = { @@ -120,7 +121,15 @@ object RestoreMethods extends TransformationPhase { m.copy(defs = m.definedClasses ++ newFuns).copiedFrom(m) } - p.copy(units = p.units map { u => u.copy(modules = u.modules map refreshModule)}) + p.copy(units = p.units map { u => u.copy( + modules = u.modules map refreshModule, + imports = u.imports flatMap { + // Don't include imports for functions that became methods + case WildcardImport(fd : FunDef) if oldFuns contains fd => None + case other => Some(other) + } + )}) + }