diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index f53e9c1e77b931351f834f810ff412a07189f525..78c2b93edecf113c717a625a3c974c8dcf43e474 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -289,42 +289,43 @@ object DefOps { fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) : (Program, Map[FunDef, FunDef])= { - var fdMapCache = Map[FunDef, FunDef]() - var fdStatic = Set[FunDef]() - def fdMap(fd: FunDef): FunDef = { - if (!(fdMapCache contains fd)) { - if(fdStatic(fd)) fd else { - // We have to duplicate all calling functions except those not required to change and whose transitive descendants are all not required to change. - // If one descendant is required to change, then all its transitive callers have to change including the function itself. - // The RHS of mappings is true if there is a change to propagate for this function. - val mappings = for(callee <- (fd::p.callGraph.transitiveCallees(fd).toList).distinct) yield { - callee -> (fdMapCache.get(callee) match { - case Some(newFd) => None - case None => fdMapF(callee) - }) - } - for(m <- mappings) { - m match { - case (f, Some(newFd)) => - fdMapCache += f -> newFd - for(caller <- p.callGraph.transitiveCallers(f)) { - if(!(fdMapCache contains caller)) { - fdMapCache += f -> f.duplicate() - } - } - case (f, _) => - } + var fdMapFCache = Map[FunDef, Option[FunDef]]() // Original fdMapF cache + var fdMapCache = Map[FunDef, Option[FunDef]]() // Final replacement. + def fdMapFCached(fd: FunDef): Option[FunDef] = { + fdMapFCache.get(fd) match { + case Some(e) => e + case None => + val new_fd = fdMapF(fd) + fdMapFCache += fd -> new_fd + new_fd + } + } + + def duplicateParents(fd: FunDef): Unit = { + fdMapCache.get(fd) match { + case None => + fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) + for(fp <- p.callGraph.callers(fd)) { + duplicateParents(fp) } - if(fdMapCache contains fd) { - fdMapCache(fd) - } else { - fdStatic += fd - fd + case _ => + } + } + + def fdMap(fd: FunDef): FunDef = { + fdMapCache.get(fd) match { + case Some(Some(e)) => e + case Some(None) => fd + case None => + if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { + duplicateParents(fd) + } else { // Verify that for all + fdMapCache += fd -> None } - } - } else fdMapCache(fd) + fdMapCache(fd).getOrElse(fd) + } } - + val newP = p.copy(units = for (u <- p.units) yield { u.copy( defs = u.defs.map { @@ -339,7 +340,7 @@ object DefOps { } ) }) - + for(fd <- newP.definedFunctions) { if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd @@ -349,10 +350,10 @@ object DefOps { }(c.pattern)) case _ => false }(fd.fullBody)) { - fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache.withDefault { x => x }, fiMapF) + fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) } } - (newP, fdMapCache) + (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) } def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = { @@ -384,97 +385,77 @@ object DefOps { * * @param p The original program * @param cdMapF Given c returns Some(d) where d can take an abstract parent and return a class e if c should be replaced by e, and None if c should be kept. - * @param fiMapF Given a previous case class invocation and its new case class definition, returns the expression to use. + * @param ciMapF Given a previous case class invocation and its new case class definition, returns the expression to use. * By default it is the case class construction using the new case class definition. - * @return the new program with a map from the old case classes to the new case classes */ - def replaceCaseClassDefs(p: Program)(cdMapF: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], + * @return the new program with a map from the old case classes to the new case classes, with maps concerning identifiers and function definitions. */ + def replaceCaseClassDefs(p: Program)(_cdMapF: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { - var cdMapCache = Map[ClassDef, ClassDef]() - var cdStatic = Set[ClassDef]() + var cdMapFCache = Map[CaseClassDef, Option[Option[AbstractClassType] => CaseClassDef]]() + var cdMapCache = Map[ClassDef, Option[ClassDef]]() var idMapCache = Map[Identifier, Identifier]() - var fdMapCache = Map[FunDef, FunDef]() - def tpMap(tt: TypeTree): TypeTree = TypeOps.postMap{ + var fdMapFCache = Map[FunDef, Option[FunDef]]() + var fdMapCache = Map[FunDef, Option[FunDef]]() + def cdMapF(cd: ClassDef): Option[Option[AbstractClassType] => CaseClassDef] = { + cd match { + case ccd: CaseClassDef => + cdMapFCache.getOrElse(ccd, { + val new_cd_potential = cdMapF(ccd) + cdMapFCache += ccd -> new_cd_potential + new_cd_potential + }) + case acd: AbstractClassDef => None + } + } + def tpMap[T <: TypeTree](tt: T): T = TypeOps.postMap{ case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs)) case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs)) case e => None - }(tt) + }(tt).asInstanceOf[T] - def cdMap(cd: ClassDef): ClassDef = { - if (!(cdMapCache contains cd)) { - if(cdStatic(cd)) cd else { - // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. - // If something extends List[A] and A is modified, then the first something should be modified. - def dependencies(s: ClassDef): Set[ClassDef] = { - Set(s) ++ s.parent.toList.flatMap(p => TypeOps.collect[ClassDef]{ - case AbstractClassType(acd, _) => Set(acd:ClassDef) ++ acd.knownCCDescendants - case CaseClassType(ccd, _) => Set(ccd:ClassDef) - }(p)) - } - - def hierarchy(s: ClassDef): Seq[ClassDef] = s::s.ancestors.toList - - def collectDependencies(c: ClassDef) = leon.utils.fixpoint((s: Set[ClassDef]) => s.flatMap(dependencies))(Set(c)) - - val collectedDependencies = collectDependencies(cd) - - val mappings = (for(callee <- collectedDependencies.collect{ case c: CaseClassDef => c}) yield { - callee -> (cdMapCache.get(callee) match { - case Some(newcd) => None - case None => cdMapF(callee) - }) - }).toMap - - def isChanging(c: ClassDef): Boolean = { - c match { - case ccd: CaseClassDef => - mappings.get(ccd) match { - case Some(Some(_)) => true - case _ => false } - case acd: AbstractClassDef => - acd.knownCCDescendants.exists(isChanging) - } - } - - def duplicateClassDef(cd: ClassDef): ClassDef = { - cdMapCache.get(cd) match { - case Some(new_cd) => new_cd - case None => - val old_parent = cd.parent - val parent = old_parent.map(duplicateAbstractClassType) - val new_cd = (cd match { case cc:CaseClassDef => mappings.get(cc) case _ => None }) match { - case Some(Some(new_cd_if_parent)) => - new_cd_if_parent(parent) - case _ => - cd match { - case acd:AbstractClassDef => acd.duplicate(parent = parent) - case ccd:CaseClassDef => ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract. - } - } - cdMapCache += cd -> new_cd - new_cd + def duplicateClassDef(cd: ClassDef): Unit = { + cdMapCache.get(cd) match { + case Some(new_cd) => + case None => + val parent = cd.parent.map(duplicateAbstractClassType) + val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse{ + cd match { + case acd:AbstractClassDef => acd.duplicate(parent = parent) + case ccd:CaseClassDef => ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract. } } - // TODO: Do not unnecessarily duplicate the argument types if they don't have to change. - def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { - TypeOps.postMap{ - case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps)) - case CaseClassType(ccd, tps) => Some(CaseClassType(duplicateClassDef(ccd).asInstanceOf[CaseClassDef], tps)) - case _ => None - }(act).asInstanceOf[AbstractClassType] - } - for((c, funMod) <- mappings) { - // If the dependencies of c contain a class to be transformed, duplicate this class and its hierarchy - if(funMod.nonEmpty || collectDependencies(c).exists(isChanging)) { - cdMapCache += c -> duplicateClassDef(c) - } else { - cdStatic += c - } + cdMapCache += cd -> Some(new_cd) + } + } + + def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { + TypeOps.postMap{ + case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps)) + case CaseClassType(ccd, tps) => Some(CaseClassType(duplicateClassDef(ccd).asInstanceOf[CaseClassDef], tps)) + case _ => None + }(act).asInstanceOf[AbstractClassType] + } + + // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. + // If something extends List[A] and A is modified, then the first something should be modified. + def dependencies(s: ClassDef): Set[ClassDef] = { + Set(s) ++ s.parent.toList.flatMap(p => TypeOps.collect[ClassDef]{ + case AbstractClassType(acd, _) => Set(acd:ClassDef) ++ acd.knownCCDescendants + case CaseClassType(ccd, _) => Set(ccd:ClassDef) + }(p)) + } + + def cdMap(cd: ClassDef): ClassDef = { + cdMapCache.get(cd) match { + case Some(Some(new_cd)) => new_cd + case Some(None) => cd + case None => + if(cdMapF(cd).isDefined || dependencies(cd).exists(cd => cdMapF(cd).isDefined)) { // Needs replacement in any case. + duplicateClassDef(cd) + } else { + cdMapCache += cd -> None } - cdMapCache.getOrElse(cd, cd) - } - } else { - cdMapCache(cd) + cdMapCache(cd).getOrElse(cd) } } def idMap(id: Identifier): Identifier = { @@ -483,11 +464,97 @@ object DefOps { } idMapCache(id) } + + def idHasToChange(id: Identifier): Boolean = { + typeHasToChange(id.getType) + } + + def typeHasToChange(tp: TypeTree): Boolean = { + TypeOps.exists{ + case AbstractClassType(acd, _) => cdMap(acd) != acd + case CaseClassType(ccd, _) => cdMap(ccd) != ccd + }(tp) + } + + def patternHasToChange(p: Pattern): Boolean = { + PatternOps.exists { + case CaseClassPattern(optId, cct, sub) => optId.exists(idHasToChange) || typeHasToChange(cct) + case InstanceOfPattern(optId, cct) => optId.exists(idHasToChange) || typeHasToChange(cct) + case Extractors.Pattern(optId, subp, builder) => optId.exists(idHasToChange) + case e => false + } (p) + } + + def exprHasToChange(e: Expr): Boolean = { + ExprOps.exists{ + case Let(id, expr, body) => idHasToChange(id) + case Variable(id) => idHasToChange(id) + case ci @ CaseClass(cct, args) => typeHasToChange(cct) + case CaseClassSelector(cct, expr, identifier) => typeHasToChange(cct) || idHasToChange(identifier) + case IsInstanceOf(e, cct) => typeHasToChange(cct) + case AsInstanceOf(e, cct) => typeHasToChange(cct) + case MatchExpr(scrut, cases) => + cases.exists{ + case MatchCase(pattern, optGuard, rhs) => + patternHasToChange(pattern) + } + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + tps.exists(typeHasToChange) + case _ => + false + }(e) + } + + def funDefHasToChange(fd: FunDef): Boolean = { + exprHasToChange(fd.fullBody) || fd.params.exists(vid => typeHasToChange(vid.id.getType)) || typeHasToChange(fd.returnType) + } + + def funHasToChange(fd: FunDef): Boolean = { + funDefHasToChange(fd) || p.callGraph.transitiveCallees(fd).exists(fd => + fdMapFCache.get(fd) match { + case Some(Some(_)) => true + case Some(None) => false + case None => funDefHasToChange(fd) + }) + } + + def fdMapFCached(fd: FunDef): Option[FunDef] = { + fdMapFCache.get(fd) match { + case Some(e) => e + case None => + val new_fd = if(funHasToChange(fd)) { + Some(fd.duplicate(params = fd.params.map(vd => ValDef(idMap(vd.id))), returnType = tpMap(fd.returnType))) + } else { + None + } + fdMapFCache += fd -> new_fd + new_fd + } + } + + def duplicateParents(fd: FunDef): Unit = { + fdMapCache.get(fd) match { + case None => + fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) + for(fp <- p.callGraph.callers(fd)) { + duplicateParents(fp) + } + case _ => + } + } + def fdMap(fd: FunDef): FunDef = { - if (!(fdMapCache contains fd)) { - fdMapCache += fd -> fd.duplicate(params = fd.params.map(vd => ValDef(idMap(vd.id))), returnType = tpMap(fd.returnType)) + fdMapCache.get(fd) match { + case Some(Some(e)) => e + case Some(None) => fd + case None => + if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { + duplicateParents(fd) + } else { // Verify that for all + fdMapCache += fd -> None + } + fdMapCache(fd).getOrElse(fd) } - fdMapCache(fd) } val newP = p.copy(units = for (u <- p.units) yield { @@ -505,28 +572,9 @@ object DefOps { } ) }) - object ToTransform { - def unapply(c: ClassType): Option[ClassDef] = Some(cdMap(c.classDef)) - } - trait Transformed[T <: TypeTree] { - def unapply(c: T): Option[T] = Some(TypeOps.postMap({ - case c: ClassType => - val newClassDef = cdMap(c.classDef) - Some((c match { - case CaseClassType(ccd, tps) => - CaseClassType(newClassDef.asInstanceOf[CaseClassDef], tps.map(e => TypeOps.postMap{ case TypeTransformed(ct) => Some(ct) case _ => None }(e))) - case AbstractClassType(acd, tps) => - AbstractClassType(newClassDef.asInstanceOf[AbstractClassDef], tps.map(e => TypeOps.postMap{ case TypeTransformed(ct) => Some(ct) case _ => None }(e))) - }).asInstanceOf[T]) - case _ => None - })(c).asInstanceOf[T]) - } - object CaseClassTransformed extends Transformed[CaseClassType] - object ClassTransformed extends Transformed[ClassType] - object TypeTransformed extends Transformed[TypeTree] def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{ - case CaseClassPattern(optId, CaseClassTransformed(ct), sub) => Some(CaseClassPattern(optId.map(idMap), ct, sub)) - case InstanceOfPattern(optId, ClassTransformed(ct)) => Some(InstanceOfPattern(optId.map(idMap), ct)) + case CaseClassPattern(optId, cct, sub) => Some(CaseClassPattern(optId.map(idMap), tpMap[CaseClassType](cct), sub)) + case InstanceOfPattern(optId, cct) => Some(InstanceOfPattern(optId.map(idMap), tpMap[ClassType](cct))) case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId.map(idMap), TypedFunDef(fdMap(fd), tps.map(tpMap)), subp)) case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp)) case e => None @@ -536,13 +584,12 @@ object DefOps { ExprOps.postMap { case Let(id, expr, body) => Some(Let(idMap(id), expr, body)) case Variable(id) => Some(Variable(idMap(id))) - case ci @ CaseClass(CaseClassTransformed(ct), args) => - ciMapF(ci, ct).map(_.setPos(ci)) - //case IsInstanceOf(e, ToTransform()) => - case CaseClassSelector(CaseClassTransformed(cct), expr, identifier) => - Some(CaseClassSelector(cct, expr, idMap(identifier))) - case IsInstanceOf(e, ClassTransformed(ct)) => Some(IsInstanceOf(e, ct)) - case AsInstanceOf(e, ClassTransformed(ct)) => Some(AsInstanceOf(e, ct)) + case ci @ CaseClass(ct, args) => + ciMapF(ci, tpMap(ct)).map(_.setPos(ci)) + case CaseClassSelector(cct, expr, identifier) => + Some(CaseClassSelector(tpMap(cct), expr, idMap(identifier))) + case IsInstanceOf(e, ct) => Some(IsInstanceOf(e, tpMap(ct))) + case AsInstanceOf(e, ct) => Some(AsInstanceOf(e, tpMap(ct))) case MatchExpr(scrut, cases) => Some(MatchExpr(scrut, cases.map{ case MatchCase(pattern, optGuard, rhs) => @@ -558,7 +605,10 @@ object DefOps { for(fd <- newP.definedFunctions) { fd.fullBody = replaceClassDefsUse(fd.fullBody) } - (newP, cdMapCache, idMapCache, fdMapCache) + (newP, + cdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd}, + idMapCache, + fdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd }) }