Skip to content
Snippets Groups Projects
Commit 90bfec6d authored by Manos Koukoutos's avatar Manos Koukoutos
Browse files

MethodLifting now handles complex inheritance correctly

parent 9c5101fc
No related branches found
No related tags found
No related merge requests found
...@@ -95,49 +95,43 @@ object MethodLifting extends TransformationPhase { ...@@ -95,49 +95,43 @@ object MethodLifting extends TransformationPhase {
// First we create the appropriate functions from methods: // First we create the appropriate functions from methods:
var mdToFds = Map[FunDef, FunDef]() var mdToFds = Map[FunDef, FunDef]()
var mdToCls = Map[FunDef, ClassDef]()
// Lift methods to the root class
for {
u <- program.units
ch <- u.classHierarchies
c <- ch
if c.parent.isDefined
fd <- c.methods
if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id))
} {
val root = c.ancestors.last
val tMap = c.tparams.zip(root.tparams.map{_.tp}).toMap
val tSubst: TypeTree => TypeTree = instantiateType(_, tMap)
val fdParams = fd.params map { vd =>
val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType))
ValDef(newId).setPos(vd.getPos)
}
val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap
val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap)
val newUnits = for (u <- program.units) yield { val newFd = new FunDef(fd.id, fd.tparams, tSubst(fd.returnType), fdParams).copiedFrom(fd)
var fdsOf = Map[String, Set[FunDef]]() newFd.copyContentFrom(fd)
// Lift methods to the root class mdToCls += newFd -> c
for {
ch <- u.classHierarchies
c <- ch
if c.parent.isDefined
fd <- c.methods
if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id))
} {
val root = c.ancestors.last
val tMap = c.tparams.zip(root.tparams.map{_.tp}).toMap
val tSubst: TypeTree => TypeTree = instantiateType(_, tMap)
val fdParams = fd.params map { vd => newFd.fullBody = eSubst(newFd.fullBody)
val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType))
ValDef(newId).setPos(vd.getPos)
}
val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap
val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap)
val newFd = new FunDef(fd.id, fd.tparams, tSubst(fd.returnType), fdParams).copiedFrom(fd)
newFd.copyContentFrom(fd)
val prec = fd.precondition.getOrElse(BooleanLiteral(true))
newFd.fullBody = eSubst(withPrecondition(
newFd.fullBody,
Some(and(
prec,
isInstOf(
This(root.typed),
c.typed(root.tparams.map{ _.tp })
)
))
))
c.unregisterMethod(fd.id) c.unregisterMethod(fd.id)
root.registerMethod(newFd) root.registerMethod(newFd)
} }
val newUnits = for (u <- program.units) yield {
var fdsOf = Map[String, Set[FunDef]]()
// 1) Create one function for each method // 1) Create one function for each method
for { cd <- u.classHierarchyRoots if cd.methods.nonEmpty; fd <- cd.methods } { for { cd <- u.classHierarchyRoots; fd <- cd.methods } {
// We import class type params and freshen them // We import class type params and freshen them
val ctParams = cd.tparams map { _.freshen } val ctParams = cd.tparams map { _.freshen }
val tparamsMap = cd.tparams.zip(ctParams map { _.tp }).toMap val tparamsMap = cd.tparams.zip(ctParams map { _.tp }).toMap
...@@ -157,6 +151,14 @@ object MethodLifting extends TransformationPhase { ...@@ -157,6 +151,14 @@ object MethodLifting extends TransformationPhase {
nfd.setPos(fd) nfd.setPos(fd)
nfd.addFlag(IsMethod(cd)) nfd.addFlag(IsMethod(cd))
def classPre(fd: FunDef) = mdToCls.get(fd) match {
case None =>
BooleanLiteral(true)
case Some(cl) =>
isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp }))
}
if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) { if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) {
// Don't need to compose methods // Don't need to compose methods
val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap
...@@ -168,8 +170,15 @@ object MethodLifting extends TransformationPhase { ...@@ -168,8 +170,15 @@ object MethodLifting extends TransformationPhase {
} }
val insTp: Expr => Expr = instantiateType(_, tparamsMap, paramsMap) val insTp: Expr => Expr = instantiateType(_, tparamsMap, paramsMap)
nfd.fullBody = insTp( postMap(thisToReceiver)(insTp(nfd.fullBody)) ) nfd.fullBody = insTp( postMap(thisToReceiver)(insTp(nfd.fullBody)) )
// Add precondition if the method was defined in a subclass
val pre = and(
classPre(fd),
nfd.precondition.getOrElse(BooleanLiteral(true))
)
nfd.fullBody = withPrecondition(nfd.fullBody, Some(pre))
} else { } else {
// We need to compose methods of subclasses // We need to compose methods of subclasses
...@@ -189,7 +198,7 @@ object MethodLifting extends TransformationPhase { ...@@ -189,7 +198,7 @@ object MethodLifting extends TransformationPhase {
(from,to) <- m.tparams zip fd.tparams (from,to) <- m.tparams zip fd.tparams
} yield (from, to.tp)).toMap } yield (from, to.tp)).toMap
def inst(cs: Seq[MatchCase]) = instantiateType( def inst(cs: Seq[MatchCase]) = instantiateType(
MatchExpr(Variable(receiver), cs).setPos(fd), matchExpr(Variable(receiver), cs).setPos(fd),
classParamsMap ++ methodParamsMap, classParamsMap ++ methodParamsMap,
paramsMap paramsMap
) )
...@@ -201,10 +210,18 @@ object MethodLifting extends TransformationPhase { ...@@ -201,10 +210,18 @@ object MethodLifting extends TransformationPhase {
)) ))
val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType))) val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType)))
/* Some obvious simplifications */ // Some simplifications
val preSimple = { val preSimple = {
val trivial = pre.forall { _.rhs == BooleanLiteral(true) } val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) }
if (trivial) None else Some(inst(pre).setPos(fd.getPos))
val compositePre =
if (nonTrivial == 0) {
BooleanLiteral(true)
} else {
inst(pre).setPos(fd.getPos)
}
Some(and(classPre(fd), compositePre))
} }
val postSimple = { val postSimple = {
val trivial = post.forall { val trivial = post.forall {
...@@ -235,8 +252,8 @@ object MethodLifting extends TransformationPhase { ...@@ -235,8 +252,8 @@ object MethodLifting extends TransformationPhase {
withPrecondition(bodySimple, preSimple), withPrecondition(bodySimple, preSimple),
postSimple postSimple
) )
} }
mdToFds += fd -> nfd mdToFds += fd -> nfd
fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd) fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment