Skip to content
Snippets Groups Projects
Commit 615494a4 authored by Mikaël Mayer's avatar Mikaël Mayer
Browse files

Working fundef and classdef replacement w.r.t. the tests (thanks to Etienne)

parent 325d0379
No related branches found
No related tags found
No related merge requests found
...@@ -289,42 +289,43 @@ object DefOps { ...@@ -289,42 +289,43 @@ object DefOps {
fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap)
: (Program, Map[FunDef, FunDef])= { : (Program, Map[FunDef, FunDef])= {
var fdMapCache = Map[FunDef, FunDef]() var fdMapFCache = Map[FunDef, Option[FunDef]]() // Original fdMapF cache
var fdStatic = Set[FunDef]() var fdMapCache = Map[FunDef, Option[FunDef]]() // Final replacement.
def fdMap(fd: FunDef): FunDef = { def fdMapFCached(fd: FunDef): Option[FunDef] = {
if (!(fdMapCache contains fd)) { fdMapFCache.get(fd) match {
if(fdStatic(fd)) fd else { case Some(e) => e
// We have to duplicate all calling functions except those not required to change and whose transitive descendants are all not required to change. case None =>
// If one descendant is required to change, then all its transitive callers have to change including the function itself. val new_fd = fdMapF(fd)
// The RHS of mappings is true if there is a change to propagate for this function. fdMapFCache += fd -> new_fd
val mappings = for(callee <- (fd::p.callGraph.transitiveCallees(fd).toList).distinct) yield { new_fd
callee -> (fdMapCache.get(callee) match { }
case Some(newFd) => None }
case None => fdMapF(callee)
}) def duplicateParents(fd: FunDef): Unit = {
} fdMapCache.get(fd) match {
for(m <- mappings) { case None =>
m match { fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate()))
case (f, Some(newFd)) => for(fp <- p.callGraph.callers(fd)) {
fdMapCache += f -> newFd duplicateParents(fp)
for(caller <- p.callGraph.transitiveCallers(f)) {
if(!(fdMapCache contains caller)) {
fdMapCache += f -> f.duplicate()
}
}
case (f, _) =>
}
} }
if(fdMapCache contains fd) { case _ =>
fdMapCache(fd) }
} else { }
fdStatic += fd
fd def fdMap(fd: FunDef): FunDef = {
fdMapCache.get(fd) match {
case Some(Some(e)) => e
case Some(None) => fd
case None =>
if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) {
duplicateParents(fd)
} else { // Verify that for all
fdMapCache += fd -> None
} }
} fdMapCache(fd).getOrElse(fd)
} else fdMapCache(fd) }
} }
val newP = p.copy(units = for (u <- p.units) yield { val newP = p.copy(units = for (u <- p.units) yield {
u.copy( u.copy(
defs = u.defs.map { defs = u.defs.map {
...@@ -339,7 +340,7 @@ object DefOps { ...@@ -339,7 +340,7 @@ object DefOps {
} }
) )
}) })
for(fd <- newP.definedFunctions) { for(fd <- newP.definedFunctions) {
if(ExprOps.exists{ if(ExprOps.exists{
case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd
...@@ -349,10 +350,10 @@ object DefOps { ...@@ -349,10 +350,10 @@ object DefOps {
}(c.pattern)) }(c.pattern))
case _ => false case _ => false
}(fd.fullBody)) { }(fd.fullBody)) {
fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache.withDefault { x => x }, fiMapF) fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF)
} }
} }
(newP, fdMapCache) (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd })
} }
def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = { def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = {
...@@ -384,97 +385,77 @@ object DefOps { ...@@ -384,97 +385,77 @@ object DefOps {
* *
* @param p The original program * @param p The original program
* @param cdMapF Given c returns Some(d) where d can take an abstract parent and return a class e if c should be replaced by e, and None if c should be kept. * @param cdMapF Given c returns Some(d) where d can take an abstract parent and return a class e if c should be replaced by e, and None if c should be kept.
* @param fiMapF Given a previous case class invocation and its new case class definition, returns the expression to use. * @param ciMapF Given a previous case class invocation and its new case class definition, returns the expression to use.
* By default it is the case class construction using the new case class definition. * By default it is the case class construction using the new case class definition.
* @return the new program with a map from the old case classes to the new case classes */ * @return the new program with a map from the old case classes to the new case classes, with maps concerning identifiers and function definitions. */
def replaceCaseClassDefs(p: Program)(cdMapF: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], def replaceCaseClassDefs(p: Program)(_cdMapF: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef],
ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap)
: (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = {
var cdMapCache = Map[ClassDef, ClassDef]() var cdMapFCache = Map[CaseClassDef, Option[Option[AbstractClassType] => CaseClassDef]]()
var cdStatic = Set[ClassDef]() var cdMapCache = Map[ClassDef, Option[ClassDef]]()
var idMapCache = Map[Identifier, Identifier]() var idMapCache = Map[Identifier, Identifier]()
var fdMapCache = Map[FunDef, FunDef]() var fdMapFCache = Map[FunDef, Option[FunDef]]()
def tpMap(tt: TypeTree): TypeTree = TypeOps.postMap{ var fdMapCache = Map[FunDef, Option[FunDef]]()
def cdMapF(cd: ClassDef): Option[Option[AbstractClassType] => CaseClassDef] = {
cd match {
case ccd: CaseClassDef =>
cdMapFCache.getOrElse(ccd, {
val new_cd_potential = cdMapF(ccd)
cdMapFCache += ccd -> new_cd_potential
new_cd_potential
})
case acd: AbstractClassDef => None
}
}
def tpMap[T <: TypeTree](tt: T): T = TypeOps.postMap{
case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs)) case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs))
case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs)) case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs))
case e => None case e => None
}(tt) }(tt).asInstanceOf[T]
def cdMap(cd: ClassDef): ClassDef = { def duplicateClassDef(cd: ClassDef): Unit = {
if (!(cdMapCache contains cd)) { cdMapCache.get(cd) match {
if(cdStatic(cd)) cd else { case Some(new_cd) =>
// If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. case None =>
// If something extends List[A] and A is modified, then the first something should be modified. val parent = cd.parent.map(duplicateAbstractClassType)
def dependencies(s: ClassDef): Set[ClassDef] = { val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse{
Set(s) ++ s.parent.toList.flatMap(p => TypeOps.collect[ClassDef]{ cd match {
case AbstractClassType(acd, _) => Set(acd:ClassDef) ++ acd.knownCCDescendants case acd:AbstractClassDef => acd.duplicate(parent = parent)
case CaseClassType(ccd, _) => Set(ccd:ClassDef) case ccd:CaseClassDef => ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract.
}(p))
}
def hierarchy(s: ClassDef): Seq[ClassDef] = s::s.ancestors.toList
def collectDependencies(c: ClassDef) = leon.utils.fixpoint((s: Set[ClassDef]) => s.flatMap(dependencies))(Set(c))
val collectedDependencies = collectDependencies(cd)
val mappings = (for(callee <- collectedDependencies.collect{ case c: CaseClassDef => c}) yield {
callee -> (cdMapCache.get(callee) match {
case Some(newcd) => None
case None => cdMapF(callee)
})
}).toMap
def isChanging(c: ClassDef): Boolean = {
c match {
case ccd: CaseClassDef =>
mappings.get(ccd) match {
case Some(Some(_)) => true
case _ => false }
case acd: AbstractClassDef =>
acd.knownCCDescendants.exists(isChanging)
}
}
def duplicateClassDef(cd: ClassDef): ClassDef = {
cdMapCache.get(cd) match {
case Some(new_cd) => new_cd
case None =>
val old_parent = cd.parent
val parent = old_parent.map(duplicateAbstractClassType)
val new_cd = (cd match { case cc:CaseClassDef => mappings.get(cc) case _ => None }) match {
case Some(Some(new_cd_if_parent)) =>
new_cd_if_parent(parent)
case _ =>
cd match {
case acd:AbstractClassDef => acd.duplicate(parent = parent)
case ccd:CaseClassDef => ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract.
}
}
cdMapCache += cd -> new_cd
new_cd
} }
} }
// TODO: Do not unnecessarily duplicate the argument types if they don't have to change. cdMapCache += cd -> Some(new_cd)
def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { }
TypeOps.postMap{ }
case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps))
case CaseClassType(ccd, tps) => Some(CaseClassType(duplicateClassDef(ccd).asInstanceOf[CaseClassDef], tps)) def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = {
case _ => None TypeOps.postMap{
}(act).asInstanceOf[AbstractClassType] case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps))
} case CaseClassType(ccd, tps) => Some(CaseClassType(duplicateClassDef(ccd).asInstanceOf[CaseClassDef], tps))
for((c, funMod) <- mappings) { case _ => None
// If the dependencies of c contain a class to be transformed, duplicate this class and its hierarchy }(act).asInstanceOf[AbstractClassType]
if(funMod.nonEmpty || collectDependencies(c).exists(isChanging)) { }
cdMapCache += c -> duplicateClassDef(c)
} else { // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted.
cdStatic += c // If something extends List[A] and A is modified, then the first something should be modified.
} def dependencies(s: ClassDef): Set[ClassDef] = {
Set(s) ++ s.parent.toList.flatMap(p => TypeOps.collect[ClassDef]{
case AbstractClassType(acd, _) => Set(acd:ClassDef) ++ acd.knownCCDescendants
case CaseClassType(ccd, _) => Set(ccd:ClassDef)
}(p))
}
def cdMap(cd: ClassDef): ClassDef = {
cdMapCache.get(cd) match {
case Some(Some(new_cd)) => new_cd
case Some(None) => cd
case None =>
if(cdMapF(cd).isDefined || dependencies(cd).exists(cd => cdMapF(cd).isDefined)) { // Needs replacement in any case.
duplicateClassDef(cd)
} else {
cdMapCache += cd -> None
} }
cdMapCache.getOrElse(cd, cd) cdMapCache(cd).getOrElse(cd)
}
} else {
cdMapCache(cd)
} }
} }
def idMap(id: Identifier): Identifier = { def idMap(id: Identifier): Identifier = {
...@@ -483,11 +464,97 @@ object DefOps { ...@@ -483,11 +464,97 @@ object DefOps {
} }
idMapCache(id) idMapCache(id)
} }
def idHasToChange(id: Identifier): Boolean = {
typeHasToChange(id.getType)
}
def typeHasToChange(tp: TypeTree): Boolean = {
TypeOps.exists{
case AbstractClassType(acd, _) => cdMap(acd) != acd
case CaseClassType(ccd, _) => cdMap(ccd) != ccd
}(tp)
}
def patternHasToChange(p: Pattern): Boolean = {
PatternOps.exists {
case CaseClassPattern(optId, cct, sub) => optId.exists(idHasToChange) || typeHasToChange(cct)
case InstanceOfPattern(optId, cct) => optId.exists(idHasToChange) || typeHasToChange(cct)
case Extractors.Pattern(optId, subp, builder) => optId.exists(idHasToChange)
case e => false
} (p)
}
def exprHasToChange(e: Expr): Boolean = {
ExprOps.exists{
case Let(id, expr, body) => idHasToChange(id)
case Variable(id) => idHasToChange(id)
case ci @ CaseClass(cct, args) => typeHasToChange(cct)
case CaseClassSelector(cct, expr, identifier) => typeHasToChange(cct) || idHasToChange(identifier)
case IsInstanceOf(e, cct) => typeHasToChange(cct)
case AsInstanceOf(e, cct) => typeHasToChange(cct)
case MatchExpr(scrut, cases) =>
cases.exists{
case MatchCase(pattern, optGuard, rhs) =>
patternHasToChange(pattern)
}
case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) =>
tps.exists(typeHasToChange)
case _ =>
false
}(e)
}
def funDefHasToChange(fd: FunDef): Boolean = {
exprHasToChange(fd.fullBody) || fd.params.exists(vid => typeHasToChange(vid.id.getType)) || typeHasToChange(fd.returnType)
}
def funHasToChange(fd: FunDef): Boolean = {
funDefHasToChange(fd) || p.callGraph.transitiveCallees(fd).exists(fd =>
fdMapFCache.get(fd) match {
case Some(Some(_)) => true
case Some(None) => false
case None => funDefHasToChange(fd)
})
}
def fdMapFCached(fd: FunDef): Option[FunDef] = {
fdMapFCache.get(fd) match {
case Some(e) => e
case None =>
val new_fd = if(funHasToChange(fd)) {
Some(fd.duplicate(params = fd.params.map(vd => ValDef(idMap(vd.id))), returnType = tpMap(fd.returnType)))
} else {
None
}
fdMapFCache += fd -> new_fd
new_fd
}
}
def duplicateParents(fd: FunDef): Unit = {
fdMapCache.get(fd) match {
case None =>
fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate()))
for(fp <- p.callGraph.callers(fd)) {
duplicateParents(fp)
}
case _ =>
}
}
def fdMap(fd: FunDef): FunDef = { def fdMap(fd: FunDef): FunDef = {
if (!(fdMapCache contains fd)) { fdMapCache.get(fd) match {
fdMapCache += fd -> fd.duplicate(params = fd.params.map(vd => ValDef(idMap(vd.id))), returnType = tpMap(fd.returnType)) case Some(Some(e)) => e
case Some(None) => fd
case None =>
if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) {
duplicateParents(fd)
} else { // Verify that for all
fdMapCache += fd -> None
}
fdMapCache(fd).getOrElse(fd)
} }
fdMapCache(fd)
} }
val newP = p.copy(units = for (u <- p.units) yield { val newP = p.copy(units = for (u <- p.units) yield {
...@@ -505,28 +572,9 @@ object DefOps { ...@@ -505,28 +572,9 @@ object DefOps {
} }
) )
}) })
object ToTransform {
def unapply(c: ClassType): Option[ClassDef] = Some(cdMap(c.classDef))
}
trait Transformed[T <: TypeTree] {
def unapply(c: T): Option[T] = Some(TypeOps.postMap({
case c: ClassType =>
val newClassDef = cdMap(c.classDef)
Some((c match {
case CaseClassType(ccd, tps) =>
CaseClassType(newClassDef.asInstanceOf[CaseClassDef], tps.map(e => TypeOps.postMap{ case TypeTransformed(ct) => Some(ct) case _ => None }(e)))
case AbstractClassType(acd, tps) =>
AbstractClassType(newClassDef.asInstanceOf[AbstractClassDef], tps.map(e => TypeOps.postMap{ case TypeTransformed(ct) => Some(ct) case _ => None }(e)))
}).asInstanceOf[T])
case _ => None
})(c).asInstanceOf[T])
}
object CaseClassTransformed extends Transformed[CaseClassType]
object ClassTransformed extends Transformed[ClassType]
object TypeTransformed extends Transformed[TypeTree]
def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{ def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{
case CaseClassPattern(optId, CaseClassTransformed(ct), sub) => Some(CaseClassPattern(optId.map(idMap), ct, sub)) case CaseClassPattern(optId, cct, sub) => Some(CaseClassPattern(optId.map(idMap), tpMap[CaseClassType](cct), sub))
case InstanceOfPattern(optId, ClassTransformed(ct)) => Some(InstanceOfPattern(optId.map(idMap), ct)) case InstanceOfPattern(optId, cct) => Some(InstanceOfPattern(optId.map(idMap), tpMap[ClassType](cct)))
case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId.map(idMap), TypedFunDef(fdMap(fd), tps.map(tpMap)), subp)) case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId.map(idMap), TypedFunDef(fdMap(fd), tps.map(tpMap)), subp))
case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp)) case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp))
case e => None case e => None
...@@ -536,13 +584,12 @@ object DefOps { ...@@ -536,13 +584,12 @@ object DefOps {
ExprOps.postMap { ExprOps.postMap {
case Let(id, expr, body) => Some(Let(idMap(id), expr, body)) case Let(id, expr, body) => Some(Let(idMap(id), expr, body))
case Variable(id) => Some(Variable(idMap(id))) case Variable(id) => Some(Variable(idMap(id)))
case ci @ CaseClass(CaseClassTransformed(ct), args) => case ci @ CaseClass(ct, args) =>
ciMapF(ci, ct).map(_.setPos(ci)) ciMapF(ci, tpMap(ct)).map(_.setPos(ci))
//case IsInstanceOf(e, ToTransform()) => case CaseClassSelector(cct, expr, identifier) =>
case CaseClassSelector(CaseClassTransformed(cct), expr, identifier) => Some(CaseClassSelector(tpMap(cct), expr, idMap(identifier)))
Some(CaseClassSelector(cct, expr, idMap(identifier))) case IsInstanceOf(e, ct) => Some(IsInstanceOf(e, tpMap(ct)))
case IsInstanceOf(e, ClassTransformed(ct)) => Some(IsInstanceOf(e, ct)) case AsInstanceOf(e, ct) => Some(AsInstanceOf(e, tpMap(ct)))
case AsInstanceOf(e, ClassTransformed(ct)) => Some(AsInstanceOf(e, ct))
case MatchExpr(scrut, cases) => case MatchExpr(scrut, cases) =>
Some(MatchExpr(scrut, cases.map{ Some(MatchExpr(scrut, cases.map{
case MatchCase(pattern, optGuard, rhs) => case MatchCase(pattern, optGuard, rhs) =>
...@@ -558,7 +605,10 @@ object DefOps { ...@@ -558,7 +605,10 @@ object DefOps {
for(fd <- newP.definedFunctions) { for(fd <- newP.definedFunctions) {
fd.fullBody = replaceClassDefsUse(fd.fullBody) fd.fullBody = replaceClassDefsUse(fd.fullBody)
} }
(newP, cdMapCache, idMapCache, fdMapCache) (newP,
cdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd},
idMapCache,
fdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd })
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment