diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index c0a219c4ddb18ec73945c718d0af5551c44d5c09..0e76036d6fe35984843fd347abd6e90c53a8cf35 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -286,8 +286,14 @@ object ExprOps extends GenTreeOps[Expr] { * * This function relies on the static map `typedIds` to ensure identical * structures and must therefore be synchronized. + * + * The optional argument [[onlySimple]] determines whether non-simple expressions + * (see [[isSimple]]) should be normalized into a dependency or recursed into + * (when they don't depend on [[args]]). This distinction is used in the + * unrolling solver to provide geenral equality checks between functions even when + * they have complex closures. */ - def normalizeStructure(args: Seq[Identifier], expr: Expr): (Seq[Identifier], Expr, Map[Identifier, Expr]) = synchronized { + def normalizeStructure(args: Seq[Identifier], expr: Expr, onlySimple: Boolean = true): (Seq[Identifier], Expr, Map[Identifier, Expr]) = synchronized { val vars = args.toSet class Normalizer extends TreeTransformer { @@ -316,7 +322,7 @@ object ExprOps extends GenTreeOps[Expr] { } override def transform(e: Expr)(implicit bindings: Map[Identifier, Identifier]): Expr = e match { - case expr if isSimple(expr) && (variablesOf(expr) & vars).isEmpty => getId(expr).toVariable + case expr if (isSimple(expr) || !onlySimple) && (variablesOf(expr) & vars).isEmpty => getId(expr).toVariable case _ => super.transform(e) } } @@ -332,7 +338,7 @@ object ExprOps extends GenTreeOps[Expr] { } def normalizeStructure(lambda: Lambda): (Lambda, Map[Identifier, Expr]) = { - val (args, body, subst) = normalizeStructure(lambda.args.map(_.id), lambda.body) + val (args, body, subst) = normalizeStructure(lambda.args.map(_.id), lambda.body, onlySimple = false) (Lambda(args.map(ValDef(_)), body), subst) } diff --git a/src/main/scala/leon/solvers/unrolling/LambdaManager.scala b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala index 775fcc2671f7c108fab00a30de2a039d6eda1298..485edfde6c3a938b831e9a8e3f34e1da21f648f1 100644 --- a/src/main/scala/leon/solvers/unrolling/LambdaManager.scala +++ b/src/main/scala/leon/solvers/unrolling/LambdaManager.scala @@ -20,6 +20,7 @@ import Template._ import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} +/** Represents an application of a first-class function in the unfolding procedure */ case class App[T](caller: T, tpe: FunctionType, args: Seq[Arg[T]], encoded: T) { override def toString = "(" + caller + " : " + tpe + ")" + args.map(_.encoded).mkString("(", ",", ")") def substitute(substituter: T => T, msubst: Map[T, Matcher[T]]): App[T] = copy( @@ -29,6 +30,11 @@ case class App[T](caller: T, tpe: FunctionType, args: Seq[Arg[T]], encoded: T) { ) } +/** Constructor object for [[LambdaTemplate]] + * + * The [[apply]] methods performs some pre-processing before creating + * an instance of [[LambdaTemplate]]. + */ object LambdaTemplate { def apply[T]( @@ -43,8 +49,8 @@ object LambdaTemplate { guardedExprs: Map[Identifier, Seq[Expr]], quantifications: Seq[QuantificationTemplate[T]], lambdas: Seq[LambdaTemplate[T]], + structure: LambdaStructure[T], baseSubstMap: Map[Identifier, T], - dependencies: Map[Identifier, T], lambda: Lambda ) : LambdaTemplate[T] = { @@ -60,9 +66,6 @@ object LambdaTemplate { "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() } - val (structuralLambda, deps) = normalizeStructure(lambda) - val keyDeps = deps.map { case (id, dep) => id -> encoder.encodeExpr(dependencies)(dep) } - new LambdaTemplate[T]( ids, encoder, @@ -78,34 +81,113 @@ object LambdaTemplate { lambdas, matchers, quantifications, - structuralLambda, - keyDeps, + structure, lambda, lambdaString ) } } -trait KeyedTemplate[T, E <: Expr] { - val dependencies: Map[Identifier, T] - val structure: E +/** Semi-template used for hardcore function equality + * + * Function equality, while unhandled in general, can be very useful for certain + * proofs that refer specifically to first-class functions. In order to support + * such proofs, flexible notions of equality on first-class functions are + * necessary. These are provided by [[LambdaStructure]] which, much like a + * [[Template]], will generate clauses that represent equality between two + * functions. + * + * To support complex cases of equality where closed portions of the first-class + * function rely on complex program features (function calls, introducing lambdas, + * foralls, etc.), we use a structure that resembles a [[Template]] that is + * instantiated when function equality is of interest. + * + * Note that lambda creation now introduces clauses to determine equality between + * closed portions (that are independent of the lambda arguments). + */ +class LambdaStructure[T] private[unrolling] ( + /** @see [[Template.encoder]] */ + val encoder: TemplateEncoder[T], + /** @see [[Template.manager]] */ + val manager: QuantificationManager[T], - lazy val key: (E, Seq[T]) = { - def rec(e: Expr): Seq[Identifier] = e match { - case Variable(id) => - if (dependencies.isDefinedAt(id)) { - Seq(id) - } else { - Seq.empty - } + /** The normalized lambda that is shared between all "equal" first-class functions. + * First-class function equality is conditionned on `lambda` equality. + * + * @see [[dependencies]] for the other component of equality between first-class functions + */ + val lambda: Lambda, - case Operator(es, _) => es.flatMap(rec) + /** The closed expressions (independent of the arguments to [[lambda]] contained in + * the first-class function. Equality is conditioned on equality of `dependencies` + * (inside the solver). + * + * @see [[lambda]] for the other component of equality between first-class functions + */ + val dependencies: Seq[T], + val pathVar: (Identifier, T), - case _ => Seq.empty - } + /** The set of closed variables that exist in the associated lambda. + * + * This set is necessary to determine whether other closures have been + * captured by this particular closure when deciding the order of + * lambda instantiations in [[Template.substitution]]. + */ + val closures: Seq[T], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], + val clauses: Seq[T], + val blockers: Calls[T], + val applications: Apps[T], + val lambdas: Seq[LambdaTemplate[T]], + val matchers: Map[T, Set[Matcher[T]]], + val quantifications: Seq[QuantificationTemplate[T]]) { + + def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]) = new LambdaStructure[T]( + encoder, manager, lambda, + dependencies.map(substituter), + pathVar._1 -> substituter(pathVar._2), + closures.map(substituter), condVars, exprVars, condTree, + clauses.map(substituter), + blockers.map { case (b, fis) => substituter(b) -> fis.map(fi => fi.copy( + args = fi.args.map(_.substitute(substituter, matcherSubst)))) }, + applications.map { case (b, fas) => substituter(b) -> fas.map(_.substitute(substituter, matcherSubst)) }, + lambdas.map(_.substitute(substituter, matcherSubst)), + matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, matcherSubst)) }, + quantifications.map(_.substitute(substituter, matcherSubst))) + + /** The [[key]] value (tuple of [[lambda]] and [[dependencies]]) is used + * to determine syntactic equality between lambdas. If the keys of two + * closures are equal, then they must necessarily be equal in every model. + * + * The [[instantiation]] consists of the clause set instantiation (in the + * sense of [[Template.instantiate]] that is required for [[dependencies]] + * to make sense in the solver (introduces blockers, quantifications, other + * lambdas, etc.) Since [[dependencies]] CHANGE during instantiation and + * [[key]] makes no sense without the associated instantiation, the implicit + * contract here is that whenever a new key appears during unfolding, its + * associated instantiation MUST be added to the set of instantiations + * managed by the solver. However, if an identical pre-existing key has + * already been found, then the associated instantiations must already appear + * in those handled by the solver. + */ + lazy val (key, instantiation) = { + val (substMap, substInst) = Template.substitution[T](encoder, manager, + condVars, exprVars, condTree, quantifications, lambdas, Set.empty, Map.empty, pathVar._1, pathVar._2) + val tmplInst = Template.instantiate(encoder, manager, clauses, blockers, applications, matchers, substMap) + + val key = (lambda, dependencies.map(encoder.substitute(substMap.mapValues(_.encoded)))) + val instantiation = substInst ++ tmplInst + (key, instantiation) + } - structure -> rec(structure).distinct.map(dependencies) + override def equals(that: Any): Boolean = that match { + case (struct: LambdaStructure[T]) => key == struct.key + case _ => false } + + override def hashCode: Int = key.hashCode } class LambdaTemplate[T] private ( @@ -117,16 +199,15 @@ class LambdaTemplate[T] private ( val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], val condTree: Map[Identifier, Set[Identifier]], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], + val clauses: Clauses[T], + val blockers: Calls[T], + val applications: Apps[T], val lambdas: Seq[LambdaTemplate[T]], val matchers: Map[T, Set[Matcher[T]]], val quantifications: Seq[QuantificationTemplate[T]], - val structure: Lambda, - val dependencies: Map[Identifier, T], + val structure: LambdaStructure[T], val lambda: Lambda, - stringRepr: () => String) extends Template[T] with KeyedTemplate[T, Lambda] { + stringRepr: () => String) extends Template[T] { val args = arguments.map(_._2) val tpe = bestRealType(ids._1.getType).asInstanceOf[FunctionType] @@ -144,10 +225,7 @@ class LambdaTemplate[T] private ( val newApplications = applications.map { case (b, fas) => val bp = if (b == start) newStart else b - bp -> fas.map(fa => fa.copy( - caller = substituter(fa.caller), - args = fa.args.map(_.substitute(substituter, matcherSubst)) - )) + bp -> fas.map(_.substitute(substituter, matcherSubst)) } val newLambdas = lambdas.map(_.substitute(substituter, matcherSubst)) @@ -159,7 +237,7 @@ class LambdaTemplate[T] private ( val newQuantifications = quantifications.map(_.substitute(substituter, matcherSubst)) - val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) + val newStructure = structure.substitute(substituter, matcherSubst) new LambdaTemplate[T]( ids._1 -> substituter(ids._2), @@ -176,8 +254,7 @@ class LambdaTemplate[T] private ( newLambdas, newMatchers, newQuantifications, - structure, - newDependencies, + newStructure, lambda, stringRepr ) @@ -189,7 +266,7 @@ class LambdaTemplate[T] private ( ids._1 -> idT, encoder, manager, pathVar, arguments, condVars, exprVars, condTree, clauses map substituter, // make sure the body-defining clause is inlined! blockers, applications, lambdas, matchers, quantifications, - structure, dependencies, lambda, stringRepr + structure, lambda, stringRepr ) } @@ -205,7 +282,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco private[unrolling] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) protected[unrolling] val byID = new IncrementalMap[T, LambdaTemplate[T]] - protected val byType = new IncrementalMap[FunctionType, Map[(Expr, Seq[T]), LambdaTemplate[T]]].withDefaultValue(Map.empty) + protected val byType = new IncrementalMap[FunctionType, Map[LambdaStructure[T], LambdaTemplate[T]]].withDefaultValue(Map.empty) protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) protected val knownFree = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) protected val maybeFree = new IncrementalMap[FunctionType, Set[(T, T)]].withDefaultValue(Set.empty) @@ -286,7 +363,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco } def instantiateLambda(template: LambdaTemplate[T]): (T, Instantiation[T]) = { - byType(template.tpe).get(template.key) match { + byType(template.tpe).get(template.structure) match { case Some(template) => (template.ids._2, Instantiation.empty) @@ -295,7 +372,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco val newTemplate = template.withId(idT) // make sure the new lambda isn't equal to any free lambda var - val instantiation = Instantiation.empty[T] withClauses ( + val instantiation = newTemplate.structure.instantiation withClauses ( equalityClauses(newTemplate) ++ knownFree(newTemplate.tpe).map(f => encoder.mkNot(encoder.mkEquals(idT, f))).toSeq ++ maybeFree(newTemplate.tpe).map { p => @@ -303,7 +380,7 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco }) byID += idT -> newTemplate - byType += newTemplate.tpe -> (byType(newTemplate.tpe) + (newTemplate.key -> newTemplate)) + byType += newTemplate.tpe -> (byType(newTemplate.tpe) + (newTemplate.structure -> newTemplate)) val inst = applications(newTemplate.tpe).foldLeft(instantiation) { case (instantiation, app @ (_, App(caller, _, args, _))) => @@ -353,19 +430,22 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(enco } private def equalityClauses(template: LambdaTemplate[T]): Seq[T] = { - val (s1, deps1) = template.key byType(template.tpe).values.map { that => - val (s2, deps2) = that.key val equals = encoder.mkEquals(template.ids._2, that.ids._2) - if (s1 == s2) { - val pairs = (deps1 zip deps2).filter(p => p._1 != p._2) - if (pairs.isEmpty) equals else { - val eqs = pairs.map(p => encoder.mkEquals(p._1, p._2)) - encoder.mkEquals(encoder.mkAnd(eqs : _*), equals) - } - } else { - encoder.mkNot(equals) - } + encoder.mkImplies( + encoder.mkAnd(template.pathVar._2, that.pathVar._2), + if (template.structure.lambda == that.structure.lambda) { + val pairs = template.structure.dependencies zip that.structure.dependencies + val filtered = pairs.filter(p => p._1 != p._2) + if (filtered.isEmpty) { + equals + } else { + val eqs = filtered.map(p => encoder.mkEquals(p._1, p._2)) + encoder.mkEquals(encoder.mkAnd(eqs : _*), equals) + } + } else { + encoder.mkNot(equals) + }) }.toSeq } } diff --git a/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala b/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala index e4fabf2daf1bb200d95e272188223efc9a8c105b..797f11aa1475d063cca706bc7b14da81a3a55dc8 100644 --- a/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/unrolling/QuantificationManager.scala @@ -60,9 +60,17 @@ class QuantificationTemplate[T]( val structure: Forall, val dependencies: Map[Identifier, T], val forall: Forall, - stringRepr: () => String) extends KeyedTemplate[T, Forall] { + stringRepr: () => String) { lazy val start = pathVar._2 + lazy val key: (Forall, Seq[T]) = (structure, { + var cls: Seq[T] = Seq.empty + purescala.ExprOps.preTraversal { + case Variable(id) => cls ++= dependencies.get(id) + case _ => + } (structure) + cls + }) def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): QuantificationTemplate[T] = { new QuantificationTemplate[T]( @@ -150,7 +158,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private val ignoredSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Int, Set[T], Map[T,Arg[T]])]] private val handledSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Set[T], Map[T,Arg[T]])]] - private val lambdaAxioms = new IncrementalSet[((Expr, Seq[T]), Seq[(Identifier, T)])] + private val lambdaAxioms = new IncrementalSet[(LambdaStructure[T], Seq[(Identifier, T)])] private val templates = new IncrementalMap[(Expr, Seq[T]), T] override protected def incrementals: List[IncrementalState] = @@ -166,7 +174,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private def matcherKey(caller: T, tpe: TypeTree): MatcherKey = tpe match { case ft: FunctionType if knownFree(ft)(caller) => CallerKey(caller, tpe) - case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structure, tpe) + case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structure.lambda, tpe) case _ => TypeKey(tpe) } @@ -668,7 +676,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } val quantifiers = quantified zip abstractNormalizer.normalize(quantified) - val key = template.key -> quantifiers + val key = template.structure -> quantifiers if (quantifiers.isEmpty || lambdaAxioms(key)) { Instantiation.empty[T] diff --git a/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala index 29b307f3d6fd9cb85b14361b8a16c75e3991fe30..9912becb82b65d8087fc3a8bfbee7bf4b38611c8 100644 --- a/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateGenerator.scala @@ -413,9 +413,43 @@ class TemplateGenerator[T](val theories: TheoryEncoder, clauses.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(pathVar, cls, clauseSubst)) val ids: (Identifier, T) = lid -> storeLambda(lid) - val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap + + val (struct, deps) = normalizeStructure(l) + + import Template._ + import Instantiation.MapSetWrapper + + val (dependencies, (depConds, depExprs, depTree, depGuarded, depLambdas, depQuants)) = + deps.foldLeft[(Seq[T], Clauses)](Seq.empty -> emptyClauses) { + case ((dependencies, clsSet), (id, expr)) => + if (!isSimple(expr)) { + val encoded = encoder.encodeId(id) + val (e, cls @ (_, _, _, _, lmbds, quants)) = mkExprClauses(pathVar, expr, localSubst) + val clauseSubst = localSubst ++ lmbds.map(_.ids) ++ quants.map(_.qs) + (dependencies :+ encoder.encodeExpr(clauseSubst)(e), clsSet ++ cls) + } else { + (dependencies :+ encoder.encodeExpr(localSubst)(expr), clsSet) + } + } + + val (depClauses, depCalls, depApps, _, depMatchers, _) = Template.encode( + encoder, pathVar -> encodedCond(pathVar), Seq.empty, + depConds, depExprs, depGuarded, depLambdas, depQuants, localSubst) + + val depClosures: Seq[T] = { + val vars = variablesOf(l) + var cls: Seq[Identifier] = Seq.empty + preTraversal { case Variable(id) if vars(id) => cls :+= id case _ => } (l) + cls.distinct.map(localSubst) + } + + val structure = new LambdaStructure[T]( encoder, manager, + struct, dependencies, pathVar -> encodedCond(pathVar), depClosures, + depConds, depExprs, depTree, depClauses, depCalls, depApps, depLambdas, depMatchers, depQuants) + val template = LambdaTemplate(ids, encoder, manager, pathVar -> encodedCond(pathVar), - idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaQuants, lambdaTemplates, localSubst, dependencies, l) + idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, + lambdaGuarded, lambdaQuants, lambdaTemplates, structure, localSubst, l) registerLambda(template) Variable(lid) diff --git a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala index 2857a0cbb6188efb3d82690c588400d886f79b7f..63cf3ea024e48bf7d3cc509b6533826538776a2b 100644 --- a/src/main/scala/leon/solvers/unrolling/TemplateManager.scala +++ b/src/main/scala/leon/solvers/unrolling/TemplateManager.scala @@ -18,6 +18,19 @@ import utils._ import scala.collection.generic.CanBuildFrom +/** A template instatiation + * + * [[Template]] instances, when provided with concrete arguments and a + * blocker, will generate three outputs used for program unfolding: + * - clauses: clauses that will be added to the underlying solver + * - call blockers: bookkeeping information necessary for named + * function unfolding + * - app blockers: bookkeeping information necessary for first-class + * function unfolding + * + * This object provides helper methods to deal with the triplets + * generated during unfolding. + */ object Instantiation { type Clauses[T] = Seq[T] type CallBlockers[T] = Map[T, Set[TemplateCallInfo[T]]] @@ -61,6 +74,11 @@ object Instantiation { import Instantiation.{empty => _, _} import Template.{Apps, Calls, Functions, Arg} +/** Pre-compiled sets of clauses with extra bookkeeping information that enables + * efficient unfolding of function calls and applications. + * [[Template]] is a super-type for all such clause sets that can be instantiated + * given a concrete argument list and a blocker in the decision-tree. + */ trait Template[T] { self => val encoder : TemplateEncoder[T] val manager : TemplateManager[T] @@ -75,7 +93,7 @@ trait Template[T] { self => val clauses : Clauses[T] val blockers : Calls[T] val applications : Apps[T] - val functions : Set[(T, FunctionType, T)] + val functions : Functions[T] val lambdas : Seq[LambdaTemplate[T]] val quantifications : Seq[QuantificationTemplate[T]] @@ -325,10 +343,10 @@ object Template { // suffices to make sure the traversal order is correct. var seen : Set[LambdaTemplate[T]] = Set.empty - val lambdaKeys = lambdas.map(lambda => lambda.ids._1 -> lambda).toMap + val lambdaKeys = lambdas.map(lambda => lambda.ids._2 -> lambda).toMap def extractSubst(lambda: LambdaTemplate[T]): Unit = { for { - dep <- lambda.dependencies.flatMap(p => lambdaKeys.get(p._1)) + dep <- lambda.structure.closures flatMap lambdaKeys.get if !seen(dep) } extractSubst(dep) diff --git a/src/test/resources/regression/verification/purescala/valid/Lambdas2.scala b/src/test/resources/regression/verification/purescala/valid/Lambdas2.scala new file mode 100644 index 0000000000000000000000000000000000000000..05574dd64b7da9c8cf542360e24244989d8ef869 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Lambdas2.scala @@ -0,0 +1,29 @@ +import leon.lang._ + +object Lambdas2 { + + val init_messages = (i: BigInt) => BigInt(0) + + + def smallNumbers(n: BigInt, messages: BigInt => BigInt)(i: BigInt, j: BigInt) = { + i < n && j < n + } + + def intForAll2(n: BigInt, m: BigInt, p: (BigInt, BigInt) => Boolean): Boolean = { +// forall ((i: BigInt, j: BigInt) => (0 <= i && i < n && 0 <= j && j < n) ==> p(i,j)) + + if (n <= 0 || m <= 0) true + else p(n-1,m-1) && intForAll2(n-1, m, p) && intForAll2(n, m-1, p) + } + + def invariant(n: BigInt, messages: BigInt => BigInt) = { + intForAll2(n, n, smallNumbers(n, messages)) + } + + def theorem(n: BigInt) = { + require(intForAll2(n, n, smallNumbers(n, init_messages))) + + invariant(n, init_messages) + } holds + +}