From 7b57317a80b88db0b722c2f159a2128f0ced576d Mon Sep 17 00:00:00 2001
From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch>
Date: Mon, 8 Sep 2014 10:16:58 +0200
Subject: [PATCH] Handle imports during MethodLifting/RestoreMethods

---
 .../scala/leon/purescala/MethodLifting.scala  | 57 +++++++++++--------
 .../scala/leon/purescala/RestoreMethods.scala | 11 +++-
 2 files changed, 43 insertions(+), 25 deletions(-)

diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala
index 3ab73e549..179533af7 100644
--- a/src/main/scala/leon/purescala/MethodLifting.scala
+++ b/src/main/scala/leon/purescala/MethodLifting.scala
@@ -75,31 +75,40 @@ object MethodLifting extends TransformationPhase {
       }(e)
     }
 
-    val newUnits = program.units map { u => u.copy (modules = u.modules map { m =>
-      // We remove methods from class definitions and add corresponding functions
-      val newDefs = m.defs.flatMap {
-        case acd: AbstractClassDef if acd.methods.nonEmpty =>
-          acd +: acd.methods.map(translateMethod(_))
-
-        case ccd: CaseClassDef if ccd.methods.nonEmpty =>
-          ccd +: ccd.methods.map(translateMethod(_))
-
-        case fd: FunDef =>
-          List(translateMethod(fd))
-
-        case d =>
-          List(d)
+    val newUnits = program.units map { u => u.copy (
+      imports = u.imports flatMap {
+        case s@SingleImport(c : ClassDef) =>
+          // If a class is imported, also add the "methods" of this class
+          s :: ( c.methods map { md => SingleImport(mdToFds(md))})    
+        case other => List(other)
+      },
+        
+      modules = u.modules map { m =>
+        // We remove methods from class definitions and add corresponding functions
+        val newDefs = m.defs.flatMap {
+          case acd: AbstractClassDef if acd.methods.nonEmpty =>
+            acd +: acd.methods.map(translateMethod(_))
+  
+          case ccd: CaseClassDef if ccd.methods.nonEmpty =>
+            ccd +: ccd.methods.map(translateMethod(_))
+  
+          case fd: FunDef =>
+            List(translateMethod(fd))
+  
+          case d =>
+            List(d)
+        }
+  
+        // finally, we clear methods from classes
+        m.defs.foreach {
+          case cd: ClassDef =>
+            cd.clearMethods()
+          case _ =>
+        }
+  
+        ModuleDef(m.id, newDefs, m.isStandalone )
       }
-
-      // finally, we clear methods from classes
-      m.defs.foreach {
-        case cd: ClassDef =>
-          cd.clearMethods()
-        case _ =>
-      }
-
-      ModuleDef(m.id, newDefs, m.isStandalone )
-    })}
+    )}
 
     Program(program.id, newUnits)
   }
diff --git a/src/main/scala/leon/purescala/RestoreMethods.scala b/src/main/scala/leon/purescala/RestoreMethods.scala
index 1e94a5cc2..a6989a131 100644
--- a/src/main/scala/leon/purescala/RestoreMethods.scala
+++ b/src/main/scala/leon/purescala/RestoreMethods.scala
@@ -76,6 +76,7 @@ object RestoreMethods extends TransformationPhase {
      * Renew that function map by applying subsituteMethods on its values to obtain correct functions
      */
     val fd2MdFinal = fd2Md.mapValues(substituteMethods)
+    val oldFuns = fd2MdFinal.map{ _._1 }.toSet
     
     // We need a special type of transitive closure, detecting only trans. calls on the same argument
     def transCallsOnSameArg(fd : FunDef) : Set[FunDef] = {
@@ -120,7 +121,15 @@ object RestoreMethods extends TransformationPhase {
       m.copy(defs = m.definedClasses ++ newFuns).copiedFrom(m)    
     }
     
-    p.copy(units = p.units map { u => u.copy(modules = u.modules map refreshModule)})
+    p.copy(units = p.units map { u => u.copy(
+      modules = u.modules map refreshModule,
+      imports = u.imports flatMap {
+        // Don't include imports for functions that became methods
+        case WildcardImport(fd : FunDef) if oldFuns contains fd => None
+        case other => Some(other)
+      }
+    )})
+      
     
   }
 
-- 
GitLab