diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index b159fdbdba34313cbe538767910750010456d45d..b69cda26f8d994714d03aa4927c8af470a384df6 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -384,18 +384,7 @@ trait CodeExtraction extends ASTExtractors { case _ => (selectors, false) } - val theDef = searchRelative(thePath.mkString("."), current).find { - case _: UnitDef => false - case m: LeonModuleDef => !m.isPackageObject - case _ => true - } - - (isWild, theDef) match { - case (true, Some(df)) => Some(WildcardImport(df)) - case (false, Some(df)) => Some(SingleImport(df)) - case (true, None) => Some(PackageImport(thePath)) - case (false, None) => None // import comes from a Scala library or something... - } + Some(LeonImport(thePath, isWild)) } } diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 7e4d216a32bbef4801936cb76f6bf8588820e52e..0e6150ac1ce9cf0b9d3b1adc3c6810263ca7064c 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -54,7 +54,7 @@ object DefOps { List(pgm) ++ ( for ( u <- unitOf(df).toSeq; imp <- u.imports; - impDf <- imp.importedDefs + impDf <- imp.importedDefs(u) ) yield impDf ) for ( @@ -108,9 +108,8 @@ object DefOps { pathFrom match { case (u: UnitDef) :: _ => val imports = u.imports.map { - case PackageImport(ls) => ls - case SingleImport(d) => nameToParts(fullName(d, useUniqueIds)).init - case WildcardImport(d) => nameToParts(fullName(d, useUniqueIds)) + case Import(path, true) => path + case Import(path, false) => path.init }.toList def stripImport(of: List[String], imp: List[String]): Option[List[String]] = { @@ -210,14 +209,12 @@ object DefOps { searchRelative(names, path.reverse) } - private case class ImportPath(ls: List[String], wild: Boolean) - - private def resolveImports(imports: List[ImportPath], names: List[String]): List[List[String]] = { - def resolveImport(i: ImportPath): Option[List[String]] = { - if (!i.wild && names.startsWith(i.ls.last)) { - Some(i.ls ++ names.tail) - } else if (i.wild) { - Some(i.ls ++ names) + private def resolveImports(imports: Seq[Import], names: List[String]): Seq[List[String]] = { + def resolveImport(i: Import): Option[List[String]] = { + if (!i.isWild && names.startsWith(i.path.last)) { + Some(i.path ++ names.tail) + } else if (i.isWild) { + Some(i.path ++ names) } else { None } @@ -234,17 +231,11 @@ object DefOps { searchWithin(names, p) case u: UnitDef => - val imports = u.imports.map { - case PackageImport(ls) => ImportPath(ls, true) - case SingleImport(d) => ImportPath(nameToParts(fullName(d)), false) - case WildcardImport(d) => ImportPath(nameToParts(fullName(d)), true) - }.toList - val inModules = d.subDefinitions.filter(_.id.name == n).flatMap { sd => searchWithin(ns, sd) } - val namesImported = resolveImports(imports, names) + val namesImported = resolveImports(u.imports, names) val nameWithPackage = u.pack ++ names val allNames = namesImported :+ nameWithPackage @@ -299,11 +290,6 @@ object DefOps { } }) case d => d - }, - imports = u.imports map { - case SingleImport(fd : FunDef) => - SingleImport(fdMap(fd)) - case other => other } ) }) @@ -347,13 +333,16 @@ object DefOps { res } - // @Note: This function does not remove functions in classdefs - def removeDefs(p: Program, dds: Set[Definition]): Program = { - p.copy(units = for (u <- p.units if !dds(u)) yield { + // @Note: This function does not filter functions in classdefs + def filterFunDefs(p: Program, fdF: FunDef => Boolean): Program = { + p.copy(units = p.units.map { u => u.copy( - defs = for (d <- u.defs if !dds(d)) yield d match { + defs = u.defs.collect { case md: ModuleDef => - md.copy(defs = md.defs.filterNot(dds)) + md.copy(defs = md.defs.filter { + case fd: FunDef => fdF(fd) + case d => true + }) case cd => cd } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 127c562217fbb8d64329d87db3bbff969231405a..da9f619efbc8a4e14ac93eb32e0a01aeaa0dcbbb 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -101,11 +101,17 @@ object Definitions { /** A package as a path of names */ type PackageRef = List[String] - abstract class Import extends Definition { - def subDefinitions = Nil - def importedDefs(implicit pgm: Program): Seq[Definition] + case class Import(path: List[String], isWild: Boolean) extends Tree { + def importedDefs(in: UnitDef)(implicit pgm: Program): Seq[Definition] = { + val found = DefOps.searchRelative(path.mkString("."), in) + if (isWild) { + found.flatMap(_.subDefinitions) + } else { + found + } + } } - +/* // import pack._ case class PackageImport(pack : PackageRef) extends Import { val id = FreshIdentifier("import " + (pack mkString ".")) @@ -132,6 +138,7 @@ object Definitions { def importedDefs(implicit pgm: Program): Seq[Definition] = df.subDefinitions } + */ case class UnitDef( id: Identifier, diff --git a/src/main/scala/leon/purescala/FunctionMapping.scala b/src/main/scala/leon/purescala/FunctionMapping.scala index c7a31c61500a144f531dc7963ed87b0c963a400e..34a9c2d8087c76dc31dd0e9a81c650e75239b1f5 100644 --- a/src/main/scala/leon/purescala/FunctionMapping.scala +++ b/src/main/scala/leon/purescala/FunctionMapping.scala @@ -53,13 +53,7 @@ abstract class FunctionMapping extends TransformationPhase { } }) case d => d - }, - imports = u.imports map { - case SingleImport(fd : FunDef) => - SingleImport(functionToFunction.get(fd).map{ _.to }.getOrElse(fd)) - case other => other - } - + } ) }) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 57269f7cedaa5cb84697e78dd6bc54ed110b6d63..887c50464eab12fcc4bebf46b83f1c5201bd92b9 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -450,14 +450,12 @@ class PrettyPrinter(opts: PrinterOptions, |${nary(defs,"\n\n")} |""" - case PackageImport(pack) => - p"import ${nary(pack,".")}._" - - case SingleImport(df) => - p"import "; printWithPath(df) - - case WildcardImport(df) => - p"import "; printWithPath(df); p"._" + case Import(path, isWild) => + if (isWild) { + p"import ${nary(path,".")}._" + } else { + p"import ${nary(path,".")}" + } case ModuleDef(id, defs, _) => p"""|object $id { diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala index 5382275db54b3508a63aee57a11b5f229078f545..e11fec3f653302f783507c3d3034c435743e005f 100644 --- a/src/main/scala/leon/utils/InliningPhase.scala +++ b/src/main/scala/leon/utils/InliningPhase.scala @@ -16,14 +16,6 @@ object InliningPhase extends TransformationPhase { val description = "Inline functions marked as @inline and remove their definitions" def apply(ctx: LeonContext, p: Program): Program = { - val toInline = p.definedFunctions.filter(_.flags(IsInlined)).toSet - - val substs = toInline.map { fd => - fd -> { (tps: Seq[TypeTree], s: Seq[Expr]) => - val newBody = replaceFromIDs(fd.params.map(_.id).zip(s).toMap, fd.fullBody) - instantiateType(newBody, (fd.tparams zip tps).toMap, Map()) - } - }.toMap def simplifyImplicitClass(e: Expr) = e match { case CaseClassSelector(cct, cc: CaseClass, id) => @@ -37,15 +29,17 @@ object InliningPhase extends TransformationPhase { fixpoint(postMap(simplifyImplicitClass _))(e) } - val (np, _) = replaceFunDefs(p)({fd => None}, {(fi, fd) => - if (substs contains fd) { - Some(simplify(substs(fd)(fi.tfd.tps, fi.args))) - } else { - None - } - }) + for (fd <- p.definedFunctions) { + fd.fullBody = simplify(preMap { + case FunctionInvocation(TypedFunDef(fd, tps), args) if fd.flags(IsInlined) => + val newBody = replaceFromIDs(fd.params.map(_.id).zip(args).toMap, fd.fullBody) + Some(instantiateType(newBody, (fd.tparams zip tps).toMap, Map())) + case _ => + None + }(fd.fullBody)) + } - removeDefs(np, toInline.map { fd => (fd: Definition) }) + filterFunDefs(p, fd => !fd.flags(IsInlined)) } } diff --git a/src/test/scala/leon/test/purescala/DefOpsSuite.scala b/src/test/scala/leon/test/purescala/DefOpsSuite.scala index 245c1191fac5b801a8c7bb122749a5ae278161f9..86350f2948a1e9d01487203d02dd23482179cdd4 100644 --- a/src/test/scala/leon/test/purescala/DefOpsSuite.scala +++ b/src/test/scala/leon/test/purescala/DefOpsSuite.scala @@ -13,13 +13,13 @@ private [purescala] object DefOpsHelper extends LeonTestSuite { private def parseStrings(strs : List[String]) : Program = { val context = createLeonContext() - val pipeline = - ExtractionPhase andThen - PreprocessingPhase - + val pipeline = + ExtractionPhase andThen + PreprocessingPhase + val inputs = strs map { str => TemporaryInputPhase.run(context)((str, Nil)).head } val program = pipeline.run(context)(inputs) - + program }