diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 6015fa8d2cbf5ddefb0d27054e30d38c1c0884b2..9be718da8915d89cd5dfd86cf4af27ba7c85b80e 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -19,6 +19,7 @@ import scala.collection.JavaConverters._ import java.lang.reflect.Constructor + class CompilationUnit(val ctx: LeonContext, val program: Program, val params: CodeGenParams = CodeGenParams()) extends CodeGeneration { @@ -362,35 +363,29 @@ class CompilationUnit(val ctx: LeonContext, /** Traverses the program to find all definitions, and stores those in global variables */ def init() { // First define all classes/ methods/ functions - for (m <- program.modules) { - for ( (parent, children) <- m.algebraicDataTypes; - cls <- Seq(parent) ++ children) { + for (u <- program.units; m <- u.modules) { + val (parents, children) = m.algebraicDataTypes.toSeq.unzip + for ( cls <- parents ++ children.flatten ++ m.singleCaseClasses) { defineClass(cls) for (meth <- cls.methods) { defToModuleOrClass += meth -> cls } } - - for ( single <- m.singleCaseClasses ) { - defineClass(single) - for (meth <- single.methods) { - defToModuleOrClass += meth -> single - } - } - + + defineClass(m) for(funDef <- m.definedFunctions) { defToModuleOrClass += funDef -> m } - defineClass(m) + } } /** Compiles the program. Uses information provided by $init */ def compile() { // Compile everything - for (m <- program.modules) { + for (u <- program.units) { - for ((parent, children) <- m.algebraicDataTypes) { + for ((parent, children) <- u.algebraicDataTypes) { compileAbstractClassDef(parent) for (c <- children) { @@ -398,17 +393,14 @@ class CompilationUnit(val ctx: LeonContext, } } - for(single <- m.singleCaseClasses) { + for(single <- u.singleCaseClasses) { compileCaseClassDef(single) } + for (m <- u.modules) compileModule(m) } - for (m <- program.modules) { - compileModule(m) - } - classes.values.foreach(loader.register _) } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 2398e51b979bd38d67b115a123c70d9eee357e2b..9ec20b64184f8298431b2fe540e8158d70c39ae3 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -105,6 +105,10 @@ trait CodeExtraction extends ASTExtractors { throw new ImpureCodeEncounteredException(t.pos, msg, Some(t)) } + // Simple case classes to capture the representation of units/modules after discovering them. + case class TempModule(name : String, trees : List[Tree]) + case class TempUnit(name : String, modules : List[TempModule]) + class Extraction(units: List[CompilationUnit]) { private var currentFunDef: FunDef = null @@ -173,26 +177,26 @@ trait CodeExtraction extends ASTExtractors { annotationsOf(s) contains "extern" } - def extractModules: List[LeonModuleDef] = { + def extractUnits: List[UnitDef] = { try { - val templates: List[(String, List[Tree])] = { - - var standaloneDefs = List[Tree]() - - val modules = units.reverse.flatMap { u => u.body match { + val templates: List[TempUnit] = units.reverse.map { u => u.body match { - case PackageDef(name, lst) => - - lst.flatMap { _ match { + case PackageDef(refTree, lst) => + + val name = refTree.name.toString + + var standaloneDefs = List[Tree]() + + val modules = lst.flatMap { _ match { case t if isIgnored(t.symbol) => None case PackageDef(_, List(ExObjectDef(n, templ))) => - Some((n.toString, templ.body)) + Some(TempModule(n.toString, templ.body)) case ExObjectDef(n, templ) => - Some((n.toString, templ.body)) + Some(TempModule(n.toString, templ.body)) case d @ ExAbstractClass(_, _, _) => standaloneDefs ::= d @@ -208,33 +212,41 @@ trait CodeExtraction extends ASTExtractors { case other => outOfSubsetError(other, "Expected: top-level object/class.") None - }}.toList + }} + + TempUnit(name, + if (standaloneDefs.isEmpty) modules + else ( TempModule(name+ "$standalone", standaloneDefs) ) :: modules + ) - }} - + } + - // Combine all standalone definitions into one module - if (standaloneDefs.isEmpty) modules - else modules :+ ("standalone$", standaloneDefs.reverse) } // Phase 1, we detect classes/types - templates.foreach{ case (name, templ) => collectClassSymbols(templ) } - + for (TempUnit(name,mods) <- templates; mod <- mods) collectClassSymbols(mod.trees) + // Phase 2, we collect functions signatures - templates.foreach{ case (name, templ) => collectFunSigs(templ) } + for (TempUnit(name,mods) <- templates; mod <- mods) collectFunSigs(mod.trees) // Phase 3, we collect classes/types' definitions - templates.foreach{ case (name, templ) => extractClassDefs(templ) } + for (TempUnit(name,mods) <- templates; mod <- mods) extractClassDefs(mod.trees) // Phase 4, we collect methods' definitions - templates.foreach{ case (name, templ) => extractMethodDefs(templ) } + for (TempUnit(name,mods) <- templates; mod <- mods) extractMethodDefs(mod.trees) // Phase 5, we collect function definitions - templates.foreach{ case (name, templ) => extractFunDefs(templ) } + for (TempUnit(name,mods) <- templates; mod <- mods) extractFunDefs(mod.trees) // Phase 6, we create modules and extract bodies - templates.map{ case (name, templ) => extractObjectDef(name, templ) } + for (TempUnit(name,mods) <- templates) yield { + UnitDef( + FreshIdentifier(name), + for( TempModule(name,trees) <- mods) yield extractObjectDef(name, trees), + false + ) + } } catch { case icee: ImpureCodeEncounteredException => diff --git a/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala b/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala index c91b5582694d8ca7a6a1bd5e2b0915e55974b1c0..a7bb5891cf3b6ae0faec02e087eb4f951ea7711d 100644 --- a/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala +++ b/src/main/scala/leon/frontends/scalac/ExtractionPhase.scala @@ -12,7 +12,7 @@ import scala.tools.nsc.{Settings=>NSCSettings,CompilerCommand} object ExtractionPhase extends LeonPhase[List[String], Program] { - val name = "Scalc Extraction" + val name = "Scalac Extraction" val description = "Extraction of trees from the Scala Compiler" implicit val debug = DebugSectionTrees @@ -26,7 +26,8 @@ object ExtractionPhase extends LeonPhase[List[String], Program] { scala.Predef.getClass ) - val urls = neededClasses.map(_.getProtectionDomain().getCodeSource().getLocation()) + val urls = neededClasses.map{ _.getProtectionDomain().getCodeSource().getLocation() } + val classpath = urls.map(_.getPath).mkString(":") settings.classpath.value = classpath @@ -61,10 +62,9 @@ object ExtractionPhase extends LeonPhase[List[String], Program] { val run = new compiler.Run run.compile(command.files) - timer.stop() - val pgm = Program(FreshIdentifier("__program"), compiler.leonExtraction.modules) + val pgm = Program(FreshIdentifier("__program"), compiler.leonExtraction.compiledUnits) ctx.reporter.debug(pgm.asString(ctx)) pgm } else { diff --git a/src/main/scala/leon/frontends/scalac/LeonExtraction.scala b/src/main/scala/leon/frontends/scalac/LeonExtraction.scala index 075f3dc26db222fea2e5d53527be33629906bd85..bd5e8b0561e589ce49f865a424f880f728f8136a 100644 --- a/src/main/scala/leon/frontends/scalac/LeonExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/LeonExtraction.scala @@ -19,8 +19,8 @@ trait LeonExtraction extends SubComponent with CodeExtraction { val ctx: LeonContext - def modules = { - new Extraction(units).extractModules + def compiledUnits = { + new Extraction(units).extractUnits } def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new Phase(prev) diff --git a/src/main/scala/leon/purescala/CompleteAbstractDefinitions.scala b/src/main/scala/leon/purescala/CompleteAbstractDefinitions.scala index b4d1817999b532ca19e79045f03c4c19fea118b9..a287b3d5f135d81c7876e1dab6ad5e0a353918c9 100644 --- a/src/main/scala/leon/purescala/CompleteAbstractDefinitions.scala +++ b/src/main/scala/leon/purescala/CompleteAbstractDefinitions.scala @@ -19,7 +19,7 @@ object CompleteAbstractDefinitions extends TransformationPhase { // First we create the appropriate functions from methods: var mdToFds = Map[FunDef, FunDef]() - program.modules foreach { m => + for (u <- program.units; m <- u.modules ) { // We remove methods from class definitions and add corresponding functions m.defs.foreach { case fd: FunDef if fd.body.isEmpty => diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 0bf35b4869a46fedf0ed5f8c85d68c23b350f9d9..1932cfd6de3d39889b46c6e47b89ede74425d395 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -51,15 +51,15 @@ object Definitions { /** A wrapper for a program. For now a program is simply a single object. The * name is meaningless and we just use the package name as id. */ - case class Program(id: Identifier, modules: List[ModuleDef]) extends Definition { + case class Program(id: Identifier, units: List[UnitDef]) extends Definition { enclosing = None - def subDefinitions = modules + def subDefinitions = units - def definedFunctions = modules.flatMap(_.definedFunctions) - def definedClasses = modules.flatMap(_.definedClasses) - def classHierarchyRoots = modules.flatMap(_.classHierarchyRoots) - def algebraicDataTypes = modules.flatMap(_.algebraicDataTypes).toMap - def singleCaseClasses = modules.flatMap(_.singleCaseClasses) + def definedFunctions = units.flatMap(_.definedFunctions) + def definedClasses = units.flatMap(_.definedClasses) + def classHierarchyRoots = units.flatMap(_.classHierarchyRoots) + def algebraicDataTypes = units.flatMap(_.algebraicDataTypes).toMap + def singleCaseClasses = units.flatMap(_.singleCaseClasses) lazy val callGraph = new CallGraph(this) @@ -68,10 +68,7 @@ object Definitions { }.headOption.getOrElse(throw LeonFatalError("Unknown case class '"+name+"'")) def duplicate = { - copy(modules = modules.map(m => m.copy(defs = m.defs.collect { - case fd: FunDef => fd.duplicate - case d => d - }))) + copy(units = units.map{_.duplicate}) } def writeScalaFile(filename: String) { @@ -96,6 +93,41 @@ object Definitions { val id = tp.id } + + object UnitDef { + def apply(id : Identifier, modules : Seq[ModuleDef]) : UnitDef = UnitDef(id,modules, true) + } + + case class UnitDef( + val id: Identifier, + modules : Seq[ModuleDef], + isMainUnit : Boolean // false for libraries/imports + ) extends Definition { + + def subDefinitions = modules + + def definedFunctions = modules.flatMap(_.definedFunctions) + def definedClasses = modules.flatMap(_.definedClasses) + def classHierarchyRoots = modules.flatMap(_.classHierarchyRoots) + def algebraicDataTypes = modules.flatMap(_.algebraicDataTypes) + def singleCaseClasses = modules.flatMap(_.singleCaseClasses) + + def duplicate = { + copy(modules = modules map { _.duplicate } ) + } + + def writeScalaFile(filename: String) { + import java.io.FileWriter + import java.io.BufferedWriter + val fstream = new FileWriter(filename) + val out = new BufferedWriter(fstream) + out.write(ScalaPrinter(this)) + out.close + } + } + + + /** Objects work as containers for class definitions, functions (def's) and * val's. */ case class ModuleDef(id: Identifier, defs : Seq[Definition]) extends Definition { @@ -117,6 +149,13 @@ object Definitions { lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { case c @ CaseClassDef(_, _, None, _) => c } + + def duplicate = copy(defs = defs map { _ match { + case f : FunDef => f.duplicate + case cd : ClassDef => cd.duplicate + case other => other // FIXME: huh? + }}) + } @@ -170,6 +209,25 @@ object Definitions { val isAbstract: Boolean val isCaseObject: Boolean + + def duplicate = this match { + case ab : AbstractClassDef => { + val ab2 = ab.copy() + ab.knownChildren foreach ab2.registerChildren + ab.methods foreach { m => ab2.registerMethod(m.duplicate) } + ab2 + } + case cc : CaseClassDef => { + val cc2 = cc.copy() + cc.methods foreach { m => cc2.registerMethod(m.duplicate) } + cc2.setFields(cc.fields map { _.copy() }) + cc2 + } + } + + lazy val definedFunctions : Seq[FunDef] = methods + lazy val definedClasses = Seq(this) + lazy val classHierarchyRoots = if (this.hasParent) Seq(this) else Nil } /** Abstract classes. */ @@ -180,6 +238,9 @@ object Definitions { val fields = Nil val isAbstract = true val isCaseObject = false + + lazy val singleCaseClasses : Seq[CaseClassDef] = Nil + } /** Case classes/objects. */ @@ -207,6 +268,9 @@ object Definitions { scala.sys.error("Could not find '"+id+"' ("+id.uniqueName+") within "+fields.map(_.id.uniqueName).mkString(", ")) } } + + lazy val singleCaseClasses : Seq[CaseClassDef] = if (hasParent) Nil else Seq(this) + } diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 8a4b6f6a5c9f3b8f691572106f1a754d3b963819..5151f2b708a1448922370697f96f452107c14eaf 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -24,7 +24,7 @@ object FunctionClosure extends TransformationPhase { def apply(ctx: LeonContext, program: Program): Program = { - val newModules = program.modules.map { m => + val newUnits = program.units.map { u => u.copy(modules = u.modules map { m => pathConstraints = Nil enclosingLets = Nil newFunDefs = Map() @@ -39,8 +39,8 @@ object FunctionClosure extends TransformationPhase { }) ModuleDef(m.id, m.defs ++ topLevelFuns) - } - val res = Program(program.id, newModules) + })} + val res = Program(program.id, newUnits) res } diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index f041559d3cad2c30db08e1ad2ad44d95b84a691b..abac76b18f7049e5286988c00f081e9d40770343 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -75,7 +75,7 @@ object MethodLifting extends TransformationPhase { }(e) } - val newModules = program.modules map { m => + val newUnits = program.units map { u => u.copy (modules = u.modules map { m => // We remove methods from class definitions and add corresponding functions val newDefs = m.defs.flatMap { case acd: AbstractClassDef if acd.methods.nonEmpty => @@ -99,9 +99,9 @@ object MethodLifting extends TransformationPhase { } ModuleDef(m.id, newDefs) - } + })} - Program(program.id, newModules) + Program(program.id, newUnits) } } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index aafaff7539ee0f63550687f081ce55a386defc92..7338b58107e1f4c9e9ae0981f33c09c5d6365ae4 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -372,11 +372,15 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe // Definitions - case Program(id, modules) => - p"""|package $id { - | ${nary(modules, "\n\n")} - |}""" + case Program(id, units) => + p"""${nary(units, "\n\n")}""" + case UnitDef(id,modules,isBasic) => + if (isBasic) { + p"""|package $id { + | ${nary(modules,"\n\n")} + |}""" + } case ModuleDef(id, defs) => p"""|object $id { | ${nary(defs, "\n\n")} diff --git a/src/main/scala/leon/purescala/RestoreMethods.scala b/src/main/scala/leon/purescala/RestoreMethods.scala index 57f50ad4a0da0c85c3d5766cedee916934eb70d6..0a808649884927dada2f31a1f14ea8957ea335ae 100644 --- a/src/main/scala/leon/purescala/RestoreMethods.scala +++ b/src/main/scala/leon/purescala/RestoreMethods.scala @@ -119,7 +119,7 @@ object RestoreMethods extends TransformationPhase { m.copy(defs = m.definedClasses ++ newFuns).copiedFrom(m) } - p.copy(modules = p.modules map refreshModule) + p.copy(units = p.units map { u => u.copy(modules = u.modules map refreshModule)}) } diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 238c893e8dbd3e9a2b9be4c5feeae00cd8ab9f19..d048771eb35fe3ce8cb9ab29c3cda3e2105bc6fd 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -112,7 +112,9 @@ class Synthesizer(val context : LeonContext, val newDefs = sol.defs + fd - val npr = program.copy(modules = ModuleDef(FreshIdentifier("synthesis"), newDefs.toSeq) :: program.modules) + val npr = program.copy(units = program.units map { u => + u.copy(modules = ModuleDef(FreshIdentifier("synthesis"), newDefs.toSeq) +: u.modules ) + }) (npr, newDefs) } diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index fb85214a724c2e8593663bd1593ea04feeffd780..965380fab10955dd450839725f670bd234b4ace2 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -41,7 +41,8 @@ trait Solvable extends Processor { val program : Program = checker.program val context : LeonContext = checker.context val sizeModule : ModuleDef = ModuleDef(FreshIdentifier("$size", false), checker.defs.toSeq) - val newProgram : Program = program.copy(modules = sizeModule :: program.modules) + val sizeUnit : UnitDef = UnitDef(FreshIdentifier("$size", false),Seq(sizeModule),false) + val newProgram : Program = program.copy( units = sizeUnit :: program.units) (new FairZ3Solver(context, newProgram) with TimeoutAssumptionSolver).setTimeout(500L) }) diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index 1c67c8f165eee6d4c2a4893dee068e9a601fb3f5..0bd55feb3b29d29f686896730ea82d13207f0bf8 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -19,7 +19,7 @@ object UnitElimination extends TransformationPhase { private var id2FreshId: Map[Identifier, Identifier] = Map() def apply(ctx: LeonContext, pgm: Program): Program = { - val newModules = pgm.modules.map { m => + val newUnits = pgm.units map { u => u.copy(modules = u.modules.map { m => fun2FreshFun = Map() val allFuns = m.definedFunctions @@ -45,10 +45,10 @@ object UnitElimination extends TransformationPhase { }) ModuleDef(m.id, m.definedClasses ++ newFuns) - } + })} - Program(pgm.id, newModules) + Program(pgm.id, newUnits) } private def simplifyType(tpe: TypeTree): TypeTree = tpe match { diff --git a/src/test/scala/leon/test/purescala/DataGen.scala b/src/test/scala/leon/test/purescala/DataGen.scala index f537a06a3bf6b732901a7ec25e3133f572166780..b68c1526d12aa600a5d0c1dc27869ef3253fe7df 100644 --- a/src/test/scala/leon/test/purescala/DataGen.scala +++ b/src/test/scala/leon/test/purescala/DataGen.scala @@ -68,7 +68,7 @@ class DataGen extends LeonTestSuite { generator.generate(TupleType(Seq(BooleanType,BooleanType))).toSet.size === 4 // Make sure we target our own lists - val module = prog.modules.find(_.id.name == "Program").get + val module = prog.units.flatMap{_.modules}.find(_.id.name == "Program").get val listType : TypeTree = classDefToClassType(module.classHierarchyRoots.head) val sizeDef : FunDef = module.definedFunctions.find(_.id.name == "size").get val sortedDef : FunDef = module.definedFunctions.find(_.id.name == "isSorted").get diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala index c0ab6eadba1465a232eff3c062b5d242b4039d1e..b646893d3b4de1ce3032b022a7a276bf04430a31 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala @@ -44,8 +44,9 @@ class FairZ3SolverTests extends LeonTestSuite { private val minimalProgram = Program( FreshIdentifier("Minimal"), - List(ModuleDef(FreshIdentifier("Minimal"), Seq( - fDef + List(UnitDef( + FreshIdentifier("Minimal"), + List(ModuleDef(FreshIdentifier("Minimal"), Seq(fDef)) ))) ) diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala index b10a84d80d064a1fdf0453bfcb839bb2537d99a1..663a9038a6dca67284004defff6953601890910c 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala @@ -52,8 +52,9 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { private val minimalProgram = Program( FreshIdentifier("Minimal"), - List(ModuleDef(FreshIdentifier("Minimal"), Seq( - fDef + List(UnitDef( + FreshIdentifier("Minimal"), + List(ModuleDef(FreshIdentifier("Minimal"), Seq(fDef)) ))) ) diff --git a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala index 2fb9909c4e9dee24992a95628ddead0e570b751a..561a107dc3f9d85b5104b9401a6337cddad6a404 100644 --- a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala @@ -49,8 +49,9 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { private val minimalProgram = Program( FreshIdentifier("Minimal"), - List(ModuleDef(FreshIdentifier("Minimal"), Seq( - fDef + List(UnitDef( + FreshIdentifier("Minimal"), + List(ModuleDef(FreshIdentifier("Minimal"), Seq(fDef)) ))) )