From acdba3aad40dd1a07427df3e16279b937bd328f5 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Mon, 14 Nov 2016 20:09:46 +0100 Subject: [PATCH] Specialization for container adts --- .../inox/solvers/unrolling/BagSuite.scala | 31 ++- src/main/scala/inox/Main.scala | 1 + src/main/scala/inox/ast/Definitions.scala | 43 ++-- src/main/scala/inox/ast/Extractors.scala | 37 +++- src/main/scala/inox/ast/Printers.scala | 1 + src/main/scala/inox/ast/SymbolOps.scala | 209 +++++++++++++----- .../solvers/unrolling/DatatypeTemplates.scala | 5 +- .../solvers/unrolling/FunctionTemplates.scala | 13 +- .../solvers/unrolling/LambdaTemplates.scala | 27 ++- .../unrolling/QuantificationTemplates.scala | 20 +- .../solvers/unrolling/TemplateGenerator.scala | 57 +++-- .../inox/solvers/unrolling/Templates.scala | 90 ++++++-- 12 files changed, 390 insertions(+), 144 deletions(-) diff --git a/src/it/scala/inox/solvers/unrolling/BagSuite.scala b/src/it/scala/inox/solvers/unrolling/BagSuite.scala index af74660a5..09c524c9f 100644 --- a/src/it/scala/inox/solvers/unrolling/BagSuite.scala +++ b/src/it/scala/inox/solvers/unrolling/BagSuite.scala @@ -15,6 +15,7 @@ class BagSuite extends SolvingTestSuite with DatastructureUtils { optSelectedSolvers(Set(solverName)), optCheckModels(true), optFeelingLucky(feelingLucky), + optNoSimplifications(solverName == "smt-cvc4"), optTimeout(300), ast.optPrintUniqueIds(true) ) @@ -57,7 +58,27 @@ class BagSuite extends SolvingTestSuite with DatastructureUtils { }) } - val symbols = baseSymbols.withFunctions(Seq(bag, split)) + val split2ID = FreshIdentifier("split2") + val split2 = mkFunDef(split2ID)("A") { case Seq(aT) => ( + Seq("l" :: List(aT)), T(List(aT), List(aT)), { case Seq(l) => + let( + "res" :: T(List(aT), List(aT)), + if_ (l.isInstOf(Cons(aT)) && l.asInstOf(Cons(aT)).getField(tail).isInstOf(Cons(aT))) { + let( + "tuple" :: T(List(aT), List(aT)), + E(splitID)(aT)(l.asInstOf(Cons(aT)).getField(tail).asInstOf(Cons(aT)).getField(tail)) + ) { tuple => E( + Cons(aT)(l.asInstOf(Cons(aT)).getField(head), tuple._1), + Cons(aT)(l.asInstOf(Cons(aT)).getField(tail).asInstOf(Cons(aT)).getField(head), tuple._2) + )} + } else_ { + E(Nil(aT)(), Nil(aT)()) + } + ) { res => Assume(bag(aT)(l) === BagUnion(bag(aT)(res._1), bag(aT)(res._2)), res) } + }) + } + + val symbols = baseSymbols.withFunctions(Seq(bag, split, split2)) test("Finite model finding 1") { ctx => val program = InoxProgram(ctx, symbols) @@ -102,4 +123,12 @@ class BagSuite extends SolvingTestSuite with DatastructureUtils { assert(SimpleSolverAPI(SolverFactory.default(program)).solveVALID(clause) contains true) } + + test("split2 doesn't preserve content") { ctx => + val program = InoxProgram(ctx, symbols) + val Let(vd, body, Assume(pred, _)) = split2.fullBody + val clause = Let(vd, body, pred) + + assert(SimpleSolverAPI(SolverFactory.default(program)).solveSAT(Not(clause)).isSAT) + } } diff --git a/src/main/scala/inox/Main.scala b/src/main/scala/inox/Main.scala index 09ac0e7fa..6cc0e5fcf 100644 --- a/src/main/scala/inox/Main.scala +++ b/src/main/scala/inox/Main.scala @@ -53,6 +53,7 @@ trait MainHelpers { solvers.unrolling.optUnrollFactor -> "Number of unrollings to perform in each unfold step", solvers.unrolling.optFeelingLucky -> "Use evaluator to find counter-examples early", solvers.unrolling.optUnrollAssumptions -> "Use unsat-assumptions to drive unfolding while remaining fair", + solvers.unrolling.optNoSimplifications -> "Disable selector/quantifier simplifications in solvers", solvers.smtlib.optCVC4Options -> "Pass extra options to CVC4", evaluators.optMaxCalls -> "Maximum number of function invocations allowed during evaluation", evaluators.optIgnoreContracts -> "Don't fail on invalid contracts during evaluation" diff --git a/src/main/scala/inox/ast/Definitions.scala b/src/main/scala/inox/ast/Definitions.scala index 659a94fc8..2c6e87145 100644 --- a/src/main/scala/inox/ast/Definitions.scala +++ b/src/main/scala/inox/ast/Definitions.scala @@ -191,6 +191,22 @@ trait Definitions { self: Trees => /** The root of the class hierarchy */ def root(implicit s: Symbols): ADTDefinition + def isInductive(implicit s: Symbols): Boolean = { + val base = typed + + def rec(adt: TypedADTDefinition, seen: Set[TypedADTDefinition], first: Boolean = false): Boolean = { + if (!first && adt == base) true else if (seen(adt)) false else (adt match { + case tsort: TypedADTSort => tsort.constructors.exists(rec(_, seen + tsort)) + case tcons: TypedADTConstructor => tcons.fieldsTypes.flatMap(tpe => s.typeOps.collect { + case t: ADTType => Set(t.getADT) + case _ => Set.empty[TypedADTDefinition] + } (tpe)).exists(rec(_, seen + tcons)) + }) + } + + rec(base, Set.empty, first = true) + } + /** An invariant that refines this [[ADTDefinition]] */ def invariant(implicit s: Symbols): Option[FunDef] = { val rt = root @@ -222,32 +238,6 @@ trait Definitions { self: Trees => case sort => throw NotWellFormedException(sort) }) - def isInductive(implicit s: Symbols): Boolean = { - def induct(tpe: Type, seen: Set[ADTDefinition]): Boolean = tpe match { - case adt: ADTType => - val tadt = adt.lookupADT.getOrElse(throw ADTLookupException(adt.id)) - val root = tadt.definition.root - seen(root) || { - val constructors = root match { - case tcons: ADTConstructor => Seq(tcons) - case tsort: ADTSort => tsort.constructors - } - - constructors.exists(tcons => tcons.fields.exists(vd => induct(vd.tpe, seen + root))) - } - - case TupleType(tpes) => - tpes.exists(tpe => induct(tpe, seen)) - - case _ => false - } - - if (this == root && !this.isSort) false - else constructors.exists { cons => - cons.fields.exists(vd => induct(vd.getType, Set(root))) - } - } - def root(implicit s: Symbols): ADTDefinition = this def typed(implicit s: Symbols): TypedADTSort = typed(tparams.map(_.tp)) @@ -280,6 +270,7 @@ trait Definitions { self: Trees => val flags: Set[Flag]) extends ADTDefinition { val isSort = false + /** Returns the index of the field with the specified id */ def selectorID2Index(id: Identifier) : Int = { val index = fields.indexWhere(_.id == id) diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala index f09ba2b8b..192054bf0 100644 --- a/src/main/scala/inox/ast/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -291,7 +291,7 @@ trait Extractors { self: Trees => } object IsTyped { - def unapply[T <: Typed](e: T)(implicit p: Symbols): Option[(T, Type)] = Some((e, e.getType)) + def unapply[T <: Typed](e: T)(implicit s: Symbols): Option[(T, Type)] = Some((e, e.getType)) } def unwrapTuple(e: Expr, isTuple: Boolean)(implicit s: Symbols): Seq[Expr] = e.getType match { @@ -301,7 +301,7 @@ trait Extractors { self: Trees => case tp => sys.error(s"Calling unwrapTuple on non-tuple $e of type $tp") } - def unwrapTuple(e: Expr, expectedSize: Int)(implicit p: Symbols): Seq[Expr] = unwrapTuple(e, expectedSize > 1) + def unwrapTuple(e: Expr, expectedSize: Int)(implicit s: Symbols): Seq[Expr] = unwrapTuple(e, expectedSize > 1) def unwrapTupleType(tp: Type, isTuple: Boolean): Seq[Type] = tp match { case TupleType(subs) if isTuple => subs @@ -311,4 +311,37 @@ trait Extractors { self: Trees => def unwrapTupleType(tp: Type, expectedSize: Int): Seq[Type] = unwrapTupleType(tp, expectedSize > 1) + + object RecordType { + def unapply(tpe: ADTType)(implicit s: Symbols): Option[TypedADTConstructor] = tpe.getADT match { + case tcons: TypedADTConstructor if !tcons.definition.isInductive => Some(tcons) + case tsort: TypedADTSort if tsort.constructors.size == 1 => unapply(tsort.constructors.head.toType) + case _ => None + } + } + + object FunctionContainerType { + def unapply(tpe: Type)(implicit s: Symbols): Boolean = { + def rec(tpe: Type, first: Boolean = false): Boolean = tpe match { + case RecordType(tcons) => tcons.fieldsTypes.exists(rec(_)) + case TupleType(tpes) => tpes.exists(rec(_)) + case _: FunctionType if !first => true + case _ => false + } + + rec(tpe, first = true) + } + } + + object Container { + def unapply(e: Expr)(implicit s: Symbols): Option[(Seq[Expr], Seq[Expr] => Expr)] = e.getType match { + case adt @ RecordType(tcons) => + Some((tcons.fields.map(vd => ADTSelector(e, vd.id)), es => ADT(adt, es))) + + case TupleType(tps) => + Some((tps.indices.map(i => TupleSelect(e, i + 1)).toSeq, es => Tuple(es))) + + case _ => None + } + } } diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala index af0858fbc..259686819 100644 --- a/src/main/scala/inox/ast/Printers.scala +++ b/src/main/scala/inox/ast/Printers.scala @@ -367,6 +367,7 @@ trait Printers { protected def isSimpleExpr(e: Expr): Boolean = e match { case _: Let => false + case _: Assume => false case _ => true } diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 49cd236dc..990bb8928 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -54,6 +54,7 @@ trait SymbolOps { self: TypeOps => case And(args) => Some(andJoin(args)) case Or(args) => Some(orJoin(args)) case Tuple(args) => Some(tupleWrap(args)) + case Application(e, es) => Some(application(e, es)) case _ => None } postMap(step)(expr) @@ -142,9 +143,9 @@ trait SymbolOps { self: TypeOps => } } - def rec(vars: Set[Variable], body: Expr): Expr = { + def outer(vars: Set[Variable], body: Expr): Expr = { - class Normalizer extends SelfTreeTransformer { + object normalizer extends SelfTreeTransformer { override def transform(id: Identifier, tpe: Type): (Identifier, Type) = (transformId(id, tpe), tpe) override def transform(e: Expr): Expr = e match { @@ -159,27 +160,30 @@ trait SymbolOps { self: TypeOps => Variable(getId(expr), expr.getType) case f: Forall => - val newBody = rec(vars ++ f.args.map(_.toVariable), f.body) + val newBody = outer(vars ++ f.args.map(_.toVariable), f.body) Forall(f.args.map(vd => vd.copy(id = varSubst(vd.id))), newBody) case l: Lambda => - val newBody = rec(vars ++ l.args.map(_.toVariable), l.body) + val newBody = outer(vars ++ l.args.map(_.toVariable), l.body) Lambda(l.args.map(vd => vd.copy(id = varSubst(vd.id))), newBody) case _ => super.transform(e) } } - val n = new Normalizer - // this registers the argument images into subst - vars foreach n.transform - n.transform(body) + vars foreach (v => transformId(v.id, v.tpe)) + normalizer.transform(body) } - val newExpr = rec(args.map(_.toVariable).toSet, expr) + val newExpr = outer(args.map(_.toVariable).toSet, expr) val bindings = args.map(vd => vd.copy(id = varSubst(vd.id))) - val freeVars = variablesOf(newExpr) -- bindings.map(_.toVariable) + + val bindingVars = bindings.map(_.toVariable).toSet + val freeVars = fixpoint { (vs: Set[Variable]) => + vs ++ subst.filter(p => vs(p._1)).flatMap(p => variablesOf(p._2)) -- bindingVars + } (variablesOf(newExpr) -- bindings.map(_.toVariable)) + val bodySubst = subst.filter(p => freeVars(p._1)).toMap (bindings, newExpr, bodySubst) @@ -242,6 +246,29 @@ trait SymbolOps { self: TypeOps => fixpoint(inline)(e) } + def inlineQuantifications(e: Expr): Expr = postMap { + case Forall(args1, Forall(args2, body)) => Some(Forall(args1 ++ args2, body)) + case a @ Assume(pred, body) => + val vars = variablesOf(a) + var assumptions: Seq[Expr] = Seq.empty + object transformer extends transformers.TransformerWithPC { + val trees: self.trees.type = self.trees + val symbols: self.symbols.type = self.symbols + val initEnv = Path.empty + + override protected def rec(e: Expr, path: Path): Expr = e match { + case Assume(pred, body) if (variablesOf(pred) ++ path.variables) subsetOf vars => + assumptions :+= path implies pred + rec(body, path withCond pred) + case _ => super.rec(e, path) + } + } + val newPred = transformer.transform(pred) + val newBody = transformer.transform(body) + Some(Assume(andJoin(newPred +: assumptions), newBody)) + case _ => None + } (e) + /* Weaker variant of disjunctive normal form */ def normalizeClauses(e: Expr): Expr = e match { case Not(Not(e)) => normalizeClauses(e) @@ -251,7 +278,7 @@ trait SymbolOps { self: TypeOps => case _ => e } - normalizeClauses(inlineFunctions(e)) + normalizeClauses(inlineQuantifications(inlineFunctions(e))) } def simplifyLets(expr: Expr): Expr = postMap({ @@ -530,30 +557,46 @@ trait SymbolOps { self: TypeOps => } object InvocationExtractor { - private def flatInvocation(expr: Expr): Option[(Identifier, Seq[Type], Seq[Expr])] = expr match { - case fi @ FunctionInvocation(id, tps, args) => Some((id, tps, args)) - case Application(caller, args) => flatInvocation(caller) match { - case Some((id, tps, prevArgs)) => Some((id, tps, prevArgs ++ args)) - case None => None + type Invocation = (Identifier, Seq[Type], Seq[Either[Identifier, Int]], Seq[Expr]) + + private def flatSelectors(expr: Expr): Option[Invocation] = expr match { + case ADTSelector(IsTyped(e, FunctionContainerType()), sid) => flatSelectors(e).map { + case (id, tps, path, args) => (id, tps, path :+ Left(sid), args) + } + case TupleSelect(IsTyped(e, FunctionContainerType()), i) => flatSelectors(e).map { + case (id, tps, path, args) => (id, tps, path :+ Right(i), args) } + case fi @ FunctionInvocation(id, tps, args) => Some((id, tps, Seq.empty, args)) case _ => None } - def unapply(expr: Expr): Option[(Identifier, Seq[Type], Seq[Expr])] = expr match { + private def flatInvocation(expr: Expr, specialize: Boolean): Option[Invocation] = expr match { + case fi @ FunctionInvocation(id, tps, args) => Some((id, tps, Seq.empty, args)) + case Application(caller, args) => flatInvocation(caller, specialize) match { + case Some((id, tps, path, prevArgs)) => Some((id, tps, path, prevArgs ++ args)) + case None => None + } + case _ => if (specialize) flatSelectors(expr) else None + } + + def extract(expr: Expr, specialize: Boolean = true): Option[Invocation] = expr match { case IsTyped(f: FunctionInvocation, ft: FunctionType) => None case IsTyped(f: Application, ft: FunctionType) => None - case FunctionInvocation(id, tps, args) => Some((id, tps, args)) - case f: Application => flatInvocation(f) + case IsTyped(f: FunctionInvocation, FunctionContainerType()) if specialize => None + case FunctionInvocation(id, tps, args) => Some((id, tps, Seq.empty, args)) + case f: Application => flatInvocation(f, specialize) case _ => None } + + object Specialized { def unapply(expr: Expr): Option[Invocation] = extract(expr, specialize = true) } + object Unspecialized { def unapply(expr: Expr): Option[Invocation] = extract(expr, specialize = false) } } - def firstOrderCallsOf(expr: Expr): Set[(Identifier, Seq[Type], Seq[Expr])] = - collect { e => InvocationExtractor.unapply(e).toSet[(Identifier, Seq[Type], Seq[Expr])] }(expr) + def firstOrderCallsOf(expr: Expr, specialize: Boolean = true): Set[InvocationExtractor.Invocation] = + collect { e => InvocationExtractor.extract(e, specialize).toSet }(expr) object ApplicationExtractor { private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(fi: FunctionInvocation, _) => None case Application(caller: Application, args) => flatApplication(caller) match { case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) case None => None @@ -562,68 +605,116 @@ trait SymbolOps { self: TypeOps => case _ => None } - def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + def extract(expr: Expr, specialize: Boolean = true): Option[(Expr, Seq[Expr])] = expr match { case IsTyped(f: Application, ft: FunctionType) => None - case f: Application => flatApplication(f) - case _ => None + case _ => InvocationExtractor.extract(expr, specialize) match { + case Some(_) => None + case None => expr match { + case f: Application => flatApplication(f) + case _ => None + } + } } + + object Specialized { def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = extract(expr, specialize = true) } + object Unspecialized { def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = extract(expr, specialize = false) } } - def firstOrderAppsOf(expr: Expr): Set[(Expr, Seq[Expr])] = - collect[(Expr, Seq[Expr])] { - case ApplicationExtractor(caller, args) => Set(caller -> args) - case _ => Set.empty - } (expr) + def firstOrderAppsOf(expr: Expr, specialize: Boolean = true): Set[(Expr, Seq[Expr])] = + collect[(Expr, Seq[Expr])](e => ApplicationExtractor.extract(e).toSet)(expr) - def simplifyHOFunctions(expr: Expr): Expr = { + def simplifyHOFunctions(expr: Expr, simplify: Boolean = true): Expr = { - def liftToLambdas(expr: Expr) = { - def apply(expr: Expr, args: Seq[Expr]): Expr = expr match { - case IfExpr(cond, thenn, elze) => - IfExpr(cond, apply(thenn, args), apply(elze, args)) - case Let(i, e, b) => - Let(i, e, apply(b, args)) - //case l @ Lambda(params, body) => - // l.withParamSubst(args, body) - case _ => Application(expr, args) - } + def pushDown(expr: Expr, recons: Expr => Expr): Expr = expr match { + case IfExpr(cond, thenn, elze) => + IfExpr(cond, pushDown(thenn, recons), pushDown(elze, recons)) + case Let(i, e, b) => + Let(i, e, pushDown(b, recons)) + case Assume(pred, body) => + Assume(pred, pushDown(body, recons)) + case _ => recons(expr) + } + + def traverse(expr: Expr, lift: Expr => Expr): Expr = { + def extract(expr: Expr, build: Boolean) = if (build) lift(expr) else expr + + def rec(expr: Expr, build: Boolean): Expr = extract(expr match { + case Application(caller, args) => + val newArgs = args.map(rec(_, true)) + val newCaller = rec(caller, false) + Application(newCaller, newArgs) + case FunctionInvocation(id, tps, args) => + val newArgs = args.map(rec(_, true)) + FunctionInvocation(id, tps, newArgs) + case l @ Lambda(args, body) => + val newBody = rec(body, true) + Lambda(args, newBody) + case Deconstructor(es, recons) => recons(es.map(rec(_, build))) + }, build) + rec(lift(expr), true) + } + + def liftToLambdas(expr: Expr) = { def lift(expr: Expr): Expr = expr.getType match { case FunctionType(from, to) => expr match { case _ : Lambda => expr case _ : Variable => expr case e => val args = from.map(tpe => ValDef(FreshIdentifier("x", true), tpe)) - val application = apply(expr, args.map(_.toVariable)) + val application = pushDown(expr, Application(_, args.map(_.toVariable))) Lambda(args, lift(application)) } case _ => expr } - def extract(expr: Expr, build: Boolean) = if (build) lift(expr) else expr + traverse(expr, lift) + } - def rec(expr: Expr, build: Boolean): Expr = expr match { - case Application(caller, args) => - val newArgs = args.map(rec(_, true)) - val newCaller = rec(caller, false) - extract(Application(newCaller, newArgs), build) - case FunctionInvocation(id, tps, args) => - val newArgs = args.map(rec(_, true)) - extract(FunctionInvocation(id, tps, newArgs), build) - case l @ Lambda(args, body) => - val newBody = rec(body, true) - extract(Lambda(args, newBody), build) - case Deconstructor(es, recons) => recons(es.map(rec(_, build))) + def liftContainers(expr: Expr): Expr = { + def lift(expr: Expr): Expr = expr.getType match { + case tpe @ FunctionContainerType() => expr match { + case _ : ADT => expr + case _ : Tuple => expr + case _ : Variable => expr + case e => tpe match { + case adt @ RecordType(tcons) => + val castExpr = if (tcons.id == adt.id) expr else AsInstanceOf(expr, tcons.toType) + val fields = tcons.fields.map(vd => pushDown(castExpr, ADTSelector(_, vd.id))) + ADT(tcons.toType, fields) + case TupleType(tpes) => + Tuple(tpes.indices.map(i => pushDown(expr, TupleSelect(_, i + 1)))) + } + } + case _ => expr } - rec(lift(expr), true) + traverse(expr, lift) } - liftToLambdas(expr) + def simplifyContainers(expr: Expr): Expr = postMap { + case ADTSelector(IsTyped(e, FunctionContainerType()), id) => + val newExpr = pushDown(e, adtSelector(_, id)) + if (newExpr != expr) Some(newExpr) else None + case TupleSelect(IsTyped(e, FunctionContainerType()), i) => + val newExpr = pushDown(e, tupleSelect(_, i, true)) + if (newExpr != expr) Some(newExpr) else None + case _ => None + } (expr) + + liftToLambdas(if (simplify) { + simplifyContainers(liftContainers(expr)) + } else { + expr + }) } - def simplifyFormula(e: Expr): Expr = { - simplifyHOFunctions(simplifyByConstructors(simplifyQuantifications(e))) + def simplifyFormula(e: Expr, simplify: Boolean = true): Expr = { + if (simplify) { + fixpoint((e: Expr) => simplifyHOFunctions(simplifyByConstructors(simplifyQuantifications(e))))(e) + } else { + simplifyHOFunctions(e, simplify = false) + } } // Use this only to debug isValueOfType diff --git a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala index 5106ccfef..96774c91d 100644 --- a/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/DatatypeTemplates.scala @@ -286,8 +286,8 @@ trait DatatypeTemplates { self: Templates => var callInfos : Set[Call] = Set.empty for (e <- es) { - callInfos ++= firstOrderCallsOf(e).map { case (id, tps, args) => - Call(getFunction(id, tps), args.map(arg => Left(encoder(arg)))) + callInfos ++= firstOrderCallsOf(e).map { case (id, tps, path, args) => + Call(getFunction(id, tps), path, args.map(arg => Left(encoder(arg)))) } clauses :+= encoder(Implies(b, e)) @@ -460,6 +460,7 @@ trait DatatypeTemplates { self: Templates => val lambdas = Seq.empty[LambdaTemplate] val matchers = Map.empty[Encoded, Set[Matcher]] val quantifications = Seq.empty[QuantificationTemplate] + val pointers = Map.empty[Encoded, Encoded] override def instantiate(substMap: Map[Encoded, Arg]): Clauses = { val substituter = mkSubstituter(substMap.mapValues(_.encoded)) diff --git a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala index 29965f7df..4827a6ed0 100644 --- a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala @@ -19,6 +19,7 @@ trait FunctionTemplates { self: Templates => def apply( tfd: TypedFunDef, + path: Seq[Either[Identifier, Int]], pathVar: (Variable, Encoded), arguments: Seq[(Variable, Encoded)], condVars: Map[Variable, Encoded], @@ -30,9 +31,9 @@ trait FunctionTemplates { self: Templates => quantifications: Seq[QuantificationTemplate] ) : FunctionTemplate = { - val (clauses, blockers, applications, matchers, templateString) = + val (clauses, blockers, applications, matchers, pointers, templateString) = Template.encode(pathVar, arguments, condVars, exprVars, guardedExprs, equations, - lambdas, quantifications, optCall = Some(tfd)) + lambdas, quantifications, optCall = Some(tfd -> path)) val funString : () => String = () => { "Template for def " + tfd.signature + @@ -41,7 +42,6 @@ trait FunctionTemplates { self: Templates => } new FunctionTemplate( - tfd, pathVar, arguments.map(_._2), condVars, @@ -53,13 +53,13 @@ trait FunctionTemplates { self: Templates => matchers, lambdas, quantifications, + pointers, funString ) } } class FunctionTemplate private( - val tfd: TypedFunDef, val pathVar: (Variable, Encoded), val args: Seq[Encoded], val condVars: Map[Variable, Encoded], @@ -71,6 +71,7 @@ trait FunctionTemplates { self: Templates => val matchers: Matchers, val lambdas: Seq[LambdaTemplate], val quantifications: Seq[QuantificationTemplate], + val pointers: Map[Encoded, Encoded], stringRepr: () => String) extends Template { private lazy val str : String = stringRepr() @@ -136,7 +137,7 @@ trait FunctionTemplates { self: Templates => for ((blocker, (gen, _, _, calls)) <- thisCallInfos if calls.nonEmpty && !interrupted; _ = remainingBlockers -= blocker; - call @ Call(tfd, args) <- calls) { + call @ Call(tfd, path, args) <- calls) { val newCls = new scala.collection.mutable.ListBuffer[Encoded] val defBlocker = defBlockers.get(call) match { @@ -149,7 +150,7 @@ trait FunctionTemplates { self: Templates => val defBlocker = encodeSymbol(Variable(FreshIdentifier("d", true), BooleanType)) defBlockers += call -> defBlocker - val template = mkTemplate(tfd) + val template = mkTemplate(tfd, path) //reporter.debug(template) val newExprs = template.instantiate(defBlocker, args) diff --git a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala index e82f6c25c..879305f93 100644 --- a/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/LambdaTemplates.scala @@ -69,7 +69,7 @@ trait LambdaTemplates { self: Templates => val id = ids._2 val tpe = ids._1.getType.asInstanceOf[FunctionType] - val (clauses, blockers, applications, matchers, templateString) = + val (clauses, blockers, applications, matchers, pointers, templateString) = Template.encode(pathVar, arguments, condVars, exprVars, guardedExprs, equations, lambdas, quantifications, substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) @@ -81,7 +81,7 @@ trait LambdaTemplates { self: Templates => ids, pathVar, arguments, condVars, exprVars, condTree, clauses, blockers, applications, matchers, - lambdas, quantifications, + lambdas, quantifications, pointers, structure, lambda, lambdaString, false ) @@ -140,7 +140,8 @@ trait LambdaTemplates { self: Templates => val applications: Apps, val matchers: Matchers, val lambdas: Seq[LambdaTemplate], - val quantifications: Seq[QuantificationTemplate]) { + val quantifications: Seq[QuantificationTemplate], + val pointers: Map[Encoded, Encoded]) { def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]) = new LambdaStructure( lambda, @@ -152,7 +153,8 @@ trait LambdaTemplates { self: Templates => applications.map { case (b, fas) => substituter(b) -> fas.map(_.substitute(substituter, msubst)) }, matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, lambdas.map(_.substitute(substituter, msubst)), - quantifications.map(_.substitute(substituter, msubst))) + quantifications.map(_.substitute(substituter, msubst)), + pointers.map(p => substituter(p._1) -> substituter(p._2))) /** The [[key]] value (tuple of [[lambda]] and [[dependencies]]) is used * to determine syntactic equality between lambdas. If the keys of two @@ -170,8 +172,8 @@ trait LambdaTemplates { self: Templates => * in those handled by the solver. */ lazy val (key, instantiation, locals, instantiationSubst) = { - val (substMap, substInst) = Template.substitution( - condVars, exprVars, condTree, lambdas, quantifications, Map.empty, pathVar._2) + val (substMap, substInst) = Template.substitution(condVars, exprVars, condTree, + lambdas, quantifications, pointers, Map.empty, pathVar._2) val tmplInst = Template.instantiate(clauses, blockers, applications, matchers, substMap) val instantiation = substInst ++ tmplInst @@ -209,6 +211,7 @@ trait LambdaTemplates { self: Templates => val matchers: Matchers, val lambdas: Seq[LambdaTemplate], val quantifications: Seq[QuantificationTemplate], + val pointers: Map[Encoded, Encoded], val structure: LambdaStructure, val lambda: Lambda, private[unrolling] val stringRepr: () => String, @@ -227,6 +230,7 @@ trait LambdaTemplates { self: Templates => matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, lambdas.map(_.substitute(substituter, msubst)), quantifications.map(_.substitute(substituter, msubst)), + pointers.map(p => substituter(p._1) -> substituter(p._2)), structure.substitute(substituter, msubst), lambda, stringRepr, isConcrete) @@ -243,6 +247,7 @@ trait LambdaTemplates { self: Templates => matchers.map { case (b, ms) => b -> ms.map(_.substitute(substituter, Map.empty)) }, lambdas.map(_.substitute(substituter, Map.empty)), quantifications.map(_.substitute(substituter, Map.empty)), + pointers.map(p => substituter(p._1) -> substituter(p._2)), structure, lambda, stringRepr, true) } @@ -305,6 +310,14 @@ trait LambdaTemplates { self: Templates => } } + def registerLambda(pointer: Encoded, target: Encoded): Boolean = byID.get(target) match { + case Some(template) => + byID += pointer -> template + true + case None => + false + } + def instantiateLambda(template: LambdaTemplate): (Encoded, Clauses) = { byType(template.tpe).get(template.structure).map { t => (t.ids._2, Seq.empty) @@ -571,7 +584,7 @@ trait LambdaTemplates { self: Templates => for ((app, (gen, infos)) <- thisAppInfos if remainingApps(app)) appInfos.get(app) match { case Some((newGen, origGen, b, notB, newInfos)) => appInfos += app -> (gen min newGen, origGen, b, notB, infos ++ newInfos) - + case None => val b = appBlockers(app) val notB = mkNot(b) diff --git a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala index f30efa6cb..26792dbf8 100644 --- a/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/QuantificationTemplates.scala @@ -103,6 +103,7 @@ trait QuantificationTemplates { self: Templates => val matchers: Matchers, val lambdas: Seq[LambdaTemplate], val quantifications: Seq[QuantificationTemplate], + val pointers: Map[Encoded, Encoded], val key: (Encoded, Seq[ValDef], Expr, Seq[Encoded]), val body: Expr, stringRepr: () => String) { @@ -122,6 +123,7 @@ trait QuantificationTemplates { self: Templates => matchers.map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter, msubst)) }, lambdas.map(_.substitute(substituter, msubst)), quantifications.map(_.substitute(substituter, msubst)), + pointers.map(p => substituter(p._1) -> substituter(p._2)), (substituter(key._1), key._2, key._3, key._4.map(substituter)), body, stringRepr) @@ -192,7 +194,7 @@ trait QuantificationTemplates { self: Templates => // the encoded clauses use the `guard` as blocker instead of `pathVar._2`. This only // works due to [[Template.encode]] injecting `pathVar` BEFORE `substMap` into the // global encoding substitution. - val (clauses, blockers, applications, matchers, templateString) = + val (clauses, blockers, applications, matchers, pointers, templateString) = Template.encode(pathVar, quantifiers, condVars, exprVars, extraGuarded merge guardedExprs, extraEqs ++ equations, lambdas, quantifications, substMap = substMap ++ extraSubst) @@ -202,8 +204,8 @@ trait QuantificationTemplates { self: Templates => (optVar, new QuantificationTemplate( pathVar, polarity, quantifiers, condVars, exprVars, condTree, clauses, - blockers, applications, matchers, lambdas, quantifications, key, proposition.body, - () => "Template for " + proposition + " is :\n" + templateString())) + blockers, applications, matchers, lambdas, quantifications, pointers, key, + proposition.body, () => "Template for " + proposition + " is :\n" + templateString())) } } @@ -523,6 +525,7 @@ trait QuantificationTemplates { self: Templates => val matchers: Matchers val lambdas: Seq[LambdaTemplate] val quantifications: Seq[QuantificationTemplate] + val pointers: Map[Encoded, Encoded] val holds: Encoded val body: Expr @@ -700,7 +703,7 @@ trait QuantificationTemplates { self: Templates => val baseSubst = subst ++ instanceSubst(enabler).mapValues(Left(_)) val (substMap, substClauses) = Template.substitution( - condVars, exprVars, condTree, lambdas, quantifications, baseSubst, enabler) + condVars, exprVars, condTree, lambdas, quantifications, pointers, baseSubst, enabler) instantiation ++= substClauses val msubst = substMap.collect { case (c, Right(m)) => c -> m } @@ -807,6 +810,7 @@ trait QuantificationTemplates { self: Templates => val matchers: Matchers, val lambdas: Seq[LambdaTemplate], val quantifications: Seq[QuantificationTemplate], + val pointers: Map[Encoded, Encoded], val body: Expr) extends Quantification { private var _currentQ2Var: Encoded = qs._2 @@ -848,6 +852,7 @@ trait QuantificationTemplates { self: Templates => val matchers: Matchers, val lambdas: Seq[LambdaTemplate], val quantifications: Seq[QuantificationTemplate], + val pointers: Map[Encoded, Encoded], val body: Expr) extends Quantification { val holds = trueT @@ -901,6 +906,7 @@ trait QuantificationTemplates { self: Templates => } merge Map(guard -> Set(matcher)), template.lambdas.map(_.substitute(substituter, Map.empty)), template.quantifications.map(_.substitute(substituter, Map.empty)), + template.pointers.map(p => substituter(p._1) -> substituter(p._2)), key, body, template.stringRepr))._2 // mapping is guaranteed empty!! } } @@ -917,7 +923,7 @@ trait QuantificationTemplates { self: Templates => val axiom = new Axiom(template.pathVar._2, guard, template.quantifiers, template.condVars, template.exprVars, template.condTree, template.clauses, template.blockers, template.applications, template.matchers, - template.lambdas, template.quantifications, template.body) + template.lambdas, template.quantifications, template.pointers, template.body) quantifications += axiom @@ -933,7 +939,7 @@ trait QuantificationTemplates { self: Templates => val instT = encodeSymbol(insts._1) val (substMap, substClauses) = Template.substitution( template.condVars, template.exprVars, template.condTree, - template.lambdas, template.quantifications, + template.lambdas, template.quantifications, template.pointers, Map(insts._2 -> Left(instT)), template.pathVar._2) clauses ++= substClauses @@ -952,7 +958,7 @@ trait QuantificationTemplates { self: Templates => template.quantifiers, template.condVars, template.exprVars, template.condTree, template.clauses map substituter, // one clause depends on 'qs._2' (and therefore 'qT') template.blockers, template.applications, template.matchers, - template.lambdas, template.quantifications, template.body) + template.lambdas, template.quantifications, template.pointers, template.body) quantifications += quantification diff --git a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala index 3b01c100e..ff73de921 100644 --- a/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala +++ b/src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala @@ -36,18 +36,37 @@ trait TemplateGenerator { self: Templates => } } - private val cache: MutableMap[TypedFunDef, FunctionTemplate] = MutableMap.empty + type SelectorPath = Seq[Either[Identifier, Int]] + private val cache: MutableMap[(TypedFunDef, SelectorPath), FunctionTemplate] = MutableMap.empty - def mkTemplate(tfd: TypedFunDef): FunctionTemplate = { - if (cache contains tfd) { - return cache(tfd) + def mkTemplate(tfd: TypedFunDef, path: SelectorPath): FunctionTemplate = { + if (cache contains (tfd -> path)) { + return cache(tfd -> path) } - val lambdaBody : Expr = simplifyFormula(tfd.fullBody) + val lambdaBody: Expr = { + def rec(e: Expr, path: SelectorPath): Expr = (e, path) match { + case (ADT(tpe, es), Left(id) +: tail) => + rec(es(tpe.getADT.toConstructor.definition.selectorID2Index(id)), tail) + case (Tuple(es), Right(i) +: tail) => + rec(es(i - 1), tail) + case _ => e + } + + rec(simplifyFormula(tfd.fullBody, simplify), path) + } val funDefArgs: Seq[Variable] = tfd.params.map(_.toVariable) val lambdaArguments: Seq[Variable] = lambdaArgs(lambdaBody) - val invocation : Expr = tfd.applied(funDefArgs) + val invocation: Expr = { + def rec(e: Expr, path: SelectorPath): Expr = path match { + case Left(id) +: tail => rec(ADTSelector(e, id), tail) + case Right(i) +: tail => rec(TupleSelect(e, i), tail) + case _ => e + } + + rec(tfd.applied(funDefArgs), path) + } val invocationEqualsBody : Seq[Expr] = liftedEquals(invocation, lambdaBody, lambdaArguments) :+ Equals(invocation, lambdaBody) @@ -63,23 +82,29 @@ trait TemplateGenerator { self: Templates => val (condVars, exprVars, condTree, guardedExprs, eqs, lambdas, quantifications) = invocationEqualsBody.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) - val template = FunctionTemplate(tfd, pathVar, arguments, + val template = FunctionTemplate(tfd, path, pathVar, arguments, condVars, exprVars, condTree, guardedExprs, eqs, lambdas, quantifications) - cache += tfd -> template + cache += (tfd, path) -> template template } private def lambdaArgs(expr: Expr): Seq[Variable] = expr match { case Lambda(args, body) => args.map(_.toVariable.freshen) ++ lambdaArgs(body) - case IsTyped(_, _: FunctionType) => sys.error("Only applicable on lambda chains") + case Assume(pred, body) => lambdaArgs(body) + case IsTyped(_, _: FunctionType) => sys.error("Only applicable on lambda/assume chains") case _ => Seq.empty } private def liftedEquals(invocation: Expr, body: Expr, args: Seq[Variable], inlineFirst: Boolean = false): Seq[Expr] = { def rec(i: Expr, b: Expr, args: Seq[Variable], inline: Boolean): Seq[Expr] = i.getType match { case FunctionType(from, to) => + def apply(e: Expr, es: Seq[Expr]): Expr = e match { + case _: Lambda if inline => application(e, es) + case Assume(pred, l: Lambda) if inline => Assume(pred, application(l, es)) + case _ => Application(e, es) + } + val (currArgs, nextArgs) = args.splitAt(from.size) - val apply = if (inline) application _ else Application val (appliedInv, appliedBody) = (apply(i, currArgs), apply(b, currArgs)) rec(appliedInv, appliedBody, nextArgs, false) :+ Equals(appliedInv, appliedBody) case _ => @@ -287,8 +312,9 @@ trait TemplateGenerator { self: Templates => case _ => (Seq.empty, e) } - extractBody(struct) match { - case (params, app @ ApplicationExtractor(caller: Variable, args)) => + val (params, app) = extractBody(struct) + ApplicationExtractor.extract(app, simplify) match { + case Some((caller: Variable, args)) => !app.getType.isInstanceOf[FunctionType] && (params.map(_.toVariable) == args) && (deps.get(caller) match { @@ -319,7 +345,7 @@ trait TemplateGenerator { self: Templates => } } - val (depClauses, depCalls, depApps, depMatchers, _) = Template.encode( + val (depClauses, depCalls, depApps, depMatchers, depPointers, _) = Template.encode( pathVar -> encodedCond(pathVar), Seq.empty, depConds, depExprs, depGuarded, depEqs, depLambdas, depQuants, depSubst) @@ -335,8 +361,8 @@ trait TemplateGenerator { self: Templates => val dependencies = sortedDeps.map(p => depSubst(p._1)) val structure = new LambdaStructure( - struct, dependencies, pathVar -> encodedCond(pathVar), depClosures, - depConds, depExprs, depTree, depClauses, depCalls, depApps, depMatchers, depLambdas, depQuants) + struct, dependencies, pathVar -> encodedCond(pathVar), depClosures, depConds, depExprs, depTree, + depClauses, depCalls, depApps, depMatchers, depLambdas, depQuants, depPointers) val realLambda = if (isNormalForm) l else struct val lid = Variable(FreshIdentifier("lambda", true), l.getType) @@ -383,5 +409,4 @@ trait TemplateGenerator { self: Templates => val p = rec(pathVar, expr, polarity) (p, (condVars, exprVars, condTree, guardedExprs, eqs, lambdas, quantifications)) } - } diff --git a/src/main/scala/inox/solvers/unrolling/Templates.scala b/src/main/scala/inox/solvers/unrolling/Templates.scala index 95f9f7b56..4afc1a829 100644 --- a/src/main/scala/inox/solvers/unrolling/Templates.scala +++ b/src/main/scala/inox/solvers/unrolling/Templates.scala @@ -8,6 +8,8 @@ import utils._ import scala.collection.generic.CanBuildFrom +object optNoSimplifications extends FlagOptionDef("nosimplifications", false) + trait Templates extends TemplateGenerator with FunctionTemplates with DatatypeTemplates @@ -39,6 +41,8 @@ trait Templates extends TemplateGenerator private[unrolling] lazy val trueT = mkEncoder(Map.empty)(BooleanLiteral(true)) private[unrolling] lazy val falseT = mkEncoder(Map.empty)(BooleanLiteral(false)) + protected lazy val simplify = !ctx.options.findOptionOrDefault(optNoSimplifications) + private var currentGen: Int = 0 protected def currentGeneration: Int = currentGen protected def nextGeneration(gen: Int): Int = gen + 3 @@ -209,12 +213,23 @@ trait Templates extends TemplateGenerator } /** Represents a named function call in the unfolding procedure */ - case class Call(tfd: TypedFunDef, args: Seq[Arg]) { + case class Call(tfd: TypedFunDef, path: Seq[Either[Identifier, Int]], args: Seq[Arg]) { override def toString = { - tfd.signature + args.map { - case Right(m) => m.toString - case Left(v) => asString(v) - }.mkString("(", ", ", ")") + tfd.signature + { + val (fdArgs, appArgs) = args.splitAt(tfd.params.size) + def pArgs(args: Seq[Arg]) = if (args.isEmpty) "" else args.map { + case Right(m) => m.toString + case Left(v) => asString(v) + }.mkString("(", ", ", ")") + + pArgs(fdArgs) + + (if (path.nonEmpty) "." else "") + + path.map { + case Left(id) => id.asString + case Right(i) => "_" + i + }.mkString(".") + + pArgs(appArgs) + } } def substitute(substituter: Encoded => Encoded, msubst: Map[Encoded, Matcher]): Call = copy( @@ -309,11 +324,13 @@ trait Templates extends TemplateGenerator val lambdas : Seq[LambdaTemplate] val quantifications : Seq[QuantificationTemplate] + val pointers : Map[Encoded, Encoded] + lazy val start = pathVar._2 def instantiate(aVar: Encoded, args: Seq[Arg]): Clauses = { val (substMap, clauses) = Template.substitution( - condVars, exprVars, condTree, lambdas, quantifications, + condVars, exprVars, condTree, lambdas, quantifications, pointers, (this.args zip args).toMap + (start -> Left(aVar)), aVar) clauses ++ instantiate(substMap) } @@ -333,6 +350,12 @@ trait Templates extends TemplateGenerator caller } + private[unrolling] def mkSelection(expr: Expr, path: Seq[Either[Identifier, Int]]): Expr = path match { + case Left(id) +: tail => mkSelection(ADTSelector(expr, id), tail) + case Right(i) +: tail => mkSelection(TupleSelect(expr, i), tail) + case _ => expr + } + object Template { def encode( @@ -345,9 +368,9 @@ trait Templates extends TemplateGenerator lambdas: Seq[LambdaTemplate], quantifications: Seq[QuantificationTemplate], substMap: Map[Variable, Encoded] = Map.empty[Variable, Encoded], - optCall: Option[TypedFunDef] = None, + optCall: Option[(TypedFunDef, Seq[Either[Identifier, Int]])] = None, optApp: Option[(Encoded, FunctionType)] = None - ) : (Clauses, Calls, Apps, Matchers, () => String) = { + ) : (Clauses, Calls, Apps, Matchers, Map[Encoded, Encoded], () => String) = { val idToTrId : Map[Variable, Encoded] = condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ @@ -355,19 +378,44 @@ trait Templates extends TemplateGenerator val encoder : Expr => Encoded = mkEncoder(idToTrId) - val optIdCall = optCall.map(tfd => Call(tfd, arguments.map(p => Left(p._2)))) + val optIdCall = optCall.map { case (tfd, path) => Call(tfd, path, arguments.map(p => Left(p._2))) } val optIdApp = optApp.map { case (idT, tpe) => val v = Variable(FreshIdentifier("x", true), tpe) val encoded = mkEncoder(Map(v -> idT) ++ arguments)(mkApplication(v, arguments.map(_._1))) App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2)), encoded) } - lazy val optIdMatcher = optCall.map { tfd => + lazy val optIdMatcher = optCall.map { case (tfd, path) => val (fiArgs, appArgs) = arguments.map(_._1).splitAt(tfd.params.size) - val encoded = mkEncoder(arguments.toMap)(mkApplication(tfd.applied(fiArgs), appArgs)) + val encoded = mkEncoder(arguments.toMap)(mkApplication(mkSelection(tfd.applied(fiArgs), path), appArgs)) Matcher(Right(tfd), arguments.map(p => Left(p._2)), encoded) } + def lambdaPointers(expr: Expr): Map[Expr, Variable] = { + def collectSelectors(expr: Expr, ptr: Expr): Seq[(Expr, Variable)] = expr match { + case ADT(tpe, es) => (tpe.getADT.toConstructor.fields zip es).flatMap { + case (vd, e) => collectSelectors(e, ADTSelector(ptr, vd.id)) + } + + case Tuple(es) => es.zipWithIndex.flatMap { + case (e, i) => collectSelectors(e, TupleSelect(ptr, i + 1)) + } + + case IsTyped(v: Variable, _: FunctionType) => Seq(ptr -> v) + case _ => Seq.empty + } + + exprOps.collect { + case Equals(v: Variable, e) => collectSelectors(e, v).toSet + case FunctionInvocation(_, _, es) => es.flatMap(e => collectSelectors(e, e)).toSet + case Application(_, es) => es.flatMap(e => collectSelectors(e, e)).toSet + case _ => Set.empty[(Expr, Variable)] + } (expr).toMap + } + + val pointers = (equations ++ guardedExprs.flatMap(_._2)).flatMap(lambdaPointers).toMap + val encodedPointers = pointers.map(p => encoder(p._1) -> encoder(p._2)) + val (clauses, blockers, applications, matchers) = { var clauses : Clauses = Seq.empty var blockers : Map[Variable, Set[Call]] = Map.empty @@ -403,11 +451,11 @@ trait Templates extends TemplateGenerator case None => Left(encoder(arg)) } - funInfos ++= firstOrderCallsOf(e).map { case (id, tps, args) => - Call(getFunction(id, tps), args.map(encodeArg)) + funInfos ++= firstOrderCallsOf(e, simplify).map { case (id, tps, path, args) => + Call(getFunction(id, tps), path, args.map(encodeArg)) } - appInfos ++= firstOrderAppsOf(e).map { case (c, args) => + appInfos ++= firstOrderAppsOf(e, simplify).map { case (c, args) => val tpe = bestRealType(c.getType).asInstanceOf[FunctionType] App(encoder(c), tpe, args.map(encodeArg), encoder(mkApplication(c, args))) } @@ -462,7 +510,7 @@ trait Templates extends TemplateGenerator }.mkString("\n") } - (clauses, encodedBlockers, encodedApps, encodedMatchers, stringRepr) + (clauses, encodedBlockers, encodedApps, encodedMatchers, encodedPointers, stringRepr) } def substitution( @@ -471,6 +519,7 @@ trait Templates extends TemplateGenerator condTree: Map[Variable, Set[Variable]], lambdas: Seq[LambdaTemplate], quantifications: Seq[QuantificationTemplate], + pointers: Map[Encoded, Encoded], baseSubst: Map[Encoded, Arg], aVar: Encoded ): (Map[Encoded, Arg], Clauses) = { @@ -518,6 +567,11 @@ trait Templates extends TemplateGenerator clauses ++= cls } + val substituter = mkSubstituter(subst.mapValues(_.encoded)) + for ((ptr, lambda) <- pointers) { + registerLambda(substituter(ptr), substituter(lambda)) + } + (subst, clauses) } @@ -557,15 +611,15 @@ trait Templates extends TemplateGenerator val tpeClauses = bindings.flatMap { case (v, s) => registerSymbol(encodedStart, s, v.getType) }.toSeq - val instExpr = simplifyFormula(expr) + val instExpr = simplifyFormula(expr, simplify) val (condVars, exprVars, condTree, guardedExprs, eqs, lambdas, quants) = mkClauses(start, instExpr, bindings + (start -> encodedStart), polarity = Some(true)) - val (clauses, calls, apps, matchers, _) = Template.encode( + val (clauses, calls, apps, matchers, pointers, _) = Template.encode( start -> encodedStart, bindings.toSeq, condVars, exprVars, guardedExprs, eqs, lambdas, quants) val (substMap, substClauses) = Template.substitution( - condVars, exprVars, condTree, lambdas, quants, Map.empty, encodedStart) + condVars, exprVars, condTree, lambdas, quants, pointers, Map.empty, encodedStart) val templateClauses = Template.instantiate(clauses, calls, apps, matchers, substMap) val allClauses = encodedStart +: (tpeClauses ++ substClauses ++ templateClauses) -- GitLab