Skip to content
Snippets Groups Projects
Commit 4e9c02f2 authored by Emmanouil (Manos) Koukoutos's avatar Emmanouil (Manos) Koukoutos
Browse files

RestoreMethods after they have been turned to funs

parent 0f51db75
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,7 @@ object Main { ...@@ -12,6 +12,7 @@ object Main {
utils.TypingPhase, utils.TypingPhase,
FileOutputPhase, FileOutputPhase,
ScopingPhase, ScopingPhase,
purescala.RestoreMethods,
xlang.ArrayTransformation, xlang.ArrayTransformation,
xlang.EpsilonElimination, xlang.EpsilonElimination,
xlang.ImperativeCodeElimination, xlang.ImperativeCodeElimination,
...@@ -212,7 +213,7 @@ object Main { ...@@ -212,7 +213,7 @@ object Main {
} else if (settings.verify) { } else if (settings.verify) {
AnalysisPhase AnalysisPhase
} else { } else {
utils.FileOutputPhase purescala.RestoreMethods andThen utils.FileOutputPhase
} }
} }
......
package leon.purescala
import leon._
import leon.purescala.Definitions._
import leon.purescala.Common._
import leon.purescala.Trees._
import leon.purescala.TreeOps.{applyOnFunDef,preMapOnFunDef,replaceFromIDs,functionCallsOf}
import leon.purescala.TypeTrees._
import utils.GraphOps._
object RestoreMethods extends TransformationPhase {
val name = "Restore methods"
val description = "Restore methods that were previously turned into standalone functions"
/**
* New functions are returned, whereas classes are mutated
*/
def apply(ctx : LeonContext, p : Program) = {
var fd2Md = Map[FunDef, FunDef]()
var whoseMethods = Map[ClassDef, Seq[FunDef]]()
for ( (Some(cd : ClassDef), funDefs) <- p.definedFunctions.groupBy(_.enclosing).toSeq ) {
whoseMethods += cd -> funDefs
for (fn <- funDefs) {
val theName = try {
// Remove class name from function name, if it is there
val maybeClassName = fn.id.name.substring(0,cd.id.name.length)
if (maybeClassName == cd.id.name) {
fn.id.name.substring(cd.id.name.length + 1) // +1 to also remove $
} else {
fn.id.name
}
} catch {
case e : IndexOutOfBoundsException => fn.id.name
}
val md = new FunDef(
id = FreshIdentifier(theName),
tparams = fn.tparams diff cd.tparams,
params = fn.params.tail, // no this$
returnType = fn.returnType
).copiedFrom(fn)
md.copyContentFrom(fn)
// first parameter should become "this"
val thisType = fn.params.head.getType.asInstanceOf[ClassType]
val theMap = Map(fn.params.head.id -> This(thisType))
val mdFinal = applyOnFunDef(replaceFromIDs(theMap, _))(md)
fd2Md += fn -> mdFinal
}
}
/**
* Substitute a function in an expression with the respective new method
*/
def substituteMethods = preMapOnFunDef ({
case FunctionInvocation(tfd,args) if fd2Md contains tfd.fd => {
val md = fd2Md.get(tfd.fd).get // the method we are substituting
val mi = MethodInvocation(
args.head, // "this"
args.head.getType.asInstanceOf[ClassType].classDef, // this.type
md.typed(tfd.tps.takeRight(md.tparams.length)), // throw away class parameters
args.tail // rest of the arguments
)
Some(mi)
}
case _ => None
}, true) _
/**
* Renew that function map by applying subsituteMethods on its values to obtain correct functions
*/
val fd2MdFinal = fd2Md.mapValues(substituteMethods)
// We need a special type of transitive closure, detecting only trans. calls on the same argument
def transCallsOnSameArg(fd : FunDef) : Set[FunDef] = {
require(fd.params.length == 1)
require(fd.params.head.getType.isInstanceOf[ClassType])
def callsOnSameArg(fd : FunDef) : Set[FunDef] = {
val theArg = fd.params.head.toVariable
functionCallsOf(fd.fullBody) collect { case fi if fi.args contains theArg => fi.tfd.fd }
}
reachable(callsOnSameArg,fd)
}
def refreshModule(m : ModuleDef) = {
val newFuns : Seq[FunDef] = m.definedFunctions diff fd2MdFinal.keys.toSeq map substituteMethods// only keep non-methods
for (cl <- m.definedClasses) {
// We're going through some hoops to ensure strict fields are defined in topological order
// We need to work with the functions of the original program to have access to CallGraph
val (strict, other) = whoseMethods.getOrElse(cl,Seq()).partition{ fd2MdFinal(_).canBeStrictField }
val strictSet = strict.toSet
// Make the call-subgraph that only includes the strict fields of this class
val strictCallGraph = strict.map { st =>
(st, transCallsOnSameArg(st) & strictSet)
}.toMap
// Topologically sort, or warn in case of cycle
val strictOrdered = topologicalSorting(strictCallGraph) fold (
cycle => {
ctx.reporter.warning(
s"""|Fields
|${cycle map {_.id} mkString "\n"}
|are involved in circular definition!""".stripMargin
)
strict
},
r => r
)
for (fun <- strictOrdered ++ other) {
cl.registerMethod(fd2MdFinal(fun))
}
}
m.copy(defs = m.definedClasses ++ newFuns).copiedFrom(m)
}
p.copy(modules = p.modules map refreshModule)
}
}
package leon.utils
object GraphOps {
/**
* Takes an graph in form of a map (vertex -> out neighbors).
* Returns a topological sorting of the vertices (Right value) if there is one.
* If there is none, it returns the set of vertices that belong to a cycle
* or come before a cycle (Left value)
*/
def topologicalSorting[A](toPreds: Map[A,Set[A]]) : Either[Set[A], Seq[A]] = {
def tSort(toPreds: Map[A, Set[A]], done: Seq[A]): Either[Set[A], Seq[A]] = {
val (noPreds, hasPreds) = toPreds.partition { _._2.isEmpty }
if (noPreds.isEmpty) {
if (hasPreds.isEmpty) Right(done.reverse)
else Left(hasPreds.keySet)
}
else {
val found : Seq[A] = noPreds.keys.toSeq
tSort(hasPreds mapValues { _ -- found }, found ++ done)
}
}
tSort(toPreds, Seq())
}
/**
* Returns the set of reachable nodes from a given node,
* not including the node itself (unless it is member of a cycle)
* @param next A function giving the nodes directly accessible from a given node
* @param source The source from which to begin the search
*/
def reachable[A](next : A => Set[A], source : A) : Set[A] = {
var seen = Set[A]()
def rec(current : A) {
val notSeen = next(current) -- seen
seen ++= notSeen
for (node <- notSeen) {
rec(node)
}
}
rec (source)
seen
}
/**
* Returns true if there is a path from source to target.
* @param next A function giving the nodes directly accessible from a given node
*/
def isReachable[A](next : A => Set[A], source : A, target : A) : Boolean = {
var seen : Set[A] = Set(source)
def rec(current : A) : Boolean = {
val notSeen = next(current) -- seen
seen ++= notSeen
(next(current) contains target) || (notSeen exists rec)
}
rec(source)
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment