From e782ef6f744e13e6d0085bcf7792d3b18af24d63 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Mon, 10 Dec 2012 16:41:43 +0100
Subject: [PATCH] Provide all classes to CompilationUnit

---
 .../scala/leon/codegen/CodeGeneration.scala   |  5 +++-
 .../scala/leon/codegen/CompilationUnit.scala  | 28 +++++++++++++------
 2 files changed, 24 insertions(+), 9 deletions(-)

diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala
index 8858051af..84a74702d 100644
--- a/src/main/scala/leon/codegen/CodeGeneration.scala
+++ b/src/main/scala/leon/codegen/CodeGeneration.scala
@@ -43,8 +43,11 @@ object CodeGeneration {
       case Int32Type | BooleanType =>
         ch << IRETURN
 
+      case UnitType | TupleType(_)  | SetType(_) | MapType(_, _) | AbstractClassType(_) | CaseClassType(_) => 
+        ch << ARETURN
+
       case other =>
-        throw CompilationException("Unsupported return type : " + other)
+        throw CompilationException("Unsupported return type : " + other.getClass)
     }
 
     ch.freeze
diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala
index f0d9b17ee..f7ae846c2 100644
--- a/src/main/scala/leon/codegen/CompilationUnit.scala
+++ b/src/main/scala/leon/codegen/CompilationUnit.scala
@@ -14,14 +14,14 @@ import cafebabe.Flags._
 
 import CodeGeneration._
 
-class CompilationUnit(val program: Program, val mainClass: ClassFile, implicit val env: CompilationEnvironment) {
-  val mainClassName = defToJVMName(program, program.mainObject)
-
+class CompilationUnit(val program: Program, val classes: Seq[ClassFile], implicit val env: CompilationEnvironment) {
   val loader = new CafebabeClassLoader
-  loader.register(mainClass)
+  classes.foreach(loader.register(_))
 
   def writeClassFiles() {
-    mainClass.writeToFile(mainClassName + ".class")
+    for (cl <- classes) {
+      cl.writeToFile(cl.className + ".class")
+    }
   }
 
   private var _nextExprId = 0
@@ -41,8 +41,12 @@ class CompilationUnit(val program: Program, val mainClass: ClassFile, implicit v
     case b: java.lang.Boolean =>
       BooleanLiteral(b.booleanValue)
 
+    case cc: runtime.CaseClass =>
+      println("YAY")
+      throw CompilationException("YAY Unsupported return value : " + e)
+
     case _ => 
-      throw CompilationException("Unsupported return value : " + e)
+      throw CompilationException("MEH Unsupported return value : " + e.getClass)
   }
 
   def compileExpression(e: Expr, args: Seq[Identifier]): CompiledExpression = {
@@ -83,7 +87,7 @@ class CompilationUnit(val program: Program, val mainClass: ClassFile, implicit v
       case Int32Type | BooleanType =>
         ch << IRETURN
 
-      case UnitType | TupleType(_)  | SetType(_) | MapType(_, _) => 
+      case UnitType | TupleType(_)  | SetType(_) | MapType(_, _) | AbstractClassType(_) | CaseClassType(_) => 
         ch << ARETURN
 
       case other =>
@@ -102,13 +106,21 @@ object CompilationUnit {
   def compileProgram(p: Program): Option[CompilationUnit] = {
     implicit val env = CompilationEnvironment.fromProgram(p)
 
+    var classes = Seq[ClassFile]()
+
     for((parent,children) <- p.algebraicDataTypes) {
       val acf = compileAbstractClassDef(p, parent)
       val ccfs = children.map(c => compileCaseClassDef(p, c))
+
+      classes = classes :+ acf
+      classes = classes ++ ccfs
     } 
 
     val mainClassName = defToJVMName(p, p.mainObject)
     val cf = new ClassFile(mainClassName, None)
+
+    classes = classes :+ cf
+
     cf.addDefaultConstructor
 
     cf.setFlags((
@@ -136,6 +148,6 @@ object CompilationUnit {
       compileFunDef(funDef, m.codeHandler)
     }
 
-    Some(new CompilationUnit(p, cf, env))
+    Some(new CompilationUnit(p, classes, env))
   }
 }
-- 
GitLab