diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 3a51cd78aa9dcbe46da38771dff17a2418920144..c0c7b70bbdef015524dd31352f0dfef9e9cda972 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -124,7 +124,7 @@ trait SymbolOps { self: TypeOps => * unrolling solver to provide geenral equality checks between functions even when * they have complex closures. */ - def normalizeStructure(args: Seq[ValDef], expr: Expr, onlySimple: Boolean = true): + def normalizeStructure(args: Seq[ValDef], expr: Expr, preserveApps: Boolean = true): (Seq[ValDef], Expr, Map[Variable, Expr]) = synchronized { val subst: MutableMap[Variable, Expr] = MutableMap.empty @@ -160,6 +160,33 @@ trait SymbolOps { self: TypeOps => } } + def extractMatcher(e: Expr): (Seq[Expr], Seq[Expr] => Expr) = e match { + case Application(caller, args) => + val (es, recons) = extractMatcher(caller) + (args ++ es, es => { + val (newArgs, newEs) = es.splitAt(args.size) + Application(recons(newEs), newArgs) + }) + + case ADTSelector(adt, id) => + val (es, recons) = extractMatcher(adt) + (es, es => ADTSelector(recons(es), id)) + + case ElementOfSet(elem, set) => + val (es, recons) = extractMatcher(set) + (elem +: es, { case elem +: es => ElementOfSet(elem, recons(es)) }) + + case MultiplicityInBag(elem, bag) => + val (es, recons) = extractMatcher(bag) + (elem +: es, { case elem +: es => MultiplicityInBag(elem, recons(es)) }) + + case MapApply(map, key) => + val (es, recons) = extractMatcher(map) + (key +: es, { case key +: es => MapApply(recons(es), key) }) + + case _ => (Seq(e), es => es.head) + } + def outer(vars: Set[Variable], body: Expr): Expr = { // this registers the argument images into subst val tvars = vars map (v => v.copy(id = transformId(v.id, v.tpe))) @@ -236,9 +263,16 @@ trait SymbolOps { self: TypeOps => case Variable(id, tpe) => Variable(transformId(id, tpe), tpe) + case (_: Application) | (_: MultiplicityInBag) | (_: ElementOfSet) | (_: MapApply) if ( + !isLocal(e, path) && + preserveApps + ) => + val (es, recons) = extractMatcher(e) + val newEs = es.map(rec(_, path)) + recons(newEs) + case Let(vd, e, b) if ( isLocal(e, path) && - (!onlySimple || isSimple(e)) && ((isSatisfiable(path) contains true) || isPure(e)) ) => val newId = getId(e) @@ -246,7 +280,6 @@ trait SymbolOps { self: TypeOps => case expr if ( isLocal(expr, path) && - (!onlySimple || isSimple(expr)) && ((isSatisfiable(path) contains true) || isPure(expr)) ) => Variable(getId(expr), expr.getType) @@ -282,14 +315,16 @@ trait SymbolOps { self: TypeOps => (bindings, newExpr, bodySubst) } - def normalizeStructure(lambda: Lambda): (Lambda, Map[Variable, Expr]) = { - val (args, body, subst) = normalizeStructure(lambda.args, lambda.body, onlySimple = false) - (Lambda(args, body), subst) - } - - def normalizeStructure(forall: Forall): (Forall, Map[Variable, Expr]) = { - val (args, body, subst) = normalizeStructure(forall.args, forall.body) - (Forall(args, body), subst) + def normalizeStructure(e: Expr): (Expr, Map[Variable, Expr]) = e match { + case lambda: Lambda => + val (args, body, subst) = normalizeStructure(lambda.args, lambda.body, preserveApps = false) + (Lambda(args, body), subst) + case forall: Forall => + val (args, body, subst) = normalizeStructure(forall.args, forall.body, preserveApps = true) + (Forall(args, body), subst) + case _ => + val (_, body, subst) = normalizeStructure(Seq.empty, e) + (body, subst) } /** Ensures the closure [[l]] can only be equal to some other closure if they share diff --git a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala index 879305f93a9f1cb0db9e7d514c9399e6d9d5d0b5..6c3ea1cc3f510ebf06dbaa8ae84d884ac4f26922 100644 --- a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala @@ -62,7 +62,8 @@ trait LambdaTemplates { self: Templates => equations: Seq[Expr], lambdas: Seq[LambdaTemplate], quantifications: Seq[QuantificationTemplate], - structure: LambdaStructure, + structure: TemplateStructure, + closures: Set[Encoded], baseSubstMap: Map[Variable, Encoded], lambda: Lambda ) : LambdaTemplate = { @@ -82,122 +83,12 @@ trait LambdaTemplates { self: Templates => condVars, exprVars, condTree, clauses, blockers, applications, matchers, lambdas, quantifications, pointers, - structure, + structure, closures, lambda, lambdaString, false ) } } - /** Semi-template used for hardcore function equality - * - * Function equality, while unhandled in general, can be very useful for certain - * proofs that refer specifically to first-class functions. In order to support - * such proofs, flexible notions of equality on first-class functions are - * necessary. These are provided by [[LambdaStructure]] which, much like a - * [[Template]], will generate clauses that represent equality between two - * functions. - * - * To support complex cases of equality where closed portions of the first-class - * function rely on complex program features (function calls, introducing lambdas, - * foralls, etc.), we use a structure that resembles a [[Template]] that is - * instantiated when function equality is of interest. - * - * Note that lambda creation now introduces clauses to determine equality between - * closed portions (that are independent of the lambda arguments). - */ - class LambdaStructure private[unrolling] ( - /** The normalized lambda that is shared between all "equal" first-class functions. - * First-class function equality is conditionned on `lambda` equality. - * - * @see [[dependencies]] for the other component of equality between first-class functions - */ - val lambda: Lambda, - - /** The closed expressions (independent of the arguments to [[lambda]]) contained in - * the first-class function. Equality is conditioned on equality of `dependencies` - * (inside the solver). - * - * @see [[lambda]] for the other component of equality between first-class functions - */ - val dependencies: Seq[Encoded], - val pathVar: (Variable, Encoded), - - /** The set of closed variables that exist in the associated lambda. - * - * This set is necessary to determine whether other closures have been - * captured by this particular closure when deciding the order of - * lambda instantiations in [[Template.substitution]]. - * - * We also use this set when computing lambda instantiation orders to - * determine whether equality with free first-class functions is possible. - */ - val closures: Seq[Encoded], - val condVars: Map[Variable, Encoded], - val exprVars: Map[Variable, Encoded], - val condTree: Map[Variable, Set[Variable]], - val clauses: Clauses, - val blockers: Calls, - val applications: Apps, - val matchers: Matchers, - val lambdas: Seq[LambdaTemplate], - val quantifications: Seq[QuantificationTemplate], - val pointers: Map[Encoded, Encoded]) { - - def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]) = new LambdaStructure( - lambda, - dependencies.map(substituter), - pathVar._1 -> substituter(pathVar._2), - closures.map(substituter), condVars, exprVars, condTree, - clauses.map(substituter), - blockers.map { case (b, fis) => substituter(b) -> fis.map(_.substitute(substituter, msubst)) }, - applications.map { case (b, fas) => substituter(b) -> fas.map(_.substitute(substituter, msubst)) }, - matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, - lambdas.map(_.substitute(substituter, msubst)), - quantifications.map(_.substitute(substituter, msubst)), - pointers.map(p => substituter(p._1) -> substituter(p._2))) - - /** The [[key]] value (tuple of [[lambda]] and [[dependencies]]) is used - * to determine syntactic equality between lambdas. If the keys of two - * closures are equal, then they must necessarily be equal in every model. - * - * The [[instantiation]] consists of the clause set instantiation (in the - * sense of [[Template.instantiate]] that is required for [[dependencies]] - * to make sense in the solver (introduces blockers, quantifications, other - * lambdas, etc.) Since [[dependencies]] CHANGE during instantiation and - * [[key]] makes no sense without the associated instantiation, the implicit - * contract here is that whenever a new key appears during unfolding, its - * associated instantiation MUST be added to the set of instantiations - * managed by the solver. However, if an identical pre-existing key has - * already been found, then the associated instantiations must already appear - * in those handled by the solver. - */ - lazy val (key, instantiation, locals, instantiationSubst) = { - val (substMap, substInst) = Template.substitution(condVars, exprVars, condTree, - lambdas, quantifications, pointers, Map.empty, pathVar._2) - val tmplInst = Template.instantiate(clauses, blockers, applications, matchers, substMap) - val instantiation = substInst ++ tmplInst - - val substituter = mkSubstituter(substMap.mapValues(_.encoded)) - val deps = dependencies.map(substituter) - val key = (lambda, blockerPath(pathVar._2), deps) - - val sortedDeps = exprOps.variablesOf(lambda).toSeq.sortBy(_.id.uniqueName) - val locals = sortedDeps zip deps - (key, instantiation, locals, substMap.mapValues(_.encoded)) - } - - override def equals(that: Any): Boolean = that match { - case (struct: LambdaStructure) => key == struct.key - case _ => false - } - - override def hashCode: Int = key.hashCode - - def subsumes(that: LambdaStructure): Boolean = { - key._1 == that.key._1 && key._3 == that.key._3 && key._2.subsetOf(that.key._2) - } - } - class LambdaTemplate private ( val ids: (Variable, Encoded), val pathVar: (Variable, Encoded), @@ -212,7 +103,8 @@ trait LambdaTemplates { self: Templates => val lambdas: Seq[LambdaTemplate], val quantifications: Seq[QuantificationTemplate], val pointers: Map[Encoded, Encoded], - val structure: LambdaStructure, + val structure: TemplateStructure, + val closures: Set[Encoded], val lambda: Lambda, private[unrolling] val stringRepr: () => String, private val isConcrete: Boolean) extends Template { @@ -232,6 +124,7 @@ trait LambdaTemplates { self: Templates => quantifications.map(_.substitute(substituter, msubst)), pointers.map(p => substituter(p._1) -> substituter(p._2)), structure.substitute(substituter, msubst), + closures.map(substituter), lambda, stringRepr, isConcrete) /** This must be called right before returning the clauses in [[structure.instantiation]]! */ @@ -248,7 +141,7 @@ trait LambdaTemplates { self: Templates => lambdas.map(_.substitute(substituter, Map.empty)), quantifications.map(_.substitute(substituter, Map.empty)), pointers.map(p => substituter(p._1) -> substituter(p._2)), - structure, lambda, stringRepr, true) + structure, closures, lambda, stringRepr, true) } override def instantiate(blocker: Encoded, args: Seq[Arg]): Clauses = { @@ -447,7 +340,7 @@ trait LambdaTemplates { self: Templates => val equals = mkEquals(template.ids._2, that.ids._2) mkImplies( mkAnd(template.start, that.start), - if (template.structure.lambda == that.structure.lambda) { + if (template.structure.body == that.structure.body) { val pairs = template.structure.locals zip that.structure.locals val filtered = pairs.filter(p => p._1 != p._2) if (filtered.isEmpty) { @@ -475,7 +368,7 @@ trait LambdaTemplates { self: Templates => val blockerToApps = new IncrementalMap[Encoded, (Encoded, App)]() val byID = new IncrementalMap[Encoded, LambdaTemplate] - val byType = new IncrementalMap[FunctionType, Map[LambdaStructure, LambdaTemplate]].withDefaultValue(Map.empty) + val byType = new IncrementalMap[FunctionType, Map[TemplateStructure, LambdaTemplate]].withDefaultValue(Map.empty) val applications = new IncrementalMap[FunctionType, Set[(Encoded, App)]].withDefaultValue(Set.empty) val freeBlockers = new IncrementalMap[FunctionType, Set[(Encoded, Encoded)]].withDefaultValue(Set.empty) val freeFunctions = new IncrementalMap[FunctionType, Set[(Encoded, Encoded)]].withDefaultValue(Set.empty) diff --git a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala index c9fe21763ea309ed859566f0a2fe9fb74079299c..fac5a2213075f8a7883f945bcd31f087f1629793 100644 --- a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala @@ -104,9 +104,10 @@ trait QuantificationTemplates { self: Templates => val lambdas: Seq[LambdaTemplate], val quantifications: Seq[QuantificationTemplate], val pointers: Map[Encoded, Encoded], - val key: (Encoded, Seq[ValDef], Expr, Seq[Encoded]), + val structure: TemplateStructure, val body: Expr, - stringRepr: () => String) { + stringRepr: () => String, + private val isConcrete: Boolean) { lazy val start = pathVar._2 lazy val mapping: Map[Variable, Encoded] = polarity match { @@ -124,21 +125,29 @@ trait QuantificationTemplates { self: Templates => lambdas.map(_.substitute(substituter, msubst)), quantifications.map(_.substitute(substituter, msubst)), pointers.map(p => substituter(p._1) -> substituter(p._2)), - (substituter(key._1), key._2, key._3, key._4.map(substituter)), - body, stringRepr) + structure.substitute(substituter, msubst), + body, stringRepr, isConcrete) + + def concretize: QuantificationTemplate = { + assert(!isConcrete, "Can't concretize concrete quantification template") + val substituter = mkSubstituter(structure.instantiationSubst) + new QuantificationTemplate( + pathVar, polarity, quantifiers, condVars, exprVars, condTree, + clauses map substituter, + blockers.map { case (b, fis) => b -> fis.map(_.substitute(substituter, Map.empty)) }, + applications.map { case (b, fas) => b -> fas.map(_.substitute(substituter, Map.empty)) }, + matchers.map { case (b, ms) => b -> ms.map(_.substitute(substituter, Map.empty)) }, + lambdas.map(_.substitute(substituter, Map.empty)), + quantifications.map(_.substitute(substituter, Map.empty)), + pointers.map(p => substituter(p._1) -> substituter(p._2)), + structure, body, stringRepr, true) + } private lazy val str : String = stringRepr() override def toString : String = str } object QuantificationTemplate { - def templateKey(quantifiers: Seq[ValDef], expr: Expr, substMap: Map[Variable, Encoded]): (Seq[ValDef], Expr, Seq[Encoded]) = { - val (vals, struct, deps) = normalizeStructure(quantifiers, expr) - val encoder = mkEncoder(substMap) _ - val depClosures = deps.toSeq.sortBy(_._1.id.uniqueName).map(p => encoder(p._2)) - (vals, struct, depClosures) - } - def apply( pathVar: (Variable, Encoded), optPol: Option[Boolean], @@ -151,6 +160,7 @@ trait QuantificationTemplates { self: Templates => equations: Seq[Expr], lambdas: Seq[LambdaTemplate], quantifications: Seq[QuantificationTemplate], + structure: TemplateStructure, substMap: Map[Variable, Encoded], proposition: Forall ): (Option[Variable], QuantificationTemplate) = { @@ -199,13 +209,10 @@ trait QuantificationTemplates { self: Templates => extraGuarded merge guardedExprs, extraEqs ++ equations, lambdas, quantifications, substMap = substMap ++ extraSubst) - val tk = templateKey(proposition.args, proposition.body, substMap) - val key = (pathVar._2, tk._1, tk._2, tk._3) - (optVar, new QuantificationTemplate( pathVar, polarity, quantifiers, condVars, exprVars, condTree, clauses, - blockers, applications, matchers, lambdas, quantifications, pointers, key, - proposition.body, () => "Template for " + proposition + " is :\n" + templateString())) + blockers, applications, matchers, lambdas, quantifications, pointers, structure, + proposition.body, () => "Template for " + proposition + " is :\n" + templateString(), false)) } } @@ -221,8 +228,8 @@ trait QuantificationTemplates { self: Templates => val handledSubsts = new IncrementalMap[Quantification, Set[(Set[Encoded], Map[Encoded,Arg])]] val ignoredGrounds = new IncrementalMap[Int, Set[Quantification]] - val lambdaAxioms = new IncrementalSet[LambdaStructure] - val templates = new IncrementalMap[(Encoded, Seq[ValDef], Expr, Seq[Encoded]), (QuantificationTemplate, Map[Encoded, Encoded])] + val lambdaAxioms = new IncrementalSet[TemplateStructure] + val templates = new IncrementalMap[TemplateStructure, (QuantificationTemplate, Encoded, Map[Encoded, Encoded])] val incrementals: Seq[IncrementalState] = Seq(quantifications, lambdaAxioms, templates, ignoredMatchers, handledMatchers, ignoredSubsts, handledSubsts, ignoredGrounds) @@ -355,10 +362,14 @@ trait QuantificationTemplates { self: Templates => private case class TypeKey(tt: Type) extends TypedKey(tt) private def matcherKey(key: Either[(Encoded, Type), TypedFunDef]): MatcherKey = key match { - case Right(tfd) => FunctionKey(tfd) - case Left((caller, ft: FunctionType)) if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structure.lambda, ft) - case Left((caller, ft: FunctionType)) => CallerKey(caller, ft) - case Left((_, tpe)) => TypeKey(tpe) + case Right(tfd) => + FunctionKey(tfd) + case Left((caller, ft: FunctionType)) if byID.isDefinedAt(caller) => + LambdaKey(byID(caller).structure.body.asInstanceOf[Lambda], ft) + case Left((caller, ft: FunctionType)) => + CallerKey(caller, ft) + case Left((_, tpe)) => + TypeKey(tpe) } @inline @@ -881,16 +892,9 @@ trait QuantificationTemplates { self: Templates => case Lambda(args, body) => rec(Application(caller, args.map(_.toVariable)), body) case _ => Equals(caller, body) } - rec(template.ids._1, template.structure.lambda) + rec(template.ids._1, template.structure.body) } - val tk = QuantificationTemplate.templateKey( - quantifiers.map(_._1.toVal), - body, - template.structure.locals.toMap + template.ids - ) - val key = (template.pathVar._2, tk._1, tk._2, tk._3) - instantiateQuantification(new QuantificationTemplate( template.pathVar, Positive(guard), @@ -907,23 +911,28 @@ trait QuantificationTemplates { self: Templates => template.lambdas.map(_.substitute(substituter, Map.empty)), template.quantifications.map(_.substitute(substituter, Map.empty)), template.pointers.map(p => substituter(p._1) -> substituter(p._2)), - key, body, template.stringRepr))._2 // mapping is guaranteed empty!! + template.structure, body, template.stringRepr, false))._2 // mapping is guaranteed empty!! } } def instantiateQuantification(template: QuantificationTemplate): (Map[Encoded, Encoded], Clauses) = { - templates.get(template.key) match { - case Some((_, map)) => + templates.get(template.structure).orElse { + templates.collectFirst { case (s, t) if s subsumes template.structure => t } + } match { + case Some((_, _, map)) => (map, Seq.empty) case None => + val newTemplate = template.concretize val clauses = new scala.collection.mutable.ListBuffer[Encoded] - val mapping: Map[Encoded, Encoded] = template.polarity match { + clauses ++= newTemplate.structure.instantiation + + val (inst, mapping): (Encoded, Map[Encoded, Encoded]) = newTemplate.polarity match { case Positive(guard) => - val axiom = new Axiom(template.pathVar._2, guard, - template.quantifiers, template.condVars, template.exprVars, template.condTree, - template.clauses, template.blockers, template.applications, template.matchers, - template.lambdas, template.quantifications, template.pointers, template.body) + val axiom = new Axiom(newTemplate.pathVar._2, guard, + newTemplate.quantifiers, newTemplate.condVars, newTemplate.exprVars, newTemplate.condTree, + newTemplate.clauses, newTemplate.blockers, newTemplate.applications, newTemplate.matchers, + newTemplate.lambdas, newTemplate.quantifications, newTemplate.pointers, newTemplate.body) quantifications += axiom @@ -933,32 +942,32 @@ trait QuantificationTemplates { self: Templates => val groundGen = currentGeneration + 3 ignoredGrounds += groundGen -> (ignoredGrounds.getOrElse(groundGen, Set.empty) + axiom) - Map.empty + (trueT, Map.empty) case Negative(insts) => val instT = encodeSymbol(insts._1) val (substMap, substClauses) = Template.substitution( - template.condVars, template.exprVars, template.condTree, - template.lambdas, template.quantifications, template.pointers, - Map(insts._2 -> Left(instT)), template.pathVar._2) + newTemplate.condVars, newTemplate.exprVars, newTemplate.condTree, + newTemplate.lambdas, newTemplate.quantifications, newTemplate.pointers, + Map(insts._2 -> Left(instT)), newTemplate.pathVar._2) clauses ++= substClauses - // this will call `instantiateMatcher` on all matchers in `template.matchers` - clauses ++= Template.instantiate(template.clauses, - template.blockers, template.applications, template.matchers, substMap) + // this will call `instantiateMatcher` on all matchers in `newTemplate.matchers` + clauses ++= Template.instantiate(newTemplate.clauses, + newTemplate.blockers, newTemplate.applications, newTemplate.matchers, substMap) - Map(insts._2 -> instT) + (instT, Map(insts._2 -> instT)) case Unknown(qs, q2s, insts, guard) => val qT = encodeSymbol(qs._1) val substituter = mkSubstituter(Map(qs._2 -> qT)) - val quantification = new GeneralQuantification(template.pathVar._2, + val quantification = new GeneralQuantification(newTemplate.pathVar._2, qs._1 -> qT, q2s, insts, guard, - template.quantifiers, template.condVars, template.exprVars, template.condTree, - template.clauses map substituter, // one clause depends on 'qs._2' (and therefore 'qT') - template.blockers, template.applications, template.matchers, - template.lambdas, template.quantifications, template.pointers, template.body) + newTemplate.quantifiers, newTemplate.condVars, newTemplate.exprVars, newTemplate.condTree, + newTemplate.clauses map substituter, // one clause depends on 'qs._2' (and therefore 'qT') + newTemplate.blockers, newTemplate.applications, newTemplate.matchers, + newTemplate.lambdas, newTemplate.quantifications, newTemplate.pointers, newTemplate.body) quantifications += quantification @@ -966,9 +975,9 @@ trait QuantificationTemplates { self: Templates => clauses ++= quantification.instantiate(bs, m) } - val freshQuantifiers = template.quantifiers.map(p => encodeSymbol(p._1)) - val freshSubst = mkSubstituter((template.quantifiers.map(_._2) zip freshQuantifiers).toMap) - for ((b,ms) <- template.matchers; m <- ms) { + val freshQuantifiers = newTemplate.quantifiers.map(p => encodeSymbol(p._1)) + val freshSubst = mkSubstituter((newTemplate.quantifiers.map(_._2) zip freshQuantifiers).toMap) + for ((b,ms) <- newTemplate.matchers; m <- ms) { clauses ++= instantiateMatcher(Set.empty[Encoded], m, false) // it is very rare that such instantiations are actually required, so we defer them val gen = currentGeneration + 20 @@ -976,10 +985,22 @@ trait QuantificationTemplates { self: Templates => } clauses ++= quantification.ensureGrounds - Map(qs._2 -> qT) + (qT, Map(qs._2 -> qT)) + } + + clauses ++= templates.flatMap { case (key, (tmpl, tinst, _)) => + if (newTemplate.structure.body == tmpl.structure.body) { + val eqConds = (newTemplate.structure.locals zip tmpl.structure.locals) + .filter(p => p._1 != p._2) + .map(p => mkEquals(p._1._2, p._2._2)) + val cond = mkAnd(newTemplate.pathVar._2 +: tmpl.pathVar._2 +: eqConds : _*) + Some(mkImplies(cond, mkEquals(inst, tinst))) + } else { + None + } } - templates += template.key -> ((template, mapping)) + templates += newTemplate.structure -> ((newTemplate, inst, mapping)) (mapping, clauses.toSeq) } } diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala index 66887e0fcffa3f425cd2a811f11308979a03ec68..453e8fc0661bd107f0fdd789fd9b72ec089f08e9 100644 --- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala @@ -307,70 +307,10 @@ trait TemplateGenerator { self: Templates => rec(pathVar, a, Some(true)) } - val (struct, deps) = normalizeStructure(without.asInstanceOf[Lambda]) - val sortedDeps = exprOps.variablesOf(struct).map(v => v -> deps(v)).toSeq.sortBy(_._1.id.uniqueName) - - val isNormalForm: Boolean = { - def extractBody(e: Expr): (Seq[ValDef], Expr) = e match { - case Lambda(args, body) => - val (nextArgs, nextBody) = extractBody(body) - (args ++ nextArgs, nextBody) - case _ => (Seq.empty, e) - } + val (realLambda, structure, depSubst) = recStructure(pathVar, l) - val (params, app) = extractBody(struct) - ApplicationExtractor.extract(app, simplify) match { - case Some((caller: Variable, args)) => - !app.getType.isInstanceOf[FunctionType] && - (params.map(_.toVariable) == args) && - (deps.get(caller) match { - case Some(_: Application | _: FunctionInvocation | _: Variable | _: ADTSelector) => true - case _ => false - }) - case _ => false - } - } - - val depsByScope: Seq[(Variable, Expr)] = { - def rec(v: Variable): Seq[Variable] = - (exprOps.variablesOf(deps(v)) & deps.keySet - v).toSeq.flatMap(rec) :+ v - deps.keys.toSeq.flatMap(rec).distinct.map(v => v -> deps(v)) - } + val depClosures = exprOps.variablesOf(l).flatMap(lambdaVars.get) - val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars - val (depSubst, (depConds, depExprs, depTree, depGuarded, depEqs, depLambdas, depQuants)) = - depsByScope.foldLeft[(Map[Variable, Encoded], TemplateClauses)](localSubst -> emptyClauses) { - case ((depSubst, clsSet), (v, expr)) => - if (!exprOps.isSimple(expr)) { - val normalExpr = if (!isNormalForm) simplifyHOFunctions(expr) else expr - val (e, cls @ (_, _, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, normalExpr, depSubst) - val clauseSubst = depSubst ++ lmbds.map(_.ids) ++ quants.flatMap(_.mapping) - (depSubst + (v -> mkEncoder(clauseSubst)(e)), clsSet ++ cls) - } else { - (depSubst + (v -> mkEncoder(depSubst)(expr)), clsSet) - } - } - - val (depClauses, depCalls, depApps, depMatchers, depPointers, _) = Template.encode( - pathVar -> encodedCond(pathVar), Seq.empty, - depConds, depExprs, depGuarded, depEqs, depLambdas, depQuants, depSubst) - - val depClosures: Seq[Encoded] = { - var cls: Seq[Variable] = Seq.empty - for ((_, e) <- sortedDeps) { - val vars = exprOps.variablesOf(e).toSet - exprOps.preTraversal { case v: Variable if vars(v) => cls :+= v case _ => } (e) - } - cls.distinct.map(depSubst) - } - - val dependencies = sortedDeps.map(p => depSubst(p._1)) - - val structure = new LambdaStructure( - struct, dependencies, pathVar -> encodedCond(pathVar), depClosures, depConds, depExprs, depTree, - depClauses, depCalls, depApps, depMatchers, depLambdas, depQuants, depPointers) - - val realLambda = if (isNormalForm) l else struct val lid = Variable(FreshIdentifier("lambda", true), l.getType) val clauses = liftedEquals(lid, realLambda, idArgs, inlineFirst = true) @@ -382,12 +322,13 @@ trait TemplateGenerator { self: Templates => val template = LambdaTemplate(ids, pathVar -> encodedCond(pathVar), idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, - lambdaGuarded, lambdaEqs, lambdaTemplates, lambdaQuants, structure, depSubst, l) + lambdaGuarded, lambdaEqs, lambdaTemplates, lambdaQuants, structure, depClosures, depSubst, l) registerLambda(template) lid case f: Forall => val (assumptions, Forall(args, body)) = liftAssumptions(f) + val argsSet = args.toSet for (a <- assumptions) { rec(pathVar, a, Some(true)) @@ -395,20 +336,25 @@ trait TemplateGenerator { self: Templates => val TopLevelAnds(conjuncts) = body - val conjunctQs = conjuncts.map { conjunct => - val vars = exprOps.variablesOf(conjunct) - val quantifiers = args.map(_.toVariable).filter(vars).toSet + val conjunctQs = conjuncts.map { conj => + val conjArgs: Seq[ValDef] = { + var vds: Seq[ValDef] = Seq.empty + exprOps.preTraversal { case v: Variable if argsSet(v.toVal) => vds :+= v.toVal case _ => } (conj) + vds.distinct + } + + val (Forall(args, body), structure, depSubst) = recStructure(pathVar, Forall(conjArgs, conj)) + val quantifiers = args.map(_.toVariable).toSet val idQuantifiers : Seq[Variable] = quantifiers.toSeq val trQuantifiers : Seq[Encoded] = idQuantifiers.map(encodeSymbol) - val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars - val clauseSubst: Map[Variable, Encoded] = localSubst ++ (idQuantifiers zip trQuantifiers) - val (p, (qConds, qExprs, qTree, qGuarded, qEqs, qLambdas, qQuants)) = mkExprClauses(pathVar, conjunct, clauseSubst) + val clauseSubst: Map[Variable, Encoded] = depSubst ++ (idQuantifiers zip trQuantifiers) + val (p, (qConds, qExprs, qTree, qGuarded, qEqs, qLambdas, qQuants)) = mkExprClauses(pathVar, body, clauseSubst) val (optVar, template) = QuantificationTemplate(pathVar -> encodedCond(pathVar), pol, p, idQuantifiers zip trQuantifiers, qConds, qExprs, qTree, qGuarded, qEqs, qLambdas, qQuants, - localSubst, Forall(quantifiers.toSeq.sortBy(_.id.uniqueName).map(_.toVal), conjunct)) + structure, depSubst, Forall(conjArgs, conj)) registerQuantification(template) optVar.getOrElse(BooleanLiteral(true)) } @@ -418,6 +364,71 @@ trait TemplateGenerator { self: Templates => case Operator(as, r) => r(as.map(a => rec(pathVar, a, None))) } + def recStructure(pathVar: Variable, expr: Expr): (Expr, TemplateStructure, Map[Variable, Encoded]) = { + val (assumptions, without) = liftAssumptions(expr) + + for (a <- assumptions) { + rec(pathVar, a, Some(true)) + } + + val (struct, deps) = normalizeStructure(without) + val sortedDeps = exprOps.variablesOf(struct).map(v => v -> deps(v)).toSeq.sortBy(_._1.id.uniqueName) + + lazy val isNormalForm: Boolean = { + def extractBody(e: Expr): (Seq[ValDef], Expr) = e match { + case Lambda(args, body) => + val (nextArgs, nextBody) = extractBody(body) + (args ++ nextArgs, nextBody) + case _ => (Seq.empty, e) + } + + val (params, app) = extractBody(struct) + ApplicationExtractor.extract(app, simplify) match { + case Some((caller: Variable, args)) => + !app.getType.isInstanceOf[FunctionType] && + (params.map(_.toVariable) == args) && + (deps.get(caller) match { + case Some(_: Application | _: FunctionInvocation | _: Variable | _: ADTSelector) => true + case _ => false + }) + case _ => false + } + } + + val depsByScope: Seq[(Variable, Expr)] = { + def rec(v: Variable): Seq[Variable] = + (exprOps.variablesOf(deps(v)) & deps.keySet - v).toSeq.flatMap(rec) :+ v + deps.keys.toSeq.flatMap(rec).distinct.map(v => v -> deps(v)) + } + + val localSubst: Map[Variable, Encoded] = substMap ++ condVars ++ exprVars ++ lambdaVars + val (depSubst, (depConds, depExprs, depTree, depGuarded, depEqs, depLambdas, depQuants)) = + depsByScope.foldLeft[(Map[Variable, Encoded], TemplateClauses)](localSubst -> emptyClauses) { + case ((depSubst, clsSet), (v, expr)) => + if (!exprOps.isSimple(expr)) { + val normalExpr = if (!isNormalForm) simplifyHOFunctions(expr) else expr + val (e, cls @ (_, _, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, normalExpr, depSubst) + val clauseSubst = depSubst ++ lmbds.map(_.ids) ++ quants.flatMap(_.mapping) + (depSubst + (v -> mkEncoder(clauseSubst)(e)), clsSet ++ cls) + } else { + (depSubst + (v -> mkEncoder(depSubst)(expr)), clsSet) + } + } + + val (depClauses, depCalls, depApps, depMatchers, depPointers, _) = Template.encode( + pathVar -> encodedCond(pathVar), Seq.empty, + depConds, depExprs, depGuarded, depEqs, depLambdas, depQuants, depSubst) + + val dependencies = sortedDeps.map(p => depSubst(p._1)) + + val structure = new TemplateStructure(struct, dependencies, + pathVar -> encodedCond(pathVar), depConds, depExprs, depTree, + depClauses, depCalls, depApps, depMatchers, depLambdas, depQuants, depPointers) + + val res = if (isNormalForm) expr else struct + (res, structure, depSubst) + } + val p = rec(pathVar, expr, polarity) (p, (condVars, exprVars, condTree, guardedExprs, eqs, lambdas, quantifications)) } diff --git a/src/main/scala/inox/solvers/unrolling/Templates.scala b/src/main/scala/inox/solvers/unrolling/Templates.scala index 4afc1a8298ead11b16113f9f3f3e17c815782387..c834726f5024fa12aa03f05c71f5c55e0f56aa3d 100644 --- a/src/main/scala/inox/solvers/unrolling/Templates.scala +++ b/src/main/scala/inox/solvers/unrolling/Templates.scala @@ -341,6 +341,104 @@ trait Templates extends TemplateGenerator override def toString : String = "Instantiated template" } + /** Semi-template used for inner-template equality + * + * We introduce a structure here that resembles a [[Template]] that is instantiated + * ONCE when the corresponding template becomes of interest. */ + class TemplateStructure( + + /** The normalized expression that is shared between all templates that are "equal". + * Template equality is conditioned on [[body]] equality. + * + * @see [[dependencies]] for the other component of equality + */ + val body: Expr, + + /** The closed expressions (independent of the arguments to [[body]]) contained in + * the inner-template. Equality is conditionned on equality of [[dependencies]] + * (inside the solver). + * + * @see [[body]] for the other component of equality + */ + val dependencies: Seq[Encoded], + + /** The condition under which this structure can be reached within the program. If + * the `pathVar` does not hold, then equality will not be checked. */ + val pathVar: (Variable, Encoded), + + val condVars: Map[Variable, Encoded], + val exprVars: Map[Variable, Encoded], + val condTree: Map[Variable, Set[Variable]], + val clauses: Clauses, + val blockers: Calls, + val applications: Apps, + val matchers: Matchers, + val lambdas: Seq[LambdaTemplate], + val quantifications: Seq[QuantificationTemplate], + val pointers: Map[Encoded, Encoded]) { + + def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]) = new TemplateStructure( + body, + dependencies.map(substituter), + pathVar._1 -> substituter(pathVar._2), + condVars, exprVars, condTree, + clauses.map(substituter), + blockers.map { case (b, fis) => substituter(b) -> fis.map(_.substitute(substituter, msubst)) }, + applications.map { case (b, fas) => substituter(b) -> fas.map(_.substitute(substituter, msubst)) }, + matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, + lambdas.map(_.substitute(substituter, msubst)), + quantifications.map(_.substitute(substituter, msubst)), + pointers.map(p => substituter(p._1) -> substituter(p._2)) + ) + + /** The [[key]] value (triplet of [[body]], a normalization of [[pathVar]] and [[locals]]) + * is used to determine syntactic equality between inner-templates. If the key of two such + * templates are equal, then they must necessarily be equal in every model. + * + * The [[instantiation]] consists of the clause set instantiation (in the sense of + * [[Template.instantiate]] that is required for [[dependencies]] to make sense in the solver + * (introduces blockers, lambdas, quantifications, etc.) Since [[dependencies]] CHANGE during + * instantiation and [[key]] makes no sense without the associated instantiation, the implicit + * contract here is that whenever a new key appears during unfolding, its associated + * instantiation MUST be added to the set of instantiations managed by the solver. However, if + * an identical (or subsuming) pre-existing key has already been found, then the associated + * instantiation must already appear in the handled by the solver and the new one can be discarded. + * + * The [[locals]] value consists of the [[dependencies]] on which the substitution resulting + * from instantiation has been applied. The [[dependencies]] should not be directly used here + * as they may depend on closure and quantifier ids that were only obtained when [[instantiation]] + * was computed. + * + * The [[instantiationSubst]] substitution corresponds that applied to [[dependencies]] when + * constructing [[locals]]. + */ + lazy val (key, instantiation, locals, instantiationSubst) = { + val (substMap, substInst) = Template.substitution(condVars, exprVars, condTree, + lambdas, quantifications, pointers, Map.empty, pathVar._2) + val tmplInst = Template.instantiate(clauses, blockers, applications, matchers, substMap) + val instantiation = substInst ++ tmplInst + + val substituter = mkSubstituter(substMap.mapValues(_.encoded)) + val deps = dependencies.map(substituter) + val key = (body, blockerPath(pathVar._2), deps) + + val sortedDeps = exprOps.variablesOf(body).toSeq.sortBy(_.id.uniqueName) + val locals = sortedDeps zip deps + (key, instantiation, locals, substMap.mapValues(_.encoded)) + } + + override def equals(that: Any): Boolean = that match { + case (struct: TemplateStructure) => key == struct.key + case _ => false + } + + override def hashCode: Int = key.hashCode + + def subsumes(that: TemplateStructure): Boolean = { + key._1 == that.key._1 && key._3 == that.key._3 && key._2.subsetOf(that.key._2) + } + } + private[unrolling] def mkApplication(caller: Expr, args: Seq[Expr]): Expr = caller.getType match { case FunctionType(from, to) => val (curr, next) = args.splitAt(from.size) @@ -541,7 +639,7 @@ trait Templates extends TemplateGenerator val lambdaKeys = lambdas.map(lambda => lambda.ids._2 -> lambda).toMap def extractSubst(lambda: LambdaTemplate): Unit = { for { - dep <- lambda.structure.closures flatMap lambdaKeys.get + dep <- lambda.closures map lambdaKeys if !seen(dep) } extractSubst(dep) diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala index 8aa6d10e5024b3174fdd30398b8706dec12fb6d0..d279eccb0e4fcc224cc4dcb94d309b585c67d7f1 100644 --- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala @@ -302,7 +302,7 @@ trait AbstractUnrollingSolver extends Solver { self => } }.toMap - exprOps.replaceFromSymbols(localsSubst, tmpl.structure.lambda).asInstanceOf[Lambda] + exprOps.replaceFromSymbols(localsSubst, tmpl.structure.body).asInstanceOf[Lambda] } }