From 5a7f93afee19dc9bec6156039f2e03aa33bf9556 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Fri, 4 Dec 2015 12:48:08 +0100 Subject: [PATCH] Small fix for termination and optimized lambda instantiation --- .../solvers/templates/LambdaManager.scala | 233 +++++++++++++++--- .../templates/QuantificationManager.scala | 72 ++++-- .../solvers/templates/TemplateGenerator.scala | 57 +++-- .../leon/solvers/templates/TemplateInfo.scala | 2 +- ...{Templates.scala => TemplateManager.scala} | 184 ++++---------- .../solvers/templates/UnrollingBank.scala | 42 +++- .../scala/leon/termination/ChainBuilder.scala | 2 +- .../leon/termination/ChainProcessor.scala | 5 +- 8 files changed, 362 insertions(+), 235 deletions(-) rename src/main/scala/leon/solvers/templates/{Templates.scala => TemplateManager.scala} (72%) diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 00bdbfa07..74bfe2fd4 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -6,29 +6,194 @@ package templates import purescala.Common._ import purescala.Expressions._ +import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ import utils._ import Instantiation._ -class LambdaManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { +case class App[T](caller: T, tpe: FunctionType, args: Seq[T]) { + override def toString = "(" + caller + " : " + tpe + ")" + args.mkString("(", ",", ")") +} + +object LambdaTemplate { + + def apply[T]( + ids: (Identifier, T), + encoder: TemplateEncoder[T], + manager: QuantificationManager[T], + pathVar: (Identifier, T), + arguments: Seq[(Identifier, T)], + condVars: Map[Identifier, T], + exprVars: Map[Identifier, T], + condTree: Map[Identifier, Set[Identifier]], + guardedExprs: Map[Identifier, Seq[Expr]], + quantifications: Seq[QuantificationTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], + baseSubstMap: Map[Identifier, T], + dependencies: Map[Identifier, T], + lambda: Lambda + ) : LambdaTemplate[T] = { + + val id = ids._2 + val tpe = ids._1.getType.asInstanceOf[FunctionType] + val (clauses, blockers, applications, matchers, templateString) = + Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, + substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) + + val lambdaString : () => String = () => { + "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() + } + + val (structuralLambda, structSubst) = normalizeStructure(lambda) + val keyDeps = dependencies.map { case (id, idT) => structSubst(id) -> idT } + val key = structuralLambda.asInstanceOf[Lambda] + + new LambdaTemplate[T]( + ids, + encoder, + manager, + pathVar._2, + arguments, + condVars, + exprVars, + condTree, + clauses, + blockers, + applications, + quantifications, + matchers, + lambdas, + keyDeps, + key, + lambdaString + ) + } +} + +class LambdaTemplate[T] private ( + val ids: (Identifier, T), + val encoder: TemplateEncoder[T], + val manager: QuantificationManager[T], + val start: T, + val arguments: Seq[(Identifier, T)], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val quantifications: Seq[QuantificationTemplate[T]], + val matchers: Map[T, Set[Matcher[T]]], + val lambdas: Seq[LambdaTemplate[T]], + private[templates] val dependencies: Map[Identifier, T], + private[templates] val structuralKey: Lambda, + stringRepr: () => String) extends Template[T] { + + val args = arguments.map(_._2) + val tpe = ids._1.getType.asInstanceOf[FunctionType] + + def substitute(substituter: T => T): LambdaTemplate[T] = { + val newStart = substituter(start) + val newClauses = clauses.map(substituter) + val newBlockers = blockers.map { case (b, fis) => + val bp = if (b == start) newStart else b + bp -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + } + + val newApplications = applications.map { case (b, fas) => + val bp = if (b == start) newStart else b + bp -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) + } + + val newQuantifications = quantifications.map(_.substitute(substituter)) + + val newMatchers = matchers.map { case (b, ms) => + val bp = if (b == start) newStart else b + bp -> ms.map(_.substitute(substituter)) + } + + val newLambdas = lambdas.map(_.substitute(substituter)) + + val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) + + new LambdaTemplate[T]( + ids._1 -> substituter(ids._2), + encoder, + manager, + newStart, + arguments, + condVars, + exprVars, + condTree, + newClauses, + newBlockers, + newApplications, + newQuantifications, + newMatchers, + newLambdas, + newDependencies, + structuralKey, + stringRepr + ) + } + + private lazy val str : String = stringRepr() + override def toString : String = str + + lazy val key: (Expr, Seq[T]) = { + def rec(e: Expr): Seq[Identifier] = e match { + case Variable(id) => + if (dependencies.isDefinedAt(id)) { + Seq(id) + } else { + Seq.empty + } + + case Operator(es, _) => es.flatMap(rec) + + case _ => Seq.empty + } + + structuralKey -> rec(structuralKey).distinct.map(dependencies) + } + + override def equals(that: Any): Boolean = that match { + case t: LambdaTemplate[T] => + val (lambda1, deps1) = key + val (lambda2, deps2) = t.key + (lambda1 == lambda2) && { + (deps1 zip deps2).forall { case (id1, id2) => + (manager.byID.get(id1), manager.byID.get(id2)) match { + case (Some(t1), Some(t2)) => t1 == t2 + case _ => id1 == id2 + } + } + } + + case _ => false + } + + override def hashCode: Int = key.hashCode + + override def instantiate(substMap: Map[T, T]): Instantiation[T] = { + super.instantiate(substMap) ++ manager.instantiateAxiom(this, substMap) + } +} + +class LambdaManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(encoder) { private[templates] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) - protected val byID = new IncrementalMap[T, LambdaTemplate[T]] - protected val byType = new IncrementalMap[FunctionType, Set[(T, LambdaTemplate[T])]].withDefaultValue(Set.empty) - protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) - protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) + protected[templates] val byID = new IncrementalMap[T, LambdaTemplate[T]] + protected val byType = new IncrementalMap[FunctionType, Set[LambdaTemplate[T]]].withDefaultValue(Set.empty) + protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) + protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) private val instantiated = new IncrementalSet[(T, App[T])] - protected def incrementals: List[IncrementalState] = - List(byID, byType, applications, freeLambdas, instantiated) - - def clear(): Unit = incrementals.foreach(_.clear()) - def reset(): Unit = incrementals.foreach(_.reset()) - def push(): Unit = incrementals.foreach(_.push()) - def pop(): Unit = incrementals.foreach(_.pop()) + override protected def incrementals: List[IncrementalState] = + super.incrementals ++ List(byID, byType, applications, freeLambdas, instantiated) def registerFree(lambdas: Seq[(Identifier, T)]): Unit = { for ((id, idT) <- lambdas) id.getType match { @@ -40,21 +205,26 @@ class LambdaManager[T](protected[templates] val encoder: TemplateEncoder[T]) ext def instantiateLambda(template: LambdaTemplate[T]): Instantiation[T] = { val idT = template.ids._2 - var clauses : Clauses[T] = equalityClauses(idT, template) + var clauses : Clauses[T] = equalityClauses(template) var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) // make sure the new lambda isn't equal to any free lambda var clauses ++= freeLambdas(template.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT))) byID += idT -> template - byType += template.tpe -> (byType(template.tpe) + (idT -> template)) - for (blockedApp @ (_, App(caller, tpe, args)) <- applications(template.tpe)) { - val equals = encoder.mkEquals(idT, caller) - appBlockers += (blockedApp -> (appBlockers(blockedApp) + TemplateAppInfo(template, equals, args))) - } + if (byType(template.tpe)(template)) { + (clauses, Map.empty, Map.empty) + } else { + byType += template.tpe -> (byType(template.tpe) + template) + + for (blockedApp @ (_, App(caller, tpe, args)) <- applications(template.tpe)) { + val equals = encoder.mkEquals(idT, caller) + appBlockers += (blockedApp -> (appBlockers(blockedApp) + TemplateAppInfo(template, equals, args))) + } - (clauses, Map.empty, appBlockers) + (clauses, Map.empty, appBlockers) + } } def instantiateApp(blocker: T, app: App[T]): Instantiation[T] = { @@ -75,8 +245,8 @@ class LambdaManager[T](protected[templates] val encoder: TemplateEncoder[T]) ext // so that UnrollingBank will generate the initial block! val init = instantiation withApps Map(key -> Set.empty) val inst = byType(tpe).foldLeft(init) { - case (instantiation, (idT, template)) => - val equals = encoder.mkEquals(idT, caller) + case (instantiation, template) => + val equals = encoder.mkEquals(template.ids._2, caller) instantiation withApp (key -> TemplateAppInfo(template, equals, args)) } @@ -88,16 +258,21 @@ class LambdaManager[T](protected[templates] val encoder: TemplateEncoder[T]) ext } } - private def equalityClauses(idT: T, template: LambdaTemplate[T]): Seq[T] = { - byType(template.tpe).map { case (thatIdT, that) => - val equals = encoder.mkEquals(idT, thatIdT) - template.contextEquality(that) match { - case None => encoder.mkNot(equals) - case Some(Seq()) => equals - case Some(seq) => encoder.mkEquals(encoder.mkAnd(seq : _*), equals) + private def equalityClauses(template: LambdaTemplate[T]): Seq[T] = { + val (s1, deps1) = template.key + byType(template.tpe).map { that => + val (s2, deps2) = that.key + val equals = encoder.mkEquals(template.ids._2, that.ids._2) + if (s1 == s2) { + val pairs = (deps1 zip deps2).filter(p => p._1 != p._2) + if (pairs.isEmpty) equals else { + val eqs = pairs.map(p => encoder.mkEquals(p._1, p._2)) + encoder.mkEquals(encoder.mkAnd(eqs : _*), equals) + } + } else { + encoder.mkNot(equals) } }.toSeq } - } diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index f2d1c1d6f..e9ac6bbec 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -53,6 +53,7 @@ class QuantificationTemplate[T]( val quantifiers: Seq[(Identifier, T)], val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], @@ -70,6 +71,7 @@ class QuantificationTemplate[T]( quantifiers, condVars, exprVars, + condTree, clauses.map(substituter), blockers.map { case (b, fis) => substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) @@ -97,6 +99,7 @@ object QuantificationTemplate { quantifiers: Seq[(Identifier, T)], condVars: Map[Identifier, T], exprVars: Map[Identifier, T], + condTree: Map[Identifier, Set[Identifier]], guardedExprs: Map[Identifier, Seq[Expr]], lambdas: Seq[LambdaTemplate[T]], subst: Map[Identifier, T] @@ -112,7 +115,7 @@ object QuantificationTemplate { new QuantificationTemplate[T](quantificationManager, pathVar._2, qs, q2s, insts, guards._2, quantifiers, - condVars, exprVars, clauses, blockers, applications, matchers, lambdas) + condVars, exprVars, condTree, clauses, blockers, applications, matchers, lambdas) } } @@ -214,7 +217,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case _ => 0 }).max - private def encodeEnablers(es: Set[T]): T = encoder.mkAnd(es.toSeq.sortBy(_.toString) : _*) + private def encodeEnablers(es: Set[T]): T = + if (es.isEmpty) trueT else encoder.mkAnd(es.toSeq.sortBy(_.toString) : _*) private type Matchers = Set[(T, Matcher[T])] @@ -343,6 +347,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val allMatchers: Map[T, Set[Matcher[T]]] val condVars: Map[Identifier, T] val exprVars: Map[Identifier, T] + val condTree: Map[Identifier, Set[Identifier]] val clauses: Seq[T] val blockers: Map[T, Set[TemplateCallInfo[T]]] val applications: Map[T, Set[App[T]]] @@ -424,7 +429,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } val substituter = encoder.substitute(subst.mapValues(Matcher.argValue)) - val enablers = (if (constraints.isEmpty) Set(trueT) else constraints).map(substituter) + val enablers = constraints.filter(_ != trueT).map(substituter) val isStrict = matcherEqs.forall(p => substituter(p._1) == p._2) (enablers, subst, isStrict) @@ -435,16 +440,20 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage for (mapping <- mappings(bs, matcher)) { val (enablers, subst, isStrict) = extractSubst(mapping) - val enabler = encodeEnablers(enablers) + val (enabler, optEnabler) = freshBlocker(enablers) - val baseSubstMap = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } - val lambdaSubstMap = lambdas map(lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) - val substMap = subst.mapValues(Matcher.argValue) ++ baseSubstMap ++ lambdaSubstMap ++ instanceSubst(enabler) + if (optEnabler.isDefined) { + instantiation = instantiation withClause encoder.mkEquals(enabler, optEnabler.get) + } + + val baseSubstMap = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ + freshConds(enabler, condVars, condTree) + val lambdaSubstMap = lambdas map (lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) + val substMap = subst.mapValues(Matcher.argValue) ++ baseSubstMap ++ lambdaSubstMap ++ instanceSubst(enablers) instantiation ++= Template.instantiate(encoder, QuantificationManager.this, clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) - val msubst = subst.collect { case (c, Right(m)) => c -> m } val substituter = encoder.substitute(substMap) @@ -465,7 +474,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage instantiation } - protected def instanceSubst(enabler: T): Map[T, T] + protected def instanceSubst(enablers: Set[T]): Map[T, T] } private class Quantification ( @@ -478,6 +487,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val allMatchers: Map[T, Set[Matcher[T]]], val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], @@ -485,10 +495,10 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage var currentQ2Var: T = qs._2 - protected def instanceSubst(enabler: T): Map[T, T] = { + protected def instanceSubst(enablers: Set[T]): Map[T, T] = { val nextQ2Var = encoder.encodeId(q2s._1) - val subst = Map(qs._2 -> currentQ2Var, guardVar -> enabler, + val subst = Map(qs._2 -> currentQ2Var, guardVar -> encodeEnablers(enablers), q2s._2 -> nextQ2Var, insts._2 -> encoder.encodeId(insts._1)) currentQ2Var = nextQ2Var @@ -498,6 +508,20 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private lazy val blockerId = FreshIdentifier("blocker", BooleanType, true) private lazy val blockerCache: MutableMap[T, T] = MutableMap.empty + private def freshBlocker(enablers: Set[T]): (T, Option[T]) = enablers.toSeq match { + case Seq(b) if isBlocker(b) => (b, None) + case _ => + val enabler = encodeEnablers(enablers) + blockerCache.get(enabler) match { + case Some(b) => (b, None) + case None => + val nb = encoder.encodeId(blockerId) + blockerCache += enabler -> nb + for (b <- enablers if isBlocker(b)) implies(b, nb) + blocker(nb) + (nb, Some(enabler)) + } + } private class Axiom ( val start: T, @@ -508,22 +532,17 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val allMatchers: Map[T, Set[Matcher[T]]], val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { - protected def instanceSubst(enabler: T): Map[T, T] = { - val newBlocker = blockerCache.get(enabler) match { - case Some(b) => b - case None => - val nb = encoder.encodeId(blockerId) - blockerCache += enabler -> nb - blockerCache += nb -> nb - nb - } - - Map(guardVar -> encoder.mkAnd(start, enabler), blocker -> newBlocker) + protected def instanceSubst(enablers: Set[T]): Map[T, T] = { + // no need to add an equality clause here since it is already contained in the Axiom clauses + val (newBlocker, optEnabler) = freshBlocker(enablers) + val guardT = if (optEnabler.isDefined) encoder.mkAnd(start, optEnabler.get) else start + Map(guardVar -> guardT, blocker -> newBlocker) } } @@ -585,6 +604,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage allMatchers + (template.start -> (allMatchers.getOrElse(template.start, Set.empty) + selfMatcher)), template.condVars map { case (id, idT) => id -> substituter(idT) }, template.exprVars map { case (id, idT) => id -> substituter(idT) }, + template.condTree, (template.clauses map substituter) :+ enablingClause, template.blockers map { case (b, fis) => substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) @@ -606,6 +626,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage allMatchers: Map[T, Set[Matcher[T]]], condVars: Map[Identifier, T], exprVars: Map[Identifier, T], + condTree: Map[Identifier, Set[Identifier]], clauses: Seq[T], blockers: Map[T, Set[TemplateCallInfo[T]]], applications: Map[T, Set[App[T]]], @@ -618,7 +639,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage for (matchers <- matchQuorums) { val axiom = new Axiom(start, blocker, guardVar, quantified, - matchers, allMatchers, condVars, exprVars, + matchers, allMatchers, condVars, exprVars, condTree, clauses, blockers, applications, lambdas ) @@ -638,7 +659,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage m <- matchers sm = m.substitute(substituter) if !instCtx.corresponding(sm).exists(_._2.args == sm.args) - } instantiation ++= instCtx.instantiate(Set(trueT), sm)(quantifications.toSeq : _*) + } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) instantiation } @@ -661,6 +682,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage template.matchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) }, template.condVars, template.exprVars, + template.condTree, template.clauses map substituter, template.blockers map { case (b, fis) => substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) @@ -697,7 +719,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (_, ms) <- template.matchers; m <- ms sm = m.substitute(substituter) if !instCtx.corresponding(sm).exists(_._2.args == sm.args) - } instantiation ++= instCtx.instantiate(Set(trueT), sm)(quantifications.toSeq : _*) + } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) instantiation } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index 7a3df85ff..3ec1f060f 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -13,11 +13,24 @@ import purescala.Definitions._ import purescala.Constructors._ import purescala.Quantification._ +import Instantiation._ + class TemplateGenerator[T](val encoder: TemplateEncoder[T], val assumePreHolds: Boolean) { private var cache = Map[TypedFunDef, FunctionTemplate[T]]() private var cacheExpr = Map[Expr, FunctionTemplate[T]]() + private type Clauses = ( + Map[Identifier,T], + Map[Identifier,T], + Map[Identifier, Set[Identifier]], + Map[Identifier, Seq[Expr]], + Seq[LambdaTemplate[T]], + Seq[QuantificationTemplate[T]] + ) + + private def emptyClauses: Clauses = (Map.empty, Map.empty, Map.empty, Map.empty, Seq.empty, Seq.empty) + val manager = new QuantificationManager[T](encoder) def mkTemplate(body: Expr): FunctionTemplate[T] = { @@ -71,16 +84,14 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val substMap : Map[Identifier, T] = arguments.toMap + pathVar - val (bodyConds, bodyExprs, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { - invocationEqualsBody.map(expr => mkClauses(start, expr, substMap)).getOrElse { - (Map[Identifier,T](), Map[Identifier,T](), Map[Identifier,Seq[Expr]](), Seq[LambdaTemplate[T]](), Seq[QuantificationTemplate[T]]()) - } + val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { + invocationEqualsBody.map(expr => mkClauses(start, expr, substMap)).getOrElse(emptyClauses) } else { mkClauses(start, lambdaBody.get, substMap) } // Now the postcondition. - val (condVars, exprVars, guardedExprs, lambdas, quantifications) = tfd.postcondition match { + val (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) = tfd.postcondition match { case Some(post) => val newPost : Expr = application(matchToIfThenElse(post), Seq(invocation)) @@ -95,19 +106,15 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], newPost } - val (postConds, postExprs, postGuarded, postLambdas, postQuantifications) = mkClauses(start, postHolds, substMap) - val allGuarded = (bodyGuarded.keys ++ postGuarded.keys).map { k => - k -> (bodyGuarded.getOrElse(k, Seq.empty) ++ postGuarded.getOrElse(k, Seq.empty)) - }.toMap - - (bodyConds ++ postConds, bodyExprs ++ postExprs, allGuarded, bodyLambdas ++ postLambdas, bodyQuantifications ++ postQuantifications) + val (postConds, postExprs, postTree, postGuarded, postLambdas, postQuantifications) = mkClauses(start, postHolds, substMap) + (bodyConds ++ postConds, bodyExprs ++ postExprs, bodyTree merge postTree, bodyGuarded merge postGuarded, bodyLambdas ++ postLambdas, bodyQuantifications ++ postQuantifications) case None => - (bodyConds, bodyExprs, bodyGuarded, bodyLambdas, bodyQuantifications) + (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) } val template = FunctionTemplate(tfd, encoder, manager, - pathVar, arguments, condVars, exprVars, guardedExprs, quantifications, lambdas, isRealFunDef) + pathVar, arguments, condVars, exprVars, condTree, guardedExprs, quantifications, lambdas, isRealFunDef) cache += tfd -> template template } @@ -173,11 +180,15 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], (quantified, flatConj) } - def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): - (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Seq[LambdaTemplate[T]], Seq[QuantificationTemplate[T]]) = { + def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): Clauses = { var condVars = Map[Identifier, T]() - @inline def storeCond(id: Identifier) : Unit = condVars += id -> encoder.encodeId(id) + var condTree = Map[Identifier, Set[Identifier]](pathVar -> Set.empty).withDefaultValue(Set.empty) + def storeCond(pathVar: Identifier, id: Identifier) : Unit = { + condVars += id -> encoder.encodeId(id) + condTree += pathVar -> (condTree(pathVar) + id) + } + @inline def encodedCond(id: Identifier) : T = substMap.getOrElse(id, condVars(id)) var exprVars = Map[Identifier, T]() @@ -276,8 +287,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val newBool2 : Identifier = FreshIdentifier("b", BooleanType, true) val newExpr : Identifier = FreshIdentifier("e", i.getType, true) - storeCond(newBool1) - storeCond(newBool2) + storeCond(pathVar, newBool1) + storeCond(pathVar, newBool2) storeExpr(newExpr) @@ -319,12 +330,12 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars val clauseSubst: Map[Identifier, T] = localSubst ++ (idArgs zip trArgs) - val (lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, lambdaQuants) = mkClauses(pathVar, clause, clauseSubst) + val (lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaTemplates, lambdaQuants) = mkClauses(pathVar, clause, clauseSubst) val ids: (Identifier, T) = lid -> storeLambda(lid) val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap val template = LambdaTemplate(ids, encoder, manager, pathVar -> encodedCond(pathVar), - idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaQuants, lambdaTemplates, localSubst, dependencies, l) + idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaQuants, lambdaTemplates, localSubst, dependencies, l) registerLambda(template) Variable(lid) @@ -350,14 +361,14 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val qs: (Identifier, T) = q -> encoder.encodeId(q) val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars val clauseSubst: Map[Identifier, T] = localSubst ++ (idQuantifiers zip trQuantifiers) - val (qConds, qExprs, qGuarded, qTemplates, qQuants) = mkClauses(pathVar, clause, clauseSubst) + val (qConds, qExprs, qTree, qGuarded, qTemplates, qQuants) = mkClauses(pathVar, clause, clauseSubst) assert(qQuants.isEmpty, "Unhandled nested quantification in "+clause) val binder = Equals(Variable(q), And(Variable(q2), Variable(inst))) val allQGuarded = qGuarded + (pathVar -> (binder +: qGuarded.getOrElse(pathVar, Seq.empty))) val template = QuantificationTemplate[T](encoder, manager, pathVar -> encodedCond(pathVar), - qs, q2, inst, guard, idQuantifiers zip trQuantifiers, qConds, qExprs, allQGuarded, qTemplates, localSubst) + qs, q2, inst, guard, idQuantifiers zip trQuantifiers, qConds, qExprs, qTree, allQGuarded, qTemplates, localSubst) registerQuantification(template) Variable(q) } @@ -371,7 +382,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val p = rec(pathVar, expr) storeGuarded(pathVar, p) - (condVars, exprVars, guardedExprs, lambdas, quantifications) + (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) } } diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/templates/TemplateInfo.scala index 977aeb571..033f15dd6 100644 --- a/src/main/scala/leon/solvers/templates/TemplateInfo.scala +++ b/src/main/scala/leon/solvers/templates/TemplateInfo.scala @@ -14,6 +14,6 @@ case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) { case class TemplateAppInfo[T](template: LambdaTemplate[T], equals: T, args: Seq[T]) { override def toString = { - template.ids._1 + "|" + equals + args.mkString("(", ",", ")") + template.ids._2 + "|" + equals + args.mkString("(", ",", ")") } } diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala similarity index 72% rename from src/main/scala/leon/solvers/templates/Templates.scala rename to src/main/scala/leon/solvers/templates/TemplateManager.scala index 5e7302c54..07019aefd 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -12,9 +12,9 @@ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -case class App[T](caller: T, tpe: FunctionType, args: Seq[T]) { - override def toString = "(" + caller + " : " + tpe + ")" + args.mkString("(", ",", ")") -} +import utils._ + +import scala.collection.generic.CanBuildFrom object Instantiation { type Clauses[T] = Seq[T] @@ -24,12 +24,18 @@ object Instantiation { def empty[T] = (Seq.empty[T], Map.empty[T, Set[TemplateCallInfo[T]]], Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) - implicit class MapWrapper[A,B](map: Map[A,Set[B]]) { + implicit class MapSetWrapper[A,B](map: Map[A,Set[B]]) { def merge(that: Map[A,Set[B]]): Map[A,Set[B]] = (map.keys ++ that.keys).map { k => k -> (map.getOrElse(k, Set.empty) ++ that.getOrElse(k, Set.empty)) }.toMap } + implicit class MapSeqWrapper[A,B](map: Map[A,Seq[B]]) { + def merge(that: Map[A,Seq[B]]): Map[A,Seq[B]] = (map.keys ++ that.keys).map { k => + k -> (map.getOrElse(k, Seq.empty) ++ that.getOrElse(k, Seq.empty)) + }.toMap + } + implicit class InstantiationWrapper[T](i: Instantiation[T]) { def ++(that: Instantiation[T]): Instantiation[T] = { val (thisClauses, thisBlockers, thisApps) = i @@ -59,6 +65,7 @@ trait Template[T] { self => val args : Seq[T] val condVars : Map[Identifier, T] val exprVars : Map[Identifier, T] + val condTree : Map[Identifier, Set[Identifier]] val clauses : Seq[T] val blockers : Map[T, Set[TemplateCallInfo[T]]] val applications : Map[T, Set[App[T]]] @@ -73,7 +80,8 @@ trait Template[T] { self => val baseSubstMap : Map[T,T] = substCache.get(args) match { case Some(subst) => subst case None => - val subst = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } ++ + val subst = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ + manager.freshConds(aVar, condVars, condTree) ++ (this.args zip args) substCache += args -> subst subst @@ -270,6 +278,7 @@ object FunctionTemplate { arguments: Seq[(Identifier, T)], condVars: Map[Identifier, T], exprVars: Map[Identifier, T], + condTree: Map[Identifier, Set[Identifier]], guardedExprs: Map[Identifier, Seq[Expr]], quantifications: Seq[QuantificationTemplate[T]], lambdas: Seq[LambdaTemplate[T]], @@ -295,6 +304,7 @@ object FunctionTemplate { arguments.map(_._2), condVars, exprVars, + condTree, clauses, blockers, applications, @@ -315,6 +325,7 @@ class FunctionTemplate[T] private( val args: Seq[T], val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], @@ -333,152 +344,41 @@ class FunctionTemplate[T] private( } } -object LambdaTemplate { - - def apply[T]( - ids: (Identifier, T), - encoder: TemplateEncoder[T], - manager: QuantificationManager[T], - pathVar: (Identifier, T), - arguments: Seq[(Identifier, T)], - condVars: Map[Identifier, T], - exprVars: Map[Identifier, T], - guardedExprs: Map[Identifier, Seq[Expr]], - quantifications: Seq[QuantificationTemplate[T]], - lambdas: Seq[LambdaTemplate[T]], - baseSubstMap: Map[Identifier, T], - dependencies: Map[Identifier, T], - lambda: Lambda - ) : LambdaTemplate[T] = { - - val id = ids._2 - val tpe = ids._1.getType.asInstanceOf[FunctionType] - val (clauses, blockers, applications, matchers, templateString) = - Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, - substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) - - val lambdaString : () => String = () => { - "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() - } - - val (structuralLambda, structSubst) = normalizeStructure(lambda) - val keyDeps = dependencies.map { case (id, idT) => structSubst(id) -> idT } - val key = structuralLambda.asInstanceOf[Lambda] - - new LambdaTemplate[T]( - ids, - encoder, - manager, - pathVar._2, - arguments, - condVars, - exprVars, - clauses, - blockers, - applications, - quantifications, - matchers, - lambdas, - keyDeps, - key, - lambdaString - ) - } -} - -class LambdaTemplate[T] private ( - val ids: (Identifier, T), - val encoder: TemplateEncoder[T], - val manager: QuantificationManager[T], - val start: T, - val arguments: Seq[(Identifier, T)], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val quantifications: Seq[QuantificationTemplate[T]], - val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Seq[LambdaTemplate[T]], - private[templates] val dependencies: Map[Identifier, T], - private[templates] val structuralKey: Lambda, - stringRepr: () => String) extends Template[T] { +class TemplateManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { + private val condImplies = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) + private val condImplied = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) - val args = arguments.map(_._2) - val tpe = ids._1.getType.asInstanceOf[FunctionType] + protected def incrementals: List[IncrementalState] = List(condImplies, condImplied) - def substitute(substituter: T => T): LambdaTemplate[T] = { - val newStart = substituter(start) - val newClauses = clauses.map(substituter) - val newBlockers = blockers.map { case (b, fis) => - val bp = if (b == start) newStart else b - bp -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) - } + def clear(): Unit = incrementals.foreach(_.clear()) + def reset(): Unit = incrementals.foreach(_.reset()) + def push(): Unit = incrementals.foreach(_.push()) + def pop(): Unit = incrementals.foreach(_.pop()) - val newApplications = applications.map { case (b, fas) => - val bp = if (b == start) newStart else b - bp -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) - } + def freshConds(path: T, condVars: Map[Identifier, T], tree: Map[Identifier, Set[Identifier]]): Map[T, T] = { + val subst = condVars.map { case (id, idT) => idT -> encoder.encodeId(id) } + val pathVar = tree.keys.filter(id => !condVars.isDefinedAt(id)).head + val mapping = condVars.mapValues(subst) + (pathVar -> path) - val newQuantifications = quantifications.map(_.substitute(substituter)) - - val newMatchers = matchers.map { case (b, ms) => - val bp = if (b == start) newStart else b - bp -> ms.map(_.substitute(substituter)) + for ((parent, children) <- tree; ep = mapping(parent); child <- children) { + val ec = mapping(child) + condImplies += ep -> (condImplies(ep) + ec) + condImplied += ec -> (condImplied(ec) + ep) } - val newLambdas = lambdas.map(_.substitute(substituter)) - - val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) - - new LambdaTemplate[T]( - ids._1 -> substituter(ids._2), - encoder, - manager, - newStart, - arguments, - condVars, - exprVars, - newClauses, - newBlockers, - newApplications, - newQuantifications, - newMatchers, - newLambdas, - newDependencies, - structuralKey, - stringRepr - ) + subst } - private lazy val str : String = stringRepr() - override def toString : String = str - - def contextEquality(that: LambdaTemplate[T]) : Option[Seq[T]] = { - if (structuralKey != that.structuralKey) { - None // can't be equal - } else if (dependencies.isEmpty) { - Some(Seq.empty) // must be equal - } else { - def rec(e1: Expr, e2: Expr): Seq[T] = (e1,e2) match { - case (Variable(id1), Variable(id2)) => - if (dependencies.isDefinedAt(id1)) { - Seq(encoder.mkEquals(dependencies(id1), that.dependencies(id2))) - } else { - Seq.empty - } - - case (Operator(es1, _), Operator(es2, _)) => - (es1 zip es2).flatMap(p => rec(p._1, p._2)) - - case _ => Seq.empty - } - - Some(rec(structuralKey, that.structuralKey)) + def blocker(b: T): Unit = condImplies += (b -> Set.empty) + def isBlocker(b: T): Boolean = condImplies.isDefinedAt(b) || condImplied.isDefinedAt(b) + + def implies(b1: T, b2: T): Unit = implies(b1, Set(b2)) + def implies(b1: T, b2s: Set[T]): Unit = { + val fb2s = b2s.filter(_ != b1) + condImplies += b1 -> (condImplies(b1) ++ fb2s) + for (b2 <- fb2s) { + condImplied += b2 -> (condImplies(b2) + b1) } } - override def instantiate(substMap: Map[T, T]): Instantiation[T] = { - super.instantiate(substMap) ++ manager.instantiateAxiom(this, substMap) - } } diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index ddfb22b0b..e6e5cd223 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -19,7 +19,8 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat private val manager = templateGenerator.manager // Function instantiations have their own defblocker - private val defBlockers = new IncrementalMap[TemplateCallInfo[T], T]() + private val defBlockers = new IncrementalMap[TemplateCallInfo[T], T]() + private val lambdaBlockers = new IncrementalMap[TemplateAppInfo[T], T]() // Keep which function invocation is guarded by which guard, // also specify the generation of the blocker. @@ -32,6 +33,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def push() { callInfos.push() defBlockers.push() + lambdaBlockers.push() appInfos.push() appBlockers.push() blockerToApps.push() @@ -41,6 +43,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def pop() { callInfos.pop() defBlockers.pop() + lambdaBlockers.pop() appInfos.pop() appBlockers.pop() blockerToApps.pop() @@ -50,6 +53,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def clear() { callInfos.clear() defBlockers.clear() + lambdaBlockers.clear() appInfos.clear() appBlockers.clear() blockerToApps.clear() @@ -59,6 +63,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def reset() { callInfos.reset() defBlockers.reset() + lambdaBlockers.reset() appInfos.reset() appBlockers.reset() blockerToApps.clear() @@ -257,6 +262,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat // we need to define this defBlocker and link it to definition val defBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType)) defBlockers += info -> defBlocker + manager.implies(id, defBlocker) val template = templateGenerator.mkTemplate(tfd) //reporter.debug(template) @@ -279,7 +285,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat // We connect it to the defBlocker: blocker => defBlocker if (defBlocker != id) { - newCls ++= List(encoder.mkImplies(id, defBlocker)) + newCls :+= encoder.mkImplies(id, defBlocker) } reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") @@ -293,22 +299,32 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat for ((app @ (b, _), (gen, _, _, _, infos)) <- thisAppInfos; info @ TemplateAppInfo(template, equals, args) <- infos) { var newCls = Seq.empty[T] - val nb = encoder.encodeId(FreshIdentifier("b", BooleanType, true)) - newCls :+= encoder.mkEquals(nb, encoder.mkAnd(equals, b)) + val lambdaBlocker = lambdaBlockers.get(info) match { + case Some(lambdaBlocker) => lambdaBlocker - val (newExprs, callBlocks, appBlocks) = template.instantiate(nb, args) - val blockExprs = freshAppBlocks(appBlocks.keys) + case None => + val lambdaBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType)) + lambdaBlockers += info -> lambdaBlocker + manager.implies(b, lambdaBlocker) - for ((b, newInfos) <- callBlocks) { - registerCallBlocker(nextGeneration(gen), b, newInfos) - } + val (newExprs, callBlocks, appBlocks) = template.instantiate(lambdaBlocker, args) + val blockExprs = freshAppBlocks(appBlocks.keys) + + for ((b, newInfos) <- callBlocks) { + registerCallBlocker(nextGeneration(gen), b, newInfos) + } + + for ((newApp, newInfos) <- appBlocks) { + registerAppBlocker(nextGeneration(gen), newApp, newInfos) + } - for ((newApp, newInfos) <- appBlocks) { - registerAppBlocker(nextGeneration(gen), newApp, newInfos) + newCls ++= newExprs + newCls ++= blockExprs + lambdaBlocker } - newCls ++= newExprs - newCls ++= blockExprs + val enabler = if (equals == manager.trueT) b else encoder.mkAnd(equals, b) + newCls :+= encoder.mkImplies(enabler, lambdaBlocker) reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") for (cl <- newCls) { diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index f2410fd82..0c6106127 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -156,7 +156,7 @@ trait ChainBuilder extends RelationBuilder { self: Strengthener with RelationCom if (!checker.program.callGraph.transitivelyCalls(fd, funDef)) { Set.empty[FunDef] -> Set.empty[Chain] } else if (fd == funDef) { - Set(fd) -> Set(Chain(chain.reverse)) + Set.empty[FunDef] -> Set(Chain(chain.reverse)) } else if (seen(fd)) { Set(fd) -> Set.empty[Chain] } else { diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala index af50911bd..799e51db6 100644 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -33,7 +33,10 @@ class ChainProcessor( reporter.debug("-+> Multiple looping points, can't build chain proof") None } else { - val funDef = loopPoints.head + val funDef = loopPoints.headOption getOrElse { + chainsMap.collectFirst { case (fd, (fds, chains)) if chains.nonEmpty => fd }.get + } + val chains = chainsMap(funDef)._2 val e1 = tupleWrap(funDef.params.map(_.toVariable)) -- GitLab