diff --git a/src/main/java/leon/codegen/runtime/Forall.java b/src/main/java/leon/codegen/runtime/Forall.java new file mode 100644 index 0000000000000000000000000000000000000000..f6877b604bcd069af6c2ce20d78a17d82d434dbe --- /dev/null +++ b/src/main/java/leon/codegen/runtime/Forall.java @@ -0,0 +1,35 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.HashMap; + +public abstract class Forall { + private static final HashMap<Tuple, Boolean> cache = new HashMap<Tuple, Boolean>(); + + protected final LeonCodeGenRuntimeHenkinMonitor monitor; + protected final Tuple closures; + protected final boolean check; + + public Forall(LeonCodeGenRuntimeMonitor monitor, Tuple closures) throws LeonCodeGenEvaluationException { + if (!(monitor instanceof LeonCodeGenRuntimeHenkinMonitor)) + throw new LeonCodeGenEvaluationException("Can't evaluate foralls without domain"); + + this.monitor = (LeonCodeGenRuntimeHenkinMonitor) monitor; + this.closures = closures; + this.check = (boolean) closures.get(closures.getArity() - 1); + } + + public boolean check() { + Tuple key = new Tuple(new Object[] { getClass(), monitor, closures }); // check is in the closures + if (cache.containsKey(key)) { + return cache.get(key); + } else { + boolean res = checkForall(); + cache.put(key, res); + return res; + } + } + + public abstract boolean checkForall(); +} diff --git a/src/main/java/leon/codegen/runtime/Lambda.java b/src/main/java/leon/codegen/runtime/Lambda.java index 0bc5171fd6405f59ab2ec4d60e3bf368c49a7bff..af255726311655efaeddea545c5e6e44afc15b8e 100644 --- a/src/main/java/leon/codegen/runtime/Lambda.java +++ b/src/main/java/leon/codegen/runtime/Lambda.java @@ -5,4 +5,5 @@ package leon.codegen.runtime; public abstract class Lambda { public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException; public abstract void checkForall(boolean[] quantified); + public abstract void checkAxiom(); } diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java b/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java new file mode 100644 index 0000000000000000000000000000000000000000..f172316a2548a52c6b294f70101a15ebbb8ce98a --- /dev/null +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java @@ -0,0 +1,14 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +/** Such exceptions are thrown when the evaluator encounters a forall + * expression whose shape is not supported in Leon. */ +public class LeonCodeGenQuantificationException extends Exception { + + private static final long serialVersionUID = -1824885321497473916L; + + public LeonCodeGenQuantificationException(String msg) { + super(msg); + } +} diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java index 7314bfae531af0b68432ea5dd5dcf93b51d629af..597beec44b6a1a1719909e00ecb7d7916f0c7c03 100644 --- a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java @@ -7,16 +7,27 @@ import java.util.LinkedList; import java.util.HashMap; public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { - private final HashMap<Integer, List<Tuple>> domains = new HashMap<Integer, List<Tuple>>(); - private final List<String> warnings = new LinkedList<String>(); + private final HashMap<Integer, List<Tuple>> tpes = new HashMap<Integer, List<Tuple>>(); + private final HashMap<Class<?>, List<Tuple>> lambdas = new HashMap<Class<?>, List<Tuple>>(); + public final boolean checkForalls; - public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { + public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations, boolean checkForalls) { super(maxInvocations); + this.checkForalls = checkForalls; + } + + public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { + this(maxInvocations, false); } public void add(int type, Tuple input) { - if (!domains.containsKey(type)) domains.put(type, new LinkedList<Tuple>()); - domains.get(type).add(input); + if (!tpes.containsKey(type)) tpes.put(type, new LinkedList<Tuple>()); + tpes.get(type).add(input); + } + + public void add(Class<?> clazz, Tuple input) { + if (!lambdas.containsKey(clazz)) lambdas.put(clazz, new LinkedList<Tuple>()); + lambdas.get(clazz).add(input); } public List<Tuple> domain(Object obj, int type) { @@ -26,19 +37,14 @@ public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { for (Tuple key : l.mapping.keySet()) { domain.add(key); } + } else if (obj instanceof Lambda) { + List<Tuple> lambdaDomain = lambdas.get(obj.getClass()); + if (lambdaDomain != null) domain.addAll(lambdaDomain); } - List<Tuple> tpeDomain = domains.get(type); + List<Tuple> tpeDomain = tpes.get(type); if (tpeDomain != null) domain.addAll(tpeDomain); return domain; } - - public void warn(String warning) { - warnings.add(warning); - } - - public List<String> getWarnings() { - return warnings; - } } diff --git a/src/main/java/leon/codegen/runtime/PartialLambda.java b/src/main/java/leon/codegen/runtime/PartialLambda.java index 7bab72ea31dfacb6438c7f217da0991d5238a2b2..b04036db5e9f81d1eaf7fa2c9a047bfef45a4df8 100644 --- a/src/main/java/leon/codegen/runtime/PartialLambda.java +++ b/src/main/java/leon/codegen/runtime/PartialLambda.java @@ -29,7 +29,7 @@ public final class PartialLambda extends Lambda { } else if (dflt != null) { return dflt; } else { - throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments"); + throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments " + tuple); } } @@ -50,4 +50,7 @@ public final class PartialLambda extends Lambda { @Override public void checkForall(boolean[] quantified) {} + + @Override + public void checkAxiom() {} } diff --git a/src/main/java/leon/codegen/runtime/Tuple.java b/src/main/java/leon/codegen/runtime/Tuple.java index 9ae7a5f490223b0bc3b6a07dbb05160c76dc5e11..3b72931da6fc04b3504d7f3bee2056560c269634 100644 --- a/src/main/java/leon/codegen/runtime/Tuple.java +++ b/src/main/java/leon/codegen/runtime/Tuple.java @@ -54,4 +54,20 @@ public final class Tuple { _hash = h; return h; } + + @Override + public String toString() { + String str = "("; + boolean first = true; + for (Object obj : elements) { + if (first) { + first = false; + } else { + str += ", "; + } + str += obj == null ? "null" : obj.toString(); + } + str += ")"; + return str; + } } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 8860b23bc88523f84797e46352251b7b824b1b52..0ac837c752fad1242e035dbfccec0adfcf581871 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -6,7 +6,7 @@ package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps.{simplestValue, matchToIfThenElse, collect} +import purescala.ExprOps.{simplestValue, matchToIfThenElse, collect, variablesOf, CollectorWithPaths} import purescala.Types._ import purescala.Constructors._ import purescala.Extractors._ @@ -47,6 +47,8 @@ trait CodeGeneration { def withArgs(newArgs: Map[Identifier, Int]) = new Locals(vars, args ++ newArgs, fields) def withFields(newFields: Map[Identifier,(String,String,String)]) = new Locals(vars, args, fields ++ newFields) + + override def toString = "Locals("+vars + ", " + args + ", " + fields + ")" } object NoLocals extends Locals(Map.empty, Map.empty, Map.empty) @@ -70,8 +72,10 @@ trait CodeGeneration { private[codegen] val RationalClass = "leon/codegen/runtime/Rational" private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" private[codegen] val LambdaClass = "leon/codegen/runtime/Lambda" + private[codegen] val ForallClass = "leon/codegen/runtime/Forall" private[codegen] val PartialLambdaClass = "leon/codegen/runtime/PartialLambda" private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" + private[codegen] val InvalidForallClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" private[codegen] val BadQuantificationClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" @@ -225,8 +229,8 @@ trait CodeGeneration { private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String] private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda] - private def compileLambda(l: Lambda, ch: CodeHandler)(implicit locals: Locals): Unit = { - val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(l) + protected def compileLambda(l: Lambda): (String, Seq[(Identifier, String)], String) = { + val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(matchToIfThenElse(l)) val reverseSubst = structSubst.map(p => p._2 -> p._1) val nl = normalized.asInstanceOf[Lambda] @@ -280,6 +284,10 @@ trait CodeGeneration { cch.freeze } + val argMapping = nl.args.map(_.id).zipWithIndex.toMap + val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap + val newLocals = NoLocals.withArgs(argMapping).withFields(closureMapping) + locally { val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") @@ -288,11 +296,6 @@ trait CodeGeneration { METHOD_ACC_FINAL ).asInstanceOf[U2]) - val argMapping = nl.args.map(_.id).zipWithIndex.toMap - val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap - - val newLocals = locals.withArgs(argMapping).withFields(closureMapping) - val apch = apm.codeHandler mkBoxedExpr(nl.body, apch)(newLocals) @@ -378,7 +381,7 @@ trait CodeGeneration { } locally { - val vmh = cf.addMethod("V", "checkForall", s"[Z") + val vmh = cf.addMethod("V", "checkForall", "[Z") vmh.setFlags(( METHOD_ACC_PUBLIC | METHOD_ACC_FINAL @@ -386,64 +389,126 @@ trait CodeGeneration { val vch = vmh.codeHandler - vch << ALoad(1) // load boolean array `quantified` + vch << ALoad(1) // load argument array def rec(args: Seq[Identifier], idx: Int, quantified: Set[Identifier]): Unit = args match { case x :: xs => val notQuantLabel = vch.getFreshLabel("notQuant") - vch << DUP << ALoad(idx) << IfEq(notQuantLabel) + vch << DUP << Ldc(idx) << BALOAD << IfEq(notQuantLabel) rec(xs, idx + 1, quantified + x) vch << Label(notQuantLabel) rec(xs, idx + 1, quantified) case Nil => - if (quantified.nonEmpty) checkQuantified(quantified, nl.body, vch) + if (quantified.nonEmpty) { + checkQuantified(quantified, nl.body, vch)(newLocals) + vch << ALoad(0) << InvokeVirtual(LambdaClass, "checkAxiom", "()V") + } vch << POP << RETURN } - rec(nl.args.map(_.id), 0, Set.empty) + if (requireQuantification) { + rec(nl.args.map(_.id), 0, Set.empty) + } else { + vch << POP << RETURN + } vch.freeze } + locally { + val vmh = cf.addMethod("V", "checkAxiom") + vmh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val vch = vmh.codeHandler + + if (requireQuantification) { + val thisVar = FreshIdentifier("this", l.getType) + val axiom = Equals(Application(Variable(thisVar), nl.args.map(_.toVariable)), nl.body) + val axiomLocals = NoLocals.withFields(closureMapping).withVar(thisVar -> 0) + + mkForall(nl.args.map(_.id).toSet, axiom, vch, check = false)(axiomLocals) + + val skip = vch.getFreshLabel("skip") + vch << IfNe(skip) + vch << New(InvalidForallClass) << DUP + vch << Ldc("Unaxiomatic lambda " + l) + vch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + vch << ATHROW + vch << Label(skip) + } + + vch << RETURN + vch.freeze + } + loader.register(cf) afName }) - val consSig = "(" + closures.map(_._2).mkString("") + ")V" - - ch << New(afName) << DUP - for ((id,jvmt) <- closures) { - if (id == monitorID) { - load(monitorID, ch) - } else { - mkExpr(Variable(reverseSubst(id)), ch) - } - } - ch << InvokeSpecial(afName, constructorName, consSig) + (afName, closures.map { case p @ (id, jvmt) => + if (id == monitorID) p else (reverseSubst(id) -> jvmt) + }, "(" + closures.map(_._2).mkString("") + ")V") } private def checkQuantified(quantified: Set[Identifier], body: Expr, ch: CodeHandler)(implicit locals: Locals): Unit = { - val status = checkForall(quantified, body) - if (status.isValid) { - purescala.ExprOps.preTraversal { - case Application(caller, args) => - ch << NewArray.primitive("T_BOOLEAN") - for ((arg, idx) <- args.zipWithIndex) { - ch << DUP << Ldc(idx) << Ldc(arg match { - case Variable(id) if quantified(id) => 1 - case _ => 0 - }) << BASTORE + val skipCheck = ch.getFreshLabel("skipCheck") + + load(monitorID, ch) + ch << CheckCast(HenkinClass) << GetField(HenkinClass, "checkForalls", "Z") + ch << IfEq(skipCheck) + + checkForall(quantified, body)(ctx) match { + case status: ForallInvalid => + ch << New(InvalidForallClass) << DUP + ch << Ldc("Invalid forall: " + status.getMessage) + ch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + ch << ATHROW + + case ForallValid => + // expand match case expressions and lets so that caller can be compiled given + // the current locals (lets and matches introduce new locals) + val cleanBody = purescala.ExprOps.expandLets(purescala.ExprOps.matchToIfThenElse(body)) + + val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] { + def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match { + case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path)) + case _ => None } - ch << InvokeVirtual(LambdaClass, "checkForall", "([Z)V") - case _ => - } (body) - } else { - load(monitorID, ch) - ch << Ldc("Invalid forall: " + status) - ch << InvokeVirtual(HenkinClass, "warn", "(Ljava/lang/String;)V") + override def rec(e: Expr, path: Seq[Expr]): Expr = e match { + case l : Lambda => l + case _ => super.rec(e, path) + } + }.traverse(cleanBody) + + for ((caller, args, paths) <- calls) { + if ((variablesOf(caller) & quantified).isEmpty) { + val enabler = andJoin(paths.filter(expr => (variablesOf(expr) & quantified).isEmpty)) + val skipCall = ch.getFreshLabel("skipCall") + mkExpr(enabler, ch) + ch << IfEq(skipCall) + + mkExpr(caller, ch) + ch << Ldc(args.size) << NewArray.primitive("T_BOOLEAN") + for ((arg, idx) <- args.zipWithIndex) { + ch << DUP << Ldc(idx) << Ldc(arg match { + case Variable(id) if quantified(id) => 1 + case _ => 0 + }) << BASTORE + } + + ch << InvokeVirtual(LambdaClass, "checkForall", "([Z)V") + + ch << Label(skipCall) + } + } } + + ch << Label(skipCheck) } private val typeIdCache = scala.collection.mutable.Map.empty[TypeTree, Int] @@ -455,137 +520,270 @@ trait CodeGeneration { id } - private def compileForall(f: Forall, ch: CodeHandler)(implicit locals: Locals): Unit = { - // make sure we have an available HenkinModel - val monitorOk = ch.getFreshLabel("monitorOk") + private[codegen] val forallToClass = scala.collection.mutable.Map.empty[Expr, String] + + private def mkForall(quants: Set[Identifier], body: Expr, ch: CodeHandler, check: Boolean = true)(implicit locals: Locals): Unit = { + val (afName, closures, consSig) = compileForall(quants, body) + ch << New(afName) << DUP load(monitorID, ch) - ch << InstanceOf(HenkinClass) << IfNe(monitorOk) - ch << New(ImpossibleEvaluationClass) << DUP - ch << Ldc("Can't evaluate foralls without domain") - ch << InvokeSpecial(ImpossibleEvaluationClass, constructorName, "(Ljava/lang/String;)V") - ch << ATHROW - ch << Label(monitorOk) - - val quantified = f.args.map(_.id).toSet - checkQuantified(quantified, f.body, ch) - - val Forall(fargs, TopLevelAnds(conjuncts)) = f - val endLabel = ch.getFreshLabel("forallEnd") - - for (conj <- conjuncts) { - val vars = purescala.ExprOps.variablesOf(conj) - val args = fargs.map(_.id).filter(vars) - val quantified = args.toSet - - val matchQuorums = extractQuorums(conj, quantified) - - val matcherIndexes = matchQuorums.flatten.distinct.zipWithIndex.toMap - - def buildLoops( - mis: List[(Expr, Seq[Expr], Int)], - localMapping: Map[Identifier, Int], - pointerMapping: Map[(Int, Int), Identifier] - ): Unit = mis match { - case (expr, args, qidx) :: rest => - load(monitorID, ch) - ch << CheckCast(HenkinClass) + mkTuple(closures.map(_.toVariable) :+ BooleanLiteral(check), ch) + ch << InvokeSpecial(afName, constructorName, consSig) + ch << InvokeVirtual(ForallClass, "check", "()Z") + } - mkExpr(expr, ch) - ch << Ldc(typeId(expr.getType)) - ch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") - ch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") + private def compileForall(quants: Set[Identifier], body: Expr): (String, Seq[Identifier], String) = { + val (nl, structSubst) = purescala.ExprOps.normalizeStructure(matchToIfThenElse(body)) + val reverseSubst = structSubst.map(p => p._2 -> p._1) + val nquants = quants.flatMap(structSubst.get) - val loop = ch.getFreshLabel("loop") - val out = ch.getFreshLabel("out") - ch << Label(loop) - // it - ch << DUP - // it, it - ch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") - // it, hasNext - ch << IfEq(out) << DUP - // it, it - ch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") - // it, elem - ch << CheckCast(TupleClass) - - val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { - val id = FreshIdentifier("q", arg.getType, true) - val slot = ch.getFreshVar - - ch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") - mkUnbox(arg.getType, ch) - ch << (typeToJVM(arg.getType) match { - case "I" | "Z" => IStore(slot) - case _ => AStore(slot) - }) - - (id -> slot, (qidx -> aidx) -> id) - }).unzip - - ch << POP - // it - - buildLoops(rest, localMapping ++ newLoc, pointerMapping ++ newPtr) - - ch << Goto(loop) - ch << Label(out) << POP - - case Nil => - var okLabel: Option[String] = None - for (quorum <- matchQuorums) { - okLabel.foreach(ok => ch << Label(ok)) - okLabel = Some(ch.getFreshLabel("quorumOk")) - - var mappings: Seq[(Identifier, Int, Int)] = Seq.empty - var constraints: Seq[(Expr, Int, Int)] = Seq.empty - var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty - - for (q @ (expr, args) <- quorum) { - val qidx = matcherIndexes(q) - val (qmappings, qconstraints) = args.zipWithIndex.partition { - case (Variable(id), aidx) => quantified(id) - case _ => false - } - - mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) - constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) - } + val closures = (purescala.ExprOps.variablesOf(nl) -- nquants).toSeq.sortBy(_.uniqueName) - val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { - val base :: others = es.toList.map(p => (p._2, p._3)) - equalities ++= others.map(p => base -> p) - (id -> base) - } + val afName = forallToClass.getOrElse(nl, { + val afName = "Leon$CodeGen$Forall$" + forallCounter.nextGlobal + forallToClass += nl -> afName - val enabler = andJoin(constraints.map { - case (e, qidx, aidx) => Equals(e, pointerMapping(qidx -> aidx).toVariable) - } ++ equalities.map { - case (k1, k2) => Equals(pointerMapping(k1).toVariable, pointerMapping(k2).toVariable) - }) + val cf = new ClassFile(afName, Some(ForallClass)) - mkExpr(enabler, ch)(locals.withVars(localMapping)) - ch << IfEq(okLabel.get) + cf.setFlags(( + CLASS_ACC_SUPER | + CLASS_ACC_PUBLIC | + CLASS_ACC_FINAL + ).asInstanceOf[U2]) - val varsMap = args.map(id => id -> localMapping(pointerMapping(mapping(id)))).toMap - mkExpr(conj, ch)(locals.withVars(varsMap)) - ch << IfNe(okLabel.get) + locally { + val cch = cf.addConstructor(s"L$MonitorClass;", s"L$TupleClass;").codeHandler - // -- Forall is false! -- - // POP all the iterators... - for (_ <- List.range(0, matcherIndexes.size)) ch << POP + cch << ALoad(0) << ALoad(1) << ALoad(2) + cch << InvokeSpecial(ForallClass, constructorName, s"(L$MonitorClass;L$TupleClass;)V") + cch << RETURN + cch.freeze + } - // ... and return false - ch << Ldc(0) << Goto(endLabel) + locally { + val cfm = cf.addMethod("Z", "checkForall") + + cfm.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val cfch = cfm.codeHandler + + cfch << ALoad(0) << GetField(ForallClass, "closures", s"L$TupleClass;") + + val closureVars = (for ((id, idx) <- closures.zipWithIndex) yield { + val slot = cfch.getFreshVar + cfch << DUP << Ldc(idx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") + mkUnbox(id.getType, cfch) + cfch << (id.getType match { + case ValueType() => IStore(slot) + case _ => AStore(slot) + }) + id -> slot + }).toMap + + cfch << POP + + val monitorSlot = cfch.getFreshVar + cfch << ALoad(0) << GetField(ForallClass, "monitor", s"L$HenkinClass;") + cfch << AStore(monitorSlot) + + implicit val locals = NoLocals.withVars(closureVars).withVar(monitorID -> monitorSlot) + + val skipCheck = cfch.getFreshLabel("skipCheck") + cfch << ALoad(0) << GetField(ForallClass, "check", "Z") + cfch << IfEq(skipCheck) + checkQuantified(nquants, nl, cfch) + cfch << Label(skipCheck) + + val TopLevelAnds(conjuncts) = nl + val endLabel = cfch.getFreshLabel("forallEnd") + + for (conj <- conjuncts) { + val vars = purescala.ExprOps.variablesOf(conj) + val quantified = nquants.filter(vars) + + val matchQuorums = extractQuorums(conj, quantified) + + var allSlots: List[Int] = Nil + var freeSlots: List[Int] = Nil + def getSlot(): Int = freeSlots match { + case x :: xs => + freeSlots = xs + x + case Nil => + val slot = cfch.getFreshVar + allSlots = allSlots :+ slot + slot } - ch << Label(okLabel.get) + for ((qrm, others) <- matchQuorums) { + val quorum = qrm.toList + + def rec(mis: List[(Expr, Expr, Seq[Expr], Int)], locs: Map[Identifier, Int], pointers: Map[(Int, Int), Identifier]): Unit = mis match { + case (TopLevelAnds(paths), expr, args, qidx) :: rest => + load(monitorID, cfch) + cfch << CheckCast(HenkinClass) + + mkExpr(expr, cfch) + cfch << Ldc(typeId(expr.getType)) + cfch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") + cfch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") + + val loop = cfch.getFreshLabel("loop") + val out = cfch.getFreshLabel("out") + cfch << Label(loop) + // it + cfch << DUP + // it, it + cfch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") + // it, hasNext + cfch << IfEq(out) << DUP + // it, it + cfch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") + // it, elem + cfch << CheckCast(TupleClass) + + val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { + val id = FreshIdentifier("q", arg.getType, true) + val slot = getSlot() + + cfch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") + mkUnbox(arg.getType, cfch) + cfch << (typeToJVM(arg.getType) match { + case "I" | "Z" => IStore(slot) + case _ => AStore(slot) + }) + + (id -> slot, (qidx -> aidx) -> id) + }).unzip + + cfch << POP + // it + + rec(rest, locs ++ newLoc, pointers ++ newPtr) + + cfch << Goto(loop) + cfch << Label(out) << POP + + case Nil => + val okLabel = cfch.getFreshLabel("assignmentOk") + + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + + for ((q @ (_, expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id), aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, pointers(qidx -> aidx).toVariable) + } ++ equalities.map { + case (k1, k2) => Equals(pointers(k1).toVariable, pointers(k2).toVariable) + }) + + val varsMap = quantified.map(id => id -> locs(pointers(mapping(id)))).toMap + val varLocals = locals.withVars(varsMap) + + mkExpr(enabler, cfch)(varLocals.withVars(locs)) + cfch << IfEq(okLabel) + + val checkOk = cfch.getFreshLabel("checkOk") + load(monitorID, cfch) + cfch << GetField(HenkinClass, "checkForalls", "Z") + cfch << IfEq(checkOk) + + var nextLabel: Option[String] = None + for ((b,caller,args) <- others) { + nextLabel.foreach(label => cfch << Label(label)) + nextLabel = Some(cfch.getFreshLabel("next")) + + mkExpr(b, cfch)(varLocals) + cfch << IfEq(nextLabel.get) + + load(monitorID, cfch) + cfch << CheckCast(HenkinClass) + mkExpr(caller, cfch)(varLocals) + cfch << Ldc(typeId(caller.getType)) + cfch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") + mkTuple(args, cfch)(varLocals) + cfch << InvokeInterface(JavaListClass, "contains", s"(L$ObjectClass;)Z") + cfch << IfNe(nextLabel.get) + + cfch << New(InvalidForallClass) << DUP + cfch << Ldc("Unhandled transitive implication in " + conj) + cfch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + cfch << ATHROW + } + nextLabel.foreach(label => cfch << Label(label)) + + cfch << Label(checkOk) + mkExpr(conj, cfch)(varLocals) + cfch << IfNe(okLabel) + + // -- Forall is false! -- + // POP all the iterators... + for (_ <- List.range(0, quorum.size)) cfch << POP + + // ... and return false + cfch << Ldc(0) << Goto(endLabel) + cfch << Label(okLabel) + } + + val skipQuorum = cfch.getFreshLabel("skipQuorum") + for ((TopLevelAnds(paths), _, _) <- quorum) { + val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty)) + mkExpr(p, cfch) + cfch << IfEq(skipQuorum) + } + + val mis = quorum.zipWithIndex.map { case ((p, e, as), idx) => (p, e, as, idx) } + rec(mis, Map.empty, Map.empty) + freeSlots = allSlots + + cfch << Label(skipQuorum) + } + } + + cfch << Ldc(1) << Label(endLabel) + cfch << IRETURN + + cfch.freeze } - buildLoops(matcherIndexes.toList.map { case ((e, as), idx) => (e, as, idx) }, Map.empty, Map.empty) - } + loader.register(cf) + + afName + }) + + (afName, closures.map(reverseSubst), s"(L$MonitorClass;L$TupleClass;)V") + } - ch << Ldc(1) << Label(endLabel) + // also makes tuples with 0/1 args + private def mkTuple(es: Seq[Expr], ch: CodeHandler)(implicit locals: Locals) : Unit = { + ch << New(TupleClass) << DUP + ch << Ldc(es.size) + ch << NewArray(s"$ObjectClass") + for((e,i) <- es.zipWithIndex) { + ch << DUP + ch << Ldc(i) + mkBoxedExpr(e, ch) + ch << AASTORE + } + ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V") } private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { @@ -675,17 +873,7 @@ trait CodeGeneration { instrumentedGetField(ch, cct, sid) // Tuples (note that instanceOf checks are in mkBranch) - case Tuple(es) => - ch << New(TupleClass) << DUP - ch << Ldc(es.size) - ch << NewArray(s"$ObjectClass") - for((e,i) <- es.zipWithIndex) { - ch << DUP - ch << Ldc(i) - mkBoxedExpr(e, ch) - ch << AASTORE - } - ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V") + case Tuple(es) => mkTuple(es, ch) case TupleSelect(t, i) => val TupleType(bs) = t.getType @@ -938,26 +1126,38 @@ trait CodeGeneration { ch << InvokeVirtual(LambdaClass, "apply", s"([L$ObjectClass;)L$ObjectClass;") mkUnbox(app.getType, ch) - case p @ PartialLambda(mapping, dflt, _) => - if (dflt.isDefined) { - mkExpr(dflt.get, ch) - ch << New(PartialLambdaClass) - ch << InvokeSpecial(PartialLambdaClass, constructorName, s"(L$ObjectClass;)V") - } else { - ch << DefaultNew(PartialLambdaClass) + case p @ PartialLambda(mapping, optDflt, _) => + ch << New(PartialLambdaClass) << DUP + optDflt match { + case Some(dflt) => + mkBoxedExpr(dflt, ch) + ch << InvokeSpecial(PartialLambdaClass, constructorName, s"(L$ObjectClass;)V") + case None => + ch << InvokeSpecial(PartialLambdaClass, constructorName, "()V") } for ((es,v) <- mapping) { - mkExpr(Tuple(es), ch) - mkExpr(v, ch) + ch << DUP + mkTuple(es, ch) + mkBoxedExpr(v, ch) ch << InvokeVirtual(PartialLambdaClass, "add", s"(L$TupleClass;L$ObjectClass;)V") } case l @ Lambda(args, body) => - compileLambda(l, ch) + val (afName, closures, consSig) = compileLambda(l) + + ch << New(afName) << DUP + for ((id,jvmt) <- closures) { + if (id == monitorID) { + load(monitorID, ch) + } else { + mkExpr(Variable(id), ch) + } + } + ch << InvokeSpecial(afName, constructorName, consSig) case f @ Forall(args, body) => - compileForall(f, ch) + mkForall(args.map(_.id).toSet, body, ch) // Arithmetic case Plus(l, r) => @@ -1250,7 +1450,7 @@ trait CodeGeneration { // Assumes the top of the stack contains of value of the right type, and makes it // compatible with java.lang.Object. - private[codegen] def mkBox(tpe: TypeTree, ch: CodeHandler)(implicit locals: Locals) { + private[codegen] def mkBox(tpe: TypeTree, ch: CodeHandler): Unit = { tpe match { case Int32Type => ch << New(BoxedIntClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedIntClass, constructorName, "(I)V") @@ -1268,7 +1468,7 @@ trait CodeGeneration { } // Assumes that the top of the stack contains a value that should be of type `tpe`, and unboxes it to the right (JVM) type. - private[codegen] def mkUnbox(tpe: TypeTree, ch: CodeHandler)(implicit locals: Locals) { + private[codegen] def mkUnbox(tpe: TypeTree, ch: CodeHandler): Unit = { tpe match { case Int32Type => ch << CheckCast(BoxedIntClass) << InvokeVirtual(BoxedIntClass, "intValue", "()I") @@ -1540,7 +1740,7 @@ trait CodeGeneration { lzy.returnType match { case ValueType() => // Since the underlying field only has boxed types, we have to unbox them to return them - mkUnbox(lzy.returnType, ch)(newLocs) + mkUnbox(lzy.returnType, ch) ch << IRETURN case _ => ch << ARETURN @@ -1874,7 +2074,7 @@ trait CodeGeneration { pech << Ldc(i) pech << ALoad(0) instrumentedGetField(pech, cct, f.id)(newLocs) - mkBox(f.getType, pech)(newLocs) + mkBox(f.getType, pech) pech << AASTORE } diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 70428db8cc0be9664e38b8e011f92ed0ebfb04bf..ac303c79351831a6f1c0dca37e5b283a64ccfa50 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -127,13 +127,24 @@ class CompilationUnit(val ctx: LeonContext, conss.last } - def modelToJVM(model: solvers.Model, maxInvocations: Int): LeonCodeGenRuntimeMonitor = model match { + def modelToJVM(model: solvers.Model, maxInvocations: Int, check: Boolean): LeonCodeGenRuntimeMonitor = model match { case hModel: solvers.HenkinModel => - val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations) - for ((tpe, domain) <- hModel.domains; args <- domain) { + val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations, check) + for ((lambda, domain) <- hModel.doms.lambdas) { + val (afName, _, _) = compileLambda(lambda) + val lc = loader.loadClass(afName) + + for (args <- domain) { + // note here that it doesn't matter that `lhm` doesn't yet have its domains + // filled since all values in `args` should be grounded + val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] + lhm.add(lc, inputJvm) + } + } + + for ((tpe, domain) <- hModel.doms.tpes; args <- domain) { val tpeId = typeId(tpe) - // note here that it doesn't matter that `lhm` doesn't yet have its domains - // filled since all values in `args` should be grounded + // same remark as above about valueToJVM(_)(lhm) val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] lhm.add(tpeId, inputJvm) } @@ -537,3 +548,5 @@ class CompilationUnit(val ctx: LeonContext, private [codegen] object exprCounter extends UniqueCounter[Unit] private [codegen] object lambdaCounter extends UniqueCounter[Unit] +private [codegen] object forallCounter extends UniqueCounter[Unit] + diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index fc2d3bd6470ca900c1515eef74bedf679146b1fa..f9fca911564c61ad984fa97c3f2ac0da7fc021b4 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -8,7 +8,7 @@ import purescala.Expressions._ import cafebabe._ -import runtime.{LeonCodeGenRuntimeMonitor => LM, LeonCodeGenRuntimeHenkinMonitor => LHM} +import runtime.{LeonCodeGenRuntimeMonitor => LM} import java.lang.reflect.InvocationTargetException @@ -51,19 +51,10 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, } } - def eval(model: solvers.Model) : Expr = { + def eval(model: solvers.Model, check: Boolean = false) : Expr = { try { - val monitor = unit.modelToJVM(model, params.maxFunctionInvocations) - val res = evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) - monitor match { - case hm: LHM => - val it = hm.getWarnings().iterator() - while (it.hasNext) { - unit.ctx.reporter.warning(it.next) - } - case _ => - } - res + val monitor = unit.modelToJVM(model, params.maxFunctionInvocations, check) + evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) } catch { case ite : InvocationTargetException => throw ite.getCause } diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index e68e21f011f3936797157688901018923c2127bf..3bd58a01d575928a5663ff5bf5446d9c161dce64 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -110,13 +110,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val cs = for (size <- List(1, 2, 3, 5)) yield { val subs = (1 to size).flatMap(_ => from :+ to).toList Constructor[Expr, TypeTree](subs, ft, { s => - val args = from.map(tpe => FreshIdentifier("x", tpe, true)) - val argsTuple = tupleWrap(args.map(_.toVariable)) val grouped = s.grouped(from.size + 1).toSeq - val body = grouped.init.foldRight(grouped.last.last) { case (t, elze) => - IfExpr(Equals(argsTuple, tupleWrap(t.init)), t.last, elze) - } - Lambda(args.map(id => ValDef(id)), body) + val mapping = grouped.init.map { case args :+ res => (args -> res) } + PartialLambda(mapping, Some(grouped.last.last), ft) }, ft.asString(ctx) + "@" + size) } constructors += ft -> cs diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index 36cd9da0c7c35cfdc9d674162c3279d461afe813..f470a7690535dd22d82480ddb65d944c80c17b80 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -9,8 +9,13 @@ import purescala.Expressions._ import purescala.Quantification._ import codegen.CompilationUnit +import codegen.CompiledExpression import codegen.CodeGenParams +import leon.codegen.runtime.LeonCodeGenRuntimeException +import leon.codegen.runtime.LeonCodeGenEvaluationException +import leon.codegen.runtime.LeonCodeGenQuantificationException + class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) { val name = "codegen-eval" val description = "Evaluator for PureScala expressions based on compilation to JVM" @@ -20,9 +25,55 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva this(ctx, new CompilationUnit(ctx, prog, params)) } + private def compileExpr(expression: Expr, args: Seq[Identifier]): Option[CompiledExpression] = { + ctx.timers.evaluators.codegen.compilation.start() + try { + Some(unit.compileExpression(expression, args)(ctx)) + } catch { + case t: Throwable => + ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage) + None + } finally { + ctx.timers.evaluators.codegen.compilation.stop() + } + } + + def check(expression: Expr, model: solvers.Model) : CheckResult = { + compileExpr(expression, model.toSeq.map(_._1)).map { ce => + ctx.timers.evaluators.codegen.runtime.start() + try { + val res = ce.eval(model, check = true) + if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess + else EvaluationResults.CheckValidityFailure + } catch { + case e : ArithmeticException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : ArrayIndexOutOfBoundsException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : LeonCodeGenRuntimeException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : LeonCodeGenEvaluationException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : java.lang.ExceptionInInitializerError => + EvaluationResults.CheckRuntimeFailure(e.getException.getMessage) + + case so : java.lang.StackOverflowError => + EvaluationResults.CheckRuntimeFailure("Stack overflow") + + case e : LeonCodeGenQuantificationException => + EvaluationResults.CheckQuantificationFailure(e.getMessage) + } finally { + ctx.timers.evaluators.codegen.runtime.stop() + } + }.getOrElse(EvaluationResults.CheckRuntimeFailure("Couldn't compile expression.")) + } + def eval(expression: Expr, model: solvers.Model) : EvaluationResult = { - val toPairs = model.toSeq - compile(expression, toPairs.map(_._1)).map { e => + compile(expression, model.toSeq.map(_._1)).map { e => ctx.timers.evaluators.codegen.runtime.start() val res = e(model) ctx.timers.evaluators.codegen.runtime.stop() @@ -31,45 +82,30 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva } override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { - import leon.codegen.runtime.LeonCodeGenRuntimeException - import leon.codegen.runtime.LeonCodeGenEvaluationException - - ctx.timers.evaluators.codegen.compilation.start() - try { - val ce = unit.compileExpression(expression, args)(ctx) - - Some((model: solvers.Model) => { - if (args.exists(arg => !model.isDefinedAt(arg))) { - EvaluationResults.EvaluatorError("Model undefined for free arguments") - } else try { - EvaluationResults.Successful(ce.eval(model)) - } catch { - case e : ArithmeticException => - EvaluationResults.RuntimeError(e.getMessage) + compileExpr(expression, args).map(ce => (model: solvers.Model) => { + if (args.exists(arg => !model.isDefinedAt(arg))) { + EvaluationResults.EvaluatorError("Model undefined for free arguments") + } else try { + EvaluationResults.Successful(ce.eval(model)) + } catch { + case e : ArithmeticException => + EvaluationResults.RuntimeError(e.getMessage) - case e : ArrayIndexOutOfBoundsException => - EvaluationResults.RuntimeError(e.getMessage) + case e : ArrayIndexOutOfBoundsException => + EvaluationResults.RuntimeError(e.getMessage) - case e : LeonCodeGenRuntimeException => - EvaluationResults.RuntimeError(e.getMessage) + case e : LeonCodeGenRuntimeException => + EvaluationResults.RuntimeError(e.getMessage) - case e : LeonCodeGenEvaluationException => - EvaluationResults.EvaluatorError(e.getMessage) + case e : LeonCodeGenEvaluationException => + EvaluationResults.EvaluatorError(e.getMessage) - case e : java.lang.ExceptionInInitializerError => - EvaluationResults.RuntimeError(e.getException.getMessage) + case e : java.lang.ExceptionInInitializerError => + EvaluationResults.RuntimeError(e.getException.getMessage) - case so : java.lang.StackOverflowError => - EvaluationResults.RuntimeError("Stack overflow") - - } - }) - } catch { - case t: Throwable => - ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage) - None - } finally { - ctx.timers.evaluators.codegen.compilation.stop() - } + case so : java.lang.StackOverflowError => + EvaluationResults.RuntimeError("Stack overflow") + } + }) } } diff --git a/src/main/scala/leon/evaluators/DefaultEvaluator.scala b/src/main/scala/leon/evaluators/DefaultEvaluator.scala index d732d48c0d40aaacf97cd8125b077c1cef397148..e9b4c6a01229c60130cbd588eb2a67a8738b0281 100644 --- a/src/main/scala/leon/evaluators/DefaultEvaluator.scala +++ b/src/main/scala/leon/evaluators/DefaultEvaluator.scala @@ -13,7 +13,7 @@ class DefaultEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluat type GC = GlobalContext def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC(model: solvers.Model) = new GlobalContext(model) + def initGC(model: solvers.Model, check: Boolean) = new GlobalContext(model, check) case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext { def newVars(news: Map[Identifier, Expr]) = copy(news) diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index cd843fbb145e4f9220b2e9fe3d91e30d8ff3c1be..d05043d913d7a9df7546f37b579a211250839bcd 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -18,7 +18,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte implicit val debugSection = utils.DebugSectionEvaluation def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC(model: solvers.Model) = new GlobalContext(model) + def initGC(model: solvers.Model, check: Boolean) = new GlobalContext(model, check) var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) diff --git a/src/main/scala/leon/evaluators/EvaluationResults.scala b/src/main/scala/leon/evaluators/EvaluationResults.scala index e37c61e3b73c8261fb9ca2456abd0fd8811b1877..8628042a61049e5c2f5fba202aa01d22e6fa816b 100644 --- a/src/main/scala/leon/evaluators/EvaluationResults.scala +++ b/src/main/scala/leon/evaluators/EvaluationResults.scala @@ -17,4 +17,21 @@ object EvaluationResults { /** Represents an evaluation that failed (in the evaluator). */ case class EvaluatorError(message : String) extends Result(None) + + /** Results of checking proposition evaluation. + * Useful for verification of model validity in presence of quantifiers. */ + sealed abstract class CheckResult(val success: Boolean) + + /** Successful proposition evaluation (model |= expr) */ + case object CheckSuccess extends CheckResult(true) + + /** Check failed with `model |= !expr` */ + case object CheckValidityFailure extends CheckResult(false) + + /** Check failed due to evaluation or runtime errors. + * @see [[RuntimeError]] and [[EvaluatorError]] */ + case class CheckRuntimeFailure(msg: String) extends CheckResult(false) + + /** Check failed due to inconsistence of model with quantified propositions. */ + case class CheckQuantificationFailure(msg: String) extends CheckResult(false) } diff --git a/src/main/scala/leon/evaluators/Evaluator.scala b/src/main/scala/leon/evaluators/Evaluator.scala index 9d14bd3dfcc0eb6f08f5273b98ac1448038b363e..d64bda7fa5b8862ede59b5a0c35863a8439b565a 100644 --- a/src/main/scala/leon/evaluators/Evaluator.scala +++ b/src/main/scala/leon/evaluators/Evaluator.scala @@ -14,6 +14,7 @@ import solvers.Model abstract class Evaluator(val context: LeonContext, val program: Program) extends LeonComponent { type EvaluationResult = EvaluationResults.Result + type CheckResult = EvaluationResults.CheckResult /** Evaluates an expression, using [[Model.mapping]] as a valuation function for the free variables. */ def eval(expr: Expr, model: Model) : EvaluationResult @@ -26,6 +27,9 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends /** Evaluates a ground expression. */ final def eval(expr: Expr) : EvaluationResult = eval(expr, Model.empty) + /** Checks that `model |= expr` and that quantifications are all valid */ + def check(expr: Expr, model: Model) : CheckResult + /** Compiles an expression into a function, where the arguments are the free variables in the expression. * `argorder` specifies in which order the arguments should be passed. * The default implementation uses the evaluation function each time, but evaluators are free diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 9e33722dd62c67e1357acd2012d6955647715356..4b3e490228a34d164b816c8dc9d35b530506979e 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -15,6 +15,8 @@ import purescala.Quantification._ import solvers.{Model, HenkinModel} import solvers.SolverFactory +import scala.collection.mutable.{Map => MutableMap} + abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int) extends Evaluator(ctx, prog) { val name = "evaluator" val description = "Recursive interpreter for PureScala expressions" @@ -24,8 +26,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int type RC <: RecContext type GC <: GlobalContext - case class EvalError(msg : String) extends Exception - case class RuntimeError(msg : String) extends Exception + case class EvalError(msg: String) extends Exception + case class RuntimeError(msg: String) extends Exception + case class QuantificationError(msg: String) extends Exception val scalaEv = new ScalacEvaluator(this, ctx, prog) @@ -43,29 +46,50 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } } - class GlobalContext(val model: Model) { + class GlobalContext(val model: Model, val check: Boolean) { def maxSteps = RecursiveEvaluator.this.maxSteps - var stepsLeft = maxSteps - var warnings = List.empty[String] + + val lambdas: MutableMap[Lambda, Lambda] = MutableMap.empty } def initRC(mappings: Map[Identifier, Expr]): RC - def initGC(model: Model): GC + def initGC(model: Model, check: Boolean): GC // Used by leon-web, please do not delete - // Used by quantified proposition checking now too! var lastGC: Option[GC] = None private[this] var clpCache = Map[(Choose, Seq[Expr]), Expr]() - def eval(ex: Expr, model: Model) = { + def check(ex: Expr, model: Model): CheckResult = { + assert(ex.getType == BooleanType, "Can't check non-boolean expression " + ex.asString(ctx)) try { - lastGC = Some(initGC(model)) + lastGC = Some(initGC(model, true)) ctx.timers.evaluators.recursive.runtime.start() val res = e(ex)(initRC(model.toMap), lastGC.get) - for (warning <- lastGC.get.warnings) ctx.reporter.warning(warning) - EvaluationResults.Successful(res) + if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess + else EvaluationResults.CheckValidityFailure + } catch { + case so: StackOverflowError => + EvaluationResults.CheckRuntimeFailure("Stack overflow") + case e @ EvalError(msg) => + EvaluationResults.CheckRuntimeFailure(msg) + case e @ RuntimeError(msg) => + EvaluationResults.CheckRuntimeFailure(msg) + case jre: java.lang.RuntimeException => + EvaluationResults.CheckRuntimeFailure(jre.getMessage) + case qe @ QuantificationError(msg) => + EvaluationResults.CheckQuantificationFailure(msg) + } finally { + ctx.timers.evaluators.recursive.runtime.stop() + } + } + + def eval(ex: Expr, model: Model) = { + try { + lastGC = Some(initGC(model, false)) + ctx.timers.evaluators.recursive.runtime.start() + EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get)) } catch { case so: StackOverflowError => EvaluationResults.EvaluatorError("Stack overflow") @@ -80,6 +104,141 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } } + private def evalForall(quants: Set[Identifier], body: Expr, check: Boolean = true)(implicit rctx: RC, gctx: GC): Expr = { + val henkinModel: HenkinModel = gctx.model match { + case hm: HenkinModel => hm + case _ => throw EvalError("Can't evaluate foralls without henkin model") + } + + val TopLevelAnds(conjuncts) = body + e(andJoin(conjuncts.flatMap { conj => + val vars = variablesOf(conj) + val quantified = quants.filter(vars) + + extractQuorums(conj, quantified).flatMap { case (qrm, others) => + val quorum = qrm.toList + + if (quorum.exists { case (TopLevelAnds(paths), _, _) => + val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty)) + e(p) == BooleanLiteral(false) + }) List(BooleanLiteral(true)) else { + + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + + for (((_, expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id),aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + def domain(expr: Expr): Set[Seq[Expr]] = henkinModel.domain(e(expr) match { + case l: Lambda => gctx.lambdas.getOrElse(l, l) + case ev => ev + }) + + val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { + case (acc, (_, expr, _)) => acc.flatMap(s => domain(expr).map(d => s :+ d)) + } + + argSets.map { args => + val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { + case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } + }.toMap + + val map = mapping.map { case (id, key) => id -> argMap(key) } + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) + } ++ equalities.map { + case (k1, k2) => Equals(argMap(k1), argMap(k2)) + }) + + val ctx = rctx.withNewVars(map) + if (e(enabler)(ctx, gctx) == BooleanLiteral(true)) { + if (gctx.check) { + for ((b,caller,args) <- others if e(b)(ctx, gctx) == BooleanLiteral(true)) { + val evArgs = args.map(arg => e(arg)(ctx, gctx)) + if (!domain(caller)(evArgs)) + throw QuantificationError("Unhandled transitive implication in " + replaceFromIDs(map, conj)) + } + } + + e(conj)(ctx, gctx) + } else { + BooleanLiteral(true) + } + } + } + } + })) match { + case res @ BooleanLiteral(true) if check => + if (gctx.check) { + checkForall(quants, body) match { + case status: ForallInvalid => + throw QuantificationError("Invalid forall: " + status.getMessage) + case _ => + // make sure the body doesn't contain matches or lets as these introduce new locals + val cleanBody = expandLets(matchToIfThenElse(body)) + val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] { + def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match { + case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path)) + case _ => None + } + + override def rec(e: Expr, path: Seq[Expr]): Expr = e match { + case l : Lambda => l + case _ => super.rec(e, path) + } + }.traverse(cleanBody) + + for ((caller, appArgs, paths) <- calls) { + val path = andJoin(paths.filter(expr => (variablesOf(expr) & quants).isEmpty)) + if (e(path) == BooleanLiteral(true)) e(caller) match { + case _: PartialLambda => // OK + case l: Lambda => + val nl @ Lambda(args, body) = gctx.lambdas.getOrElse(l, l) + val lambdaQuantified = (appArgs zip args).collect { + case (Variable(id), vd) if quants(id) => vd.id + }.toSet + + if (lambdaQuantified.nonEmpty) { + checkForall(lambdaQuantified, body) match { + case lambdaStatus: ForallInvalid => + throw QuantificationError("Invalid forall: " + lambdaStatus.getMessage) + case _ => // do nothing + } + + val axiom = Equals(Application(nl, args.map(_.toVariable)), nl.body) + if (evalForall(args.map(_.id).toSet, axiom, check = false) == BooleanLiteral(false)) { + throw QuantificationError("Unaxiomatic lambda " + l) + } + } + case f => + throw EvalError("Cannot apply non-lambda function " + f.asString) + } + } + } + } + + res + + // `res == false` means the quantification is valid since there effectivelly must + // exist an input for which the proposition doesn't hold + case res => res + } + } + protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { case Variable(id) => rctx.mappings.get(id) match { @@ -101,7 +260,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int mapping.find { case (pargs, res) => (args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true)) }.map(_._2).orElse(dflt).getOrElse { - throw EvalError("Cannot apply partial lambda outside of domain") + throw EvalError("Cannot apply partial lambda outside of domain : " + + args.map(e(_).asString(ctx)).mkString("(", ", ", ")")) } case f => throw EvalError("Cannot apply non-lambda function " + f.asString) @@ -516,101 +676,19 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int FiniteSet(els.map(e), base) case l @ Lambda(_, _) => - val (nl, structSubst) = normalizeStructure(l) + val (nl, structSubst) = normalizeStructure(matchToIfThenElse(l)) val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap - replaceFromIDs(mapping, nl) + val newLambda = replaceFromIDs(mapping, nl).asInstanceOf[Lambda] + if (!gctx.lambdas.isDefinedAt(newLambda)) { + gctx.lambdas += (newLambda -> nl.asInstanceOf[Lambda]) + } + newLambda case PartialLambda(mapping, dflt, tpe) => PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), dflt.map(e), tpe) - case f @ Forall(fargs, body @ TopLevelAnds(conjuncts)) => - val henkinModel: HenkinModel = gctx.model match { - case hm: HenkinModel => hm - case _ => throw EvalError("Can't evaluate foralls without henkin model") - } - - e(andJoin(for (conj <- conjuncts) yield { - val vars = variablesOf(conj) - val args = fargs.map(_.id).filter(vars) - val quantified = args.toSet - - val matcherQuorums = extractQuorums(conj, quantified) - - val instantiations = matcherQuorums.flatMap { quorum => - var mappings: Seq[(Identifier, Int, Int)] = Seq.empty - var constraints: Seq[(Expr, Int, Int)] = Seq.empty - - for (((expr, args), qidx) <- quorum.zipWithIndex) { - val (qmappings, qconstraints) = args.zipWithIndex.partition { - case (Variable(id),aidx) => quantified(id) - case _ => false - } - - mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) - constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) - } - - var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty - val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { - val base :: others = es.toList.map(p => (p._2, p._3)) - equalities ++= others.map(p => base -> p) - (id -> base) - } - - val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { - case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d)) - } - - argSets.map { args => - val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { - case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } - }.toMap - - val map = mapping.map { case (id, key) => id -> argMap(key) } - val enabler = andJoin(constraints.map { - case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) - } ++ equalities.map { - case (k1, k2) => Equals(argMap(k1), argMap(k2)) - }) - - (enabler, map) - } - } - - e(andJoin(instantiations.map { case (enabler, mapping) => - e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx) - })) - })) match { - case res @ BooleanLiteral(true) => - val quantified = fargs.map(_.id).toSet - val status = checkForall(quantified, body) - if (!status.isValid) { - gctx.warnings :+= "Invalid forall: " + status - } else { - for ((caller, appArgs) <- firstOrderAppsOf(body)) e(caller) match { - case _: PartialLambda => // OK - case Lambda(args, body) => - val lambdaQuantified = (appArgs zip args).collect { - case (Variable(id), vd) if quantified(id) => vd.id - }.toSet - - if (lambdaQuantified.nonEmpty) { - val lambdaStatus = checkForall(lambdaQuantified, body) - if (!lambdaStatus.isValid) { - gctx.warnings :+= "Invalid forall: " + lambdaStatus - } - } - case f => - throw EvalError("Cannot apply non-lambda function " + f.asString) - } - } - - res - - // `res == false` means the quantification is valid since there effectivelly must - // exist an input for which the proposition doesn't hold - case res => res - } + case Forall(fargs, body) => + evalForall(fargs.map(_.id).toSet, body) case ArrayLength(a) => val FiniteArray(_, _, IntLiteral(length)) = e(a) diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index ec977763f3da3f583d225cbdde8dd98aa33f312b..82a5bb04536ba240f8f95589a7a9ad87dbd6e9f6 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -15,9 +15,10 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex def initRC(mappings: Map[Identifier, Expr]) = TracingRecContext(mappings, 2) - def initGC(model: solvers.Model) = new TracingGlobalContext(Nil, model) + def initGC(model: solvers.Model, check: Boolean) = new TracingGlobalContext(Nil, model, check) - class TracingGlobalContext(var values: List[(Tree, Expr)], model: solvers.Model) extends GlobalContext(model) + class TracingGlobalContext(var values: List[(Tree, Expr)], model: solvers.Model, check: Boolean) + extends GlobalContext(model, check) case class TracingRecContext(mappings: Map[Identifier, Expr], tracingFrames: Int) extends RecContext { def newVars(news: Map[Identifier, Expr]) = copy(mappings = news) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 2dc33e801c4a26728ccd0551be046c2b2c6b2ed5..d15a761355794bb3f58b2d7c0f64bdf1903c2433 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -1139,9 +1139,8 @@ object ExprOps { case tp: TypeParameter => GenericValue(tp, 0) - case FunctionType(from, to) => - val args = from.map(tpe => ValDef(FreshIdentifier("x", tpe, true))) - Lambda(args, simplestValue(to)) + case ft @ FunctionType(from, to) => + PartialLambda(Seq.empty, Some(simplestValue(to)), ft) case _ => throw LeonFatalError("I can't choose simplest value for type " + tpe) } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 8159f1e9e5589acc69bdfe6267e1952c6c015793..755c540d1f3b23c505488d49ecd945fa324d0080 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -259,7 +259,7 @@ class PrettyPrinter(opts: PrinterOptions, } if (dflt.isDefined) { - p" ${dflt.get}" + p" getOrElse ${dflt.get}" } } diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala index bf88450fd54063338d63310cbec7de7ddf9db76b..0889439714161d437cd936fd54734dc2c9e2b657 100644 --- a/src/main/scala/leon/purescala/Quantification.scala +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -6,6 +6,7 @@ package purescala import Common._ import Definitions._ import Expressions._ +import Constructors._ import Extractors._ import ExprOps._ import Types._ @@ -22,7 +23,7 @@ object Quantification { ): Seq[Set[A]] = { def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) val expandedMap: Map[A, Set[A]] = matchers.map(m => m -> expand(m)).toMap - val reverseMap : Map[A, Set[A]] = expandedMap + val reverseMap : Map[A, Set[A]] = expandedMap.toSeq .flatMap(p => p._2.map(m => m -> p._1)) // flatten to reversed pairs .groupBy(_._1).mapValues(_.map(_._2).toSet) // rebuild map from pair set @@ -48,7 +49,7 @@ object Quantification { res.filter(ms => ms.forall(m => reverseMap(m) subsetOf ms)) } - def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Expr, Seq[Expr])]] = { + def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[(Set[(Expr, Expr, Seq[Expr])], Set[(Expr, Expr, Seq[Expr])])] = { object QMatcher { def unapply(e: Expr): Option[(Expr, Seq[Expr])] = e match { case QuantificationMatcher(expr, args) => @@ -61,18 +62,20 @@ object Quantification { } } - extractQuorums(collect { - case QMatcher(e, a) => Set(e -> a) - case _ => Set.empty[(Expr, Seq[Expr])] - } (expr), quantified, - (p: (Expr, Seq[Expr])) => p._2.collect { case QMatcher(e, a) => e -> a }.toSet, - (p: (Expr, Seq[Expr])) => p._2.collect { case Variable(id) if quantified(id) => id }.toSet) + val allMatchers = CollectorWithPaths { case QMatcher(expr, args) => expr -> args }.traverse(expr) + val matchers = allMatchers.map { case ((caller, args), path) => (path, caller, args) }.toSet + + val quorums = extractQuorums(matchers, quantified, + (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case QMatcher(e, a) => (p._1, e, a) }.toSet, + (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case Variable(id) if quantified(id) => id }.toSet) + + quorums.map(quorum => quorum -> matchers.filter(m => !quorum(m))) } def extractModel( asMap: Map[Identifier, Expr], funDomains: Map[Identifier, Set[Seq[Expr]]], - tpeDomains: Map[TypeTree, Set[Seq[Expr]]], + typeDomains: Map[TypeTree, Set[Seq[Expr]]], evaluator: Evaluator ): Map[Identifier, Expr] = asMap.map { case (id, expr) => id -> (funDomains.get(id) match { @@ -84,12 +87,12 @@ object Quantification { case None => postMap { case p @ PartialLambda(mapping, dflt, tpe) => - Some(PartialLambda(tpeDomains.get(tpe) match { + Some(PartialLambda(typeDomains.get(tpe) match { case Some(domain) => domain.toSeq.map { es => val optEv = evaluator.eval(Application(p, es)).result es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(p, es))) } - case _ => scala.sys.error(s"Can't extract $p without domain") + case _ => Seq.empty }, None, tpe)) case _ => None } (expr) @@ -97,28 +100,24 @@ object Quantification { } object HenkinDomains { - def empty = new HenkinDomains(Map.empty) - def apply(domains: Map[TypeTree, Set[Seq[Expr]]]) = new HenkinDomains(domains) + def empty = new HenkinDomains(Map.empty, Map.empty) } - class HenkinDomains (val domains: Map[TypeTree, Set[Seq[Expr]]]) { - def get(e: Expr): Set[Seq[Expr]] = e match { - case PartialLambda(_, Some(dflt), _) => scala.sys.error("No domain for non-partial lambdas") - case PartialLambda(mapping, _, _) => mapping.map(_._1).toSet - case _ => domains.get(e.getType) match { - case Some(domain) => domain - case None => scala.sys.error("Undefined Henkin domain for " + e) + class HenkinDomains (val lambdas: Map[Lambda, Set[Seq[Expr]]], val tpes: Map[TypeTree, Set[Seq[Expr]]]) { + def get(e: Expr): Set[Seq[Expr]] = { + val specialized: Set[Seq[Expr]] = e match { + case PartialLambda(_, Some(dflt), _) => scala.sys.error("No domain for non-partial lambdas") + case PartialLambda(mapping, _, _) => mapping.map(_._1).toSet + case l: Lambda => lambdas.getOrElse(l, Set.empty) + case _ => Set.empty } + specialized ++ tpes.getOrElse(e.getType, Set.empty) } - - override def toString = domains.map { case (tpe, argSet) => - tpe + ": " + argSet.map(_.mkString("(", ",", ")")).mkString(", ") - }.mkString("domain={\n ", "\n ", "}") } object QuantificationMatcher { private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(fi: FunctionInvocation, _) => None + case Application(fi: FunctionInvocation, args) => Some((fi, args)) case Application(caller: Application, args) => flatApplication(caller) match { case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) case None => None @@ -162,16 +161,17 @@ object Quantification { def isValid = true } - sealed abstract class ForallInvalid extends ForallStatus { + sealed abstract class ForallInvalid(msg: String) extends ForallStatus { def isValid = false + def getMessage: String = msg } - case object NoMatchers extends ForallInvalid - case class ComplexArgument(expr: Expr) extends ForallInvalid - case class NonBijectiveMapping(expr: Expr) extends ForallInvalid - case class InvalidOperation(expr: Expr) extends ForallInvalid + case class NoMatchers(expr: String) extends ForallInvalid("No matchers available for E-Matching in " + expr) + case class ComplexArgument(expr: String) extends ForallInvalid("Unhandled E-Matching pattern in " + expr) + case class NonBijectiveMapping(expr: String) extends ForallInvalid("Non-bijective mapping for quantifiers in " + expr) + case class InvalidOperation(expr: String) extends ForallInvalid("Invalid operation on quantifiers in " + expr) - def checkForall(quantified: Set[Identifier], body: Expr): ForallStatus = { + def checkForall(quantified: Set[Identifier], body: Expr)(implicit ctx: LeonContext): ForallStatus = { val TopLevelAnds(conjuncts) = body for (conjunct <- conjuncts) { val matchers = collect[(Expr, Seq[Expr])] { @@ -179,7 +179,7 @@ object Quantification { case _ => Set.empty } (conjunct) - if (matchers.isEmpty) return NoMatchers + if (matchers.isEmpty) return NoMatchers(conjunct.asString) val complexArgs = matchers.flatMap { case (_, args) => args.flatMap(arg => arg match { @@ -190,7 +190,7 @@ object Quantification { }) } - if (complexArgs.nonEmpty) return ComplexArgument(complexArgs.head) + if (complexArgs.nonEmpty) return ComplexArgument(complexArgs.head.asString) val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { @@ -200,7 +200,7 @@ object Quantification { } val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) - if (bijectiveMappings.size > 1) return NonBijectiveMapping(bijectiveMappings.head._2.head._1) + if (bijectiveMappings.size > 1) return NonBijectiveMapping(bijectiveMappings.head._2.head._1.asString) val matcherSet = matcherToQuants.filter(_._2.nonEmpty).keys.toSet @@ -223,7 +223,7 @@ object Quantification { case Operator(es, _) => val matcherArgs = matcherSet & es.toSet if (q.nonEmpty && !(q.size == 1 && matcherArgs.isEmpty && m.getType == BooleanType)) - return InvalidOperation(m) + return InvalidOperation(m.asString) else Set.empty case Variable(id) if quantified(id) => Set(id) case _ => q diff --git a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala index 664b9e3b26f0229cfb0ec12cf52bef7a53d6aa6a..8fee704d295e153016d256c18abcd991ad9788b6 100644 --- a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala +++ b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala @@ -22,7 +22,7 @@ class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends Recursive type GC = GlobalContext def initRC(mappings: Map[Identifier, Expr]) = CollectingRecContext(mappings, None) - def initGC(model: leon.solvers.Model) = new GlobalContext(model) + def initGC(model: leon.solvers.Model, check: Boolean) = new GlobalContext(model, check) type FI = (FunDef, Seq[Expr]) diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala index dc3e8584fd74578ac2173e1819c059da9f0ec99b..fa11ab6613bd65b196cce87ee062c3c56f0b95f9 100644 --- a/src/main/scala/leon/solvers/QuantificationSolver.scala +++ b/src/main/scala/leon/solvers/QuantificationSolver.scala @@ -4,14 +4,14 @@ package solvers import purescala.Common._ import purescala.Expressions._ import purescala.Quantification._ +import purescala.Definitions._ import purescala.Types._ -class HenkinModel(mapping: Map[Identifier, Expr], doms: HenkinDomains) +class HenkinModel(mapping: Map[Identifier, Expr], val doms: HenkinDomains) extends Model(mapping) with AbstractModel[HenkinModel] { override def newBuilder = new HenkinModelBuilder(doms) - def domains: Map[TypeTree, Set[Seq[Expr]]] = doms.domains def domain(expr: Expr) = doms.get(expr) } @@ -26,5 +26,10 @@ class HenkinModelBuilder(domains: HenkinDomains) } trait QuantificationSolver { + val program: Program def getModel: HenkinModel + + protected lazy val requireQuantification = program.definedFunctions.exists { fd => + purescala.ExprOps.exists { case _: Forall => true case _ => false } (fd.fullBody) + } } diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala index 10902ebbf49a409d19d584b2c57c799fc5487d0e..81ed53e4bebff420aea42ff09895d3b1e055da79 100644 --- a/src/main/scala/leon/solvers/Solver.scala +++ b/src/main/scala/leon/solvers/Solver.scala @@ -8,6 +8,82 @@ import purescala.Expressions._ import purescala.Common.Tree import verification.VC +trait AbstractModel[+This <: Model with AbstractModel[This]] + extends scala.collection.IterableLike[(Identifier, Expr), This] { + + protected val mapping: Map[Identifier, Expr] + + def set(allVars: Iterable[Identifier]): This = { + val builder = newBuilder + builder ++= allVars.map(id => id -> mapping.getOrElse(id, simplestValue(id.getType))) + builder.result + } + + def ++(mapping: Map[Identifier, Expr]): This = { + val builder = newBuilder + builder ++= this.mapping ++ mapping + builder.result + } + + def filter(allVars: Iterable[Identifier]): This = { + val builder = newBuilder + for (p <- mapping.filterKeys(allVars.toSet)) { + builder += p + } + builder.result + } + + def iterator = mapping.iterator + def seq = mapping.seq + + def asString(ctx: LeonContext): String = { + val strings = toSeq.sortBy(_._1.name).map { + case (id, v) => (id.asString(ctx), purescala.PrettyPrinter(v)) + } + + if (strings.nonEmpty) { + val max = strings.map(_._1.length).max + + strings.map { case (id, v) => ("%-"+max+"s -> %s").format(id, v) }.mkString("\n") + } else { + "(Empty model)" + } + } +} + +trait AbstractModelBuilder[+This <: Model with AbstractModel[This]] + extends scala.collection.mutable.Builder[(Identifier, Expr), This] { + + import scala.collection.mutable.MapBuilder + protected val mapBuilder = new MapBuilder[Identifier, Expr, Map[Identifier, Expr]](Map.empty) + + def +=(elem: (Identifier, Expr)): this.type = { + mapBuilder += elem + this + } + + def clear(): Unit = mapBuilder.clear +} + +class Model(protected val mapping: Map[Identifier, Expr]) + extends AbstractModel[Model] + with (Identifier => Expr) { + + def newBuilder = new ModelBuilder + def isDefinedAt(id: Identifier): Boolean = mapping.isDefinedAt(id) + def get(id: Identifier): Option[Expr] = mapping.get(id) + def getOrElse[E >: Expr](id: Identifier, e: E): E = get(id).getOrElse(e) + def apply(id: Identifier): Expr = get(id).getOrElse { throw new IllegalArgumentException } +} + +object Model { + def empty = new Model(Map.empty) +} + +class ModelBuilder extends AbstractModelBuilder[Model] { + def result = new Model(mapBuilder.result) +} + trait Solver extends Interruptible { def name: String val context: LeonContext diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 0ec568cc18903e4b6b5038a678ce1b44aabc8de4..f7d395feb913244ed24e5aa0379c1e2e8480610c 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -13,7 +13,7 @@ import purescala.ExprOps._ import purescala.Types._ import utils._ -import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre} +import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoChecks} import templates._ import evaluators._ @@ -26,6 +26,7 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying val feelingLucky = context.findOptionOrDefault(optFeelingLucky) val useCodeGen = context.findOptionOrDefault(optUseCodeGen) val assumePreHolds = context.findOptionOrDefault(optAssumePre) + val disableChecks = context.findOptionOrDefault(optNoChecks) protected var lastCheckResult : (Boolean, Option[Boolean], Option[HenkinModel]) = (false, None, None) @@ -112,6 +113,7 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying def extract(b: Expr, m: Matcher[Expr]): Set[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe val optEnabler = evaluator.eval(b, model).result + if (optEnabler == Some(BooleanLiteral(true))) { val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg), model).result) if (optArgs.forall(_.isDefined)) { @@ -124,60 +126,101 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying } } - val funDomains = allVars.flatMap(id => id.getType match { - case ft @ FunctionType(fromTypes, _) => - Some(id -> templateGenerator.manager.instantiations(Variable(id), ft).flatMap { - case (b, m) => extract(b, m) - }) - case _ => None - }).toMap.mapValues(_.toSet) + val (typeInsts, partialInsts, lambdaInsts) = templateGenerator.manager.instantiations - val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) - val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - val asDMap = purescala.Quantification.extractModel(model.toMap, funDomains, typeDomains, evaluator) - val domains = new HenkinDomains(typeDomains) - val hmodel = new HenkinModel(asDMap, domains) + val funDomains: Map[Identifier, Set[Seq[Expr]]] = partialInsts.map { + case (Variable(id), domain) => id -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - isValidModel(hmodel) + val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { + case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - hmodel + val asDMap = purescala.Quantification.extractModel(model.toMap, funDomains, typeDomains, evaluator) + val domains = new HenkinDomains(lambdaDomains, typeDomains) + new HenkinModel(asDMap, domains) } def foundAnswer(res: Option[Boolean], model: Option[HenkinModel] = None) = { lastCheckResult = (true, res, model) } - def isValidModel(model: HenkinModel, silenceErrors: Boolean = false): Boolean = { - import EvaluationResults._ + def validatedModel(silenceErrors: Boolean = false): (Boolean, HenkinModel) = { + val lastModel = solver.getModel + val clauses = templateGenerator.manager.checkClauses + val optModel = if (clauses.isEmpty) Some(lastModel) else { + solver.push() + for (clause <- clauses) { + solver.assertCnstr(clause) + } - val expr = andJoin(constraints.toSeq) - val fullModel = model fill freeVars.toSet + reporter.debug(" - Verifying model transitivity") + val solverModel = solver.check match { + case Some(true) => + Some(solver.getModel) - evaluator.eval(expr, fullModel) match { - case Successful(BooleanLiteral(true)) => - reporter.debug("- Model validated.") - true + case Some(false) => + val msg = "- Transitivity independence not guaranteed for model" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None - case Successful(BooleanLiteral(false)) => - reporter.debug("- Invalid model.") - false + case None => + val msg = "- Unknown for transitivity independence!?" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + } - case Successful(e) => - reporter.warning("- Model leads unexpected result: "+e) - false + solver.pop() + solverModel + } - case RuntimeError(msg) => - reporter.debug("- Model leads to runtime error.") - false + optModel match { + case None => + (false, extractModel(lastModel)) - case EvaluatorError(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to evaluator error: " + msg) - } else { - reporter.warning("- Model leads to evaluator error: " + msg) - } - false + case Some(m) => + val model = extractModel(m) + + val expr = andJoin(constraints.toSeq) + val fullModel = model set freeVars.toSet + + (evaluator.check(expr, fullModel) match { + case EvaluationResults.CheckSuccess => + reporter.debug("- Model validated.") + true + + case EvaluationResults.CheckValidityFailure => + reporter.debug("- Invalid model.") + false + + case EvaluationResults.CheckRuntimeFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to evaluation error: " + msg) + } else { + reporter.warning("- Model leads to evaluation error: " + msg) + } + false + + case EvaluationResults.CheckQuantificationFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to quantification error: " + msg) + } else { + reporter.warning("- Model leads to quantification error: " + msg) + } + false + }, fullModel) } } @@ -203,9 +246,20 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying foundAnswer(None) case Some(true) => // SAT - val model = extractModel(solver.getModel) + val (valid, model) = if (!this.disableChecks && requireQuantification) { + validatedModel(silenceErrors = false) + } else { + true -> extractModel(solver.getModel) + } + solver.pop() - foundAnswer(Some(true), Some(model)) + if (valid) { + foundAnswer(Some(true), Some(model)) + } else { + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model.asString(context)) + foundAnswer(None, Some(model)) + } case Some(false) if !unrollingBank.canUnroll => solver.pop() @@ -234,12 +288,9 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying case Some(true) => if (feelingLucky && !interrupted) { - val model = extractModel(solver.getModel) - // we might have been lucky :D - if (isValidModel(model, silenceErrors = true)) { - foundAnswer(Some(true), Some(model)) - } + val (valid, model) = validatedModel(silenceErrors = true) + if (valid) foundAnswer(Some(true), Some(model)) } case None => diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index e5908ee1f9ebe139692b4bc0176cb46b7400f2c4..f2d1c1d6f47f52b266faf0b39173977874ea5179 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -11,10 +11,11 @@ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.Quantification.{QuantificationTypeMatcher => QTM} import Instantiation._ -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet, Stack => MutableStack, Queue} object Matcher { def argValue[T](arg: Either[T, Matcher[T]]): T = arg match { @@ -24,18 +25,24 @@ object Matcher { } case class Matcher[T](caller: T, tpe: TypeTree, args: Seq[Either[T, Matcher[T]]], encoded: T) { - override def toString = "M(" + caller + " : " + tpe + ", " + args.map(Matcher.argValue).mkString("(",",",")") + ")" + override def toString = caller + args.map { + case Right(m) => m.toString + case Left(v) => v.toString + }.mkString("(",",",")") - def substitute(substituter: T => T): Matcher[T] = copy( + def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]] = Map.empty): Matcher[T] = copy( caller = substituter(caller), - args = args.map { arg => arg.left.map(substituter).right.map(_.substitute(substituter)) }, + args = args.map { + case Left(v) => matcherSubst.get(v) match { + case Some(m) => Right(m) + case None => Left(substituter(v)) + } + case Right(m) => Right(m.substitute(substituter, matcherSubst)) + }, encoded = substituter(encoded) ) } -case class IllegalQuantificationException(expr: Expr, msg: String) - extends Exception(msg +" @ " + expr) - class QuantificationTemplate[T]( val quantificationManager: QuantificationManager[T], val start: T, @@ -111,20 +118,36 @@ object QuantificationTemplate { class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) { private val quantifications = new IncrementalSeq[MatcherQuantification] - private val instantiated = new InstantiationContext - private val fInstantiated = new InstantiationContext { - override def apply(p: (T, Matcher[T])): Boolean = - corresponding(p._2).exists(_._2.args == p._2.args) - } + private val instCtx = new InstantiationContext + + private val handled = new ContextMap + private val ignored = new ContextMap private val known = new IncrementalSet[T] + private val lambdaAxioms = new IncrementalSet[(LambdaTemplate[T], Seq[(Identifier, T)])] + + override protected def incrementals: List[IncrementalState] = + List(quantifications, instCtx, handled, ignored, known, lambdaAxioms) ++ super.incrementals + + private sealed abstract class MatcherKey(val tpe: TypeTree) + private case class CallerKey(caller: T, tt: TypeTree) extends MatcherKey(tt) + private case class LambdaKey(lambda: Lambda, tt: TypeTree) extends MatcherKey(tt) + private case class TypeKey(tt: TypeTree) extends MatcherKey(tt) - private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = correspond(qm, m.caller, m.tpe) - private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = qm.tpe match { - case _: FunctionType => qm.tpe == tpe && (qm.caller == caller || !known(caller)) - case _ => qm.tpe == tpe + private def matcherKey(caller: T, tpe: TypeTree): MatcherKey = tpe match { + case _: FunctionType if known(caller) => CallerKey(caller, tpe) + case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structuralKey, tpe) + case _ => TypeKey(tpe) } + @inline + private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = + correspond(qm, m.caller, m.tpe) + + @inline + private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = + matcherKey(qm.caller, qm.tpe) == matcherKey(caller, tpe) + private val uniformQuantMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty private val uniformQuantSet: MutableSet[T] = MutableSet.empty @@ -150,134 +173,166 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage }.toMap } - override protected def incrementals: List[IncrementalState] = - List(quantifications, instantiated, fInstantiated, known) ++ super.incrementals - def assumptions: Seq[T] = quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq - def instantiations: Seq[(T, Matcher[T])] = (instantiated.all ++ fInstantiated.all).toSeq + def instantiations: (Map[TypeTree, Matchers], Map[T, Matchers], Map[Lambda, Matchers]) = { + var typeInsts: Map[TypeTree, Matchers] = Map.empty + var partialInsts: Map[T, Matchers] = Map.empty + var lambdaInsts: Map[Lambda, Matchers] = Map.empty - def instantiations(caller: T, tpe: TypeTree): Seq[(T, Matcher[T])] = - (instantiated.corresponding(caller, tpe) ++ fInstantiated.corresponding(caller, tpe)).toSeq + val instantiations = handled.instantiations ++ instCtx.map.instantiations + for ((key, matchers) <- instantiations) key match { + case TypeKey(tpe) => typeInsts += tpe -> matchers + case CallerKey(caller, _) => partialInsts += caller -> matchers + case LambdaKey(lambda, _) => lambdaInsts += lambda -> matchers + } + + (typeInsts, partialInsts, lambdaInsts) + } + + def toto: (Map[TypeTree, Matchers], Map[T, Matchers], Map[Lambda, Matchers]) = { + var typeInsts: Map[TypeTree, Matchers] = Map.empty + var partialInsts: Map[T, Matchers] = Map.empty + var lambdaInsts: Map[Lambda, Matchers] = Map.empty + + for ((key, matchers) <- ignored.instantiations) key match { + case TypeKey(tpe) => typeInsts += tpe -> matchers + case CallerKey(caller, _) => partialInsts += caller -> matchers + case LambdaKey(lambda, _) => lambdaInsts += lambda -> matchers + } + + (typeInsts, partialInsts, lambdaInsts) + } override def registerFree(ids: Seq[(Identifier, T)]): Unit = { super.registerFree(ids) known ++= ids.map(_._2) } - private type Context = Set[(T, Matcher[T])] + private def matcherDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { + case Right(ma) => matcherDepth(ma) + case _ => 0 + }).max - private class ContextMap( - private val tpeMap: MutableMap[TypeTree, Context] = MutableMap.empty, - private val funMap: MutableMap[T, Context] = MutableMap.empty - ) { - def +=(p: (T, Matcher[T])): Unit = { - tpeMap(p._2.tpe) = tpeMap.getOrElse(p._2.tpe, Set.empty) + p - p match { - case (_, Matcher(caller, tpe: FunctionType, _, _)) if known(caller) => - funMap(caller) = funMap.getOrElse(caller, Set.empty) + p - case _ => - } - } + private def encodeEnablers(es: Set[T]): T = encoder.mkAnd(es.toSeq.sortBy(_.toString) : _*) - def merge(that: ContextMap): this.type = { - for ((tpe, values) <- that.tpeMap) tpeMap(tpe) = tpeMap.getOrElse(tpe, Set.empty) ++ values - for ((caller, values) <- that.funMap) funMap(caller) = funMap.getOrElse(caller, Set.empty) ++ values - this + private type Matchers = Set[(T, Matcher[T])] + + private class Context private(ctx: Map[Matcher[T], Set[Set[T]]]) extends Iterable[(Set[T], Matcher[T])] { + def this() = this(Map.empty) + + def apply(p: (Set[T], Matcher[T])): Boolean = ctx.get(p._2) match { + case None => false + case Some(blockerSets) => blockerSets(p._1) || blockerSets.exists(set => set.subsetOf(p._1)) } - @inline - def get(m: Matcher[T]): Context = get(m.caller, m.tpe) + def +(p: (Set[T], Matcher[T])): Context = if (apply(p)) this else { + val prev = ctx.getOrElse(p._2, Seq.empty) + val newSet = prev.filterNot(set => p._1.subsetOf(set)).toSet + p._1 + new Context(ctx + (p._2 -> newSet)) + } - def get(caller: T, tpe: TypeTree): Context = - funMap.getOrElse(caller, Set.empty) ++ tpeMap.getOrElse(tpe, Set.empty) + def ++(that: Context): Context = that.foldLeft(this)((ctx, p) => ctx + p) - override def clone = new ContextMap(tpeMap.clone, funMap.clone) + def iterator = ctx.toSeq.flatMap { case (m, bss) => bss.map(bs => bs -> m) }.iterator + def toMatchers: Matchers = this.map(p => encodeEnablers(p._1) -> p._2).toSet } - private class InstantiationContext private ( - private var _instantiated : Context, - private var _next : Context, - private var _map : ContextMap, - private var _count : Int + private class ContextMap( + private var tpeMap: MutableMap[TypeTree, Context] = MutableMap.empty, + private var funMap: MutableMap[MatcherKey, Context] = MutableMap.empty ) extends IncrementalState { - - def this() = this(Set.empty, Set.empty, new ContextMap, 0) - def this(ctx: InstantiationContext) = this(ctx._instantiated, Set.empty, ctx._map.clone, ctx._count) - - private val stack = new scala.collection.mutable.Stack[(Context, Context, ContextMap, Int)] + private val stack = new MutableStack[(MutableMap[TypeTree, Context], MutableMap[MatcherKey, Context])] def clear(): Unit = { stack.clear() - _instantiated = Set.empty - _next = Set.empty - _map = new ContextMap - _count = 0 + tpeMap.clear() + funMap.clear() } def reset(): Unit = clear() - def push(): Unit = stack.push((_instantiated, _next, _map.clone, _count)) + def push(): Unit = { + stack.push((tpeMap, funMap)) + tpeMap = tpeMap.clone + funMap = funMap.clone + } def pop(): Unit = { - val (instantiated, next, map, count) = stack.pop() - _instantiated = instantiated - _next = next - _map = map - _count = count + val (ptpeMap, pfunMap) = stack.pop() + tpeMap = ptpeMap + funMap = pfunMap + } + + def +=(p: (Set[T], Matcher[T])): Unit = matcherKey(p._2.caller, p._2.tpe) match { + case TypeKey(tpe) => tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) + p + case key => funMap(key) = funMap.getOrElse(key, new Context) + p } - def count = _count - def instantiated = _instantiated - def all = _instantiated ++ _next + def merge(that: ContextMap): this.type = { + for ((tpe, values) <- that.tpeMap) tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) ++ values + for ((caller, values) <- that.funMap) funMap(caller) = funMap.getOrElse(caller, new Context) ++ values + this + } - def corresponding(m: Matcher[T]): Context = _map.get(m) - def corresponding(caller: T, tpe: TypeTree): Context = _map.get(caller, tpe) + def get(caller: T, tpe: TypeTree): Context = + funMap.getOrElse(matcherKey(caller, tpe), new Context) ++ tpeMap.getOrElse(tpe, new Context) - def apply(p: (T, Matcher[T])): Boolean = _instantiated(p) + def get(key: MatcherKey): Context = key match { + case TypeKey(tpe) => tpeMap.getOrElse(tpe, new Context) + case key => funMap.getOrElse(key, new Context) + } - def inc(): Unit = _count += 1 + def instantiations: Map[MatcherKey, Matchers] = + (funMap.toMap ++ tpeMap.map { case (tpe,ms) => TypeKey(tpe) -> ms }).mapValues(_.toMatchers) + } - def +=(p: (T, Matcher[T])): Unit = { - if (!this(p)) _next += p + private class InstantiationContext private ( + private var _instantiated : Context, val map : ContextMap + ) extends IncrementalState { + + private val stack = new MutableStack[Context] + + def this() = this(new Context, new ContextMap) + + def clear(): Unit = { + stack.clear() + map.clear() + _instantiated = new Context } - def ++=(ps: Iterable[(T, Matcher[T])]): Unit = { - for (p <- ps) this += p + def reset(): Unit = clear() + + def push(): Unit = { + stack.push(_instantiated) + map.push() } - def consume: Iterator[(T, Matcher[T])] = { - var n = _next - _next = Set.empty - - new Iterator[(T, Matcher[T])] { - def hasNext = n.nonEmpty - def next = { - val p @ (b,m) = n.head - _instantiated += p - _map += p - n -= p - p - } - } + def pop(): Unit = { + _instantiated = stack.pop() + map.pop() } - def instantiateNext: Instantiation[T] = { - var instantiation = Instantiation.empty[T] - for ((b,m) <- consume) { - println("consuming " + (b -> m)) - for (q <- quantifications) { - instantiation ++= q.instantiate(b, m)(this) - } + def instantiated: Context = _instantiated + def apply(p: (Set[T], Matcher[T])): Boolean = _instantiated(p) + + def corresponding(m: Matcher[T]): Context = map.get(m.caller, m.tpe) + + def instantiate(blockers: Set[T], matcher: Matcher[T])(qs: MatcherQuantification*): Instantiation[T] = { + if (this(blockers -> matcher)) { + Instantiation.empty[T] + } else { + map += (blockers -> matcher) + _instantiated += (blockers -> matcher) + var instantiation = Instantiation.empty[T] + for (q <- qs) instantiation ++= q.instantiate(blockers, matcher) + instantiation } - instantiation } def merge(that: InstantiationContext): this.type = { _instantiated ++= that._instantiated - _next ++= that._next - _map.merge(that._map) - _count = _count max that._count + map.merge(that.map) this } } @@ -293,50 +348,117 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val applications: Map[T, Set[App[T]]] val lambdas: Seq[LambdaTemplate[T]] - private def mappings(blocker: T, matcher: Matcher[T], instCtx: InstantiationContext): Set[(T, Map[T, T])] = { - - // Build a mapping from applications in the quantified statement to all potential concrete - // applications previously encountered. Also make sure the current `app` is in the mapping - // as other instantiations have been performed previously when the associated applications - // were first encountered. - val matcherMappings: Set[Set[(T, Matcher[T], Matcher[T])]] = matchers - // 1. select an application in the quantified proposition for which the current app can - // be bound when generating the new constraints - .filter(qm => correspond(qm, matcher)) - // 2. build the instantiation mapping associated to the chosen current application binding + private lazy val depth = matchers.map(matcherDepth).max + private lazy val transMatchers: Set[Matcher[T]] = (for { + (b, ms) <- allMatchers.toSeq + m <- ms if !matchers(m) && matcherDepth(m) <= depth + } yield m).toSet + + /* Build a mapping from applications in the quantified statement to all potential concrete + * applications previously encountered. Also make sure the current `app` is in the mapping + * as other instantiations have been performed previously when the associated applications + * were first encountered. + */ + private def mappings(bs: Set[T], matcher: Matcher[T]): Set[Set[(Set[T], Matcher[T], Matcher[T])]] = { + /* 1. select an application in the quantified proposition for which the current app can + * be bound when generating the new constraints + */ + matchers.filter(qm => correspond(qm, matcher)) + + /* 2. build the instantiation mapping associated to the chosen current application binding */ .flatMap { bindingMatcher => - // 2.1. select all potential matches for each quantified application + /* 2.1. select all potential matches for each quantified application */ val matcherToInstances = matchers .map(qm => if (qm == bindingMatcher) { - bindingMatcher -> Set(blocker -> matcher) + bindingMatcher -> Set(bs -> matcher) } else { qm -> instCtx.corresponding(qm) }).toMap - // 2.2. based on the possible bindings for each quantified application, build a set of - // instantiation mappings that can be used to instantiate all necessary constraints - extractMappings(matcherToInstances) + /* 2.2. based on the possible bindings for each quantified application, build a set of + * instantiation mappings that can be used to instantiate all necessary constraints + */ + val allMappings = matcherToInstances.foldLeft[Set[Set[(Set[T], Matcher[T], Matcher[T])]]](Set(Set.empty)) { + case (mappings, (qm, instances)) => Set(instances.toSeq.flatMap { + case (bs, m) => mappings.map(mapping => mapping + ((bs, qm, m))) + } : _*) + } + + /* 2.3. filter out bindings that don't make sense where abstract sub-matchers + * (matchers in arguments of other matchers) are bound to different concrete + * matchers in the argument and quorum positions + */ + allMappings.filter { s => + def expand(ms: Traversable[(Either[T,Matcher[T]], Either[T,Matcher[T]])]): Set[(Matcher[T], Matcher[T])] = ms.flatMap { + case (Right(qm), Right(m)) => Set(qm -> m) ++ expand(qm.args zip m.args) + case _ => Set.empty[(Matcher[T], Matcher[T])] + }.toSet + + expand(s.map(p => Right(p._2) -> Right(p._3))).groupBy(_._1).forall(_._2.size == 1) + } + + allMappings } + } + + private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Either[T, Matcher[T]]], Boolean) = { + var constraints: Set[T] = Set.empty + var matcherEqs: List[(T, T)] = Nil + var subst: Map[T, Either[T, Matcher[T]]] = Map.empty + + for { + (bs, qm @ Matcher(qcaller, _, qargs, _), m @ Matcher(caller, _, args, _)) <- mapping + _ = constraints ++= bs + _ = matcherEqs :+= qm.encoded -> m.encoded + (qarg, arg) <- (qargs zip args) + } qarg match { + case Left(quant) if subst.isDefinedAt(quant) => + constraints += encoder.mkEquals(quant, Matcher.argValue(arg)) + case Left(quant) if quantified(quant) => + subst += quant -> arg + case Right(qam) => + val argVal = Matcher.argValue(arg) + constraints += encoder.mkEquals(qam.encoded, argVal) + matcherEqs :+= qam.encoded -> argVal + } + + val substituter = encoder.substitute(subst.mapValues(Matcher.argValue)) + val enablers = (if (constraints.isEmpty) Set(trueT) else constraints).map(substituter) + val isStrict = matcherEqs.forall(p => substituter(p._1) == p._2) - for (mapping <- matcherMappings) yield extractSubst(quantified, mapping) + (enablers, subst, isStrict) } - def instantiate(blocker: T, matcher: Matcher[T])(implicit instCtx: InstantiationContext): Instantiation[T] = { + def instantiate(bs: Set[T], matcher: Matcher[T]): Instantiation[T] = { var instantiation = Instantiation.empty[T] - for ((enabler, subst) <- mappings(blocker, matcher, instCtx)) { + for (mapping <- mappings(bs, matcher)) { + val (enablers, subst, isStrict) = extractSubst(mapping) + val enabler = encodeEnablers(enablers) + val baseSubstMap = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } val lambdaSubstMap = lambdas map(lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) - val substMap = subst ++ baseSubstMap ++ lambdaSubstMap ++ instanceSubst(enabler) + val substMap = subst.mapValues(Matcher.argValue) ++ baseSubstMap ++ lambdaSubstMap ++ instanceSubst(enabler) instantiation ++= Template.instantiate(encoder, QuantificationManager.this, clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) + + val msubst = subst.collect { case (c, Right(m)) => c -> m } val substituter = encoder.substitute(substMap) - for ((b, ms) <- allMatchers; m <- ms if !matchers(m)) { - println(m.substitute(substituter)) - instCtx += substituter(b) -> m.substitute(substituter) + + for ((b,ms) <- allMatchers; m <- ms) { + val sb = enablers + substituter(b) + val sm = m.substitute(substituter, matcherSubst = msubst) + + if (matchers(m)) { + handled += sb -> sm + } else if (transMatchers(m) && isStrict) { + instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) + } else { + ignored += sb -> sm + } } } @@ -374,8 +496,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - private val blockerId = FreshIdentifier("blocker", BooleanType, true) - private val blockerCache: MutableMap[T, T] = MutableMap.empty + private lazy val blockerId = FreshIdentifier("blocker", BooleanType, true) + private lazy val blockerCache: MutableMap[T, T] = MutableMap.empty private class Axiom ( val start: T, @@ -396,8 +518,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case Some(b) => b case None => val nb = encoder.encodeId(blockerId) - blockerCache(enabler) = nb - blockerCache(nb) = nb + blockerCache += enabler -> nb + blockerCache += nb -> nb nb } @@ -405,58 +527,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - private def extractMappings( - bindings: Map[Matcher[T], Set[(T, Matcher[T])]] - ): Set[Set[(T, Matcher[T], Matcher[T])]] = { - val allMappings = bindings.foldLeft[Set[Set[(T, Matcher[T], Matcher[T])]]](Set(Set.empty)) { - case (mappings, (qm, instances)) => Set(instances.toSeq.flatMap { - case (b, m) => mappings.map(mapping => mapping + ((b, qm, m))) - } : _*) - } - - def subBindings(b: T, sm: Matcher[T], m: Matcher[T]): Set[(T, Matcher[T], Matcher[T])] = { - (for ((sarg, arg) <- sm.args zip m.args) yield { - (sarg, arg) match { - case (Right(sargm), Right(argm)) => Set((b, sargm, argm)) ++ subBindings(b, sargm, argm) - case _ => Set.empty[(T, Matcher[T], Matcher[T])] - } - }).flatten.toSet - } - - allMappings.filter { s => - val withSubs = s ++ s.flatMap { case (b, sm, m) => subBindings(b, sm, m) } - withSubs.groupBy(_._2).forall(_._2.size == 1) - } - - allMappings - } - - private def extractSubst(quantified: Set[T], mapping: Set[(T, Matcher[T], Matcher[T])]): (T, Map[T,T]) = { - var constraints: List[T] = Nil - var subst: Map[T, T] = Map.empty - - for { - (b, Matcher(qcaller, _, qargs, _), Matcher(caller, _, args, _)) <- mapping - _ = constraints :+= b - (qarg, arg) <- (qargs zip args) - argVal = Matcher.argValue(arg) - } qarg match { - case Left(quant) if subst.isDefinedAt(quant) => - constraints :+= encoder.mkEquals(quant, argVal) - case Left(quant) if quantified(quant) => - subst += quant -> argVal - case _ => - constraints :+= encoder.mkEquals(Matcher.argValue(qarg), argVal) - } - - val enabler = - if (constraints.isEmpty) trueT - else if (constraints.size == 1) constraints.head - else encoder.mkAnd(constraints : _*) - - (encoder.substitute(subst)(enabler), subst) - } - private def extractQuorums( quantified: Set[T], matchers: Set[Matcher[T]], @@ -481,8 +551,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) } - private val lambdaAxioms: MutableSet[(LambdaTemplate[T], Seq[(Identifier, T)])] = MutableSet.empty - def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, T]): Instantiation[T] = { val quantifiers = template.arguments map { case (id, idT) => id -> substMap(idT) @@ -556,27 +624,21 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage quantifications += axiom - for (instCtx <- List(instantiated, fInstantiated)) { - val pCtx = new InstantiationContext(instCtx) - - for ((b, m) <- pCtx.instantiated) { - instantiation ++= axiom.instantiate(b, m)(pCtx) - } - - for (i <- (1 to instCtx.count)) { - instantiation ++= pCtx.instantiateNext - } - - instCtx.merge(pCtx) + val newCtx = new InstantiationContext() + for ((b,m) <- instCtx.instantiated) { + instantiation ++= newCtx.instantiate(b, m)(axiom) } + instCtx.merge(newCtx) } val quantifierSubst = uniformSubst(quantifiers) val substituter = encoder.substitute(quantifierSubst) - for (m <- matchers) { - instantiation ++= instantiateMatcher(trueT, m.substitute(substituter), fInstantiated) - } + for { + m <- matchers + sm = m.substitute(substituter) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } instantiation ++= instCtx.instantiate(Set(trueT), sm)(quantifications.toSeq : _*) instantiation } @@ -611,19 +673,11 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage quantifications += quantification - for (instCtx <- List(instantiated, fInstantiated)) { - val pCtx = new InstantiationContext(instCtx) - - for ((b, m) <- pCtx.instantiated) { - instantiation ++= quantification.instantiate(b, m)(pCtx) - } - - for (i <- (1 to instCtx.count)) { - instantiation ++= pCtx.instantiateNext - } - - instCtx.merge(pCtx) + val newCtx = new InstantiationContext() + for ((b,m) <- instCtx.instantiated) { + instantiation ++= newCtx.instantiate(b, m)(quantification) } + instCtx.merge(newCtx) quantification.qs._2 } @@ -639,36 +693,74 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val quantifierSubst = uniformSubst(template.quantifiers) val substituter = encoder.substitute(substMap ++ quantifierSubst) - for ((_, ms) <- template.matchers; m <- ms) { - instantiation ++= instantiateMatcher(trueT, m.substitute(substituter), fInstantiated) - } + for { + (_, ms) <- template.matchers; m <- ms + sm = m.substitute(substituter) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } instantiation ++= instCtx.instantiate(Set(trueT), sm)(quantifications.toSeq : _*) instantiation } - private def instantiateMatcher(blocker: T, matcher: Matcher[T], instCtx: InstantiationContext): Instantiation[T] = { - if (instCtx(blocker -> matcher)) { - Instantiation.empty[T] - } else { - println("instantiating " + (blocker -> matcher)) - var instantiation: Instantiation[T] = Instantiation.empty - - val pCtx = new InstantiationContext(instCtx) - pCtx += blocker -> matcher - pCtx.inc() // pCtx.count == instCtx.count + 1 + def instantiateMatcher(blocker: T, matcher: Matcher[T]): Instantiation[T] = { + instCtx.instantiate(Set(blocker), matcher)(quantifications.toSeq : _*) + } - // we just inc()'ed so we can start at 1 (instCtx.count is guaranteed to have increased) - for (i <- (1 to instCtx.count)) { - instantiation ++= pCtx.instantiateNext + private type SetDef = (T, (Identifier, T), (Identifier, T), Seq[T], T, T, T) + private val setConstructors: MutableMap[TypeTree, SetDef] = MutableMap.empty + + def checkClauses: Seq[T] = { + val clauses = new scala.collection.mutable.ListBuffer[T] + + for ((key, ctx) <- ignored.instantiations) { + val insts = instCtx.map.get(key).toMatchers + + val QTM(argTypes, _) = key.tpe + val tupleType = tupleTypeWrap(argTypes) + + val (guardT, (setPrev, setPrevT), (setNext, setNextT), elems, containsT, emptyT, setT) = + setConstructors.getOrElse(tupleType, { + val guard = FreshIdentifier("guard", BooleanType) + val setPrev = FreshIdentifier("prevSet", SetType(tupleType)) + val setNext = FreshIdentifier("nextSet", SetType(tupleType)) + val elems = argTypes.map(tpe => FreshIdentifier("elem", tpe)) + + val elemExpr = tupleWrap(elems.map(_.toVariable)) + val contextExpr = And( + Implies(Variable(guard), Equals(Variable(setNext), + SetUnion(Variable(setPrev), FiniteSet(Set(elemExpr), tupleType)))), + Implies(Not(Variable(guard)), Equals(Variable(setNext), Variable(setPrev)))) + + val guardP = guard -> encoder.encodeId(guard) + val setPrevP = setPrev -> encoder.encodeId(setPrev) + val setNextP = setNext -> encoder.encodeId(setNext) + val elemsP = elems.map(e => e -> encoder.encodeId(e)) + + val containsT = encoder.encodeExpr(elemsP.toMap + setPrevP)(ElementOfSet(elemExpr, setPrevP._1.toVariable)) + val emptyT = encoder.encodeExpr(Map.empty)(FiniteSet(Set.empty, tupleType)) + val contextT = encoder.encodeExpr(Map(guardP, setPrevP, setNextP) ++ elemsP)(contextExpr) + + val setDef = (guardP._2, setPrevP, setNextP, elemsP.map(_._2), containsT, emptyT, contextT) + setConstructors += key.tpe -> setDef + setDef + }) + + var prev = emptyT + for ((b, m) <- insts.toSeq) { + val next = encoder.encodeId(setNext) + val argsMap = (elems zip m.args).map { case (idT, arg) => idT -> Matcher.argValue(arg) } + val substMap = Map(guardT -> b, setPrevT -> prev, setNextT -> next) ++ argsMap + prev = next + clauses += encoder.substitute(substMap)(setT) } - instantiation ++= instCtx.merge(pCtx).instantiateNext - - instantiation + val setMap = Map(setPrevT -> prev) + for ((b, m) <- ctx.toSeq) { + val substMap = setMap ++ (elems zip m.args).map(p => p._1 -> Matcher.argValue(p._2)) + clauses += encoder.substitute(substMap)(encoder.mkImplies(b, containsT)) + } } - } - def instantiateMatcher(blocker: T, matcher: Matcher[T]): Instantiation[T] = { - instantiateMatcher(blocker, matcher, instantiated) + clauses.toSeq } } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index f0e0d745baf5492337d922e60db5a18b35e787b8..7a3df85ff334077fd05d0382fad5ab3a45f050d2 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -11,6 +11,7 @@ import purescala.ExprOps._ import purescala.Types._ import purescala.Definitions._ import purescala.Constructors._ +import purescala.Quantification._ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val assumePreHolds: Boolean) { @@ -133,6 +134,45 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], andJoin(rec(invocation, body, args, inlineFirst)) } + private def minimalFlattening(inits: Set[Identifier], conj: Expr): (Set[Identifier], Expr) = { + var mapping: Map[Expr, Expr] = Map.empty + var quantified: Set[Identifier] = inits + var quantifierEqualities: Seq[(Expr, Identifier)] = Seq.empty + + val newConj = postMap { + case expr if mapping.isDefinedAt(expr) => + Some(mapping(expr)) + + case expr @ QuantificationMatcher(c, args) => + val isMatcher = args.exists { case Variable(id) => quantified(id) case _ => false } + val isRelevant = (variablesOf(expr) & quantified).nonEmpty + if (!isMatcher && isRelevant) { + val newArgs = args.map { + case arg @ QuantificationMatcher(_, _) if (variablesOf(arg) & quantified).nonEmpty => + val id = FreshIdentifier("flat", arg.getType) + quantifierEqualities :+= (arg -> id) + quantified += id + Variable(id) + case arg => arg + } + + val newExpr = replace((args zip newArgs).toMap, expr) + mapping += expr -> newExpr + Some(newExpr) + } else { + None + } + + case _ => None + } (conj) + + val flatConj = implies(andJoin(quantifierEqualities.map { + case (arg, id) => Equals(arg, Variable(id)) + }), newConj) + + (quantified, flatConj) + } + def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Seq[LambdaTemplate[T]], Seq[QuantificationTemplate[T]]) = { @@ -294,7 +334,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val conjunctQs = conjuncts.map { conjunct => val vars = variablesOf(conjunct) - val quantifiers = args.map(_.id).filter(vars).toSet + val inits = args.map(_.id).filter(vars).toSet + val (quantifiers, flatConj) = minimalFlattening(inits, conjunct) val idQuantifiers : Seq[Identifier] = quantifiers.toSeq val trQuantifiers : Seq[T] = idQuantifiers.map(encoder.encodeId) @@ -304,7 +345,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val inst: Identifier = FreshIdentifier("inst", BooleanType, true) val guard: Identifier = FreshIdentifier("guard", BooleanType, true) - val clause = Equals(Variable(inst), Implies(Variable(guard), conjunct)) + val clause = Equals(Variable(inst), Implies(Variable(guard), flatConj)) val qs: (Identifier, T) = q -> encoder.encodeId(q) val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars diff --git a/src/main/scala/leon/solvers/z3/FairZ3Component.scala b/src/main/scala/leon/solvers/z3/FairZ3Component.scala index f321d516d66a4d43751519f06e9c65b3a0e4d126..256dcf38fc5dc99f31ce2f6fb30dc12e0caa5268 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Component.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Component.scala @@ -13,6 +13,7 @@ trait FairZ3Component extends LeonComponent { val optUseCodeGen = LeonFlagOptionDef("codegen", "Use compiled evaluator instead of interpreter", false) val optUnrollCores = LeonFlagOptionDef("unrollcores", "Use unsat-cores to drive unrolling while remaining fair", false) val optAssumePre = LeonFlagOptionDef("assumepre", "Assume precondition holds (pre && f(x) = body) when unfolding", false) + val optNoChecks = LeonFlagOptionDef("nochecks", "Disable counter-example check in presence of foralls" , false) override val definedOptions: Set[LeonOptionDef[Any]] = Set(optEvalGround, optCheckModels, optFeelingLucky, optUseCodeGen, optUnrollCores, optAssumePre) diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 105bfa7dc036e790cc10ac5f11f7db47b1a72537..a187549c514cc8fc83ec7476493e1d9e96a48763 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -26,7 +26,8 @@ class FairZ3Solver(val context: LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction with FairZ3Component - with EvaluatingSolver { + with EvaluatingSolver + with QuantificationSolver { enclosing => @@ -36,6 +37,9 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val evalGroundApps = context.findOptionOrDefault(optEvalGround) val unrollUnsatCores = context.findOptionOrDefault(optUnrollCores) val assumePreHolds = context.findOptionOrDefault(optAssumePre) + val disableChecks = context.findOptionOrDefault(optNoChecks) + + assert(!checkModels || !disableChecks, "Options \"checkmodels\" and \"nochecks\" are mutually exclusive") protected val errors = new IncrementalBijection[Unit, Boolean]() protected def hasError = errors.getB(()) contains true @@ -87,85 +91,26 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } } - val funDomains = ids.flatMap(id => id.getType match { - case ft @ FunctionType(fromTypes, _) => variables.getB(id.toVariable) match { - case Some(z3ID) => Some(id -> templateGenerator.manager.instantiations(z3ID, ft).flatMap { - case (b, m) => extract(b, m) - }) - case _ => None + val (typeInsts, partialInsts, lambdaInsts) = templateGenerator.manager.instantiations + + val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + val funDomains: Map[Identifier, Set[Seq[Expr]]] = partialInsts.flatMap { + case (c, domain) => variables.getA(c).collect { + case Variable(id) => id -> domain.flatMap { case (b, m) => extract(b, m) }.toSet } - case _ => None - }).toMap.mapValues(_.toSet) + } - val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) - val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { + case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } val asMap = modelToMap(model, ids) val asDMap = purescala.Quantification.extractModel(asMap, funDomains, typeDomains, evaluator) - - val domain = new HenkinDomains(typeDomains) - new HenkinModel(asDMap, domain) - } - - private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, HenkinModel) = { - if(!interrupted) { - - val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap - val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { - if (functions containsB p._1) { - val tfd = functions.toA(p._1) - if (!tfd.hasImplementation) { - val (cses, default) = p._2 - val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( - andJoin( - q._1.zip(tfd.params).map(a12 => Equals(fromZ3Formula(model, a12._1, a12._2.getType), Variable(a12._2.id))) - ), - fromZ3Formula(model, q._2, tfd.returnType), - expr)) - Seq((tfd.id, ite)) - } else Seq() - } else Seq() - }) - - val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { - if(functions containsB p._1) { - val tfd = functions.toA(p._1) - if(!tfd.hasImplementation) { - Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) - } else Seq() - } else Seq() - }).toMap - - val leonModel = extractModel(model, variables) - val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) - val evalResult = evaluator.eval(formula, fullModel) - - evalResult match { - case EvaluationResults.Successful(BooleanLiteral(true)) => - reporter.debug("- Model validated.") - (true, fullModel) - - case EvaluationResults.Successful(res) => - assert(res == BooleanLiteral(false), "Checking model returned non-boolean") - reporter.debug("- Invalid model.") - (false, fullModel) - - case EvaluationResults.RuntimeError(msg) => - reporter.debug("- Model leads to runtime error.") - (false, fullModel) - - case EvaluationResults.EvaluatorError(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to evaluator error: " + msg) - } else { - reporter.warning("Something went wrong. While evaluating the model, we got this : " + msg) - } - (false, fullModel) - - } - } else { - (false, HenkinModel.empty) - } + val domains = new HenkinDomains(lambdaDomains, typeDomains) + new HenkinModel(asDMap, domains) } implicit val z3Printable = (z3: Z3AST) => new Printable { @@ -300,6 +245,115 @@ class FairZ3Solver(val context: LeonContext, val program: Program) }).toSet } + def validatedModel(silenceErrors: Boolean) : (Boolean, HenkinModel) = { + if (interrupted) { + (false, HenkinModel.empty) + } else { + val lastModel = solver.getModel + val clauses = templateGenerator.manager.checkClauses + val optModel = if (clauses.isEmpty) Some(lastModel) else { + solver.push() + for (clause <- clauses) { + solver.assertCnstr(clause) + } + + reporter.debug(" - Verifying model transitivity") + val timer = context.timers.solvers.z3.check.start() + solver.push() // FIXME: remove when z3 bug is fixed + val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.satisfactionAssumptions) :_*) + solver.pop() // FIXME: remove when z3 bug is fixed + timer.stop() + + val solverModel = res match { + case Some(true) => + Some(solver.getModel) + + case Some(false) => + val msg = "- Transitivity independence not guaranteed for model" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + + case None => + val msg = "- Unknown for transitivity independence!?" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + } + + solver.pop() + solverModel + } + + val model = optModel getOrElse lastModel + + val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap + val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { + if (functions containsB p._1) { + val tfd = functions.toA(p._1) + if (!tfd.hasImplementation) { + val (cses, default) = p._2 + val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( + andJoin( + q._1.zip(tfd.params).map(a12 => Equals(fromZ3Formula(model, a12._1, a12._2.getType), Variable(a12._2.id))) + ), + fromZ3Formula(model, q._2, tfd.returnType), + expr)) + Seq((tfd.id, ite)) + } else Seq() + } else Seq() + }) + + val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { + if(functions containsB p._1) { + val tfd = functions.toA(p._1) + if(!tfd.hasImplementation) { + Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) + } else Seq() + } else Seq() + }).toMap + + val leonModel = extractModel(model, freeVars.toSet) + val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) + + if (!optModel.isDefined) { + (false, leonModel) + } else { + (evaluator.check(entireFormula, fullModel) match { + case EvaluationResults.CheckSuccess => + reporter.debug("- Model validated.") + true + + case EvaluationResults.CheckValidityFailure => + reporter.debug("- Invalid model.") + false + + case EvaluationResults.CheckRuntimeFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to evaluation error: " + msg) + } else { + reporter.warning("- Model leads to evaluation error: " + msg) + } + false + + case EvaluationResults.CheckQuantificationFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to quantification error: " + msg) + } else { + reporter.warning("- Model leads to quantification error: " + msg) + } + false + }, leonModel) + } + } + } + while(!foundDefinitiveAnswer && !interrupted) { //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) @@ -331,27 +385,18 @@ class FairZ3Solver(val context: LeonContext, val program: Program) foundAnswer(None) case Some(true) => // SAT - - val z3model = solver.getModel() - - if (this.checkModels) { - val (isValid, model) = validateModel(z3model, entireFormula, allVars, silenceErrors = false) - - if (isValid) { - foundAnswer(Some(true), model) - } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model) - foundAnswer(None, model) - } + val (valid, model) = if (!this.disableChecks && (this.checkModels || requireQuantification)) { + validatedModel(false) } else { - val model = extractModel(z3model, allVars) - - //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") - //reporter.debug("- Found a model:") - //reporter.debug(modelAsString) + true -> extractModel(solver.getModel, allVars) + } + if (valid) { foundAnswer(Some(true), model) + } else { + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model.asString(context)) + foundAnswer(None, model) } case Some(false) if !unrollingBank.canUnroll => @@ -418,7 +463,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) //reporter.debug("SAT WITHOUT Blockers") if (this.feelingLucky && !interrupted) { // we might have been lucky :D - val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, allVars, silenceErrors = true) + val (wereWeLucky, cleanModel) = validatedModel(true) if(wereWeLucky) { foundAnswer(Some(true), cleanModel) diff --git a/src/test/resources/regression/verification/purescala/invalid/Existentials.scala b/src/test/resources/regression/verification/purescala/invalid/Existentials.scala new file mode 100644 index 0000000000000000000000000000000000000000..19679db9202b325edc0eb8dfff8eeea41bb47d36 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Existentials.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object Existentials { + + def exists[A](p: A => Boolean): Boolean = !forall((x: A) => !p(x)) + + def check1(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) == exists((y1:BigInt) => p(y1)) + }.holds + + /* + def check2(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) ==> exists((y1:BigInt) => p(y1)) + }.holds + */ +} diff --git a/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala b/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala new file mode 100644 index 0000000000000000000000000000000000000000..83773b2224fb8790acfc25125143b358c1226f05 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object ForallAssoc { + + /* + def test3(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, f(4, 5)))) == f(f(f(f(1, 2), 3), 4), 4) + }.holds + */ + + def test4(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, 4))) == 0 + }.holds + +} diff --git a/src/test/resources/regression/verification/purescala/valid/Existentials.scala b/src/test/resources/regression/verification/purescala/valid/Existentials.scala new file mode 100644 index 0000000000000000000000000000000000000000..992d58cd0678e98146ad1085a22269671a35421c --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Existentials.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object Existentials { + + def exists[A](p: A => Boolean): Boolean = !forall((x: A) => !p(x)) + + /* + def check1(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) == exists((y1:BigInt) => p(y1)) + }.holds + */ + + def check2(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) ==> exists((y1:BigInt) => p(y1)) + }.holds +} diff --git a/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala b/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae90e9489678e1f4cc20d90b0cd23767cc317546 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala @@ -0,0 +1,23 @@ +import leon.lang._ + +object ForallAssoc { + + def ex[A](x1: A, x2: A, x3: A, x4: A, x5: A, f: (A, A) => A) = { + require(forall { + (x: A, y: A, z: A) => f(x, f(y, z)) == f(f(x, y), z) + }) + + f(x1, f(x2, f(x3, f(x4, x5)))) == f(f(x1, f(x2, f(x3, x4))), x5) + }.holds + + def test1(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, 4))) == f(f(f(1, 2), 3), 4) + }.holds + + def test2(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, f(4, 5)))) == f(f(f(f(1, 2), 3), 4), 5) + }.holds + +} diff --git a/src/test/resources/regression/verification/purescala/valid/Predicate.scala b/src/test/resources/regression/verification/purescala/valid/Predicate.scala new file mode 100644 index 0000000000000000000000000000000000000000..c011d4b910c2d3859b0e2cca80e069de5b19788b --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Predicate.scala @@ -0,0 +1,48 @@ +package leon.monads.predicate + +import leon.collection._ +import leon.lang._ +import leon.annotation._ + +object Predicate { + + def exists[A](p: A => Boolean): Boolean = !forall((a: A) => !p(a)) + + // Monadic bind + @inline + def flatMap[A,B](p: A => Boolean, f: A => (B => Boolean)): B => Boolean = { + (b: B) => exists[A](a => p(a) && f(a)(b)) + } + + // All monads are also functors, and they define the map function + @inline + def map[A,B](p: A => Boolean, f: A => B): B => Boolean = { + (b: B) => exists[A](a => p(a) && f(a) == b) + } + + /* + @inline + def >>=[B](f: A => Predicate[B]): Predicate[B] = flatMap(f) + + @inline + def >>[B](that: Predicate[B]) = >>= ( _ => that ) + + @inline + def withFilter(f: A => Boolean): Predicate[A] = { + Predicate { a => p(a) && f(a) } + } + */ + + def equals[A](p: A => Boolean, that: A => Boolean): Boolean = { + forall[A](a => p(a) == that(a)) + } + + def test[A,B,C](p: A => Boolean, f: A => B, g: B => C): Boolean = { + equals(map(map(p, f), g), map(p, (a: A) => g(f(a)))) + }.holds + + def testInt(p: BigInt => Boolean, f: BigInt => BigInt, g: BigInt => BigInt): Boolean = { + equals(map(map(p, f), g), map(p, (x: BigInt) => g(f(x)))) + }.holds +} +