diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 7cc576adbd6dac0a277d40ade8c7dc06a4836602..541053a7751badc090ac06eb73bfcce27bfdcadb 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -290,15 +290,42 @@ object DefOps { : (Program, Map[FunDef, FunDef])= { var fdMapCache = Map[FunDef, FunDef]() + var fdStatic = Set[FunDef]() def fdMap(fd: FunDef): FunDef = { if (!(fdMapCache contains fd)) { - fdMapCache += fd -> fdMapF(fd).getOrElse(fd.duplicate()) - } - - fdMapCache(fd) + if(fdStatic(fd)) fd else { + // We have to duplicate all calling functions except those not required to change and whose transitive descendants are all not required to change. + // If one descendant is required to change, then all its transitive callers have to change including the function itself. + val mappings = for(callee <- (fd::p.callGraph.transitiveCallees(fd).toList).distinct) yield { + callee -> fdMapCache.get(callee).orElse(fdMapF(callee)) + } + if(mappings.exists(kv => kv._2 != None && kv._2 != Some(kv._1))) { + for(m <- mappings) { + m match { + case (f, Some(newFd)) => + fdMapCache += f -> newFd + for(caller <- p.callGraph.transitiveCallers(f)) { + if(!(fdMapCache contains caller)) { + fdMapCache += f -> f.duplicate() + } + } + case (f, None) => + } + } + if(fdMapCache contains fd) { + fdMapCache(fd) + } else { + fdStatic += fd + fd + } + } else { + fdStatic += fd + fd + } + } + } else fdMapCache(fd) } - val newP = p.copy(units = for (u <- p.units) yield { u.copy( defs = u.defs.map { @@ -313,17 +340,28 @@ object DefOps { } ) }) - // TODO: Check for function invocations in unapply patterns. + for(fd <- newP.definedFunctions) { - if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd case _ => false }(fd.fullBody)) { - fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache, fiMapF) + if(ExprOps.exists{ + case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd + case MatchExpr(_, cases) => cases.exists(c => PatternOps.exists{ + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => fdMapCache contains fd + case _ => false + }(c.pattern)) + case _ => false + }(fd.fullBody)) { + fd.fullBody = replaceFunCalls(fd.fullBody, fdMapCache.withDefault { x => x }, fiMapF) } } (newP, fdMapCache) } - def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { + def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = { preMap { + case MatchExpr(scrut, cases) => + Some(MatchExpr(scrut, cases.map(matchcase => matchcase match { + case MatchCase(pattern, guard, rhs) => MatchCase(replaceFunCalls(pattern, fdMapF), guard, rhs) + }))) case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => fiMapF(fi, fdMapF(fd)).map(_.setPos(fi)) case _ => @@ -331,6 +369,10 @@ object DefOps { }(e) } + def replaceFunCalls(p: Pattern, fdMapF: FunDef => FunDef): Pattern = PatternOps.preMap{ + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId, TypedFunDef(fdMapF(fd), tps), subp)) + case _ => None + }(p) private def defaultCdMap(cc: CaseClass, ccd: CaseClassType): Option[Expr] = (cc, ccd) match { case (CaseClass(old, args), newCcd) if old.classDef != newCcd => @@ -488,7 +530,12 @@ object DefOps { ) }) if (!found) { - println("addDefs could not find anchor definition!") + println(s"addDefs could not find anchor definition! Not found: $after") + p.definedFunctions.filter(f => f.id.name == after.id.name).map(fd => fd.id.name + " : " + fd) match { + case Nil => + case e => println("Did you mean " + e) + } + println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) } res }