-
Manos Koukoutos authoredManos Koukoutos authored
MethodLifting.scala 3.96 KiB
/* Copyright 2009-2014 EPFL, Lausanne */
package leon
package purescala
import Common._
import Definitions._
import Expressions._
import Extractors._
import ExprOps._
import Types._
import TypeOps.instantiateType
object MethodLifting extends TransformationPhase {
val name = "Method Lifting"
val description = "Translate methods into top-level functions"
def apply(ctx: LeonContext, program: Program): Program = {
// First we create the appropriate functions from methods:
var mdToFds = Map[FunDef, FunDef]()
program.classHierarchyRoots.filter(_.methods.nonEmpty) flatMap { cd =>
cd.methods.map { fd =>
// We import class type params and freshen them
val ctParams = cd.tparams map { _.freshen }
val tparamsMap = cd.tparams.zip(ctParams map { _.tp }).toMap
val id = FreshIdentifier(cd.id.name+"$"+fd.id.name).setPos(fd.id)
val recType = classDefToClassType(cd, ctParams.map(_.tp))
val retType = instantiateType(fd.returnType, tparamsMap)
val fdParams = fd.params map { vd =>
val newId = FreshIdentifier(vd.id.name, instantiateType(vd.id.getType, tparamsMap))
ValDef(newId).setPos(vd.getPos)
}
val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap
val receiver = FreshIdentifier("$this", recType).setPos(cd.id)
val nfd = new FunDef(id, ctParams ++ fd.tparams, retType, ValDef(receiver) +: fdParams, fd.defType)
nfd.copyContentFrom(fd)
nfd.setPos(fd)
nfd.fullBody = instantiateType(nfd.fullBody, tparamsMap, paramsMap)
mdToFds += fd -> nfd
}
}
def translateMethod(fd: FunDef) = {
val (nfd, rec) = mdToFds.get(fd) match {
case Some(nfd) =>
(nfd, Some(() => Variable(nfd.params(0).id)))
case None =>
(fd, None)
}
nfd.fullBody = removeMethodCalls(rec)(nfd.fullBody)
nfd
}
def removeMethodCalls(rec: Option[() => Expr])(e: Expr): Expr = {
postMap{
case th: This =>
rec match {
case Some(r) =>
Some(r().setPos(th))
case None =>
ctx.reporter.fatalError("`this` used out of a method context?!?")
}
case mi @ MethodInvocation(rec, cd, tfd, args) =>
rec match {
case IsTyped(rec, ct: ClassType) =>
Some(FunctionInvocation(mdToFds(tfd.fd).typed(ct.tps ++ tfd.tps), rec +: args).setPos(mi))
case _ =>
ctx.reporter.fatalError("MethodInvocation on a non-class receiver !?!")
}
case _ => None
}(e)
}
val modsToMods = ( for {
u <- program.units
m <- u.modules
} yield (m, {
// We remove methods from class definitions and add corresponding functions
val newDefs = m.defs.flatMap {
case acd: AbstractClassDef if acd.methods.nonEmpty =>
acd +: acd.methods.map(translateMethod)
case ccd: CaseClassDef if ccd.methods.nonEmpty =>
ccd +: ccd.methods.map(translateMethod)
case fd: FunDef =>
List(translateMethod(fd))
case d =>
List(d)
}
// finally, we clear methods from classes
m.defs.foreach {
case cd: ClassDef =>
cd.clearMethods()
case _ =>
}
ModuleDef(m.id, newDefs, m.isStandalone )
})).toMap
val newUnits = program.units map { u => u.copy(
imports = u.imports flatMap {
case s@SingleImport(c : ClassDef) =>
// If a class is imported, also add the "methods" of this class
s :: ( c.methods map { md => SingleImport(mdToFds(md))})
// If importing a ModuleDef, update to new ModuleDef
case SingleImport(m : ModuleDef) => List(SingleImport(modsToMods(m)))
case WildcardImport(m : ModuleDef) => List(WildcardImport(modsToMods(m)))
case other => List(other)
},
modules = u.modules map modsToMods
)}
Program(program.id, newUnits)
}
}