Skip to content
Snippets Groups Projects
Commit 6a21d74f authored by Mikaël Mayer's avatar Mikaël Mayer
Browse files

Keep original functions during function replacement if they can be kept.

parent d05422f7
No related branches found
No related tags found
No related merge requests found
......@@ -290,15 +290,42 @@ object DefOps {
: (Program, Map[FunDef, FunDef])= {
var fdMapCache = Map[FunDef, FunDef]()
var fdStatic = Set[FunDef]()
def fdMap(fd: FunDef): FunDef = {
if (!(fdMapCache contains fd)) {
fdMapCache += fd -> fdMapF(fd).getOrElse(fd.duplicate())
}
fdMapCache(fd)
if(fdStatic(fd)) fd else {
// We have to duplicate all calling functions except those not required to change and whose transitive descendants are all not required to change.
// If one descendant is required to change, then all its transitive callers have to change including the function itself.
val mappings = for(callee <- (fd::p.callGraph.transitiveCallees(fd).toList).distinct) yield {
callee -> fdMapCache.get(callee).orElse(fdMapF(callee))
}
if(mappings.exists(kv => kv._2 != None && kv._2 != Some(kv._1))) {
for(m <- mappings) {
m match {
case (f, Some(newFd)) =>
fdMapCache += f -> newFd
for(caller <- p.callGraph.transitiveCallers(f)) {
if(!(fdMapCache contains caller)) {
fdMapCache += f -> f.duplicate()
}
}
case (f, None) =>
}
}
if(fdMapCache contains fd) {
fdMapCache(fd)
} else {
fdStatic += fd
fd
}
} else {
fdStatic += fd
fd
}
}
} else fdMapCache(fd)
}
val newP = p.copy(units = for (u <- p.units) yield {
u.copy(
defs = u.defs.map {
......@@ -313,17 +340,28 @@ object DefOps {
}
)
})
// TODO: Check for function invocations in unapply patterns.
for(fd <- newP.definedFunctions) {
if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd case _ => false }(fd.fullBody)) {
fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache, fiMapF)
if(ExprOps.exists{
case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd
case MatchExpr(_, cases) => cases.exists(c => PatternOps.exists{
case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => fdMapCache contains fd
case _ => false
}(c.pattern))
case _ => false
}(fd.fullBody)) {
fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache.withDefault { x => x }, fiMapF)
}
}
(newP, fdMapCache)
}
def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = {
def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = {
preMap {
case MatchExpr(scrut, cases) =>
Some(MatchExpr(scrut, cases.map(matchcase => matchcase match {
case MatchCase(pattern, guard, rhs) => MatchCase(replaceFunCalls(pattern, fdMapF), guard, rhs)
})))
case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) =>
fiMapF(fi, fdMapF(fd)).map(_.setPos(fi))
case _ =>
......@@ -331,6 +369,10 @@ object DefOps {
}(e)
}
def replaceFunCalls(p: Pattern, fdMapF: FunDef => FunDef): Pattern = PatternOps.preMap{
case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId, TypedFunDef(fdMapF(fd), tps), subp))
case _ => None
}(p)
private def defaultCdMap(cc: CaseClass, ccd: CaseClassType): Option[Expr] = (cc, ccd) match {
case (CaseClass(old, args), newCcd) if old.classDef != newCcd =>
......@@ -488,7 +530,12 @@ object DefOps {
)
})
if (!found) {
println("addDefs could not find anchor definition!")
println(s"addDefs could not find anchor definition! Not found: $after")
p.definedFunctions.filter(f => f.id.name == after.id.name).map(fd => fd.id.name + " : " + fd) match {
case Nil =>
case e => println("Did you mean " + e)
}
println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n"))
}
res
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment