diff --git a/src/main/java/leon/codegen/runtime/Lambda.java b/src/main/java/leon/codegen/runtime/Lambda.java index a6abbef37edbe8f87f480a21a6200e32a9e0206b..0bc5171fd6405f59ab2ec4d60e3bf368c49a7bff 100644 --- a/src/main/java/leon/codegen/runtime/Lambda.java +++ b/src/main/java/leon/codegen/runtime/Lambda.java @@ -4,4 +4,5 @@ package leon.codegen.runtime; public abstract class Lambda { public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException; + public abstract void checkForall(boolean[] quantified); } diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java index 0be4ad91212be930ec5f8730ed30ebcf9a9f4e0a..7314bfae531af0b68432ea5dd5dcf93b51d629af 100644 --- a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java @@ -8,6 +8,7 @@ import java.util.HashMap; public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { private final HashMap<Integer, List<Tuple>> domains = new HashMap<Integer, List<Tuple>>(); + private final List<String> warnings = new LinkedList<String>(); public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { super(maxInvocations); @@ -21,7 +22,8 @@ public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { public List<Tuple> domain(Object obj, int type) { List<Tuple> domain = new LinkedList<Tuple>(); if (obj instanceof PartialLambda) { - for (Tuple key : ((PartialLambda) obj).mapping.keySet()) { + PartialLambda l = (PartialLambda) obj; + for (Tuple key : l.mapping.keySet()) { domain.add(key); } } @@ -31,4 +33,12 @@ public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { return domain; } + + public void warn(String warning) { + warnings.add(warning); + } + + public List<String> getWarnings() { + return warnings; + } } diff --git a/src/main/java/leon/codegen/runtime/PartialLambda.java b/src/main/java/leon/codegen/runtime/PartialLambda.java index 826cc5ed9930e54bc2f50d7f09e6fa09be3fa307..7bab72ea31dfacb6438c7f217da0991d5238a2b2 100644 --- a/src/main/java/leon/codegen/runtime/PartialLambda.java +++ b/src/main/java/leon/codegen/runtime/PartialLambda.java @@ -6,9 +6,15 @@ import java.util.HashMap; public final class PartialLambda extends Lambda { final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>(); + private final Object dflt; public PartialLambda() { + this(null); + } + + public PartialLambda(Object dflt) { super(); + this.dflt = dflt; } public void add(Tuple key, Object value) { @@ -20,6 +26,8 @@ public final class PartialLambda extends Lambda { Tuple tuple = new Tuple(args); if (mapping.containsKey(tuple)) { return mapping.get(tuple); + } else if (dflt != null) { + return dflt; } else { throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments"); } @@ -28,7 +36,8 @@ public final class PartialLambda extends Lambda { @Override public boolean equals(Object that) { if (that != null && (that instanceof PartialLambda)) { - return mapping.equals(((PartialLambda) that).mapping); + PartialLambda l = (PartialLambda) that; + return ((dflt != null && dflt.equals(l.dflt)) || (dflt == null && l.dflt == null)) && mapping.equals(l.mapping); } else { return false; } @@ -36,6 +45,9 @@ public final class PartialLambda extends Lambda { @Override public int hashCode() { - return 63 + 11 * mapping.hashCode(); + return 63 + 11 * mapping.hashCode() + (dflt == null ? 0 : dflt.hashCode()); } + + @Override + public void checkForall(boolean[] quantified) {} } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 85ed5eb0be33daa2797d08a004048ebf2da8ac17..8860b23bc88523f84797e46352251b7b824b1b52 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -70,8 +70,10 @@ trait CodeGeneration { private[codegen] val RationalClass = "leon/codegen/runtime/Rational" private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" private[codegen] val LambdaClass = "leon/codegen/runtime/Lambda" + private[codegen] val PartialLambdaClass = "leon/codegen/runtime/PartialLambda" private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" private[codegen] val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" + private[codegen] val BadQuantificationClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" private[codegen] val GenericValuesClass = "leon/codegen/runtime/GenericValues" @@ -375,6 +377,34 @@ trait CodeGeneration { hch.freeze } + locally { + val vmh = cf.addMethod("V", "checkForall", s"[Z") + vmh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val vch = vmh.codeHandler + + vch << ALoad(1) // load boolean array `quantified` + def rec(args: Seq[Identifier], idx: Int, quantified: Set[Identifier]): Unit = args match { + case x :: xs => + val notQuantLabel = vch.getFreshLabel("notQuant") + vch << DUP << ALoad(idx) << IfEq(notQuantLabel) + rec(xs, idx + 1, quantified + x) + vch << Label(notQuantLabel) + rec(xs, idx + 1, quantified) + + case Nil => + if (quantified.nonEmpty) checkQuantified(quantified, nl.body, vch) + vch << POP << RETURN + } + + rec(nl.args.map(_.id), 0, Set.empty) + + vch.freeze + } + loader.register(cf) afName @@ -393,6 +423,29 @@ trait CodeGeneration { ch << InvokeSpecial(afName, constructorName, consSig) } + private def checkQuantified(quantified: Set[Identifier], body: Expr, ch: CodeHandler)(implicit locals: Locals): Unit = { + val status = checkForall(quantified, body) + if (status.isValid) { + purescala.ExprOps.preTraversal { + case Application(caller, args) => + ch << NewArray.primitive("T_BOOLEAN") + for ((arg, idx) <- args.zipWithIndex) { + ch << DUP << Ldc(idx) << Ldc(arg match { + case Variable(id) if quantified(id) => 1 + case _ => 0 + }) << BASTORE + } + + ch << InvokeVirtual(LambdaClass, "checkForall", "([Z)V") + case _ => + } (body) + } else { + load(monitorID, ch) + ch << Ldc("Invalid forall: " + status) + ch << InvokeVirtual(HenkinClass, "warn", "(Ljava/lang/String;)V") + } + } + private val typeIdCache = scala.collection.mutable.Map.empty[TypeTree, Int] private[codegen] def typeId(tpe: TypeTree): Int = typeIdCache.get(tpe) match { case Some(id) => id @@ -413,6 +466,9 @@ trait CodeGeneration { ch << ATHROW ch << Label(monitorOk) + val quantified = f.args.map(_.id).toSet + checkQuantified(quantified, f.body, ch) + val Forall(fargs, TopLevelAnds(conjuncts)) = f val endLabel = ch.getFreshLabel("forallEnd") @@ -882,6 +938,21 @@ trait CodeGeneration { ch << InvokeVirtual(LambdaClass, "apply", s"([L$ObjectClass;)L$ObjectClass;") mkUnbox(app.getType, ch) + case p @ PartialLambda(mapping, dflt, _) => + if (dflt.isDefined) { + mkExpr(dflt.get, ch) + ch << New(PartialLambdaClass) + ch << InvokeSpecial(PartialLambdaClass, constructorName, s"(L$ObjectClass;)V") + } else { + ch << DefaultNew(PartialLambdaClass) + } + + for ((es,v) <- mapping) { + mkExpr(Tuple(es), ch) + mkExpr(v, ch) + ch << InvokeVirtual(PartialLambdaClass, "add", s"(L$TupleClass;L$ObjectClass;)V") + } + case l @ Lambda(args, body) => compileLambda(l, ch) diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 69f95f45559737bfd895989aba010dbc0667bb3a..70428db8cc0be9664e38b8e011f92ed0ebfb04bf 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -201,8 +201,13 @@ class CompilationUnit(val ctx: LeonContext, } m - case f @ PartialLambda(mapping, _) => - val l = new leon.codegen.runtime.PartialLambda() + case f @ PartialLambda(mapping, dflt, _) => + val l = if (dflt.isDefined) { + new leon.codegen.runtime.PartialLambda(dflt.get) + } else { + new leon.codegen.runtime.PartialLambda() + } + for ((ks,v) <- mapping) { // Force tuple even with 1/0 elems. val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple] diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index a9d1eda0c5e36e19a6b6c12f99a617b480866f10..fc2d3bd6470ca900c1515eef74bedf679146b1fa 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -54,7 +54,16 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, def eval(model: solvers.Model) : Expr = { try { val monitor = unit.modelToJVM(model, params.maxFunctionInvocations) - evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) + val res = evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) + monitor match { + case hm: LHM => + val it = hm.getWarnings().iterator() + while (it.hasNext) { + unit.ctx.reporter.warning(it.next) + } + case _ => + } + res } catch { case ite : InvocationTargetException => throw ite.getCause } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index fccda2aeabd69b72bef9690e8f55fa61a562911c..9e33722dd62c67e1357acd2012d6955647715356 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -47,12 +47,14 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int def maxSteps = RecursiveEvaluator.this.maxSteps var stepsLeft = maxSteps + var warnings = List.empty[String] } def initRC(mappings: Map[Identifier, Expr]): RC def initGC(model: Model): GC // Used by leon-web, please do not delete + // Used by quantified proposition checking now too! var lastGC: Option[GC] = None private[this] var clpCache = Map[(Choose, Seq[Expr]), Expr]() @@ -61,7 +63,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int try { lastGC = Some(initGC(model)) ctx.timers.evaluators.recursive.runtime.start() - EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get)) + val res = e(ex)(initRC(model.toMap), lastGC.get) + for (warning <- lastGC.get.warnings) ctx.reporter.warning(warning) + EvaluationResults.Successful(res) } catch { case so: StackOverflowError => EvaluationResults.EvaluatorError("Stack overflow") @@ -93,10 +97,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val newArgs = args.map(e) val mapping = l.paramSubst(newArgs) e(body)(rctx.withNewVars(mapping), gctx) - case PartialLambda(mapping, _) => + case PartialLambda(mapping, dflt, _) => mapping.find { case (pargs, res) => (args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true)) - }.map(_._2).getOrElse { + }.map(_._2).orElse(dflt).getOrElse { throw EvalError("Cannot apply partial lambda outside of domain") } case f => @@ -232,7 +236,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (lv,rv) match { case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) - case (PartialLambda(m1, _), PartialLambda(m2, _)) => BooleanLiteral(m1.toSet == m2.toSet) + case (PartialLambda(m1, d1, _), PartialLambda(m2, d2, _)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) case _ => BooleanLiteral(lv == rv) } @@ -516,10 +520,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap replaceFromIDs(mapping, nl) - case PartialLambda(mapping, tpe) => - PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), tpe) + case PartialLambda(mapping, dflt, tpe) => + PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), dflt.map(e), tpe) - case f @ Forall(fargs, TopLevelAnds(conjuncts)) => + case f @ Forall(fargs, body @ TopLevelAnds(conjuncts)) => val henkinModel: HenkinModel = gctx.model match { case hm: HenkinModel => hm case _ => throw EvalError("Can't evaluate foralls without henkin model") @@ -576,7 +580,37 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int e(andJoin(instantiations.map { case (enabler, mapping) => e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx) })) - })) + })) match { + case res @ BooleanLiteral(true) => + val quantified = fargs.map(_.id).toSet + val status = checkForall(quantified, body) + if (!status.isValid) { + gctx.warnings :+= "Invalid forall: " + status + } else { + for ((caller, appArgs) <- firstOrderAppsOf(body)) e(caller) match { + case _: PartialLambda => // OK + case Lambda(args, body) => + val lambdaQuantified = (appArgs zip args).collect { + case (Variable(id), vd) if quantified(id) => vd.id + }.toSet + + if (lambdaQuantified.nonEmpty) { + val lambdaStatus = checkForall(lambdaQuantified, body) + if (!lambdaStatus.isValid) { + gctx.warnings :+= "Invalid forall: " + lambdaStatus + } + } + case f => + throw EvalError("Cannot apply non-lambda function " + f.asString) + } + } + + res + + // `res == false` means the quantification is valid since there effectivelly must + // exist an input for which the proposition doesn't hold + case res => res + } case ArrayLength(a) => val FiniteArray(_, _, IntLiteral(length)) = e(a) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index c80d777480f66d795f2bc15eef0a3010452394fb..2dc33e801c4a26728ccd0551be046c2b2c6b2ed5 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -1952,13 +1952,55 @@ object ExprOps { es foreach rec } - def functionAppsOf(expr: Expr): Set[Application] = { - collect[Application] { - case f: Application => Set(f) - case _ => Set() - }(expr) + object InvocationExtractor { + private def flatInvocation(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { + case fi @ FunctionInvocation(tfd, args) => Some((tfd, args)) + case Application(caller, args) => flatInvocation(caller) match { + case Some((tfd, prevArgs)) => Some((tfd, prevArgs ++ args)) + case None => None + } + case _ => None + } + + def unapply(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { + case IsTyped(f: FunctionInvocation, ft: FunctionType) => None + case IsTyped(f: Application, ft: FunctionType) => None + case FunctionInvocation(tfd, args) => Some(tfd -> args) + case f: Application => flatInvocation(f) + case _ => None + } + } + + def firstOrderCallsOf(expr: Expr): Set[(TypedFunDef, Seq[Expr])] = + collect[(TypedFunDef, Seq[Expr])] { + case InvocationExtractor(tfd, args) => Set(tfd -> args) + case _ => Set.empty + } (expr) + + object ApplicationExtractor { + private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(fi: FunctionInvocation, _) => None + case Application(caller: Application, args) => flatApplication(caller) match { + case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) + case None => None + } + case Application(caller, args) => Some((caller, args)) + case _ => None + } + + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case IsTyped(f: Application, ft: FunctionType) => None + case f: Application => flatApplication(f) + case _ => None + } } + def firstOrderAppsOf(expr: Expr): Set[(Expr, Seq[Expr])] = + collect[(Expr, Seq[Expr])] { + case ApplicationExtractor(caller, args) => Set(caller -> args) + case _ => Set.empty + } (expr) + def simplifyHOFunctions(expr: Expr) : Expr = { def liftToLambdas(expr: Expr) = { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 1569ebfb8ab3d8c0ab7f0e09bbb31f82870fc8f9..a45c4ee59a0bf518f320769286b00b95ce2fc42a 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -230,7 +230,7 @@ object Expressions { } } - case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], tpe: FunctionType) extends Expr { + case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], default: Option[Expr], tpe: FunctionType) extends Expr { val getType = tpe } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index cd57d2187fe0cd59d45e3d88e17b0926e2279412..b865d474b0ca5d10fbade26338afbb9766fb788c 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -41,7 +41,7 @@ object Extractors { Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head))) case Lambda(args, body) => Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head))) - case PartialLambda(mapping, tpe) => + case PartialLambda(mapping, dflt, tpe) => val sze = tpe.from.size + 1 val subArgs = mapping.flatMap { case (args, v) => args :+ v } val builder = (as: Seq[Expr]) => { @@ -52,9 +52,10 @@ object Extractors { case Seq() => Seq.empty case _ => sys.error("unexpected number of key/value expressions") } - PartialLambda(rec(as), tpe) + val (nas, nd) = if (dflt.isDefined) (as.init, Some(as.last)) else (as, None) + PartialLambda(rec(nas), nd, tpe) } - Some((subArgs, builder)) + Some((subArgs ++ dflt, builder)) case Forall(args, body) => Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head))) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 3b96fff138dd7bc6a91c66acf3e3928fac0e44d0..8159f1e9e5589acc69bdfe6267e1952c6c015793 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -247,6 +247,22 @@ class PrettyPrinter(opts: PrinterOptions, case Lambda(args, body) => optP { p"($args) => $body" } + case PartialLambda(mapping, dflt, _) => + optP { + def pm(p: (Seq[Expr], Expr)): PrinterHelpers.Printable = + (pctx: PrinterContext) => p"${purescala.Constructors.tupleWrap(p._1)} => ${p._2}"(pctx) + + if (mapping.isEmpty) { + p"{}" + } else { + p"{ ${nary(mapping map pm)} }" + } + + if (dflt.isDefined) { + p" ${dflt.get}" + } + } + case Plus(l,r) => optP { p"$l + $r" } case Minus(l,r) => optP { p"$l - $r" } case Times(l,r) => optP { p"$l * $r" } diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala index 1b00ed1b41a5053fb07a94695b00f595d54453ba..bf88450fd54063338d63310cbec7de7ddf9db76b 100644 --- a/src/main/scala/leon/purescala/Quantification.scala +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -10,6 +10,8 @@ import Extractors._ import ExprOps._ import Types._ +import evaluators._ + object Quantification { def extractQuorums[A,B]( @@ -18,6 +20,12 @@ object Quantification { margs: A => Set[A], qargs: A => Set[B] ): Seq[Set[A]] = { + def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) + val expandedMap: Map[A, Set[A]] = matchers.map(m => m -> expand(m)).toMap + val reverseMap : Map[A, Set[A]] = expandedMap + .flatMap(p => p._2.map(m => m -> p._1)) // flatten to reversed pairs + .groupBy(_._1).mapValues(_.map(_._2).toSet) // rebuild map from pair set + def rec(oms: Seq[A], mSet: Set[A], qss: Seq[Set[B]]): Seq[Set[A]] = { if (qss.contains(quantified)) { Seq(mSet) @@ -34,9 +42,10 @@ object Quantification { } } - def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) - val oms = matchers.toSeq.sortBy(m => -expand(m).size) - rec(oms, Set.empty, Seq.empty) + val oms = expandedMap.toSeq.sortBy(p => -p._2.size).map(_._1) + val res = rec(oms, Set.empty, Seq.empty) + + res.filter(ms => ms.forall(m => reverseMap(m) subsetOf ms)) } def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Expr, Seq[Expr])]] = { @@ -60,6 +69,33 @@ object Quantification { (p: (Expr, Seq[Expr])) => p._2.collect { case Variable(id) if quantified(id) => id }.toSet) } + def extractModel( + asMap: Map[Identifier, Expr], + funDomains: Map[Identifier, Set[Seq[Expr]]], + tpeDomains: Map[TypeTree, Set[Seq[Expr]]], + evaluator: Evaluator + ): Map[Identifier, Expr] = asMap.map { case (id, expr) => + id -> (funDomains.get(id) match { + case Some(domain) => + PartialLambda(domain.toSeq.map { es => + val optEv = evaluator.eval(Application(expr, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(expr, es))) + }, None, id.getType.asInstanceOf[FunctionType]) + + case None => postMap { + case p @ PartialLambda(mapping, dflt, tpe) => + Some(PartialLambda(tpeDomains.get(tpe) match { + case Some(domain) => domain.toSeq.map { es => + val optEv = evaluator.eval(Application(p, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(p, es))) + } + case _ => scala.sys.error(s"Can't extract $p without domain") + }, None, tpe)) + case _ => None + } (expr) + }) + } + object HenkinDomains { def empty = new HenkinDomains(Map.empty) def apply(domains: Map[TypeTree, Set[Seq[Expr]]]) = new HenkinDomains(domains) @@ -67,18 +103,33 @@ object Quantification { class HenkinDomains (val domains: Map[TypeTree, Set[Seq[Expr]]]) { def get(e: Expr): Set[Seq[Expr]] = e match { - case PartialLambda(mapping, _) => mapping.map(_._1).toSet + case PartialLambda(_, Some(dflt), _) => scala.sys.error("No domain for non-partial lambdas") + case PartialLambda(mapping, _, _) => mapping.map(_._1).toSet case _ => domains.get(e.getType) match { case Some(domain) => domain case None => scala.sys.error("Undefined Henkin domain for " + e) } } + + override def toString = domains.map { case (tpe, argSet) => + tpe + ": " + argSet.map(_.mkString("(", ",", ")")).mkString(", ") + }.mkString("domain={\n ", "\n ", "}") } object QuantificationMatcher { + private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(fi: FunctionInvocation, _) => None + case Application(caller: Application, args) => flatApplication(caller) match { + case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) + case None => None + } + case Application(caller, args) => Some((caller, args)) + case _ => None + } + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(_: Application | _: FunctionInvocation, _) => None - case Application(e, args) => Some(e -> args) + case IsTyped(a: Application, ft: FunctionType) => None + case Application(e, args) => flatApplication(expr) case ArraySelect(arr, index) => Some(arr -> Seq(index)) case MapApply(map, key) => Some(map -> Seq(key)) case ElementOfSet(elem, set) => Some(set -> Seq(elem)) @@ -87,8 +138,15 @@ object Quantification { } object QuantificationTypeMatcher { + private def flatType(tpe: TypeTree): (Seq[TypeTree], TypeTree) = tpe match { + case FunctionType(from, to) => + val (nextArgs, finalTo) = flatType(to) + (from ++ nextArgs, finalTo) + case _ => (Seq.empty, tpe) + } + def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match { - case FunctionType(from, to) => Some(from -> to) + case FunctionType(from, to) => Some(flatType(tpe)) case ArrayType(base) => Some(Seq(Int32Type) -> base) case MapType(from, to) => Some(Seq(from) -> to) case SetType(base) => Some(Seq(base) -> BooleanType) @@ -96,87 +154,83 @@ object Quantification { } } - object CheckForalls extends UnitPhase[Program] { - - val name = "Foralls" - val description = "Check syntax of foralls to guarantee sound instantiations" - - def apply(ctx: LeonContext, program: Program) = { - program.definedFunctions.foreach { fd => - val foralls = collect[Forall] { - case f: Forall => Set(f) - case _ => Set.empty - } (fd.fullBody) - - val free = fd.paramIds.toSet ++ (fd.postcondition match { - case Some(Lambda(args, _)) => args.map(_.id) - case _ => Seq.empty + sealed abstract class ForallStatus { + def isValid: Boolean + } + + case object ForallValid extends ForallStatus { + def isValid = true + } + + sealed abstract class ForallInvalid extends ForallStatus { + def isValid = false + } + + case object NoMatchers extends ForallInvalid + case class ComplexArgument(expr: Expr) extends ForallInvalid + case class NonBijectiveMapping(expr: Expr) extends ForallInvalid + case class InvalidOperation(expr: Expr) extends ForallInvalid + + def checkForall(quantified: Set[Identifier], body: Expr): ForallStatus = { + val TopLevelAnds(conjuncts) = body + for (conjunct <- conjuncts) { + val matchers = collect[(Expr, Seq[Expr])] { + case QuantificationMatcher(e, args) => Set(e -> args) + case _ => Set.empty + } (conjunct) + + if (matchers.isEmpty) return NoMatchers + + val complexArgs = matchers.flatMap { case (_, args) => + args.flatMap(arg => arg match { + case QuantificationMatcher(_, _) => None + case Variable(id) => None + case _ if (variablesOf(arg) & quantified).nonEmpty => Some(arg) + case _ => None }) + } - for (Forall(args, TopLevelAnds(conjuncts)) <- foralls) { - val quantified = args.map(_.id).toSet - - for (conjunct <- conjuncts) { - val matchers = collect[(Expr, Seq[Expr])] { - case QuantificationMatcher(e, args) => Set(e -> args) - case _ => Set.empty - } (conjunct) - - if (matchers.isEmpty) - ctx.reporter.warning("E-matching isn't possible without matchers!") - - if (matchers.exists { case (_, args) => - args.exists{ - case QuantificationMatcher(_, _) => false - case Variable(id) => false - case arg => (variablesOf(arg) & quantified).nonEmpty - } - }) ctx.reporter.warning("Matcher arguments must have simple form in " + conjunct) - - val freeMatchers = matchers.collect { case (Variable(id), args) if free(id) => id -> args } - - val id2Quant = freeMatchers.foldLeft(Map.empty[Identifier, Set[Identifier]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - })) - } + if (complexArgs.nonEmpty) return ComplexArgument(complexArgs.head) - if (id2Quant.filter(_._2.nonEmpty).groupBy(_._2).nonEmpty) - ctx.reporter.warning("Multiple matchers must provide bijective matching in " + conjunct) - - fold[Set[Identifier]] { case (m, children) => - val q = children.toSet.flatten - - m match { - case QuantificationMatcher(_, args) => - q -- args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - } - case LessThan(_: Variable, _: Variable) => q - case LessEquals(_: Variable, _: Variable) => q - case GreaterThan(_: Variable, _: Variable) => q - case GreaterEquals(_: Variable, _: Variable) => q - case And(_) => q - case Or(_) => q - case Implies(_, _) => q - case Operator(es, _) => - val vars = es.flatMap { - case Variable(id) => Set(id) - case _ => Set.empty[Identifier] - }.toSet - - if (!(q.isEmpty || (q.size == 1 && (vars & free).isEmpty))) - ctx.reporter.warning("Invalid operation " + m + " on quantified variables") - q -- vars - case Variable(id) if quantified(id) => Set(id) - case _ => q - } - } (conjunct) - } - } + val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + })) } + + val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) + if (bijectiveMappings.size > 1) return NonBijectiveMapping(bijectiveMappings.head._2.head._1) + + val matcherSet = matcherToQuants.filter(_._2.nonEmpty).keys.toSet + + val qs = foldRight[Set[Identifier]] { case (m, children) => + val q = children.toSet.flatten + + m match { + case QuantificationMatcher(_, args) => + q -- args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + } + case LessThan(_: Variable, _: Variable) => q + case LessEquals(_: Variable, _: Variable) => q + case GreaterThan(_: Variable, _: Variable) => q + case GreaterEquals(_: Variable, _: Variable) => q + case And(_) => q + case Or(_) => q + case Implies(_, _) => q + case Operator(es, _) => + val matcherArgs = matcherSet & es.toSet + if (q.nonEmpty && !(q.size == 1 && matcherArgs.isEmpty && m.getType == BooleanType)) + return InvalidOperation(m) + else Set.empty + case Variable(id) if quantified(id) => Set(id) + case _ => q + } + } (conjunct) } + + ForallValid } } diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 0c7afeb65199042e34004199d17bba25dde9e1c5..0ec568cc18903e4b6b5038a678ce1b44aabc8de4 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -111,9 +111,9 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying def extract(b: Expr, m: Matcher[Expr]): Set[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe - val optEnabler = evaluator.eval(b).result + val optEnabler = evaluator.eval(b, model).result if (optEnabler == Some(BooleanLiteral(true))) { - val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg)).result) + val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg), model).result) if (optArgs.forall(_.isDefined)) { Set(optArgs.map(_.get)) } else { @@ -132,28 +132,16 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying case _ => None }).toMap.mapValues(_.toSet) - val asDMap = model.map(p => funDomains.get(p._1) match { - case Some(domain) => - val mapping = domain.toSeq.map { es => - val ev: Expr = p._2 match { - case RawArrayValue(_, mapping, dflt) => - mapping.collectFirst { - case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v - } getOrElse dflt - case _ => scala.sys.error("Unexpected function encoding " + p._2) - } - es -> ev - } - - p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) - case None => p - }).toMap - val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + val asDMap = purescala.Quantification.extractModel(model.toMap, funDomains, typeDomains, evaluator) val domains = new HenkinDomains(typeDomains) - new HenkinModel(asDMap, domains) + val hmodel = new HenkinModel(asDMap, domains) + + isValidModel(hmodel) + + hmodel } def foundAnswer(res: Option[Boolean], model: Option[HenkinModel] = None) = { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 33b3a26d88a660fe737b144766c8dfdfa48ff685..05a04128a4dd6b62c0f641466d2710f8f6206df1 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -202,7 +202,8 @@ trait SMTLIBTarget extends Interruptible { r case ft @ FunctionType(from, to) => - r + val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, from.size) -> v } + PartialLambda(elems, Some(r.default), ft) case MapType(from, to) => // We expect a RawArrayValue with keys in from and values in Option[to], diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 3d5eec72c809a7ba9459b4b46752835b63bd6011..00bdbfa07ca49c450cd91b3fce0e67dc7655402a 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -12,30 +12,34 @@ import purescala.Types._ import utils._ import Instantiation._ -class LambdaManager[T](protected val encoder: TemplateEncoder[T]) extends IncrementalState { +class LambdaManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { + private[templates] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) protected val byID = new IncrementalMap[T, LambdaTemplate[T]] protected val byType = new IncrementalMap[FunctionType, Set[(T, LambdaTemplate[T])]].withDefaultValue(Set.empty) protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) + private val instantiated = new IncrementalSet[(T, App[T])] + protected def incrementals: List[IncrementalState] = - List(byID, byType, applications, freeLambdas) + List(byID, byType, applications, freeLambdas, instantiated) def clear(): Unit = incrementals.foreach(_.clear()) def reset(): Unit = incrementals.foreach(_.reset()) def push(): Unit = incrementals.foreach(_.push()) def pop(): Unit = incrementals.foreach(_.pop()) - def registerFree(lambdas: Seq[(TypeTree, T)]): Unit = { - for ((tpe, idT) <- lambdas) tpe match { + def registerFree(lambdas: Seq[(Identifier, T)]): Unit = { + for ((id, idT) <- lambdas) id.getType match { case ft: FunctionType => freeLambdas += ft -> (freeLambdas(ft) + idT) case _ => } } - def instantiateLambda(idT: T, template: LambdaTemplate[T]): Instantiation[T] = { + def instantiateLambda(template: LambdaTemplate[T]): Instantiation[T] = { + val idT = template.ids._2 var clauses : Clauses[T] = equalityClauses(idT, template) var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) @@ -55,32 +59,33 @@ class LambdaManager[T](protected val encoder: TemplateEncoder[T]) extends Increm def instantiateApp(blocker: T, app: App[T]): Instantiation[T] = { val App(caller, tpe, args) = app - var clauses : Clauses[T] = Seq.empty - var callBlockers : CallBlockers[T] = Map.empty.withDefaultValue(Set.empty) - var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) - - if (byID contains caller) { - val (newClauses, newCalls, newApps) = byID(caller).instantiate(blocker, args) + val instantiation = Instantiation.empty[T] - clauses ++= newClauses - newCalls.foreach(p => callBlockers += p._1 -> (callBlockers(p._1) ++ p._2)) - newApps.foreach(p => appBlockers += p._1 -> (appBlockers(p._1) ++ p._2)) - } else if (!freeLambdas(tpe).contains(caller)) { + if (freeLambdas(tpe).contains(caller)) instantiation else { val key = blocker -> app - // make sure that even if byType(tpe) is empty, app is recorded in blockers - // so that UnrollingBank will generate the initial block! - if (!(appBlockers contains key)) appBlockers += key -> Set.empty + if (instantiated(key)) instantiation else { + instantiated += key - for ((idT,template) <- byType(tpe)) { - val equals = encoder.mkEquals(idT, caller) - appBlockers += (key -> (appBlockers(key) + TemplateAppInfo(template, equals, args))) - } + if (byID contains caller) { + instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) + } else { - applications += tpe -> (applications(tpe) + key) - } + // make sure that even if byType(tpe) is empty, app is recorded in blockers + // so that UnrollingBank will generate the initial block! + val init = instantiation withApps Map(key -> Set.empty) + val inst = byType(tpe).foldLeft(init) { + case (instantiation, (idT, template)) => + val equals = encoder.mkEquals(idT, caller) + instantiation withApp (key -> TemplateAppInfo(template, equals, args)) + } - (clauses, callBlockers, appBlockers) + applications += tpe -> (applications(tpe) + key) + + inst + } + } + } } private def equalityClauses(idT: T, template: LambdaTemplate[T]): Seq[T] = { diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index fde9dc746b4a5e6207fc3ed896bbbd5700cbc79a..e5908ee1f9ebe139692b4bc0176cb46b7400f2c4 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -50,10 +50,31 @@ class QuantificationTemplate[T]( val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]]) { - - def instantiate(substMap: Map[T, T]): Instantiation[T] = { - quantificationManager.instantiateQuantification(this, substMap) + val lambdas: Seq[LambdaTemplate[T]]) { + + def substitute(substituter: T => T): QuantificationTemplate[T] = { + new QuantificationTemplate[T]( + quantificationManager, + substituter(start), + qs, + q2s, + insts, + guardVar, + quantifiers, + condVars, + exprVars, + clauses.map(substituter), + blockers.map { case (b, fis) => + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + }, + applications.map { case (b, apps) => + substituter(b) -> apps.map(app => app.copy(caller = substituter(app.caller), args = app.args.map(substituter))) + }, + matchers.map { case (b, ms) => + substituter(b) -> ms.map(_.substitute(substituter)) + }, + lambdas.map(_.substitute(substituter)) + ) } } @@ -70,7 +91,7 @@ object QuantificationTemplate { condVars: Map[Identifier, T], exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], subst: Map[Identifier, T] ): QuantificationTemplate[T] = { @@ -89,75 +110,190 @@ object QuantificationTemplate { } class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) { - private lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) + private val quantifications = new IncrementalSeq[MatcherQuantification] + private val instantiated = new InstantiationContext + private val fInstantiated = new InstantiationContext { + override def apply(p: (T, Matcher[T])): Boolean = + corresponding(p._2).exists(_._2.args == p._2.args) + } - private val quantifications = new IncrementalSeq[Quantification] - private val instantiated = new IncrementalSet[(T, Matcher[T])] - private val fInsts = new IncrementalSet[Matcher[T]] private val known = new IncrementalSet[T] - private def fInstantiated = fInsts.map(m => trueT -> m) - private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = correspond(qm, m.caller, m.tpe) private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = qm.tpe match { case _: FunctionType => qm.tpe == tpe && (qm.caller == caller || !known(caller)) case _ => qm.tpe == tpe } - private val uniformQuantifiers = scala.collection.mutable.Map.empty[TypeTree, Seq[T]] + private val uniformQuantMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty + private val uniformQuantSet: MutableSet[T] = MutableSet.empty + + def isQuantifier(idT: T): Boolean = uniformQuantSet(idT) private def uniformSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { qs.groupBy(_._1.getType).flatMap { case (tpe, qst) => - val prev = uniformQuantifiers.get(tpe) match { + val prev = uniformQuantMap.get(tpe) match { case Some(seq) => seq case None => Seq.empty } if (prev.size >= qst.size) { - qst.map(_._2) zip prev.take(qst.size - 1) + qst.map(_._2) zip prev.take(qst.size) } else { val (handled, newQs) = qst.splitAt(prev.size) val uQs = newQs.map(p => p._2 -> encoder.encodeId(p._1)) - uniformQuantifiers(tpe) = prev ++ uQs.map(_._2) + + uniformQuantMap(tpe) = prev ++ uQs.map(_._2) + uniformQuantSet ++= uQs.map(_._2) + (handled.map(_._2) zip prev) ++ uQs } }.toMap } override protected def incrementals: List[IncrementalState] = - List(quantifications, instantiated, fInsts, known) ++ super.incrementals + List(quantifications, instantiated, fInstantiated, known) ++ super.incrementals - def assumptions: Seq[T] = quantifications.map(_.currentQ2Var).toSeq + def assumptions: Seq[T] = quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq - def instantiations: Seq[(T, Matcher[T])] = instantiated.toSeq ++ fInstantiated + def instantiations: Seq[(T, Matcher[T])] = (instantiated.all ++ fInstantiated.all).toSeq def instantiations(caller: T, tpe: TypeTree): Seq[(T, Matcher[T])] = - instantiations.filter { case (b,m) => correspond(m, caller, tpe) } + (instantiated.corresponding(caller, tpe) ++ fInstantiated.corresponding(caller, tpe)).toSeq - override def registerFree(ids: Seq[(TypeTree, T)]): Unit = { + override def registerFree(ids: Seq[(Identifier, T)]): Unit = { super.registerFree(ids) known ++= ids.map(_._2) } - private class Quantification ( - val qs: (Identifier, T), - val q2s: (Identifier, T), - val insts: (Identifier, T), - val guardVar: T, - val quantified: Set[T], - val matchers: Set[Matcher[T]], - val allMatchers: Map[T, Set[Matcher[T]]], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val lambdas: Map[T, LambdaTemplate[T]]) { + private type Context = Set[(T, Matcher[T])] + + private class ContextMap( + private val tpeMap: MutableMap[TypeTree, Context] = MutableMap.empty, + private val funMap: MutableMap[T, Context] = MutableMap.empty + ) { + def +=(p: (T, Matcher[T])): Unit = { + tpeMap(p._2.tpe) = tpeMap.getOrElse(p._2.tpe, Set.empty) + p + p match { + case (_, Matcher(caller, tpe: FunctionType, _, _)) if known(caller) => + funMap(caller) = funMap.getOrElse(caller, Set.empty) + p + case _ => + } + } - var currentQ2Var: T = qs._2 - private var slaves: Seq[(T, Map[T,T], Quantification)] = Nil + def merge(that: ContextMap): this.type = { + for ((tpe, values) <- that.tpeMap) tpeMap(tpe) = tpeMap.getOrElse(tpe, Set.empty) ++ values + for ((caller, values) <- that.funMap) funMap(caller) = funMap.getOrElse(caller, Set.empty) ++ values + this + } + + @inline + def get(m: Matcher[T]): Context = get(m.caller, m.tpe) + + def get(caller: T, tpe: TypeTree): Context = + funMap.getOrElse(caller, Set.empty) ++ tpeMap.getOrElse(tpe, Set.empty) + + override def clone = new ContextMap(tpeMap.clone, funMap.clone) + } + + private class InstantiationContext private ( + private var _instantiated : Context, + private var _next : Context, + private var _map : ContextMap, + private var _count : Int + ) extends IncrementalState { + + def this() = this(Set.empty, Set.empty, new ContextMap, 0) + def this(ctx: InstantiationContext) = this(ctx._instantiated, Set.empty, ctx._map.clone, ctx._count) + + private val stack = new scala.collection.mutable.Stack[(Context, Context, ContextMap, Int)] + + def clear(): Unit = { + stack.clear() + _instantiated = Set.empty + _next = Set.empty + _map = new ContextMap + _count = 0 + } + + def reset(): Unit = clear() + + def push(): Unit = stack.push((_instantiated, _next, _map.clone, _count)) + + def pop(): Unit = { + val (instantiated, next, map, count) = stack.pop() + _instantiated = instantiated + _next = next + _map = map + _count = count + } + + def count = _count + def instantiated = _instantiated + def all = _instantiated ++ _next + + def corresponding(m: Matcher[T]): Context = _map.get(m) + def corresponding(caller: T, tpe: TypeTree): Context = _map.get(caller, tpe) + + def apply(p: (T, Matcher[T])): Boolean = _instantiated(p) + + def inc(): Unit = _count += 1 - private def mappings(blocker: T, matcher: Matcher[T]) - (implicit instantiated: Iterable[(T, Matcher[T])]): Set[(T, Map[T, T])] = { + def +=(p: (T, Matcher[T])): Unit = { + if (!this(p)) _next += p + } + + def ++=(ps: Iterable[(T, Matcher[T])]): Unit = { + for (p <- ps) this += p + } + + def consume: Iterator[(T, Matcher[T])] = { + var n = _next + _next = Set.empty + + new Iterator[(T, Matcher[T])] { + def hasNext = n.nonEmpty + def next = { + val p @ (b,m) = n.head + _instantiated += p + _map += p + n -= p + p + } + } + } + + def instantiateNext: Instantiation[T] = { + var instantiation = Instantiation.empty[T] + for ((b,m) <- consume) { + println("consuming " + (b -> m)) + for (q <- quantifications) { + instantiation ++= q.instantiate(b, m)(this) + } + } + instantiation + } + + def merge(that: InstantiationContext): this.type = { + _instantiated ++= that._instantiated + _next ++= that._next + _map.merge(that._map) + _count = _count max that._count + this + } + } + + private trait MatcherQuantification { + val quantified: Set[T] + val matchers: Set[Matcher[T]] + val allMatchers: Map[T, Set[Matcher[T]]] + val condVars: Map[Identifier, T] + val exprVars: Map[Identifier, T] + val clauses: Seq[T] + val blockers: Map[T, Set[TemplateCallInfo[T]]] + val applications: Map[T, Set[App[T]]] + val lambdas: Seq[LambdaTemplate[T]] + + private def mappings(blocker: T, matcher: Matcher[T], instCtx: InstantiationContext): Set[(T, Map[T, T])] = { // Build a mapping from applications in the quantified statement to all potential concrete // applications previously encountered. Also make sure the current `app` is in the mapping @@ -175,13 +311,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage .map(qm => if (qm == bindingMatcher) { bindingMatcher -> Set(blocker -> matcher) } else { - val instances: Set[(T, Matcher[T])] = instantiated.filter { case (b, m) => correspond(qm, m) }.toSet - - // concrete applications can appear multiple times in the constraint, and this is also the case - // for the current application for which we are generating the constraints - val withCurrent = if (correspond(qm, matcher)) instances + (blocker -> matcher) else instances - - qm -> withCurrent + qm -> instCtx.corresponding(qm) }).toMap // 2.2. based on the possible bindings for each quantified application, build a set of @@ -192,61 +322,86 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage for (mapping <- matcherMappings) yield extractSubst(quantified, mapping) } - private def extractSlaveInfo(enabler: T, senabler: T, subst: Map[T,T], ssubst: Map[T,T]): (T, Map[T,T]) = { - val substituter = encoder.substitute(subst) - val slaveEnabler = encoder.mkAnd(enabler, substituter(senabler)) - val slaveSubst = ssubst.map(p => p._1 -> substituter(p._2)) - (slaveEnabler, slaveSubst) - } - - private def instantiate(enabler: T, subst: Map[T, T], seen: Set[Quantification]): Instantiation[T] = { - if (seen(this)) { - Instantiation.empty[T] - } else { - val nextQ2Var = encoder.encodeId(q2s._1) + def instantiate(blocker: T, matcher: Matcher[T])(implicit instCtx: InstantiationContext): Instantiation[T] = { + var instantiation = Instantiation.empty[T] + for ((enabler, subst) <- mappings(blocker, matcher, instCtx)) { val baseSubstMap = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } - val lambdaSubstMap = lambdas map { case (idT, lambda) => idT -> encoder.encodeId(lambda.id) } - val substMap = subst ++ baseSubstMap ++ lambdaSubstMap + - (qs._2 -> currentQ2Var) + (guardVar -> enabler) + (q2s._2 -> nextQ2Var) + - (insts._2 -> encoder.encodeId(insts._1)) + val lambdaSubstMap = lambdas map(lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) + val substMap = subst ++ baseSubstMap ++ lambdaSubstMap ++ instanceSubst(enabler) - var instantiation = Template.instantiate(encoder, QuantificationManager.this, + instantiation ++= Template.instantiate(encoder, QuantificationManager.this, clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) - for { - (senabler, ssubst, slave) <- slaves - (slaveEnabler, slaveSubst) = extractSlaveInfo(enabler, senabler, subst, ssubst) - } instantiation ++= slave.instantiate(slaveEnabler, slaveSubst, seen + this) - - currentQ2Var = nextQ2Var - instantiation + val substituter = encoder.substitute(substMap) + for ((b, ms) <- allMatchers; m <- ms if !matchers(m)) { + println(m.substitute(substituter)) + instCtx += substituter(b) -> m.substitute(substituter) + } } + + instantiation } - def register(senabler: T, ssubst: Map[T, T], slave: Quantification): Instantiation[T] = { - var instantiation = Instantiation.empty[T] + protected def instanceSubst(enabler: T): Map[T, T] + } + + private class Quantification ( + val qs: (Identifier, T), + val q2s: (Identifier, T), + val insts: (Identifier, T), + val guardVar: T, + val quantified: Set[T], + val matchers: Set[Matcher[T]], + val allMatchers: Map[T, Set[Matcher[T]]], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { - for { - instantiated <- List(instantiated, fInstantiated) - (blocker, matcher) <- instantiated - (enabler, subst) <- mappings(blocker, matcher)(instantiated) - (slaveEnabler, slaveSubst) = extractSlaveInfo(enabler, senabler, subst, ssubst) - } instantiation ++= slave.instantiate(slaveEnabler, slaveSubst, Set(this)) + var currentQ2Var: T = qs._2 - slaves :+= (senabler, ssubst, slave) + protected def instanceSubst(enabler: T): Map[T, T] = { + val nextQ2Var = encoder.encodeId(q2s._1) - instantiation + val subst = Map(qs._2 -> currentQ2Var, guardVar -> enabler, + q2s._2 -> nextQ2Var, insts._2 -> encoder.encodeId(insts._1)) + + currentQ2Var = nextQ2Var + subst } + } - def instantiate(blocker: T, matcher: Matcher[T])(implicit instantiated: Iterable[(T, Matcher[T])]): Instantiation[T] = { - var instantiation = Instantiation.empty[T] + private val blockerId = FreshIdentifier("blocker", BooleanType, true) + private val blockerCache: MutableMap[T, T] = MutableMap.empty - for ((enabler, subst) <- mappings(blocker, matcher)) { - instantiation ++= instantiate(enabler, subst, Set.empty) + private class Axiom ( + val start: T, + val blocker: T, + val guardVar: T, + val quantified: Set[T], + val matchers: Set[Matcher[T]], + val allMatchers: Map[T, Set[Matcher[T]]], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { + + protected def instanceSubst(enabler: T): Map[T, T] = { + val newBlocker = blockerCache.get(enabler) match { + case Some(b) => b + case None => + val nb = encoder.encodeId(blockerId) + blockerCache(enabler) = nb + blockerCache(nb) = nb + nb } - instantiation + Map(guardVar -> encoder.mkAnd(start, enabler), blocker -> newBlocker) } } @@ -272,6 +427,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val withSubs = s ++ s.flatMap { case (b, sm, m) => subBindings(b, sm, m) } withSubs.groupBy(_._2).forall(_._2.size == 1) } + + allMappings } private def extractSubst(quantified: Set[T], mapping: Set[(T, Matcher[T], Matcher[T])]): (T, Map[T,T]) = { @@ -300,28 +457,133 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (encoder.substitute(subst)(enabler), subst) } - def instantiateQuantification(template: QuantificationTemplate[T], substMap: Map[T, T]): Instantiation[T] = { - val quantified = template.quantifiers.map(_._2).toSet - - val allMatchers: Map[T, Set[Matcher[T]]] = { - def rec(templates: Map[T, LambdaTemplate[T]]): Map[T, Set[Matcher[T]]] = - templates.foldLeft(Map.empty[T, Set[Matcher[T]]]) { - case (matchers, (_, template)) => matchers merge template.matchers merge rec(template.lambdas) + private def extractQuorums( + quantified: Set[T], + matchers: Set[Matcher[T]], + lambdas: Seq[LambdaTemplate[T]] + ): Seq[Set[Matcher[T]]] = { + val extMatchers: Set[Matcher[T]] = { + def rec(templates: Seq[LambdaTemplate[T]]): Set[Matcher[T]] = + templates.foldLeft(Set.empty[Matcher[T]]) { + case (matchers, template) => matchers ++ template.matchers.flatMap(_._2) ++ rec(template.lambdas) } - template.matchers merge rec(template.lambdas) + matchers ++ rec(lambdas) } - val quantifiedMatchers = (for { - (_, ms) <- allMatchers - m @ Matcher(_, _, args, _) <- ms + val quantifiedMatchers = for { + m @ Matcher(_, _, args, _) <- extMatchers if args exists (_.left.exists(quantified)) - } yield m).toSet + } yield m - val matchQuorums: Seq[Set[Matcher[T]]] = purescala.Quantification.extractQuorums( - quantifiedMatchers, quantified, + purescala.Quantification.extractQuorums(quantifiedMatchers, quantified, (m: Matcher[T]) => m.args.collect { case Right(m) if quantifiedMatchers(m) => m }.toSet, (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) + } + + private val lambdaAxioms: MutableSet[(LambdaTemplate[T], Seq[(Identifier, T)])] = MutableSet.empty + + def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, T]): Instantiation[T] = { + val quantifiers = template.arguments map { + case (id, idT) => id -> substMap(idT) + } filter (p => isQuantifier(p._2)) + + if (quantifiers.isEmpty || lambdaAxioms(template -> quantifiers)) { + Instantiation.empty[T] + } else { + lambdaAxioms += template -> quantifiers + val blockerT = encoder.encodeId(blockerId) + + val guard = FreshIdentifier("guard", BooleanType, true) + val guardT = encoder.encodeId(guard) + + val substituter = encoder.substitute(substMap + (template.start -> blockerT)) + val allMatchers = template.matchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) } + val qMatchers = allMatchers.flatMap(_._2).toSet + + val encArgs = template.args map substituter + val app = Application(Variable(template.ids._1), template.arguments.map(_._1.toVariable)) + val appT = encoder.encodeExpr((template.arguments.map(_._1) zip encArgs).toMap + template.ids)(app) + val selfMatcher = Matcher(template.ids._2, template.tpe, encArgs.map(Left(_)), appT) + + val enablingClause = encoder.mkImplies(guardT, blockerT) + + instantiateAxiom( + substMap(template.start), + blockerT, + guardT, + quantifiers, + qMatchers, + allMatchers + (template.start -> (allMatchers.getOrElse(template.start, Set.empty) + selfMatcher)), + template.condVars map { case (id, idT) => id -> substituter(idT) }, + template.exprVars map { case (id, idT) => id -> substituter(idT) }, + (template.clauses map substituter) :+ enablingClause, + template.blockers map { case (b, fis) => + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + }, + template.applications map { case (b, apps) => + substituter(b) -> apps.map(app => app.copy(caller = substituter(app.caller), args = app.args.map(substituter))) + }, + template.lambdas map (_.substitute(substituter)) + ) + } + } + + def instantiateAxiom( + start: T, + blocker: T, + guardVar: T, + quantifiers: Seq[(Identifier, T)], + matchers: Set[Matcher[T]], + allMatchers: Map[T, Set[Matcher[T]]], + condVars: Map[Identifier, T], + exprVars: Map[Identifier, T], + clauses: Seq[T], + blockers: Map[T, Set[TemplateCallInfo[T]]], + applications: Map[T, Set[App[T]]], + lambdas: Seq[LambdaTemplate[T]] + ): Instantiation[T] = { + val quantified = quantifiers.map(_._2).toSet + val matchQuorums = extractQuorums(quantified, matchers, lambdas) + + var instantiation = Instantiation.empty[T] + + for (matchers <- matchQuorums) { + val axiom = new Axiom(start, blocker, guardVar, quantified, + matchers, allMatchers, condVars, exprVars, + clauses, blockers, applications, lambdas + ) + + quantifications += axiom + + for (instCtx <- List(instantiated, fInstantiated)) { + val pCtx = new InstantiationContext(instCtx) + + for ((b, m) <- pCtx.instantiated) { + instantiation ++= axiom.instantiate(b, m)(pCtx) + } + + for (i <- (1 to instCtx.count)) { + instantiation ++= pCtx.instantiateNext + } + + instCtx.merge(pCtx) + } + } + + val quantifierSubst = uniformSubst(quantifiers) + val substituter = encoder.substitute(quantifierSubst) + + for (m <- matchers) { + instantiation ++= instantiateMatcher(trueT, m.substitute(substituter), fInstantiated) + } + + instantiation + } + + def instantiateQuantification(template: QuantificationTemplate[T], substMap: Map[T, T]): Instantiation[T] = { + val quantified = template.quantifiers.map(_._2).toSet + val matchQuorums = extractQuorums(quantified, template.matchers.flatMap(_._2).toSet, template.lambdas) var instantiation = Instantiation.empty[T] @@ -333,8 +595,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val quantification = new Quantification(template.qs._1 -> newQ, template.q2s, template.insts, template.guardVar, quantified, - matchers map (m => m.substitute(substituter)), - allMatchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) }, + matchers map (_.substitute(substituter)), + template.matchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) }, template.condVars, template.exprVars, template.clauses map substituter, @@ -344,52 +606,25 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage template.applications map { case (b, fas) => substituter(b) -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) }, - template.lambdas map { case (idT, template) => substituter(idT) -> template.substitute(subst) } + template.lambdas map (_.substitute(substituter)) ) - def extendClauses(master: Quantification, slave: Quantification): Instantiation[T] = { - val bindingsMap: Map[Matcher[T], Set[(T, Matcher[T])]] = slave.matchers.map { sm => - val instances = master.allMatchers.toSeq.flatMap { case (b, ms) => ms.map(b -> _) } - sm -> instances.filter(p => correspond(p._2, sm)).toSet - }.toMap - - val allMappings = extractMappings(bindingsMap) - val filteredMappings = allMappings.filter { s => - s.exists { case (b, sm, m) => !master.matchers(m) } && - s.exists { case (b, sm, m) => sm != m } && - s.forall { case (b, sm, m) => - (sm.args zip m.args).forall { - case (Right(_), Left(_)) => false - case _ => true - } - } - } + quantifications += quantification - var instantiation = Instantiation.empty[T] + for (instCtx <- List(instantiated, fInstantiated)) { + val pCtx = new InstantiationContext(instCtx) - for (mapping <- filteredMappings) { - val (enabler, subst) = extractSubst(slave.quantified, mapping) - instantiation ++= master.register(enabler, subst, slave) + for ((b, m) <- pCtx.instantiated) { + instantiation ++= quantification.instantiate(b, m)(pCtx) } - instantiation - } - - val allSet = quantification.allMatchers.flatMap(_._2).toSet - for (q <- quantifications) { - val allQSet = q.allMatchers.flatMap(_._2).toSet - if (quantification.matchers.forall(m => allQSet.exists(qm => correspond(qm, m)))) - instantiation ++= extendClauses(q, quantification) - - if (q.matchers.forall(qm => allSet.exists(m => correspond(qm, m)))) - instantiation ++= extendClauses(quantification, q) - } + for (i <- (1 to instCtx.count)) { + instantiation ++= pCtx.instantiateNext + } - for (instantiated <- List(instantiated, fInstantiated); (b, m) <- instantiated) { - instantiation ++= quantification.instantiate(b, m)(instantiated) + instCtx.merge(pCtx) } - quantifications += quantification quantification.qs._2 } @@ -405,33 +640,35 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val substituter = encoder.substitute(substMap ++ quantifierSubst) for ((_, ms) <- template.matchers; m <- ms) { - val sm = m.substitute(substituter) - - if (!fInsts.exists(fm => correspond(sm, fm) && sm.args == fm.args)) { - for (q <- quantifications) { - instantiation ++= q.instantiate(trueT, sm)(fInstantiated) - } - - fInsts += sm - } + instantiation ++= instantiateMatcher(trueT, m.substitute(substituter), fInstantiated) } instantiation } - def instantiateMatcher(blocker: T, matcher: Matcher[T]): Instantiation[T] = { - val qInst = if (instantiated(blocker -> matcher)) Instantiation.empty[T] else { - var instantiation = Instantiation.empty[T] - for (q <- quantifications) { - instantiation ++= q.instantiate(blocker, matcher)(instantiated) + private def instantiateMatcher(blocker: T, matcher: Matcher[T], instCtx: InstantiationContext): Instantiation[T] = { + if (instCtx(blocker -> matcher)) { + Instantiation.empty[T] + } else { + println("instantiating " + (blocker -> matcher)) + var instantiation: Instantiation[T] = Instantiation.empty + + val pCtx = new InstantiationContext(instCtx) + pCtx += blocker -> matcher + pCtx.inc() // pCtx.count == instCtx.count + 1 + + // we just inc()'ed so we can start at 1 (instCtx.count is guaranteed to have increased) + for (i <- (1 to instCtx.count)) { + instantiation ++= pCtx.instantiateNext } - instantiated += (blocker -> matcher) + instantiation ++= instCtx.merge(pCtx).instantiateNext instantiation } - - qInst } + def instantiateMatcher(blocker: T, matcher: Matcher[T]): Instantiation[T] = { + instantiateMatcher(blocker, matcher, instantiated) + } } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index f6484ff7be0152f316d8cf911c56a94320fade5a..f0e0d745baf5492337d922e60db5a18b35e787b8 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -72,7 +72,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val (bodyConds, bodyExprs, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { invocationEqualsBody.map(expr => mkClauses(start, expr, substMap)).getOrElse { - (Map[Identifier,T](), Map[Identifier,T](), Map[Identifier,Seq[Expr]](), Map[T,LambdaTemplate[T]](), Seq[QuantificationTemplate[T]]()) + (Map[Identifier,T](), Map[Identifier,T](), Map[Identifier,Seq[Expr]](), Seq[LambdaTemplate[T]](), Seq[QuantificationTemplate[T]]()) } } else { mkClauses(start, lambdaBody.get, substMap) @@ -134,7 +134,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], } def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): - (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Map[T, LambdaTemplate[T]], Seq[QuantificationTemplate[T]]) = { + (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Seq[LambdaTemplate[T]], Seq[QuantificationTemplate[T]]) = { var condVars = Map[Identifier, T]() @inline def storeCond(id: Identifier) : Unit = condVars += id -> encoder.encodeId(id) @@ -165,8 +165,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], @inline def registerQuantification(quantification: QuantificationTemplate[T]): Unit = quantifications :+= quantification - var lambdas = Map[T, LambdaTemplate[T]]() - @inline def registerLambda(idT: T, lambda: LambdaTemplate[T]) : Unit = lambdas += idT -> lambda + var lambdas = Seq[LambdaTemplate[T]]() + @inline def registerLambda(lambda: LambdaTemplate[T]) : Unit = lambdas :+= lambda def requireDecomposition(e: Expr) = { exists{ @@ -280,13 +280,12 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars val clauseSubst: Map[Identifier, T] = localSubst ++ (idArgs zip trArgs) val (lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, lambdaQuants) = mkClauses(pathVar, clause, clauseSubst) - assert(lambdaQuants.isEmpty, "Unhandled quantification in lambdas in " + l) val ids: (Identifier, T) = lid -> storeLambda(lid) val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap val template = LambdaTemplate(ids, encoder, manager, pathVar -> encodedCond(pathVar), - idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, localSubst, dependencies, l) - registerLambda(ids._2, template) + idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaQuants, lambdaTemplates, localSubst, dependencies, l) + registerLambda(template) Variable(lid) diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/templates/TemplateInfo.scala index e298e298a6f828c78dcf4da8de5177f94f16758b..977aeb5711b66c006161ff1af28fe5b9604456eb 100644 --- a/src/main/scala/leon/solvers/templates/TemplateInfo.scala +++ b/src/main/scala/leon/solvers/templates/TemplateInfo.scala @@ -14,6 +14,6 @@ case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) { case class TemplateAppInfo[T](template: LambdaTemplate[T], equals: T, args: Seq[T]) { override def toString = { - template.id + "|" + equals + args.mkString("(", ",", ")") + template.ids._1 + "|" + equals + args.mkString("(", ",", ")") } } diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index 32d273c3937d6ba4b808b79c16edf1ded4ade785..5e7302c549720a0291ecedf592c4a28d181b59fd 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -40,6 +40,12 @@ object Instantiation { def withClause(cl: T): Instantiation[T] = (i._1 :+ cl, i._2, i._3) def withClauses(cls: Seq[T]): Instantiation[T] = (i._1 ++ cls, i._2, i._3) + + def withCalls(calls: CallBlockers[T]): Instantiation[T] = (i._1, i._2 merge calls, i._3) + def withApps(apps: AppBlockers[T]): Instantiation[T] = (i._1, i._2, i._3 merge apps) + def withApp(app: ((T, App[T]), TemplateAppInfo[T])): Instantiation[T] = { + (i._1, i._2, i._3 merge Map(app._1 -> Set(app._2))) + } } } @@ -56,9 +62,9 @@ trait Template[T] { self => val clauses : Seq[T] val blockers : Map[T, Set[TemplateCallInfo[T]]] val applications : Map[T, Set[App[T]]] - val quantifications: Seq[QuantificationTemplate[T]] - val matchers: Map[T, Set[Matcher[T]]] - val lambdas : Map[T, LambdaTemplate[T]] + val quantifications : Seq[QuantificationTemplate[T]] + val matchers : Map[T, Set[Matcher[T]]] + val lambdas : Seq[LambdaTemplate[T]] private var substCache : Map[Seq[T],Map[T,T]] = Map.empty @@ -73,10 +79,13 @@ trait Template[T] { self => subst } - val lambdaSubstMap = lambdas.map { case (idT, lambda) => idT -> encoder.encodeId(lambda.id) } + val lambdaSubstMap = lambdas.map(lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) val quantificationSubstMap = quantifications.map(q => q.qs._2 -> encoder.encodeId(q.qs._1)) val substMap : Map[T,T] = baseSubstMap ++ lambdaSubstMap ++ quantificationSubstMap + (start -> aVar) + instantiate(substMap) + } + protected def instantiate(substMap: Map[T, T]): Instantiation[T] = { Template.instantiate(encoder, manager, clauses, blockers, applications, quantifications, matchers, lambdas, substMap) } @@ -86,43 +95,6 @@ trait Template[T] { self => object Template { - private object InvocationExtractor { - private def flatInvocation(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { - case fi @ FunctionInvocation(tfd, args) => Some((tfd, args)) - case Application(caller, args) => flatInvocation(caller) match { - case Some((tfd, prevArgs)) => Some((tfd, prevArgs ++ args)) - case None => None - } - case _ => None - } - - def unapply(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { - case IsTyped(f: FunctionInvocation, ft: FunctionType) => None - case IsTyped(f: Application, ft: FunctionType) => None - case FunctionInvocation(tfd, args) => Some(tfd -> args) - case f: Application => flatInvocation(f) - case _ => None - } - } - - private object ApplicationExtractor { - private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(fi: FunctionInvocation, _) => None - case Application(caller: Application, args) => flatApplication(caller) match { - case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) - case None => None - } - case Application(caller, args) => Some((caller, args)) - case _ => None - } - - def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case IsTyped(f: Application, ft: FunctionType) => None - case f: Application => flatApplication(f) - case _ => None - } - } - private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") @@ -146,16 +118,14 @@ object Template { condVars: Map[Identifier, T], exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], substMap: Map[Identifier, T] = Map.empty[Identifier, T], optCall: Option[TypedFunDef] = None, optApp: Option[(T, FunctionType)] = None ) : (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[T, Set[App[T]]], Map[T, Set[Matcher[T]]], () => String) = { - val idToTrId : Map[Identifier, T] = { - condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ - lambdas.map { case (idT, template) => template.id -> idT } - } + val idToTrId : Map[Identifier, T] = + condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) val encodeExpr : Expr => T = encoder.encodeExpr(idToTrId) @@ -180,17 +150,10 @@ object Template { var matchInfos : Set[Matcher[T]] = Set.empty for (e <- es) { - funInfos ++= collect[TemplateCallInfo[T]] { - case InvocationExtractor(tfd, args) => - Set(TemplateCallInfo(tfd, args.map(encodeExpr))) - case _ => Set.empty - }(e) - - appInfos ++= collect[App[T]] { - case ApplicationExtractor(c, args) => - Set(App(encodeExpr(c), c.getType.asInstanceOf[FunctionType], args.map(encodeExpr))) - case _ => Set.empty - }(e) + funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeExpr))) + appInfos ++= firstOrderAppsOf(e).map { case (c, args) => + App(encodeExpr(c), c.getType.asInstanceOf[FunctionType], args.map(encodeExpr)) + } matchInfos ++= fold[Map[Expr, Matcher[T]]] { (expr, res) => val result = res.flatten.toMap @@ -247,7 +210,7 @@ object Template { " * Matchers :" + (if (matchers.isEmpty) "\n" else { "\n " + matchers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" }) + - " * Lambdas :\n" + lambdas.map { case (_, template) => + " * Lambdas :\n" + lambdas.map { case template => " +> " + template.toString.split("\n").mkString("\n ") + "\n" }.mkString("\n") } @@ -263,7 +226,7 @@ object Template { applications: Map[T, Set[App[T]]], quantifications: Seq[QuantificationTemplate[T]], matchers: Map[T, Set[Matcher[T]]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], substMap: Map[T, T] ): Instantiation[T] = { @@ -276,10 +239,8 @@ object Template { var instantiation: Instantiation[T] = (newClauses, newBlockers, Map.empty) - for ((idT, lambda) <- lambdas) { - val newIdT = substituter(idT) - val newTemplate = lambda.substitute(substMap) - instantiation ++= manager.instantiateLambda(newIdT, newTemplate) + for (lambda <- lambdas) { + instantiation ++= manager.instantiateLambda(lambda.substitute(substituter)) } for ((b,apps) <- applications; bp = substituter(b); app <- apps) { @@ -292,7 +253,7 @@ object Template { } for (q <- quantifications) { - instantiation ++= q.instantiate(substMap) + instantiation ++= manager.instantiateQuantification(q, substMap) } instantiation @@ -311,7 +272,7 @@ object FunctionTemplate { exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], quantifications: Seq[QuantificationTemplate[T]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], isRealFunDef: Boolean ) : FunctionTemplate[T] = { @@ -359,7 +320,7 @@ class FunctionTemplate[T] private( val applications: Map[T, Set[App[T]]], val quantifications: Seq[QuantificationTemplate[T]], val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]], + val lambdas: Seq[LambdaTemplate[T]], isRealFunDef: Boolean, stringRepr: () => String) extends Template[T] { @@ -367,7 +328,7 @@ class FunctionTemplate[T] private( override def toString : String = str override def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[(T, App[T]), Set[TemplateAppInfo[T]]]) = { - if (!isRealFunDef) manager.registerFree(tfd.params.map(_.getType) zip args) + if (!isRealFunDef) manager.registerFree(tfd.params.map(_.id) zip args) super.instantiate(aVar, args) } } @@ -383,7 +344,8 @@ object LambdaTemplate { condVars: Map[Identifier, T], exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], + quantifications: Seq[QuantificationTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], baseSubstMap: Map[Identifier, T], dependencies: Map[Identifier, T], lambda: Lambda @@ -404,16 +366,17 @@ object LambdaTemplate { val key = structuralLambda.asInstanceOf[Lambda] new LambdaTemplate[T]( - ids._1, + ids, encoder, manager, pathVar._2, - arguments.map(_._2), + arguments, condVars, exprVars, clauses, blockers, applications, + quantifications, matchers, lambdas, keyDeps, @@ -424,30 +387,27 @@ object LambdaTemplate { } class LambdaTemplate[T] private ( - val id: Identifier, + val ids: (Identifier, T), val encoder: TemplateEncoder[T], val manager: QuantificationManager[T], val start: T, - val args: Seq[T], + val arguments: Seq[(Identifier, T)], val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], + val quantifications: Seq[QuantificationTemplate[T]], val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]], + val lambdas: Seq[LambdaTemplate[T]], private[templates] val dependencies: Map[Identifier, T], private[templates] val structuralKey: Lambda, stringRepr: () => String) extends Template[T] { - // Universal quantification is not allowed inside closure bodies! - val quantifications: Seq[QuantificationTemplate[T]] = Seq.empty - - val tpe = id.getType.asInstanceOf[FunctionType] - - def substitute(substMap: Map[T,T]): LambdaTemplate[T] = { - val substituter : T => T = encoder.substitute(substMap) + val args = arguments.map(_._2) + val tpe = ids._1.getType.asInstanceOf[FunctionType] + def substitute(substituter: T => T): LambdaTemplate[T] = { val newStart = substituter(start) val newClauses = clauses.map(substituter) val newBlockers = blockers.map { case (b, fis) => @@ -460,26 +420,29 @@ class LambdaTemplate[T] private ( bp -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) } + val newQuantifications = quantifications.map(_.substitute(substituter)) + val newMatchers = matchers.map { case (b, ms) => val bp = if (b == start) newStart else b bp -> ms.map(_.substitute(substituter)) } - val newLambdas = lambdas.map { case (idT, template) => idT -> template.substitute(substMap) } + val newLambdas = lambdas.map(_.substitute(substituter)) val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) new LambdaTemplate[T]( - id, + ids._1 -> substituter(ids._2), encoder, manager, newStart, - args, + arguments, condVars, exprVars, newClauses, newBlockers, newApplications, + newQuantifications, newMatchers, newLambdas, newDependencies, @@ -514,4 +477,8 @@ class LambdaTemplate[T] private ( Some(rec(structuralKey, that.structuralKey)) } } + + override def instantiate(substMap: Map[T, T]): Instantiation[T] = { + super.instantiate(substMap) ++ manager.instantiateAxiom(this, substMap) + } } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index d936a4e3ac73612e62fa15e4eb3ba15102d1c6fb..8add63cce9366fb24c7d50ac79dcc1c9a8e045f1 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -670,8 +670,12 @@ trait AbstractZ3Solver extends Solver { FiniteMap(elems, from, to) } - case FunctionType(fts, tt) => - rec(t, RawArrayType(tupleTypeWrap(fts), tt)) + case ft @ FunctionType(fts, tt) => + rec(t, RawArrayType(tupleTypeWrap(fts), tt)) match { + case r: RawArrayValue => + val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, fts.size) -> v } + PartialLambda(elems, Some(r.default), ft) + } case tpe @ SetType(dt) => model.getSetValue(t) match { diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 5baa9b9aa22713f1b4e86782570e124d48aac767..105bfa7dc036e790cc10ac5f11f7db47b1a72537 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -56,8 +56,6 @@ class FairZ3Solver(val context: LeonContext, val program: Program) toggleWarningMessages(true) private def extractModel(model: Z3Model, ids: Set[Identifier]): HenkinModel = { - val asMap = modelToMap(model, ids) - def extract(b: Z3AST, m: Matcher[Z3AST]): Set[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe val optEnabler = model.evalAs[Boolean](b) @@ -99,25 +97,12 @@ class FairZ3Solver(val context: LeonContext, val program: Program) case _ => None }).toMap.mapValues(_.toSet) - val asDMap = asMap.map(p => funDomains.get(p._1) match { - case Some(domain) => - val mapping = domain.toSeq.map { es => - val ev: Expr = p._2 match { - case RawArrayValue(_, mapping, dflt) => - mapping.collectFirst { - case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v - } getOrElse dflt - case _ => scala.sys.error("Unexpected function encoding " + p._2) - } - es -> ev - } - p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) - case None => p - }) - val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + val asMap = modelToMap(model, ids) + val asDMap = purescala.Quantification.extractModel(asMap, funDomains, typeDomains, evaluator) + val domain = new HenkinDomains(typeDomains) new HenkinModel(asDMap, domain) } diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 87d1003150995eb236e5f3abb738450516c5b7f7..7bfeece2f6f35e88c4557db2ecafe4c684a1198d 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -5,7 +5,6 @@ package utils import leon.purescala._ import leon.purescala.Definitions.Program -import leon.purescala.Quantification.CheckForalls import leon.solvers.isabelle.AdaptationPhase import leon.verification.InjectAsserts import leon.xlang.{NoXLangFeaturesChecking, XLangDesugaringPhase} @@ -39,8 +38,7 @@ class PreprocessingPhase(desugarXLang: Boolean = false) extends LeonPhase[Progra synthesis.ConversionPhase andThen CheckADTFieldsTypes andThen InjectAsserts andThen - InliningPhase andThen - CheckForalls + InliningPhase val pipeX = if(desugarXLang) { XLangDesugaringPhase andThen