diff --git a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala index 7c1bf7ad13756f11720c27718975b4c1d357e6b2..89d522e7f2f4f643b1730b1fcf92a22bf47b5e1b 100644 --- a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala @@ -248,7 +248,7 @@ trait QuantificationTemplates { self: Templates => def unroll: Clauses = { val imClauses = new scala.collection.mutable.ListBuffer[Encoded] for (e @ (gen, bs, m) <- ignoredMatchers.toSeq if gen <= currentGeneration) { - imClauses ++= instantiateMatcher(bs, m) + imClauses ++= instantiateMatcher(bs, m, defer = true) ignoredMatchers -= e } @@ -263,7 +263,7 @@ trait QuantificationTemplates { self: Templates => ignoredSubsts += q -> keep for ((_, bs, subst) <- release) { - suClauses ++= q.instantiateSubst(bs, subst) + suClauses ++= q.instantiateSubst(bs, subst, defer = true) } } @@ -293,18 +293,18 @@ trait QuantificationTemplates { self: Templates => } def instantiateMatcher(blocker: Encoded, matcher: Matcher): Clauses = { - instantiateMatcher(Set(blocker), matcher) + instantiateMatcher(Set(blocker), matcher, false) } @inline - private def instantiateMatcher(blockers: Set[Encoded], matcher: Matcher): Clauses = { + private def instantiateMatcher(blockers: Set[Encoded], matcher: Matcher, defer: Boolean = false): Clauses = { val relevantBlockers = blockerPath(blockers) if (handledMatchers(relevantBlockers -> matcher)) { Seq.empty } else { handledMatchers += relevantBlockers -> matcher - quantifications.flatMap(_.instantiate(relevantBlockers, matcher)) + quantifications.flatMap(_.instantiate(relevantBlockers, matcher, defer)) } } @@ -509,7 +509,29 @@ trait QuantificationTemplates { self: Templates => def clear(): Unit = for (gs <- grounds.values) gs.clear() def reset(): Unit = for (gs <- grounds.values) gs.reset() - def instantiate(bs: Set[Encoded], m: Matcher): Clauses = { + private val optimizationQuorums: Seq[Set[Matcher]] = { + def matchersOf(m: Matcher): Set[Matcher] = m.args.flatMap { + case Right(m) => matchersOf(m) + m + case _ => Set.empty[Matcher] + }.toSet + + def quantifiersOf(m: Matcher): Set[Encoded] = + (matchersOf(m) + m).flatMap(_.args.collect { case Left(q) if quantified(q) => q }) + + val allMatchers = matchers.flatMap(_._2).toList + val allQuorums = allMatchers.toSet.subsets + .filter(ms => ms.flatMap(quantifiersOf) == quantified) + .filterNot(ms => allMatchers.exists { m => + !ms(m) && { + val common = ms & matchersOf(m) + common.nonEmpty && + (quantifiersOf(m) -- common.flatMap(quantifiersOf)).nonEmpty + } + }) + allQuorums.foldLeft(Seq[Set[Matcher]]())((qs,q) => if (qs.exists(_ subsetOf q)) qs else qs :+ q) + } + + def instantiate(bs: Set[Encoded], m: Matcher, defer: Boolean = false): Clauses = { generationCounter += 1 val gen = generationCounter @@ -531,7 +553,7 @@ trait QuantificationTemplates { self: Templates => (key, i) <- constraints; perfect <- correspond(matcherKey(m), key) if !grounds(q)(bs, m.args(i))) { - val initGens: Set[Int] = if (perfect) Set.empty else Set(gen) + val initGens: Set[Int] = if (perfect && !defer) Set.empty else Set(gen) val newMappings = (quantified - q).foldLeft(Seq((bs, Map(q -> m.args(i)), initGens, 0))) { case (maps, oq) => for { (bs, map, gens, c) <- maps @@ -546,7 +568,21 @@ trait QuantificationTemplates { self: Templates => /* @nv: I tried some smarter cost computations here but it turns out that the * overhead needed to determine when to optimize exceeds the benefits */ mappings ++= newMappings.map { case (bs, map, gens, c) => - (bs, map, c + (3 * map.values.collect { case Right(m) => totalDepth(m) }.sum)) + val substituter = mkSubstituter(map.mapValues(_.encoded)) + val msubst = map.collect { case (q, Right(m)) => q -> m } + val isOpt = optimizationQuorums.exists { ms => + ms.forall(m => handledMatchers.contains(m.substitute(substituter, msubst))) + } + + val cost = if (initGens.nonEmpty) { + 1 + 3 * map.values.collect { case Right(m) => totalDepth(m) }.sum + } else if (!isOpt) { + 3 + } else { + 0 + } + + (bs, map, c + cost) } // register ground instantiation for future instantiations @@ -600,14 +636,14 @@ trait QuantificationTemplates { self: Templates => val gen = currentGeneration + delay + (if (getPolarity.isEmpty) 2 else 0) ignoredSubsts += this -> (ignoredSubsts.getOrElse(this, Set.empty) + ((gen, bs, subst))) } else { - instantiation ++= instantiateSubst(bs, subst) + instantiation ++= instantiateSubst(bs, subst, defer = false) } } instantiation.toSeq } - def instantiateSubst(bs: Set[Encoded], subst: Map[Encoded, Arg]): Clauses = { + def instantiateSubst(bs: Set[Encoded], subst: Map[Encoded, Arg], defer: Boolean = false): Clauses = { handledSubsts += this -> (handledSubsts.getOrElse(this, Set.empty) + (bs -> subst)) val instantiation = new scala.collection.mutable.ListBuffer[Encoded] @@ -630,14 +666,11 @@ trait QuantificationTemplates { self: Templates => val sb = bs ++ (if (b == guard) Set.empty else Set(substituter(b))) val sm = m.substitute(substituter, msubst) - def abs(i: Int): Int = if (i < 0) -i else i - val totalDelay = 2 * (abs(totalDepth(sm) - totalDepth(m)) + (if (b == guard) 0 else 1)) - - if (totalDelay > 0) { - val gen = currentGeneration + totalDelay + if (b != guard) { + val gen = currentGeneration + 1 ignoredMatchers += ((gen, sb, sm)) } else { - instantiation ++= instantiateMatcher(sb, sm) + instantiation ++= instantiateMatcher(sb, sm, defer = defer) } } @@ -883,7 +916,7 @@ trait QuantificationTemplates { self: Templates => val freshSubst = mkSubstituter(template.quantifiers.map(p => p._2 -> encodeSymbol(p._1)).toMap) for ((b,ms) <- template.matchers; m <- ms) { - clauses ++= instantiateMatcher(Set.empty[Encoded], m) + clauses ++= instantiateMatcher(Set.empty[Encoded], m, false) // it is very rare that such instantiations are actually required, so we defer them val gen = currentGeneration + 20 ignoredMatchers += ((gen, Set(b), m.substitute(freshSubst, Map.empty))) @@ -996,7 +1029,7 @@ trait QuantificationTemplates { self: Templates => val valuesP = values.map(v => v -> encodeSymbol(v)) val exprT = mkEncoder(elemsP.toMap ++ valuesP + guardP)(expr) - val disjunction = handledSubsts(q) match { + val disjunction = handledSubsts.getOrElse(q, Set.empty) match { case set if set.isEmpty => mkEncoder(Map.empty)(BooleanLiteral(false)) case set => mkOr(set.toSeq.map { case (enablers, subst) => val b = if (enablers.isEmpty) trueT else mkAnd(enablers.toSeq : _*)