From e6f9e4b490129dde96265db72dadb038c1a58b2c Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Thu, 24 Mar 2016 14:41:20 +0100 Subject: [PATCH] ADT Invariants in program transformations --- .../scala/leon/codegen/CodeGeneration.scala | 35 +++- .../scala/leon/codegen/CompilationUnit.scala | 13 +- .../leon/evaluators/RecursiveEvaluator.scala | 14 +- src/main/scala/leon/purescala/Common.scala | 8 +- src/main/scala/leon/purescala/DefOps.scala | 162 +++++++++-------- .../purescala/DefinitionTransformer.scala | 104 +++++++++++ .../leon/purescala/DependencyFinder.scala | 84 +++++++++ .../scala/leon/purescala/Expressions.scala | 2 +- .../scala/leon/purescala/MethodLifting.scala | 12 +- .../scala/leon/purescala/RestoreMethods.scala | 2 +- .../leon/purescala/TreeTransformer.scala | 23 ++- .../leon/solvers/theories/TheoryEncoder.scala | 171 +----------------- .../solvers/unrolling/UnrollingBank.scala | 46 +++-- .../solvers/unrolling/UnrollingSolver.scala | 10 +- .../scala/leon/synthesis/Synthesizer.scala | 2 +- .../leon/synthesis/rules/CEGISLike.scala | 11 +- .../leon/termination/TerminationChecker.scala | 2 +- .../solvers/StringRenderSuite.scala | 2 +- 18 files changed, 399 insertions(+), 304 deletions(-) create mode 100644 src/main/scala/leon/purescala/DefinitionTransformer.scala create mode 100644 src/main/scala/leon/purescala/DependencyFinder.scala diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 0e2a3fe23..866c3ce09 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -92,7 +92,8 @@ trait CodeGeneration { def idToSafeJVMName(id: Identifier) = { scala.reflect.NameTransformer.encode(id.uniqueName).replaceAll("\\.", "\\$") } - def defToJVMName(d : Definition) : String = "Leon$CodeGen$" + idToSafeJVMName(d.id) + + def defToJVMName(d: Definition): String = "Leon$CodeGen$" + idToSafeJVMName(d.id) /** Retrieve the name of the underlying lazy field from a lazy field accessor method */ private[codegen] def underlyingField(lazyAccessor : String) = lazyAccessor + "$underlying" @@ -201,7 +202,7 @@ trait CodeGeneration { funDef.fullBody } else { funDef.body.getOrElse( - if(funDef.annotations contains "extern") { + if (funDef.annotations contains "extern") { Error(funDef.id.getType, "Body of " + funDef.id.name + " not implemented at compile-time and still executed.") } else { throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name) @@ -489,8 +490,9 @@ trait CodeGeneration { } ch << New(ccName) << DUP load(monitorID, ch) + loadTypes(cct.tps, ch) - for((a, vd) <- as zip cct.classDef.fields) { + for ((a, vd) <- as zip cct.classDef.fields) { vd.getType match { case TypeParameter(_) => mkBoxedExpr(a, ch) @@ -668,6 +670,8 @@ trait CodeGeneration { load(monitorID, ch) ch << DUP_X2 << POP + loadTypes(Seq(tp), ch) + ch << DUP_X2 << POP ch << InvokeSpecial(consName, constructorName, ccApplySig) // list, it, newList @@ -1529,7 +1533,7 @@ trait CodeGeneration { } } - def compileAbstractClassDef(acd : AbstractClassDef) { + def compileAbstractClassDef(acd: AbstractClassDef) { val cName = defToJVMName(acd) @@ -1642,7 +1646,6 @@ trait CodeGeneration { } def compileCaseClassDef(ccd: CaseClassDef) { - val cName = defToJVMName(ccd) val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) // An instantiation of ccd with its own type parameters @@ -1656,13 +1659,14 @@ trait CodeGeneration { CLASS_ACC_FINAL ).asInstanceOf[U2]) - if(ccd.parent.isEmpty) { + if (ccd.parent.isEmpty) { cf.addInterface(CaseClassClass) } // Case class parameters val fieldsTypes = ccd.fields.map { vd => (vd.id, typeToJVM(vd.getType)) } - val constructorArgs = (monitorID -> s"L$MonitorClass;") +: fieldsTypes + val tpeParam = if (ccd.tparams.isEmpty) Seq() else Seq(tpsID -> "[I") + val constructorArgs = (monitorID -> s"L$MonitorClass;") +: (tpeParam ++ fieldsTypes) val newLocs = NoLocals.withFields(constructorArgs.map { case (id, jvmt) => (id, (cName, id.name, jvmt)) @@ -1674,7 +1678,7 @@ trait CodeGeneration { // Compile methods for (method <- methods) { - compileFunDef(method,ccd) + compileFunDef(method, ccd) } // Compile lazy fields @@ -1688,7 +1692,7 @@ trait CodeGeneration { } // definition of the constructor - for((id, jvmt) <- constructorArgs) { + for ((id, jvmt) <- constructorArgs) { val fh = cf.addField(jvmt, id.name) fh.setFlags(( FIELD_ACC_PUBLIC | @@ -1710,7 +1714,7 @@ trait CodeGeneration { } var c = 1 - for((id, jvmt) <- constructorArgs) { + for ((id, jvmt) <- constructorArgs) { cch << ALoad(0) cch << (jvmt match { case "I" | "Z" => ILoad(c) @@ -1734,6 +1738,17 @@ trait CodeGeneration { // Now initialize fields for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)(newLocs) } for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)(newLocs) } + + // Finally check invariant (if it exists) + if (params.checkContracts && ccd.hasInvariant) { + val thisId = FreshIdentifier("this", cct, true) + val invLocals = newLocs.withVar(thisId -> 0) + mkExpr(IfExpr(FunctionInvocation(cct.invariant.get, Seq(Variable(thisId))), + BooleanLiteral(true), + Error(BooleanType, "ADT Invariant failed @" + ccd.invariant.get.getPos)), cch)(invLocals) + cch << POP + } + cch << RETURN cch.freeze } diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 1bc3600fb..cd6bad57e 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -80,7 +80,7 @@ class CompilationUnit(val ctx: LeonContext, id } - def defineClass(df: Definition) { + def defineClass(df: Definition): Unit = { val cName = defToJVMName(df) val cf = df match { @@ -105,7 +105,8 @@ class CompilationUnit(val ctx: LeonContext, def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { classes.get(cd) match { case Some(cf) => - val sig = "(L"+MonitorClass+";" + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" + val tpeParam = if (cd.tparams.isEmpty) "" else "[I" + val sig = "(L"+MonitorClass+";" + tpeParam + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" Some((cf.className, sig)) case _ => None } @@ -114,7 +115,6 @@ class CompilationUnit(val ctx: LeonContext, // Returns className, methodName, methodSignature private[this] var funDefInfo = Map[FunDef, (String, String, String)]() - /** * Returns (cn, mn, sig) where * cn is the module name @@ -213,7 +213,8 @@ class CompilationUnit(val ctx: LeonContext, case CaseClass(cct, args) => caseClassConstructor(cct.classDef) match { case Some(cons) => - val jvmArgs = monitor +: args.map(valueToJVM) + val tpeParam = if (cct.tps.isEmpty) Seq() else Seq(cct.tps.map(registerType).toArray) + val jvmArgs = monitor +: (tpeParam ++ args.map(valueToJVM)) cons.newInstance(jvmArgs.toArray : _*).asInstanceOf[AnyRef] case None => ctx.reporter.fatalError("Case class constructor not found?!?") @@ -259,10 +260,8 @@ class CompilationUnit(val ctx: LeonContext, val lc = loader.loadClass(afName) val conss = lc.getConstructors.sortBy(_.getParameterTypes.length) - println(conss) assert(conss.nonEmpty) val lambdaConstructor = conss.last - println(args.toArray) lambdaConstructor.newInstance(args.toArray : _*).asInstanceOf[AnyRef] case f @ IsTyped(FiniteArray(elems, default, IntLiteral(length)), ArrayType(underlying)) => @@ -541,7 +540,7 @@ class CompilationUnit(val ctx: LeonContext, for (m <- u.modules) { defineClass(m) - for(funDef <- m.definedFunctions) { + for (funDef <- m.definedFunctions) { defToModuleOrClass += funDef -> m } } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 5b8b9ccae..418ac42e6 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -191,8 +191,18 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case _ => BooleanLiteral(lv == rv) } - case CaseClass(cd, args) => - CaseClass(cd, args.map(e)) + case CaseClass(cct, args) => + val cc = CaseClass(cct, args.map(e)) + if (cct.classDef.hasInvariant) { + e(FunctionInvocation(cct.invariant.get, Seq(cc))) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => + throw RuntimeError("ADT invariant violation for " + cct.classDef.id.asString + " reached in evaluation.: " + cct.invariant.get.asString) + case other => + throw RuntimeError(typeErrorMsg(other, BooleanType)) + } + } + cc case AsInstanceOf(expr, ct) => val le = e(expr) diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index 848ded806..c3af4f200 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -95,8 +95,10 @@ object Common { * @param tpe The type of the identifier * @param alwaysShowUniqueID If the unique ID should always be shown */ - def apply(name: String, tpe: TypeTree = Untyped, alwaysShowUniqueID: Boolean = false) : Identifier = - new Identifier(decode(name), uniqueCounter.nextGlobal, uniqueCounter.next(name), tpe, alwaysShowUniqueID) + def apply(name: String, tpe: TypeTree = Untyped, alwaysShowUniqueID: Boolean = false) : Identifier = { + val decoded = decode(name) + new Identifier(decoded, uniqueCounter.nextGlobal, uniqueCounter.next(decoded), tpe, alwaysShowUniqueID) + } /** Builds a fresh identifier, whose ID is always shown * @@ -104,7 +106,7 @@ object Common { * @param forceId The forced ID of the identifier * @param tpe The type of the identifier */ - def apply(name: String, forceId: Int, tpe: TypeTree): Identifier = + def apply(name: String, forceId: Int, tpe: TypeTree): Identifier = new Identifier(decode(name), uniqueCounter.nextGlobal, forceId, tpe, alwaysShowUniqueID = true) } diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 5363066f6..7afb63104 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -9,6 +9,8 @@ import ExprOps.{preMap, functionCallsOf} import leon.purescala.Types.AbstractClassType import leon.purescala.Types._ +import scala.collection.mutable.{Map => MutableMap} + object DefOps { private def packageOf(df: Definition)(implicit pgm: Program): PackageRef = { @@ -291,95 +293,114 @@ object DefOps { None } - /** Clones the given program by replacing some functions by other functions. - * - * @param p The original program - * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept. - * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use. - * By default it is the function invocation using the new function definition. - * @return the new program with a map from the old functions to the new functions */ - def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], - fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) - : (Program, Map[FunDef, FunDef]) = { + def replaceDefs(p: Program)(fdMapF: FunDef => Option[FunDef], + cdMapF: ClassDef => Option[ClassDef], + fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap, + ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) + : (Program, Map[Identifier, Identifier], Map[FunDef, FunDef], Map[ClassDef, ClassDef]) = { + + val idMap: MutableMap[Identifier, Identifier] = MutableMap.empty + val cdMap: MutableMap[ClassDef , ClassDef ] = MutableMap.empty + val fdMap: MutableMap[FunDef , FunDef ] = MutableMap.empty + + val dependencies = new DependencyFinder + val transformer = new TreeTransformer { + override def transform(id: Identifier): Identifier = idMap.getOrElse(id, { + val ntpe = transform(id.getType) + val nid = if (ntpe == id.getType) id else id.duplicate(tpe = ntpe) + idMap += id -> nid + nid + }) - var fdMapFCache = Map[FunDef, Option[FunDef]]() // Original fdMapF cache - var fdMapCache = Map[FunDef, Option[FunDef]]() // Final replacement. + override def transform(cd: ClassDef): ClassDef = cdMap.getOrElse(cd, cd) + override def transform(fd: FunDef): FunDef = fdMap.getOrElse(fd, fd) - def fdMapFCached(fd: FunDef): Option[FunDef] = { - fdMapFCache.get(fd) match { - case Some(e) => e - case None => - val new_fd = fdMapF(fd) - fdMapFCache += fd -> new_fd - new_fd + override def transform(expr: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = expr match { + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + val nfi = fiMapF(fi, transform(fd)) getOrElse expr + super.transform(nfi) + case cc @ CaseClass(cct, args) => + val ncc = ciMapF(cc, transform(cct).asInstanceOf[CaseClassType]) getOrElse expr + super.transform(ncc) + case _ => super.transform(expr) } } - def duplicateParents(fd: FunDef): Unit = { - fdMapCache.get(fd) match { - case None => - fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) - for(fp <- p.callGraph.callers(fd)) { - duplicateParents(fp) - } - case _ => - } - } + for (fd <- p.definedFunctions; nfd <- fdMapF(fd)) fdMap += fd -> nfd + for (cd <- p.definedClasses; ncd <- cdMapF(cd)) cdMap += cd -> ncd - def fdMap(fd: FunDef): FunDef = { - fdMapCache.get(fd) match { - case Some(Some(e)) => e - case Some(None) => fd - case None => - if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { - duplicateParents(fd) - } else { // Verify that for all - fdMapCache += fd -> None - } - fdMapCache(fd).getOrElse(fd) - } + def requiresReplacement(d: Definition): Boolean = dependencies(d).exists { + case cd: ClassDef => cdMap contains cd + case fd: FunDef => fdMap contains fd + case _ => false } - val newP = p.copy(units = for (u <- p.units) yield { - u.copy( - defs = u.defs.map { - case m : ModuleDef => - m.copy(defs = for (df <- m.defs) yield { - df match { - case f : FunDef => fdMap(f) - case d => d - } - }) - case d => d - } - ) + def trCd(cd: ClassDef): ClassDef = cdMap.getOrElse(cd, { + val parent = cd.parent.map(act => act.copy(classDef = trCd(act.classDef).asInstanceOf[AbstractClassDef])) + val newCd = cd match { + case acd: AbstractClassDef => acd.duplicate(parent = parent) + case ccd: CaseClassDef => ccd.duplicate(parent = parent) + } + cdMap += cd -> newCd + newCd }) - for (fd <- newP.definedFunctions) { - if (ExprOps.exists{ - case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd - case MatchExpr(_, cases) => cases.exists(c => PatternOps.exists{ - case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => fdMapCache contains fd - case _ => false - }(c.pattern)) - case _ => false - }(fd.fullBody)) { - fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) - } + for (cd <- p.definedClasses if requiresReplacement(cd)) trCd(cd) + for (fd <- p.definedFunctions if requiresReplacement(fd)) { + val newId = transformer.transform(fd.id) + val newReturn = transformer.transform(fd.returnType) + val newParams = fd.params map (vd => ValDef(transformer.transform(vd.id))) + fdMap += fd -> fd.duplicate(id = newId, params = newParams, returnType = newReturn) + } + + for ((cd,ncd) <- cdMap) (cd, ncd) match { + case (ccd: CaseClassDef, nccd: CaseClassDef) => + nccd.setFields(ccd.fields map (vd => ValDef(transformer.transform(vd.id)))) + ccd.invariant.foreach(fd => nccd.setInvariant(transformer.transform(fd))) + case _ => } - for (cd <- newP.classHierarchyRoots) { - cd.invariant.foreach(inv => cd.setInvariant(fdMap(inv))) + for ((fd,nfd) <- fdMap) { + val bindings = (fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap + nfd.fullBody = transformer.transform(fd.fullBody)(bindings) } - (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) + val newP = p.copy(units = for (u <- p.units) yield { + u.copy(defs = u.defs.map { + case m : ModuleDef => + m.copy(defs = for (df <- m.defs) yield { + df match { + case cd : ClassDef => transformer.transform(cd) + case fd : FunDef => transformer.transform(fd) + case d => d + } + }) + case cd: ClassDef => transformer.transform(cd) + case d => d + }) + }) + + (newP, idMap.toMap, fdMap.toMap, cdMap.toMap) + } + + /** Clones the given program by replacing some functions by other functions. + * + * @param p The original program + * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept. + * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use. + * By default it is the function invocation using the new function definition. + * @return the new program with a map from the old functions to the new functions */ + def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], + fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) + : (Program, Map[Identifier, Identifier], Map[FunDef, FunDef], Map[ClassDef, ClassDef]) = { + replaceDefs(p)(fdMapF, cd => None, fiMapF) } def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = { preMap { - case me@MatchExpr(scrut, cases) => + case me @ MatchExpr(scrut, cases) => Some(MatchExpr(scrut, cases.map(matchcase => matchcase match { - case mc@MatchCase(pattern, guard, rhs) => MatchCase(replaceFunCalls(pattern, fdMapF), guard, rhs).copiedFrom(mc) + case mc @ MatchCase(pattern, guard, rhs) => MatchCase(replaceFunCalls(pattern, fdMapF), guard, rhs).copiedFrom(mc) })).copiedFrom(me)) case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => fiMapF(fi, fdMapF(fd)).map(_.copiedFrom(fi)) @@ -410,7 +431,6 @@ object DefOps { def replaceCaseClassDefs(p: Program)(cdMapFOriginal: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { - var cdMapFCache = Map[CaseClassDef, Option[Option[AbstractClassType] => CaseClassDef]]() var cdMapCache = Map[ClassDef, Option[ClassDef]]() var idMapCache = Map[Identifier, Identifier]() diff --git a/src/main/scala/leon/purescala/DefinitionTransformer.scala b/src/main/scala/leon/purescala/DefinitionTransformer.scala new file mode 100644 index 000000000..89c5b613b --- /dev/null +++ b/src/main/scala/leon/purescala/DefinitionTransformer.scala @@ -0,0 +1,104 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Extractors._ +import Types._ + +import utils._ +import scala.collection.mutable.{Set => MutableSet} + +class DefinitionTransformer( + idMap: Bijection[Identifier, Identifier] = new Bijection[Identifier, Identifier], + fdMap: Bijection[FunDef , FunDef ] = new Bijection[FunDef , FunDef ], + cdMap: Bijection[ClassDef , ClassDef ] = new Bijection[ClassDef , ClassDef ]) extends TreeTransformer { + + override def transform(id: Identifier): Identifier = idMap.cachedB(id) { + val ntpe = transform(id.getType) + if (ntpe == id.getType) id else id.duplicate(tpe = ntpe) + } + + override def transform(fd: FunDef): FunDef = fdMap.getBorElse(fd, if (tmpDefs(fd)) fd else { + transformDefs(fd) + fdMap.toB(fd) + }) + + override def transform(cd: ClassDef): ClassDef = cdMap.getBorElse(cd, if (tmpDefs(cd)) cd else { + transformDefs(cd) + cdMap.toB(cd) + }) + + private val dependencies = new DependencyFinder + private val tmpDefs: MutableSet[Definition] = MutableSet.empty + + private def transformDefs(base: Definition): Unit = { + val deps = dependencies(base) + val (cds, fds) = { + val (c, f) = deps.partition(_.isInstanceOf[ClassDef]) + (c.map(_.asInstanceOf[ClassDef]), f.map(_.asInstanceOf[FunDef])) + } + + tmpDefs ++= cds.filterNot(cdMap containsA _) ++ fds.filterNot(fdMap containsA _) + + var requireCache: Map[Definition, Boolean] = Map.empty + def required(d: Definition): Boolean = requireCache.getOrElse(d, { + val res = d match { + case fd: FunDef => + val newReturn = transform(fd.returnType) + lazy val newParams = fd.params.map(vd => ValDef(transform(vd.id))) + lazy val newBody = transform(fd.fullBody)((fd.params.map(_.id) zip newParams.map(_.id)).toMap) + newReturn != fd.returnType || newParams != fd.params || newBody != fd.fullBody + + case cd: ClassDef => + cd.fieldsIds.exists(id => transform(id.getType) != id.getType) || + cd.invariant.exists(required) + + case _ => scala.sys.error("Should never happen!?") + } + + requireCache += d -> res + res + }) + + val req = deps filter required + val allReq = req ++ (deps filter (d => (dependencies(d) & req).nonEmpty)) + val requiredCds = allReq collect { case cd: ClassDef => cd } + val requiredFds = allReq collect { case fd: FunDef => fd } + tmpDefs --= deps + + val nonReq = deps filterNot allReq + cdMap ++= nonReq collect { case cd: ClassDef => cd -> cd } + fdMap ++= nonReq collect { case fd: FunDef => fd -> fd } + + def trCd(cd: ClassDef): ClassDef = cdMap.cachedB(cd) { + val parent = cd.parent.map(act => act.copy(classDef = trCd(act.classDef).asInstanceOf[AbstractClassDef])) + cd match { + case acd: AbstractClassDef => acd.duplicate(id = transform(acd.id), parent = parent) + case ccd: CaseClassDef => ccd.duplicate(id = transform(ccd.id), parent = parent) + } + } + + for (cd <- requiredCds) trCd(cd) + for (fd <- requiredFds) { + val newReturn = transform(fd.returnType) + val newParams = fd.params map (vd => ValDef(transform(vd.id))) + fdMap += fd -> fd.duplicate(id = transform(fd.id), params = newParams, returnType = newReturn) + } + + for (ccd <- requiredCds collect { case ccd: CaseClassDef => ccd }) { + val newCcd = cdMap.toB(ccd).asInstanceOf[CaseClassDef] + newCcd.setFields(ccd.fields.map(vd => ValDef(transform(vd.id)))) + ccd.invariant.foreach(fd => newCcd.setInvariant(transform(fd))) + } + + for (fd <- requiredFds) { + val nfd = fdMap.toB(fd) + fd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap) + } + } +} + diff --git a/src/main/scala/leon/purescala/DependencyFinder.scala b/src/main/scala/leon/purescala/DependencyFinder.scala new file mode 100644 index 000000000..5382046e1 --- /dev/null +++ b/src/main/scala/leon/purescala/DependencyFinder.scala @@ -0,0 +1,84 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Extractors._ +import Types._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +class DependencyFinder { + private val deps: MutableMap[Definition, Set[Definition]] = MutableMap.empty + + def apply(d: Definition): Set[Definition] = deps.getOrElse(d, { + new Finder(d).dependencies + }) + + private class Finder(var current: Definition) extends TreeTraverser { + val foundDeps: MutableMap[Definition, MutableSet[Definition]] = MutableMap.empty + foundDeps(current) = MutableSet.empty + + private def withCurrent[T](d: Definition)(b: => T): T = { + if (!(foundDeps contains d)) foundDeps(d) = MutableSet.empty + val c = current + current = d + val res = b + current = c + res + } + + override def traverse(id: Identifier): Unit = traverse(id.getType) + + override def traverse(cd: ClassDef): Unit = if (!foundDeps(current)(cd)) { + foundDeps(current) += cd + if (!(deps contains cd) && !(foundDeps contains cd)) { + for (cd <- cd.root.knownDescendants :+ cd) { + cd.invariant foreach (fd => withCurrent(cd)(traverse(fd))) + withCurrent(cd)(cd.fieldsIds foreach traverse) + cd.parent foreach { p => + foundDeps(p.classDef) = foundDeps.getOrElse(p.classDef, MutableSet.empty) + cd + foundDeps(cd) = foundDeps.getOrElse(cd, MutableSet.empty) + p.classDef + } + } + } + } + + override def traverse(fd: FunDef): Unit = if (!foundDeps(current)(fd)) { + foundDeps(current) += fd + if (!(deps contains fd) && !(foundDeps contains fd)) withCurrent(fd) { + fd.params foreach (vd => traverse(vd.id)) + traverse(fd.returnType) + traverse(fd.fullBody) + } + } + + def dependencies: Set[Definition] = { + current match { + case fd: FunDef => traverse(fd) + case cd: ClassDef => traverse(cd) + case _ => + } + + for ((d, ds) <- foundDeps) { + deps(d) = deps.getOrElse(d, Set.empty) ++ ds + } + + var changed = false + do { + for ((d, ds) <- deps.toSeq) { + val next = ds.flatMap(d => deps.getOrElse(d, Set.empty)) + if (!(next subsetOf ds)) { + deps(d) = next + changed = true + } + } + } while (changed) + + deps(current) + } + } +} diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 25d1da631..79172eb7a 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -388,7 +388,7 @@ object Expressions { * @param out The output expression * @param cases The cases to compare against */ - case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends Expr { + case class Passes(in: Expr, out: Expr, cases: Seq[MatchCase]) extends Expr { require(cases.nonEmpty) val getType = leastUpperBound(cases.map(_.rhs.getType)) match { diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 957505f29..13810f83d 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -156,19 +156,18 @@ object MethodLifting extends TransformationPhase { isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp })) } - if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) { // Don't need to compose methods - val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap + val paramsMap = fd.params.zip(fdParams).map { case (x,y) => (x.id, y.id) }.toMap def thisToReceiver(e: Expr): Option[Expr] = e match { - case th@This(ct) => + case th @ This(ct) => Some(asInstOf(receiver.toVariable, ct).setPos(th)) case _ => None } val insTp: Expr => Expr = instantiateType(_, tparamsMap, paramsMap) - nfd.fullBody = insTp( postMap(thisToReceiver)(insTp(nfd.fullBody)) ) + nfd.fullBody = postMap(thisToReceiver)(insTp(nfd.fullBody)) // Add precondition if the method was defined in a subclass val pre = and( @@ -186,19 +185,22 @@ object MethodLifting extends TransformationPhase { m <- c.methods if m.id == fd.id (from,to) <- m.params zip fdParams } yield (from.id, to.id)).toMap + val classParamsMap = (for { c <- cd.knownDescendants :+ cd (from, to) <- c.tparams zip ctParams } yield (from, to.tp)).toMap + val methodParamsMap = (for { c <- cd.knownDescendants :+ cd m <- c.methods if m.id == fd.id (from,to) <- m.tparams zip fd.tparams } yield (from, to.tp)).toMap + def inst(cs: Seq[MatchCase]) = instantiateType( matchExpr(Variable(receiver), cs).setPos(fd), classParamsMap ++ methodParamsMap, - paramsMap + paramsMap + (receiver -> receiver) ) /* Separately handle pre, post, body */ diff --git a/src/main/scala/leon/purescala/RestoreMethods.scala b/src/main/scala/leon/purescala/RestoreMethods.scala index 5c0abc5e5..664b18978 100644 --- a/src/main/scala/leon/purescala/RestoreMethods.scala +++ b/src/main/scala/leon/purescala/RestoreMethods.scala @@ -57,7 +57,7 @@ object RestoreMethods extends TransformationPhase { }) }) - val (np2, _) = replaceFunDefs(np)(fd => None, { (fi, fd) => + val (np2, _, _, _) = replaceFunDefs(np)(fd => None, { (fi, fd) => fdToMd.get(fi.tfd.fd) match { case Some(md) => Some(MethodInvocation( diff --git a/src/main/scala/leon/purescala/TreeTransformer.scala b/src/main/scala/leon/purescala/TreeTransformer.scala index 88b267670..77a5659f1 100644 --- a/src/main/scala/leon/purescala/TreeTransformer.scala +++ b/src/main/scala/leon/purescala/TreeTransformer.scala @@ -9,6 +9,9 @@ import Expressions._ import Extractors._ import Types._ +import utils._ +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + trait TreeTransformer { def transform(id: Identifier): Identifier = id def transform(cd: ClassDef): ClassDef = cd @@ -22,11 +25,11 @@ trait TreeTransformer { transform(default), transform(tpe).asInstanceOf[FunctionType]).copiedFrom(e) case Lambda(args, body) => val newArgs = args.map(vd => ValDef(transform(vd.id))) - val newBindings = (args zip newArgs).filter(p => p._1 != p._2).map(p => p._1.id -> p._2.id) + val newBindings = (args zip newArgs).map(p => p._1.id -> p._2.id) Lambda(newArgs, transform(body)(bindings ++ newBindings)).copiedFrom(e) case Forall(args, body) => val newArgs = args.map(vd => ValDef(transform(vd.id))) - val newBindings = (args zip newArgs).filter(p => p._1 != p._2).map(p => p._1.id -> p._2.id) + val newBindings = (args zip newArgs).map(p => p._1.id -> p._2.id) Forall(newArgs, transform(body)(bindings ++ newBindings)).copiedFrom(e) case Let(a, expr, body) => val newA = transform(a) @@ -53,6 +56,12 @@ trait TreeTransformer { val allBindings = bindings ++ newBindings MatchCase(newPattern, guard.map(g => transform(g)(allBindings)), transform(rhs)(allBindings)).copiedFrom(cse) }).copiedFrom(e) + case Passes(in, out, cases) => + Passes(transform(in), transform(out), for (cse @ MatchCase(pattern, guard, rhs) <- cases) yield { + val (newPattern, newBindings) = transform(pattern) + val allBindings = bindings ++ newBindings + MatchCase(newPattern, guard.map(g => transform(g)(allBindings)), transform(rhs)(allBindings)).copiedFrom(cse) + }).copiedFrom(e) case FiniteSet(es, tpe) => FiniteSet(es map transform, transform(tpe)).copiedFrom(e) case FiniteBag(es, tpe) => @@ -81,26 +90,26 @@ trait TreeTransformer { case InstanceOfPattern(binder, ct) => val newBinder = binder map transform val newPat = InstanceOfPattern(newBinder, transform(ct).asInstanceOf[ClassType]).copiedFrom(pat) - (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap) + (newPat, (binder zip newBinder).toMap) case WildcardPattern(binder) => val newBinder = binder map transform val newPat = WildcardPattern(newBinder).copiedFrom(pat) - (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap) + (newPat, (binder zip newBinder).toMap) case CaseClassPattern(binder, ct, subs) => val newBinder = binder map transform val (newSubs, subBinders) = (subs map transform).unzip val newPat = CaseClassPattern(newBinder, transform(ct).asInstanceOf[CaseClassType], newSubs).copiedFrom(pat) - (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap ++ subBinders.flatten) + (newPat, (binder zip newBinder).toMap ++ subBinders.flatten) case TuplePattern(binder, subs) => val newBinder = binder map transform val (newSubs, subBinders) = (subs map transform).unzip val newPat = TuplePattern(newBinder, newSubs).copiedFrom(pat) - (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap ++ subBinders.flatten) + (newPat, (binder zip newBinder).toMap ++ subBinders.flatten) case UnapplyPattern(binder, TypedFunDef(fd, tpes), subs) => val newBinder = binder map transform val (newSubs, subBinders) = (subs map transform).unzip val newPat = UnapplyPattern(newBinder, TypedFunDef(transform(fd), tpes map transform), newSubs).copiedFrom(pat) - (newPat, (binder zip newBinder).filter(p => p._1 != p._2).toMap ++ subBinders.flatten) + (newPat, (binder zip newBinder).toMap ++ subBinders.flatten) case PatternExtractor(subs, builder) => val (newSubs, subBinders) = (subs map transform).unzip (builder(newSubs).copiedFrom(pat), subBinders.flatten.toMap) diff --git a/src/main/scala/leon/solvers/theories/TheoryEncoder.scala b/src/main/scala/leon/solvers/theories/TheoryEncoder.scala index e0de93c43..9cb18effd 100644 --- a/src/main/scala/leon/solvers/theories/TheoryEncoder.scala +++ b/src/main/scala/leon/solvers/theories/TheoryEncoder.scala @@ -35,175 +35,8 @@ trait TheoryEncoder { self => def encode(fd: FunDef): FunDef = encoder.transform(fd) def decode(fd: FunDef): FunDef = decoder.transform(fd) - protected trait Converter extends purescala.TreeTransformer { - private[TheoryEncoder] val idMap: Bijection[Identifier, Identifier] - private[TheoryEncoder] val fdMap: Bijection[FunDef , FunDef ] - private[TheoryEncoder] val cdMap: Bijection[ClassDef , ClassDef ] - - override def transform(id: Identifier): Identifier = idMap.cachedB(id) { - val ntpe = transform(id.getType) - if (ntpe == id.getType) id else id.duplicate(tpe = ntpe) - } - - override def transform(fd: FunDef): FunDef = fdMap.getBorElse(fd, if (tmpDefs(fd)) fd else { - transformDefs(fd) - fdMap.toB(fd) - }) - - override def transform(cd: ClassDef): ClassDef = cdMap.getBorElse(cd, if (tmpDefs(cd)) cd else { - transformDefs(cd) - cdMap.toB(cd) - }) - - private val deps: MutableMap[Definition, Set[Definition]] = MutableMap.empty - private val tmpDefs: MutableSet[Definition] = MutableSet.empty - - private class DependencyFinder(var current: Definition) extends purescala.TreeTraverser { - val deps: MutableMap[Definition, MutableSet[Definition]] = MutableMap.empty - deps(current) = MutableSet.empty - - private def withCurrent[T](d: Definition)(b: => T): T = { - if (!(deps contains d)) deps(d) = MutableSet.empty - val c = current - current = d - val res = b - current = c - res - } - - override def traverse(id: Identifier): Unit = traverse(id.getType) - - override def traverse(cd: ClassDef): Unit = if (!deps(current)(cd)) { - deps(current) += cd - if (!(Converter.this.deps contains cd) && !(deps contains cd)) { - for (cd <- cd.root.knownDescendants :+ cd) { - cd.invariant foreach (fd => withCurrent(cd)(traverse(fd))) - withCurrent(cd)(cd.fieldsIds foreach traverse) - cd.parent foreach { p => - deps(p.classDef) = deps.getOrElse(p.classDef, MutableSet.empty) + cd - deps(cd) = deps.getOrElse(cd, MutableSet.empty) + p.classDef - } - } - } - } - - override def traverse(fd: FunDef): Unit = if (!deps(current)(fd)) { - deps(current) += fd - if (!(Converter.this.deps contains fd) && !(deps contains fd)) withCurrent(fd) { - fd.params foreach (vd => traverse(vd.id)) - traverse(fd.returnType) - traverse(fd.fullBody) - } - } - - def dependencies: Set[Definition] = { - current match { - case fd: FunDef => traverse(fd) - case cd: ClassDef => traverse(cd) - case _ => - } - - for ((d, ds) <- deps) { - Converter.this.deps(d) = Converter.this.deps.getOrElse(d, Set.empty) ++ ds - } - - var changed = false - do { - for ((d, ds) <- Converter.this.deps.toSeq) { - val next = ds.flatMap(d => Converter.this.deps.getOrElse(d, Set.empty)) - if (!(next subsetOf ds)) { - Converter.this.deps(d) = next - changed = true - } - } - } while (changed) - - Converter.this.deps(current) - } - } - - private def dependencies(d: Definition): Set[Definition] = deps.getOrElse(d, { - new DependencyFinder(d).dependencies - }) - - private def transformDefs(base: Definition): Unit = { - val deps = dependencies(base) - val (cds, fds) = { - val (c, f) = deps.partition(_.isInstanceOf[ClassDef]) - (c.map(_.asInstanceOf[ClassDef]), f.map(_.asInstanceOf[FunDef])) - } - - tmpDefs ++= cds.filterNot(cdMap containsA _) ++ fds.filterNot(fdMap containsA _) - - var requireCache: Map[Definition, Boolean] = Map.empty - def required(d: Definition): Boolean = requireCache.getOrElse(d, { - val res = d match { - case fd: FunDef => - val newReturn = transform(fd.returnType) - lazy val newParams = fd.params.map(vd => ValDef(transform(vd.id))) - lazy val newBody = transform(fd.fullBody)((fd.params.map(_.id) zip newParams.map(_.id)).toMap) - newReturn != fd.returnType || newParams != fd.params || newBody != fd.fullBody - - case cd: ClassDef => - cd.fieldsIds.exists(id => transform(id.getType) != id.getType) || - cd.invariant.exists(required) - - case _ => scala.sys.error("Should never happen!?") - } - - requireCache += d -> res - res - }) - - val req = deps filter required - val allReq = req ++ (deps filter (d => (dependencies(d) & req).nonEmpty)) - val requiredCds = allReq collect { case cd: ClassDef => cd } - val requiredFds = allReq collect { case fd: FunDef => fd } - tmpDefs --= deps - - val nonReq = deps filterNot allReq - cdMap ++= nonReq collect { case cd: ClassDef => cd -> cd } - fdMap ++= nonReq collect { case fd: FunDef => fd -> fd } - - def trCd(cd: ClassDef): ClassDef = cdMap.cachedB(cd) { - val parent = cd.parent.map(act => act.copy(classDef = trCd(act.classDef).asInstanceOf[AbstractClassDef])) - cd match { - case acd: AbstractClassDef => acd.duplicate(id = transform(acd.id), parent = parent) - case ccd: CaseClassDef => ccd.duplicate(id = transform(ccd.id), parent = parent) - } - } - - for (cd <- requiredCds) trCd(cd) - for (fd <- requiredFds) { - val newReturn = transform(fd.returnType) - val newParams = fd.params map (vd => ValDef(transform(vd.id))) - fdMap += fd -> fd.duplicate(id = transform(fd.id), params = newParams, returnType = newReturn) - } - - for (ccd <- requiredCds collect { case ccd: CaseClassDef => ccd }) { - val newCcd = cdMap.toB(ccd).asInstanceOf[CaseClassDef] - newCcd.setFields(ccd.fields.map(vd => ValDef(transform(vd.id)))) - newCcd.invariant.foreach(fd => ccd.setInvariant(transform(fd))) - } - - for (fd <- requiredFds) { - val nfd = fdMap.toB(fd) - fd.fullBody = transform(fd.fullBody)((fd.params zip nfd.params).map(p => p._1.id -> p._2.id).toMap) - } - } - } - - protected class Encoder extends Converter { - private[TheoryEncoder] final val idMap: Bijection[Identifier, Identifier] = TheoryEncoder.this.idMap - private[TheoryEncoder] final val fdMap: Bijection[FunDef , FunDef ] = TheoryEncoder.this.fdMap - private[TheoryEncoder] final val cdMap: Bijection[ClassDef , ClassDef ] = TheoryEncoder.this.cdMap - } - - protected class Decoder extends Converter { - private[TheoryEncoder] final val idMap: Bijection[Identifier, Identifier] = TheoryEncoder.this.idMap.swap - private[TheoryEncoder] final val fdMap: Bijection[FunDef , FunDef ] = TheoryEncoder.this.fdMap.swap - private[TheoryEncoder] final val cdMap: Bijection[ClassDef , ClassDef ] = TheoryEncoder.this.cdMap.swap - } + protected class Encoder extends purescala.DefinitionTransformer(idMap, fdMap, cdMap) + protected class Decoder extends purescala.DefinitionTransformer(idMap.swap, fdMap.swap, cdMap.swap) def >>(that: TheoryEncoder): TheoryEncoder = new TheoryEncoder { val encoder = new Encoder { diff --git a/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala b/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala index 7ba55bd37..2c38aca15 100644 --- a/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/unrolling/UnrollingBank.scala @@ -212,30 +212,40 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def promoteBlocker(b: T, force: Boolean = false): Boolean = { var seen: Set[T] = Set.empty var promoted: Boolean = false + var blockers: Seq[Set[T]] = Seq(Set(b)) - def rec(b: T): Unit = if (!seen(b)) { - seen += b - if (callInfos contains b) { - val (_, origGen, notB, fis) = callInfos(b) + do { + val (bs +: rest) = blockers + blockers = rest - callInfos += b -> (1, origGen, notB, fis) - promoted = true - } + val next = (for (b <- bs if !seen(b)) yield { + seen += b - if (blockerToApps contains b) { - val app = blockerToApps(b) - val (_, origGen, _, notB, infos) = appInfos(app) + if (callInfos contains b) { + val (_, origGen, notB, fis) = callInfos(b) - appInfos += app -> (1, origGen, b, notB, infos) - promoted = true - } + callInfos += b -> (1, origGen, notB, fis) + promoted = true + } - if (!promoted && force) { - for (cb <- templateGenerator.manager.blockerChildren(b)) rec(cb) - } - } + if (blockerToApps contains b) { + val app = blockerToApps(b) + val (_, origGen, _, notB, infos) = appInfos(app) + + appInfos += app -> (1, origGen, b, notB, infos) + promoted = true + } + + if (force) { + templateGenerator.manager.blockerChildren(b) + } else { + Set.empty[T] + } + }).flatten + + if (next.nonEmpty) blockers :+= next + } while (!promoted && blockers.nonEmpty) - rec(b) promoted } diff --git a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala index 76f18864b..601aa84eb 100644 --- a/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/unrolling/UnrollingSolver.scala @@ -388,9 +388,13 @@ trait AbstractUnrollingSolver[T] if (this.feelingLucky) { // we might have been lucky :D - val extracted = extractModel(model) - val valid = validateModel(extracted, assumptionsSeq, silenceErrors = true) - if (valid) foundAnswer(Some(true), extracted) + try { + val extracted = extractModel(model) + val valid = validateModel(extracted, assumptionsSeq, silenceErrors = true) + if (valid) foundAnswer(Some(true), extracted) + } catch { + case u: Unsupported => + } } if (!foundDefinitiveAnswer) { diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index c81b045d0..23d79c1e7 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -167,7 +167,7 @@ class Synthesizer(val context : LeonContext, val solutionExpr = sol.toSimplifiedExpr(context, program, ci.fd) - val (npr, fdMap) = replaceFunDefs(program)({ + val (npr, _, fdMap, _) = replaceFunDefs(program)({ case fd if fd eq ci.fd => val nfd = fd.duplicate() nfd.fullBody = replace(Map(ci.source -> solutionExpr), nfd.fullBody) diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index 7ac1efc04..238d3198f 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -274,7 +274,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { private val phiFd = new FunDef(FreshIdentifier("phiFd", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), BooleanType) // The program with the body of the current function replaced by the current partial solution - private val (innerProgram, origFdMap) = { + private val (innerProgram, origIdMap, origFdMap, origCdMap) = { val outerSolution = { new PartialSolution(hctx.search.strat, true) @@ -313,7 +313,12 @@ abstract class CEGISLike(name: String) extends Rule(name) { case fd => Some(fd.duplicate()) } + } + private val outerToInner = new purescala.TreeTransformer { + override def transform(id: Identifier): Identifier = origIdMap.getOrElse(id, id) + override def transform(cd: ClassDef): ClassDef = origCdMap.getOrElse(cd, cd) + override def transform(fd: FunDef): FunDef = origFdMap.getOrElse(fd, fd) } /** @@ -322,9 +327,7 @@ abstract class CEGISLike(name: String) extends Rule(name) { * to the CEGIS-specific program, 'outer' refers to the actual program on * which we do synthesis. */ - private def outerExprToInnerExpr(e: Expr): Expr = { - replaceFunCalls(e, {fd => origFdMap.getOrElse(fd, fd) }) - } + private def outerExprToInnerExpr(e: Expr): Expr = outerToInner.transform(e)(Map.empty) private val innerPc = outerExprToInnerExpr(p.pc) private val innerPhi = outerExprToInnerExpr(p.phi) diff --git a/src/main/scala/leon/termination/TerminationChecker.scala b/src/main/scala/leon/termination/TerminationChecker.scala index efe8bf438..266e13aa2 100644 --- a/src/main/scala/leon/termination/TerminationChecker.scala +++ b/src/main/scala/leon/termination/TerminationChecker.scala @@ -9,7 +9,7 @@ import purescala.DefOps._ abstract class TerminationChecker(val context: LeonContext, initProgram: Program) extends LeonComponent { val program = { - val (pgm, _) = replaceFunDefs(initProgram){ fd => Some(fd.duplicate()) } + val (pgm, _, _, _) = replaceFunDefs(initProgram){ fd => Some(fd.duplicate()) } pgm } diff --git a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala index c8a3000dc..fe02ce2ef 100644 --- a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala +++ b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala @@ -115,7 +115,7 @@ class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with Scal val newProgram = DefOps.addFunDefs(synth.program, solutions.head.defs, synth.sctx.functionContext) val newFd = ci.fd.duplicate() newFd.body = Some(solutions.head.term) - val (newProgram2, _) = DefOps.replaceFunDefs(newProgram)({ fd => + val (newProgram2, _, _, _) = DefOps.replaceFunDefs(newProgram)({ fd => if(fd == ci.fd) { Some(newFd) } else None -- GitLab