diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 541053a7751badc090ac06eb73bfcce27bfdcadb..f53e9c1e77b931351f834f810ff412a07189f525 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -296,28 +296,27 @@ object DefOps { if(fdStatic(fd)) fd else { // We have to duplicate all calling functions except those not required to change and whose transitive descendants are all not required to change. // If one descendant is required to change, then all its transitive callers have to change including the function itself. + // The RHS of mappings is true if there is a change to propagate for this function. val mappings = for(callee <- (fd::p.callGraph.transitiveCallees(fd).toList).distinct) yield { - callee -> fdMapCache.get(callee).orElse(fdMapF(callee)) + callee -> (fdMapCache.get(callee) match { + case Some(newFd) => None + case None => fdMapF(callee) + }) } - if(mappings.exists(kv => kv._2 != None && kv._2 != Some(kv._1))) { - for(m <- mappings) { - m match { - case (f, Some(newFd)) => - fdMapCache += f -> newFd - for(caller <- p.callGraph.transitiveCallers(f)) { - if(!(fdMapCache contains caller)) { - fdMapCache += f -> f.duplicate() - } + for(m <- mappings) { + m match { + case (f, Some(newFd)) => + fdMapCache += f -> newFd + for(caller <- p.callGraph.transitiveCallers(f)) { + if(!(fdMapCache contains caller)) { + fdMapCache += f -> f.duplicate() } - case (f, None) => - } - } - if(fdMapCache contains fd) { - fdMapCache(fd) - } else { - fdStatic += fd - fd + } + case (f, _) => } + } + if(fdMapCache contains fd) { + fdMapCache(fd) } else { fdStatic += fd fd @@ -384,15 +383,15 @@ object DefOps { /** Clones the given program by replacing some classes by other classes. * * @param p The original program - * @param cdMapF Given c and its cloned parent, returns Some(d) if c should be replaced by d, and None if c should be kept. - * Will always start to call this method for the topmost parents, and then descending. + * @param cdMapF Given c returns Some(d) where d can take an abstract parent and return a class e if c should be replaced by e, and None if c should be kept. * @param fiMapF Given a previous case class invocation and its new case class definition, returns the expression to use. * By default it is the case class construction using the new case class definition. * @return the new program with a map from the old case classes to the new case classes */ - def replaceClassDefs(p: Program)(cdMapF: (ClassDef, Option[AbstractClassType]) => Option[ClassDef], - ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) - : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { + def replaceCaseClassDefs(p: Program)(cdMapF: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], + ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) + : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { var cdMapCache = Map[ClassDef, ClassDef]() + var cdStatic = Set[ClassDef]() var idMapCache = Map[Identifier, Identifier]() var fdMapCache = Map[FunDef, FunDef]() def tpMap(tt: TypeTree): TypeTree = TypeOps.postMap{ @@ -403,22 +402,80 @@ object DefOps { def cdMap(cd: ClassDef): ClassDef = { if (!(cdMapCache contains cd)) { - lazy val parent = cd.parent.map( tpMap(_).asInstanceOf[AbstractClassType] ) - val ncd = cdMapF(cd, parent) match { - case Some(new_ccd) => - for((old_id, new_id) <- cd.fieldsIds.zip(new_ccd.fieldsIds)) { - idMapCache += old_id -> new_id + if(cdStatic(cd)) cd else { + // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. + // If something extends List[A] and A is modified, then the first something should be modified. + def dependencies(s: ClassDef): Set[ClassDef] = { + Set(s) ++ s.parent.toList.flatMap(p => TypeOps.collect[ClassDef]{ + case AbstractClassType(acd, _) => Set(acd:ClassDef) ++ acd.knownCCDescendants + case CaseClassType(ccd, _) => Set(ccd:ClassDef) + }(p)) + } + + def hierarchy(s: ClassDef): Seq[ClassDef] = s::s.ancestors.toList + + def collectDependencies(c: ClassDef) = leon.utils.fixpoint((s: Set[ClassDef]) => s.flatMap(dependencies))(Set(c)) + + val collectedDependencies = collectDependencies(cd) + + val mappings = (for(callee <- collectedDependencies.collect{ case c: CaseClassDef => c}) yield { + callee -> (cdMapCache.get(callee) match { + case Some(newcd) => None + case None => cdMapF(callee) + }) + }).toMap + + def isChanging(c: ClassDef): Boolean = { + c match { + case ccd: CaseClassDef => + mappings.get(ccd) match { + case Some(Some(_)) => true + case _ => false } + case acd: AbstractClassDef => + acd.knownCCDescendants.exists(isChanging) } - new_ccd - case None => - cd match { - case acd:AbstractClassDef => acd.duplicate(parent = parent) - case ccd:CaseClassDef => ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract. + } + + def duplicateClassDef(cd: ClassDef): ClassDef = { + cdMapCache.get(cd) match { + case Some(new_cd) => new_cd + case None => + val old_parent = cd.parent + val parent = old_parent.map(duplicateAbstractClassType) + val new_cd = (cd match { case cc:CaseClassDef => mappings.get(cc) case _ => None }) match { + case Some(Some(new_cd_if_parent)) => + new_cd_if_parent(parent) + case _ => + cd match { + case acd:AbstractClassDef => acd.duplicate(parent = parent) + case ccd:CaseClassDef => ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract. + } + } + cdMapCache += cd -> new_cd + new_cd + } + } + // TODO: Do not unnecessarily duplicate the argument types if they don't have to change. + def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { + TypeOps.postMap{ + case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps)) + case CaseClassType(ccd, tps) => Some(CaseClassType(duplicateClassDef(ccd).asInstanceOf[CaseClassDef], tps)) + case _ => None + }(act).asInstanceOf[AbstractClassType] + } + for((c, funMod) <- mappings) { + // If the dependencies of c contain a class to be transformed, duplicate this class and its hierarchy + if(funMod.nonEmpty || collectDependencies(c).exists(isChanging)) { + cdMapCache += c -> duplicateClassDef(c) + } else { + cdStatic += c } + } + cdMapCache.getOrElse(cd, cd) } - cdMapCache += cd -> ncd + } else { + cdMapCache(cd) } - cdMapCache(cd) } def idMap(id: Identifier): Identifier = { if (!(idMapCache contains id)) { diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala index 2807c2394c0b1540094ddeffd0bccd2071be9b3f..b94233f285e1fe63c486dd2711c63e115ef1dd7b 100644 --- a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -45,12 +45,12 @@ object Z3StringCapableSolver { var hasStrings = false val program_with_strings = converter.getProgram val (program_with_correct_classes, cdMap, idMap, fdMap) = if(program_with_strings.definedClasses.exists{ case c: CaseClassDef => c.fieldsIds.exists(id => TypeOps.exists{ _ == StringType}(id.getType)) case _ => false}) { - val res:(Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = DefOps.replaceClassDefs(program_with_strings)((cd: ClassDef, parent: Option[AbstractClassType]) => { + val res:(Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = DefOps.replaceCaseClassDefs(program_with_strings)((cd: ClassDef) => { cd match { case acd:AbstractClassDef => None case ccd:CaseClassDef => if(ccd.fieldsIds.exists(id => TypeOps.exists(StringType == _)(id.getType))) { - Some(ccd.duplicate(convertId(ccd.id), ccd.tparams, ccd.fieldsIds.map(id => ValDef(convertId(id))), parent, ccd.isCaseObject)) + Some((parent: Option[AbstractClassType]) => ccd.duplicate(convertId(ccd.id), ccd.tparams, ccd.fieldsIds.map(id => ValDef(convertId(id))), parent, ccd.isCaseObject)) } else None } })