diff --git a/build.sbt b/build.sbt index 5c7ac98914456ab5b098ac6a3fd6714b833d1a81..5b984ed01dc07a99e12adf882198d16d2a667043 100644 --- a/build.sbt +++ b/build.sbt @@ -143,7 +143,7 @@ def ghProject(repo: String, version: String) = RootProject(uri(s"${repo}#${versi lazy val bonsai = ghProject("git://github.com/colder/bonsai.git", "0fec9f97f4220fa94b1f3f305f2e8b76a3cd1539") -lazy val scalaSmtLib = ghProject("git://github.com/MikaelMayer/scala-smtlib.git", "8ef13ef3294ab823aed0bad40678f507b1fe63e2") +lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "372bb14d0c84953acc17f9a7e1592087adb0a3e1") lazy val root = (project in file(".")). configs(RegressionTest, IsabelleTest, IntegrTest). diff --git a/library/lang/Set.scala b/library/lang/Set.scala index 36c70a8372dd7a81d1737dfd321e81839d38fe6f..8f4595d33fada35acfa38435c3cd542116cf3cbd 100644 --- a/library/lang/Set.scala +++ b/library/lang/Set.scala @@ -17,6 +17,7 @@ case class Set[T](val theSet: scala.collection.immutable.Set[T]) { def ++(a: Set[T]): Set[T] = new Set[T](theSet ++ a.theSet) def -(a: T): Set[T] = new Set[T](theSet - a) def --(a: Set[T]): Set[T] = new Set[T](theSet -- a.theSet) + def size: BigInt = theSet.size def contains(a: T): Boolean = theSet.contains(a) def isEmpty: Boolean = theSet.isEmpty def subsetOf(b: Set[T]): Boolean = theSet.subsetOf(b.theSet) diff --git a/library/lang/package.scala b/library/lang/package.scala index 398ae9e36e53c689b69ae4dbda6788d643c25d42..b19ec529bdb3975956de00a7da8e8ad39bcf34e4 100644 --- a/library/lang/package.scala +++ b/library/lang/package.scala @@ -47,6 +47,9 @@ package object lang { @ignore def error[T](reason: java.lang.String): T = sys.error(reason) + @ignore + def old[T](value: T): T = value + @ignore implicit class Passes[A,B](io : (A,B)) { val (in, out) = io diff --git a/src/main/java/leon/codegen/runtime/Forall.java b/src/main/java/leon/codegen/runtime/Forall.java new file mode 100644 index 0000000000000000000000000000000000000000..f6877b604bcd069af6c2ce20d78a17d82d434dbe --- /dev/null +++ b/src/main/java/leon/codegen/runtime/Forall.java @@ -0,0 +1,35 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.HashMap; + +public abstract class Forall { + private static final HashMap<Tuple, Boolean> cache = new HashMap<Tuple, Boolean>(); + + protected final LeonCodeGenRuntimeHenkinMonitor monitor; + protected final Tuple closures; + protected final boolean check; + + public Forall(LeonCodeGenRuntimeMonitor monitor, Tuple closures) throws LeonCodeGenEvaluationException { + if (!(monitor instanceof LeonCodeGenRuntimeHenkinMonitor)) + throw new LeonCodeGenEvaluationException("Can't evaluate foralls without domain"); + + this.monitor = (LeonCodeGenRuntimeHenkinMonitor) monitor; + this.closures = closures; + this.check = (boolean) closures.get(closures.getArity() - 1); + } + + public boolean check() { + Tuple key = new Tuple(new Object[] { getClass(), monitor, closures }); // check is in the closures + if (cache.containsKey(key)) { + return cache.get(key); + } else { + boolean res = checkForall(); + cache.put(key, res); + return res; + } + } + + public abstract boolean checkForall(); +} diff --git a/src/main/java/leon/codegen/runtime/Lambda.java b/src/main/java/leon/codegen/runtime/Lambda.java index a6abbef37edbe8f87f480a21a6200e32a9e0206b..af255726311655efaeddea545c5e6e44afc15b8e 100644 --- a/src/main/java/leon/codegen/runtime/Lambda.java +++ b/src/main/java/leon/codegen/runtime/Lambda.java @@ -4,4 +4,6 @@ package leon.codegen.runtime; public abstract class Lambda { public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException; + public abstract void checkForall(boolean[] quantified); + public abstract void checkAxiom(); } diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java b/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java new file mode 100644 index 0000000000000000000000000000000000000000..f172316a2548a52c6b294f70101a15ebbb8ce98a --- /dev/null +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java @@ -0,0 +1,14 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +/** Such exceptions are thrown when the evaluator encounters a forall + * expression whose shape is not supported in Leon. */ +public class LeonCodeGenQuantificationException extends Exception { + + private static final long serialVersionUID = -1824885321497473916L; + + public LeonCodeGenQuantificationException(String msg) { + super(msg); + } +} diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java index 0be4ad91212be930ec5f8730ed30ebcf9a9f4e0a..597beec44b6a1a1719909e00ecb7d7916f0c7c03 100644 --- a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java @@ -7,26 +7,42 @@ import java.util.LinkedList; import java.util.HashMap; public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { - private final HashMap<Integer, List<Tuple>> domains = new HashMap<Integer, List<Tuple>>(); + private final HashMap<Integer, List<Tuple>> tpes = new HashMap<Integer, List<Tuple>>(); + private final HashMap<Class<?>, List<Tuple>> lambdas = new HashMap<Class<?>, List<Tuple>>(); + public final boolean checkForalls; - public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { + public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations, boolean checkForalls) { super(maxInvocations); + this.checkForalls = checkForalls; + } + + public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { + this(maxInvocations, false); } public void add(int type, Tuple input) { - if (!domains.containsKey(type)) domains.put(type, new LinkedList<Tuple>()); - domains.get(type).add(input); + if (!tpes.containsKey(type)) tpes.put(type, new LinkedList<Tuple>()); + tpes.get(type).add(input); + } + + public void add(Class<?> clazz, Tuple input) { + if (!lambdas.containsKey(clazz)) lambdas.put(clazz, new LinkedList<Tuple>()); + lambdas.get(clazz).add(input); } public List<Tuple> domain(Object obj, int type) { List<Tuple> domain = new LinkedList<Tuple>(); if (obj instanceof PartialLambda) { - for (Tuple key : ((PartialLambda) obj).mapping.keySet()) { + PartialLambda l = (PartialLambda) obj; + for (Tuple key : l.mapping.keySet()) { domain.add(key); } + } else if (obj instanceof Lambda) { + List<Tuple> lambdaDomain = lambdas.get(obj.getClass()); + if (lambdaDomain != null) domain.addAll(lambdaDomain); } - List<Tuple> tpeDomain = domains.get(type); + List<Tuple> tpeDomain = tpes.get(type); if (tpeDomain != null) domain.addAll(tpeDomain); return domain; diff --git a/src/main/java/leon/codegen/runtime/PartialLambda.java b/src/main/java/leon/codegen/runtime/PartialLambda.java index 826cc5ed9930e54bc2f50d7f09e6fa09be3fa307..b04036db5e9f81d1eaf7fa2c9a047bfef45a4df8 100644 --- a/src/main/java/leon/codegen/runtime/PartialLambda.java +++ b/src/main/java/leon/codegen/runtime/PartialLambda.java @@ -6,9 +6,15 @@ import java.util.HashMap; public final class PartialLambda extends Lambda { final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>(); + private final Object dflt; public PartialLambda() { + this(null); + } + + public PartialLambda(Object dflt) { super(); + this.dflt = dflt; } public void add(Tuple key, Object value) { @@ -20,15 +26,18 @@ public final class PartialLambda extends Lambda { Tuple tuple = new Tuple(args); if (mapping.containsKey(tuple)) { return mapping.get(tuple); + } else if (dflt != null) { + return dflt; } else { - throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments"); + throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments " + tuple); } } @Override public boolean equals(Object that) { if (that != null && (that instanceof PartialLambda)) { - return mapping.equals(((PartialLambda) that).mapping); + PartialLambda l = (PartialLambda) that; + return ((dflt != null && dflt.equals(l.dflt)) || (dflt == null && l.dflt == null)) && mapping.equals(l.mapping); } else { return false; } @@ -36,6 +45,12 @@ public final class PartialLambda extends Lambda { @Override public int hashCode() { - return 63 + 11 * mapping.hashCode(); + return 63 + 11 * mapping.hashCode() + (dflt == null ? 0 : dflt.hashCode()); } + + @Override + public void checkForall(boolean[] quantified) {} + + @Override + public void checkAxiom() {} } diff --git a/src/main/java/leon/codegen/runtime/Set.java b/src/main/java/leon/codegen/runtime/Set.java index 965aa20df6ba453e2fcbdfb0c2e4afce66d29d27..522dbe6589eac54bf2acac3e1f81ee4b4c7c962c 100644 --- a/src/main/java/leon/codegen/runtime/Set.java +++ b/src/main/java/leon/codegen/runtime/Set.java @@ -49,8 +49,8 @@ public final class Set { return true; } - public int size() { - return _underlying.size(); + public BigInt size() { + return new BigInt(""+_underlying.size()); } public Set union(Set s) { @@ -84,7 +84,7 @@ public final class Set { Set other = (Set)that; - return this.size() == other.size() && this.subsetOf(other); + return this.size().equals(other.size()) && this.subsetOf(other); } @Override diff --git a/src/main/java/leon/codegen/runtime/Tuple.java b/src/main/java/leon/codegen/runtime/Tuple.java index 9ae7a5f490223b0bc3b6a07dbb05160c76dc5e11..3b72931da6fc04b3504d7f3bee2056560c269634 100644 --- a/src/main/java/leon/codegen/runtime/Tuple.java +++ b/src/main/java/leon/codegen/runtime/Tuple.java @@ -54,4 +54,20 @@ public final class Tuple { _hash = h; return h; } + + @Override + public String toString() { + String str = "("; + boolean first = true; + for (Object obj : elements) { + if (first) { + first = false; + } else { + str += ", "; + } + str += obj == null ? "null" : obj.toString(); + } + str += ")"; + return str; + } } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index e544addf8084daabaeaf02e22d715f28b68e55e7..f86908c749ed508c8ee350389a2d52a239b5511a 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -6,7 +6,7 @@ package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps.{simplestValue, matchToIfThenElse, collect} +import purescala.ExprOps.{simplestValue, matchToIfThenElse, collect, variablesOf, CollectorWithPaths} import purescala.Types._ import purescala.Constructors._ import purescala.Extractors._ @@ -47,6 +47,8 @@ trait CodeGeneration { def withArgs(newArgs: Map[Identifier, Int]) = new Locals(vars, args ++ newArgs, fields) def withFields(newFields: Map[Identifier,(String,String,String)]) = new Locals(vars, args, fields ++ newFields) + + override def toString = "Locals("+vars + ", " + args + ", " + fields + ")" } object NoLocals extends Locals(Map.empty, Map.empty, Map.empty) @@ -71,8 +73,12 @@ trait CodeGeneration { private[codegen] val RationalClass = "leon/codegen/runtime/Rational" private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" private[codegen] val LambdaClass = "leon/codegen/runtime/Lambda" + private[codegen] val ForallClass = "leon/codegen/runtime/Forall" + private[codegen] val PartialLambdaClass = "leon/codegen/runtime/PartialLambda" private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" + private[codegen] val InvalidForallClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" + private[codegen] val BadQuantificationClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" private[codegen] val GenericValuesClass = "leon/codegen/runtime/GenericValues" @@ -228,8 +234,8 @@ trait CodeGeneration { private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String] private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda] - private def compileLambda(l: Lambda, ch: CodeHandler)(implicit locals: Locals): Unit = { - val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(l) + protected def compileLambda(l: Lambda): (String, Seq[(Identifier, String)], String) = { + val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(matchToIfThenElse(l)) val reverseSubst = structSubst.map(p => p._2 -> p._1) val nl = normalized.asInstanceOf[Lambda] @@ -283,6 +289,10 @@ trait CodeGeneration { cch.freeze } + val argMapping = nl.args.map(_.id).zipWithIndex.toMap + val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap + val newLocals = NoLocals.withArgs(argMapping).withFields(closureMapping) + locally { val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") @@ -291,11 +301,6 @@ trait CodeGeneration { METHOD_ACC_FINAL ).asInstanceOf[U2]) - val argMapping = nl.args.map(_.id).zipWithIndex.toMap - val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap - - val newLocals = locals.withArgs(argMapping).withFields(closureMapping) - val apch = apm.codeHandler mkBoxedExpr(nl.body, apch)(newLocals) @@ -380,22 +385,135 @@ trait CodeGeneration { hch.freeze } + locally { + val vmh = cf.addMethod("V", "checkForall", "[Z") + vmh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val vch = vmh.codeHandler + + vch << ALoad(1) // load argument array + def rec(args: Seq[Identifier], idx: Int, quantified: Set[Identifier]): Unit = args match { + case x :: xs => + val notQuantLabel = vch.getFreshLabel("notQuant") + vch << DUP << Ldc(idx) << BALOAD << IfEq(notQuantLabel) + rec(xs, idx + 1, quantified + x) + vch << Label(notQuantLabel) + rec(xs, idx + 1, quantified) + + case Nil => + if (quantified.nonEmpty) { + checkQuantified(quantified, nl.body, vch)(newLocals) + vch << ALoad(0) << InvokeVirtual(LambdaClass, "checkAxiom", "()V") + } + vch << POP << RETURN + } + + if (requireQuantification) { + rec(nl.args.map(_.id), 0, Set.empty) + } else { + vch << POP << RETURN + } + + vch.freeze + } + + locally { + val vmh = cf.addMethod("V", "checkAxiom") + vmh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val vch = vmh.codeHandler + + if (requireQuantification) { + val thisVar = FreshIdentifier("this", l.getType) + val axiom = Equals(Application(Variable(thisVar), nl.args.map(_.toVariable)), nl.body) + val axiomLocals = NoLocals.withFields(closureMapping).withVar(thisVar -> 0) + + mkForall(nl.args.map(_.id).toSet, axiom, vch, check = false)(axiomLocals) + + val skip = vch.getFreshLabel("skip") + vch << IfNe(skip) + vch << New(InvalidForallClass) << DUP + vch << Ldc("Unaxiomatic lambda " + l) + vch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + vch << ATHROW + vch << Label(skip) + } + + vch << RETURN + vch.freeze + } + loader.register(cf) afName }) - val consSig = "(" + closures.map(_._2).mkString("") + ")V" + (afName, closures.map { case p @ (id, jvmt) => + if (id == monitorID) p else (reverseSubst(id) -> jvmt) + }, "(" + closures.map(_._2).mkString("") + ")V") + } - ch << New(afName) << DUP - for ((id,jvmt) <- closures) { - if (id == monitorID) { - load(monitorID, ch) - } else { - mkExpr(Variable(reverseSubst(id)), ch) - } + private def checkQuantified(quantified: Set[Identifier], body: Expr, ch: CodeHandler)(implicit locals: Locals): Unit = { + val skipCheck = ch.getFreshLabel("skipCheck") + + load(monitorID, ch) + ch << CheckCast(HenkinClass) << GetField(HenkinClass, "checkForalls", "Z") + ch << IfEq(skipCheck) + + checkForall(quantified, body)(ctx) match { + case status: ForallInvalid => + ch << New(InvalidForallClass) << DUP + ch << Ldc("Invalid forall: " + status.getMessage) + ch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + ch << ATHROW + + case ForallValid => + // expand match case expressions and lets so that caller can be compiled given + // the current locals (lets and matches introduce new locals) + val cleanBody = purescala.ExprOps.expandLets(purescala.ExprOps.matchToIfThenElse(body)) + + val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] { + def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match { + case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path)) + case _ => None + } + + override def rec(e: Expr, path: Seq[Expr]): Expr = e match { + case l : Lambda => l + case _ => super.rec(e, path) + } + }.traverse(cleanBody) + + for ((caller, args, paths) <- calls) { + if ((variablesOf(caller) & quantified).isEmpty) { + val enabler = andJoin(paths.filter(expr => (variablesOf(expr) & quantified).isEmpty)) + val skipCall = ch.getFreshLabel("skipCall") + mkExpr(enabler, ch) + ch << IfEq(skipCall) + + mkExpr(caller, ch) + ch << Ldc(args.size) << NewArray.primitive("T_BOOLEAN") + for ((arg, idx) <- args.zipWithIndex) { + ch << DUP << Ldc(idx) << Ldc(arg match { + case Variable(id) if quantified(id) => 1 + case _ => 0 + }) << BASTORE + } + + ch << InvokeVirtual(LambdaClass, "checkForall", "([Z)V") + + ch << Label(skipCall) + } + } } - ch << InvokeSpecial(afName, constructorName, consSig) + + ch << Label(skipCheck) } private val typeIdCache = scala.collection.mutable.Map.empty[TypeTree, Int] @@ -407,134 +525,270 @@ trait CodeGeneration { id } - private def compileForall(f: Forall, ch: CodeHandler)(implicit locals: Locals): Unit = { - // make sure we have an available HenkinModel - val monitorOk = ch.getFreshLabel("monitorOk") + private[codegen] val forallToClass = scala.collection.mutable.Map.empty[Expr, String] + + private def mkForall(quants: Set[Identifier], body: Expr, ch: CodeHandler, check: Boolean = true)(implicit locals: Locals): Unit = { + val (afName, closures, consSig) = compileForall(quants, body) + ch << New(afName) << DUP load(monitorID, ch) - ch << InstanceOf(HenkinClass) << IfNe(monitorOk) - ch << New(ImpossibleEvaluationClass) << DUP - ch << Ldc("Can't evaluate foralls without domain") - ch << InvokeSpecial(ImpossibleEvaluationClass, constructorName, "(Ljava/lang/String;)V") - ch << ATHROW - ch << Label(monitorOk) - - val Forall(fargs, TopLevelAnds(conjuncts)) = f - val endLabel = ch.getFreshLabel("forallEnd") - - for (conj <- conjuncts) { - val vars = purescala.ExprOps.variablesOf(conj) - val args = fargs.map(_.id).filter(vars) - val quantified = args.toSet - - val matchQuorums = extractQuorums(conj, quantified) - - val matcherIndexes = matchQuorums.flatten.distinct.zipWithIndex.toMap - - def buildLoops( - mis: List[(Expr, Seq[Expr], Int)], - localMapping: Map[Identifier, Int], - pointerMapping: Map[(Int, Int), Identifier] - ): Unit = mis match { - case (expr, args, qidx) :: rest => - load(monitorID, ch) - ch << CheckCast(HenkinClass) + mkTuple(closures.map(_.toVariable) :+ BooleanLiteral(check), ch) + ch << InvokeSpecial(afName, constructorName, consSig) + ch << InvokeVirtual(ForallClass, "check", "()Z") + } + + private def compileForall(quants: Set[Identifier], body: Expr): (String, Seq[Identifier], String) = { + val (nl, structSubst) = purescala.ExprOps.normalizeStructure(matchToIfThenElse(body)) + val reverseSubst = structSubst.map(p => p._2 -> p._1) + val nquants = quants.flatMap(structSubst.get) - mkExpr(expr, ch) - ch << Ldc(typeId(expr.getType)) - ch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") - ch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") + val closures = (purescala.ExprOps.variablesOf(nl) -- nquants).toSeq.sortBy(_.uniqueName) - val loop = ch.getFreshLabel("loop") - val out = ch.getFreshLabel("out") - ch << Label(loop) - // it - ch << DUP - // it, it - ch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") - // it, hasNext - ch << IfEq(out) << DUP - // it, it - ch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") - // it, elem - ch << CheckCast(TupleClass) - - val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { - val id = FreshIdentifier("q", arg.getType, true) - val slot = ch.getFreshVar - - ch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") - mkUnbox(arg.getType, ch) - ch << (typeToJVM(arg.getType) match { - case "I" | "Z" => IStore(slot) - case _ => AStore(slot) - }) - - (id -> slot, (qidx -> aidx) -> id) - }).unzip - - ch << POP - // it - - buildLoops(rest, localMapping ++ newLoc, pointerMapping ++ newPtr) - - ch << Goto(loop) - ch << Label(out) << POP - - case Nil => - var okLabel: Option[String] = None - for (quorum <- matchQuorums) { - okLabel.foreach(ok => ch << Label(ok)) - okLabel = Some(ch.getFreshLabel("quorumOk")) - - var mappings: Seq[(Identifier, Int, Int)] = Seq.empty - var constraints: Seq[(Expr, Int, Int)] = Seq.empty - var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty - - for (q @ (expr, args) <- quorum) { - val qidx = matcherIndexes(q) - val (qmappings, qconstraints) = args.zipWithIndex.partition { - case (Variable(id), aidx) => quantified(id) - case _ => false - } - - mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) - constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) - } + val afName = forallToClass.getOrElse(nl, { + val afName = "Leon$CodeGen$Forall$" + forallCounter.nextGlobal + forallToClass += nl -> afName - val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { - val base :: others = es.toList.map(p => (p._2, p._3)) - equalities ++= others.map(p => base -> p) - (id -> base) - } + val cf = new ClassFile(afName, Some(ForallClass)) + + cf.setFlags(( + CLASS_ACC_SUPER | + CLASS_ACC_PUBLIC | + CLASS_ACC_FINAL + ).asInstanceOf[U2]) - val enabler = andJoin(constraints.map { - case (e, qidx, aidx) => Equals(e, pointerMapping(qidx -> aidx).toVariable) - } ++ equalities.map { - case (k1, k2) => Equals(pointerMapping(k1).toVariable, pointerMapping(k2).toVariable) - }) + locally { + val cch = cf.addConstructor(s"L$MonitorClass;", s"L$TupleClass;").codeHandler + + cch << ALoad(0) << ALoad(1) << ALoad(2) + cch << InvokeSpecial(ForallClass, constructorName, s"(L$MonitorClass;L$TupleClass;)V") + cch << RETURN + cch.freeze + } + + locally { + val cfm = cf.addMethod("Z", "checkForall") + + cfm.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val cfch = cfm.codeHandler + + cfch << ALoad(0) << GetField(ForallClass, "closures", s"L$TupleClass;") + + val closureVars = (for ((id, idx) <- closures.zipWithIndex) yield { + val slot = cfch.getFreshVar + cfch << DUP << Ldc(idx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") + mkUnbox(id.getType, cfch) + cfch << (id.getType match { + case ValueType() => IStore(slot) + case _ => AStore(slot) + }) + id -> slot + }).toMap + + cfch << POP + + val monitorSlot = cfch.getFreshVar + cfch << ALoad(0) << GetField(ForallClass, "monitor", s"L$HenkinClass;") + cfch << AStore(monitorSlot) + + implicit val locals = NoLocals.withVars(closureVars).withVar(monitorID -> monitorSlot) + + val skipCheck = cfch.getFreshLabel("skipCheck") + cfch << ALoad(0) << GetField(ForallClass, "check", "Z") + cfch << IfEq(skipCheck) + checkQuantified(nquants, nl, cfch) + cfch << Label(skipCheck) + + val TopLevelAnds(conjuncts) = nl + val endLabel = cfch.getFreshLabel("forallEnd") + + for (conj <- conjuncts) { + val vars = purescala.ExprOps.variablesOf(conj) + val quantified = nquants.filter(vars) + + val matchQuorums = extractQuorums(conj, quantified) + + var allSlots: List[Int] = Nil + var freeSlots: List[Int] = Nil + def getSlot(): Int = freeSlots match { + case x :: xs => + freeSlots = xs + x + case Nil => + val slot = cfch.getFreshVar + allSlots = allSlots :+ slot + slot + } - mkExpr(enabler, ch)(locals.withVars(localMapping)) - ch << IfEq(okLabel.get) + for ((qrm, others) <- matchQuorums) { + val quorum = qrm.toList + + def rec(mis: List[(Expr, Expr, Seq[Expr], Int)], locs: Map[Identifier, Int], pointers: Map[(Int, Int), Identifier]): Unit = mis match { + case (TopLevelAnds(paths), expr, args, qidx) :: rest => + load(monitorID, cfch) + cfch << CheckCast(HenkinClass) + + mkExpr(expr, cfch) + cfch << Ldc(typeId(expr.getType)) + cfch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") + cfch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") + + val loop = cfch.getFreshLabel("loop") + val out = cfch.getFreshLabel("out") + cfch << Label(loop) + // it + cfch << DUP + // it, it + cfch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") + // it, hasNext + cfch << IfEq(out) << DUP + // it, it + cfch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") + // it, elem + cfch << CheckCast(TupleClass) + + val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { + val id = FreshIdentifier("q", arg.getType, true) + val slot = getSlot() + + cfch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") + mkUnbox(arg.getType, cfch) + cfch << (typeToJVM(arg.getType) match { + case "I" | "Z" => IStore(slot) + case _ => AStore(slot) + }) + + (id -> slot, (qidx -> aidx) -> id) + }).unzip + + cfch << POP + // it + + rec(rest, locs ++ newLoc, pointers ++ newPtr) + + cfch << Goto(loop) + cfch << Label(out) << POP + + case Nil => + val okLabel = cfch.getFreshLabel("assignmentOk") + + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + + for ((q @ (_, expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id), aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, pointers(qidx -> aidx).toVariable) + } ++ equalities.map { + case (k1, k2) => Equals(pointers(k1).toVariable, pointers(k2).toVariable) + }) + + val varsMap = quantified.map(id => id -> locs(pointers(mapping(id)))).toMap + val varLocals = locals.withVars(varsMap) + + mkExpr(enabler, cfch)(varLocals.withVars(locs)) + cfch << IfEq(okLabel) + + val checkOk = cfch.getFreshLabel("checkOk") + load(monitorID, cfch) + cfch << GetField(HenkinClass, "checkForalls", "Z") + cfch << IfEq(checkOk) + + var nextLabel: Option[String] = None + for ((b,caller,args) <- others) { + nextLabel.foreach(label => cfch << Label(label)) + nextLabel = Some(cfch.getFreshLabel("next")) + + mkExpr(b, cfch)(varLocals) + cfch << IfEq(nextLabel.get) + + load(monitorID, cfch) + cfch << CheckCast(HenkinClass) + mkExpr(caller, cfch)(varLocals) + cfch << Ldc(typeId(caller.getType)) + cfch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") + mkTuple(args, cfch)(varLocals) + cfch << InvokeInterface(JavaListClass, "contains", s"(L$ObjectClass;)Z") + cfch << IfNe(nextLabel.get) + + cfch << New(InvalidForallClass) << DUP + cfch << Ldc("Unhandled transitive implication in " + conj) + cfch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + cfch << ATHROW + } + nextLabel.foreach(label => cfch << Label(label)) + + cfch << Label(checkOk) + mkExpr(conj, cfch)(varLocals) + cfch << IfNe(okLabel) + + // -- Forall is false! -- + // POP all the iterators... + for (_ <- List.range(0, quorum.size)) cfch << POP + + // ... and return false + cfch << Ldc(0) << Goto(endLabel) + cfch << Label(okLabel) + } - val varsMap = args.map(id => id -> localMapping(pointerMapping(mapping(id)))).toMap - mkExpr(conj, ch)(locals.withVars(varsMap)) - ch << IfNe(okLabel.get) + val skipQuorum = cfch.getFreshLabel("skipQuorum") + for ((TopLevelAnds(paths), _, _) <- quorum) { + val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty)) + mkExpr(p, cfch) + cfch << IfEq(skipQuorum) + } - // -- Forall is false! -- - // POP all the iterators... - for (_ <- List.range(0, matcherIndexes.size)) ch << POP + val mis = quorum.zipWithIndex.map { case ((p, e, as), idx) => (p, e, as, idx) } + rec(mis, Map.empty, Map.empty) + freeSlots = allSlots - // ... and return false - ch << Ldc(0) << Goto(endLabel) + cfch << Label(skipQuorum) } + } + + cfch << Ldc(1) << Label(endLabel) + cfch << IRETURN - ch << Label(okLabel.get) + cfch.freeze } - buildLoops(matcherIndexes.toList.map { case ((e, as), idx) => (e, as, idx) }, Map.empty, Map.empty) - } + loader.register(cf) + + afName + }) - ch << Ldc(1) << Label(endLabel) + (afName, closures.map(reverseSubst), s"(L$MonitorClass;L$TupleClass;)V") + } + + // also makes tuples with 0/1 args + private def mkTuple(es: Seq[Expr], ch: CodeHandler)(implicit locals: Locals) : Unit = { + ch << New(TupleClass) << DUP + ch << Ldc(es.size) + ch << NewArray(s"$ObjectClass") + for((e,i) <- es.zipWithIndex) { + ch << DUP + ch << Ldc(i) + mkBoxedExpr(e, ch) + ch << AASTORE + } + ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V") } private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { @@ -627,17 +881,7 @@ trait CodeGeneration { instrumentedGetField(ch, cct, sid) // Tuples (note that instanceOf checks are in mkBranch) - case Tuple(es) => - ch << New(TupleClass) << DUP - ch << Ldc(es.size) - ch << NewArray(s"$ObjectClass") - for((e,i) <- es.zipWithIndex) { - ch << DUP - ch << Ldc(i) - mkBoxedExpr(e, ch) - ch << AASTORE - } - ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V") + case Tuple(es) => mkTuple(es, ch) case TupleSelect(t, i) => val TupleType(bs) = t.getType @@ -662,7 +906,7 @@ trait CodeGeneration { case SetCardinality(s) => mkExpr(s, ch) - ch << InvokeVirtual(SetClass, "size", "()I") + ch << InvokeVirtual(SetClass, "size", s"()$BigIntClass;") case SubsetOf(s1, s2) => mkExpr(s1, ch) @@ -890,11 +1134,38 @@ trait CodeGeneration { ch << InvokeVirtual(LambdaClass, "apply", s"([L$ObjectClass;)L$ObjectClass;") mkUnbox(app.getType, ch) + case p @ PartialLambda(mapping, optDflt, _) => + ch << New(PartialLambdaClass) << DUP + optDflt match { + case Some(dflt) => + mkBoxedExpr(dflt, ch) + ch << InvokeSpecial(PartialLambdaClass, constructorName, s"(L$ObjectClass;)V") + case None => + ch << InvokeSpecial(PartialLambdaClass, constructorName, "()V") + } + + for ((es,v) <- mapping) { + ch << DUP + mkTuple(es, ch) + mkBoxedExpr(v, ch) + ch << InvokeVirtual(PartialLambdaClass, "add", s"(L$TupleClass;L$ObjectClass;)V") + } + case l @ Lambda(args, body) => - compileLambda(l, ch) + val (afName, closures, consSig) = compileLambda(l) + + ch << New(afName) << DUP + for ((id,jvmt) <- closures) { + if (id == monitorID) { + load(monitorID, ch) + } else { + mkExpr(Variable(id), ch) + } + } + ch << InvokeSpecial(afName, constructorName, consSig) case f @ Forall(args, body) => - compileForall(f, ch) + mkForall(args.map(_.id).toSet, body, ch) // String processing => case StringConcat(l, r) => @@ -1219,7 +1490,7 @@ trait CodeGeneration { // Assumes the top of the stack contains of value of the right type, and makes it // compatible with java.lang.Object. - private[codegen] def mkBox(tpe: TypeTree, ch: CodeHandler)(implicit locals: Locals) { + private[codegen] def mkBox(tpe: TypeTree, ch: CodeHandler): Unit = { tpe match { case Int32Type => ch << New(BoxedIntClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedIntClass, constructorName, "(I)V") @@ -1237,7 +1508,7 @@ trait CodeGeneration { } // Assumes that the top of the stack contains a value that should be of type `tpe`, and unboxes it to the right (JVM) type. - private[codegen] def mkUnbox(tpe: TypeTree, ch: CodeHandler)(implicit locals: Locals) { + private[codegen] def mkUnbox(tpe: TypeTree, ch: CodeHandler): Unit = { tpe match { case Int32Type => ch << CheckCast(BoxedIntClass) << InvokeVirtual(BoxedIntClass, "intValue", "()I") @@ -1512,7 +1783,7 @@ trait CodeGeneration { lzy.returnType match { case ValueType() => // Since the underlying field only has boxed types, we have to unbox them to return them - mkUnbox(lzy.returnType, ch)(newLocs) + mkUnbox(lzy.returnType, ch) ch << IRETURN case _ => ch << ARETURN @@ -1846,7 +2117,7 @@ trait CodeGeneration { pech << Ldc(i) pech << ALoad(0) instrumentedGetField(pech, cct, f.id)(newLocs) - mkBox(f.getType, pech)(newLocs) + mkBox(f.getType, pech) pech << AASTORE } diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 8b2ee0bcab255e29ccb3948f55a9d14b64205d53..93c8c6d021233a6c5968d7de9a78b7c3d135ced1 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -127,13 +127,24 @@ class CompilationUnit(val ctx: LeonContext, conss.last } - def modelToJVM(model: solvers.Model, maxInvocations: Int): LeonCodeGenRuntimeMonitor = model match { + def modelToJVM(model: solvers.Model, maxInvocations: Int, check: Boolean): LeonCodeGenRuntimeMonitor = model match { case hModel: solvers.HenkinModel => - val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations) - for ((tpe, domain) <- hModel.domains; args <- domain) { + val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations, check) + for ((lambda, domain) <- hModel.doms.lambdas) { + val (afName, _, _) = compileLambda(lambda) + val lc = loader.loadClass(afName) + + for (args <- domain) { + // note here that it doesn't matter that `lhm` doesn't yet have its domains + // filled since all values in `args` should be grounded + val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] + lhm.add(lc, inputJvm) + } + } + + for ((tpe, domain) <- hModel.doms.tpes; args <- domain) { val tpeId = typeId(tpe) - // note here that it doesn't matter that `lhm` doesn't yet have its domains - // filled since all values in `args` should be grounded + // same remark as above about valueToJVM(_)(lhm) val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] lhm.add(tpeId, inputJvm) } @@ -201,8 +212,13 @@ class CompilationUnit(val ctx: LeonContext, } m - case f @ PartialLambda(mapping, _) => - val l = new leon.codegen.runtime.PartialLambda() + case f @ PartialLambda(mapping, dflt, _) => + val l = if (dflt.isDefined) { + new leon.codegen.runtime.PartialLambda(dflt.get) + } else { + new leon.codegen.runtime.PartialLambda() + } + for ((ks,v) <- mapping) { // Force tuple even with 1/0 elems. val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple] @@ -533,3 +549,5 @@ class CompilationUnit(val ctx: LeonContext, private [codegen] object exprCounter extends UniqueCounter[Unit] private [codegen] object lambdaCounter extends UniqueCounter[Unit] +private [codegen] object forallCounter extends UniqueCounter[Unit] + diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index a9d1eda0c5e36e19a6b6c12f99a617b480866f10..f9fca911564c61ad984fa97c3f2ac0da7fc021b4 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -8,7 +8,7 @@ import purescala.Expressions._ import cafebabe._ -import runtime.{LeonCodeGenRuntimeMonitor => LM, LeonCodeGenRuntimeHenkinMonitor => LHM} +import runtime.{LeonCodeGenRuntimeMonitor => LM} import java.lang.reflect.InvocationTargetException @@ -51,9 +51,9 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, } } - def eval(model: solvers.Model) : Expr = { + def eval(model: solvers.Model, check: Boolean = false) : Expr = { try { - val monitor = unit.modelToJVM(model, params.maxFunctionInvocations) + val monitor = unit.modelToJVM(model, params.maxFunctionInvocations, check) evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) } catch { case ite : InvocationTargetException => throw ite.getCause diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 9b9224482d6c7ed058b526ec70f414663558942c..0f3df8aede362bb3e3c7fb2c6431c3daf059556c 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -33,6 +33,14 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { b -> Constructor[Expr, TypeTree](List(), BooleanType, s => BooleanLiteral(b), ""+b) }).toMap + val chars = (for (c <- Set('a', 'b', 'c', 'd')) yield { + c -> Constructor[Expr, TypeTree](List(), CharType, s => CharLiteral(c), ""+c) + }).toMap + + val rationals = (for (n <- Set(0, 1, 2, 3); d <- Set(1,2,3,4)) yield { + (n, d) -> Constructor[Expr, TypeTree](List(), RealType, s => FractionalLiteral(n, d), "" + n + "/" + d) + }).toMap + val strings = (for (b <- Set("", "a", "b", "Abcd")) yield { b -> Constructor[Expr, TypeTree](List(), StringType, s => StringLiteral(b), b) }).toMap @@ -44,6 +52,10 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { def boolConstructor(b: Boolean) = booleans(b) + def charConstructor(c: Char) = chars(c) + + def rationalConstructor(n: Int, d: Int) = rationals(n -> d) + def stringConstructor(s: String) = strings(s) def cPattern(c: Constructor[Expr, TypeTree], args: Seq[VPattern[Expr, TypeTree]]) = { @@ -57,7 +69,6 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { getConstructors(t).head.copy(retType = act) } - private def getConstructors(t: TypeTree): List[Constructor[Expr, TypeTree]] = t match { case UnitType => constructors.getOrElse(t, { @@ -105,7 +116,6 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { constructors.getOrElse(mt, { val cs = for (size <- List(0, 1, 2, 5)) yield { val subs = (1 to size).flatMap(i => List(from, to)).toList - Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toMap, from, to), mt.asString(ctx)+"@"+size) } constructors += mt -> cs @@ -117,13 +127,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val cs = for (size <- List(1, 2, 3, 5)) yield { val subs = (1 to size).flatMap(_ => from :+ to).toList Constructor[Expr, TypeTree](subs, ft, { s => - val args = from.map(tpe => FreshIdentifier("x", tpe, true)) - val argsTuple = tupleWrap(args.map(_.toVariable)) val grouped = s.grouped(from.size + 1).toSeq - val body = grouped.init.foldRight(grouped.last.last) { case (t, elze) => - IfExpr(Equals(argsTuple, tupleWrap(t.init)), t.last, elze) - } - Lambda(args.map(id => ValDef(id)), body) + val mapping = grouped.init.map { case args :+ res => (args -> res) } + PartialLambda(mapping, Some(grouped.last.last), ft) }, ft.asString(ctx) + "@" + size) } constructors += ft -> cs @@ -173,6 +179,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case (b: java.lang.Boolean, BooleanType) => (cPattern(boolConstructor(b), List()), true) + case (c: java.lang.Character, CharType) => + (cPattern(charConstructor(c), List()), true) + case (b: java.lang.String, StringType) => (cPattern(stringConstructor(b), List()), true) @@ -203,7 +212,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { (ConstructorPattern(c, elems.map(_._1)), elems.forall(_._2)) case _ => - ctx.reporter.error("Could not retreive type for :"+cc.getClass.getName) + ctx.reporter.error("Could not retrieve type for :"+cc.getClass.getName) (AnyPattern[Expr, TypeTree](), false) } @@ -227,6 +236,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case (gv: GenericValue, t: TypeParameter) => (cPattern(getConstructors(t)(gv.id-1), List()), true) + case (v, t) => ctx.reporter.debug("Unsupported value, can't paternify : "+v+" ("+v.getClass+") : "+t) (AnyPattern[Expr, TypeTree](), false) @@ -297,8 +307,8 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { None }) - - val gen = new StubGenerator[Expr, TypeTree]((ints.values ++ bigInts.values ++ booleans.values).toSeq, + val stubValues = ints.values ++ bigInts.values ++ booleans.values ++ chars.values ++ rationals.values + val gen = new StubGenerator[Expr, TypeTree](stubValues.toSeq, Some(getConstructors _), treatEmptyStubsAsChildless = true) diff --git a/src/main/scala/leon/evaluators/AngelicEvaluator.scala b/src/main/scala/leon/evaluators/AngelicEvaluator.scala index 090eeff0f6d38a5802718989ff81fe623e670d4e..99d704f67c7485ba8e00727c4edcc5d30644b4cc 100644 --- a/src/main/scala/leon/evaluators/AngelicEvaluator.scala +++ b/src/main/scala/leon/evaluators/AngelicEvaluator.scala @@ -22,6 +22,9 @@ class AngelicEvaluator(underlying: NDEvaluator) case other@(RuntimeError(_) | EvaluatorError(_)) => other.asInstanceOf[Result[Nothing]] } + + /** Checks that `model |= expr` and that quantifications are all valid */ + def check(expr: Expr, model: Model): CheckResult = underlying.check(expr, model) } class DemonicEvaluator(underlying: NDEvaluator) @@ -39,4 +42,7 @@ class DemonicEvaluator(underlying: NDEvaluator) case other@(RuntimeError(_) | EvaluatorError(_)) => other.asInstanceOf[Result[Nothing]] } + + /** Checks that `model |= expr` and that quantifications are all valid */ + def check(expr: Expr, model: Model): CheckResult = underlying.check(expr, model) } \ No newline at end of file diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index a32bce7e5a1d286bc09a65e491f6b2bd33a3efde..533ba695ca27f478a03ac7b6cf53d46885e11309 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -8,9 +8,15 @@ import purescala.Definitions._ import purescala.Expressions._ import codegen.CompilationUnit +import codegen.CompiledExpression import codegen.CodeGenParams +import leon.codegen.runtime.LeonCodeGenRuntimeException +import leon.codegen.runtime.LeonCodeGenEvaluationException +import leon.codegen.runtime.LeonCodeGenQuantificationException + class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) with DeterministicEvaluator { + val name = "codegen-eval" val description = "Evaluator for PureScala expressions based on compilation to JVM" @@ -19,9 +25,55 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva this(ctx, new CompilationUnit(ctx, prog, params)) } + private def compileExpr(expression: Expr, args: Seq[Identifier]): Option[CompiledExpression] = { + ctx.timers.evaluators.codegen.compilation.start() + try { + Some(unit.compileExpression(expression, args)(ctx)) + } catch { + case t: Throwable => + ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage) + None + } finally { + ctx.timers.evaluators.codegen.compilation.stop() + } + } + + def check(expression: Expr, model: solvers.Model) : CheckResult = { + compileExpr(expression, model.toSeq.map(_._1)).map { ce => + ctx.timers.evaluators.codegen.runtime.start() + try { + val res = ce.eval(model, check = true) + if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess + else EvaluationResults.CheckValidityFailure + } catch { + case e : ArithmeticException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : ArrayIndexOutOfBoundsException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : LeonCodeGenRuntimeException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : LeonCodeGenEvaluationException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : java.lang.ExceptionInInitializerError => + EvaluationResults.CheckRuntimeFailure(e.getException.getMessage) + + case so : java.lang.StackOverflowError => + EvaluationResults.CheckRuntimeFailure("Stack overflow") + + case e : LeonCodeGenQuantificationException => + EvaluationResults.CheckQuantificationFailure(e.getMessage) + } finally { + ctx.timers.evaluators.codegen.runtime.stop() + } + }.getOrElse(EvaluationResults.CheckRuntimeFailure("Couldn't compile expression.")) + } + def eval(expression: Expr, model: solvers.Model) : EvaluationResult = { - val toPairs = model.toSeq - compile(expression, toPairs.map(_._1)).map { e => + compile(expression, model.toSeq.map(_._1)).map { e => ctx.timers.evaluators.codegen.runtime.start() val res = e(model) ctx.timers.evaluators.codegen.runtime.stop() @@ -30,14 +82,7 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva } override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { - import leon.codegen.runtime.LeonCodeGenRuntimeException - import leon.codegen.runtime.LeonCodeGenEvaluationException - - ctx.timers.evaluators.codegen.compilation.start() - try { - val ce = unit.compileExpression(expression, args)(ctx) - - Some((model: solvers.Model) => { + compileExpr(expression, args).map(ce => (model: solvers.Model) => { if (args.exists(arg => !model.isDefinedAt(arg))) { EvaluationResults.EvaluatorError("Model undefined for free arguments") } else try { @@ -60,15 +105,7 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva case so : java.lang.StackOverflowError => EvaluationResults.RuntimeError("Stack overflow") - } }) - } catch { - case t: Throwable => - ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage) - None - } finally { - ctx.timers.evaluators.codegen.compilation.stop() } } -} diff --git a/src/main/scala/leon/evaluators/ContextualEvaluator.scala b/src/main/scala/leon/evaluators/ContextualEvaluator.scala index 59e46a658eda6ddb73b6c62f86d10836d60e7dc1..0fc33102a04716816fc3b2a83faa1384b37da1fd 100644 --- a/src/main/scala/leon/evaluators/ContextualEvaluator.scala +++ b/src/main/scala/leon/evaluators/ContextualEvaluator.scala @@ -3,13 +3,11 @@ package leon package evaluators +import leon.purescala.Extractors.{IsTyped, TopLevelAnds} import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.Types._ -import purescala.Constructors._ -import purescala.ExprOps._ -import purescala.Quantification._ import solvers.{HenkinModel, Model} abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps: Int) extends Evaluator(ctx, prog) with CEvalHelpers { @@ -20,17 +18,18 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps type GC <: GlobalContext def initRC(mappings: Map[Identifier, Expr]): RC - def initGC(model: solvers.Model): GC + def initGC(model: solvers.Model, check: Boolean): GC case class EvalError(msg : String) extends Exception case class RuntimeError(msg : String) extends Exception + case class QuantificationError(msg: String) extends Exception // Used by leon-web, please do not delete var lastGC: Option[GC] = None def eval(ex: Expr, model: Model) = { try { - lastGC = Some(initGC(model)) + lastGC = Some(initGC(model, check = true)) ctx.timers.evaluators.recursive.runtime.start() EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get)) } catch { @@ -47,6 +46,30 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps } } + def check(ex: Expr, model: Model): CheckResult = { + assert(ex.getType == BooleanType, "Can't check non-boolean expression " + ex.asString) + try { + lastGC = Some(initGC(model, check = true)) + ctx.timers.evaluators.recursive.runtime.start() + val res = e(ex)(initRC(model.toMap), lastGC.get) + if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess + else EvaluationResults.CheckValidityFailure + } catch { + case so: StackOverflowError => + EvaluationResults.CheckRuntimeFailure("Stack overflow") + case e @ EvalError(msg) => + EvaluationResults.CheckRuntimeFailure(msg) + case e @ RuntimeError(msg) => + EvaluationResults.CheckRuntimeFailure(msg) + case jre: java.lang.RuntimeException => + EvaluationResults.CheckRuntimeFailure(jre.getMessage) + case qe @ QuantificationError(msg) => + EvaluationResults.CheckQuantificationFailure(msg) + } finally { + ctx.timers.evaluators.recursive.runtime.stop() + } + } + protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): Value def typeErrorMsg(tree : Expr, expected : TypeTree) : String = s"Type error : expected ${expected.asString}, found ${tree.asString}." @@ -55,60 +78,62 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps private[evaluators] trait CEvalHelpers { this: ContextualEvaluator => - - def forallInstantiations(gctx:GC, fargs: Seq[ValDef], conj: Expr) = { - - val henkinModel: HenkinModel = gctx.model match { - case hm: HenkinModel => hm - case _ => throw EvalError("Can't evaluate foralls without henkin model") - } - - val vars = variablesOf(conj) - val args = fargs.map(_.id).filter(vars) - val quantified = args.toSet - val matcherQuorums = extractQuorums(conj, quantified) + /* This is an effort to generalize forall to non-det. solvers + def forallInstantiations(gctx:GC, fargs: Seq[ValDef], conj: Expr) = { - matcherQuorums.flatMap { quorum => - var mappings: Seq[(Identifier, Int, Int)] = Seq.empty - var constraints: Seq[(Expr, Int, Int)] = Seq.empty + val henkinModel: HenkinModel = gctx.model match { + case hm: HenkinModel => hm + case _ => throw EvalError("Can't evaluate foralls without henkin model") + } - for (((expr, args), qidx) <- quorum.zipWithIndex) { - val (qmappings, qconstraints) = args.zipWithIndex.partition { - case (Variable(id), aidx) => quantified(id) - case _ => false - } + val vars = variablesOf(conj) + val args = fargs.map(_.id).filter(vars) + val quantified = args.toSet - mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) - constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) - } + val matcherQuorums = extractQuorums(conj, quantified) - var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty - val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { - val base :: others = es.toList.map(p => (p._2, p._3)) - equalities ++= others.map(p => base -> p) - (id -> base) - } + matcherQuorums.flatMap { quorum => + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty - val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { - case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d)) + for (((expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id), aidx) => quantified(id) + case _ => false } - argSets.map { args => - val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { - case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } - }.toMap + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { + case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d)) + } + + argSets.map { args => + val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { + case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } + }.toMap + + val map = mapping.map { case (id, key) => id -> argMap(key) } + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) + } ++ equalities.map { + case (k1, k2) => Equals(argMap(k1), argMap(k2)) + }) + + (enabler, map) + } + }*/ + - val map = mapping.map { case (id, key) => id -> argMap(key) } - val enabler = andJoin(constraints.map { - case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) - } ++ equalities.map { - case (k1, k2) => Equals(argMap(k1), argMap(k2)) - }) - (enabler, map) - } - } - - } } \ No newline at end of file diff --git a/src/main/scala/leon/evaluators/DefaultEvaluator.scala b/src/main/scala/leon/evaluators/DefaultEvaluator.scala index 5c951eedc1a68943c42d9ee4f07f8d58e5f74bf9..18a9159c3cb0e29f47e0757314f935865f6b10cf 100644 --- a/src/main/scala/leon/evaluators/DefaultEvaluator.scala +++ b/src/main/scala/leon/evaluators/DefaultEvaluator.scala @@ -1,3 +1,5 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + package leon package evaluators @@ -5,4 +7,5 @@ import purescala.Definitions.Program class DefaultEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 5000) - with DefaultContexts \ No newline at end of file + with HasDefaultGlobalContext + with HasDefaultRecContext diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index 05f60fd270f7d7ebc28b8627b703d23f30b6e214..4c405c8b6f216ee9b839101d8bff5574035b05f4 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -10,16 +10,14 @@ import purescala.Types._ import codegen._ -class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) extends RecursiveEvaluator(ctx, prog, params.maxFunctionInvocations) { +class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) + extends RecursiveEvaluator(ctx, prog, params.maxFunctionInvocations) + with HasDefaultGlobalContext +{ type RC = DualRecContext - type GC = GlobalContext - - def initGC(model: solvers.Model) = new GlobalContext(model, this.maxSteps) - - implicit val debugSection = utils.DebugSectionEvaluation - def initRC(mappings: Map[Identifier, Expr]): RC = DualRecContext(mappings) + implicit val debugSection = utils.DebugSectionEvaluation var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) diff --git a/src/main/scala/leon/evaluators/EvaluationResults.scala b/src/main/scala/leon/evaluators/EvaluationResults.scala index b9cbecdde4cdd31b2b29dcb82fdd23c50aaf9a24..18f7a0c92d448f98c8f6a271d91e021a649e3b9c 100644 --- a/src/main/scala/leon/evaluators/EvaluationResults.scala +++ b/src/main/scala/leon/evaluators/EvaluationResults.scala @@ -15,4 +15,21 @@ object EvaluationResults { /** Represents an evaluation that failed (in the evaluator). */ case class EvaluatorError(message : String) extends Result(None) + + /** Results of checking proposition evaluation. + * Useful for verification of model validity in presence of quantifiers. */ + sealed abstract class CheckResult(val success: Boolean) + + /** Successful proposition evaluation (model |= expr) */ + case object CheckSuccess extends CheckResult(true) + + /** Check failed with `model |= !expr` */ + case object CheckValidityFailure extends CheckResult(false) + + /** Check failed due to evaluation or runtime errors. + * @see [[RuntimeError]] and [[EvaluatorError]] */ + case class CheckRuntimeFailure(msg: String) extends CheckResult(false) + + /** Check failed due to inconsistence of model with quantified propositions. */ + case class CheckQuantificationFailure(msg: String) extends CheckResult(false) } diff --git a/src/main/scala/leon/evaluators/Evaluator.scala b/src/main/scala/leon/evaluators/Evaluator.scala index 9843da81d4187e468cf418c6100f2ce68f2dfefd..ff0f35f1241547d66f81f0b341fca508276b40ea 100644 --- a/src/main/scala/leon/evaluators/Evaluator.scala +++ b/src/main/scala/leon/evaluators/Evaluator.scala @@ -18,6 +18,7 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends type Value type EvaluationResult = EvaluationResults.Result[Value] + type CheckResult = EvaluationResults.CheckResult /** Evaluates an expression, using [[Model.mapping]] as a valuation function for the free variables. */ def eval(expr: Expr, model: Model) : EvaluationResult @@ -30,6 +31,9 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends /** Evaluates a ground expression. */ final def eval(expr: Expr) : EvaluationResult = eval(expr, Model.empty) + /** Checks that `model |= expr` and that quantifications are all valid */ + def check(expr: Expr, model: Model) : CheckResult + /** Compiles an expression into a function, where the arguments are the free variables in the expression. * `argorder` specifies in which order the arguments should be passed. * The default implementation uses the evaluation function each time, but evaluators are free @@ -50,4 +54,4 @@ trait DeterministicEvaluator extends Evaluator { trait NDEvaluator extends Evaluator { type Value = Stream[Expr] -} \ No newline at end of file +} diff --git a/src/main/scala/leon/evaluators/EvaluatorContexts.scala b/src/main/scala/leon/evaluators/EvaluatorContexts.scala index 776e389dd855736132d8c066fc837cc2b4bc34ad..a63ee6483bfcdb7804cd95bfcb32df83a66e235b 100644 --- a/src/main/scala/leon/evaluators/EvaluatorContexts.scala +++ b/src/main/scala/leon/evaluators/EvaluatorContexts.scala @@ -4,9 +4,11 @@ package leon package evaluators import purescala.Common.Identifier -import purescala.Expressions.Expr +import leon.purescala.Expressions.{Lambda, Expr} import solvers.Model +import scala.collection.mutable.{Map => MutableMap} + trait RecContext[RC <: RecContext[RC]] { def mappings: Map[Identifier, Expr] @@ -25,15 +27,18 @@ case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext def newVars(news: Map[Identifier, Expr]) = copy(news) } -class GlobalContext(val model: Model, val maxSteps: Int) { +class GlobalContext(val model: Model, val maxSteps: Int, val check: Boolean) { var stepsLeft = maxSteps -} - -protected[evaluators] trait DefaultContexts extends ContextualEvaluator { - final type RC = DefaultRecContext - final type GC = GlobalContext + val lambdas: MutableMap[Lambda, Lambda] = MutableMap.empty +} +trait HasDefaultRecContext extends ContextualEvaluator { + type RC = DefaultRecContext def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC(model: solvers.Model) = new GlobalContext(model, this.maxSteps) +} + +trait HasDefaultGlobalContext extends ContextualEvaluator { + def initGC(model: solvers.Model, check: Boolean) = new GlobalContext(model, this.maxSteps, check) + type GC = GlobalContext } \ No newline at end of file diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index b6d1e4c8e4706dbd298bdf38f1af07ad30d96165..2bc4a2bbc337b0e6952090627ccfb0257fb1d947 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -3,18 +3,19 @@ package leon package evaluators -import leon.purescala.Constructors._ -import leon.purescala.ExprOps._ -import leon.purescala.Expressions.Pattern -import leon.purescala.Extractors._ import leon.purescala.Quantification._ -import leon.purescala.TypeOps._ -import leon.purescala.Types._ -import leon.solvers.{SolverFactory, HenkinModel} +import purescala.Constructors._ +import purescala.ExprOps._ +import purescala.Expressions.Pattern +import purescala.Extractors._ +import purescala.TypeOps._ +import purescala.Types._ import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ -import leon.utils.DebugSectionSynthesis +import leon.solvers.{HenkinModel, Model, SolverFactory} + +import scala.collection.mutable.{Map => MutableMap} abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int) extends ContextualEvaluator(ctx, prog, maxSteps) @@ -45,11 +46,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val newArgs = args.map(e) val mapping = l.paramSubst(newArgs) e(body)(rctx.withNewVars(mapping), gctx) - case PartialLambda(mapping, _) => + case PartialLambda(mapping, dflt, _) => mapping.find { case (pargs, res) => (args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true)) - }.map(_._2).getOrElse { - throw EvalError("Cannot apply partial lambda outside of domain") + }.map(_._2).orElse(dflt).getOrElse { + throw EvalError("Cannot apply partial lambda outside of domain : " + + args.map(e(_).asString(ctx)).mkString("(", ", ", ")")) } case f => throw EvalError("Cannot apply non-lambda function " + f.asString) @@ -183,7 +185,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (lv,rv) match { case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) - case (PartialLambda(m1, _), PartialLambda(m2, _)) => BooleanLiteral(m1.toSet == m2.toSet) + case (PartialLambda(m1, d1, _), PartialLambda(m2, d2, _)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) case _ => BooleanLiteral(lv == rv) } @@ -489,7 +491,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case SetCardinality(s) => val sr = e(s) sr match { - case FiniteSet(els, _) => IntLiteral(els.size) + case FiniteSet(els, _) => InfiniteIntegerLiteral(els.size) case _ => throw EvalError(typeErrorMsg(sr, SetType(Untyped))) } @@ -497,20 +499,19 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int FiniteSet(els.map(e), base) case l @ Lambda(_, _) => - val (nl, structSubst) = normalizeStructure(l) + val (nl, structSubst) = normalizeStructure(matchToIfThenElse(l)) val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap - replaceFromIDs(mapping, nl) + val newLambda = replaceFromIDs(mapping, nl).asInstanceOf[Lambda] + if (!gctx.lambdas.isDefinedAt(newLambda)) { + gctx.lambdas += (newLambda -> nl.asInstanceOf[Lambda]) + } + newLambda - case PartialLambda(mapping, tpe) => - PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), tpe) + case PartialLambda(mapping, dflt, tpe) => + PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), dflt.map(e), tpe) - case f @ Forall(fargs, TopLevelAnds(conjuncts)) => - e(andJoin(for (conj <- conjuncts) yield { - val instantiations = forallInstantiations(gctx, fargs, conj) - e(andJoin(instantiations.map { case (enabler, mapping) => - e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx) - })) - })) + case Forall(fargs, body) => + evalForall(fargs.map(_.id).toSet, body) case ArrayLength(a) => val FiniteArray(_, _, IntLiteral(length)) = e(a) @@ -716,6 +717,140 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } + protected def evalForall(quants: Set[Identifier], body: Expr, check: Boolean = true)(implicit rctx: RC, gctx: GC): Expr = { + val henkinModel: HenkinModel = gctx.model match { + case hm: HenkinModel => hm + case _ => throw EvalError("Can't evaluate foralls without henkin model") +} + + val TopLevelAnds(conjuncts) = body + e(andJoin(conjuncts.flatMap { conj => + val vars = variablesOf(conj) + val quantified = quants.filter(vars) + + extractQuorums(conj, quantified).flatMap { case (qrm, others) => + val quorum = qrm.toList + + if (quorum.exists { case (TopLevelAnds(paths), _, _) => + val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty)) + e(p) == BooleanLiteral(false) + }) List(BooleanLiteral(true)) else { + + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + + for (((_, expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id),aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + def domain(expr: Expr): Set[Seq[Expr]] = henkinModel.domain(e(expr) match { + case l: Lambda => gctx.lambdas.getOrElse(l, l) + case ev => ev + }) + + val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { + case (acc, (_, expr, _)) => acc.flatMap(s => domain(expr).map(d => s :+ d)) + } + + argSets.map { args => + val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { + case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } + }.toMap + + val map = mapping.map { case (id, key) => id -> argMap(key) } + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) + } ++ equalities.map { + case (k1, k2) => Equals(argMap(k1), argMap(k2)) + }) + + val ctx = rctx.withNewVars(map) + if (e(enabler)(ctx, gctx) == BooleanLiteral(true)) { + if (gctx.check) { + for ((b,caller,args) <- others if e(b)(ctx, gctx) == BooleanLiteral(true)) { + val evArgs = args.map(arg => e(arg)(ctx, gctx)) + if (!domain(caller)(evArgs)) + throw QuantificationError("Unhandled transitive implication in " + replaceFromIDs(map, conj)) + } + } + + e(conj)(ctx, gctx) + } else { + BooleanLiteral(true) + } + } + } + } + })) match { + case res @ BooleanLiteral(true) if check => + if (gctx.check) { + checkForall(quants, body) match { + case status: ForallInvalid => + throw QuantificationError("Invalid forall: " + status.getMessage) + case _ => + // make sure the body doesn't contain matches or lets as these introduce new locals + val cleanBody = expandLets(matchToIfThenElse(body)) + val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] { + def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match { + case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path)) + case _ => None + } + + override def rec(e: Expr, path: Seq[Expr]): Expr = e match { + case l : Lambda => l + case _ => super.rec(e, path) + } + }.traverse(cleanBody) + + for ((caller, appArgs, paths) <- calls) { + val path = andJoin(paths.filter(expr => (variablesOf(expr) & quants).isEmpty)) + if (e(path) == BooleanLiteral(true)) e(caller) match { + case _: PartialLambda => // OK + case l: Lambda => + val nl @ Lambda(args, body) = gctx.lambdas.getOrElse(l, l) + val lambdaQuantified = (appArgs zip args).collect { + case (Variable(id), vd) if quants(id) => vd.id + }.toSet + + if (lambdaQuantified.nonEmpty) { + checkForall(lambdaQuantified, body) match { + case lambdaStatus: ForallInvalid => + throw QuantificationError("Invalid forall: " + lambdaStatus.getMessage) + case _ => // do nothing + } + + val axiom = Equals(Application(nl, args.map(_.toVariable)), nl.body) + if (evalForall(args.map(_.id).toSet, axiom, check = false) == BooleanLiteral(false)) { + throw QuantificationError("Unaxiomatic lambda " + l) + } + } + case f => + throw EvalError("Cannot apply non-lambda function " + f.asString) + } + } + } + } + + res + + // `res == false` means the quantification is valid since there effectivelly must + // exist an input for which the proposition doesn't hold + case res => res + } + } } diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala index 72a4dd243ea52829fd9136a3923427bfd4f8a5d2..f9ef63bdd9ad8650d44877ad03ff109fbd387200 100644 --- a/src/main/scala/leon/evaluators/StreamEvaluator.scala +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -18,16 +18,13 @@ import leon.utils.StreamUtils._ class StreamEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvaluator(ctx, prog, 50000) with NDEvaluator - with DefaultContexts + with HasDefaultGlobalContext + with HasDefaultRecContext { val name = "ND-evaluator" val description = "Non-deterministic interpreter for Leon programs that returns a Stream of solutions" - case class NDValue(tp: TypeTree) extends Expr with Terminal { - val getType = tp - } - protected[evaluators] def e(expr: Expr)(implicit rctx: RC, gctx: GC): Stream[Expr] = expr match { case Variable(id) => rctx.mappings.get(id).toStream @@ -40,7 +37,8 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) case l @ Lambda(params, body) => val mapping = l.paramSubst(newArgs) e(body)(rctx.withNewVars(mapping), gctx).distinct - case PartialLambda(mapping, _) => + case PartialLambda(mapping, _, _) => + // FIXME mapping.collectFirst { case (pargs, res) if (newArgs zip pargs).forall { case (f, r) => f == r } => res @@ -74,12 +72,6 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) case Error(tpe, desc) => Stream() - case NDValue(tp) => - // FIXME: This is the only source of infinite values, and will in a way break - // the evaluator: the evaluator is not designed to fairly handle infinite streams. - // Of course currently it is only used for boolean type, which is finite :) - valuesOf(tp) - case IfExpr(cond, thenn, elze) => e(cond).distinct.flatMap { case BooleanLiteral(true) => e(thenn) @@ -143,7 +135,8 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) ).toMap Stream(replaceFromIDs(mapping, nl)) - case PartialLambda(mapping, tpe) => + // FIXME + case PartialLambda(mapping, tpe, df) => def solveOne(pair: (Seq[Expr], Expr)) = { val (args, res) = pair for { @@ -151,11 +144,11 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) r <- e(res) } yield as -> r } - cartesianProduct(mapping map solveOne) map (PartialLambda(_, tpe)) + cartesianProduct(mapping map solveOne) map (PartialLambda(_, tpe, df)) // FIXME!!! case f @ Forall(fargs, TopLevelAnds(conjuncts)) => - - def solveOne(conj: Expr) = { + Stream() // FIXME + /*def solveOne(conj: Expr) = { val instantiations = forallInstantiations(gctx, fargs, conj) for { es <- cartesianProduct(instantiations.map { case (enabler, mapping) => @@ -168,7 +161,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) for { conj <- cartesianProduct(conjuncts map solveOne) res <- e(andJoin(conj)) - } yield res + } yield res*/ case p : Passes => e(p.asConstraint) @@ -228,7 +221,8 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) solverf.shutdown() } sol - }).takeWhile(_.isDefined).map(_.get) + }).takeWhile(_.isDefined).take(10).map(_.get) + // This take(10) is there because we are not working well with infinite streams yet... } catch { case e: Throwable => solverf.reclaim(solver) @@ -352,7 +346,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) (lv, rv) match { case (FiniteSet(el1, _), FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _), FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) - case (PartialLambda(m1, _), PartialLambda(m2, _)) => BooleanLiteral(m1.toSet == m2.toSet) + case (PartialLambda(m1, _, d1), PartialLambda(m2, _, d2)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) case _ => BooleanLiteral(lv == rv) } @@ -588,6 +582,4 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) } - } - diff --git a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala index 84da8f21d7957aaba69ae86811bc8c8ec63a4ae3..43ff4bc2459b9a125d60d2ce16992e90b491505b 100644 --- a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala @@ -10,7 +10,7 @@ import purescala.Definitions.Program import purescala.Expressions.Expr import leon.utils.DebugSectionSynthesis -class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvaluator(ctx, prog, 50000) with DefaultContexts { +class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvaluator(ctx, prog, 50000) with HasDefaultGlobalContext with HasDefaultRecContext { val underlying = new DefaultEvaluator(ctx, prog) { override protected[evaluators] def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index ad2d5723e7f9744e2e8e41e227055dc91fec9b34..4c0b1f39c9126e4ccf0da6db389394fe0c33d294 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -14,9 +14,10 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex def initRC(mappings: Map[Identifier, Expr]) = TracingRecContext(mappings, 2) - def initGC(model: solvers.Model) = new TracingGlobalContext(Nil, model) + def initGC(model: solvers.Model, check: Boolean) = new TracingGlobalContext(Nil, model, check) - class TracingGlobalContext(var values: List[(Tree, Expr)], model: solvers.Model) extends GlobalContext(model, maxSteps) + class TracingGlobalContext(var values: List[(Tree, Expr)], model: solvers.Model, check: Boolean) + extends GlobalContext(model, this.maxSteps, check) case class TracingRecContext(mappings: Map[Identifier, Expr], tracingFrames: Int) extends RecContext[TracingRecContext] { def newVars(news: Map[Identifier, Expr]) = copy(mappings = news) diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 855efa3360898deb67380c6ea5f3038de6e2dbad..68a4f4bf3d93f3b94c1b71a10d55ab6121a1792d 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -593,6 +593,15 @@ trait ASTExtractors { } } + object ExOldExpression { + def unapply(tree: Apply) : Option[Symbol] = tree match { + case a @ Apply(TypeApply(ExSymbol("leon", "lang", "old"), List(tpe)), List(arg)) => + Some(arg.symbol) + case _ => + None + } + } + object ExHoleExpression { def unapply(tree: Tree) : Option[(Tree, List[Tree])] = tree match { case a @ Apply(TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "$qmark"), List(tpt)), args1) => diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 1903a0c1f7aab62aee1524078417a5adc1024e2a..f3995ee635cbdd80aefb5f48503908be1f1d60d6 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1046,6 +1046,15 @@ trait CodeExtraction extends ASTExtractors { val tupleExprs = exprs.map(e => extractTree(e)) Tuple(tupleExprs) + case ex@ExOldExpression(sym) if dctx.isVariable(sym) => + dctx.vars.get(sym).orElse(dctx.mutableVars.get(sym)) match { + case Some(builder) => + val Variable(id) = builder() + Old(id).setPos(ex.pos) + case None => + outOfSubsetError(current, "old can only be used with variables") + } + case ExErrorExpression(str, tpt) => Error(extractType(tpt), str) @@ -1097,7 +1106,7 @@ trait CodeExtraction extends ASTExtractors { val oldCurrentFunDef = currentFunDef - val funDefWithBody = extractFunBody(fd, params, b)(newDctx.copy(mutableVars = Map())) + val funDefWithBody = extractFunBody(fd, params, b)(newDctx) currentFunDef = oldCurrentFunDef @@ -1192,11 +1201,11 @@ trait CodeExtraction extends ASTExtractors { } getOwner(lhsRec) match { - case Some(Some(fd)) if fd != currentFunDef => - outOfSubsetError(tr, "cannot update an array that is not defined locally") + // case Some(Some(fd)) if fd != currentFunDef => + // outOfSubsetError(tr, "cannot update an array that is not defined locally") - case Some(None) => - outOfSubsetError(tr, "cannot update an array that is not defined locally") + // case Some(None) => + // outOfSubsetError(tr, "cannot update an array that is not defined locally") case Some(_) => @@ -1622,6 +1631,9 @@ trait CodeExtraction extends ASTExtractors { or(a1, a2) // Set methods + case (IsTyped(a1, SetType(b1)), "size", Nil) => + SetCardinality(a1) + //case (IsTyped(a1, SetType(b1)), "min", Nil) => // SetMin(a1) @@ -1848,7 +1860,11 @@ trait CodeExtraction extends ASTExtractors { case AnnotatedType(_, tpe) => extractType(tpe) case _ => - outOfSubsetError(tpt.typeSymbol.pos, "Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")") + if (tpt ne null) { + outOfSubsetError(tpt.typeSymbol.pos, "Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")") + } else { + outOfSubsetError(NoPosition, "Tree with null-pointer as type found") + } } private def getClassType(sym: Symbol, tps: List[LeonType])(implicit dctx: DefContext) = { diff --git a/src/main/scala/leon/invariant/structure/FunctionUtils.scala b/src/main/scala/leon/invariant/structure/FunctionUtils.scala index a0acdcda7388ee4b6971106cf78cdb49c42768a4..565a6f9f41cbc140c7b6814606897f7cc7eb11f1 100644 --- a/src/main/scala/leon/invariant/structure/FunctionUtils.scala +++ b/src/main/scala/leon/invariant/structure/FunctionUtils.scala @@ -5,13 +5,11 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ import invariant.factories._ import invariant.util._ import Util._ import PredicateUtil._ -import ProgramUtil._ import scala.language.implicitConversions /** @@ -37,8 +35,8 @@ object FunctionUtils { def isTemplateInvocation(finv: Expr) = { finv match { case FunctionInvocation(funInv, args) => - (funInv.id.name == "tmpl" && funInv.returnType == BooleanType && - args.size == 1 && args(0).isInstanceOf[Lambda]) + funInv.id.name == "tmpl" && funInv.returnType == BooleanType && + args.size == 1 && args(0).isInstanceOf[Lambda] case _ => false } @@ -46,8 +44,7 @@ object FunctionUtils { def isQMark(e: Expr) = e match { case FunctionInvocation(TypedFunDef(fd, Seq()), args) => - (fd.id.name == "?" && fd.returnType == IntegerType && - args.size <= 1) + fd.id.name == "?" && fd.returnType == IntegerType && args.size <= 1 case _ => false } @@ -104,7 +101,7 @@ object FunctionUtils { val Lambda(_, postBody) = fd.postcondition.get // collect all terms with question marks and convert them to a template val postWoQmarks = postBody match { - case And(args) if args.exists(exists(isQMark) _) => + case And(args) if args.exists(exists(isQMark)) => val (tempExprs, otherPreds) = args.partition { case a if exists(isQMark)(a) => true case _ => false diff --git a/src/main/scala/leon/invariant/util/CallGraph.scala b/src/main/scala/leon/invariant/util/CallGraph.scala index 23f87d5bc0b62163ef40e852d7cc0af0cd658bc3..5e93c451215e9e1f6679e7e94da47a4a2d707a3c 100644 --- a/src/main/scala/leon/invariant/util/CallGraph.scala +++ b/src/main/scala/leon/invariant/util/CallGraph.scala @@ -1,15 +1,10 @@ package leon package invariant.util -import purescala._ -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ import ProgramUtil._ -import Util._ import invariant.structure.FunctionUtils._ import invariant.datastructure._ @@ -71,7 +66,7 @@ class CallGraph { graph.getNodes.toList.foreach((f) => { var inserted = false var index = 0 - for (i <- 0 to funcList.length - 1) { + for (i <- funcList.indices) { if (!inserted && this.transitivelyCalls(funcList(i), f)) { index = i inserted = true @@ -97,7 +92,6 @@ class CallGraph { object CallGraphUtil { def constructCallGraph(prog: Program, onlyBody: Boolean = false, withTemplates: Boolean = false): CallGraph = { -// // println("Constructing call graph") val cg = new CallGraph() functionsWOFields(prog.definedFunctions).foreach((fd) => { @@ -125,17 +119,11 @@ object CallGraphUtil { cg } - def getCallees(expr: Expr): Set[FunDef] = { - var callees = Set[FunDef]() - simplePostTransform((expr) => expr match { - //note: do not consider field invocations - case FunctionInvocation(TypedFunDef(callee, _), args) - if callee.isRealFunction => { - callees += callee - expr - } - case _ => expr - })(expr) - callees - } + def getCallees(expr: Expr): Set[FunDef] = collect { + case expr@FunctionInvocation(TypedFunDef(callee, _), _) if callee.isRealFunction => + Set(callee) + case _ => + Set[FunDef]() + }(expr) + } \ No newline at end of file diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 728b602f9fa343ae8657a2922ca03a3dfe1ada3f..7625eddf86a1656c96b46a9ad7fb3fd2d894dae2 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -161,7 +161,7 @@ object Constructors { BooleanLiteral(true) } } - /** $encodingof `... match { ... }` but simplified if possible. Throws an error if no case can match the scrutined expression. + /** $encodingof `... match { ... }` but simplified if possible. Simplifies to [[Error]] if no case can match the scrutined expression. * @see [[purescala.Expressions.MatchExpr MatchExpr]] */ def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : Expr ={ @@ -249,8 +249,8 @@ object Constructors { /** $encodingof Simplified `Array(...)` (array length defined at compile-time) * @see [[purescala.Expressions.NonemptyArray NonemptyArray]] */ - def finiteArray(els: Seq[Expr]): Expr = { - require(els.nonEmpty) + def finiteArray(els: Seq[Expr], tpe: TypeTree = Untyped): Expr = { + require(els.nonEmpty || tpe != Untyped) finiteArray(els, None, Untyped) // Untyped is not correct, but will not be used anyway } /** $encodingof Simplified `Array[...](...)` (array length and default element defined at run-time) with type information diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index c8bebf741251b873ee6535d180985b8100a8e9c7..a7a20e2d477ac478d6757723c0f36d39bbf7fe3e 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -345,6 +345,7 @@ object ExprOps { val subvs = subs.flatten.toSet e match { case Variable(i) => subvs + i + case Old(i) => subvs + i case LetDef(fd, _) => subvs -- fd.params.map(_.id) case Let(i, _, _) => subvs - i case LetVar(i, _, _) => subvs - i @@ -1139,9 +1140,8 @@ object ExprOps { case tp: TypeParameter => GenericValue(tp, 0) - case FunctionType(from, to) => - val args = from.map(tpe => ValDef(FreshIdentifier("x", tpe, alwaysShowUniqueID = true))) - Lambda(args, simplestValue(to)) + case ft @ FunctionType(from, to) => + PartialLambda(Seq.empty, Some(simplestValue(to)), ft) case _ => throw LeonFatalError("I can't choose simplest value for type " + tpe) } @@ -1994,12 +1994,54 @@ object ExprOps { es foreach rec } - def functionAppsOf(expr: Expr): Set[Application] = { - collect[Application] { - case f: Application => Set(f) - case _ => Set() + object InvocationExtractor { + private def flatInvocation(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { + case fi @ FunctionInvocation(tfd, args) => Some((tfd, args)) + case Application(caller, args) => flatInvocation(caller) match { + case Some((tfd, prevArgs)) => Some((tfd, prevArgs ++ args)) + case None => None + } + case _ => None + } + + def unapply(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { + case IsTyped(f: FunctionInvocation, ft: FunctionType) => None + case IsTyped(f: Application, ft: FunctionType) => None + case FunctionInvocation(tfd, args) => Some(tfd -> args) + case f: Application => flatInvocation(f) + case _ => None + } + } + + def firstOrderCallsOf(expr: Expr): Set[(TypedFunDef, Seq[Expr])] = + collect[(TypedFunDef, Seq[Expr])] { + case InvocationExtractor(tfd, args) => Set(tfd -> args) + case _ => Set.empty }(expr) + + object ApplicationExtractor { + private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(fi: FunctionInvocation, _) => None + case Application(caller: Application, args) => flatApplication(caller) match { + case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) + case None => None } + case Application(caller, args) => Some((caller, args)) + case _ => None + } + + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case IsTyped(f: Application, ft: FunctionType) => None + case f: Application => flatApplication(f) + case _ => None + } + } + + def firstOrderAppsOf(expr: Expr): Set[(Expr, Seq[Expr])] = + collect[(Expr, Seq[Expr])] { + case ApplicationExtractor(caller, args) => Set(caller -> args) + case _ => Set.empty + } (expr) def simplifyHOFunctions(expr: Expr) : Expr = { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 8d6e00ab0271bbbdef497a539ccbce40ac293066..40aa96e95b67b45452290e4a1a7fa4809fea2c01 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -76,6 +76,10 @@ object Expressions { val getType = tpe } + case class Old(id: Identifier) extends Expr with Terminal { + val getType = id.getType + } + /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* * * @param pred The precondition formula inside ``require(...)`` @@ -226,7 +230,7 @@ object Expressions { } } - case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], tpe: FunctionType) extends Expr { + case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], default: Option[Expr], tpe: FunctionType) extends Expr { val getType = tpe } @@ -806,7 +810,7 @@ object Expressions { } /** $encodingof `set.length` */ case class SetCardinality(set: Expr) extends Expr { - val getType = Int32Type + val getType = IntegerType } /** $encodingof `set.subsetOf(set2)` */ case class SubsetOf(set1: Expr, set2: Expr) extends Expr { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index a467e3083c9e723570a7cea0007836a8725d500b..4673ea6dab2a33fc94b6a6130527099bd706dd2c 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -53,7 +53,7 @@ object Extractors { Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head))) case Lambda(args, body) => Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head))) - case PartialLambda(mapping, tpe) => + case PartialLambda(mapping, dflt, tpe) => val sze = tpe.from.size + 1 val subArgs = mapping.flatMap { case (args, v) => args :+ v } val builder = (as: Seq[Expr]) => { @@ -64,9 +64,10 @@ object Extractors { case Seq() => Seq.empty case _ => sys.error("unexpected number of key/value expressions") } - PartialLambda(rec(as), tpe) + val (nas, nd) = if (dflt.isDefined) (as.init, Some(as.last)) else (as, None) + PartialLambda(rec(nas), nd, tpe) } - Some((subArgs, builder)) + Some((subArgs ++ dflt, builder)) case Forall(args, body) => Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head))) @@ -196,9 +197,14 @@ object Extractors { val l = as.length nonemptyArray(as.take(l - 2), Some((as(l - 2), as(l - 1)))) })) - case NonemptyArray(elems, None) => - val all = elems.values.toSeq - Some((all, finiteArray)) + case na@NonemptyArray(elems, None) => + val ArrayType(tpe) = na.getType + val (indexes, elsOrdered) = elems.toSeq.unzip + + Some(( + elsOrdered, + es => finiteArray(indexes.zip(es).toMap, None, tpe) + )) case Tuple(args) => Some((args, tupleWrap)) case IfExpr(cond, thenn, elze) => Some(( Seq(cond, thenn, elze), @@ -358,22 +364,20 @@ object Extractors { } def unapply(me : MatchExpr) : Option[(Pattern, Expr, Expr)] = { - if (me eq null) None else { me match { + Option(me) collect { case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, variablesOf(scrut)) => - Some(( pattern, scrut, body )) - case _ => None - }} + ( pattern, scrut, body ) + } } } object LetTuple { def unapply(me : MatchExpr) : Option[(Seq[Identifier], Expr, Expr)] = { - if (me eq null) None else { me match { + Option(me) collect { case LetPattern(TuplePattern(None,subPatts), value, body) if subPatts forall { case WildcardPattern(Some(_)) => true; case _ => false } => - Some((subPatts map { _.binder.get }, value, body )) - case _ => None - }} + (subPatts map { _.binder.get }, value, body ) + } } } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 33c200a5ac202c9e18250de4242604d97cff1af7..307d83a4bbf0d8b36bb4aba3c203e860e3af75be 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -79,6 +79,9 @@ class PrettyPrinter(opts: PrinterOptions, } p"$name" + case Old(id) => + p"old($id)" + case Variable(id) => p"$id" @@ -265,6 +268,22 @@ class PrettyPrinter(opts: PrinterOptions, case Lambda(args, body) => optP { p"($args) => $body" } + case PartialLambda(mapping, dflt, _) => + optP { + def pm(p: (Seq[Expr], Expr)): PrinterHelpers.Printable = + (pctx: PrinterContext) => p"${purescala.Constructors.tupleWrap(p._1)} => ${p._2}"(pctx) + + if (mapping.isEmpty) { + p"{}" + } else { + p"{ ${nary(mapping map pm)} }" + } + + if (dflt.isDefined) { + p" getOrElse ${dflt.get}" + } + } + case Plus(l,r) => optP { p"$l + $r" } case Minus(l,r) => optP { p"$l - $r" } case Times(l,r) => optP { p"$l * $r" } diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala index 1b00ed1b41a5053fb07a94695b00f595d54453ba..bb5115baab042a1bd524fdb2f73deeabefa59243 100644 --- a/src/main/scala/leon/purescala/Quantification.scala +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -6,10 +6,13 @@ package purescala import Common._ import Definitions._ import Expressions._ +import Constructors._ import Extractors._ import ExprOps._ import Types._ +import evaluators._ + object Quantification { def extractQuorums[A,B]( @@ -18,6 +21,12 @@ object Quantification { margs: A => Set[A], qargs: A => Set[B] ): Seq[Set[A]] = { + def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) + val expandedMap: Map[A, Set[A]] = matchers.map(m => m -> expand(m)).toMap + val reverseMap : Map[A, Set[A]] = expandedMap.toSeq + .flatMap(p => p._2.map(m => m -> p._1)) // flatten to reversed pairs + .groupBy(_._1).mapValues(_.map(_._2).toSet) // rebuild map from pair set + def rec(oms: Seq[A], mSet: Set[A], qss: Seq[Set[B]]): Seq[Set[A]] = { if (qss.contains(quantified)) { Seq(mSet) @@ -34,12 +43,13 @@ object Quantification { } } - def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) - val oms = matchers.toSeq.sortBy(m => -expand(m).size) - rec(oms, Set.empty, Seq.empty) + val oms = expandedMap.toSeq.sortBy(p => -p._2.size).map(_._1) + val res = rec(oms, Set.empty, Seq.empty) + + res.filter(ms => ms.forall(m => reverseMap(m) subsetOf ms)) } - def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Expr, Seq[Expr])]] = { + def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[(Set[(Expr, Expr, Seq[Expr])], Set[(Expr, Expr, Seq[Expr])])] = { object QMatcher { def unapply(e: Expr): Option[(Expr, Seq[Expr])] = e match { case QuantificationMatcher(expr, args) => @@ -52,33 +62,73 @@ object Quantification { } } - extractQuorums(collect { - case QMatcher(e, a) => Set(e -> a) - case _ => Set.empty[(Expr, Seq[Expr])] - } (expr), quantified, - (p: (Expr, Seq[Expr])) => p._2.collect { case QMatcher(e, a) => e -> a }.toSet, - (p: (Expr, Seq[Expr])) => p._2.collect { case Variable(id) if quantified(id) => id }.toSet) + val allMatchers = CollectorWithPaths { case QMatcher(expr, args) => expr -> args }.traverse(expr) + val matchers = allMatchers.map { case ((caller, args), path) => (path, caller, args) }.toSet + + val quorums = extractQuorums(matchers, quantified, + (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case QMatcher(e, a) => (p._1, e, a) }.toSet, + (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case Variable(id) if quantified(id) => id }.toSet) + + quorums.map(quorum => quorum -> matchers.filter(m => !quorum(m))) + } + + def extractModel( + asMap: Map[Identifier, Expr], + funDomains: Map[Identifier, Set[Seq[Expr]]], + typeDomains: Map[TypeTree, Set[Seq[Expr]]], + evaluator: DeterministicEvaluator + ): Map[Identifier, Expr] = asMap.map { case (id, expr) => + id -> (funDomains.get(id) match { + case Some(domain) => + PartialLambda(domain.toSeq.map { es => + val optEv = evaluator.eval(Application(expr, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(expr, es))) + }, None, id.getType.asInstanceOf[FunctionType]) + + case None => postMap { + case p @ PartialLambda(mapping, dflt, tpe) => + Some(PartialLambda(typeDomains.get(tpe) match { + case Some(domain) => domain.toSeq.map { es => + val optEv = evaluator.eval(Application(p, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(p, es))) + } + case _ => Seq.empty + }, None, tpe)) + case _ => None + } (expr) + }) } object HenkinDomains { - def empty = new HenkinDomains(Map.empty) - def apply(domains: Map[TypeTree, Set[Seq[Expr]]]) = new HenkinDomains(domains) + def empty = new HenkinDomains(Map.empty, Map.empty) } - class HenkinDomains (val domains: Map[TypeTree, Set[Seq[Expr]]]) { - def get(e: Expr): Set[Seq[Expr]] = e match { - case PartialLambda(mapping, _) => mapping.map(_._1).toSet - case _ => domains.get(e.getType) match { - case Some(domain) => domain - case None => scala.sys.error("Undefined Henkin domain for " + e) + class HenkinDomains (val lambdas: Map[Lambda, Set[Seq[Expr]]], val tpes: Map[TypeTree, Set[Seq[Expr]]]) { + def get(e: Expr): Set[Seq[Expr]] = { + val specialized: Set[Seq[Expr]] = e match { + case PartialLambda(_, Some(dflt), _) => scala.sys.error("No domain for non-partial lambdas") + case PartialLambda(mapping, _, _) => mapping.map(_._1).toSet + case l: Lambda => lambdas.getOrElse(l, Set.empty) + case _ => Set.empty } + specialized ++ tpes.getOrElse(e.getType, Set.empty) } } object QuantificationMatcher { + private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(fi: FunctionInvocation, args) => Some((fi, args)) + case Application(caller: Application, args) => flatApplication(caller) match { + case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) + case None => None + } + case Application(caller, args) => Some((caller, args)) + case _ => None + } + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(_: Application | _: FunctionInvocation, _) => None - case Application(e, args) => Some(e -> args) + case IsTyped(a: Application, ft: FunctionType) => None + case Application(e, args) => flatApplication(expr) case ArraySelect(arr, index) => Some(arr -> Seq(index)) case MapApply(map, key) => Some(map -> Seq(key)) case ElementOfSet(elem, set) => Some(set -> Seq(elem)) @@ -87,8 +137,15 @@ object Quantification { } object QuantificationTypeMatcher { + private def flatType(tpe: TypeTree): (Seq[TypeTree], TypeTree) = tpe match { + case FunctionType(from, to) => + val (nextArgs, finalTo) = flatType(to) + (from ++ nextArgs, finalTo) + case _ => (Seq.empty, tpe) + } + def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match { - case FunctionType(from, to) => Some(from -> to) + case FunctionType(from, to) => Some(flatType(tpe)) case ArrayType(base) => Some(Seq(Int32Type) -> base) case MapType(from, to) => Some(Seq(from) -> to) case SetType(base) => Some(Seq(base) -> BooleanType) @@ -96,87 +153,84 @@ object Quantification { } } - object CheckForalls extends UnitPhase[Program] { - - val name = "Foralls" - val description = "Check syntax of foralls to guarantee sound instantiations" - - def apply(ctx: LeonContext, program: Program) = { - program.definedFunctions.foreach { fd => - val foralls = collect[Forall] { - case f: Forall => Set(f) - case _ => Set.empty - } (fd.fullBody) - - val free = fd.paramIds.toSet ++ (fd.postcondition match { - case Some(Lambda(args, _)) => args.map(_.id) - case _ => Seq.empty + sealed abstract class ForallStatus { + def isValid: Boolean + } + + case object ForallValid extends ForallStatus { + def isValid = true + } + + sealed abstract class ForallInvalid(msg: String) extends ForallStatus { + def isValid = false + def getMessage: String = msg + } + + case class NoMatchers(expr: String) extends ForallInvalid("No matchers available for E-Matching in " + expr) + case class ComplexArgument(expr: String) extends ForallInvalid("Unhandled E-Matching pattern in " + expr) + case class NonBijectiveMapping(expr: String) extends ForallInvalid("Non-bijective mapping for quantifiers in " + expr) + case class InvalidOperation(expr: String) extends ForallInvalid("Invalid operation on quantifiers in " + expr) + + def checkForall(quantified: Set[Identifier], body: Expr)(implicit ctx: LeonContext): ForallStatus = { + val TopLevelAnds(conjuncts) = body + for (conjunct <- conjuncts) { + val matchers = collect[(Expr, Seq[Expr])] { + case QuantificationMatcher(e, args) => Set(e -> args) + case _ => Set.empty + } (conjunct) + + if (matchers.isEmpty) return NoMatchers(conjunct.asString) + + val complexArgs = matchers.flatMap { case (_, args) => + args.flatMap(arg => arg match { + case QuantificationMatcher(_, _) => None + case Variable(id) => None + case _ if (variablesOf(arg) & quantified).nonEmpty => Some(arg) + case _ => None }) + } - for (Forall(args, TopLevelAnds(conjuncts)) <- foralls) { - val quantified = args.map(_.id).toSet - - for (conjunct <- conjuncts) { - val matchers = collect[(Expr, Seq[Expr])] { - case QuantificationMatcher(e, args) => Set(e -> args) - case _ => Set.empty - } (conjunct) - - if (matchers.isEmpty) - ctx.reporter.warning("E-matching isn't possible without matchers!") - - if (matchers.exists { case (_, args) => - args.exists{ - case QuantificationMatcher(_, _) => false - case Variable(id) => false - case arg => (variablesOf(arg) & quantified).nonEmpty - } - }) ctx.reporter.warning("Matcher arguments must have simple form in " + conjunct) - - val freeMatchers = matchers.collect { case (Variable(id), args) if free(id) => id -> args } - - val id2Quant = freeMatchers.foldLeft(Map.empty[Identifier, Set[Identifier]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - })) - } + if (complexArgs.nonEmpty) return ComplexArgument(complexArgs.head.asString) - if (id2Quant.filter(_._2.nonEmpty).groupBy(_._2).nonEmpty) - ctx.reporter.warning("Multiple matchers must provide bijective matching in " + conjunct) - - fold[Set[Identifier]] { case (m, children) => - val q = children.toSet.flatten - - m match { - case QuantificationMatcher(_, args) => - q -- args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - } - case LessThan(_: Variable, _: Variable) => q - case LessEquals(_: Variable, _: Variable) => q - case GreaterThan(_: Variable, _: Variable) => q - case GreaterEquals(_: Variable, _: Variable) => q - case And(_) => q - case Or(_) => q - case Implies(_, _) => q - case Operator(es, _) => - val vars = es.flatMap { - case Variable(id) => Set(id) - case _ => Set.empty[Identifier] - }.toSet - - if (!(q.isEmpty || (q.size == 1 && (vars & free).isEmpty))) - ctx.reporter.warning("Invalid operation " + m + " on quantified variables") - q -- vars - case Variable(id) if quantified(id) => Set(id) - case _ => q - } - } (conjunct) - } - } + val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + })) } + + val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) + if (bijectiveMappings.size > 1) return NonBijectiveMapping(bijectiveMappings.head._2.head._1.asString) + + val matcherSet = matcherToQuants.filter(_._2.nonEmpty).keys.toSet + + val qs = fold[Set[Identifier]] { case (m, children) => + val q = children.toSet.flatten + + m match { + case QuantificationMatcher(_, args) => + q -- args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + } + case LessThan(_: Variable, _: Variable) => q + case LessEquals(_: Variable, _: Variable) => q + case GreaterThan(_: Variable, _: Variable) => q + case GreaterEquals(_: Variable, _: Variable) => q + case And(_) => q + case Or(_) => q + case Implies(_, _) => q + case Operator(es, _) => + val matcherArgs = matcherSet & es.toSet + if (q.nonEmpty && !(q.size == 1 && matcherArgs.isEmpty && m.getType == BooleanType)) + return InvalidOperation(m.asString) + else Set.empty + case Variable(id) if quantified(id) => Set(id) + case _ => q + } + } (conjunct) } + + ForallValid } } diff --git a/src/main/scala/leon/repair/RepairNDEvaluator.scala b/src/main/scala/leon/repair/RepairNDEvaluator.scala index bd0c50fceea9738ac708193c169b12036a27569a..56e8467478f5f0135945b9a72da0e525e4ac2c70 100644 --- a/src/main/scala/leon/repair/RepairNDEvaluator.scala +++ b/src/main/scala/leon/repair/RepairNDEvaluator.scala @@ -1,19 +1,23 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -package leon.repair +package leon +package repair -import leon.purescala._ -import Definitions._ -import Expressions._ -import leon.LeonContext -import leon.evaluators.StreamEvaluator +import purescala.Definitions.Program +import purescala.Expressions._ +import purescala.ExprOps.valuesOf +import evaluators.StreamEvaluator -/** This evaluator treats the expression [[expr]] (reference equality) as a non-deterministic value */ -class RepairNDEvaluator(ctx: LeonContext, prog: Program, expr: Expr) extends StreamEvaluator(ctx, prog) { +/** This evaluator treats the expression [[nd]] (reference equality) as a non-deterministic value */ +class RepairNDEvaluator(ctx: LeonContext, prog: Program, nd: Expr) extends StreamEvaluator(ctx, prog) { override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Stream[Expr] = expr match { - case c if c eq expr => - e(NDValue(c.getType)) + case Not(c) if c eq nd => + // This is a hack: We know the only way nd is wrapped within a Not is if it is NOT within + // a recursive call. So we need to treat it deterministically at this point... + super.e(c) collect { case BooleanLiteral(b) => BooleanLiteral(!b) } + case c if c eq nd => + valuesOf(c.getType) case other => super.e(other) } diff --git a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala index 2846bc4848bd12147a68f7c03a3b7319f8aae88c..429b34c19bbc20667bac9a6723c9c51740a57293 100644 --- a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala +++ b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala @@ -16,12 +16,10 @@ import leon.evaluators._ * as well as if each invocation was successful or erroneous (led to an error) * (.fiStatus) */ -class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) { +class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) with HasDefaultGlobalContext { type RC = CollectingRecContext - type GC = GlobalContext def initRC(mappings: Map[Identifier, Expr]) = CollectingRecContext(mappings, None) - def initGC(model: leon.solvers.Model) = new GlobalContext(model, maxSteps) type FI = (FunDef, Seq[Expr]) diff --git a/src/main/scala/leon/repair/rules/Focus.scala b/src/main/scala/leon/repair/rules/Focus.scala index 6a7754cf05a9035c39fe0599b10e7851f7b84b66..93520c5ea9cd467cf73a94392a565022782eb3d9 100644 --- a/src/main/scala/leon/repair/rules/Focus.scala +++ b/src/main/scala/leon/repair/rules/Focus.scala @@ -4,6 +4,7 @@ package leon package repair package rules +import sun.nio.cs.StreamEncoder import synthesis._ import leon.evaluators._ @@ -92,7 +93,11 @@ case object Focus extends PreprocessingRule("Focus") { def ws(g: Expr) = andJoin(Guide(g) +: wss) def testCondition(cond: Expr) = { - forAllTests(fdSpec, Map(), new AngelicEvaluator( new RepairNDEvaluator(ctx, program, cond))) + val ndSpec = postMap { + case c if c eq cond => Some(not(cond)) + case _ => None + }(fdSpec) + forAllTests(ndSpec, Map(), new AngelicEvaluator(new RepairNDEvaluator(ctx, program, cond))) } guides.flatMap { diff --git a/src/main/scala/leon/solvers/Model.scala b/src/main/scala/leon/solvers/Model.scala index ab91f594c954805267dd523fd9f82305ab789798..07bdee913f21605fbc41f660af608c492e5ee1b5 100644 --- a/src/main/scala/leon/solvers/Model.scala +++ b/src/main/scala/leon/solvers/Model.scala @@ -12,9 +12,9 @@ trait AbstractModel[+This <: Model with AbstractModel[This]] protected val mapping: Map[Identifier, Expr] - def fill(allVars: Iterable[Identifier]): This = { + def set(allVars: Iterable[Identifier]): This = { val builder = newBuilder - builder ++= mapping ++ (allVars.toSet -- mapping.keys).map(id => id -> simplestValue(id.getType)) + builder ++= allVars.map(id => id -> mapping.getOrElse(id, simplestValue(id.getType))) builder.result } diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala index dc3e8584fd74578ac2173e1819c059da9f0ec99b..fa11ab6613bd65b196cce87ee062c3c56f0b95f9 100644 --- a/src/main/scala/leon/solvers/QuantificationSolver.scala +++ b/src/main/scala/leon/solvers/QuantificationSolver.scala @@ -4,14 +4,14 @@ package solvers import purescala.Common._ import purescala.Expressions._ import purescala.Quantification._ +import purescala.Definitions._ import purescala.Types._ -class HenkinModel(mapping: Map[Identifier, Expr], doms: HenkinDomains) +class HenkinModel(mapping: Map[Identifier, Expr], val doms: HenkinDomains) extends Model(mapping) with AbstractModel[HenkinModel] { override def newBuilder = new HenkinModelBuilder(doms) - def domains: Map[TypeTree, Set[Seq[Expr]]] = doms.domains def domain(expr: Expr) = doms.get(expr) } @@ -26,5 +26,10 @@ class HenkinModelBuilder(domains: HenkinDomains) } trait QuantificationSolver { + val program: Program def getModel: HenkinModel + + protected lazy val requireQuantification = program.definedFunctions.exists { fd => + purescala.ExprOps.exists { case _: Forall => true case _ => false } (fd.fullBody) + } } diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 0c7afeb65199042e34004199d17bba25dde9e1c5..f7d395feb913244ed24e5aa0379c1e2e8480610c 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -13,7 +13,7 @@ import purescala.ExprOps._ import purescala.Types._ import utils._ -import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre} +import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoChecks} import templates._ import evaluators._ @@ -26,6 +26,7 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying val feelingLucky = context.findOptionOrDefault(optFeelingLucky) val useCodeGen = context.findOptionOrDefault(optUseCodeGen) val assumePreHolds = context.findOptionOrDefault(optAssumePre) + val disableChecks = context.findOptionOrDefault(optNoChecks) protected var lastCheckResult : (Boolean, Option[Boolean], Option[HenkinModel]) = (false, None, None) @@ -111,9 +112,10 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying def extract(b: Expr, m: Matcher[Expr]): Set[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe - val optEnabler = evaluator.eval(b).result + val optEnabler = evaluator.eval(b, model).result + if (optEnabler == Some(BooleanLiteral(true))) { - val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg)).result) + val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg), model).result) if (optArgs.forall(_.isDefined)) { Set(optArgs.map(_.get)) } else { @@ -124,35 +126,22 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying } } - val funDomains = allVars.flatMap(id => id.getType match { - case ft @ FunctionType(fromTypes, _) => - Some(id -> templateGenerator.manager.instantiations(Variable(id), ft).flatMap { - case (b, m) => extract(b, m) - }) - case _ => None - }).toMap.mapValues(_.toSet) - - val asDMap = model.map(p => funDomains.get(p._1) match { - case Some(domain) => - val mapping = domain.toSeq.map { es => - val ev: Expr = p._2 match { - case RawArrayValue(_, mapping, dflt) => - mapping.collectFirst { - case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v - } getOrElse dflt - case _ => scala.sys.error("Unexpected function encoding " + p._2) - } - es -> ev - } + val (typeInsts, partialInsts, lambdaInsts) = templateGenerator.manager.instantiations + + val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) - case None => p - }).toMap + val funDomains: Map[Identifier, Set[Seq[Expr]]] = partialInsts.map { + case (Variable(id), domain) => id -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) - val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { + case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - val domains = new HenkinDomains(typeDomains) + val asDMap = purescala.Quantification.extractModel(model.toMap, funDomains, typeDomains, evaluator) + val domains = new HenkinDomains(lambdaDomains, typeDomains) new HenkinModel(asDMap, domains) } @@ -160,36 +149,78 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying lastCheckResult = (true, res, model) } - def isValidModel(model: HenkinModel, silenceErrors: Boolean = false): Boolean = { - import EvaluationResults._ + def validatedModel(silenceErrors: Boolean = false): (Boolean, HenkinModel) = { + val lastModel = solver.getModel + val clauses = templateGenerator.manager.checkClauses + val optModel = if (clauses.isEmpty) Some(lastModel) else { + solver.push() + for (clause <- clauses) { + solver.assertCnstr(clause) + } + + reporter.debug(" - Verifying model transitivity") + val solverModel = solver.check match { + case Some(true) => + Some(solver.getModel) + + case Some(false) => + val msg = "- Transitivity independence not guaranteed for model" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + + case None => + val msg = "- Unknown for transitivity independence!?" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + } + + solver.pop() + solverModel + } - val expr = andJoin(constraints.toSeq) - val fullModel = model fill freeVars.toSet + optModel match { + case None => + (false, extractModel(lastModel)) - evaluator.eval(expr, fullModel) match { - case Successful(BooleanLiteral(true)) => - reporter.debug("- Model validated.") - true + case Some(m) => + val model = extractModel(m) - case Successful(BooleanLiteral(false)) => - reporter.debug("- Invalid model.") - false + val expr = andJoin(constraints.toSeq) + val fullModel = model set freeVars.toSet - case Successful(e) => - reporter.warning("- Model leads unexpected result: "+e) - false + (evaluator.check(expr, fullModel) match { + case EvaluationResults.CheckSuccess => + reporter.debug("- Model validated.") + true - case RuntimeError(msg) => - reporter.debug("- Model leads to runtime error.") - false + case EvaluationResults.CheckValidityFailure => + reporter.debug("- Invalid model.") + false - case EvaluatorError(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to evaluator error: " + msg) - } else { - reporter.warning("- Model leads to evaluator error: " + msg) - } - false + case EvaluationResults.CheckRuntimeFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to evaluation error: " + msg) + } else { + reporter.warning("- Model leads to evaluation error: " + msg) + } + false + + case EvaluationResults.CheckQuantificationFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to quantification error: " + msg) + } else { + reporter.warning("- Model leads to quantification error: " + msg) + } + false + }, fullModel) } } @@ -215,9 +246,20 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying foundAnswer(None) case Some(true) => // SAT - val model = extractModel(solver.getModel) + val (valid, model) = if (!this.disableChecks && requireQuantification) { + validatedModel(silenceErrors = false) + } else { + true -> extractModel(solver.getModel) + } + solver.pop() - foundAnswer(Some(true), Some(model)) + if (valid) { + foundAnswer(Some(true), Some(model)) + } else { + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model.asString(context)) + foundAnswer(None, Some(model)) + } case Some(false) if !unrollingBank.canUnroll => solver.pop() @@ -246,12 +288,9 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying case Some(true) => if (feelingLucky && !interrupted) { - val model = extractModel(solver.getModel) - // we might have been lucky :D - if (isValidModel(model, silenceErrors = true)) { - foundAnswer(Some(true), Some(model)) - } + val (valid, model) = validatedModel(silenceErrors = true) + if (valid) foundAnswer(Some(true), Some(model)) } case None => diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index cb1c3246d7eff487e00904c3acbd61d7068aa316..f1cf73142d51debaeda0fc5064096e522f049955 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -7,11 +7,13 @@ package smtlib import purescala.Common._ import purescala.Expressions._ import purescala.Constructors._ +import purescala.Extractors._ import purescala.Types._ import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Forall => SMTForall, _} import _root_.smtlib.parser.Commands._ import _root_.smtlib.interpreters.CVC4Interpreter +import _root_.smtlib.theories.experimental.Sets trait SMTLIBCVC4Target extends SMTLIBTarget { @@ -27,7 +29,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { sorts.cachedB(tpe) { tpe match { case SetType(base) => - Sort(SMTIdentifier(SSymbol("Set")), Seq(declareSort(base))) + Sets.SetSort(declareSort(base)) case _ => super.declareSort(t) @@ -59,12 +61,11 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case RawArrayType(k, v) => RawArrayValue(k, Map(), fromSMT(elem, v)) - case FunctionType(from, to) => - RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) + case ft @ FunctionType(from, to) => + PartialLambda(Seq.empty, Some(fromSMT(elem, to)), ft) case MapType(k, v) => FiniteMap(Map(), k, v) - } case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), Some(tpe)) => @@ -72,12 +73,11 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case RawArrayType(k, v) => RawArrayValue(k, Map(), fromSMT(elem, v)) - case FunctionType(from, to) => - RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) + case ft @ FunctionType(from, to) => + PartialLambda(Seq.empty, Some(fromSMT(elem, to)), ft) case MapType(k, v) => FiniteMap(Map(), k, v) - } case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), Some(tpe)) => @@ -86,9 +86,10 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { val RawArrayValue(k, elems, base) = fromSMT(arr, otpe) RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) - case FunctionType(_, v) => - val RawArrayValue(k, elems, base) = fromSMT(arr, otpe) - RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) + case FunctionType(from, v) => + val PartialLambda(mapping, dflt, ft) = fromSMT(arr, otpe) + val args = unwrapTuple(fromSMT(key, tupleTypeWrap(from)), from.size) + PartialLambda(mapping :+ (args -> fromSMT(elem, v)), dflt, ft) case MapType(k, v) => val FiniteMap(elems, k, v) = fromSMT(arr, otpe) @@ -119,33 +120,24 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { */ case fs @ FiniteSet(elems, _) => if (elems.isEmpty) { - QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset")), Some(declareSort(fs.getType))) + Sets.EmptySet(declareSort(fs.getType)) } else { val selems = elems.toSeq.map(toSMT) - val sgt = FunctionApplication(SSymbol("singleton"), Seq(selems.head)) + val sgt = Sets.Singleton(selems.head) if (selems.size > 1) { - FunctionApplication(SSymbol("insert"), selems.tail :+ sgt) + Sets.Insert(selems.tail :+ sgt) } else { sgt } } - case SubsetOf(ss, s) => - FunctionApplication(SSymbol("subset"), Seq(toSMT(ss), toSMT(s))) - - case ElementOfSet(e, s) => - FunctionApplication(SSymbol("member"), Seq(toSMT(e), toSMT(s))) - - case SetDifference(a, b) => - FunctionApplication(SSymbol("setminus"), Seq(toSMT(a), toSMT(b))) - - case SetUnion(a, b) => - FunctionApplication(SSymbol("union"), Seq(toSMT(a), toSMT(b))) - - case SetIntersection(a, b) => - FunctionApplication(SSymbol("intersection"), Seq(toSMT(a), toSMT(b))) + case SubsetOf(ss, s) => Sets.Subset(toSMT(ss), toSMT(s)) + case ElementOfSet(e, s) => Sets.Member(toSMT(e), toSMT(s)) + case SetDifference(a, b) => Sets.Setminus(toSMT(a), toSMT(b)) + case SetUnion(a, b) => Sets.Union(toSMT(a), toSMT(b)) + case SetIntersection(a, b) => Sets.Intersection(toSMT(a), toSMT(b)) case _ => super.toSMT(e) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 3d873847f878a5c5e071a82445915571ced265d0..f5ee66b505ef6e31dcabf8695a7eaef9b8c36c2a 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -203,7 +203,8 @@ trait SMTLIBTarget extends Interruptible { r case ft @ FunctionType(from, to) => - r + val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, from.size) -> v } + PartialLambda(elems, Some(r.default), ft) case MapType(from, to) => // We expect a RawArrayValue with keys in from and values in Option[to], @@ -701,6 +702,7 @@ trait SMTLIBTarget extends Interruptible { }.toMap fromSMT(body, tpe)(lets ++ defsMap, letDefs) + case (SimpleSymbol(s), _) if constructors.containsB(s) => constructors.toA(s) match { case cct: CaseClassType => diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 3d5eec72c809a7ba9459b4b46752835b63bd6011..00bdbfa07ca49c450cd91b3fce0e67dc7655402a 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -12,30 +12,34 @@ import purescala.Types._ import utils._ import Instantiation._ -class LambdaManager[T](protected val encoder: TemplateEncoder[T]) extends IncrementalState { +class LambdaManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { + private[templates] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) protected val byID = new IncrementalMap[T, LambdaTemplate[T]] protected val byType = new IncrementalMap[FunctionType, Set[(T, LambdaTemplate[T])]].withDefaultValue(Set.empty) protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) + private val instantiated = new IncrementalSet[(T, App[T])] + protected def incrementals: List[IncrementalState] = - List(byID, byType, applications, freeLambdas) + List(byID, byType, applications, freeLambdas, instantiated) def clear(): Unit = incrementals.foreach(_.clear()) def reset(): Unit = incrementals.foreach(_.reset()) def push(): Unit = incrementals.foreach(_.push()) def pop(): Unit = incrementals.foreach(_.pop()) - def registerFree(lambdas: Seq[(TypeTree, T)]): Unit = { - for ((tpe, idT) <- lambdas) tpe match { + def registerFree(lambdas: Seq[(Identifier, T)]): Unit = { + for ((id, idT) <- lambdas) id.getType match { case ft: FunctionType => freeLambdas += ft -> (freeLambdas(ft) + idT) case _ => } } - def instantiateLambda(idT: T, template: LambdaTemplate[T]): Instantiation[T] = { + def instantiateLambda(template: LambdaTemplate[T]): Instantiation[T] = { + val idT = template.ids._2 var clauses : Clauses[T] = equalityClauses(idT, template) var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) @@ -55,32 +59,33 @@ class LambdaManager[T](protected val encoder: TemplateEncoder[T]) extends Increm def instantiateApp(blocker: T, app: App[T]): Instantiation[T] = { val App(caller, tpe, args) = app - var clauses : Clauses[T] = Seq.empty - var callBlockers : CallBlockers[T] = Map.empty.withDefaultValue(Set.empty) - var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) - - if (byID contains caller) { - val (newClauses, newCalls, newApps) = byID(caller).instantiate(blocker, args) + val instantiation = Instantiation.empty[T] - clauses ++= newClauses - newCalls.foreach(p => callBlockers += p._1 -> (callBlockers(p._1) ++ p._2)) - newApps.foreach(p => appBlockers += p._1 -> (appBlockers(p._1) ++ p._2)) - } else if (!freeLambdas(tpe).contains(caller)) { + if (freeLambdas(tpe).contains(caller)) instantiation else { val key = blocker -> app - // make sure that even if byType(tpe) is empty, app is recorded in blockers - // so that UnrollingBank will generate the initial block! - if (!(appBlockers contains key)) appBlockers += key -> Set.empty + if (instantiated(key)) instantiation else { + instantiated += key - for ((idT,template) <- byType(tpe)) { - val equals = encoder.mkEquals(idT, caller) - appBlockers += (key -> (appBlockers(key) + TemplateAppInfo(template, equals, args))) - } + if (byID contains caller) { + instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) + } else { - applications += tpe -> (applications(tpe) + key) - } + // make sure that even if byType(tpe) is empty, app is recorded in blockers + // so that UnrollingBank will generate the initial block! + val init = instantiation withApps Map(key -> Set.empty) + val inst = byType(tpe).foldLeft(init) { + case (instantiation, (idT, template)) => + val equals = encoder.mkEquals(idT, caller) + instantiation withApp (key -> TemplateAppInfo(template, equals, args)) + } - (clauses, callBlockers, appBlockers) + applications += tpe -> (applications(tpe) + key) + + inst + } + } + } } private def equalityClauses(idT: T, template: LambdaTemplate[T]): Seq[T] = { diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index fde9dc746b4a5e6207fc3ed896bbbd5700cbc79a..f2d1c1d6f47f52b266faf0b39173977874ea5179 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -11,10 +11,11 @@ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.Quantification.{QuantificationTypeMatcher => QTM} import Instantiation._ -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet, Stack => MutableStack, Queue} object Matcher { def argValue[T](arg: Either[T, Matcher[T]]): T = arg match { @@ -24,18 +25,24 @@ object Matcher { } case class Matcher[T](caller: T, tpe: TypeTree, args: Seq[Either[T, Matcher[T]]], encoded: T) { - override def toString = "M(" + caller + " : " + tpe + ", " + args.map(Matcher.argValue).mkString("(",",",")") + ")" + override def toString = caller + args.map { + case Right(m) => m.toString + case Left(v) => v.toString + }.mkString("(",",",")") - def substitute(substituter: T => T): Matcher[T] = copy( + def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]] = Map.empty): Matcher[T] = copy( caller = substituter(caller), - args = args.map { arg => arg.left.map(substituter).right.map(_.substitute(substituter)) }, + args = args.map { + case Left(v) => matcherSubst.get(v) match { + case Some(m) => Right(m) + case None => Left(substituter(v)) + } + case Right(m) => Right(m.substitute(substituter, matcherSubst)) + }, encoded = substituter(encoded) ) } -case class IllegalQuantificationException(expr: Expr, msg: String) - extends Exception(msg +" @ " + expr) - class QuantificationTemplate[T]( val quantificationManager: QuantificationManager[T], val start: T, @@ -50,10 +57,31 @@ class QuantificationTemplate[T]( val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]]) { - - def instantiate(substMap: Map[T, T]): Instantiation[T] = { - quantificationManager.instantiateQuantification(this, substMap) + val lambdas: Seq[LambdaTemplate[T]]) { + + def substitute(substituter: T => T): QuantificationTemplate[T] = { + new QuantificationTemplate[T]( + quantificationManager, + substituter(start), + qs, + q2s, + insts, + guardVar, + quantifiers, + condVars, + exprVars, + clauses.map(substituter), + blockers.map { case (b, fis) => + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + }, + applications.map { case (b, apps) => + substituter(b) -> apps.map(app => app.copy(caller = substituter(app.caller), args = app.args.map(substituter))) + }, + matchers.map { case (b, ms) => + substituter(b) -> ms.map(_.substitute(substituter)) + }, + lambdas.map(_.substitute(substituter)) + ) } } @@ -70,7 +98,7 @@ object QuantificationTemplate { condVars: Map[Identifier, T], exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], subst: Map[Identifier, T] ): QuantificationTemplate[T] = { @@ -89,239 +117,535 @@ object QuantificationTemplate { } class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) { - private lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) + private val quantifications = new IncrementalSeq[MatcherQuantification] + private val instCtx = new InstantiationContext + + private val handled = new ContextMap + private val ignored = new ContextMap - private val quantifications = new IncrementalSeq[Quantification] - private val instantiated = new IncrementalSet[(T, Matcher[T])] - private val fInsts = new IncrementalSet[Matcher[T]] private val known = new IncrementalSet[T] + private val lambdaAxioms = new IncrementalSet[(LambdaTemplate[T], Seq[(Identifier, T)])] - private def fInstantiated = fInsts.map(m => trueT -> m) + override protected def incrementals: List[IncrementalState] = + List(quantifications, instCtx, handled, ignored, known, lambdaAxioms) ++ super.incrementals + + private sealed abstract class MatcherKey(val tpe: TypeTree) + private case class CallerKey(caller: T, tt: TypeTree) extends MatcherKey(tt) + private case class LambdaKey(lambda: Lambda, tt: TypeTree) extends MatcherKey(tt) + private case class TypeKey(tt: TypeTree) extends MatcherKey(tt) - private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = correspond(qm, m.caller, m.tpe) - private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = qm.tpe match { - case _: FunctionType => qm.tpe == tpe && (qm.caller == caller || !known(caller)) - case _ => qm.tpe == tpe + private def matcherKey(caller: T, tpe: TypeTree): MatcherKey = tpe match { + case _: FunctionType if known(caller) => CallerKey(caller, tpe) + case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structuralKey, tpe) + case _ => TypeKey(tpe) } - private val uniformQuantifiers = scala.collection.mutable.Map.empty[TypeTree, Seq[T]] + @inline + private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = + correspond(qm, m.caller, m.tpe) + + @inline + private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = + matcherKey(qm.caller, qm.tpe) == matcherKey(caller, tpe) + + private val uniformQuantMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty + private val uniformQuantSet: MutableSet[T] = MutableSet.empty + + def isQuantifier(idT: T): Boolean = uniformQuantSet(idT) private def uniformSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { qs.groupBy(_._1.getType).flatMap { case (tpe, qst) => - val prev = uniformQuantifiers.get(tpe) match { + val prev = uniformQuantMap.get(tpe) match { case Some(seq) => seq case None => Seq.empty } if (prev.size >= qst.size) { - qst.map(_._2) zip prev.take(qst.size - 1) + qst.map(_._2) zip prev.take(qst.size) } else { val (handled, newQs) = qst.splitAt(prev.size) val uQs = newQs.map(p => p._2 -> encoder.encodeId(p._1)) - uniformQuantifiers(tpe) = prev ++ uQs.map(_._2) + + uniformQuantMap(tpe) = prev ++ uQs.map(_._2) + uniformQuantSet ++= uQs.map(_._2) + (handled.map(_._2) zip prev) ++ uQs } }.toMap } - override protected def incrementals: List[IncrementalState] = - List(quantifications, instantiated, fInsts, known) ++ super.incrementals + def assumptions: Seq[T] = quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq - def assumptions: Seq[T] = quantifications.map(_.currentQ2Var).toSeq + def instantiations: (Map[TypeTree, Matchers], Map[T, Matchers], Map[Lambda, Matchers]) = { + var typeInsts: Map[TypeTree, Matchers] = Map.empty + var partialInsts: Map[T, Matchers] = Map.empty + var lambdaInsts: Map[Lambda, Matchers] = Map.empty - def instantiations: Seq[(T, Matcher[T])] = instantiated.toSeq ++ fInstantiated + val instantiations = handled.instantiations ++ instCtx.map.instantiations + for ((key, matchers) <- instantiations) key match { + case TypeKey(tpe) => typeInsts += tpe -> matchers + case CallerKey(caller, _) => partialInsts += caller -> matchers + case LambdaKey(lambda, _) => lambdaInsts += lambda -> matchers + } - def instantiations(caller: T, tpe: TypeTree): Seq[(T, Matcher[T])] = - instantiations.filter { case (b,m) => correspond(m, caller, tpe) } + (typeInsts, partialInsts, lambdaInsts) + } + + def toto: (Map[TypeTree, Matchers], Map[T, Matchers], Map[Lambda, Matchers]) = { + var typeInsts: Map[TypeTree, Matchers] = Map.empty + var partialInsts: Map[T, Matchers] = Map.empty + var lambdaInsts: Map[Lambda, Matchers] = Map.empty + + for ((key, matchers) <- ignored.instantiations) key match { + case TypeKey(tpe) => typeInsts += tpe -> matchers + case CallerKey(caller, _) => partialInsts += caller -> matchers + case LambdaKey(lambda, _) => lambdaInsts += lambda -> matchers + } + + (typeInsts, partialInsts, lambdaInsts) + } - override def registerFree(ids: Seq[(TypeTree, T)]): Unit = { + override def registerFree(ids: Seq[(Identifier, T)]): Unit = { super.registerFree(ids) known ++= ids.map(_._2) } - private class Quantification ( - val qs: (Identifier, T), - val q2s: (Identifier, T), - val insts: (Identifier, T), - val guardVar: T, - val quantified: Set[T], - val matchers: Set[Matcher[T]], - val allMatchers: Map[T, Set[Matcher[T]]], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val lambdas: Map[T, LambdaTemplate[T]]) { + private def matcherDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { + case Right(ma) => matcherDepth(ma) + case _ => 0 + }).max - var currentQ2Var: T = qs._2 - private var slaves: Seq[(T, Map[T,T], Quantification)] = Nil - - private def mappings(blocker: T, matcher: Matcher[T]) - (implicit instantiated: Iterable[(T, Matcher[T])]): Set[(T, Map[T, T])] = { - - // Build a mapping from applications in the quantified statement to all potential concrete - // applications previously encountered. Also make sure the current `app` is in the mapping - // as other instantiations have been performed previously when the associated applications - // were first encountered. - val matcherMappings: Set[Set[(T, Matcher[T], Matcher[T])]] = matchers - // 1. select an application in the quantified proposition for which the current app can - // be bound when generating the new constraints - .filter(qm => correspond(qm, matcher)) - // 2. build the instantiation mapping associated to the chosen current application binding - .flatMap { bindingMatcher => + private def encodeEnablers(es: Set[T]): T = encoder.mkAnd(es.toSeq.sortBy(_.toString) : _*) - // 2.1. select all potential matches for each quantified application - val matcherToInstances = matchers - .map(qm => if (qm == bindingMatcher) { - bindingMatcher -> Set(blocker -> matcher) - } else { - val instances: Set[(T, Matcher[T])] = instantiated.filter { case (b, m) => correspond(qm, m) }.toSet + private type Matchers = Set[(T, Matcher[T])] - // concrete applications can appear multiple times in the constraint, and this is also the case - // for the current application for which we are generating the constraints - val withCurrent = if (correspond(qm, matcher)) instances + (blocker -> matcher) else instances + private class Context private(ctx: Map[Matcher[T], Set[Set[T]]]) extends Iterable[(Set[T], Matcher[T])] { + def this() = this(Map.empty) - qm -> withCurrent - }).toMap + def apply(p: (Set[T], Matcher[T])): Boolean = ctx.get(p._2) match { + case None => false + case Some(blockerSets) => blockerSets(p._1) || blockerSets.exists(set => set.subsetOf(p._1)) + } - // 2.2. based on the possible bindings for each quantified application, build a set of - // instantiation mappings that can be used to instantiate all necessary constraints - extractMappings(matcherToInstances) - } + def +(p: (Set[T], Matcher[T])): Context = if (apply(p)) this else { + val prev = ctx.getOrElse(p._2, Seq.empty) + val newSet = prev.filterNot(set => p._1.subsetOf(set)).toSet + p._1 + new Context(ctx + (p._2 -> newSet)) + } + + def ++(that: Context): Context = that.foldLeft(this)((ctx, p) => ctx + p) + + def iterator = ctx.toSeq.flatMap { case (m, bss) => bss.map(bs => bs -> m) }.iterator + def toMatchers: Matchers = this.map(p => encodeEnablers(p._1) -> p._2).toSet + } - for (mapping <- matcherMappings) yield extractSubst(quantified, mapping) + private class ContextMap( + private var tpeMap: MutableMap[TypeTree, Context] = MutableMap.empty, + private var funMap: MutableMap[MatcherKey, Context] = MutableMap.empty + ) extends IncrementalState { + private val stack = new MutableStack[(MutableMap[TypeTree, Context], MutableMap[MatcherKey, Context])] + + def clear(): Unit = { + stack.clear() + tpeMap.clear() + funMap.clear() } - private def extractSlaveInfo(enabler: T, senabler: T, subst: Map[T,T], ssubst: Map[T,T]): (T, Map[T,T]) = { - val substituter = encoder.substitute(subst) - val slaveEnabler = encoder.mkAnd(enabler, substituter(senabler)) - val slaveSubst = ssubst.map(p => p._1 -> substituter(p._2)) - (slaveEnabler, slaveSubst) + def reset(): Unit = clear() + + def push(): Unit = { + stack.push((tpeMap, funMap)) + tpeMap = tpeMap.clone + funMap = funMap.clone } - private def instantiate(enabler: T, subst: Map[T, T], seen: Set[Quantification]): Instantiation[T] = { - if (seen(this)) { - Instantiation.empty[T] - } else { - val nextQ2Var = encoder.encodeId(q2s._1) + def pop(): Unit = { + val (ptpeMap, pfunMap) = stack.pop() + tpeMap = ptpeMap + funMap = pfunMap + } - val baseSubstMap = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } - val lambdaSubstMap = lambdas map { case (idT, lambda) => idT -> encoder.encodeId(lambda.id) } - val substMap = subst ++ baseSubstMap ++ lambdaSubstMap + - (qs._2 -> currentQ2Var) + (guardVar -> enabler) + (q2s._2 -> nextQ2Var) + - (insts._2 -> encoder.encodeId(insts._1)) + def +=(p: (Set[T], Matcher[T])): Unit = matcherKey(p._2.caller, p._2.tpe) match { + case TypeKey(tpe) => tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) + p + case key => funMap(key) = funMap.getOrElse(key, new Context) + p + } - var instantiation = Template.instantiate(encoder, QuantificationManager.this, - clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) + def merge(that: ContextMap): this.type = { + for ((tpe, values) <- that.tpeMap) tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) ++ values + for ((caller, values) <- that.funMap) funMap(caller) = funMap.getOrElse(caller, new Context) ++ values + this + } + + def get(caller: T, tpe: TypeTree): Context = + funMap.getOrElse(matcherKey(caller, tpe), new Context) ++ tpeMap.getOrElse(tpe, new Context) + + def get(key: MatcherKey): Context = key match { + case TypeKey(tpe) => tpeMap.getOrElse(tpe, new Context) + case key => funMap.getOrElse(key, new Context) + } - for { - (senabler, ssubst, slave) <- slaves - (slaveEnabler, slaveSubst) = extractSlaveInfo(enabler, senabler, subst, ssubst) - } instantiation ++= slave.instantiate(slaveEnabler, slaveSubst, seen + this) + def instantiations: Map[MatcherKey, Matchers] = + (funMap.toMap ++ tpeMap.map { case (tpe,ms) => TypeKey(tpe) -> ms }).mapValues(_.toMatchers) + } + + private class InstantiationContext private ( + private var _instantiated : Context, val map : ContextMap + ) extends IncrementalState { + + private val stack = new MutableStack[Context] + + def this() = this(new Context, new ContextMap) + + def clear(): Unit = { + stack.clear() + map.clear() + _instantiated = new Context + } - currentQ2Var = nextQ2Var + def reset(): Unit = clear() + + def push(): Unit = { + stack.push(_instantiated) + map.push() + } + + def pop(): Unit = { + _instantiated = stack.pop() + map.pop() + } + + def instantiated: Context = _instantiated + def apply(p: (Set[T], Matcher[T])): Boolean = _instantiated(p) + + def corresponding(m: Matcher[T]): Context = map.get(m.caller, m.tpe) + + def instantiate(blockers: Set[T], matcher: Matcher[T])(qs: MatcherQuantification*): Instantiation[T] = { + if (this(blockers -> matcher)) { + Instantiation.empty[T] + } else { + map += (blockers -> matcher) + _instantiated += (blockers -> matcher) + var instantiation = Instantiation.empty[T] + for (q <- qs) instantiation ++= q.instantiate(blockers, matcher) instantiation } } - def register(senabler: T, ssubst: Map[T, T], slave: Quantification): Instantiation[T] = { - var instantiation = Instantiation.empty[T] + def merge(that: InstantiationContext): this.type = { + _instantiated ++= that._instantiated + map.merge(that.map) + this + } + } + + private trait MatcherQuantification { + val quantified: Set[T] + val matchers: Set[Matcher[T]] + val allMatchers: Map[T, Set[Matcher[T]]] + val condVars: Map[Identifier, T] + val exprVars: Map[Identifier, T] + val clauses: Seq[T] + val blockers: Map[T, Set[TemplateCallInfo[T]]] + val applications: Map[T, Set[App[T]]] + val lambdas: Seq[LambdaTemplate[T]] + + private lazy val depth = matchers.map(matcherDepth).max + private lazy val transMatchers: Set[Matcher[T]] = (for { + (b, ms) <- allMatchers.toSeq + m <- ms if !matchers(m) && matcherDepth(m) <= depth + } yield m).toSet + + /* Build a mapping from applications in the quantified statement to all potential concrete + * applications previously encountered. Also make sure the current `app` is in the mapping + * as other instantiations have been performed previously when the associated applications + * were first encountered. + */ + private def mappings(bs: Set[T], matcher: Matcher[T]): Set[Set[(Set[T], Matcher[T], Matcher[T])]] = { + /* 1. select an application in the quantified proposition for which the current app can + * be bound when generating the new constraints + */ + matchers.filter(qm => correspond(qm, matcher)) + + /* 2. build the instantiation mapping associated to the chosen current application binding */ + .flatMap { bindingMatcher => + + /* 2.1. select all potential matches for each quantified application */ + val matcherToInstances = matchers + .map(qm => if (qm == bindingMatcher) { + bindingMatcher -> Set(bs -> matcher) + } else { + qm -> instCtx.corresponding(qm) + }).toMap + + /* 2.2. based on the possible bindings for each quantified application, build a set of + * instantiation mappings that can be used to instantiate all necessary constraints + */ + val allMappings = matcherToInstances.foldLeft[Set[Set[(Set[T], Matcher[T], Matcher[T])]]](Set(Set.empty)) { + case (mappings, (qm, instances)) => Set(instances.toSeq.flatMap { + case (bs, m) => mappings.map(mapping => mapping + ((bs, qm, m))) + } : _*) + } + + /* 2.3. filter out bindings that don't make sense where abstract sub-matchers + * (matchers in arguments of other matchers) are bound to different concrete + * matchers in the argument and quorum positions + */ + allMappings.filter { s => + def expand(ms: Traversable[(Either[T,Matcher[T]], Either[T,Matcher[T]])]): Set[(Matcher[T], Matcher[T])] = ms.flatMap { + case (Right(qm), Right(m)) => Set(qm -> m) ++ expand(qm.args zip m.args) + case _ => Set.empty[(Matcher[T], Matcher[T])] + }.toSet + + expand(s.map(p => Right(p._2) -> Right(p._3))).groupBy(_._1).forall(_._2.size == 1) + } + + allMappings + } + } + + private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Either[T, Matcher[T]]], Boolean) = { + var constraints: Set[T] = Set.empty + var matcherEqs: List[(T, T)] = Nil + var subst: Map[T, Either[T, Matcher[T]]] = Map.empty for { - instantiated <- List(instantiated, fInstantiated) - (blocker, matcher) <- instantiated - (enabler, subst) <- mappings(blocker, matcher)(instantiated) - (slaveEnabler, slaveSubst) = extractSlaveInfo(enabler, senabler, subst, ssubst) - } instantiation ++= slave.instantiate(slaveEnabler, slaveSubst, Set(this)) + (bs, qm @ Matcher(qcaller, _, qargs, _), m @ Matcher(caller, _, args, _)) <- mapping + _ = constraints ++= bs + _ = matcherEqs :+= qm.encoded -> m.encoded + (qarg, arg) <- (qargs zip args) + } qarg match { + case Left(quant) if subst.isDefinedAt(quant) => + constraints += encoder.mkEquals(quant, Matcher.argValue(arg)) + case Left(quant) if quantified(quant) => + subst += quant -> arg + case Right(qam) => + val argVal = Matcher.argValue(arg) + constraints += encoder.mkEquals(qam.encoded, argVal) + matcherEqs :+= qam.encoded -> argVal + } - slaves :+= (senabler, ssubst, slave) + val substituter = encoder.substitute(subst.mapValues(Matcher.argValue)) + val enablers = (if (constraints.isEmpty) Set(trueT) else constraints).map(substituter) + val isStrict = matcherEqs.forall(p => substituter(p._1) == p._2) - instantiation + (enablers, subst, isStrict) } - def instantiate(blocker: T, matcher: Matcher[T])(implicit instantiated: Iterable[(T, Matcher[T])]): Instantiation[T] = { + def instantiate(bs: Set[T], matcher: Matcher[T]): Instantiation[T] = { var instantiation = Instantiation.empty[T] - for ((enabler, subst) <- mappings(blocker, matcher)) { - instantiation ++= instantiate(enabler, subst, Set.empty) + for (mapping <- mappings(bs, matcher)) { + val (enablers, subst, isStrict) = extractSubst(mapping) + val enabler = encodeEnablers(enablers) + + val baseSubstMap = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } + val lambdaSubstMap = lambdas map(lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) + val substMap = subst.mapValues(Matcher.argValue) ++ baseSubstMap ++ lambdaSubstMap ++ instanceSubst(enabler) + + instantiation ++= Template.instantiate(encoder, QuantificationManager.this, + clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) + + + val msubst = subst.collect { case (c, Right(m)) => c -> m } + val substituter = encoder.substitute(substMap) + + for ((b,ms) <- allMatchers; m <- ms) { + val sb = enablers + substituter(b) + val sm = m.substitute(substituter, matcherSubst = msubst) + + if (matchers(m)) { + handled += sb -> sm + } else if (transMatchers(m) && isStrict) { + instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) + } else { + ignored += sb -> sm + } + } } instantiation } + + protected def instanceSubst(enabler: T): Map[T, T] } - private def extractMappings( - bindings: Map[Matcher[T], Set[(T, Matcher[T])]] - ): Set[Set[(T, Matcher[T], Matcher[T])]] = { - val allMappings = bindings.foldLeft[Set[Set[(T, Matcher[T], Matcher[T])]]](Set(Set.empty)) { - case (mappings, (qm, instances)) => Set(instances.toSeq.flatMap { - case (b, m) => mappings.map(mapping => mapping + ((b, qm, m))) - } : _*) - } + private class Quantification ( + val qs: (Identifier, T), + val q2s: (Identifier, T), + val insts: (Identifier, T), + val guardVar: T, + val quantified: Set[T], + val matchers: Set[Matcher[T]], + val allMatchers: Map[T, Set[Matcher[T]]], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { - def subBindings(b: T, sm: Matcher[T], m: Matcher[T]): Set[(T, Matcher[T], Matcher[T])] = { - (for ((sarg, arg) <- sm.args zip m.args) yield { - (sarg, arg) match { - case (Right(sargm), Right(argm)) => Set((b, sargm, argm)) ++ subBindings(b, sargm, argm) - case _ => Set.empty[(T, Matcher[T], Matcher[T])] - } - }).flatten.toSet - } + var currentQ2Var: T = qs._2 - allMappings.filter { s => - val withSubs = s ++ s.flatMap { case (b, sm, m) => subBindings(b, sm, m) } - withSubs.groupBy(_._2).forall(_._2.size == 1) - } - } + protected def instanceSubst(enabler: T): Map[T, T] = { + val nextQ2Var = encoder.encodeId(q2s._1) - private def extractSubst(quantified: Set[T], mapping: Set[(T, Matcher[T], Matcher[T])]): (T, Map[T,T]) = { - var constraints: List[T] = Nil - var subst: Map[T, T] = Map.empty + val subst = Map(qs._2 -> currentQ2Var, guardVar -> enabler, + q2s._2 -> nextQ2Var, insts._2 -> encoder.encodeId(insts._1)) - for { - (b, Matcher(qcaller, _, qargs, _), Matcher(caller, _, args, _)) <- mapping - _ = constraints :+= b - (qarg, arg) <- (qargs zip args) - argVal = Matcher.argValue(arg) - } qarg match { - case Left(quant) if subst.isDefinedAt(quant) => - constraints :+= encoder.mkEquals(quant, argVal) - case Left(quant) if quantified(quant) => - subst += quant -> argVal - case _ => - constraints :+= encoder.mkEquals(Matcher.argValue(qarg), argVal) + currentQ2Var = nextQ2Var + subst } + } - val enabler = - if (constraints.isEmpty) trueT - else if (constraints.size == 1) constraints.head - else encoder.mkAnd(constraints : _*) + private lazy val blockerId = FreshIdentifier("blocker", BooleanType, true) + private lazy val blockerCache: MutableMap[T, T] = MutableMap.empty - (encoder.substitute(subst)(enabler), subst) - } + private class Axiom ( + val start: T, + val blocker: T, + val guardVar: T, + val quantified: Set[T], + val matchers: Set[Matcher[T]], + val allMatchers: Map[T, Set[Matcher[T]]], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { + + protected def instanceSubst(enabler: T): Map[T, T] = { + val newBlocker = blockerCache.get(enabler) match { + case Some(b) => b + case None => + val nb = encoder.encodeId(blockerId) + blockerCache += enabler -> nb + blockerCache += nb -> nb + nb + } - def instantiateQuantification(template: QuantificationTemplate[T], substMap: Map[T, T]): Instantiation[T] = { - val quantified = template.quantifiers.map(_._2).toSet + Map(guardVar -> encoder.mkAnd(start, enabler), blocker -> newBlocker) + } + } - val allMatchers: Map[T, Set[Matcher[T]]] = { - def rec(templates: Map[T, LambdaTemplate[T]]): Map[T, Set[Matcher[T]]] = - templates.foldLeft(Map.empty[T, Set[Matcher[T]]]) { - case (matchers, (_, template)) => matchers merge template.matchers merge rec(template.lambdas) + private def extractQuorums( + quantified: Set[T], + matchers: Set[Matcher[T]], + lambdas: Seq[LambdaTemplate[T]] + ): Seq[Set[Matcher[T]]] = { + val extMatchers: Set[Matcher[T]] = { + def rec(templates: Seq[LambdaTemplate[T]]): Set[Matcher[T]] = + templates.foldLeft(Set.empty[Matcher[T]]) { + case (matchers, template) => matchers ++ template.matchers.flatMap(_._2) ++ rec(template.lambdas) } - template.matchers merge rec(template.lambdas) + matchers ++ rec(lambdas) } - val quantifiedMatchers = (for { - (_, ms) <- allMatchers - m @ Matcher(_, _, args, _) <- ms + val quantifiedMatchers = for { + m @ Matcher(_, _, args, _) <- extMatchers if args exists (_.left.exists(quantified)) - } yield m).toSet + } yield m - val matchQuorums: Seq[Set[Matcher[T]]] = purescala.Quantification.extractQuorums( - quantifiedMatchers, quantified, + purescala.Quantification.extractQuorums(quantifiedMatchers, quantified, (m: Matcher[T]) => m.args.collect { case Right(m) if quantifiedMatchers(m) => m }.toSet, (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) + } + + def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, T]): Instantiation[T] = { + val quantifiers = template.arguments map { + case (id, idT) => id -> substMap(idT) + } filter (p => isQuantifier(p._2)) + + if (quantifiers.isEmpty || lambdaAxioms(template -> quantifiers)) { + Instantiation.empty[T] + } else { + lambdaAxioms += template -> quantifiers + val blockerT = encoder.encodeId(blockerId) + + val guard = FreshIdentifier("guard", BooleanType, true) + val guardT = encoder.encodeId(guard) + + val substituter = encoder.substitute(substMap + (template.start -> blockerT)) + val allMatchers = template.matchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) } + val qMatchers = allMatchers.flatMap(_._2).toSet + + val encArgs = template.args map substituter + val app = Application(Variable(template.ids._1), template.arguments.map(_._1.toVariable)) + val appT = encoder.encodeExpr((template.arguments.map(_._1) zip encArgs).toMap + template.ids)(app) + val selfMatcher = Matcher(template.ids._2, template.tpe, encArgs.map(Left(_)), appT) + + val enablingClause = encoder.mkImplies(guardT, blockerT) + + instantiateAxiom( + substMap(template.start), + blockerT, + guardT, + quantifiers, + qMatchers, + allMatchers + (template.start -> (allMatchers.getOrElse(template.start, Set.empty) + selfMatcher)), + template.condVars map { case (id, idT) => id -> substituter(idT) }, + template.exprVars map { case (id, idT) => id -> substituter(idT) }, + (template.clauses map substituter) :+ enablingClause, + template.blockers map { case (b, fis) => + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + }, + template.applications map { case (b, apps) => + substituter(b) -> apps.map(app => app.copy(caller = substituter(app.caller), args = app.args.map(substituter))) + }, + template.lambdas map (_.substitute(substituter)) + ) + } + } + + def instantiateAxiom( + start: T, + blocker: T, + guardVar: T, + quantifiers: Seq[(Identifier, T)], + matchers: Set[Matcher[T]], + allMatchers: Map[T, Set[Matcher[T]]], + condVars: Map[Identifier, T], + exprVars: Map[Identifier, T], + clauses: Seq[T], + blockers: Map[T, Set[TemplateCallInfo[T]]], + applications: Map[T, Set[App[T]]], + lambdas: Seq[LambdaTemplate[T]] + ): Instantiation[T] = { + val quantified = quantifiers.map(_._2).toSet + val matchQuorums = extractQuorums(quantified, matchers, lambdas) + + var instantiation = Instantiation.empty[T] + + for (matchers <- matchQuorums) { + val axiom = new Axiom(start, blocker, guardVar, quantified, + matchers, allMatchers, condVars, exprVars, + clauses, blockers, applications, lambdas + ) + + quantifications += axiom + + val newCtx = new InstantiationContext() + for ((b,m) <- instCtx.instantiated) { + instantiation ++= newCtx.instantiate(b, m)(axiom) + } + instCtx.merge(newCtx) + } + + val quantifierSubst = uniformSubst(quantifiers) + val substituter = encoder.substitute(quantifierSubst) + + for { + m <- matchers + sm = m.substitute(substituter) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } instantiation ++= instCtx.instantiate(Set(trueT), sm)(quantifications.toSeq : _*) + + instantiation + } + + def instantiateQuantification(template: QuantificationTemplate[T], substMap: Map[T, T]): Instantiation[T] = { + val quantified = template.quantifiers.map(_._2).toSet + val matchQuorums = extractQuorums(quantified, template.matchers.flatMap(_._2).toSet, template.lambdas) var instantiation = Instantiation.empty[T] @@ -333,8 +657,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val quantification = new Quantification(template.qs._1 -> newQ, template.q2s, template.insts, template.guardVar, quantified, - matchers map (m => m.substitute(substituter)), - allMatchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) }, + matchers map (_.substitute(substituter)), + template.matchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) }, template.condVars, template.exprVars, template.clauses map substituter, @@ -344,52 +668,17 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage template.applications map { case (b, fas) => substituter(b) -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) }, - template.lambdas map { case (idT, template) => substituter(idT) -> template.substitute(subst) } + template.lambdas map (_.substitute(substituter)) ) - def extendClauses(master: Quantification, slave: Quantification): Instantiation[T] = { - val bindingsMap: Map[Matcher[T], Set[(T, Matcher[T])]] = slave.matchers.map { sm => - val instances = master.allMatchers.toSeq.flatMap { case (b, ms) => ms.map(b -> _) } - sm -> instances.filter(p => correspond(p._2, sm)).toSet - }.toMap - - val allMappings = extractMappings(bindingsMap) - val filteredMappings = allMappings.filter { s => - s.exists { case (b, sm, m) => !master.matchers(m) } && - s.exists { case (b, sm, m) => sm != m } && - s.forall { case (b, sm, m) => - (sm.args zip m.args).forall { - case (Right(_), Left(_)) => false - case _ => true - } - } - } - - var instantiation = Instantiation.empty[T] - - for (mapping <- filteredMappings) { - val (enabler, subst) = extractSubst(slave.quantified, mapping) - instantiation ++= master.register(enabler, subst, slave) - } - - instantiation - } - - val allSet = quantification.allMatchers.flatMap(_._2).toSet - for (q <- quantifications) { - val allQSet = q.allMatchers.flatMap(_._2).toSet - if (quantification.matchers.forall(m => allQSet.exists(qm => correspond(qm, m)))) - instantiation ++= extendClauses(q, quantification) - - if (q.matchers.forall(qm => allSet.exists(m => correspond(qm, m)))) - instantiation ++= extendClauses(quantification, q) - } + quantifications += quantification - for (instantiated <- List(instantiated, fInstantiated); (b, m) <- instantiated) { - instantiation ++= quantification.instantiate(b, m)(instantiated) + val newCtx = new InstantiationContext() + for ((b,m) <- instCtx.instantiated) { + instantiation ++= newCtx.instantiate(b, m)(quantification) } + instCtx.merge(newCtx) - quantifications += quantification quantification.qs._2 } @@ -404,34 +693,74 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val quantifierSubst = uniformSubst(template.quantifiers) val substituter = encoder.substitute(substMap ++ quantifierSubst) - for ((_, ms) <- template.matchers; m <- ms) { - val sm = m.substitute(substituter) - - if (!fInsts.exists(fm => correspond(sm, fm) && sm.args == fm.args)) { - for (q <- quantifications) { - instantiation ++= q.instantiate(trueT, sm)(fInstantiated) - } - - fInsts += sm - } - } + for { + (_, ms) <- template.matchers; m <- ms + sm = m.substitute(substituter) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } instantiation ++= instCtx.instantiate(Set(trueT), sm)(quantifications.toSeq : _*) instantiation } def instantiateMatcher(blocker: T, matcher: Matcher[T]): Instantiation[T] = { - val qInst = if (instantiated(blocker -> matcher)) Instantiation.empty[T] else { - var instantiation = Instantiation.empty[T] - for (q <- quantifications) { - instantiation ++= q.instantiate(blocker, matcher)(instantiated) - } + instCtx.instantiate(Set(blocker), matcher)(quantifications.toSeq : _*) + } - instantiated += (blocker -> matcher) + private type SetDef = (T, (Identifier, T), (Identifier, T), Seq[T], T, T, T) + private val setConstructors: MutableMap[TypeTree, SetDef] = MutableMap.empty + + def checkClauses: Seq[T] = { + val clauses = new scala.collection.mutable.ListBuffer[T] + + for ((key, ctx) <- ignored.instantiations) { + val insts = instCtx.map.get(key).toMatchers + + val QTM(argTypes, _) = key.tpe + val tupleType = tupleTypeWrap(argTypes) + + val (guardT, (setPrev, setPrevT), (setNext, setNextT), elems, containsT, emptyT, setT) = + setConstructors.getOrElse(tupleType, { + val guard = FreshIdentifier("guard", BooleanType) + val setPrev = FreshIdentifier("prevSet", SetType(tupleType)) + val setNext = FreshIdentifier("nextSet", SetType(tupleType)) + val elems = argTypes.map(tpe => FreshIdentifier("elem", tpe)) + + val elemExpr = tupleWrap(elems.map(_.toVariable)) + val contextExpr = And( + Implies(Variable(guard), Equals(Variable(setNext), + SetUnion(Variable(setPrev), FiniteSet(Set(elemExpr), tupleType)))), + Implies(Not(Variable(guard)), Equals(Variable(setNext), Variable(setPrev)))) + + val guardP = guard -> encoder.encodeId(guard) + val setPrevP = setPrev -> encoder.encodeId(setPrev) + val setNextP = setNext -> encoder.encodeId(setNext) + val elemsP = elems.map(e => e -> encoder.encodeId(e)) + + val containsT = encoder.encodeExpr(elemsP.toMap + setPrevP)(ElementOfSet(elemExpr, setPrevP._1.toVariable)) + val emptyT = encoder.encodeExpr(Map.empty)(FiniteSet(Set.empty, tupleType)) + val contextT = encoder.encodeExpr(Map(guardP, setPrevP, setNextP) ++ elemsP)(contextExpr) + + val setDef = (guardP._2, setPrevP, setNextP, elemsP.map(_._2), containsT, emptyT, contextT) + setConstructors += key.tpe -> setDef + setDef + }) + + var prev = emptyT + for ((b, m) <- insts.toSeq) { + val next = encoder.encodeId(setNext) + val argsMap = (elems zip m.args).map { case (idT, arg) => idT -> Matcher.argValue(arg) } + val substMap = Map(guardT -> b, setPrevT -> prev, setNextT -> next) ++ argsMap + prev = next + clauses += encoder.substitute(substMap)(setT) + } - instantiation + val setMap = Map(setPrevT -> prev) + for ((b, m) <- ctx.toSeq) { + val substMap = setMap ++ (elems zip m.args).map(p => p._1 -> Matcher.argValue(p._2)) + clauses += encoder.substitute(substMap)(encoder.mkImplies(b, containsT)) + } } - qInst + clauses.toSeq } - } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index df27d25e4af6dd67964b97db823c3e68c245f215..5e6b7213c968558a7143b5ced41487b14d2a9ea3 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -11,6 +11,7 @@ import purescala.ExprOps._ import purescala.Types._ import purescala.Definitions._ import purescala.Constructors._ +import purescala.Quantification._ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val assumePreHolds: Boolean) { @@ -72,7 +73,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val (bodyConds, bodyExprs, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { invocationEqualsBody.map(expr => mkClauses(start, expr, substMap)).getOrElse { - (Map[Identifier,T](), Map[Identifier,T](), Map[Identifier,Seq[Expr]](), Map[T,LambdaTemplate[T]](), Seq[QuantificationTemplate[T]]()) + (Map[Identifier,T](), Map[Identifier,T](), Map[Identifier,Seq[Expr]](), Seq[LambdaTemplate[T]](), Seq[QuantificationTemplate[T]]()) } } else { mkClauses(start, lambdaBody.get, substMap) @@ -133,8 +134,47 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], andJoin(rec(invocation, body, args, inlineFirst)) } + private def minimalFlattening(inits: Set[Identifier], conj: Expr): (Set[Identifier], Expr) = { + var mapping: Map[Expr, Expr] = Map.empty + var quantified: Set[Identifier] = inits + var quantifierEqualities: Seq[(Expr, Identifier)] = Seq.empty + + val newConj = postMap { + case expr if mapping.isDefinedAt(expr) => + Some(mapping(expr)) + + case expr @ QuantificationMatcher(c, args) => + val isMatcher = args.exists { case Variable(id) => quantified(id) case _ => false } + val isRelevant = (variablesOf(expr) & quantified).nonEmpty + if (!isMatcher && isRelevant) { + val newArgs = args.map { + case arg @ QuantificationMatcher(_, _) if (variablesOf(arg) & quantified).nonEmpty => + val id = FreshIdentifier("flat", arg.getType) + quantifierEqualities :+= (arg -> id) + quantified += id + Variable(id) + case arg => arg + } + + val newExpr = replace((args zip newArgs).toMap, expr) + mapping += expr -> newExpr + Some(newExpr) + } else { + None + } + + case _ => None + } (conj) + + val flatConj = implies(andJoin(quantifierEqualities.map { + case (arg, id) => Equals(arg, Variable(id)) + }), newConj) + + (quantified, flatConj) + } + def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): - (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Map[T, LambdaTemplate[T]], Seq[QuantificationTemplate[T]]) = { + (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Seq[LambdaTemplate[T]], Seq[QuantificationTemplate[T]]) = { var condVars = Map[Identifier, T]() @inline def storeCond(id: Identifier) : Unit = condVars += id -> encoder.encodeId(id) @@ -165,8 +205,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], @inline def registerQuantification(quantification: QuantificationTemplate[T]): Unit = quantifications :+= quantification - var lambdas = Map[T, LambdaTemplate[T]]() - @inline def registerLambda(idT: T, lambda: LambdaTemplate[T]) : Unit = lambdas += idT -> lambda + var lambdas = Seq[LambdaTemplate[T]]() + @inline def registerLambda(lambda: LambdaTemplate[T]) : Unit = lambdas :+= lambda def requireDecomposition(e: Expr) = { exists{ @@ -280,13 +320,12 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars val clauseSubst: Map[Identifier, T] = localSubst ++ (idArgs zip trArgs) val (lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, lambdaQuants) = mkClauses(pathVar, clause, clauseSubst) - assert(lambdaQuants.isEmpty, "Unhandled quantification in lambdas in " + l) val ids: (Identifier, T) = lid -> storeLambda(lid) val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap val template = LambdaTemplate(ids, encoder, manager, pathVar -> encodedCond(pathVar), - idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, localSubst, dependencies, l) - registerLambda(ids._2, template) + idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaQuants, lambdaTemplates, localSubst, dependencies, l) + registerLambda(template) Variable(lid) @@ -295,7 +334,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val conjunctQs = conjuncts.map { conjunct => val vars = variablesOf(conjunct) - val quantifiers = args.map(_.id).filter(vars).toSet + val inits = args.map(_.id).filter(vars).toSet + val (quantifiers, flatConj) = minimalFlattening(inits, conjunct) val idQuantifiers : Seq[Identifier] = quantifiers.toSeq val trQuantifiers : Seq[T] = idQuantifiers.map(encoder.encodeId) @@ -305,7 +345,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val inst: Identifier = FreshIdentifier("inst", BooleanType, true) val guard: Identifier = FreshIdentifier("guard", BooleanType, true) - val clause = Equals(Variable(inst), Implies(Variable(guard), conjunct)) + val clause = Equals(Variable(inst), Implies(Variable(guard), flatConj)) val qs: (Identifier, T) = q -> encoder.encodeId(q) val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/templates/TemplateInfo.scala index e298e298a6f828c78dcf4da8de5177f94f16758b..977aeb5711b66c006161ff1af28fe5b9604456eb 100644 --- a/src/main/scala/leon/solvers/templates/TemplateInfo.scala +++ b/src/main/scala/leon/solvers/templates/TemplateInfo.scala @@ -14,6 +14,6 @@ case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) { case class TemplateAppInfo[T](template: LambdaTemplate[T], equals: T, args: Seq[T]) { override def toString = { - template.id + "|" + equals + args.mkString("(", ",", ")") + template.ids._1 + "|" + equals + args.mkString("(", ",", ")") } } diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index 32d273c3937d6ba4b808b79c16edf1ded4ade785..5e7302c549720a0291ecedf592c4a28d181b59fd 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -40,6 +40,12 @@ object Instantiation { def withClause(cl: T): Instantiation[T] = (i._1 :+ cl, i._2, i._3) def withClauses(cls: Seq[T]): Instantiation[T] = (i._1 ++ cls, i._2, i._3) + + def withCalls(calls: CallBlockers[T]): Instantiation[T] = (i._1, i._2 merge calls, i._3) + def withApps(apps: AppBlockers[T]): Instantiation[T] = (i._1, i._2, i._3 merge apps) + def withApp(app: ((T, App[T]), TemplateAppInfo[T])): Instantiation[T] = { + (i._1, i._2, i._3 merge Map(app._1 -> Set(app._2))) + } } } @@ -56,9 +62,9 @@ trait Template[T] { self => val clauses : Seq[T] val blockers : Map[T, Set[TemplateCallInfo[T]]] val applications : Map[T, Set[App[T]]] - val quantifications: Seq[QuantificationTemplate[T]] - val matchers: Map[T, Set[Matcher[T]]] - val lambdas : Map[T, LambdaTemplate[T]] + val quantifications : Seq[QuantificationTemplate[T]] + val matchers : Map[T, Set[Matcher[T]]] + val lambdas : Seq[LambdaTemplate[T]] private var substCache : Map[Seq[T],Map[T,T]] = Map.empty @@ -73,10 +79,13 @@ trait Template[T] { self => subst } - val lambdaSubstMap = lambdas.map { case (idT, lambda) => idT -> encoder.encodeId(lambda.id) } + val lambdaSubstMap = lambdas.map(lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) val quantificationSubstMap = quantifications.map(q => q.qs._2 -> encoder.encodeId(q.qs._1)) val substMap : Map[T,T] = baseSubstMap ++ lambdaSubstMap ++ quantificationSubstMap + (start -> aVar) + instantiate(substMap) + } + protected def instantiate(substMap: Map[T, T]): Instantiation[T] = { Template.instantiate(encoder, manager, clauses, blockers, applications, quantifications, matchers, lambdas, substMap) } @@ -86,43 +95,6 @@ trait Template[T] { self => object Template { - private object InvocationExtractor { - private def flatInvocation(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { - case fi @ FunctionInvocation(tfd, args) => Some((tfd, args)) - case Application(caller, args) => flatInvocation(caller) match { - case Some((tfd, prevArgs)) => Some((tfd, prevArgs ++ args)) - case None => None - } - case _ => None - } - - def unapply(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { - case IsTyped(f: FunctionInvocation, ft: FunctionType) => None - case IsTyped(f: Application, ft: FunctionType) => None - case FunctionInvocation(tfd, args) => Some(tfd -> args) - case f: Application => flatInvocation(f) - case _ => None - } - } - - private object ApplicationExtractor { - private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(fi: FunctionInvocation, _) => None - case Application(caller: Application, args) => flatApplication(caller) match { - case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) - case None => None - } - case Application(caller, args) => Some((caller, args)) - case _ => None - } - - def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case IsTyped(f: Application, ft: FunctionType) => None - case f: Application => flatApplication(f) - case _ => None - } - } - private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") @@ -146,16 +118,14 @@ object Template { condVars: Map[Identifier, T], exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], substMap: Map[Identifier, T] = Map.empty[Identifier, T], optCall: Option[TypedFunDef] = None, optApp: Option[(T, FunctionType)] = None ) : (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[T, Set[App[T]]], Map[T, Set[Matcher[T]]], () => String) = { - val idToTrId : Map[Identifier, T] = { - condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ - lambdas.map { case (idT, template) => template.id -> idT } - } + val idToTrId : Map[Identifier, T] = + condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) val encodeExpr : Expr => T = encoder.encodeExpr(idToTrId) @@ -180,17 +150,10 @@ object Template { var matchInfos : Set[Matcher[T]] = Set.empty for (e <- es) { - funInfos ++= collect[TemplateCallInfo[T]] { - case InvocationExtractor(tfd, args) => - Set(TemplateCallInfo(tfd, args.map(encodeExpr))) - case _ => Set.empty - }(e) - - appInfos ++= collect[App[T]] { - case ApplicationExtractor(c, args) => - Set(App(encodeExpr(c), c.getType.asInstanceOf[FunctionType], args.map(encodeExpr))) - case _ => Set.empty - }(e) + funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeExpr))) + appInfos ++= firstOrderAppsOf(e).map { case (c, args) => + App(encodeExpr(c), c.getType.asInstanceOf[FunctionType], args.map(encodeExpr)) + } matchInfos ++= fold[Map[Expr, Matcher[T]]] { (expr, res) => val result = res.flatten.toMap @@ -247,7 +210,7 @@ object Template { " * Matchers :" + (if (matchers.isEmpty) "\n" else { "\n " + matchers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" }) + - " * Lambdas :\n" + lambdas.map { case (_, template) => + " * Lambdas :\n" + lambdas.map { case template => " +> " + template.toString.split("\n").mkString("\n ") + "\n" }.mkString("\n") } @@ -263,7 +226,7 @@ object Template { applications: Map[T, Set[App[T]]], quantifications: Seq[QuantificationTemplate[T]], matchers: Map[T, Set[Matcher[T]]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], substMap: Map[T, T] ): Instantiation[T] = { @@ -276,10 +239,8 @@ object Template { var instantiation: Instantiation[T] = (newClauses, newBlockers, Map.empty) - for ((idT, lambda) <- lambdas) { - val newIdT = substituter(idT) - val newTemplate = lambda.substitute(substMap) - instantiation ++= manager.instantiateLambda(newIdT, newTemplate) + for (lambda <- lambdas) { + instantiation ++= manager.instantiateLambda(lambda.substitute(substituter)) } for ((b,apps) <- applications; bp = substituter(b); app <- apps) { @@ -292,7 +253,7 @@ object Template { } for (q <- quantifications) { - instantiation ++= q.instantiate(substMap) + instantiation ++= manager.instantiateQuantification(q, substMap) } instantiation @@ -311,7 +272,7 @@ object FunctionTemplate { exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], quantifications: Seq[QuantificationTemplate[T]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], isRealFunDef: Boolean ) : FunctionTemplate[T] = { @@ -359,7 +320,7 @@ class FunctionTemplate[T] private( val applications: Map[T, Set[App[T]]], val quantifications: Seq[QuantificationTemplate[T]], val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]], + val lambdas: Seq[LambdaTemplate[T]], isRealFunDef: Boolean, stringRepr: () => String) extends Template[T] { @@ -367,7 +328,7 @@ class FunctionTemplate[T] private( override def toString : String = str override def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[(T, App[T]), Set[TemplateAppInfo[T]]]) = { - if (!isRealFunDef) manager.registerFree(tfd.params.map(_.getType) zip args) + if (!isRealFunDef) manager.registerFree(tfd.params.map(_.id) zip args) super.instantiate(aVar, args) } } @@ -383,7 +344,8 @@ object LambdaTemplate { condVars: Map[Identifier, T], exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], + quantifications: Seq[QuantificationTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], baseSubstMap: Map[Identifier, T], dependencies: Map[Identifier, T], lambda: Lambda @@ -404,16 +366,17 @@ object LambdaTemplate { val key = structuralLambda.asInstanceOf[Lambda] new LambdaTemplate[T]( - ids._1, + ids, encoder, manager, pathVar._2, - arguments.map(_._2), + arguments, condVars, exprVars, clauses, blockers, applications, + quantifications, matchers, lambdas, keyDeps, @@ -424,30 +387,27 @@ object LambdaTemplate { } class LambdaTemplate[T] private ( - val id: Identifier, + val ids: (Identifier, T), val encoder: TemplateEncoder[T], val manager: QuantificationManager[T], val start: T, - val args: Seq[T], + val arguments: Seq[(Identifier, T)], val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], + val quantifications: Seq[QuantificationTemplate[T]], val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]], + val lambdas: Seq[LambdaTemplate[T]], private[templates] val dependencies: Map[Identifier, T], private[templates] val structuralKey: Lambda, stringRepr: () => String) extends Template[T] { - // Universal quantification is not allowed inside closure bodies! - val quantifications: Seq[QuantificationTemplate[T]] = Seq.empty - - val tpe = id.getType.asInstanceOf[FunctionType] - - def substitute(substMap: Map[T,T]): LambdaTemplate[T] = { - val substituter : T => T = encoder.substitute(substMap) + val args = arguments.map(_._2) + val tpe = ids._1.getType.asInstanceOf[FunctionType] + def substitute(substituter: T => T): LambdaTemplate[T] = { val newStart = substituter(start) val newClauses = clauses.map(substituter) val newBlockers = blockers.map { case (b, fis) => @@ -460,26 +420,29 @@ class LambdaTemplate[T] private ( bp -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) } + val newQuantifications = quantifications.map(_.substitute(substituter)) + val newMatchers = matchers.map { case (b, ms) => val bp = if (b == start) newStart else b bp -> ms.map(_.substitute(substituter)) } - val newLambdas = lambdas.map { case (idT, template) => idT -> template.substitute(substMap) } + val newLambdas = lambdas.map(_.substitute(substituter)) val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) new LambdaTemplate[T]( - id, + ids._1 -> substituter(ids._2), encoder, manager, newStart, - args, + arguments, condVars, exprVars, newClauses, newBlockers, newApplications, + newQuantifications, newMatchers, newLambdas, newDependencies, @@ -514,4 +477,8 @@ class LambdaTemplate[T] private ( Some(rec(structuralKey, that.structuralKey)) } } + + override def instantiate(substMap: Map[T, T]): Instantiation[T] = { + super.instantiate(substMap) ++ manager.instantiateAxiom(this, substMap) + } } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index c0249059a300d13b89b9f48b4ac287d1131c4744..547ec3f9bd3a1f490cc8cb97643531208faa1fd6 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -19,6 +19,9 @@ import purescala.Types._ import scala.collection.mutable.{Map => MutableMap} +case class UnsoundExtractionException(ast: Z3AST, msg: String) + extends Exception("Can't extract " + ast + " : " + msg) + // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" trait AbstractZ3Solver extends Solver { @@ -34,6 +37,9 @@ trait AbstractZ3Solver extends Solver { private[this] var freed = false val traceE = new Exception() + protected def unsound(ast: Z3AST, msg: String): Nothing = + throw UnsoundExtractionException(ast, msg) + override def finalize() { if (!freed) { println("!! Solver "+this.getClass.getName+"["+this.hashCode+"] not freed properly prior to GC:") @@ -309,14 +315,21 @@ trait AbstractZ3Solver extends Solver { newAST } case v @ Variable(id) => z3Vars.get(id) match { - case Some(ast) => ast + case Some(ast) => + ast case None => { + variables.getB(v) match { + case Some(ast) => + ast + + case None => val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) z3Vars = z3Vars + (id -> newAST) variables += (v -> newAST) newAST } } + } case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) case And(exs) => z3.mkAnd(exs.map(rec): _*) @@ -562,10 +575,8 @@ trait AbstractZ3Solver extends Solver { case other => unsupported(other, "Unexpected target type for BV value") } - case None => { - throw LeonFatalError(s"Could not translate hexadecimal Z3 numeral $t") + case None => unsound(t, "could not translate hexadecimal Z3 numeral") } - } } else { tpe match { case Int32Type => IntLiteral(v) @@ -596,15 +607,11 @@ trait AbstractZ3Solver extends Solver { tpe match { case Int32Type => IntLiteral(hexa.toInt) case CharType => CharLiteral(hexa.toInt.toChar) - case IntegerType => InfiniteIntegerLiteral(BigInt(hexa.toInt)) - case _ => - reporter.fatalError("Unexpected target type for BV value: " + tpe.asString) + case _ => unsound(t, "unexpected target type for BV value: " + tpe.asString) } - case None => { - reporter.fatalError(s"Could not translate Z3NumeralIntAST numeral $t to type $tpe") + case None => unsound(t, "could not translate Z3NumeralIntAST numeral") } } - } case Z3NumeralRealAST(n: BigInt, d: BigInt) => FractionalLiteral(n, d) case Z3AppAST(decl, args) => val argsSize = args.size @@ -640,12 +647,12 @@ trait AbstractZ3Solver extends Solver { val entries = elems.map { case (IntLiteral(i), v) => i -> v - case _ => reporter.fatalError("Translation from Z3 to Array failed") + case (e,_) => unsupported(e, s"Z3 returned unexpected array index ${e.asString}") } finiteArray(entries, Some(default, s), to) - case _ => - reporter.fatalError("Translation from Z3 to Array failed") + case (s : IntLiteral, arr) => unsound(args(1), "invalid array type") + case (size, _) => unsound(args(0), "invalid array size") } } } else { @@ -659,7 +666,7 @@ trait AbstractZ3Solver extends Solver { } RawArrayValue(from, entries, default) - case None => reporter.fatalError("Translation from Z3 to Array failed") + case None => unsound(t, "invalid array AST") } case tp: TypeParameter => @@ -684,12 +691,16 @@ trait AbstractZ3Solver extends Solver { FiniteMap(elems, from, to) } - case FunctionType(fts, tt) => - rec(t, RawArrayType(tupleTypeWrap(fts), tt)) + case ft @ FunctionType(fts, tt) => + rec(t, RawArrayType(tupleTypeWrap(fts), tt)) match { + case r: RawArrayValue => + val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, fts.size) -> v } + PartialLambda(elems, Some(r.default), ft) + } case tpe @ SetType(dt) => model.getSetValue(t) match { - case None => reporter.fatalError("Translation from Z3 to set failed") + case None => unsound(t, "invalid set AST") case Some(set) => val elems = set.map(e => rec(e, dt)) FiniteSet(elems, dt) @@ -719,8 +730,7 @@ trait AbstractZ3Solver extends Solver { // case OpDiv => Division(rargs(0), rargs(1)) // case OpIDiv => Division(rargs(0), rargs(1)) // case OpMod => Modulo(rargs(0), rargs(1)) - case other => - reporter.fatalError( + case other => unsound(t, s"""|Don't know what to do with this declKind: $other |Expected type: ${Option(tpe).map{_.asString}.getOrElse("")} |Tree: $t @@ -729,8 +739,7 @@ trait AbstractZ3Solver extends Solver { } } } - case _ => - reporter.fatalError(s"Don't know what to do with this Z3 tree: $t") + case _ => unsound(t, "unexpected AST") } } rec(tree, tpe) @@ -741,6 +750,8 @@ trait AbstractZ3Solver extends Solver { Some(fromZ3Formula(model, tree, tpe)) } catch { case e: Unsupported => None + case e: UnsoundExtractionException => None + case n: java.lang.NumberFormatException => None } } diff --git a/src/main/scala/leon/solvers/z3/FairZ3Component.scala b/src/main/scala/leon/solvers/z3/FairZ3Component.scala index f321d516d66a4d43751519f06e9c65b3a0e4d126..256dcf38fc5dc99f31ce2f6fb30dc12e0caa5268 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Component.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Component.scala @@ -13,6 +13,7 @@ trait FairZ3Component extends LeonComponent { val optUseCodeGen = LeonFlagOptionDef("codegen", "Use compiled evaluator instead of interpreter", false) val optUnrollCores = LeonFlagOptionDef("unrollcores", "Use unsat-cores to drive unrolling while remaining fair", false) val optAssumePre = LeonFlagOptionDef("assumepre", "Assume precondition holds (pre && f(x) = body) when unfolding", false) + val optNoChecks = LeonFlagOptionDef("nochecks", "Disable counter-example check in presence of foralls" , false) override val definedOptions: Set[LeonOptionDef[Any]] = Set(optEvalGround, optCheckModels, optFeelingLucky, optUseCodeGen, optUnrollCores, optAssumePre) diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 5baa9b9aa22713f1b4e86782570e124d48aac767..c129243d15b661d9e42d578447fdea6cb1aea3e0 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -26,7 +26,8 @@ class FairZ3Solver(val context: LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction with FairZ3Component - with EvaluatingSolver { + with EvaluatingSolver + with QuantificationSolver { enclosing => @@ -36,6 +37,9 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val evalGroundApps = context.findOptionOrDefault(optEvalGround) val unrollUnsatCores = context.findOptionOrDefault(optUnrollCores) val assumePreHolds = context.findOptionOrDefault(optAssumePre) + val disableChecks = context.findOptionOrDefault(optNoChecks) + + assert(!checkModels || !disableChecks, "Options \"checkmodels\" and \"nochecks\" are mutually exclusive") protected val errors = new IncrementalBijection[Unit, Boolean]() protected def hasError = errors.getB(()) contains true @@ -56,131 +60,45 @@ class FairZ3Solver(val context: LeonContext, val program: Program) toggleWarningMessages(true) private def extractModel(model: Z3Model, ids: Set[Identifier]): HenkinModel = { - val asMap = modelToMap(model, ids) - def extract(b: Z3AST, m: Matcher[Z3AST]): Set[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe val optEnabler = model.evalAs[Boolean](b) if (optEnabler == Some(true)) { - // FIXME: Dirty hack! - // Unfortunately, blockers don't lead to a true decision tree where all - // unexecutable branches are false since we have - // b1 ==> ( b2 \/ b3) - // b1 ==> (!b2 \/ !b3) - // so b2 /\ b3 is possible when b1 is false. This leads to unsound models - // (like Nil.tail) which obviously cannot be part of a domain but can't be - // translated back from Z3 either. - try { - val optArgs = (m.args zip fromTypes).map { - p => softFromZ3Formula(model, model.eval(Matcher.argValue(p._1), true).get, p._2) - } + val optArgs = (m.args zip fromTypes).map { + p => softFromZ3Formula(model, model.eval(Matcher.argValue(p._1), true).get, p._2) + } - if (optArgs.forall(_.isDefined)) { - Set(optArgs.map(_.get)) - } else { - Set.empty - } - } catch { - case _: Throwable => Set.empty + if (optArgs.forall(_.isDefined)) { + Set(optArgs.map(_.get)) + } else { + Set.empty } } else { Set.empty } } - val funDomains = ids.flatMap(id => id.getType match { - case ft @ FunctionType(fromTypes, _) => variables.getB(id.toVariable) match { - case Some(z3ID) => Some(id -> templateGenerator.manager.instantiations(z3ID, ft).flatMap { - case (b, m) => extract(b, m) - }) - case _ => None - } - case _ => None - }).toMap.mapValues(_.toSet) - - val asDMap = asMap.map(p => funDomains.get(p._1) match { - case Some(domain) => - val mapping = domain.toSeq.map { es => - val ev: Expr = p._2 match { - case RawArrayValue(_, mapping, dflt) => - mapping.collectFirst { - case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v - } getOrElse dflt - case _ => scala.sys.error("Unexpected function encoding " + p._2) - } - es -> ev - } - p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) - case None => p - }) - - val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) - val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + val (typeInsts, partialInsts, lambdaInsts) = templateGenerator.manager.instantiations - val domain = new HenkinDomains(typeDomains) - new HenkinModel(asDMap, domain) - } - - private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, HenkinModel) = { - if(!interrupted) { - - val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap - val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { - if (functions containsB p._1) { - val tfd = functions.toA(p._1) - if (!tfd.hasImplementation) { - val (cses, default) = p._2 - val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( - andJoin( - q._1.zip(tfd.params).map(a12 => Equals(fromZ3Formula(model, a12._1, a12._2.getType), Variable(a12._2.id))) - ), - fromZ3Formula(model, q._2, tfd.returnType), - expr)) - Seq((tfd.id, ite)) - } else Seq() - } else Seq() - }) - - val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { - if(functions containsB p._1) { - val tfd = functions.toA(p._1) - if(!tfd.hasImplementation) { - Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) - } else Seq() - } else Seq() - }).toMap - - val leonModel = extractModel(model, variables) - val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) - val evalResult = evaluator.eval(formula, fullModel) - - evalResult match { - case EvaluationResults.Successful(BooleanLiteral(true)) => - reporter.debug("- Model validated.") - (true, fullModel) - - case EvaluationResults.Successful(res) => - assert(res == BooleanLiteral(false), "Checking model returned non-boolean") - reporter.debug("- Invalid model.") - (false, fullModel) - - case EvaluationResults.RuntimeError(msg) => - reporter.debug("- Model leads to runtime error.") - (false, fullModel) - - case EvaluationResults.EvaluatorError(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to evaluator error: " + msg) - } else { - reporter.warning("Something went wrong. While evaluating the model, we got this : " + msg) - } - (false, fullModel) + val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + val funDomains: Map[Identifier, Set[Seq[Expr]]] = partialInsts.flatMap { + case (c, domain) => variables.getA(c).collect { + case Variable(id) => id -> domain.flatMap { case (b, m) => extract(b, m) }.toSet } - } else { - (false, HenkinModel.empty) } + + val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { + case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + val asMap = modelToMap(model, ids) + val asDMap = purescala.Quantification.extractModel(asMap, funDomains, typeDomains, evaluator) + val domains = new HenkinDomains(lambdaDomains, typeDomains) + new HenkinModel(asDMap, domains) } implicit val z3Printable = (z3: Z3AST) => new Printable { @@ -315,6 +233,115 @@ class FairZ3Solver(val context: LeonContext, val program: Program) }).toSet } + def validatedModel(silenceErrors: Boolean) : (Boolean, HenkinModel) = { + if (interrupted) { + (false, HenkinModel.empty) + } else { + val lastModel = solver.getModel + val clauses = templateGenerator.manager.checkClauses + val optModel = if (clauses.isEmpty) Some(lastModel) else { + solver.push() + for (clause <- clauses) { + solver.assertCnstr(clause) + } + + reporter.debug(" - Verifying model transitivity") + val timer = context.timers.solvers.z3.check.start() + solver.push() // FIXME: remove when z3 bug is fixed + val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.satisfactionAssumptions) :_*) + solver.pop() // FIXME: remove when z3 bug is fixed + timer.stop() + + val solverModel = res match { + case Some(true) => + Some(solver.getModel) + + case Some(false) => + val msg = "- Transitivity independence not guaranteed for model" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + + case None => + val msg = "- Unknown for transitivity independence!?" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + } + + solver.pop() + solverModel + } + + val model = optModel getOrElse lastModel + + val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap + val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { + if (functions containsB p._1) { + val tfd = functions.toA(p._1) + if (!tfd.hasImplementation) { + val (cses, default) = p._2 + val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( + andJoin( + q._1.zip(tfd.params).map(a12 => Equals(fromZ3Formula(model, a12._1, a12._2.getType), Variable(a12._2.id))) + ), + fromZ3Formula(model, q._2, tfd.returnType), + expr)) + Seq((tfd.id, ite)) + } else Seq() + } else Seq() + }) + + val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { + if(functions containsB p._1) { + val tfd = functions.toA(p._1) + if(!tfd.hasImplementation) { + Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) + } else Seq() + } else Seq() + }).toMap + + val leonModel = extractModel(model, freeVars.toSet) + val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) + + if (!optModel.isDefined) { + (false, leonModel) + } else { + (evaluator.check(entireFormula, fullModel) match { + case EvaluationResults.CheckSuccess => + reporter.debug("- Model validated.") + true + + case EvaluationResults.CheckValidityFailure => + reporter.debug("- Invalid model.") + false + + case EvaluationResults.CheckRuntimeFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to evaluation error: " + msg) + } else { + reporter.warning("- Model leads to evaluation error: " + msg) + } + false + + case EvaluationResults.CheckQuantificationFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to quantification error: " + msg) + } else { + reporter.warning("- Model leads to quantification error: " + msg) + } + false + }, leonModel) + } + } + } + while(!foundDefinitiveAnswer && !interrupted) { //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) @@ -346,27 +373,18 @@ class FairZ3Solver(val context: LeonContext, val program: Program) foundAnswer(None) case Some(true) => // SAT - - val z3model = solver.getModel() - - if (this.checkModels) { - val (isValid, model) = validateModel(z3model, entireFormula, allVars, silenceErrors = false) - - if (isValid) { - foundAnswer(Some(true), model) - } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model) - foundAnswer(None, model) - } + val (valid, model) = if (!this.disableChecks && (this.checkModels || requireQuantification)) { + validatedModel(false) } else { - val model = extractModel(z3model, allVars) - - //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") - //reporter.debug("- Found a model:") - //reporter.debug(modelAsString) + true -> extractModel(solver.getModel, allVars) + } + if (valid) { foundAnswer(Some(true), model) + } else { + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model.asString(context)) + foundAnswer(None, model) } case Some(false) if !unrollingBank.canUnroll => @@ -433,7 +451,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) //reporter.debug("SAT WITHOUT Blockers") if (this.feelingLucky && !interrupted) { // we might have been lucky :D - val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, allVars, silenceErrors = true) + val (wereWeLucky, cleanModel) = validatedModel(true) if(wereWeLucky) { foundAnswer(Some(true), cleanModel) diff --git a/src/main/scala/leon/synthesis/ConversionPhase.scala b/src/main/scala/leon/synthesis/ConversionPhase.scala index 8a9035434b4b719967bdf800481c60738a29666e..d172c27204e8fac1219ad4041221ad874f4b9ba2 100644 --- a/src/main/scala/leon/synthesis/ConversionPhase.scala +++ b/src/main/scala/leon/synthesis/ConversionPhase.scala @@ -75,6 +75,19 @@ object ConversionPhase extends UnitPhase[Program] { * require(..a..) * choose(x => post(x)) * } + * (in practice, there will be no pre-and postcondition) + * + * 4) Functions that have only a choose as body gets their spec from the choose. + * + * def foo(a: T) = { + * choose(x => post(a, x)) + * } + * + * gets converted to: + * + * def foo(a: T) = { + * choose(x => post(a, x)) + * } ensuring { x => post(a, x) } * * (in practice, there will be no pre-and postcondition) */ @@ -116,7 +129,7 @@ object ConversionPhase extends UnitPhase[Program] { } } - body match { + val fullBody = body match { case Some(body) => var holes = List[Identifier]() @@ -173,6 +186,14 @@ object ConversionPhase extends UnitPhase[Program] { val newPost = post getOrElse Lambda(Seq(ValDef(FreshIdentifier("res", e.getType))), BooleanLiteral(true)) withPrecondition(Choose(newPost), pre) } + + // extract spec from chooses at the top-level + fullBody match { + case Choose(spec) => + withPostcondition(fullBody, Some(spec)) + case _ => + fullBody + } } diff --git a/src/main/scala/leon/synthesis/ExamplesBank.scala b/src/main/scala/leon/synthesis/ExamplesBank.scala index 5b00fd2900f44c07b8529d720aa7371ada9cfa09..265b35a5e5136bdcbfbd6e458350580c42bae9e3 100644 --- a/src/main/scala/leon/synthesis/ExamplesBank.scala +++ b/src/main/scala/leon/synthesis/ExamplesBank.scala @@ -4,7 +4,6 @@ package synthesis import purescala.Definitions._ import purescala.Expressions._ import purescala.Constructors._ -import evaluators._ import purescala.Common._ import repair._ import leon.utils.ASCIIHelpers._ @@ -178,9 +177,7 @@ case class QualifiedExamplesBank(as: List[Identifier], xs: List[Identifier], eb: /** Filter inputs throught expr which is an expression evaluating to a boolean */ def filterIns(expr: Expr): ExamplesBank = { - val ev = new DefaultEvaluator(hctx.sctx.context, hctx.sctx.program) - - filterIns(m => ev.eval(expr, m).result == Some(BooleanLiteral(true))) + filterIns(m => hctx.sctx.defaultEvaluator.eval(expr, m).result == Some(BooleanLiteral(true))) } /** Filters inputs through the predicate pred, with an assignment of input variables to expressions. */ diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index bb90d2a268411c44e9dfdbde013886b7db5b56d9..b01077ce751f6b53e397b15eb3d5126e560d1b56 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -6,6 +6,7 @@ package synthesis import solvers._ import solvers.combinators._ import purescala.Definitions.{Program, FunDef} +import evaluators._ /** * This is global information per entire search, contains necessary information @@ -22,6 +23,10 @@ case class SynthesisContext( val rules = settings.rules val solverFactory = SolverFactory.getFromSettings(context, program) + + lazy val defaultEvaluator = { + new DefaultEvaluator(context, program) + } } object SynthesisContext { diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index ff53916a9a84257972dc3b374db2ec4e42db6838..dde6ce913dea17b33d1618a6dd9fe4391591468e 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -11,34 +11,45 @@ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Constructors._ import purescala.Definitions._ -import solvers._ /** Abstract data type split. If a variable is typed as an abstract data type, then * it will create a match case statement on all known subtypes. */ case object ADTSplit extends Rule("ADT Split.") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(hctx.sctx.solverFactory.withTimeout(200L)) + // We approximate knowledge of types based on facts found at the top-level + // we don't care if the variables are known to be equal or not, we just + // don't want to split on two variables for which only one split + // alternative is viable. This should be much less expensive than making + // calls to a solver for each pair. + var facts = Map[Identifier, CaseClassType]() + + def addFacts(e: Expr): Unit = e match { + case Equals(Variable(a), CaseClass(cct, _)) => facts += a -> cct + case IsInstanceOf(Variable(a), cct: CaseClassType) => facts += a -> cct + case _ => + } + + val TopLevelAnds(as) = and(p.pc, p.phi) + for (e <- as) { + addFacts(e) + } val candidates = p.as.collect { case IsTyped(id, act @ AbstractClassType(cd, tpes)) => - val optCases = for (dcd <- cd.knownDescendants.sortBy(_.id.name)) yield dcd match { + val optCases = cd.knownDescendants.sortBy(_.id.name).collect { case ccd : CaseClassDef => val cct = CaseClassType(ccd, tpes) - val toSat = and(p.pc, IsInstanceOf(Variable(id), cct)) - val isImplied = solver.solveSAT(toSat) match { - case (Some(false), _) => true - case _ => false - } - - if (!isImplied) { - Some(ccd) + if (facts contains id) { + if (cct == facts(id)) { + Seq(ccd) + } else { + Nil + } } else { - None + Seq(ccd) } - case _ => - None } val cases = optCases.flatten diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index 1bb7ed94851b7c1271c53cfa4bc92850ba583e63..5d06aa76a3c1186e806edf605431643e6b2a9359 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -42,14 +42,13 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val sctx = hctx.sctx val ctx = sctx.context - // CEGIS Flags to activate or deactivate features val useOptTimeout = sctx.settings.cegisUseOptTimeout.getOrElse(true) val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false) - val useShrink = sctx.settings.cegisUseShrink.getOrElse(true) + val useShrink = sctx.settings.cegisUseShrink.getOrElse(false) // Limits the number of programs CEGIS will specifically validate individually - val validateUpTo = 5 + val validateUpTo = 3 // Shrink the program when the ratio of passing cases is less than the threshold val shrinkThreshold = 1.0/2 @@ -313,6 +312,8 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { private val cTreeFd = new FunDef(FreshIdentifier("cTree", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), p.outType) + private val solFd = new FunDef(FreshIdentifier("solFd", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), p.outType) + private val phiFd = new FunDef(FreshIdentifier("phiFd", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), BooleanType) @@ -320,13 +321,20 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val outerSolution = { new PartialSolution(hctx.search.g, true) - .solutionAround(hctx.currentNode)(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable))) + .solutionAround(hctx.currentNode)(FunctionInvocation(solFd.typed, p.as.map(_.toVariable))) .getOrElse(ctx.reporter.fatalError("Unable to get outer solution")) } - val program0 = addFunDefs(sctx.program, Seq(cTreeFd, phiFd) ++ outerSolution.defs, hctx.ci.fd) + val program0 = addFunDefs(sctx.program, Seq(cTreeFd, solFd, phiFd) ++ outerSolution.defs, hctx.ci.fd) cTreeFd.body = None + + solFd.fullBody = Ensuring( + FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), + Lambda(p.xs.map(ValDef(_)), p.phi) + ) + + phiFd.body = Some( letTuple(p.xs, FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), @@ -348,7 +356,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { // We freshen/duplicate every functions, except these two as they are // fresh anyway and we refer to them directly. - case `cTreeFd` | `phiFd` => + case `cTreeFd` | `phiFd` | `solFd` => None case fd => @@ -453,13 +461,13 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { //println("Solving for: "+cnstr.asString) - val solverf = SolverFactory.default(ctx, innerProgram).withTimeout(cexSolverTo) + val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) val solver = solverf.getNewSolver() try { solver.assertCnstr(cnstr) solver.check match { case Some(true) => - excludeProgram(bs) + excludeProgram(bs, true) val model = solver.getModel //println("Found counter example: ") //for ((s, v) <- model) { @@ -498,8 +506,24 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { var excludedPrograms = ArrayBuffer[Set[Identifier]]() // Explicitly remove program computed by bValues from the search space - def excludeProgram(bValues: Set[Identifier]): Unit = { - val bvs = bValues.filter(isBActive) + // + // If the bValues comes from models, we make sure the bValues we exclude + // are minimal we make sure we exclude only Bs that are used. + def excludeProgram(bValues: Set[Identifier], isMinimal: Boolean): Unit = { + val bs = bValues.filter(isBActive) + + def filterBTree(c: Identifier): Set[Identifier] = { + (for ((b, _, subcs) <- cTree(c) if bValues(b)) yield { + Set(b) ++ subcs.flatMap(filterBTree) + }).toSet.flatten + } + + val bvs = if (isMinimal) { + bs + } else { + rootCs.flatMap(filterBTree).toSet + } + excludedPrograms += bvs } @@ -512,19 +536,21 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { * First phase of CEGIS: solve for potential programs (that work on at least one input) */ def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = { - val solverf = SolverFactory.default(ctx, programCTree).withTimeout(exSolverTo) + val solverf = SolverFactory.getFromSettings(ctx, programCTree).withTimeout(exSolverTo) val solver = solverf.getNewSolver() val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) - //debugCExpr(cTree) - //println(" --- PhiFD ---") - //println(phiFd.fullBody.asString(ctx)) + //println("Program: ") + //println("-"*80) + //println(programCTree.asString) val toFind = and(innerPc, cnstr) //println(" --- Constraints ---") - //println(" - "+toFind) + //println(" - "+toFind.asString) try { - solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable)))) + //TODO: WHAT THE F IS THIS? + //val bsOrNotBs = andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable))) + //solver.assertCnstr(bsOrNotBs) solver.assertCnstr(toFind) for ((c, alts) <- cTree) { @@ -535,26 +561,25 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } if (activeBs.nonEmpty) { - //println(" - "+andJoin(either)) + //println(" - "+andJoin(either).asString) solver.assertCnstr(andJoin(either)) val oneOf = orJoin(activeBs.map(_.toVariable)) - //println(" - "+oneOf) + //println(" - "+oneOf.asString) solver.assertCnstr(oneOf) } } - //println(" -- Excluded:") //println(" -- Active:") val isActive = andJoin(bsOrdered.filterNot(isBActive).map(id => Not(id.toVariable))) - //println(" - "+isActive) + //println(" - "+isActive.asString) solver.assertCnstr(isActive) //println(" -- Excluded:") for (ex <- excludedPrograms) { val notThisProgram = Not(andJoin(ex.map(_.toVariable).toSeq)) - //println(f" - $notThisProgram%-40s ("+getExpr(ex)+")") + //println(f" - ${notThisProgram.asString}%-40s ("+getExpr(ex)+")") solver.assertCnstr(notThisProgram) } @@ -564,7 +589,10 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val bModel = bs.filter(b => model.get(b).contains(BooleanLiteral(true))) + //println("Tentative model: "+model.asString) + //println("Tentative model: "+bModel.filter(isBActive).map(_.asString).toSeq.sorted) //println("Tentative expr: "+getExpr(bModel)) + Some(Some(bModel)) case Some(false) => @@ -590,7 +618,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { * Second phase of CEGIS: verify a given program by looking for CEX inputs */ def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = { - val solverf = SolverFactory.default(ctx, programCTree).withTimeout(cexSolverTo) + val solverf = SolverFactory.getFromSettings(ctx, programCTree).withTimeout(cexSolverTo) val solver = solverf.getNewSolver() val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) @@ -600,6 +628,14 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { solver.assertCnstr(innerPc) solver.assertCnstr(Not(cnstr)) + //println("*"*80) + //println(Not(cnstr)) + //println(innerPc) + //println("*"*80) + //println(programCTree.asString) + //println("*"*80) + //Console.in.read() + solver.check match { case Some(true) => val model = solver.getModel @@ -684,17 +720,6 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { */ val nTests = if (p.pc == BooleanLiteral(true)) 50 else 20 - /* - val inputGenerator: FreeableIterator[Example] = { - val sff = { - (ctx: LeonContext, pgm: Program) => - SolverFactory.default(ctx, pgm).withTimeout(exSolverTo) - } - new SolverDataGen(sctx.context, sctx.program, sff).generateFor(p.as, p.pc, nTests, nTests).map { - InExample(_) - } - } */ - val inputGenerator: Iterator[Example] = if (useVanuatoo) { new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, nTests, 3000).map(InExample) } else { @@ -812,15 +837,15 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } if (doFilter && !(nPassing < nInitial * shrinkThreshold && useShrink)) { + sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") wrongPrograms.foreach { - ndProgram.excludeProgram + ndProgram.excludeProgram(_, true) } } } // CEGIS Loop at a given unfolding level while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted) { - ndProgram.solveForTentativeProgram() match { case Some(Some(bs)) => // Should we validate this program with Z3? @@ -832,18 +857,21 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { // make sure by validating this candidate with z3 true } else { + println("testing failed ?!") // One valid input failed with this candidate, we can skip - ndProgram.excludeProgram(bs) + ndProgram.excludeProgram(bs, false) false } } else { // No inputs or capability to test, we need to ask Z3 true } + sctx.reporter.debug("Found tentative model (Validate="+validateWithZ3+")!") if (validateWithZ3) { ndProgram.solveForCounterExample(bs) match { case Some(Some(inputsCE)) => + sctx.reporter.debug("Found counter-example:"+inputsCE) val ce = InExample(inputsCE) // Found counter example! baseExampleInputs += ce @@ -852,7 +880,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(ce))) { skipCESearch = true } else { - ndProgram.excludeProgram(bs) + ndProgram.excludeProgram(bs, false) } case Some(None) => @@ -862,6 +890,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { case None => // We are not sure + sctx.reporter.debug("Unknown") if (useOptTimeout) { // Interpret timeout in CE search as "the candidate is valid" sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index 78e30502c2b52bc47bc915bfe72f35c61ffb4e15..79595656c4c53cca6e47747ac34a52f436697b31 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -6,6 +6,7 @@ package rules import leon.purescala.Common.Identifier import purescala.Expressions._ +import purescala.Extractors._ import purescala.Constructors._ import solvers._ @@ -16,31 +17,31 @@ import scala.concurrent.duration._ * checks equality and output an If-Then-Else statement with the two new branches. */ case object EqualitySplit extends Rule("Eq. Split") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(hctx.sctx.solverFactory.withTimeout(50.millis)) + // We approximate knowledge of equality based on facts found at the top-level + // we don't care if the variables are known to be equal or not, we just + // don't want to split on two variables for which only one split + // alternative is viable. This should be much less expensive than making + // calls to a solver for each pair. + var facts = Set[Set[Identifier]]() - val candidates = p.as.groupBy(_.getType).mapValues(_.combinations(2).filter { - case List(a1, a2) => - val toValEQ = implies(p.pc, Equals(Variable(a1), Variable(a2))) - - val impliesEQ = solver.solveSAT(Not(toValEQ)) match { - case (Some(false), _) => true - case _ => false - } - - if (!impliesEQ) { - val toValNE = implies(p.pc, not(Equals(Variable(a1), Variable(a2)))) + def addFacts(e: Expr): Unit = e match { + case Not(e) => addFacts(e) + case LessThan(Variable(a), Variable(b)) => facts += Set(a,b) + case LessEquals(Variable(a), Variable(b)) => facts += Set(a,b) + case GreaterThan(Variable(a), Variable(b)) => facts += Set(a,b) + case GreaterEquals(Variable(a), Variable(b)) => facts += Set(a,b) + case Equals(Variable(a), Variable(b)) => facts += Set(a,b) + case _ => + } - val impliesNE = solver.solveSAT(Not(toValNE)) match { - case (Some(false), _) => true - case _ => false - } + val TopLevelAnds(as) = and(p.pc, p.phi) + for (e <- as) { + addFacts(e) + } - !impliesNE - } else { - false - } - case _ => false - }).values.flatten + val candidates = p.as.groupBy(_.getType).mapValues{ as => + as.combinations(2).filterNot(facts contains _.toSet) + }.values.flatten candidates.flatMap { case List(a1, a2) => diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala index 292e99de80d4cf7eb3351621edde3587645b0402..8b728930a10ad7d73c121762bde43637ce988fbc 100644 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -7,64 +7,42 @@ package rules import purescala.Expressions._ import purescala.Types._ import purescala.Constructors._ - -import solvers._ +import purescala.Extractors._ +import purescala.Common._ import scala.concurrent.duration._ case object InequalitySplit extends Rule("Ineq. Split.") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(hctx.sctx.solverFactory.withTimeout(50.millis)) + // We approximate knowledge of equality based on facts found at the top-level + // we don't care if the variables are known to be equal or not, we just + // don't want to split on two variables for which only one split + // alternative is viable. This should be much less expensive than making + // calls to a solver for each pair. + var facts = Set[Set[Identifier]]() + + def addFacts(e: Expr): Unit = e match { + case Not(e) => addFacts(e) + case LessThan(Variable(a), Variable(b)) => facts += Set(a,b) + case LessEquals(Variable(a), Variable(b)) => facts += Set(a,b) + case GreaterThan(Variable(a), Variable(b)) => facts += Set(a,b) + case GreaterEquals(Variable(a), Variable(b)) => facts += Set(a,b) + case Equals(Variable(a), Variable(b)) => facts += Set(a,b) + case _ => + } + + val TopLevelAnds(as) = and(p.pc, p.phi) + for (e <- as) { + addFacts(e) + } val argsPairs = p.as.filter(_.getType == IntegerType).combinations(2) ++ p.as.filter(_.getType == Int32Type).combinations(2) - val candidates = argsPairs.toList.filter { - case List(a1, a2) => - val toValLT = implies(p.pc, LessThan(Variable(a1), Variable(a2))) - - val impliesLT = solver.solveSAT(not(toValLT)) match { - case (Some(false), _) => true - case _ => false - } - - if (!impliesLT) { - val toValGT = implies(p.pc, GreaterThan(Variable(a1), Variable(a2))) - - val impliesGT = solver.solveSAT(not(toValGT)) match { - case (Some(false), _) => true - case _ => false - } + val candidates = argsPairs.toList.filter { case List(a1, a2) => !(facts contains Set(a1, a2)) } - if (!impliesGT) { - val toValEQ = implies(p.pc, Equals(Variable(a1), Variable(a2))) - - val impliesEQ = solver.solveSAT(not(toValEQ)) match { - case (Some(false), _) => true - case _ => false - } - - !impliesEQ - } else { - false - } - } else { - false - } - case _ => false - } - - - candidates.flatMap { + candidates.collect { case List(a1, a2) => - - val subLT = p.copy(pc = and(LessThan(Variable(a1), Variable(a2)), p.pc), - eb = p.qeb.filterIns(LessThan(Variable(a1), Variable(a2)))) - val subEQ = p.copy(pc = and(Equals(Variable(a1), Variable(a2)), p.pc), - eb = p.qeb.filterIns(Equals(Variable(a1), Variable(a2)))) - val subGT = p.copy(pc = and(GreaterThan(Variable(a1), Variable(a2)), p.pc), - eb = p.qeb.filterIns(GreaterThan(Variable(a1), Variable(a2)))) - val onSuccess: List[Solution] => Option[Solution] = { case sols@List(sLT, sEQ, sGT) => val pre = orJoin(sols.map(_.pre)) @@ -85,9 +63,24 @@ case object InequalitySplit extends Rule("Ineq. Split.") { None } - Some(decomp(List(subLT, subEQ, subGT), onSuccess, s"Ineq. Split on '$a1' and '$a2'")) - case _ => - None + val subTypes = List(p.outType, p.outType, p.outType) + + new RuleInstantiation(s"Ineq. Split on '$a1' and '$a2'", + SolutionBuilderDecomp(subTypes, onSuccess)) { + + def apply(hctx: SearchContext) = { + implicit val _ = hctx + + val subLT = p.copy(pc = and(LessThan(Variable(a1), Variable(a2)), p.pc), + eb = p.qeb.filterIns(LessThan(Variable(a1), Variable(a2)))) + val subEQ = p.copy(pc = and(Equals(Variable(a1), Variable(a2)), p.pc), + eb = p.qeb.filterIns(Equals(Variable(a1), Variable(a2)))) + val subGT = p.copy(pc = and(GreaterThan(Variable(a1), Variable(a2)), p.pc), + eb = p.qeb.filterIns(GreaterThan(Variable(a1), Variable(a2)))) + + RuleExpanded(List(subLT, subEQ, subGT)) + } + } } } } diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 87d1003150995eb236e5f3abb738450516c5b7f7..7bfeece2f6f35e88c4557db2ecafe4c684a1198d 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -5,7 +5,6 @@ package utils import leon.purescala._ import leon.purescala.Definitions.Program -import leon.purescala.Quantification.CheckForalls import leon.solvers.isabelle.AdaptationPhase import leon.verification.InjectAsserts import leon.xlang.{NoXLangFeaturesChecking, XLangDesugaringPhase} @@ -39,8 +38,7 @@ class PreprocessingPhase(desugarXLang: Boolean = false) extends LeonPhase[Progra synthesis.ConversionPhase andThen CheckADTFieldsTypes andThen InjectAsserts andThen - InliningPhase andThen - CheckForalls + InliningPhase val pipeX = if(desugarXLang) { XLangDesugaringPhase andThen diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index fe9582d548deb679e0de7a082d5e4ba178582a1f..e9802d500208b96b5af448391d69c6660a05f843 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -10,6 +10,7 @@ import leon.purescala.Extractors._ import leon.purescala.Constructors._ import leon.purescala.ExprOps._ import leon.purescala.TypeOps._ +import leon.purescala.Types._ import leon.xlang.Expressions._ object ImperativeCodeElimination extends UnitPhase[Program] { @@ -22,13 +23,22 @@ object ImperativeCodeElimination extends UnitPhase[Program] { fd <- pgm.definedFunctions body <- fd.body } { - val (res, scope, _) = toFunction(body)(State(fd, Set())) + val (res, scope, _) = toFunction(body)(State(fd, Set(), Map())) fd.body = Some(scope(res)) } } - case class State(parent: FunDef, varsInScope: Set[Identifier]) { + /* varsInScope refers to variable declared in the same level scope. + Typically, when entering a nested function body, the scope should be + reset to empty */ + private case class State( + parent: FunDef, + varsInScope: Set[Identifier], + funDefsMapping: Map[FunDef, (FunDef, List[Identifier])] + ) { def withVar(i: Identifier) = copy(varsInScope = varsInScope + i) + def withFunDef(fd: FunDef, nfd: FunDef, ids: List[Identifier]) = + copy(funDefsMapping = funDefsMapping + (fd -> (nfd, ids))) } //return a "scope" consisting of purely functional code that defines potentially needed @@ -119,6 +129,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) case wh@While(cond, body) => + //TODO: rewrite by re-using the nested function transformation code val (condRes, condScope, condFun) = toFunction(cond) val (_, bodyScope, bodyFun) = toFunction(body) val condBodyFun = condFun ++ bodyFun @@ -218,14 +229,115 @@ object ImperativeCodeElimination extends UnitPhase[Program] { bindFun ++ bodyFun ) + //a function invocation can update variables in scope. + case fi@FunctionInvocation(tfd, args) => + + + val (recArgs, argScope, argFun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => { + val (accArgs, accScope, accFun) = acc + val (argVal, argScope, argFun) = toFunction(arg) + val newScope = (body: Expr) => argScope(replaceNames(argFun, accScope(body))) + (argVal +: accArgs, newScope, argFun ++ accFun) + }) + + val fd = tfd.fd + state.funDefsMapping.get(fd) match { + case Some((newFd, modifiedVars)) => { + val newInvoc = FunctionInvocation(newFd.typed, recArgs ++ modifiedVars.map(id => id.toVariable)).setPos(fi) + val freshNames = modifiedVars.map(id => id.freshen) + val tmpTuple = FreshIdentifier("t", newFd.returnType) + + val scope = (body: Expr) => { + argScope(Let(tmpTuple, newInvoc, + freshNames.zipWithIndex.foldRight(body)((p, b) => + Let(p._1, TupleSelect(tmpTuple.toVariable, p._2 + 2), b)) + )) + } + val newMap = argFun ++ modifiedVars.zip(freshNames).toMap + + (TupleSelect(tmpTuple.toVariable, 1), scope, newMap) + } + case None => + (FunctionInvocation(tfd, recArgs).copiedFrom(fi), argScope, argFun) + } + + case LetDef(fd, b) => - //Recall that here the nested function should not access mutable variables from an outside scope - fd.body.foreach { bd => - val (fdRes, fdScope, _) = toFunction(bd) - fd.body = Some(fdScope(fdRes)) + + def fdWithoutSideEffects = { + fd.body.foreach { bd => + val (fdRes, fdScope, _) = toFunction(bd) + fd.body = Some(fdScope(fdRes)) + } + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)).setPos(fd).copiedFrom(expr), bodyFun) + } + + fd.body match { + case Some(bd) => { + + val modifiedVars: List[Identifier] = + collect[Identifier]({ + case Assignment(v, _) => Set(v) + case _ => Set() + })(bd).intersect(state.varsInScope).toList + + if(modifiedVars.isEmpty) fdWithoutSideEffects else { + + val freshNames: List[Identifier] = modifiedVars.map(id => id.freshen) + + val newParams: Seq[ValDef] = fd.params ++ freshNames.map(n => ValDef(n)) + val freshVarDecls: List[Identifier] = freshNames.map(id => id.freshen) + + val rewritingMap: Map[Identifier, Identifier] = + modifiedVars.zip(freshVarDecls).toMap + val freshBody = + preMap({ + case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e)) + case Variable(id) => rewritingMap.get(id).map(nid => Variable(nid)) + case _ => None + })(bd) + val wrappedBody = freshNames.zip(freshVarDecls).foldLeft(freshBody)((body, p) => { + LetVar(p._2, Variable(p._1), body) + }) + + val newReturnType = TupleType(fd.returnType :: modifiedVars.map(_.getType)) + + val newFd = new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd) + + val (fdRes, fdScope, fdFun) = + toFunction(wrappedBody)( + State(state.parent, Set(), + state.funDefsMapping + (fd -> ((newFd, freshVarDecls)))) + ) + val newRes = Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable)) + val newBody = fdScope(newRes) + + newFd.body = Some(newBody) + newFd.precondition = fd.precondition.map(prec => { + replace(modifiedVars.zip(freshNames).map(p => (p._1.toVariable, p._2.toVariable)).toMap, prec) + }) + newFd.postcondition = fd.postcondition.map(post => { + val Lambda(Seq(res), postBody) = post + val newRes = ValDef(FreshIdentifier(res.id.name, newFd.returnType)) + + val newBody = + replace( + modifiedVars.zipWithIndex.map{ case (v, i) => + (v.toVariable, TupleSelect(newRes.toVariable, i+2)): (Expr, Expr)}.toMap ++ + modifiedVars.zip(freshNames).map{ case (ov, nv) => + (Old(ov), nv.toVariable)}.toMap + + (res.toVariable -> TupleSelect(newRes.toVariable, 1)), + postBody) + Lambda(Seq(newRes), newBody).setPos(post) + }) + + val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withFunDef(fd, newFd, modifiedVars)) + (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)).copiedFrom(expr), bodyFun) + } + } + case None => fdWithoutSideEffects } - val (bodyRes, bodyScope, bodyFun) = toFunction(b) - (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)).copiedFrom(expr), bodyFun) case c @ Choose(b) => //Recall that Choose cannot mutate variables from the scope diff --git a/src/test/resources/regression/verification/purescala/invalid/Existentials.scala b/src/test/resources/regression/verification/purescala/invalid/Existentials.scala new file mode 100644 index 0000000000000000000000000000000000000000..19679db9202b325edc0eb8dfff8eeea41bb47d36 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Existentials.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object Existentials { + + def exists[A](p: A => Boolean): Boolean = !forall((x: A) => !p(x)) + + def check1(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) == exists((y1:BigInt) => p(y1)) + }.holds + + /* + def check2(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) ==> exists((y1:BigInt) => p(y1)) + }.holds + */ +} diff --git a/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala b/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala new file mode 100644 index 0000000000000000000000000000000000000000..83773b2224fb8790acfc25125143b358c1226f05 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object ForallAssoc { + + /* + def test3(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, f(4, 5)))) == f(f(f(f(1, 2), 3), 4), 4) + }.holds + */ + + def test4(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, 4))) == 0 + }.holds + +} diff --git a/src/test/resources/regression/verification/purescala/valid/Existentials.scala b/src/test/resources/regression/verification/purescala/valid/Existentials.scala new file mode 100644 index 0000000000000000000000000000000000000000..992d58cd0678e98146ad1085a22269671a35421c --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Existentials.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object Existentials { + + def exists[A](p: A => Boolean): Boolean = !forall((x: A) => !p(x)) + + /* + def check1(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) == exists((y1:BigInt) => p(y1)) + }.holds + */ + + def check2(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) ==> exists((y1:BigInt) => p(y1)) + }.holds +} diff --git a/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala b/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae90e9489678e1f4cc20d90b0cd23767cc317546 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala @@ -0,0 +1,23 @@ +import leon.lang._ + +object ForallAssoc { + + def ex[A](x1: A, x2: A, x3: A, x4: A, x5: A, f: (A, A) => A) = { + require(forall { + (x: A, y: A, z: A) => f(x, f(y, z)) == f(f(x, y), z) + }) + + f(x1, f(x2, f(x3, f(x4, x5)))) == f(f(x1, f(x2, f(x3, x4))), x5) + }.holds + + def test1(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, 4))) == f(f(f(1, 2), 3), 4) + }.holds + + def test2(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, f(4, 5)))) == f(f(f(f(1, 2), 3), 4), 5) + }.holds + +} diff --git a/src/test/resources/regression/verification/purescala/valid/Predicate.scala b/src/test/resources/regression/verification/purescala/valid/Predicate.scala new file mode 100644 index 0000000000000000000000000000000000000000..c011d4b910c2d3859b0e2cca80e069de5b19788b --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Predicate.scala @@ -0,0 +1,48 @@ +package leon.monads.predicate + +import leon.collection._ +import leon.lang._ +import leon.annotation._ + +object Predicate { + + def exists[A](p: A => Boolean): Boolean = !forall((a: A) => !p(a)) + + // Monadic bind + @inline + def flatMap[A,B](p: A => Boolean, f: A => (B => Boolean)): B => Boolean = { + (b: B) => exists[A](a => p(a) && f(a)(b)) + } + + // All monads are also functors, and they define the map function + @inline + def map[A,B](p: A => Boolean, f: A => B): B => Boolean = { + (b: B) => exists[A](a => p(a) && f(a) == b) + } + + /* + @inline + def >>=[B](f: A => Predicate[B]): Predicate[B] = flatMap(f) + + @inline + def >>[B](that: Predicate[B]) = >>= ( _ => that ) + + @inline + def withFilter(f: A => Boolean): Predicate[A] = { + Predicate { a => p(a) && f(a) } + } + */ + + def equals[A](p: A => Boolean, that: A => Boolean): Boolean = { + forall[A](a => p(a) == that(a)) + } + + def test[A,B,C](p: A => Boolean, f: A => B, g: B => C): Boolean = { + equals(map(map(p, f), g), map(p, (a: A) => g(f(a)))) + }.holds + + def testInt(p: BigInt => Boolean, f: BigInt => BigInt, g: BigInt => BigInt): Boolean = { + equals(map(map(p, f), g), map(p, (x: BigInt) => g(f(x)))) + }.holds +} + diff --git a/src/test/resources/regression/verification/xlang/invalid/NestedFunState1.scala b/src/test/resources/regression/verification/xlang/invalid/NestedFunState1.scala new file mode 100644 index 0000000000000000000000000000000000000000..9372ea742a278a3da3d476dca57e52c55c09b978 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/invalid/NestedFunState1.scala @@ -0,0 +1,20 @@ +object NestedFunState1 { + + def simpleSideEffect(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def incA(prevA: BigInt): Unit = { + require(prevA == a) + a += 1 + } ensuring(_ => a == prevA + 1) + + incA(a) + incA(a) + incA(a) + incA(a) + a + } ensuring(_ == 5) + +} diff --git a/src/test/resources/regression/verification/xlang/invalid/NestedFunState2.scala b/src/test/resources/regression/verification/xlang/invalid/NestedFunState2.scala new file mode 100644 index 0000000000000000000000000000000000000000..68769df28353c9defa6d33000bb8ae4de21c7ace --- /dev/null +++ b/src/test/resources/regression/verification/xlang/invalid/NestedFunState2.scala @@ -0,0 +1,23 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object NestedFunState2 { + + def sum(n: BigInt): BigInt = { + require(n > 0) + var i = BigInt(0) + var res = BigInt(0) + + def iter(): Unit = { + require(res >= i && i >= 0) + if(i < n) { + i += 1 + res += i + iter() + } + } + + iter() + res + } ensuring(_ < 0) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayNested1.scala b/src/test/resources/regression/verification/xlang/valid/ArrayNested1.scala new file mode 100644 index 0000000000000000000000000000000000000000..196a6442bd544a3981e6a7de7e923e040cadd32d --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayNested1.scala @@ -0,0 +1,19 @@ +import leon.lang._ + +object ArrayNested1 { + + def test(): Int = { + + var a = Array(1, 2, 0) + + def nested(): Unit = { + require(a.length == 3) + a = a.updated(1, 5) + } + + nested() + a(1) + + } ensuring(_ == 5) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayNested2.scala b/src/test/resources/regression/verification/xlang/valid/ArrayNested2.scala new file mode 100644 index 0000000000000000000000000000000000000000..a8935ab8c1d4eec6980b85f171e695071f2a9442 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayNested2.scala @@ -0,0 +1,19 @@ +import leon.lang._ + +object ArrayNested2 { + + def test(): Int = { + + val a = Array(1, 2, 0) + + def nested(): Unit = { + require(a.length == 3) + a(2) = 5 + } + + nested() + a(2) + + } ensuring(_ == 5) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder1.scala b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder1.scala new file mode 100644 index 0000000000000000000000000000000000000000..8ef668a2afeb61a6b305831358a65e675bf670ed --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder1.scala @@ -0,0 +1,22 @@ +object FunInvocEvaluationOrder1 { + + def test(): Int = { + + var res = 10 + justAddingStuff({ + res += 1 + res + }, { + res *= 2 + res + }, { + res += 10 + res + }) + + res + } ensuring(_ == 32) + + def justAddingStuff(x: Int, y: Int, z: Int): Int = x + y + z + +} diff --git a/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder2.scala b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder2.scala new file mode 100644 index 0000000000000000000000000000000000000000..12090f0f3fad0a5fcf2451e30fdb6f25bb09c590 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder2.scala @@ -0,0 +1,17 @@ +object FunInvocEvaluationOrder2 { + + def leftToRight(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def nested(x: BigInt, y: BigInt): BigInt = { + require(y >= x) + x + y + } + + nested({a += 10; a}, {a *= 2; a}) + + } ensuring(_ == 30) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder3.scala b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder3.scala new file mode 100644 index 0000000000000000000000000000000000000000..f263529653486ee5477b9fd1934760bcdc737d15 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder3.scala @@ -0,0 +1,17 @@ +object FunInvocEvaluationOrder3 { + + def leftToRight(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def nested(x: BigInt, y: BigInt): Unit = { + a = x + y + } + + nested({a += 10; a}, {a *= 2; a}) + a + + } ensuring(_ == 30) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState1.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState1.scala new file mode 100644 index 0000000000000000000000000000000000000000..1ec0266f846dd55cfa1a6274d5c6dc7dc67dd125 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState1.scala @@ -0,0 +1,23 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object NestedFunState1 { + + def sum(n: BigInt): BigInt = { + require(n > 0) + var i = BigInt(0) + var res = BigInt(0) + + def iter(): Unit = { + require(res >= i && i >= 0) + if(i < n) { + i += 1 + res += i + iter() + } + } ensuring(_ => res >= n) + + iter() + res + } ensuring(_ >= n) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState2.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState2.scala new file mode 100644 index 0000000000000000000000000000000000000000..e48b1b959b3a83262329cf48c63370a2b591cd41 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState2.scala @@ -0,0 +1,19 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object NestedFunState2 { + + def countConst(): Int = { + + var counter = 0 + + def inc(): Unit = { + counter += 1 + } + + inc() + inc() + inc() + counter + } ensuring(_ == 3) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState3.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState3.scala new file mode 100644 index 0000000000000000000000000000000000000000..650f5987dcd352a3f5c47d41c684e94d4b9be20c --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState3.scala @@ -0,0 +1,25 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ +import leon.lang._ + +object NestedFunState3 { + + + def counterN(n: Int): Int = { + require(n > 0) + + var counter = 0 + + def inc(): Unit = { + counter += 1 + } + + var i = 0 + (while(i < n) { + inc() + i += 1 + }) invariant(i >= 0 && counter == i && i <= n) + + counter + } ensuring(_ == n) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState4.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState4.scala new file mode 100644 index 0000000000000000000000000000000000000000..9f4fb2621f60c60da3a677dc732e6c93f4d8ae72 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState4.scala @@ -0,0 +1,38 @@ +import leon.lang._ + +object NestedFunState4 { + + def deep(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def iter(): Unit = { + require(a >= 0) + + var b = BigInt(0) + + def nestedIter(): Unit = { + b += 1 + } + + var i = BigInt(0) + (while(i < n) { + i += 1 + nestedIter() + }) invariant(i >= 0 && i <= n && b == i) + + a += b + + } ensuring(_ => a >= n) + + var i = BigInt(0) + (while(i < n) { + i += 1 + iter() + }) invariant(i >= 0 && i <= n && a >= i) + + a + } ensuring(_ >= n) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState5.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState5.scala new file mode 100644 index 0000000000000000000000000000000000000000..13f3cf47cd0ee6c831f12877493cbfe6c7edea5a --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState5.scala @@ -0,0 +1,29 @@ +import leon.lang._ + +object NestedFunState5 { + + def deep(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def iter(prevA: BigInt): Unit = { + require(prevA == a) + def nestedIter(): Unit = { + a += 1 + } + + nestedIter() + nestedIter() + + } ensuring(_ => a == prevA + 2) + + var i = BigInt(0) + (while(i < n) { + i += 1 + iter(a) + }) invariant(i >= 0 && i <= n && a >= 2*i) + + a + } ensuring(_ >= 2*n) +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState6.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState6.scala new file mode 100644 index 0000000000000000000000000000000000000000..cea0f2e1900a00bc803df4107b46dd62cf082c68 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState6.scala @@ -0,0 +1,20 @@ +object NestedFunState6 { + + def simpleSideEffect(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def incA(prevA: BigInt): Unit = { + require(prevA == a) + a += 1 + } ensuring(_ => a == prevA + 1) + + incA(a) + incA(a) + incA(a) + incA(a) + a + } ensuring(_ == 4) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState7.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState7.scala new file mode 100644 index 0000000000000000000000000000000000000000..ff2418a0c972add2d48c08f6319905f35d0493c2 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState7.scala @@ -0,0 +1,27 @@ +import leon.lang._ + +object NestedFunState7 { + + def test(x: BigInt): BigInt = { + + var a = BigInt(0) + + def defCase(): Unit = { + a = 100 + } + + if(x >= 0) { + a = 2*x + if(a < 100) { + a = 100 - a + } else { + defCase() + } + } else { + defCase() + } + + a + } ensuring(res => res >= 0 && res <= 100) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedOld1.scala b/src/test/resources/regression/verification/xlang/valid/NestedOld1.scala new file mode 100644 index 0000000000000000000000000000000000000000..903dfd63aca5105c9366c459552865646139033e --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedOld1.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object NestedOld1 { + + def test(): Int = { + var counter = 0 + + def inc(): Unit = { + counter += 1 + } ensuring(_ => counter == old(counter) + 1) + + inc() + counter + } ensuring(_ == 1) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedOld2.scala b/src/test/resources/regression/verification/xlang/valid/NestedOld2.scala new file mode 100644 index 0000000000000000000000000000000000000000..fe6143ecf623c295251ca2b9d75d6474d422fd4f --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedOld2.scala @@ -0,0 +1,24 @@ +import leon.lang._ + +object NestedOld2 { + + def test(): Int = { + var counterPrev = 0 + var counterNext = 1 + + def step(): Unit = { + require(counterNext == counterPrev + 1) + counterPrev = counterNext + counterNext = counterNext+1 + } ensuring(_ => { + counterPrev == old(counterNext) && + counterNext == old(counterNext) + 1 && + counterPrev == old(counterPrev) + 1 + }) + + step() + step() + counterNext + } ensuring(_ == 3) + +} diff --git a/src/test/scala/leon/integration/evaluators/CodegenEvaluatorSuite.scala b/src/test/scala/leon/integration/evaluators/CodegenEvaluatorSuite.scala index 0072da45d8f6556100410f733685c5c8f7177b89..e5565e073dc629a546019e219297c841c9d84258 100644 --- a/src/test/scala/leon/integration/evaluators/CodegenEvaluatorSuite.scala +++ b/src/test/scala/leon/integration/evaluators/CodegenEvaluatorSuite.scala @@ -9,7 +9,7 @@ import leon.purescala.Expressions._ import leon.purescala.Types._ import leon.codegen._ -class CodegenEvaluatorSuite extends LeonTestSuiteWithProgram { +class CodegenEvaluatorSuite extends LeonTestSuiteWithProgram with helpers.ExpressionsDSL{ val sources = List(""" import leon.lang._ @@ -316,7 +316,7 @@ class CodegenEvaluatorSuite extends LeonTestSuiteWithProgram { "Overrides1" -> Tuple(Seq(BooleanLiteral(false), BooleanLiteral(true))), "Overrides2" -> Tuple(Seq(BooleanLiteral(false), BooleanLiteral(true))), "Arrays1" -> IntLiteral(2), - "Arrays2" -> IntLiteral(6) + "Arrays2" -> IntLiteral(10) ) for { @@ -324,29 +324,25 @@ class CodegenEvaluatorSuite extends LeonTestSuiteWithProgram { requireMonitor <- Seq(false, true) doInstrument <- Seq(false,true) } { - val opts = ((if(requireMonitor) Some("monitor") else None) ++ - (if(doInstrument) Some("instrument") else None)).mkString("+") + val opts = (if(requireMonitor) "monitor " else "") + + (if(doInstrument) "instrument" else "") val testName = f"$name%-20s $opts%-18s" - test("Evaluation of "+testName) { case (ctx, pgm) => - val eval = new CodeGenEvaluator(ctx, pgm, CodeGenParams( + test("Evaluation of "+testName) { implicit fix => + val eval = new CodeGenEvaluator(fix._1, fix._2, CodeGenParams( maxFunctionInvocations = if (requireMonitor) 1000 else -1, // Monitor calls and abort execution if more than X calls checkContracts = true, // Generate calls that checks pre/postconditions doInstrument = doInstrument // Instrument reads to case classes (mainly for vanuatoo) )) - val fun = pgm.lookup(name+".test").collect { - case fd: FunDef => fd - }.getOrElse { - fail("Failed to lookup '"+name+".test'") - } - - (eval.eval(FunctionInvocation(fun.typed(Seq()), Seq())).result, exp) match { + (eval.eval(fcall(name + ".test")()).result, exp) match { case (Some(res), exp) => + assert(res === exp) case (None, Error(_, _)) => - case (None, _) => - case (_, Error(_, _)) => + // OK + case _ => + fail("") } } } diff --git a/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala b/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..64cb3b70b062ba30ccf6223af587ed85c9b4ff88 --- /dev/null +++ b/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala @@ -0,0 +1,74 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.integration.solvers + +import leon.test._ +import leon.test.helpers._ +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.ExprOps._ +import leon.purescala.Constructors._ +import leon.purescala.Expressions._ +import leon.purescala.Types._ +import leon.LeonContext + +import leon.solvers._ +import leon.solvers.smtlib._ +import leon.solvers.combinators._ +import leon.solvers.z3._ + +class GlobalVariablesSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { + + val sources = List( + """|import leon.lang._ + |import leon.annotation._ + | + |object GlobalVariables { + | + | def test(i: BigInt): BigInt = { + | 0 // will be replaced + | } + |} """.stripMargin + ) + + val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { + (if (SolverFactory.hasNativeZ3) Seq( + ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) + ) else Nil) ++ + (if (SolverFactory.hasZ3) Seq( + ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ) else Nil) ++ + (if (SolverFactory.hasCVC4) Seq( + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ) else Nil) + } + + // Check that we correctly extract several types from solver models + for ((sname, sf) <- getFactories) { + test(s"Global Variables in $sname") { implicit fix => + val ctx = fix._1 + val pgm = fix._2 + + pgm.lookup("GlobalVariables.test") match { + case Some(fd: FunDef) => + val b0 = FreshIdentifier("B", BooleanType); + fd.body = Some(IfExpr(b0.toVariable, bi(1), bi(-1))) + + val cnstr = LessThan(FunctionInvocation(fd.typed, Seq(bi(42))), bi(0)) + val solver = sf(ctx, pgm) + solver.assertCnstr(And(b0.toVariable, cnstr)) + + try { + if (solver.check != Some(false)) { + fail("Global variables not correctly handled.") + } + } finally { + solver.free() + } + case _ => + fail("Function with global body not found") + } + + } + } +} diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index f2ba6c89a2a6231f2e1bc18d3364705789c3bb27..24ebf35916d9abce077fdd7d8ee5752fd3c1e0d3 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -32,42 +32,38 @@ class SolversSuite extends LeonTestSuiteWithProgram { ) else Nil) } - // Check that we correctly extract several types from solver models - for ((sname, sf) <- getFactories) { - test(s"Model Extraction in $sname") { implicit fix => - val ctx = fix._1 - val pgm = fix._2 - - val solver = sf(ctx, pgm) - val types = Seq( BooleanType, UnitType, CharType, + RealType, IntegerType, Int32Type, StringType, TypeParameter.fresh("T"), SetType(IntegerType), MapType(IntegerType, IntegerType), + FunctionType(Seq(IntegerType), IntegerType), TupleType(Seq(IntegerType, BooleanType, Int32Type)) ) val vs = types.map(FreshIdentifier("v", _).toVariable) - // We need to make sure models are not co-finite - val cnstr = andJoin(vs.map(v => v.getType match { + val cnstrs = vs.map(v => v.getType match { case UnitType => Equals(v, simplestValue(v.getType)) case SetType(base) => Not(ElementOfSet(simplestValue(base), v)) case MapType(from, to) => Not(Equals(MapApply(v, simplestValue(from)), simplestValue(to))) + case FunctionType(froms, to) => + Not(Equals(Application(v, froms.map(simplestValue)), simplestValue(to))) case _ => not(Equals(v, simplestValue(v.getType))) - })) + }) + def checkSolver(solver: Solver, vs: Set[Variable], cnstr: Expr)(implicit fix: (LeonContext, Program)): Unit = { try { solver.assertCnstr(cnstr) @@ -87,7 +83,22 @@ class SolversSuite extends LeonTestSuiteWithProgram { } finally { solver.free() } + } + // Check that we correctly extract several types from solver models + for ((sname, sf) <- getFactories) { + test(s"Model Extraction in $sname") { implicit fix => + val ctx = fix._1 + val pgm = fix._2 + val solver = sf(ctx, pgm) + checkSolver(solver, vs.toSet, andJoin(cnstrs)) } } + + test(s"Data generation in enum solver") { implicit fix => + for ((v,cnstr) <- vs zip cnstrs) { + val solver = new EnumerationSolver(fix._1, fix._2) + checkSolver(solver, Set(v), cnstr) +} + } } diff --git a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala index e7c5fbf2abca03cbf66f238b052078d1bb0a56a1..c70df950e0768ca71dd7bba013ac04cd7404edab 100644 --- a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala @@ -154,9 +154,9 @@ object Injection { case class Nil() extends List // proved with unrolling=0 - def size(l: List) : Int = (l match { - case Nil() => 0 - case Cons(t) => 1 + size(t) + def size(l: List) : BigInt = (l match { + case Nil() => BigInt(0) + case Cons(t) => BigInt(1) + size(t) }) ensuring(res => res >= 0) def simple(in: List) = choose{out: List => size(out) == size(in) } @@ -274,10 +274,10 @@ object ChurchNumerals { case object Z extends Num case class S(pred: Num) extends Num - def value(n:Num) : Int = { + def value(n:Num) : BigInt = { n match { - case Z => 0 - case S(p) => 1 + value(p) + case Z => BigInt(0) + case S(p) => BigInt(1) + value(p) } } ensuring (_ >= 0) @@ -309,10 +309,10 @@ object ChurchNumerals { case object Z extends Num case class S(pred: Num) extends Num - def value(n:Num) : Int = { + def value(n:Num) : BigInt = { n match { - case Z => 0 - case S(p) => 1 + value(p) + case Z => BigInt(0) + case S(p) => BigInt(1) + value(p) } } ensuring (_ >= 0) diff --git a/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala b/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala index d308e840f159a8ab7de49380d920dfc12ee8d2ad..974615e6484304de738ebdd43ecafbb323e589c2 100644 --- a/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala +++ b/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala @@ -59,10 +59,10 @@ class PureScalaValidSuite3 extends PureScalaValidSuite { val optionVariants = List(opts(2)) } class PureScalaValidSuiteZ3 extends PureScalaValidSuite { - val optionVariants = if (isZ3Available) List(opts(3)) else Nil + val optionVariants = Nil//if (isZ3Available) List(opts(3)) else Nil } class PureScalaValidSuiteCVC4 extends PureScalaValidSuite { - val optionVariants = if (isCVC4Available) List(opts(4)) else Nil + val optionVariants = Nil//if (isCVC4Available) List(opts(4)) else Nil } class PureScalaInvalidSuite extends PureScalaVerificationSuite { diff --git a/testcases/synthesis/etienne-thesis/List/Delete.scala b/testcases/synthesis/etienne-thesis/List/Delete.scala new file mode 100644 index 0000000000000000000000000000000000000000..46f0710878d25a30e679f034ff07f985078950d1 --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Delete.scala @@ -0,0 +1,41 @@ +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Delete { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1) + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + def insert(in1: List, v: BigInt): List = { + Cons(v, in1) + } ensuring { content(_) == content(in1) ++ Set(v) } + + //def delete(in1: List, v: BigInt): List = { + // in1 match { + // case Cons(h,t) => + // if (h == v) { + // delete(t, v) + // } else { + // Cons(h, delete(t, v)) + // } + // case Nil => + // Nil + // } + //} ensuring { content(_) == content(in1) -- Set(v) } + + def delete(in1: List, v: BigInt) = choose { + (out : List) => + content(out) == content(in1) -- Set(v) + } +} diff --git a/testcases/synthesis/etienne-thesis/List/Diff.scala b/testcases/synthesis/etienne-thesis/List/Diff.scala new file mode 100644 index 0000000000000000000000000000000000000000..9fb3ade9558a7db175c6056efb9bf724f487d7ca --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Diff.scala @@ -0,0 +1,50 @@ +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Diff { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1) + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + def insert(in1: List, v: BigInt): List = { + Cons(v, in1) + } ensuring { content(_) == content(in1) ++ Set(v) } + + def delete(in1: List, v: BigInt): List = { + in1 match { + case Cons(h,t) => + if (h == v) { + delete(t, v) + } else { + Cons(h, delete(t, v)) + } + case Nil => + Nil + } + } ensuring { content(_) == content(in1) -- Set(v) } + + // def diff(in1: List, in2: List): List = { + // in2 match { + // case Nil => + // in1 + // case Cons(h, t) => + // diff(delete(in1, h), t) + // } + // } ensuring { content(_) == content(in1) -- content(in2) } + + def diff(in1: List, in2: List) = choose { + (out : List) => + content(out) == content(in1) -- content(in2) + } +} diff --git a/testcases/synthesis/etienne-thesis/List/Insert.scala b/testcases/synthesis/etienne-thesis/List/Insert.scala new file mode 100644 index 0000000000000000000000000000000000000000..48c38f4df0098410814f3ed27038c9c36c0d6532 --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Insert.scala @@ -0,0 +1,28 @@ +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Insert { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1) + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + //def insert(in1: List, v: BigInt): List = { + // Cons(v, in1) + //} ensuring { content(_) == content(in1) ++ Set(v) } + + def insert(in1: List, v: BigInt) = choose { + (out : List) => + content(out) == content(in1) ++ Set(v) + } +} diff --git a/testcases/synthesis/etienne-thesis/List/Split.scala b/testcases/synthesis/etienne-thesis/List/Split.scala new file mode 100644 index 0000000000000000000000000000000000000000..fb98204096f3e4a97b990527e588b655e8b93fac --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Split.scala @@ -0,0 +1,41 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Complete { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1) + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + def splitSpec(list : List, res : (List,List)) : Boolean = { + val s1 = size(res._1) + val s2 = size(res._2) + abs(s1 - s2) <= 1 && s1 + s2 == size(list) && + content(res._1) ++ content(res._2) == content(list) + } + + def abs(i : BigInt) : BigInt = { + if(i < 0) -i else i + } ensuring(_ >= 0) + + def dispatch(es: (BigInt, BigInt), rest: (List, List)): (List, List) = { + (Cons(es._1, rest._1), Cons(es._2, rest._2)) + } + + def split(list : List) : (List,List) = { + choose { (res : (List,List)) => splitSpec(list, res) } + } + +} diff --git a/testcases/synthesis/etienne-thesis/List/Union.scala b/testcases/synthesis/etienne-thesis/List/Union.scala new file mode 100644 index 0000000000000000000000000000000000000000..d6a5fa579f5745f00e469f19ee853da15d4fece7 --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Union.scala @@ -0,0 +1,37 @@ +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Union { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1)+ size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + def insert(in1: List, v: BigInt): List = { + Cons(v, in1) + } ensuring { content(_) == content(in1) ++ Set(v) } + + // def union(in1: List, in2: List): List = { + // in1 match { + // case Cons(h, t) => + // Cons(h, union(t, in2)) + // case Nil => + // in2 + // } + // } ensuring { content(_) == content(in1) ++ content(in2) } + + def union(in1: List, in2: List) = choose { + (out : List) => + content(out) == content(in1) ++ content(in2) + } +} diff --git a/testcases/verification/compilation/ExprCompiler.scala b/testcases/verification/compilation/ExprCompiler.scala new file mode 100644 index 0000000000000000000000000000000000000000..7518c21fb8ba7349e49c37c8acac6f99aaae2cdd --- /dev/null +++ b/testcases/verification/compilation/ExprCompiler.scala @@ -0,0 +1,89 @@ +import leon.lang._ +import leon.lang.Option +import leon.collection._ +import leon.annotation._ +import leon.proof._ +import leon.lang.synthesis._ + +object TinyCertifiedCompiler { + abstract class ByteCode[A] + case class Load[A](c: A) extends ByteCode[A] // loads a constant in to the stack + case class OpInst[A]() extends ByteCode[A] + + abstract class ExprTree[A] + case class Const[A](c: A) extends ExprTree[A] + case class Op[A](e1: ExprTree[A], e2: ExprTree[A]) extends ExprTree[A] + + def compile[A](e: ExprTree[A]): List[ByteCode[A]] = { + e match { + case Const(c) => + Cons(Load(c), Nil[ByteCode[A]]()) + case Op(e1, e2) => + (compile(e1) ++ compile(e2)) ++ Cons(OpInst(), Nil[ByteCode[A]]()) + } + } + + def op[A](x: A, y: A): A = { + ???[A] + } + + def run[A](bytecode: List[ByteCode[A]], S: List[A]): List[A] = { + (bytecode, S) match { + case (Cons(Load(c), tail), _) => + run(tail, Cons[A](c, S)) // adding elements to the head of the stack + case (Cons(OpInst(), tail), Cons(x, Cons(y, rest))) => + run(tail, Cons[A](op(y, x), rest)) + case (Cons(_, tail), _) => + run(tail, S) + case (Nil(), _) => // no bytecode to execute + S + } + } + + def interpret[A](e: ExprTree[A]): A = { + e match { + case Const(c) => c + case Op(e1, e2) => + op(interpret(e1), interpret(e2)) + } + } + + def runAppendLemma[A](l1: List[ByteCode[A]], l2: List[ByteCode[A]], S: List[A]): Boolean = { + // lemma + (run(l1 ++ l2, S) == run(l2, run(l1, S))) because + // induction scheme (induct over l1) + (l1 match { + case Nil() => + true + case Cons(h, tail) => + (h, S) match { + case (Load(c), _) => + runAppendLemma(tail, l2, Cons[A](c, S)) + case (OpInst(), Cons(x, Cons(y, rest))) => + runAppendLemma(tail, l2, Cons[A](op(y, x), rest)) + case (_, _) => + runAppendLemma(tail, l2, S) + case _ => + true + } + }) + }.holds + + def compileInterpretEquivalenceLemma[A](e: ExprTree[A], S: List[A]): Boolean = { + // lemma + (run(compile(e), S) == interpret(e) :: S) because + // induction scheme + (e match { + case Const(c) => + true + case Op(e1, e2) => + // lemma instantiation + val c1 = compile(e1) + val c2 = compile(e2) + runAppendLemma((c1 ++ c2), Cons[ByteCode[A]](OpInst[A](), Nil[ByteCode[A]]()), S) && + runAppendLemma(c1, c2, S) && + compileInterpretEquivalenceLemma(e1, S) && + compileInterpretEquivalenceLemma(e2, Cons[A](interpret(e1), S)) + }) + } holds +} diff --git a/testcases/verification/compilation/IntExprCompiler.scala b/testcases/verification/compilation/IntExprCompiler.scala new file mode 100644 index 0000000000000000000000000000000000000000..b7f46a49a27cf4092518afa08f0cef5ff8f724c7 --- /dev/null +++ b/testcases/verification/compilation/IntExprCompiler.scala @@ -0,0 +1,107 @@ +import leon.lang._ +import leon.lang.Option +import leon.collection._ +import leon.annotation._ +import leon.proof._ + +object TinyCertifiedCompiler { + + abstract class ByteCode + case class Load(c: BigInt) extends ByteCode // loads a constant in to the stack + case class OpInst() extends ByteCode + + abstract class ExprTree + case class Const(c: BigInt) extends ExprTree + case class Op(e1: ExprTree, e2: ExprTree) extends ExprTree + + def compile(e: ExprTree): List[ByteCode] = { + e match { + case Const(c) => + Cons(Load(c), Nil[ByteCode]()) + case Op(e1, e2) => + (compile(e1) ++ compile(e2)) ++ Cons(OpInst(), Nil[ByteCode]()) + } + } + + def op(x: BigInt, y: BigInt): BigInt = { + x - y + } + + def run(bytecode: List[ByteCode], S: List[BigInt]): List[BigInt] = { + (bytecode, S) match { + case (Cons(Load(c), tail), _) => + run(tail, Cons[BigInt](c, S)) // adding elements to the head of the stack + case (Cons(OpInst(), tail), Cons(x, Cons(y, rest))) => + run(tail, Cons[BigInt](op(y, x), rest)) + case (Cons(_, tail), _) => + run(tail, S) + case (Nil(), _) => // no bytecode to execute + S + } + } + + def interpret(e: ExprTree): BigInt = { + e match { + case Const(c) => c + case Op(e1, e2) => + op(interpret(e1), interpret(e2)) + } + } + + def runAppendLemma(l1: List[ByteCode], l2: List[ByteCode], S: List[BigInt]): Boolean = { + // lemma + (run(l1 ++ l2, S) == run(l2, run(l1, S))) because + // induction scheme (induct over l1) + (l1 match { + case Nil() => + true + case Cons(h, tail) => + (h, S) match { + case (Load(c), _) => + runAppendLemma(tail, l2, Cons[BigInt](c, S)) + case (OpInst(), Cons(x, Cons(y, rest))) => + runAppendLemma(tail, l2, Cons[BigInt](op(y, x), rest)) + case (_, _) => + runAppendLemma(tail, l2, S) + case _ => + true + } + }) + }.holds + + def run1(bytecode: List[ByteCode], S: List[BigInt]): List[BigInt] = { + (bytecode, S) match { + case (Cons(Load(c), tail), _) => + run1(tail, Cons[BigInt](c, S)) // adding elements to the head of the stack + case (Cons(OpInst(), tail), Cons(x, Cons(y, rest))) => + run1(tail, Cons[BigInt](op(x, y), rest)) + case (Cons(_, tail), _) => + run1(tail, S) + case (Nil(), _) => // no bytecode to execute + S + } + } + + def compileInterpretEquivalenceLemma1(e: ExprTree, S: List[BigInt]): Boolean = { + // lemma + (run1(compile(e), S) == interpret(e) :: S) + } holds + + def compileInterpretEquivalenceLemma(e: ExprTree, S: List[BigInt]): Boolean = { + // lemma + (run(compile(e), S) == interpret(e) :: S) because + // induction scheme + (e match { + case Const(c) => + true + case Op(e1, e2) => + // lemma instantiation + val c1 = compile(e1) + val c2 = compile(e2) + runAppendLemma((c1 ++ c2), Cons(OpInst(), Nil[ByteCode]()), S) && + runAppendLemma(c1, c2, S) && + compileInterpretEquivalenceLemma(e1, S) && + compileInterpretEquivalenceLemma(e2, Cons[BigInt](interpret(e1), S)) + }) + } holds +}