diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 77772b1cf2a168327eef06437efcc0066eb69f9b..1b7377975e4a6c6a0e8309d0cc36dc1f4a345262 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -12,6 +12,7 @@ object Main { utils.TypingPhase, FileOutputPhase, ScopingPhase, + purescala.RestoreMethods, xlang.ArrayTransformation, xlang.EpsilonElimination, xlang.ImperativeCodeElimination, @@ -212,7 +213,7 @@ object Main { } else if (settings.verify) { AnalysisPhase } else { - utils.FileOutputPhase + purescala.RestoreMethods andThen utils.FileOutputPhase } } diff --git a/src/main/scala/leon/purescala/RestoreMethods.scala b/src/main/scala/leon/purescala/RestoreMethods.scala new file mode 100644 index 0000000000000000000000000000000000000000..43a53fe47b45c679fa9343438eaa80ff2295fc00 --- /dev/null +++ b/src/main/scala/leon/purescala/RestoreMethods.scala @@ -0,0 +1,125 @@ +package leon.purescala + +import leon._ +import leon.purescala.Definitions._ +import leon.purescala.Common._ +import leon.purescala.Trees._ +import leon.purescala.TreeOps.{applyOnFunDef,preMapOnFunDef,replaceFromIDs,functionCallsOf} +import leon.purescala.TypeTrees._ +import utils.GraphOps._ + +object RestoreMethods extends TransformationPhase { + + val name = "Restore methods" + val description = "Restore methods that were previously turned into standalone functions" + + /** + * New functions are returned, whereas classes are mutated + */ + def apply(ctx : LeonContext, p : Program) = { + + var fd2Md = Map[FunDef, FunDef]() + var whoseMethods = Map[ClassDef, Seq[FunDef]]() + + for ( (Some(cd : ClassDef), funDefs) <- p.definedFunctions.groupBy(_.enclosing).toSeq ) { + whoseMethods += cd -> funDefs + for (fn <- funDefs) { + val theName = try { + // Remove class name from function name, if it is there + val maybeClassName = fn.id.name.substring(0,cd.id.name.length) + + if (maybeClassName == cd.id.name) { + fn.id.name.substring(cd.id.name.length + 1) // +1 to also remove $ + } else { + fn.id.name + } + } catch { + case e : IndexOutOfBoundsException => fn.id.name + } + val md = new FunDef( + id = FreshIdentifier(theName), + tparams = fn.tparams diff cd.tparams, + params = fn.params.tail, // no this$ + returnType = fn.returnType + ).copiedFrom(fn) + md.copyContentFrom(fn) + + // first parameter should become "this" + val thisType = fn.params.head.getType.asInstanceOf[ClassType] + val theMap = Map(fn.params.head.id -> This(thisType)) + val mdFinal = applyOnFunDef(replaceFromIDs(theMap, _))(md) + + fd2Md += fn -> mdFinal + } + } + + /** + * Substitute a function in an expression with the respective new method + */ + def substituteMethods = preMapOnFunDef ({ + case FunctionInvocation(tfd,args) if fd2Md contains tfd.fd => { + val md = fd2Md.get(tfd.fd).get // the method we are substituting + val mi = MethodInvocation( + args.head, // "this" + args.head.getType.asInstanceOf[ClassType].classDef, // this.type + md.typed(tfd.tps.takeRight(md.tparams.length)), // throw away class parameters + args.tail // rest of the arguments + ) + Some(mi) + } + case _ => None + }, true) _ + + /** + * Renew that function map by applying subsituteMethods on its values to obtain correct functions + */ + val fd2MdFinal = fd2Md.mapValues(substituteMethods) + + // We need a special type of transitive closure, detecting only trans. calls on the same argument + def transCallsOnSameArg(fd : FunDef) : Set[FunDef] = { + require(fd.params.length == 1) + require(fd.params.head.getType.isInstanceOf[ClassType]) + def callsOnSameArg(fd : FunDef) : Set[FunDef] = { + val theArg = fd.params.head.toVariable + functionCallsOf(fd.fullBody) collect { case fi if fi.args contains theArg => fi.tfd.fd } + } + reachable(callsOnSameArg,fd) + } + + def refreshModule(m : ModuleDef) = { + val newFuns : Seq[FunDef] = m.definedFunctions diff fd2MdFinal.keys.toSeq map substituteMethods// only keep non-methods + for (cl <- m.definedClasses) { + // We're going through some hoops to ensure strict fields are defined in topological order + + // We need to work with the functions of the original program to have access to CallGraph + val (strict, other) = whoseMethods.getOrElse(cl,Seq()).partition{ fd2MdFinal(_).canBeStrictField } + val strictSet = strict.toSet + // Make the call-subgraph that only includes the strict fields of this class + val strictCallGraph = strict.map { st => + (st, transCallsOnSameArg(st) & strictSet) + }.toMap + // Topologically sort, or warn in case of cycle + val strictOrdered = topologicalSorting(strictCallGraph) fold ( + cycle => { + ctx.reporter.warning( + s"""|Fields + |${cycle map {_.id} mkString "\n"} + |are involved in circular definition!""".stripMargin + ) + strict + }, + r => r + ) + + for (fun <- strictOrdered ++ other) { + cl.registerMethod(fd2MdFinal(fun)) + } + } + m.copy(defs = m.definedClasses ++ newFuns).copiedFrom(m) + } + + p.copy(modules = p.modules map refreshModule) + + } + +} diff --git a/src/main/scala/leon/utils/GraphOps.scala b/src/main/scala/leon/utils/GraphOps.scala new file mode 100644 index 0000000000000000000000000000000000000000..5dba04ece0f906450cf0108f92b7e9c032526d84 --- /dev/null +++ b/src/main/scala/leon/utils/GraphOps.scala @@ -0,0 +1,59 @@ +package leon.utils + +object GraphOps { + + /** + * Takes an graph in form of a map (vertex -> out neighbors). + * Returns a topological sorting of the vertices (Right value) if there is one. + * If there is none, it returns the set of vertices that belong to a cycle + * or come before a cycle (Left value) + */ + def topologicalSorting[A](toPreds: Map[A,Set[A]]) : Either[Set[A], Seq[A]] = { + def tSort(toPreds: Map[A, Set[A]], done: Seq[A]): Either[Set[A], Seq[A]] = { + val (noPreds, hasPreds) = toPreds.partition { _._2.isEmpty } + if (noPreds.isEmpty) { + if (hasPreds.isEmpty) Right(done.reverse) + else Left(hasPreds.keySet) + } + else { + val found : Seq[A] = noPreds.keys.toSeq + tSort(hasPreds mapValues { _ -- found }, found ++ done) + } + } + tSort(toPreds, Seq()) + } + + /** + * Returns the set of reachable nodes from a given node, + * not including the node itself (unless it is member of a cycle) + * @param next A function giving the nodes directly accessible from a given node + * @param source The source from which to begin the search + */ + def reachable[A](next : A => Set[A], source : A) : Set[A] = { + var seen = Set[A]() + def rec(current : A) { + val notSeen = next(current) -- seen + seen ++= notSeen + for (node <- notSeen) { + rec(node) + } + } + rec (source) + seen + } + + /** + * Returns true if there is a path from source to target. + * @param next A function giving the nodes directly accessible from a given node + */ + def isReachable[A](next : A => Set[A], source : A, target : A) : Boolean = { + var seen : Set[A] = Set(source) + def rec(current : A) : Boolean = { + val notSeen = next(current) -- seen + seen ++= notSeen + (next(current) contains target) || (notSeen exists rec) + } + rec(source) + } + +} \ No newline at end of file