diff --git a/build.sbt b/build.sbt index 9215fea007959ffb86d5226f3e5427c58697aa9e..5b984ed01dc07a99e12adf882198d16d2a667043 100644 --- a/build.sbt +++ b/build.sbt @@ -143,7 +143,7 @@ def ghProject(repo: String, version: String) = RootProject(uri(s"${repo}#${versi lazy val bonsai = ghProject("git://github.com/colder/bonsai.git", "0fec9f97f4220fa94b1f3f305f2e8b76a3cd1539") -lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "3b6ef4992b6af15d08a7320fd12202f35e97b905") +lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "372bb14d0c84953acc17f9a7e1592087adb0a3e1") lazy val root = (project in file(".")). configs(RegressionTest, IsabelleTest, IntegrTest). diff --git a/library/lang/Set.scala b/library/lang/Set.scala index 36c70a8372dd7a81d1737dfd321e81839d38fe6f..8f4595d33fada35acfa38435c3cd542116cf3cbd 100644 --- a/library/lang/Set.scala +++ b/library/lang/Set.scala @@ -17,6 +17,7 @@ case class Set[T](val theSet: scala.collection.immutable.Set[T]) { def ++(a: Set[T]): Set[T] = new Set[T](theSet ++ a.theSet) def -(a: T): Set[T] = new Set[T](theSet - a) def --(a: Set[T]): Set[T] = new Set[T](theSet -- a.theSet) + def size: BigInt = theSet.size def contains(a: T): Boolean = theSet.contains(a) def isEmpty: Boolean = theSet.isEmpty def subsetOf(b: Set[T]): Boolean = theSet.subsetOf(b.theSet) diff --git a/library/lang/package.scala b/library/lang/package.scala index 44351d02ca9328bbbf86063745d233f2db17cd7c..bb7ec69122f4814bd61b1ce79f94c2eab9054c74 100644 --- a/library/lang/package.scala +++ b/library/lang/package.scala @@ -47,6 +47,9 @@ package object lang { @ignore def error[T](reason: java.lang.String): T = sys.error(reason) + @ignore + def old[T](value: T): T = value + @ignore implicit class Passes[A,B](io : (A,B)) { val (in, out) = io diff --git a/src/main/java/leon/codegen/runtime/Forall.java b/src/main/java/leon/codegen/runtime/Forall.java new file mode 100644 index 0000000000000000000000000000000000000000..f6877b604bcd069af6c2ce20d78a17d82d434dbe --- /dev/null +++ b/src/main/java/leon/codegen/runtime/Forall.java @@ -0,0 +1,35 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.HashMap; + +public abstract class Forall { + private static final HashMap<Tuple, Boolean> cache = new HashMap<Tuple, Boolean>(); + + protected final LeonCodeGenRuntimeHenkinMonitor monitor; + protected final Tuple closures; + protected final boolean check; + + public Forall(LeonCodeGenRuntimeMonitor monitor, Tuple closures) throws LeonCodeGenEvaluationException { + if (!(monitor instanceof LeonCodeGenRuntimeHenkinMonitor)) + throw new LeonCodeGenEvaluationException("Can't evaluate foralls without domain"); + + this.monitor = (LeonCodeGenRuntimeHenkinMonitor) monitor; + this.closures = closures; + this.check = (boolean) closures.get(closures.getArity() - 1); + } + + public boolean check() { + Tuple key = new Tuple(new Object[] { getClass(), monitor, closures }); // check is in the closures + if (cache.containsKey(key)) { + return cache.get(key); + } else { + boolean res = checkForall(); + cache.put(key, res); + return res; + } + } + + public abstract boolean checkForall(); +} diff --git a/src/main/java/leon/codegen/runtime/Lambda.java b/src/main/java/leon/codegen/runtime/Lambda.java index a6abbef37edbe8f87f480a21a6200e32a9e0206b..af255726311655efaeddea545c5e6e44afc15b8e 100644 --- a/src/main/java/leon/codegen/runtime/Lambda.java +++ b/src/main/java/leon/codegen/runtime/Lambda.java @@ -4,4 +4,6 @@ package leon.codegen.runtime; public abstract class Lambda { public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException; + public abstract void checkForall(boolean[] quantified); + public abstract void checkAxiom(); } diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java b/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java new file mode 100644 index 0000000000000000000000000000000000000000..f172316a2548a52c6b294f70101a15ebbb8ce98a --- /dev/null +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java @@ -0,0 +1,14 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +/** Such exceptions are thrown when the evaluator encounters a forall + * expression whose shape is not supported in Leon. */ +public class LeonCodeGenQuantificationException extends Exception { + + private static final long serialVersionUID = -1824885321497473916L; + + public LeonCodeGenQuantificationException(String msg) { + super(msg); + } +} diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java index 0be4ad91212be930ec5f8730ed30ebcf9a9f4e0a..597beec44b6a1a1719909e00ecb7d7916f0c7c03 100644 --- a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java @@ -7,26 +7,42 @@ import java.util.LinkedList; import java.util.HashMap; public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { - private final HashMap<Integer, List<Tuple>> domains = new HashMap<Integer, List<Tuple>>(); + private final HashMap<Integer, List<Tuple>> tpes = new HashMap<Integer, List<Tuple>>(); + private final HashMap<Class<?>, List<Tuple>> lambdas = new HashMap<Class<?>, List<Tuple>>(); + public final boolean checkForalls; - public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { + public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations, boolean checkForalls) { super(maxInvocations); + this.checkForalls = checkForalls; + } + + public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { + this(maxInvocations, false); } public void add(int type, Tuple input) { - if (!domains.containsKey(type)) domains.put(type, new LinkedList<Tuple>()); - domains.get(type).add(input); + if (!tpes.containsKey(type)) tpes.put(type, new LinkedList<Tuple>()); + tpes.get(type).add(input); + } + + public void add(Class<?> clazz, Tuple input) { + if (!lambdas.containsKey(clazz)) lambdas.put(clazz, new LinkedList<Tuple>()); + lambdas.get(clazz).add(input); } public List<Tuple> domain(Object obj, int type) { List<Tuple> domain = new LinkedList<Tuple>(); if (obj instanceof PartialLambda) { - for (Tuple key : ((PartialLambda) obj).mapping.keySet()) { + PartialLambda l = (PartialLambda) obj; + for (Tuple key : l.mapping.keySet()) { domain.add(key); } + } else if (obj instanceof Lambda) { + List<Tuple> lambdaDomain = lambdas.get(obj.getClass()); + if (lambdaDomain != null) domain.addAll(lambdaDomain); } - List<Tuple> tpeDomain = domains.get(type); + List<Tuple> tpeDomain = tpes.get(type); if (tpeDomain != null) domain.addAll(tpeDomain); return domain; diff --git a/src/main/java/leon/codegen/runtime/PartialLambda.java b/src/main/java/leon/codegen/runtime/PartialLambda.java index 826cc5ed9930e54bc2f50d7f09e6fa09be3fa307..b04036db5e9f81d1eaf7fa2c9a047bfef45a4df8 100644 --- a/src/main/java/leon/codegen/runtime/PartialLambda.java +++ b/src/main/java/leon/codegen/runtime/PartialLambda.java @@ -6,9 +6,15 @@ import java.util.HashMap; public final class PartialLambda extends Lambda { final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>(); + private final Object dflt; public PartialLambda() { + this(null); + } + + public PartialLambda(Object dflt) { super(); + this.dflt = dflt; } public void add(Tuple key, Object value) { @@ -20,15 +26,18 @@ public final class PartialLambda extends Lambda { Tuple tuple = new Tuple(args); if (mapping.containsKey(tuple)) { return mapping.get(tuple); + } else if (dflt != null) { + return dflt; } else { - throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments"); + throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments " + tuple); } } @Override public boolean equals(Object that) { if (that != null && (that instanceof PartialLambda)) { - return mapping.equals(((PartialLambda) that).mapping); + PartialLambda l = (PartialLambda) that; + return ((dflt != null && dflt.equals(l.dflt)) || (dflt == null && l.dflt == null)) && mapping.equals(l.mapping); } else { return false; } @@ -36,6 +45,12 @@ public final class PartialLambda extends Lambda { @Override public int hashCode() { - return 63 + 11 * mapping.hashCode(); + return 63 + 11 * mapping.hashCode() + (dflt == null ? 0 : dflt.hashCode()); } + + @Override + public void checkForall(boolean[] quantified) {} + + @Override + public void checkAxiom() {} } diff --git a/src/main/java/leon/codegen/runtime/Set.java b/src/main/java/leon/codegen/runtime/Set.java index 965aa20df6ba453e2fcbdfb0c2e4afce66d29d27..522dbe6589eac54bf2acac3e1f81ee4b4c7c962c 100644 --- a/src/main/java/leon/codegen/runtime/Set.java +++ b/src/main/java/leon/codegen/runtime/Set.java @@ -49,8 +49,8 @@ public final class Set { return true; } - public int size() { - return _underlying.size(); + public BigInt size() { + return new BigInt(""+_underlying.size()); } public Set union(Set s) { @@ -84,7 +84,7 @@ public final class Set { Set other = (Set)that; - return this.size() == other.size() && this.subsetOf(other); + return this.size().equals(other.size()) && this.subsetOf(other); } @Override diff --git a/src/main/java/leon/codegen/runtime/Tuple.java b/src/main/java/leon/codegen/runtime/Tuple.java index 9ae7a5f490223b0bc3b6a07dbb05160c76dc5e11..3b72931da6fc04b3504d7f3bee2056560c269634 100644 --- a/src/main/java/leon/codegen/runtime/Tuple.java +++ b/src/main/java/leon/codegen/runtime/Tuple.java @@ -54,4 +54,20 @@ public final class Tuple { _hash = h; return h; } + + @Override + public String toString() { + String str = "("; + boolean first = true; + for (Object obj : elements) { + if (first) { + first = false; + } else { + str += ", "; + } + str += obj == null ? "null" : obj.toString(); + } + str += ")"; + return str; + } } diff --git a/src/main/scala/leon/LeonContext.scala b/src/main/scala/leon/LeonContext.scala index dd4e25ba35ef621c92b5cc880163be77c424295f..6fde42aecd40673887f723438b9ab2acbdbb3ef6 100644 --- a/src/main/scala/leon/LeonContext.scala +++ b/src/main/scala/leon/LeonContext.scala @@ -9,8 +9,9 @@ import java.io.File import scala.reflect.ClassTag /** Everything that is part of a compilation unit, except the actual program tree. - * Contexts are immutable, and so should all there fields (with the possible - * exception of the reporter). */ + * Contexts are immutable, and so should all there fields (with the possible + * exception of the reporter). + */ case class LeonContext( reporter: Reporter, interruptManager: InterruptManager, @@ -20,9 +21,6 @@ case class LeonContext( timers: TimerStorage = new TimerStorage ) { - // @mk: This is not typesafe, because equality for options is implemented as name equality. - // It will fail if an LeonOptionDef is passed that has the same name - // with one in Main,allOptions, but is different def findOption[A: ClassTag](optDef: LeonOptionDef[A]): Option[A] = options.collectFirst { case LeonOption(`optDef`, value:A) => value } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 33975a011737d50a45d1511b108d4a2025208eb7..8559b404dd23361ced4eefd6d967fa6d1252f786 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -6,7 +6,7 @@ package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps.{simplestValue, matchToIfThenElse, collect} +import purescala.ExprOps._ import purescala.Types._ import purescala.Constructors._ import purescala.Extractors._ @@ -47,6 +47,8 @@ trait CodeGeneration { def withArgs(newArgs: Map[Identifier, Int]) = new Locals(vars, args ++ newArgs, fields) def withFields(newFields: Map[Identifier,(String,String,String)]) = new Locals(vars, args, fields ++ newFields) + + override def toString = "Locals("+vars + ", " + args + ", " + fields + ")" } object NoLocals extends Locals(Map.empty, Map.empty, Map.empty) @@ -70,8 +72,12 @@ trait CodeGeneration { private[codegen] val RationalClass = "leon/codegen/runtime/Rational" private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" private[codegen] val LambdaClass = "leon/codegen/runtime/Lambda" + private[codegen] val ForallClass = "leon/codegen/runtime/Forall" + private[codegen] val PartialLambdaClass = "leon/codegen/runtime/PartialLambda" private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" + private[codegen] val InvalidForallClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" + private[codegen] val BadQuantificationClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" private[codegen] val GenericValuesClass = "leon/codegen/runtime/GenericValues" @@ -223,8 +229,8 @@ trait CodeGeneration { private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String] private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda] - private def compileLambda(l: Lambda, ch: CodeHandler)(implicit locals: Locals): Unit = { - val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(l) + protected def compileLambda(l: Lambda): (String, Seq[(Identifier, String)], String) = { + val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(matchToIfThenElse(l)) val reverseSubst = structSubst.map(p => p._2 -> p._1) val nl = normalized.asInstanceOf[Lambda] @@ -278,6 +284,10 @@ trait CodeGeneration { cch.freeze } + val argMapping = nl.args.map(_.id).zipWithIndex.toMap + val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap + val newLocals = NoLocals.withArgs(argMapping).withFields(closureMapping) + locally { val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") @@ -286,11 +296,6 @@ trait CodeGeneration { METHOD_ACC_FINAL ).asInstanceOf[U2]) - val argMapping = nl.args.map(_.id).zipWithIndex.toMap - val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap - - val newLocals = locals.withArgs(argMapping).withFields(closureMapping) - val apch = apm.codeHandler mkBoxedExpr(nl.body, apch)(newLocals) @@ -375,22 +380,135 @@ trait CodeGeneration { hch.freeze } + locally { + val vmh = cf.addMethod("V", "checkForall", "[Z") + vmh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val vch = vmh.codeHandler + + vch << ALoad(1) // load argument array + def rec(args: Seq[Identifier], idx: Int, quantified: Set[Identifier]): Unit = args match { + case x :: xs => + val notQuantLabel = vch.getFreshLabel("notQuant") + vch << DUP << Ldc(idx) << BALOAD << IfEq(notQuantLabel) + rec(xs, idx + 1, quantified + x) + vch << Label(notQuantLabel) + rec(xs, idx + 1, quantified) + + case Nil => + if (quantified.nonEmpty) { + checkQuantified(quantified, nl.body, vch)(newLocals) + vch << ALoad(0) << InvokeVirtual(LambdaClass, "checkAxiom", "()V") + } + vch << POP << RETURN + } + + if (requireQuantification) { + rec(nl.args.map(_.id), 0, Set.empty) + } else { + vch << POP << RETURN + } + + vch.freeze + } + + locally { + val vmh = cf.addMethod("V", "checkAxiom") + vmh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val vch = vmh.codeHandler + + if (requireQuantification) { + val thisVar = FreshIdentifier("this", l.getType) + val axiom = Equals(Application(Variable(thisVar), nl.args.map(_.toVariable)), nl.body) + val axiomLocals = NoLocals.withFields(closureMapping).withVar(thisVar -> 0) + + mkForall(nl.args.map(_.id).toSet, axiom, vch, check = false)(axiomLocals) + + val skip = vch.getFreshLabel("skip") + vch << IfNe(skip) + vch << New(InvalidForallClass) << DUP + vch << Ldc("Unaxiomatic lambda " + l) + vch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + vch << ATHROW + vch << Label(skip) + } + + vch << RETURN + vch.freeze + } + loader.register(cf) afName }) - val consSig = "(" + closures.map(_._2).mkString("") + ")V" + (afName, closures.map { case p @ (id, jvmt) => + if (id == monitorID) p else (reverseSubst(id) -> jvmt) + }, "(" + closures.map(_._2).mkString("") + ")V") + } - ch << New(afName) << DUP - for ((id,jvmt) <- closures) { - if (id == monitorID) { - load(monitorID, ch) - } else { - mkExpr(Variable(reverseSubst(id)), ch) - } + private def checkQuantified(quantified: Set[Identifier], body: Expr, ch: CodeHandler)(implicit locals: Locals): Unit = { + val skipCheck = ch.getFreshLabel("skipCheck") + + load(monitorID, ch) + ch << CheckCast(HenkinClass) << GetField(HenkinClass, "checkForalls", "Z") + ch << IfEq(skipCheck) + + checkForall(quantified, body)(ctx) match { + case status: ForallInvalid => + ch << New(InvalidForallClass) << DUP + ch << Ldc("Invalid forall: " + status.getMessage) + ch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + ch << ATHROW + + case ForallValid => + // expand match case expressions and lets so that caller can be compiled given + // the current locals (lets and matches introduce new locals) + val cleanBody = purescala.ExprOps.expandLets(purescala.ExprOps.matchToIfThenElse(body)) + + val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] { + def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match { + case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path)) + case _ => None + } + + override def rec(e: Expr, path: Seq[Expr]): Expr = e match { + case l : Lambda => l + case _ => super.rec(e, path) + } + }.traverse(cleanBody) + + for ((caller, args, paths) <- calls) { + if ((variablesOf(caller) & quantified).isEmpty) { + val enabler = andJoin(paths.filter(expr => (variablesOf(expr) & quantified).isEmpty)) + val skipCall = ch.getFreshLabel("skipCall") + mkExpr(enabler, ch) + ch << IfEq(skipCall) + + mkExpr(caller, ch) + ch << Ldc(args.size) << NewArray.primitive("T_BOOLEAN") + for ((arg, idx) <- args.zipWithIndex) { + ch << DUP << Ldc(idx) << Ldc(arg match { + case Variable(id) if quantified(id) => 1 + case _ => 0 + }) << BASTORE + } + + ch << InvokeVirtual(LambdaClass, "checkForall", "([Z)V") + + ch << Label(skipCall) + } + } } - ch << InvokeSpecial(afName, constructorName, consSig) + + ch << Label(skipCheck) } private val typeIdCache = scala.collection.mutable.Map.empty[TypeTree, Int] @@ -402,134 +520,270 @@ trait CodeGeneration { id } - private def compileForall(f: Forall, ch: CodeHandler)(implicit locals: Locals): Unit = { - // make sure we have an available HenkinModel - val monitorOk = ch.getFreshLabel("monitorOk") + private[codegen] val forallToClass = scala.collection.mutable.Map.empty[Expr, String] + + private def mkForall(quants: Set[Identifier], body: Expr, ch: CodeHandler, check: Boolean = true)(implicit locals: Locals): Unit = { + val (afName, closures, consSig) = compileForall(quants, body) + ch << New(afName) << DUP load(monitorID, ch) - ch << InstanceOf(HenkinClass) << IfNe(monitorOk) - ch << New(ImpossibleEvaluationClass) << DUP - ch << Ldc("Can't evaluate foralls without domain") - ch << InvokeSpecial(ImpossibleEvaluationClass, constructorName, "(Ljava/lang/String;)V") - ch << ATHROW - ch << Label(monitorOk) - - val Forall(fargs, TopLevelAnds(conjuncts)) = f - val endLabel = ch.getFreshLabel("forallEnd") - - for (conj <- conjuncts) { - val vars = purescala.ExprOps.variablesOf(conj) - val args = fargs.map(_.id).filter(vars) - val quantified = args.toSet - - val matchQuorums = extractQuorums(conj, quantified) - - val matcherIndexes = matchQuorums.flatten.distinct.zipWithIndex.toMap - - def buildLoops( - mis: List[(Expr, Seq[Expr], Int)], - localMapping: Map[Identifier, Int], - pointerMapping: Map[(Int, Int), Identifier] - ): Unit = mis match { - case (expr, args, qidx) :: rest => - load(monitorID, ch) - ch << CheckCast(HenkinClass) + mkTuple(closures.map(_.toVariable) :+ BooleanLiteral(check), ch) + ch << InvokeSpecial(afName, constructorName, consSig) + ch << InvokeVirtual(ForallClass, "check", "()Z") + } + + private def compileForall(quants: Set[Identifier], body: Expr): (String, Seq[Identifier], String) = { + val (nl, structSubst) = purescala.ExprOps.normalizeStructure(matchToIfThenElse(body)) + val reverseSubst = structSubst.map(p => p._2 -> p._1) + val nquants = quants.flatMap(structSubst.get) - mkExpr(expr, ch) - ch << Ldc(typeId(expr.getType)) - ch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") - ch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") + val closures = (purescala.ExprOps.variablesOf(nl) -- nquants).toSeq.sortBy(_.uniqueName) - val loop = ch.getFreshLabel("loop") - val out = ch.getFreshLabel("out") - ch << Label(loop) - // it - ch << DUP - // it, it - ch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") - // it, hasNext - ch << IfEq(out) << DUP - // it, it - ch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") - // it, elem - ch << CheckCast(TupleClass) - - val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { - val id = FreshIdentifier("q", arg.getType, true) - val slot = ch.getFreshVar - - ch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") - mkUnbox(arg.getType, ch) - ch << (typeToJVM(arg.getType) match { - case "I" | "Z" => IStore(slot) - case _ => AStore(slot) - }) - - (id -> slot, (qidx -> aidx) -> id) - }).unzip - - ch << POP - // it - - buildLoops(rest, localMapping ++ newLoc, pointerMapping ++ newPtr) - - ch << Goto(loop) - ch << Label(out) << POP - - case Nil => - var okLabel: Option[String] = None - for (quorum <- matchQuorums) { - okLabel.foreach(ok => ch << Label(ok)) - okLabel = Some(ch.getFreshLabel("quorumOk")) - - var mappings: Seq[(Identifier, Int, Int)] = Seq.empty - var constraints: Seq[(Expr, Int, Int)] = Seq.empty - var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty - - for (q @ (expr, args) <- quorum) { - val qidx = matcherIndexes(q) - val (qmappings, qconstraints) = args.zipWithIndex.partition { - case (Variable(id), aidx) => quantified(id) - case _ => false - } - - mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) - constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) - } + val afName = forallToClass.getOrElse(nl, { + val afName = "Leon$CodeGen$Forall$" + forallCounter.nextGlobal + forallToClass += nl -> afName - val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { - val base :: others = es.toList.map(p => (p._2, p._3)) - equalities ++= others.map(p => base -> p) - (id -> base) - } + val cf = new ClassFile(afName, Some(ForallClass)) + + cf.setFlags(( + CLASS_ACC_SUPER | + CLASS_ACC_PUBLIC | + CLASS_ACC_FINAL + ).asInstanceOf[U2]) - val enabler = andJoin(constraints.map { - case (e, qidx, aidx) => Equals(e, pointerMapping(qidx -> aidx).toVariable) - } ++ equalities.map { - case (k1, k2) => Equals(pointerMapping(k1).toVariable, pointerMapping(k2).toVariable) - }) + locally { + val cch = cf.addConstructor(s"L$MonitorClass;", s"L$TupleClass;").codeHandler + + cch << ALoad(0) << ALoad(1) << ALoad(2) + cch << InvokeSpecial(ForallClass, constructorName, s"(L$MonitorClass;L$TupleClass;)V") + cch << RETURN + cch.freeze + } + + locally { + val cfm = cf.addMethod("Z", "checkForall") + + cfm.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val cfch = cfm.codeHandler + + cfch << ALoad(0) << GetField(ForallClass, "closures", s"L$TupleClass;") + + val closureVars = (for ((id, idx) <- closures.zipWithIndex) yield { + val slot = cfch.getFreshVar + cfch << DUP << Ldc(idx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") + mkUnbox(id.getType, cfch) + cfch << (id.getType match { + case ValueType() => IStore(slot) + case _ => AStore(slot) + }) + id -> slot + }).toMap + + cfch << POP + + val monitorSlot = cfch.getFreshVar + cfch << ALoad(0) << GetField(ForallClass, "monitor", s"L$HenkinClass;") + cfch << AStore(monitorSlot) + + implicit val locals = NoLocals.withVars(closureVars).withVar(monitorID -> monitorSlot) + + val skipCheck = cfch.getFreshLabel("skipCheck") + cfch << ALoad(0) << GetField(ForallClass, "check", "Z") + cfch << IfEq(skipCheck) + checkQuantified(nquants, nl, cfch) + cfch << Label(skipCheck) + + val TopLevelAnds(conjuncts) = nl + val endLabel = cfch.getFreshLabel("forallEnd") + + for (conj <- conjuncts) { + val vars = purescala.ExprOps.variablesOf(conj) + val quantified = nquants.filter(vars) + + val matchQuorums = extractQuorums(conj, quantified) + + var allSlots: List[Int] = Nil + var freeSlots: List[Int] = Nil + def getSlot(): Int = freeSlots match { + case x :: xs => + freeSlots = xs + x + case Nil => + val slot = cfch.getFreshVar + allSlots = allSlots :+ slot + slot + } - mkExpr(enabler, ch)(locals.withVars(localMapping)) - ch << IfEq(okLabel.get) + for ((qrm, others) <- matchQuorums) { + val quorum = qrm.toList + + def rec(mis: List[(Expr, Expr, Seq[Expr], Int)], locs: Map[Identifier, Int], pointers: Map[(Int, Int), Identifier]): Unit = mis match { + case (TopLevelAnds(paths), expr, args, qidx) :: rest => + load(monitorID, cfch) + cfch << CheckCast(HenkinClass) + + mkExpr(expr, cfch) + cfch << Ldc(typeId(expr.getType)) + cfch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") + cfch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") + + val loop = cfch.getFreshLabel("loop") + val out = cfch.getFreshLabel("out") + cfch << Label(loop) + // it + cfch << DUP + // it, it + cfch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") + // it, hasNext + cfch << IfEq(out) << DUP + // it, it + cfch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") + // it, elem + cfch << CheckCast(TupleClass) + + val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { + val id = FreshIdentifier("q", arg.getType, true) + val slot = getSlot() + + cfch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") + mkUnbox(arg.getType, cfch) + cfch << (typeToJVM(arg.getType) match { + case "I" | "Z" => IStore(slot) + case _ => AStore(slot) + }) + + (id -> slot, (qidx -> aidx) -> id) + }).unzip + + cfch << POP + // it + + rec(rest, locs ++ newLoc, pointers ++ newPtr) + + cfch << Goto(loop) + cfch << Label(out) << POP + + case Nil => + val okLabel = cfch.getFreshLabel("assignmentOk") + + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + + for ((q @ (_, expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id), aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, pointers(qidx -> aidx).toVariable) + } ++ equalities.map { + case (k1, k2) => Equals(pointers(k1).toVariable, pointers(k2).toVariable) + }) + + val varsMap = quantified.map(id => id -> locs(pointers(mapping(id)))).toMap + val varLocals = locals.withVars(varsMap) + + mkExpr(enabler, cfch)(varLocals.withVars(locs)) + cfch << IfEq(okLabel) + + val checkOk = cfch.getFreshLabel("checkOk") + load(monitorID, cfch) + cfch << GetField(HenkinClass, "checkForalls", "Z") + cfch << IfEq(checkOk) + + var nextLabel: Option[String] = None + for ((b,caller,args) <- others) { + nextLabel.foreach(label => cfch << Label(label)) + nextLabel = Some(cfch.getFreshLabel("next")) + + mkExpr(b, cfch)(varLocals) + cfch << IfEq(nextLabel.get) + + load(monitorID, cfch) + cfch << CheckCast(HenkinClass) + mkExpr(caller, cfch)(varLocals) + cfch << Ldc(typeId(caller.getType)) + cfch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") + mkTuple(args, cfch)(varLocals) + cfch << InvokeInterface(JavaListClass, "contains", s"(L$ObjectClass;)Z") + cfch << IfNe(nextLabel.get) + + cfch << New(InvalidForallClass) << DUP + cfch << Ldc("Unhandled transitive implication in " + conj) + cfch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") + cfch << ATHROW + } + nextLabel.foreach(label => cfch << Label(label)) + + cfch << Label(checkOk) + mkExpr(conj, cfch)(varLocals) + cfch << IfNe(okLabel) + + // -- Forall is false! -- + // POP all the iterators... + for (_ <- List.range(0, quorum.size)) cfch << POP + + // ... and return false + cfch << Ldc(0) << Goto(endLabel) + cfch << Label(okLabel) + } - val varsMap = args.map(id => id -> localMapping(pointerMapping(mapping(id)))).toMap - mkExpr(conj, ch)(locals.withVars(varsMap)) - ch << IfNe(okLabel.get) + val skipQuorum = cfch.getFreshLabel("skipQuorum") + for ((TopLevelAnds(paths), _, _) <- quorum) { + val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty)) + mkExpr(p, cfch) + cfch << IfEq(skipQuorum) + } - // -- Forall is false! -- - // POP all the iterators... - for (_ <- List.range(0, matcherIndexes.size)) ch << POP + val mis = quorum.zipWithIndex.map { case ((p, e, as), idx) => (p, e, as, idx) } + rec(mis, Map.empty, Map.empty) + freeSlots = allSlots - // ... and return false - ch << Ldc(0) << Goto(endLabel) + cfch << Label(skipQuorum) } + } + + cfch << Ldc(1) << Label(endLabel) + cfch << IRETURN - ch << Label(okLabel.get) + cfch.freeze } - buildLoops(matcherIndexes.toList.map { case ((e, as), idx) => (e, as, idx) }, Map.empty, Map.empty) - } + loader.register(cf) + + afName + }) - ch << Ldc(1) << Label(endLabel) + (afName, closures.map(reverseSubst), s"(L$MonitorClass;L$TupleClass;)V") + } + + // also makes tuples with 0/1 args + private def mkTuple(es: Seq[Expr], ch: CodeHandler)(implicit locals: Locals) : Unit = { + ch << New(TupleClass) << DUP + ch << Ldc(es.size) + ch << NewArray(s"$ObjectClass") + for((e,i) <- es.zipWithIndex) { + ch << DUP + ch << Ldc(i) + mkBoxedExpr(e, ch) + ch << AASTORE + } + ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V") } private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { @@ -619,17 +873,7 @@ trait CodeGeneration { instrumentedGetField(ch, cct, sid) // Tuples (note that instanceOf checks are in mkBranch) - case Tuple(es) => - ch << New(TupleClass) << DUP - ch << Ldc(es.size) - ch << NewArray(s"$ObjectClass") - for((e,i) <- es.zipWithIndex) { - ch << DUP - ch << Ldc(i) - mkBoxedExpr(e, ch) - ch << AASTORE - } - ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V") + case Tuple(es) => mkTuple(es, ch) case TupleSelect(t, i) => val TupleType(bs) = t.getType @@ -654,7 +898,7 @@ trait CodeGeneration { case SetCardinality(s) => mkExpr(s, ch) - ch << InvokeVirtual(SetClass, "size", "()I") + ch << InvokeVirtual(SetClass, "size", s"()$BigIntClass;") case SubsetOf(s1, s2) => mkExpr(s1, ch) @@ -882,11 +1126,38 @@ trait CodeGeneration { ch << InvokeVirtual(LambdaClass, "apply", s"([L$ObjectClass;)L$ObjectClass;") mkUnbox(app.getType, ch) + case p @ PartialLambda(mapping, optDflt, _) => + ch << New(PartialLambdaClass) << DUP + optDflt match { + case Some(dflt) => + mkBoxedExpr(dflt, ch) + ch << InvokeSpecial(PartialLambdaClass, constructorName, s"(L$ObjectClass;)V") + case None => + ch << InvokeSpecial(PartialLambdaClass, constructorName, "()V") + } + + for ((es,v) <- mapping) { + ch << DUP + mkTuple(es, ch) + mkBoxedExpr(v, ch) + ch << InvokeVirtual(PartialLambdaClass, "add", s"(L$TupleClass;L$ObjectClass;)V") + } + case l @ Lambda(args, body) => - compileLambda(l, ch) + val (afName, closures, consSig) = compileLambda(l) + + ch << New(afName) << DUP + for ((id,jvmt) <- closures) { + if (id == monitorID) { + load(monitorID, ch) + } else { + mkExpr(Variable(id), ch) + } + } + ch << InvokeSpecial(afName, constructorName, consSig) case f @ Forall(args, body) => - compileForall(f, ch) + mkForall(args.map(_.id).toSet, body, ch) // Arithmetic case Plus(l, r) => @@ -1098,7 +1369,7 @@ trait CodeGeneration { ch << ATHROW case choose: Choose => - val prob = synthesis.Problem.fromChoose(choose) + val prob = synthesis.Problem.fromSpec(choose.pred) val id = runtime.ChooseEntryPoint.register(prob, this) ch << Ldc(id) @@ -1179,7 +1450,7 @@ trait CodeGeneration { // Assumes the top of the stack contains of value of the right type, and makes it // compatible with java.lang.Object. - private[codegen] def mkBox(tpe: TypeTree, ch: CodeHandler)(implicit locals: Locals) { + private[codegen] def mkBox(tpe: TypeTree, ch: CodeHandler): Unit = { tpe match { case Int32Type => ch << New(BoxedIntClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedIntClass, constructorName, "(I)V") @@ -1197,7 +1468,7 @@ trait CodeGeneration { } // Assumes that the top of the stack contains a value that should be of type `tpe`, and unboxes it to the right (JVM) type. - private[codegen] def mkUnbox(tpe: TypeTree, ch: CodeHandler)(implicit locals: Locals) { + private[codegen] def mkUnbox(tpe: TypeTree, ch: CodeHandler): Unit = { tpe match { case Int32Type => ch << CheckCast(BoxedIntClass) << InvokeVirtual(BoxedIntClass, "intValue", "()I") @@ -1469,7 +1740,7 @@ trait CodeGeneration { lzy.returnType match { case ValueType() => // Since the underlying field only has boxed types, we have to unbox them to return them - mkUnbox(lzy.returnType, ch)(newLocs) + mkUnbox(lzy.returnType, ch) ch << IRETURN case _ => ch << ARETURN @@ -1803,7 +2074,7 @@ trait CodeGeneration { pech << Ldc(i) pech << ALoad(0) instrumentedGetField(pech, cct, f.id)(newLocs) - mkBox(f.getType, pech)(newLocs) + mkBox(f.getType, pech) pech << AASTORE } diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 69f95f45559737bfd895989aba010dbc0667bb3a..556cf7b05420bbdf29583b46ba13121e5a5e0328 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -127,13 +127,24 @@ class CompilationUnit(val ctx: LeonContext, conss.last } - def modelToJVM(model: solvers.Model, maxInvocations: Int): LeonCodeGenRuntimeMonitor = model match { + def modelToJVM(model: solvers.Model, maxInvocations: Int, check: Boolean): LeonCodeGenRuntimeMonitor = model match { case hModel: solvers.HenkinModel => - val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations) - for ((tpe, domain) <- hModel.domains; args <- domain) { + val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations, check) + for ((lambda, domain) <- hModel.doms.lambdas) { + val (afName, _, _) = compileLambda(lambda) + val lc = loader.loadClass(afName) + + for (args <- domain) { + // note here that it doesn't matter that `lhm` doesn't yet have its domains + // filled since all values in `args` should be grounded + val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] + lhm.add(lc, inputJvm) + } + } + + for ((tpe, domain) <- hModel.doms.tpes; args <- domain) { val tpeId = typeId(tpe) - // note here that it doesn't matter that `lhm` doesn't yet have its domains - // filled since all values in `args` should be grounded + // same remark as above about valueToJVM(_)(lhm) val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] lhm.add(tpeId, inputJvm) } @@ -201,8 +212,13 @@ class CompilationUnit(val ctx: LeonContext, } m - case f @ PartialLambda(mapping, _) => - val l = new leon.codegen.runtime.PartialLambda() + case f @ PartialLambda(mapping, dflt, _) => + val l = if (dflt.isDefined) { + new leon.codegen.runtime.PartialLambda(dflt.get) + } else { + new leon.codegen.runtime.PartialLambda() + } + for ((ks,v) <- mapping) { // Force tuple even with 1/0 elems. val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple] @@ -505,13 +521,11 @@ class CompilationUnit(val ctx: LeonContext, for { ch <- u.classHierarchies c <- ch - } { - c match { - case acd: AbstractClassDef => - compileAbstractClassDef(acd) - case ccd: CaseClassDef => - compileCaseClassDef(ccd) - } + } c match { + case acd: AbstractClassDef => + compileAbstractClassDef(acd) + case ccd: CaseClassDef => + compileCaseClassDef(ccd) } for (m <- u.modules) compileModule(m) @@ -532,3 +546,5 @@ class CompilationUnit(val ctx: LeonContext, private [codegen] object exprCounter extends UniqueCounter[Unit] private [codegen] object lambdaCounter extends UniqueCounter[Unit] +private [codegen] object forallCounter extends UniqueCounter[Unit] + diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index a9d1eda0c5e36e19a6b6c12f99a617b480866f10..f9fca911564c61ad984fa97c3f2ac0da7fc021b4 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -8,7 +8,7 @@ import purescala.Expressions._ import cafebabe._ -import runtime.{LeonCodeGenRuntimeMonitor => LM, LeonCodeGenRuntimeHenkinMonitor => LHM} +import runtime.{LeonCodeGenRuntimeMonitor => LM} import java.lang.reflect.InvocationTargetException @@ -51,9 +51,9 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, } } - def eval(model: solvers.Model) : Expr = { + def eval(model: solvers.Model, check: Boolean = false) : Expr = { try { - val monitor = unit.modelToJVM(model, params.maxFunctionInvocations) + val monitor = unit.modelToJVM(model, params.maxFunctionInvocations, check) evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) } catch { case ite : InvocationTargetException => throw ite.getCause diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index e68e21f011f3936797157688901018923c2127bf..7aa99973c1a0ac2e510a5f2f512d501703081ac9 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -33,12 +33,24 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { b -> Constructor[Expr, TypeTree](List(), BooleanType, s => BooleanLiteral(b), ""+b) }).toMap + val chars = (for (c <- Set('a', 'b', 'c', 'd')) yield { + c -> Constructor[Expr, TypeTree](List(), CharType, s => CharLiteral(c), ""+c) + }).toMap + + val rationals = (for (n <- Set(0, 1, 2, 3); d <- Set(1,2,3,4)) yield { + (n, d) -> Constructor[Expr, TypeTree](List(), RealType, s => FractionalLiteral(n, d), "" + n + "/" + d) + }).toMap + def intConstructor(i: Int) = ints(i) def bigIntConstructor(i: Int) = bigInts(i) def boolConstructor(b: Boolean) = booleans(b) + def charConstructor(c: Char) = chars(c) + + def rationalConstructor(n: Int, d: Int) = rationals(n -> d) + def cPattern(c: Constructor[Expr, TypeTree], args: Seq[VPattern[Expr, TypeTree]]) = { ConstructorPattern[Expr, TypeTree](c, args) } @@ -50,7 +62,6 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { getConstructors(t).head.copy(retType = act) } - private def getConstructors(t: TypeTree): List[Constructor[Expr, TypeTree]] = t match { case UnitType => constructors.getOrElse(t, { @@ -97,8 +108,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case mt @ MapType(from, to) => constructors.getOrElse(mt, { val cs = for (size <- List(0, 1, 2, 5)) yield { - val subs = (1 to size).flatMap(i => List(from, to)).toList - + val subs = (1 to size).flatMap(i => List(from, to)).toList Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toMap, from, to), mt.asString(ctx)+"@"+size) } constructors += mt -> cs @@ -110,13 +120,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val cs = for (size <- List(1, 2, 3, 5)) yield { val subs = (1 to size).flatMap(_ => from :+ to).toList Constructor[Expr, TypeTree](subs, ft, { s => - val args = from.map(tpe => FreshIdentifier("x", tpe, true)) - val argsTuple = tupleWrap(args.map(_.toVariable)) val grouped = s.grouped(from.size + 1).toSeq - val body = grouped.init.foldRight(grouped.last.last) { case (t, elze) => - IfExpr(Equals(argsTuple, tupleWrap(t.init)), t.last, elze) - } - Lambda(args.map(id => ValDef(id)), body) + val mapping = grouped.init.map { case args :+ res => (args -> res) } + PartialLambda(mapping, Some(grouped.last.last), ft) }, ft.asString(ctx) + "@" + size) } constructors += ft -> cs @@ -166,6 +172,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case (b: java.lang.Boolean, BooleanType) => (cPattern(boolConstructor(b), List()), true) + case (c: java.lang.Character, CharType) => + (cPattern(charConstructor(c), List()), true) + case (cc: codegen.runtime.CaseClass, ct: ClassType) => val r = cc.__getRead() @@ -193,7 +202,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { (ConstructorPattern(c, elems.map(_._1)), elems.forall(_._2)) case _ => - ctx.reporter.error("Could not retreive type for :"+cc.getClass.getName) + ctx.reporter.error("Could not retrieve type for :"+cc.getClass.getName) (AnyPattern[Expr, TypeTree](), false) } @@ -217,12 +226,13 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case (gv: GenericValue, t: TypeParameter) => (cPattern(getConstructors(t)(gv.id-1), List()), true) + case (v, t) => ctx.reporter.debug("Unsupported value, can't paternify : "+v+" ("+v.getClass+") : "+t) (AnyPattern[Expr, TypeTree](), false) } - type InstrumentedResult = (EvaluationResults.Result, Option[vanuatoo.Pattern[Expr, TypeTree]]) + type InstrumentedResult = (EvaluationResults.Result[Expr], Option[vanuatoo.Pattern[Expr, TypeTree]]) def compile(expression: Expr, argorder: Seq[Identifier]) : Option[Expr=>InstrumentedResult] = { import leon.codegen.runtime.LeonCodeGenRuntimeException @@ -287,8 +297,8 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { None }) - - val gen = new StubGenerator[Expr, TypeTree]((ints.values ++ bigInts.values ++ booleans.values).toSeq, + val stubValues = ints.values ++ bigInts.values ++ booleans.values ++ chars.values ++ rationals.values + val gen = new StubGenerator[Expr, TypeTree](stubValues.toSeq, Some(getConstructors _), treatEmptyStubsAsChildless = true) diff --git a/src/main/scala/leon/evaluators/AngelicEvaluator.scala b/src/main/scala/leon/evaluators/AngelicEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..99d704f67c7485ba8e00727c4edcc5d30644b4cc --- /dev/null +++ b/src/main/scala/leon/evaluators/AngelicEvaluator.scala @@ -0,0 +1,48 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package evaluators + +import leon.solvers.Model +import purescala.Expressions.Expr +import EvaluationResults._ + +class AngelicEvaluator(underlying: NDEvaluator) + extends Evaluator(underlying.context, underlying.program) + with DeterministicEvaluator +{ + val description: String = "angelic evaluator" + val name: String = "Interpreter that returns the first solution of a non-deterministic program" + + def eval(expr: Expr, model: Model): EvaluationResult = underlying.eval(expr, model) match { + case Successful(Stream(h, _*)) => + Successful(h) + case Successful(Stream()) => + RuntimeError("Underlying ND-evaluator returned no solution") + case other@(RuntimeError(_) | EvaluatorError(_)) => + other.asInstanceOf[Result[Nothing]] + } + + /** Checks that `model |= expr` and that quantifications are all valid */ + def check(expr: Expr, model: Model): CheckResult = underlying.check(expr, model) +} + +class DemonicEvaluator(underlying: NDEvaluator) + extends Evaluator(underlying.context, underlying.program) + with DeterministicEvaluator +{ + val description: String = "demonic evaluator" + val name: String = "Interpreter that demands an underlying non-deterministic program has unique solution" + + def eval(expr: Expr, model: Model): EvaluationResult = underlying.eval(expr, model) match { + case Successful(Stream(h)) => + Successful(h) + case Successful(_) => + RuntimeError("Underlying ND-evaluator did not return unique solution!") + case other@(RuntimeError(_) | EvaluatorError(_)) => + other.asInstanceOf[Result[Nothing]] + } + + /** Checks that `model |= expr` and that quantifications are all valid */ + def check(expr: Expr, model: Model): CheckResult = underlying.check(expr, model) +} \ No newline at end of file diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index 36cd9da0c7c35cfdc9d674162c3279d461afe813..d92d168bf2361e0ca054a7e88c276dd7cc34e839 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -6,12 +6,17 @@ package evaluators import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.Quantification._ import codegen.CompilationUnit +import codegen.CompiledExpression import codegen.CodeGenParams -class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) { +import leon.codegen.runtime.LeonCodeGenRuntimeException +import leon.codegen.runtime.LeonCodeGenEvaluationException +import leon.codegen.runtime.LeonCodeGenQuantificationException + +class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) with DeterministicEvaluator { + val name = "codegen-eval" val description = "Evaluator for PureScala expressions based on compilation to JVM" @@ -20,9 +25,55 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva this(ctx, new CompilationUnit(ctx, prog, params)) } + private def compileExpr(expression: Expr, args: Seq[Identifier]): Option[CompiledExpression] = { + ctx.timers.evaluators.codegen.compilation.start() + try { + Some(unit.compileExpression(expression, args)(ctx)) + } catch { + case t: Throwable => + ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage) + None + } finally { + ctx.timers.evaluators.codegen.compilation.stop() + } + } + + def check(expression: Expr, model: solvers.Model) : CheckResult = { + compileExpr(expression, model.toSeq.map(_._1)).map { ce => + ctx.timers.evaluators.codegen.runtime.start() + try { + val res = ce.eval(model, check = true) + if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess + else EvaluationResults.CheckValidityFailure + } catch { + case e : ArithmeticException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : ArrayIndexOutOfBoundsException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : LeonCodeGenRuntimeException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : LeonCodeGenEvaluationException => + EvaluationResults.CheckRuntimeFailure(e.getMessage) + + case e : java.lang.ExceptionInInitializerError => + EvaluationResults.CheckRuntimeFailure(e.getException.getMessage) + + case so : java.lang.StackOverflowError => + EvaluationResults.CheckRuntimeFailure("Stack overflow") + + case e : LeonCodeGenQuantificationException => + EvaluationResults.CheckQuantificationFailure(e.getMessage) + } finally { + ctx.timers.evaluators.codegen.runtime.stop() + } + }.getOrElse(EvaluationResults.CheckRuntimeFailure("Couldn't compile expression.")) + } + def eval(expression: Expr, model: solvers.Model) : EvaluationResult = { - val toPairs = model.toSeq - compile(expression, toPairs.map(_._1)).map { e => + compile(expression, model.toSeq.map(_._1)).map { e => ctx.timers.evaluators.codegen.runtime.start() val res = e(model) ctx.timers.evaluators.codegen.runtime.stop() @@ -31,45 +82,30 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva } override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { - import leon.codegen.runtime.LeonCodeGenRuntimeException - import leon.codegen.runtime.LeonCodeGenEvaluationException - - ctx.timers.evaluators.codegen.compilation.start() - try { - val ce = unit.compileExpression(expression, args)(ctx) - - Some((model: solvers.Model) => { - if (args.exists(arg => !model.isDefinedAt(arg))) { - EvaluationResults.EvaluatorError("Model undefined for free arguments") - } else try { - EvaluationResults.Successful(ce.eval(model)) - } catch { - case e : ArithmeticException => - EvaluationResults.RuntimeError(e.getMessage) + compileExpr(expression, args).map(ce => (model: solvers.Model) => { + if (args.exists(arg => !model.isDefinedAt(arg))) { + EvaluationResults.EvaluatorError("Model undefined for free arguments") + } else try { + EvaluationResults.Successful(ce.eval(model)) + } catch { + case e : ArithmeticException => + EvaluationResults.RuntimeError(e.getMessage) - case e : ArrayIndexOutOfBoundsException => - EvaluationResults.RuntimeError(e.getMessage) + case e : ArrayIndexOutOfBoundsException => + EvaluationResults.RuntimeError(e.getMessage) - case e : LeonCodeGenRuntimeException => - EvaluationResults.RuntimeError(e.getMessage) + case e : LeonCodeGenRuntimeException => + EvaluationResults.RuntimeError(e.getMessage) - case e : LeonCodeGenEvaluationException => - EvaluationResults.EvaluatorError(e.getMessage) + case e : LeonCodeGenEvaluationException => + EvaluationResults.EvaluatorError(e.getMessage) - case e : java.lang.ExceptionInInitializerError => - EvaluationResults.RuntimeError(e.getException.getMessage) + case e : java.lang.ExceptionInInitializerError => + EvaluationResults.RuntimeError(e.getException.getMessage) - case so : java.lang.StackOverflowError => - EvaluationResults.RuntimeError("Stack overflow") - - } - }) - } catch { - case t: Throwable => - ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage) - None - } finally { - ctx.timers.evaluators.codegen.compilation.stop() - } + case so : java.lang.StackOverflowError => + EvaluationResults.RuntimeError("Stack overflow") + } + }) } } diff --git a/src/main/scala/leon/evaluators/ContextualEvaluator.scala b/src/main/scala/leon/evaluators/ContextualEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..0fc33102a04716816fc3b2a83faa1384b37da1fd --- /dev/null +++ b/src/main/scala/leon/evaluators/ContextualEvaluator.scala @@ -0,0 +1,139 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package evaluators + +import leon.purescala.Extractors.{IsTyped, TopLevelAnds} +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.Types._ +import solvers.{HenkinModel, Model} + +abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps: Int) extends Evaluator(ctx, prog) with CEvalHelpers { + + protected implicit val _ = ctx + + type RC <: RecContext[RC] + type GC <: GlobalContext + + def initRC(mappings: Map[Identifier, Expr]): RC + def initGC(model: solvers.Model, check: Boolean): GC + + case class EvalError(msg : String) extends Exception + case class RuntimeError(msg : String) extends Exception + case class QuantificationError(msg: String) extends Exception + + // Used by leon-web, please do not delete + var lastGC: Option[GC] = None + + def eval(ex: Expr, model: Model) = { + try { + lastGC = Some(initGC(model, check = true)) + ctx.timers.evaluators.recursive.runtime.start() + EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get)) + } catch { + case so: StackOverflowError => + EvaluationResults.EvaluatorError("Stack overflow") + case EvalError(msg) => + EvaluationResults.EvaluatorError(msg) + case e @ RuntimeError(msg) => + EvaluationResults.RuntimeError(msg) + case jre: java.lang.RuntimeException => + EvaluationResults.RuntimeError(jre.getMessage) + } finally { + ctx.timers.evaluators.recursive.runtime.stop() + } + } + + def check(ex: Expr, model: Model): CheckResult = { + assert(ex.getType == BooleanType, "Can't check non-boolean expression " + ex.asString) + try { + lastGC = Some(initGC(model, check = true)) + ctx.timers.evaluators.recursive.runtime.start() + val res = e(ex)(initRC(model.toMap), lastGC.get) + if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess + else EvaluationResults.CheckValidityFailure + } catch { + case so: StackOverflowError => + EvaluationResults.CheckRuntimeFailure("Stack overflow") + case e @ EvalError(msg) => + EvaluationResults.CheckRuntimeFailure(msg) + case e @ RuntimeError(msg) => + EvaluationResults.CheckRuntimeFailure(msg) + case jre: java.lang.RuntimeException => + EvaluationResults.CheckRuntimeFailure(jre.getMessage) + case qe @ QuantificationError(msg) => + EvaluationResults.CheckQuantificationFailure(msg) + } finally { + ctx.timers.evaluators.recursive.runtime.stop() + } + } + + protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): Value + + def typeErrorMsg(tree : Expr, expected : TypeTree) : String = s"Type error : expected ${expected.asString}, found ${tree.asString}." + +} + +private[evaluators] trait CEvalHelpers { + this: ContextualEvaluator => + + /* This is an effort to generalize forall to non-det. solvers + def forallInstantiations(gctx:GC, fargs: Seq[ValDef], conj: Expr) = { + + val henkinModel: HenkinModel = gctx.model match { + case hm: HenkinModel => hm + case _ => throw EvalError("Can't evaluate foralls without henkin model") + } + + val vars = variablesOf(conj) + val args = fargs.map(_.id).filter(vars) + val quantified = args.toSet + + val matcherQuorums = extractQuorums(conj, quantified) + + matcherQuorums.flatMap { quorum => + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + + for (((expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id), aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { + case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d)) + } + + argSets.map { args => + val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { + case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } + }.toMap + + val map = mapping.map { case (id, key) => id -> argMap(key) } + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) + } ++ equalities.map { + case (k1, k2) => Equals(argMap(k1), argMap(k2)) + }) + + (enabler, map) + } + }*/ + + + +} \ No newline at end of file diff --git a/src/main/scala/leon/evaluators/DefaultEvaluator.scala b/src/main/scala/leon/evaluators/DefaultEvaluator.scala index d732d48c0d40aaacf97cd8125b077c1cef397148..18a9159c3cb0e29f47e0757314f935865f6b10cf 100644 --- a/src/main/scala/leon/evaluators/DefaultEvaluator.scala +++ b/src/main/scala/leon/evaluators/DefaultEvaluator.scala @@ -3,19 +3,9 @@ package leon package evaluators -import purescala.Common._ -import purescala.Expressions._ -import purescala.Definitions._ -import purescala.Quantification._ +import purescala.Definitions.Program -class DefaultEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) { - type RC = DefaultRecContext - type GC = GlobalContext - - def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC(model: solvers.Model) = new GlobalContext(model) - - case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext { - def newVars(news: Map[Identifier, Expr]) = copy(news) - } -} +class DefaultEvaluator(ctx: LeonContext, prog: Program) + extends RecursiveEvaluator(ctx, prog, 5000) + with HasDefaultGlobalContext + with HasDefaultRecContext diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index cd843fbb145e4f9220b2e9fe3d91e30d8ff3c1be..4c405c8b6f216ee9b839101d8bff5574035b05f4 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -6,27 +6,26 @@ package evaluators import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ -import purescala.Quantification._ import purescala.Types._ import codegen._ -class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) extends RecursiveEvaluator(ctx, prog, params.maxFunctionInvocations) { - type RC = DefaultRecContext - type GC = GlobalContext +class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) + extends RecursiveEvaluator(ctx, prog, params.maxFunctionInvocations) + with HasDefaultGlobalContext +{ + type RC = DualRecContext + def initRC(mappings: Map[Identifier, Expr]): RC = DualRecContext(mappings) implicit val debugSection = utils.DebugSectionEvaluation - def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC(model: solvers.Model) = new GlobalContext(model) - var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) val unit = new CompilationUnit(ctx, prog, params) val isCompiled = prog.definedFunctions.toSet - case class DefaultRecContext(mappings: Map[Identifier, Expr], needJVMRef: Boolean = false) extends RecContext { + case class DualRecContext(mappings: Map[Identifier, Expr], needJVMRef: Boolean = false) extends RecContext[DualRecContext] { def newVars(news: Map[Identifier, Expr]) = copy(news) } diff --git a/src/main/scala/leon/evaluators/EvaluationResults.scala b/src/main/scala/leon/evaluators/EvaluationResults.scala index e37c61e3b73c8261fb9ca2456abd0fd8811b1877..18f7a0c92d448f98c8f6a271d91e021a649e3b9c 100644 --- a/src/main/scala/leon/evaluators/EvaluationResults.scala +++ b/src/main/scala/leon/evaluators/EvaluationResults.scala @@ -3,18 +3,33 @@ package leon package evaluators -import purescala.Expressions.Expr - object EvaluationResults { /** Possible results of expression evaluation. */ - sealed abstract class Result(val result : Option[Expr]) + sealed abstract class Result[+A](val result : Option[A]) /** Represents an evaluation that successfully derived the result `value`. */ - case class Successful(value : Expr) extends Result(Some(value)) + case class Successful[A](value : A) extends Result(Some(value)) /** Represents an evaluation that led to an error (in the program). */ case class RuntimeError(message : String) extends Result(None) /** Represents an evaluation that failed (in the evaluator). */ case class EvaluatorError(message : String) extends Result(None) + + /** Results of checking proposition evaluation. + * Useful for verification of model validity in presence of quantifiers. */ + sealed abstract class CheckResult(val success: Boolean) + + /** Successful proposition evaluation (model |= expr) */ + case object CheckSuccess extends CheckResult(true) + + /** Check failed with `model |= !expr` */ + case object CheckValidityFailure extends CheckResult(false) + + /** Check failed due to evaluation or runtime errors. + * @see [[RuntimeError]] and [[EvaluatorError]] */ + case class CheckRuntimeFailure(msg: String) extends CheckResult(false) + + /** Check failed due to inconsistence of model with quantified propositions. */ + case class CheckQuantificationFailure(msg: String) extends CheckResult(false) } diff --git a/src/main/scala/leon/evaluators/Evaluator.scala b/src/main/scala/leon/evaluators/Evaluator.scala index 9d14bd3dfcc0eb6f08f5273b98ac1448038b363e..ff0f35f1241547d66f81f0b341fca508276b40ea 100644 --- a/src/main/scala/leon/evaluators/Evaluator.scala +++ b/src/main/scala/leon/evaluators/Evaluator.scala @@ -6,14 +6,19 @@ package evaluators import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.Quantification._ -import purescala.ExprOps._ import solvers.Model abstract class Evaluator(val context: LeonContext, val program: Program) extends LeonComponent { - type EvaluationResult = EvaluationResults.Result + /** The type of value that this [[Evaluator]] calculates + * Typically, it will be [[Expr]] for deterministic evaluators, and + * [[Stream[Expr]]] for non-deterministic ones. + */ + type Value + + type EvaluationResult = EvaluationResults.Result[Value] + type CheckResult = EvaluationResults.CheckResult /** Evaluates an expression, using [[Model.mapping]] as a valuation function for the free variables. */ def eval(expr: Expr, model: Model) : EvaluationResult @@ -26,10 +31,14 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends /** Evaluates a ground expression. */ final def eval(expr: Expr) : EvaluationResult = eval(expr, Model.empty) + /** Checks that `model |= expr` and that quantifications are all valid */ + def check(expr: Expr, model: Model) : CheckResult + /** Compiles an expression into a function, where the arguments are the free variables in the expression. * `argorder` specifies in which order the arguments should be passed. * The default implementation uses the evaluation function each time, but evaluators are free - * to (and encouraged to) apply any specialization. */ + * to (and encouraged to) apply any specialization. + */ def compile(expr: Expr, args: Seq[Identifier]) : Option[Model => EvaluationResult] = Some( (model: Model) => if(args.exists(arg => !model.isDefinedAt(arg))) { EvaluationResults.EvaluatorError("Wrong number of arguments for evaluation.") @@ -39,3 +48,10 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends ) } +trait DeterministicEvaluator extends Evaluator { + type Value = Expr +} + +trait NDEvaluator extends Evaluator { + type Value = Stream[Expr] +} diff --git a/src/main/scala/leon/evaluators/EvaluatorContexts.scala b/src/main/scala/leon/evaluators/EvaluatorContexts.scala new file mode 100644 index 0000000000000000000000000000000000000000..a63ee6483bfcdb7804cd95bfcb32df83a66e235b --- /dev/null +++ b/src/main/scala/leon/evaluators/EvaluatorContexts.scala @@ -0,0 +1,44 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package evaluators + +import purescala.Common.Identifier +import leon.purescala.Expressions.{Lambda, Expr} +import solvers.Model + +import scala.collection.mutable.{Map => MutableMap} + +trait RecContext[RC <: RecContext[RC]] { + def mappings: Map[Identifier, Expr] + + def newVars(news: Map[Identifier, Expr]): RC + + def withNewVar(id: Identifier, v: Expr): RC = { + newVars(mappings + (id -> v)) + } + + def withNewVars(news: Map[Identifier, Expr]): RC = { + newVars(mappings ++ news) + } +} + +case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext[DefaultRecContext] { + def newVars(news: Map[Identifier, Expr]) = copy(news) +} + +class GlobalContext(val model: Model, val maxSteps: Int, val check: Boolean) { + var stepsLeft = maxSteps + + val lambdas: MutableMap[Lambda, Lambda] = MutableMap.empty +} + +trait HasDefaultRecContext extends ContextualEvaluator { + type RC = DefaultRecContext + def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) +} + +trait HasDefaultGlobalContext extends ContextualEvaluator { + def initGC(model: solvers.Model, check: Boolean) = new GlobalContext(model, this.maxSteps, check) + type GC = GlobalContext +} \ No newline at end of file diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 80f71fd87be7f27e1793e13fe38eb09b78f78d62..c3f9eed0f0d55b0f75628d491a998facd29aa346 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -3,78 +3,30 @@ package leon package evaluators -import purescala.Common._ -import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.Expressions._ -import purescala.Types._ -import purescala.TypeOps.isSubtypeOf +import leon.purescala.Quantification._ import purescala.Constructors._ +import purescala.ExprOps._ +import purescala.Expressions.Pattern import purescala.Extractors._ -import purescala.Quantification._ -import solvers.{Model, HenkinModel} -import solvers.SolverFactory - -abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int) extends Evaluator(ctx, prog) { - val name = "evaluator" - val description = "Recursive interpreter for PureScala expressions" - - private implicit val _ = ctx - - type RC <: RecContext - type GC <: GlobalContext - - case class EvalError(msg : String) extends Exception - case class RuntimeError(msg : String) extends Exception - - val scalaEv = new ScalacEvaluator(this, ctx, prog) - - trait RecContext { - def mappings: Map[Identifier, Expr] - - def newVars(news: Map[Identifier, Expr]): RC +import purescala.TypeOps._ +import purescala.Types._ +import purescala.Common._ +import purescala.Expressions._ +import purescala.Definitions._ +import leon.solvers.{HenkinModel, Model, SolverFactory} - def withNewVar(id: Identifier, v: Expr): RC = { - newVars(mappings + (id -> v)) - } +import scala.collection.mutable.{Map => MutableMap} - def withNewVars(news: Map[Identifier, Expr]): RC = { - newVars(mappings ++ news) - } - } +abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int) + extends ContextualEvaluator(ctx, prog, maxSteps) + with DeterministicEvaluator { - class GlobalContext(val model: Model) { - def maxSteps = RecursiveEvaluator.this.maxSteps + val name = "evaluator" + val description = "Recursive interpreter for PureScala expressions" - var stepsLeft = maxSteps - } + lazy val scalaEv = new ScalacEvaluator(this, ctx, prog) - def initRC(mappings: Map[Identifier, Expr]): RC - def initGC(model: Model): GC - - // Used by leon-web, please do not delete - var lastGC: Option[GC] = None - - private[this] var clpCache = Map[(Choose, Seq[Expr]), Expr]() - - def eval(ex: Expr, model: Model) = { - try { - lastGC = Some(initGC(model)) - ctx.timers.evaluators.recursive.runtime.start() - EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get)) - } catch { - case so: StackOverflowError => - EvaluationResults.EvaluatorError("Stack overflow") - case EvalError(msg) => - EvaluationResults.EvaluatorError(msg) - case e @ RuntimeError(msg) => - EvaluationResults.RuntimeError(msg) - case jre: java.lang.RuntimeException => - EvaluationResults.RuntimeError(jre.getMessage) - } finally { - ctx.timers.evaluators.recursive.runtime.stop() - } - } + protected var clpCache = Map[(Choose, Seq[Expr]), Expr]() protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { case Variable(id) => @@ -93,11 +45,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val newArgs = args.map(e) val mapping = l.paramSubst(newArgs) e(body)(rctx.withNewVars(mapping), gctx) - case PartialLambda(mapping, _) => + case PartialLambda(mapping, dflt, _) => mapping.find { case (pargs, res) => (args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true)) - }.map(_._2).getOrElse { - throw EvalError("Cannot apply partial lambda outside of domain") + }.map(_._2).orElse(dflt).getOrElse { + throw EvalError("Cannot apply partial lambda outside of domain : " + + args.map(e(_).asString(ctx)).mkString("(", ", ", ")")) } case f => throw EvalError("Cannot apply non-lambda function " + f.asString) @@ -195,9 +148,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int callResult - case And(args) if args.isEmpty => - BooleanLiteral(true) - + case And(args) if args.isEmpty => BooleanLiteral(true) case And(args) => e(args.head) match { case BooleanLiteral(false) => BooleanLiteral(false) @@ -220,9 +171,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } case Implies(l,r) => - (e(l), e(r)) match { - case (BooleanLiteral(b1),BooleanLiteral(b2)) => BooleanLiteral(!b1 || b2) - case (le, re) => throw EvalError(typeErrorMsg(le, BooleanType)) + e(l) match { + case BooleanLiteral(false) => BooleanLiteral(true) + case BooleanLiteral(true) => e(r) + case le => throw EvalError(typeErrorMsg(le, BooleanType)) } case Equals(le,re) => @@ -232,7 +184,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (lv,rv) match { case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) - case (PartialLambda(m1, _), PartialLambda(m2, _)) => BooleanLiteral(m1.toSet == m2.toSet) + case (PartialLambda(m1, d1, _), PartialLambda(m2, d2, _)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) case _ => BooleanLiteral(lv == rv) } @@ -273,7 +225,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case RealPlus(l, r) => (e(l), e(r)) match { case (FractionalLiteral(ln, ld), FractionalLiteral(rn, rd)) => - normalizeFraction(FractionalLiteral((ln * rd + rn * ld), (ld * rd))) + normalizeFraction(FractionalLiteral(ln * rd + rn * ld, ld * rd)) case (le, re) => throw EvalError(typeErrorMsg(le, RealType)) } @@ -379,7 +331,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (e(l), e(r)) match { case (FractionalLiteral(ln, ld), FractionalLiteral(rn, rd)) => if (rn != 0) - normalizeFraction(FractionalLiteral((ln * rd), (ld * rn))) + normalizeFraction(FractionalLiteral(ln * rd, ld * rn)) else throw RuntimeError("Division by 0.") case (le,re) => throw EvalError(typeErrorMsg(le, RealType)) } @@ -426,8 +378,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 < i2) case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 < i2) case (a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => - val FractionalLiteral(n, _) = e(RealMinus(a, b)) - BooleanLiteral(n < 0) + val FractionalLiteral(n, _) = e(RealMinus(a, b)) + BooleanLiteral(n < 0) case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 < c2) case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } @@ -437,8 +389,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 > i2) case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 > i2) case (a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => - val FractionalLiteral(n, _) = e(RealMinus(a, b)) - BooleanLiteral(n > 0) + val FractionalLiteral(n, _) = e(RealMinus(a, b)) + BooleanLiteral(n > 0) case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 > c2) case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } @@ -448,8 +400,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 <= i2) case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 <= i2) case (a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => - val FractionalLiteral(n, _) = e(RealMinus(a, b)) - BooleanLiteral(n <= 0) + val FractionalLiteral(n, _) = e(RealMinus(a, b)) + BooleanLiteral(n <= 0) case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 <= c2) case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } @@ -459,8 +411,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 >= i2) case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 >= i2) case (a @ FractionalLiteral(_, _), b @ FractionalLiteral(_, _)) => - val FractionalLiteral(n, _) = e(RealMinus(a, b)) - BooleanLiteral(n >= 0) + val FractionalLiteral(n, _) = e(RealMinus(a, b)) + BooleanLiteral(n >= 0) case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 >= c2) case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) } @@ -475,21 +427,19 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case SetIntersection(s1,s2) => (e(s1), e(s2)) match { - case (f @ FiniteSet(els1, _), FiniteSet(els2, _)) => { + case (f @ FiniteSet(els1, _), FiniteSet(els2, _)) => val newElems = els1 intersect els2 val SetType(tpe) = f.getType FiniteSet(newElems, tpe) - } case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } case SetDifference(s1,s2) => (e(s1), e(s2)) match { - case (f @ FiniteSet(els1, _),FiniteSet(els2, _)) => { + case (f @ FiniteSet(els1, _),FiniteSet(els2, _)) => val SetType(tpe) = f.getType val newElems = els1 -- els2 FiniteSet(newElems, tpe) - } case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } @@ -504,7 +454,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case SetCardinality(s) => val sr = e(s) sr match { - case FiniteSet(els, _) => IntLiteral(els.size) + case FiniteSet(els, _) => InfiniteIntegerLiteral(els.size) case _ => throw EvalError(typeErrorMsg(sr, SetType(Untyped))) } @@ -512,71 +462,19 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int FiniteSet(els.map(e), base) case l @ Lambda(_, _) => - val (nl, structSubst) = normalizeStructure(l) + val (nl, structSubst) = normalizeStructure(matchToIfThenElse(l)) val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap - replaceFromIDs(mapping, nl) - - case PartialLambda(mapping, tpe) => - PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), tpe) - - case f @ Forall(fargs, TopLevelAnds(conjuncts)) => - val henkinModel: HenkinModel = gctx.model match { - case hm: HenkinModel => hm - case _ => throw EvalError("Can't evaluate foralls without henkin model") + val newLambda = replaceFromIDs(mapping, nl).asInstanceOf[Lambda] + if (!gctx.lambdas.isDefinedAt(newLambda)) { + gctx.lambdas += (newLambda -> nl.asInstanceOf[Lambda]) } + newLambda - e(andJoin(for (conj <- conjuncts) yield { - val vars = variablesOf(conj) - val args = fargs.map(_.id).filter(vars) - val quantified = args.toSet - - val matcherQuorums = extractQuorums(conj, quantified) - - val instantiations = matcherQuorums.flatMap { quorum => - var mappings: Seq[(Identifier, Int, Int)] = Seq.empty - var constraints: Seq[(Expr, Int, Int)] = Seq.empty - - for (((expr, args), qidx) <- quorum.zipWithIndex) { - val (qmappings, qconstraints) = args.zipWithIndex.partition { - case (Variable(id),aidx) => quantified(id) - case _ => false - } - - mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) - constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) - } - - var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty - val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { - val base :: others = es.toList.map(p => (p._2, p._3)) - equalities ++= others.map(p => base -> p) - (id -> base) - } - - val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { - case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d)) - } - - argSets.map { args => - val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { - case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } - }.toMap - - val map = mapping.map { case (id, key) => id -> argMap(key) } - val enabler = andJoin(constraints.map { - case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) - } ++ equalities.map { - case (k1, k2) => Equals(argMap(k1), argMap(k2)) - }) - - (enabler, map) - } - } + case PartialLambda(mapping, dflt, tpe) => + PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), dflt.map(e), tpe) - e(andJoin(instantiations.map { case (enabler, mapping) => - e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx) - })) - })) + case Forall(fargs, body) => + evalForall(fargs.map(_.id).toSet, body) case ArrayLength(a) => val FiniteArray(_, _, IntLiteral(length)) = e(a) @@ -631,9 +529,6 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (l, r) => throw EvalError(typeErrorMsg(l, m.getType)) } - case gv: GenericValue => - gv - case p : Passes => e(p.asConstraint) @@ -641,7 +536,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int implicit val debugSection = utils.DebugSectionSynthesis - val p = synthesis.Problem.fromChoose(choose) + val p = synthesis.Problem.fromSpec(choose.pred) ctx.reporter.debug("Executing choose!") @@ -702,6 +597,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int throw RuntimeError("MatchError: "+rscrut.asString+" did not match any of the cases") } + case gl: GenericValue => gl case fl : FractionalLiteral => normalizeFraction(fl) case l : Literal[_] => l @@ -735,11 +631,11 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int None } case (up @ UnapplyPattern(ob, _, subs), scrut) => - e(FunctionInvocation(up.unapplyFun, Seq(scrut))) match { - case CaseClass(CaseClassType(cd, _), Seq()) if cd == program.library.Nil.get => + e(functionInvocation(up.unapplyFun.fd, Seq(scrut))) match { + case CaseClass(CaseClassType(cd, _), Seq()) if cd == program.library.None.get => None - case CaseClass(CaseClassType(cd, _), Seq(arg)) if cd == program.library.Cons.get => - val res = subs zip unwrapTuple(arg, up.unapplyFun.returnType.isInstanceOf[TupleType]) map { + case CaseClass(CaseClassType(cd, _), Seq(arg)) if cd == program.library.Some.get => + val res = subs zip unwrapTuple(arg, subs.size) map { case (s,a) => matchesPattern(s,a) } if (res.forall(_.isDefined)) { @@ -784,6 +680,141 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } } - def typeErrorMsg(tree : Expr, expected : TypeTree) : String = s"Type error : expected ${expected.asString}, found ${tree.asString}." + + protected def evalForall(quants: Set[Identifier], body: Expr, check: Boolean = true)(implicit rctx: RC, gctx: GC): Expr = { + val henkinModel: HenkinModel = gctx.model match { + case hm: HenkinModel => hm + case _ => throw EvalError("Can't evaluate foralls without henkin model") + } + + val TopLevelAnds(conjuncts) = body + e(andJoin(conjuncts.flatMap { conj => + val vars = variablesOf(conj) + val quantified = quants.filter(vars) + + extractQuorums(conj, quantified).flatMap { case (qrm, others) => + val quorum = qrm.toList + + if (quorum.exists { case (TopLevelAnds(paths), _, _) => + val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty)) + e(p) == BooleanLiteral(false) + }) List(BooleanLiteral(true)) else { + + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + + for (((_, expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id),aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + def domain(expr: Expr): Set[Seq[Expr]] = henkinModel.domain(e(expr) match { + case l: Lambda => gctx.lambdas.getOrElse(l, l) + case ev => ev + }) + + val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { + case (acc, (_, expr, _)) => acc.flatMap(s => domain(expr).map(d => s :+ d)) + } + + argSets.map { args => + val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { + case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } + }.toMap + + val map = mapping.map { case (id, key) => id -> argMap(key) } + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) + } ++ equalities.map { + case (k1, k2) => Equals(argMap(k1), argMap(k2)) + }) + + val ctx = rctx.withNewVars(map) + if (e(enabler)(ctx, gctx) == BooleanLiteral(true)) { + if (gctx.check) { + for ((b,caller,args) <- others if e(b)(ctx, gctx) == BooleanLiteral(true)) { + val evArgs = args.map(arg => e(arg)(ctx, gctx)) + if (!domain(caller)(evArgs)) + throw QuantificationError("Unhandled transitive implication in " + replaceFromIDs(map, conj)) + } + } + + e(conj)(ctx, gctx) + } else { + BooleanLiteral(true) + } + } + } + } + })) match { + case res @ BooleanLiteral(true) if check => + if (gctx.check) { + checkForall(quants, body) match { + case status: ForallInvalid => + throw QuantificationError("Invalid forall: " + status.getMessage) + case _ => + // make sure the body doesn't contain matches or lets as these introduce new locals + val cleanBody = expandLets(matchToIfThenElse(body)) + val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] { + def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match { + case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path)) + case _ => None + } + + override def rec(e: Expr, path: Seq[Expr]): Expr = e match { + case l : Lambda => l + case _ => super.rec(e, path) + } + }.traverse(cleanBody) + + for ((caller, appArgs, paths) <- calls) { + val path = andJoin(paths.filter(expr => (variablesOf(expr) & quants).isEmpty)) + if (e(path) == BooleanLiteral(true)) e(caller) match { + case _: PartialLambda => // OK + case l: Lambda => + val nl @ Lambda(args, body) = gctx.lambdas.getOrElse(l, l) + val lambdaQuantified = (appArgs zip args).collect { + case (Variable(id), vd) if quants(id) => vd.id + }.toSet + + if (lambdaQuantified.nonEmpty) { + checkForall(lambdaQuantified, body) match { + case lambdaStatus: ForallInvalid => + throw QuantificationError("Invalid forall: " + lambdaStatus.getMessage) + case _ => // do nothing + } + + val axiom = Equals(Application(nl, args.map(_.toVariable)), nl.body) + if (evalForall(args.map(_.id).toSet, axiom, check = false) == BooleanLiteral(false)) { + throw QuantificationError("Unaxiomatic lambda " + l) + } + } + case f => + throw EvalError("Cannot apply non-lambda function " + f.asString) + } + } + } + } + + res + + // `res == false` means the quantification is valid since there effectivelly must + // exist an input for which the proposition doesn't hold + case res => res + } + } } + diff --git a/src/main/scala/leon/evaluators/ScalacEvaluator.scala b/src/main/scala/leon/evaluators/ScalacEvaluator.scala index 7c4d8ee92ad2379987c0563e388d9d4fdfb777fc..63bb9f6fa1193dee0fbda7a69500bd4c2e9fa9b8 100644 --- a/src/main/scala/leon/evaluators/ScalacEvaluator.scala +++ b/src/main/scala/leon/evaluators/ScalacEvaluator.scala @@ -11,13 +11,8 @@ import purescala.Expressions._ import purescala.Types._ import java.io.File -import java.nio.file.Files -import java.net.{URL, URLClassLoader} -import java.lang.reflect.{Constructor, Method} - -import frontends.scalac.FullScalaCompiler - -import scala.tools.nsc.{Settings=>NSCSettings,CompilerCommand} +import java.net.URLClassLoader +import java.lang.reflect.Constructor /** * Call scalac-compiled functions from within Leon @@ -25,7 +20,7 @@ import scala.tools.nsc.{Settings=>NSCSettings,CompilerCommand} * Known limitations: * - Multiple argument lists */ -class ScalacEvaluator(ev: Evaluator, ctx: LeonContext, pgm: Program) extends LeonComponent { +class ScalacEvaluator(ev: DeterministicEvaluator, ctx: LeonContext, pgm: Program) extends LeonComponent { implicit val _: Program = pgm val name = "Evaluator of external functions" @@ -300,11 +295,11 @@ class ScalacEvaluator(ev: Evaluator, ctx: LeonContext, pgm: Program) extends Leo /** * Black magic here we come! */ - import org.objectweb.asm.ClassReader; - import org.objectweb.asm.ClassWriter; - import org.objectweb.asm.ClassVisitor; - import org.objectweb.asm.MethodVisitor; - import org.objectweb.asm.Opcodes; + import org.objectweb.asm.ClassReader + import org.objectweb.asm.ClassWriter + import org.objectweb.asm.ClassVisitor + import org.objectweb.asm.MethodVisitor + import org.objectweb.asm.Opcodes class InterceptingClassLoader(p: ClassLoader) extends ClassLoader(p) { import java.io._ @@ -320,24 +315,24 @@ class ScalacEvaluator(ev: Evaluator, ctx: LeonContext, pgm: Program) extends Leo if (!(toInstrument contains name)) { super.loadClass(name, resolve) } else { - val bs = new ByteArrayOutputStream(); - val is = getResourceAsStream(name.replace('.', '/') + ".class"); - val buf = new Array[Byte](512); - var len = is.read(buf); + val bs = new ByteArrayOutputStream() + val is = getResourceAsStream(name.replace('.', '/') + ".class") + val buf = new Array[Byte](512) + var len = is.read(buf) while (len > 0) { - bs.write(buf, 0, len); + bs.write(buf, 0, len) len = is.read(buf) } // Transform bytecode - val cr = new ClassReader(bs.toByteArray()); - val cw = new ClassWriter(cr, ClassWriter.COMPUTE_FRAMES); - val cv = new LeonCallsClassVisitor(cw, name, toInstrument(name)); - cr.accept(cv, 0); + val cr = new ClassReader(bs.toByteArray()) + val cw = new ClassWriter(cr, ClassWriter.COMPUTE_FRAMES) + val cv = new LeonCallsClassVisitor(cw, name, toInstrument(name)) + cr.accept(cv, 0) - val res = cw.toByteArray(); + val res = cw.toByteArray() - defineClass(name, res, 0, res.length); + defineClass(name, res, 0, res.length) } } } @@ -451,7 +446,7 @@ class ScalacEvaluator(ev: Evaluator, ctx: LeonContext, pgm: Program) extends Leo unbox(fd.returnType) mv.visitInsn(returnOpCode(fd.returnType)) - mv.visitEnd(); + mv.visitEnd() } } } diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..9cc6dd132036ffdde4a5ef70e2be87e6276f9f3c --- /dev/null +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -0,0 +1,585 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package evaluators + +import purescala.Constructors._ +import purescala.ExprOps._ +import purescala.Extractors._ +import purescala.TypeOps._ +import purescala.Types._ +import purescala.Common.Identifier +import purescala.Definitions.{TypedFunDef, Program} +import purescala.Expressions._ + +import leon.solvers.SolverFactory +import leon.utils.StreamUtils._ + +class StreamEvaluator(ctx: LeonContext, prog: Program) + extends ContextualEvaluator(ctx, prog, 50000) + with NDEvaluator + with HasDefaultGlobalContext + with HasDefaultRecContext +{ + + val name = "ND-evaluator" + val description = "Non-deterministic interpreter for Leon programs that returns a Stream of solutions" + + protected[evaluators] def e(expr: Expr)(implicit rctx: RC, gctx: GC): Stream[Expr] = expr match { + case Variable(id) => + rctx.mappings.get(id).toStream + + case Application(caller, args) => + for { + cl <- e(caller).distinct + newArgs <- cartesianProduct(args map e).distinct + res <- cl match { + case l @ Lambda(params, body) => + val mapping = l.paramSubst(newArgs) + e(body)(rctx.withNewVars(mapping), gctx).distinct + case PartialLambda(mapping, _, _) => + // FIXME + mapping.collectFirst { + case (pargs, res) if (newArgs zip pargs).forall { case (f, r) => f == r } => + res + }.toStream + case _ => + Stream() + } + } yield res + + case Let(i,v,b) => + for { + ev <- e(v).distinct + eb <- e(b)(rctx.withNewVar(i, ev), gctx).distinct + } yield eb + + case Assert(cond, oerr, body) => + e(IfExpr(Not(cond), Error(expr.getType, oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) + + case en@Ensuring(body, post) => + if ( exists{ + case Hole(_,_) => true + case WithOracle(_,_) => true + case _ => false + }(en)) { + import synthesis.ConversionPhase.convert + e(convert(en, ctx)) + } else { + e(en.toAssert) + } + + case Error(tpe, desc) => + Stream() + + case IfExpr(cond, thenn, elze) => + e(cond).distinct.flatMap { + case BooleanLiteral(true) => e(thenn) + case BooleanLiteral(false) => e(elze) + case other => Stream() + }.distinct + + case FunctionInvocation(TypedFunDef(fd, Seq(tp)), Seq(set)) if fd == program.library.setToList.get => + val cons = program.library.Cons.get + val nil = CaseClass(CaseClassType(program.library.Nil.get, Seq(tp)), Seq()) + def mkCons(h: Expr, t: Expr) = CaseClass(CaseClassType(cons, Seq(tp)), Seq(h,t)) + e(set).distinct.collect { + case FiniteSet(els, _) => + els.foldRight(nil)(mkCons) + }.distinct + + case FunctionInvocation(tfd, args) => + if (gctx.stepsLeft < 0) { + return Stream() + } + gctx.stepsLeft -= 1 + + for { + evArgs <- cartesianProduct(args map e).distinct + // build a mapping for the function... + frame = rctx.withNewVars(tfd.paramSubst(evArgs)) + BooleanLiteral(true) <- e(tfd.precOrTrue)(frame, gctx).distinct + body <- tfd.body.orElse(rctx.mappings.get(tfd.id)).toStream + callResult <- e(body)(frame, gctx).distinct + BooleanLiteral(true) <- e(application(tfd.postOrTrue, Seq(callResult)))(frame, gctx).distinct + } yield callResult + + case And(args) if args.isEmpty => + Stream(BooleanLiteral(true)) + case And(args) => + e(args.head).distinct.flatMap { + case BooleanLiteral(false) => Stream(BooleanLiteral(false)) + case BooleanLiteral(true) => e(andJoin(args.tail)) + case other => Stream() + } + + case Or(args) if args.isEmpty => + Stream(BooleanLiteral(false)) + case Or(args) => + e(args.head).distinct.flatMap { + case BooleanLiteral(true) => Stream(BooleanLiteral(true)) + case BooleanLiteral(false) => e(orJoin(args.tail)) + case other => Stream() + } + + case Implies(lhs, rhs) => + e(Or(Not(lhs), rhs)) + + case l @ Lambda(_, _) => + val (nl, structSubst) = normalizeStructure(l) + val mapping = variablesOf(l).map(id => + structSubst(id) -> (e(Variable(id)) match { + case Stream(v) => v + case _ => return Stream() + }) + ).toMap + Stream(replaceFromIDs(mapping, nl)) + + // FIXME + case PartialLambda(mapping, tpe, df) => + def solveOne(pair: (Seq[Expr], Expr)) = { + val (args, res) = pair + for { + as <- cartesianProduct(args map e) + r <- e(res) + } yield as -> r + } + cartesianProduct(mapping map solveOne) map (PartialLambda(_, tpe, df)) // FIXME!!! + + case f @ Forall(fargs, TopLevelAnds(conjuncts)) => + Stream() // FIXME + /*def solveOne(conj: Expr) = { + val instantiations = forallInstantiations(gctx, fargs, conj) + for { + es <- cartesianProduct(instantiations.map { case (enabler, mapping) => + e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx) + }) + res <- e(andJoin(es)) + } yield res + } + + for { + conj <- cartesianProduct(conjuncts map solveOne) + res <- e(andJoin(conj)) + } yield res*/ + + case p : Passes => + e(p.asConstraint) + + case choose: Choose => + + // TODO add memoization + implicit val debugSection = utils.DebugSectionSynthesis + + val p = synthesis.Problem.fromSpec(choose.pred) + + ctx.reporter.debug("Executing choose!") + + val tStart = System.currentTimeMillis + + val solverf = SolverFactory.getFromSettings(ctx, program) + val solver = solverf.getNewSolver() + + try { + val eqs = p.as.map { + case id => + Equals(Variable(id), rctx.mappings(id)) + } + + val cnstr = andJoin(eqs ::: p.pc :: p.phi :: Nil) + solver.assertCnstr(cnstr) + + def getSolution = try { + solver.check match { + case Some(true) => + val model = solver.getModel + + val valModel = valuateWithModel(model) _ + + val res = p.xs.map(valModel) + val leonRes = tupleWrap(res) + val total = System.currentTimeMillis - tStart + + ctx.reporter.debug("Synthesis took " + total + "ms") + ctx.reporter.debug("Finished synthesis with " + leonRes.asString) + + Some(leonRes) + case _ => + None + } + } catch { + case _: Throwable => None + } + + Stream.iterate(getSolution)(prev => { + val ensureDistinct = Not(Equals(tupleWrap(p.xs.map{ _.toVariable}), prev.get)) + solver.assertCnstr(ensureDistinct) + val sol = getSolution + // Clean up when the stream ends + if (sol.isEmpty) { + solverf.reclaim(solver) + solverf.shutdown() + } + sol + }).takeWhile(_.isDefined).take(10).map(_.get) + // This take(10) is there because we are not working well with infinite streams yet... + } catch { + case e: Throwable => + solverf.reclaim(solver) + solverf.shutdown() + Stream() + } + + case MatchExpr(scrut, cases) => + + def matchesCase(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Stream[(MatchCase, Map[Identifier, Expr])] = { + import purescala.TypeOps.isSubtypeOf + + def matchesPattern(pat: Pattern, expr: Expr): Stream[Map[Identifier, Expr]] = (pat, expr) match { + case (InstanceOfPattern(ob, pct), e) => + (if (isSubtypeOf(e.getType, pct)) { + Some(obind(ob, e)) + } else { + None + }).toStream + case (WildcardPattern(ob), e) => + Stream(obind(ob, e)) + + case (CaseClassPattern(ob, pct, subs), CaseClass(ct, args)) => + if (pct == ct) { + val subMaps = (subs zip args).map{ case (s, a) => matchesPattern(s, a) } + cartesianProduct(subMaps).map( _.flatten.toMap ++ obind(ob, expr)) + } else { + Stream() + } + case (UnapplyPattern(ob, unapplyFun, subs), scrut) => + e(functionInvocation(unapplyFun.fd, Seq(scrut))) flatMap { + case CaseClass(CaseClassType(cd, _), Seq()) if cd == program.library.None.get => + None + case CaseClass(CaseClassType(cd, _), Seq(arg)) if cd == program.library.Some.get => + val subMaps = subs zip unwrapTuple(arg, subs.size) map { + case (s,a) => matchesPattern(s,a) + } + cartesianProduct(subMaps).map( _.flatten.toMap ++ obind(ob, expr)) + case other => + None + } + case (TuplePattern(ob, subs), Tuple(args)) => + if (subs.size == args.size) { + val subMaps = (subs zip args).map { case (s, a) => matchesPattern(s, a) } + cartesianProduct(subMaps).map(_.flatten.toMap ++ obind(ob, expr)) + } else Stream() + case (LiteralPattern(ob, l1) , l2 : Literal[_]) if l1 == l2 => + Stream(obind(ob,l1)) + case _ => Stream() + } + + def obind(ob: Option[Identifier], e: Expr): Map[Identifier, Expr] = { + Map[Identifier, Expr]() ++ ob.map(id => id -> e) + } + + caze match { + case SimpleCase(p, rhs) => + matchesPattern(p, scrut).map(r => + (caze, r) + ) + + case GuardedCase(p, g, rhs) => + for { + r <- matchesPattern(p, scrut) + BooleanLiteral(true) <- e(g)(rctx.withNewVars(r), gctx) + } yield (caze, r) + } + } + + for { + rscrut <- e(scrut) + cs <- cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty).toStream + (c, mp) <- cs + res <- e(c.rhs)(rctx.withNewVars(mp), gctx) + } yield res + + case Operator(es, _) => + cartesianProduct(es map e) flatMap { es => + try { + Some(step(expr, es)) + } catch { + case _: RuntimeError => + // EvalErrors stop the computation altogether + None + } + } + + case other => + context.reporter.error(other.getPos, "Error: don't know how to handle " + other.asString + " in Evaluator ("+other.getClass+").") + Stream() + } + + + protected def step(expr: Expr, subs: Seq[Expr])(implicit rctx: RC, gctx: GC): Expr = (expr, subs) match { + case (Tuple(_), ts) => + Tuple(ts) + + case (TupleSelect(_, i), rs) => + rs(i - 1) + + case (Assert(_, oerr, _), Seq(BooleanLiteral(cond), body)) => + if (cond) body + else throw RuntimeError(oerr.getOrElse("Assertion failed @" + expr.getPos)) + + case (Error(_, desc), _) => + throw RuntimeError("Error reached in evaluation: " + desc) + + case (FunctionInvocation(TypedFunDef(fd, Seq(tp)), _), Seq(FiniteSet(els, _))) if fd == program.library.setToList.get => + val cons = program.library.Cons.get + val nil = CaseClass(CaseClassType(program.library.Nil.get, Seq(tp)), Seq()) + def mkCons(h: Expr, t: Expr) = CaseClass(CaseClassType(cons, Seq(tp)), Seq(h, t)) + els.foldRight(nil)(mkCons) + + case (Not(_), Seq(BooleanLiteral(arg))) => + BooleanLiteral(!arg) + + case (Implies(_, _) Seq (BooleanLiteral(b1), BooleanLiteral(b2))) => + BooleanLiteral(!b1 || b2) + + case (Equals(_, _), Seq(lv, rv)) => + (lv, rv) match { + case (FiniteSet(el1, _), FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) + case (FiniteMap(el1, _, _), FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) + case (PartialLambda(m1, _, d1), PartialLambda(m2, _, d2)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) + case _ => BooleanLiteral(lv == rv) + } + + case (CaseClass(cd, _), args) => + CaseClass(cd, args) + + case (AsInstanceOf(_, ct), Seq(ce)) => + if (isSubtypeOf(ce.getType, ct)) { + ce + } else { + throw RuntimeError("Cast error: cannot cast " + ce.asString + " to " + ct.asString) + } + + case (IsInstanceOf(_, ct), Seq(ce)) => + BooleanLiteral(isSubtypeOf(ce.getType, ct)) + + case (CaseClassSelector(ct1, _, sel), Seq(CaseClass(ct2, args))) if ct1 == ct2 => + args(ct1.classDef.selectorID2Index(sel)) + + case (Plus(_, _), Seq(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2))) => + InfiniteIntegerLiteral(i1 + i2) + + case (Minus(_, _), Seq(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2))) => + InfiniteIntegerLiteral(i1 - i2) + + case (Times(_, _), Seq(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2))) => + InfiniteIntegerLiteral(i1 * i2) + + case (Division(_, _), Seq(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2))) => + if (i2 != BigInt(0)) InfiniteIntegerLiteral(i1 / i2) + else throw RuntimeError("Division by 0.") + + case (Remainder(_, _), Seq(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2))) => + if (i2 != BigInt(0)) InfiniteIntegerLiteral(i1 % i2) + else throw RuntimeError("Remainder of division by 0.") + + case (Modulo(_, _), Seq(InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2))) => + if (i2 < 0) + InfiniteIntegerLiteral(i1 mod (-i2)) + else if (i2 != BigInt(0)) + InfiniteIntegerLiteral(i1 mod i2) + else + throw RuntimeError("Modulo of division by 0.") + + case (UMinus(_), Seq(InfiniteIntegerLiteral(i))) => + InfiniteIntegerLiteral(-i) + + case (RealPlus(_, _), Seq(FractionalLiteral(ln, ld), FractionalLiteral(rn, rd))) => + normalizeFraction(FractionalLiteral(ln * rd + rn * ld, ld * rd)) + + case (RealMinus(_, _), Seq(FractionalLiteral(ln, ld), FractionalLiteral(rn, rd))) => + normalizeFraction(FractionalLiteral(ln * rd - rn * ld, ld * rd)) + + case (RealTimes(_, _), Seq(FractionalLiteral(ln, ld), FractionalLiteral(rn, rd))) => + normalizeFraction(FractionalLiteral(ln * rn, ld * rd)) + + case (RealDivision(_, _), Seq(FractionalLiteral(ln, ld), FractionalLiteral(rn, rd))) => + if (rn != 0) normalizeFraction(FractionalLiteral(ln * rd, ld * rn)) + else throw RuntimeError("Division by 0.") + + case (BVPlus(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 + i2) + + case (BVMinus(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 - i2) + + case (BVUMinus(_), Seq(IntLiteral(i))) => + IntLiteral(-i) + + case (RealUMinus(_), Seq(FractionalLiteral(n, d))) => + FractionalLiteral(-n, d) + + case (BVNot(_), Seq(IntLiteral(i))) => + IntLiteral(~i) + + case (BVTimes(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 * i2) + + case (BVDivision(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + if (i2 != 0) IntLiteral(i1 / i2) + else throw RuntimeError("Division by 0.") + + case (BVRemainder(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + if (i2 != 0) IntLiteral(i1 % i2) + else throw RuntimeError("Remainder of division by 0.") + + case (BVAnd(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 & i2) + + case (BVOr(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 | i2) + + case (BVXOr(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 ^ i2) + + case (BVShiftLeft(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 << i2) + + case (BVAShiftRight(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 >> i2) + + case (BVLShiftRight(_, _), Seq(IntLiteral(i1), IntLiteral(i2))) => + IntLiteral(i1 >>> i2) + + case (LessThan(_, _), Seq(el, er)) => + (el, er) match { + case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 < i2) + case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 < i2) + case (a@FractionalLiteral(_, _), b@FractionalLiteral(_, _)) => + val FractionalLiteral(n, _) = e(RealMinus(a, b)).head + BooleanLiteral(n < 0) + case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 < c2) + case (le, re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case (GreaterThan(_, _), Seq(el, er)) => + (el, er) match { + case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 > i2) + case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 > i2) + case (a@FractionalLiteral(_, _), b@FractionalLiteral(_, _)) => + val FractionalLiteral(n, _) = e(RealMinus(a, b)).head + BooleanLiteral(n > 0) + case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 > c2) + case (le, re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case (LessEquals(_, _), Seq(el, er)) => + (el, er) match { + case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 <= i2) + case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 <= i2) + case (a@FractionalLiteral(_, _), b@FractionalLiteral(_, _)) => + val FractionalLiteral(n, _) = e(RealMinus(a, b)).head + BooleanLiteral(n <= 0) + case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 <= c2) + case (le, re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case (GreaterEquals(_, _), Seq(el, er)) => + (el, er) match { + case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 >= i2) + case (InfiniteIntegerLiteral(i1), InfiniteIntegerLiteral(i2)) => BooleanLiteral(i1 >= i2) + case (a@FractionalLiteral(_, _), b@FractionalLiteral(_, _)) => + val FractionalLiteral(n, _) = e(RealMinus(a, b)).head + BooleanLiteral(n >= 0) + case (CharLiteral(c1), CharLiteral(c2)) => BooleanLiteral(c1 >= c2) + case (le, re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case (IsTyped(su@SetUnion(s1, s2), tpe), Seq( + IsTyped(FiniteSet(els1, _), SetType(tpe1)), + IsTyped(FiniteSet(els2, _), SetType(tpe2)) + )) => + FiniteSet( + els1 ++ els2, + leastUpperBound(tpe1, tpe2).getOrElse(throw EvalError(typeErrorMsg(su, tpe))) + ) + + case (IsTyped(su@SetIntersection(s1, s2), tpe), Seq( + IsTyped(FiniteSet(els1, _), SetType(tpe1)), + IsTyped(FiniteSet(els2, _), SetType(tpe2)) + )) => + FiniteSet( + els1 & els2, + leastUpperBound(tpe1, tpe2).getOrElse(throw EvalError(typeErrorMsg(su, tpe))) + ) + + case (IsTyped(su@SetDifference(s1, s2), tpe), Seq( + IsTyped(FiniteSet(els1, _), SetType(tpe1)), + IsTyped(FiniteSet(els2, _), SetType(tpe2)) + )) => + FiniteSet( + els1 -- els2, + leastUpperBound(tpe1, tpe2).getOrElse(throw EvalError(typeErrorMsg(su, tpe))) + ) + + case (ElementOfSet(_, _), Seq(e, FiniteSet(els, _))) => + BooleanLiteral(els.contains(e)) + + case (SubsetOf(_, _), Seq(FiniteSet(els1, _), FiniteSet(els2, _))) => + BooleanLiteral(els1.subsetOf(els2)) + + case (SetCardinality(_), Seq(FiniteSet(els, _))) => + IntLiteral(els.size) + + case (FiniteSet(_, base), els) => + FiniteSet(els.toSet, base) + + case (ArrayLength(_), Seq(FiniteArray(_, _, IntLiteral(length)))) => + IntLiteral(length) + + case (ArrayUpdated(_, _, _), Seq( + IsTyped(FiniteArray(elems, default, length), ArrayType(tp)), + IntLiteral(i), + v + )) => + finiteArray(elems.updated(i, v), default map {(_, length)}, tp) + + case (ArraySelect(_, _), Seq(fa@FiniteArray(elems, default, IntLiteral(length)), IntLiteral(index))) => + elems + .get(index) + .orElse(if (index >= 0 && index < length) default else None) + .getOrElse(throw RuntimeError(s"Array out of bounds error during evaluation:\n array = $fa, index = $index")) + + case (fa@FiniteArray(_, _, _), subs) => + val Operator(_, builder) = fa + builder(subs) + + case (fm@FiniteMap(_, _, _), subs) => + val Operator(_, builder) = fm + builder(subs) + + case (g@MapApply(_, _), Seq(FiniteMap(m, _, _), k)) => + m.getOrElse(k, throw RuntimeError("Key not found: " + k.asString)) + + case (u@IsTyped(MapUnion(_, _), MapType(kT, vT)), Seq(FiniteMap(m1, _, _), FiniteMap(m2, _, _))) => + FiniteMap(m1 ++ m2, kT, vT) + + case (i@MapIsDefinedAt(_, _), Seq(FiniteMap(elems, _, _), k)) => + BooleanLiteral(elems.contains(k)) + + case (gv: GenericValue, Seq()) => + gv + + case (fl: FractionalLiteral, Seq()) => + normalizeFraction(fl) + + case (l: Literal[_], Seq()) => + l + + case (other, _) => + context.reporter.error(other.getPos, "Error: don't know how to handle " + other.asString + " in Evaluator ("+other.getClass+").") + throw EvalError("Unhandled case in Evaluator: " + other.asString) + + } + +} diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index ec977763f3da3f583d225cbdde8dd98aa33f312b..4c0b1f39c9126e4ccf0da6db389394fe0c33d294 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -6,7 +6,6 @@ package evaluators import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ -import purescala.Quantification._ import purescala.Types._ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) extends RecursiveEvaluator(ctx, prog, maxSteps) { @@ -15,11 +14,12 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex def initRC(mappings: Map[Identifier, Expr]) = TracingRecContext(mappings, 2) - def initGC(model: solvers.Model) = new TracingGlobalContext(Nil, model) + def initGC(model: solvers.Model, check: Boolean) = new TracingGlobalContext(Nil, model, check) - class TracingGlobalContext(var values: List[(Tree, Expr)], model: solvers.Model) extends GlobalContext(model) + class TracingGlobalContext(var values: List[(Tree, Expr)], model: solvers.Model, check: Boolean) + extends GlobalContext(model, this.maxSteps, check) - case class TracingRecContext(mappings: Map[Identifier, Expr], tracingFrames: Int) extends RecContext { + case class TracingRecContext(mappings: Map[Identifier, Expr], tracingFrames: Int) extends RecContext[TracingRecContext] { def newVars(news: Map[Identifier, Expr]) = copy(mappings = news) } diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index db477ec494742b0dc2b0e42e098676c4ab55282d..9720666d145de9c9d1bf7af5ff2a91548509b8e2 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -509,9 +509,8 @@ trait ASTExtractors { true case _ => false } - } - + object ExDefaultValueFunction{ /** Matches a function that defines the default value of a parameter */ def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, String, Int, Tree)] = { @@ -520,11 +519,11 @@ trait ASTExtractors { case DefDef(_, name, tparams, vparamss, tpt, rhs) if( vparamss.size <= 1 && name != nme.CONSTRUCTOR && sym.isSynthetic ) => - + // Split the name into pieces, to find owner of the parameter + param.index // Form has to be <owner name>$default$<param index> val symPieces = sym.name.toString.reverse.split("\\$", 3).reverseMap(_.reverse) - + try { if (symPieces(1) != "default" || symPieces(0) == "copy") throw new IllegalArgumentException("") val ownerString = symPieces(0) @@ -574,6 +573,15 @@ trait ASTExtractors { } } + object ExOldExpression { + def unapply(tree: Apply) : Option[Symbol] = tree match { + case a @ Apply(TypeApply(ExSymbol("leon", "lang", "old"), List(tpe)), List(arg)) => + Some(arg.symbol) + case _ => + None + } + } + object ExHoleExpression { def unapply(tree: Tree) : Option[(Tree, List[Tree])] = tree match { case a @ Apply(TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "$qmark"), List(tpt)), args1) => diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 0a75dfebc5ae78d8147d0f601d61d38e2e0c05e2..afd47471513c40222305ec09958d9403aa5162a6 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -246,13 +246,13 @@ trait CodeExtraction extends ASTExtractors { // ignore None - case t@ExAbstractClass(o2, sym, _) => + case t @ ExAbstractClass(o2, sym, _) => Some(getClassDef(sym, t.pos)) - case t@ExCaseClass(o2, sym, args, _) => + case t @ ExCaseClass(o2, sym, args, _) => Some(getClassDef(sym, t.pos)) - case t@ExObjectDef(n, templ) => + case t @ ExObjectDef(n, templ) => // Module val id = FreshIdentifier(n) val leonDefs = templ.body.flatMap { @@ -481,16 +481,31 @@ trait CodeExtraction extends ASTExtractors { Nil } - val tparams = tparamsMap.map(t => TypeParameterDef(t._2)) - - val defCtx = DefContext(tparamsMap.toMap) - val parent = sym.tpe.parents.headOption match { case Some(TypeRef(_, parentSym, tps)) if seenClasses contains parentSym => getClassDef(parentSym, sym.pos) match { case acd: AbstractClassDef => + val defCtx = DefContext(tparamsMap.toMap) val newTps = tps.map(extractType(_)(defCtx, sym.pos)) - Some(AbstractClassType(acd, newTps)) + val zip = (newTps zip tparamsMap.map(_._2)) + if (newTps.size != tparamsMap.size) { + outOfSubsetError(sym.pos, "Child classes should have the same number of type parameters as their parent") + None + } else if (zip.exists { + case (TypeParameter(_), _) => false + case _ => true + }) { + outOfSubsetError(sym.pos, "Child class type params should have a simple mapping to parent params") + None + } else if (zip.exists { + case (TypeParameter(id), ctp) => id.name != ctp.id.name + case _ => false + }) { + outOfSubsetError(sym.pos, "Child type params should be identical to parent class's (e.g. C[T1,T2] extends P[T1,T2])") + None + } else { + Some(acd.typed -> acd.tparams) + } case cd => outOfSubsetError(sym.pos, s"Class $id cannot extend ${cd.id}") @@ -501,11 +516,18 @@ trait CodeExtraction extends ASTExtractors { None } + val tparams = parent match { + case Some((p, tparams)) => tparams + case None => tparamsMap.map(t => TypeParameterDef(t._2)) + } + + val defCtx = DefContext((tparamsMap.map(_._1) zip tparams.map(_.tp)).toMap) + // Extract class val cd = if (sym.isAbstractClass) { - AbstractClassDef(id, tparams, parent) + AbstractClassDef(id, tparams, parent.map(_._1)) } else { - CaseClassDef(id, tparams, parent, sym.isModuleClass) + CaseClassDef(id, tparams, parent.map(_._1), sym.isModuleClass) } cd.setPos(sym.pos) //println(s"Registering $sym") @@ -513,7 +535,7 @@ trait CodeExtraction extends ASTExtractors { cd.addFlags(annotationsOf(sym).map { case (name, args) => ClassFlag.fromName(name, args) }.toSet) // Register parent - parent.foreach(_.classDef.registerChild(cd)) + parent.map(_._1).foreach(_.classDef.registerChild(cd)) // Extract case class fields cd match { @@ -522,30 +544,14 @@ trait CodeExtraction extends ASTExtractors { val fields = args.map { case (fsym, t) => val tpe = leonType(t.tpt.tpe)(defCtx, fsym.pos) val id = cachedWithOverrides(fsym, Some(ccd), tpe) - LeonValDef(id.setPos(t.pos), Some(tpe)).setPos(t.pos) + if (tpe != id.getType) println(tpe, id.getType) + LeonValDef(id.setPos(t.pos)).setPos(t.pos) } //println(s"Fields of $sym") ccd.setFields(fields) case _ => } - // Validates type parameters - parent foreach { pct => - if(pct.classDef.tparams.size == tparams.size) { - val pcd = pct.classDef - val ptps = pcd.tparams.map(_.tp) - - val targetType = AbstractClassType(pcd, ptps) - val fromChild = cd.typed(ptps).parent.get - - if (fromChild != targetType) { - outOfSubsetError(sym.pos, "Child type should form a simple bijection with parent class type (e.g. C[T1,T2] extends P[T1,T2])") - } - } else { - outOfSubsetError(sym.pos, "Child classes should have the same number of type parameters as their parent") - } - } - //println(s"Body of $sym") // We collect the methods and fields @@ -629,9 +635,10 @@ trait CodeExtraction extends ASTExtractors { val newParams = sym.info.paramss.flatten.map{ sym => val ptpe = leonType(sym.tpe)(nctx, sym.pos) - val newID = FreshIdentifier(sym.name.toString, ptpe).setPos(sym.pos) + val tpe = if (sym.isByNameParam) FunctionType(Seq(), ptpe) else ptpe + val newID = FreshIdentifier(sym.name.toString, tpe).setPos(sym.pos) owners += (newID -> None) - LeonValDef(newID).setPos(sym.pos) + LeonValDef(newID, sym.isByNameParam).setPos(sym.pos) } val tparamsDef = tparams.map(t => TypeParameterDef(t._2)) @@ -768,8 +775,9 @@ trait CodeExtraction extends ASTExtractors { vd.defaultValue = paramsToDefaultValues.get(s.symbol) } - val newVars = for ((s, vd) <- params zip funDef.params) yield { - s.symbol -> (() => Variable(vd.id)) + val newVars = for ((s, vd) <- params zip funDef.params) yield s.symbol -> { + if (s.symbol.isByNameParam) () => Application(Variable(vd.id), Seq()) + else () => Variable(vd.id) } val fctx = dctx.withNewVars(newVars).copy(isExtern = funDef.annotations("extern")) @@ -1054,6 +1062,15 @@ trait CodeExtraction extends ASTExtractors { val tupleExprs = exprs.map(e => extractTree(e)) Tuple(tupleExprs) + case ex@ExOldExpression(sym) if dctx.isVariable(sym) => + dctx.vars.get(sym).orElse(dctx.mutableVars.get(sym)) match { + case Some(builder) => + val Variable(id) = builder() + Old(id).setPos(ex.pos) + case None => + outOfSubsetError(current, "old can only be used with variables") + } + case ExErrorExpression(str, tpt) => Error(extractType(tpt), str) @@ -1083,11 +1100,11 @@ trait CodeExtraction extends ASTExtractors { } val restTree = rest match { - case Some(rst) => { + case Some(rst) => val nctx = dctx.withNewVar(vs -> (() => Variable(newID))) extractTree(rst)(nctx) - } - case None => UnitLiteral() + case None => + UnitLiteral() } rest = None @@ -1105,7 +1122,7 @@ trait CodeExtraction extends ASTExtractors { val oldCurrentFunDef = currentFunDef - val funDefWithBody = extractFunBody(fd, params, b)(newDctx.copy(mutableVars = Map())) + val funDefWithBody = extractFunBody(fd, params, b)(newDctx) currentFunDef = oldCurrentFunDef @@ -1200,11 +1217,11 @@ trait CodeExtraction extends ASTExtractors { } getOwner(lhsRec) match { - case Some(Some(fd)) if fd != currentFunDef => - outOfSubsetError(tr, "cannot update an array that is not defined locally") + // case Some(Some(fd)) if fd != currentFunDef => + // outOfSubsetError(tr, "cannot update an array that is not defined locally") - case Some(None) => - outOfSubsetError(tr, "cannot update an array that is not defined locally") + // case Some(None) => + // outOfSubsetError(tr, "cannot update an array that is not defined locally") case Some(_) => @@ -1521,23 +1538,24 @@ trait CodeExtraction extends ASTExtractors { val fd = getFunDef(sym, c.pos) val newTps = tps.map(t => extractType(t)) + val argsByName = (fd.params zip args).map(p => if (p._1.isLazy) Lambda(Seq(), p._2) else p._2) - FunctionInvocation(fd.typed(newTps), args) + FunctionInvocation(fd.typed(newTps), argsByName) case (IsTyped(rec, ct: ClassType), _, args) if isMethod(sym) => val fd = getFunDef(sym, c.pos) val cd = methodToClass(fd) val newTps = tps.map(t => extractType(t)) + val argsByName = (fd.params zip args).map(p => if (p._1.isLazy) Lambda(Seq(), p._2) else p._2) - MethodInvocation(rec, cd, fd.typed(newTps), args) + MethodInvocation(rec, cd, fd.typed(newTps), argsByName) case (IsTyped(rec, ft: FunctionType), _, args) => application(rec, args) - case (IsTyped(rec, cct: CaseClassType), name, Nil) if cct.fields.exists(_.id.name == name) => - - val fieldID = cct.fields.find(_.id.name == name).get.id + case (IsTyped(rec, cct: CaseClassType), name, Nil) if cct.classDef.fields.exists(_.id.name == name) => + val fieldID = cct.classDef.fields.find(_.id.name == name).get.id caseClassSelector(cct, rec, fieldID) @@ -1623,6 +1641,9 @@ trait CodeExtraction extends ASTExtractors { or(a1, a2) // Set methods + case (IsTyped(a1, SetType(b1)), "size", Nil) => + SetCardinality(a1) + //case (IsTyped(a1, SetType(b1)), "min", Nil) => // SetMin(a1) @@ -1715,6 +1736,14 @@ trait CodeExtraction extends ASTExtractors { LessEquals(a1, a2) case (_, name, _) => + rrec match { + case CaseClass(ct, fields) => + println(ct.fieldsTypes) + println(rrec.getType) + println(ct) + println(fields.map(f => f -> f.getType)) + case _ => + } outOfSubsetError(tr, "Unknown call to "+name) } @@ -1753,6 +1782,9 @@ trait CodeExtraction extends ASTExtractors { case tpe if tpe == NothingClass.tpe => Untyped + case ct: ConstantType => + extractType(ct.value.tpe) + case TypeRef(_, sym, _) if isBigIntSym(sym) => IntegerType @@ -1842,7 +1874,11 @@ trait CodeExtraction extends ASTExtractors { case AnnotatedType(_, tpe) => extractType(tpe) case _ => - outOfSubsetError(tpt.typeSymbol.pos, "Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")") + if (tpt ne null) { + outOfSubsetError(tpt.typeSymbol.pos, "Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")") + } else { + outOfSubsetError(NoPosition, "Tree with null-pointer as type found") + } } private def getClassType(sym: Symbol, tps: List[LeonType])(implicit dctx: DefContext) = { diff --git a/src/main/scala/leon/invariant/datastructure/DisjointSet.scala b/src/main/scala/leon/invariant/datastructure/DisjointSets.scala similarity index 94% rename from src/main/scala/leon/invariant/datastructure/DisjointSet.scala rename to src/main/scala/leon/invariant/datastructure/DisjointSets.scala index 4cab7291bfa58dd3d79993233251d7dd589e0019..003bb31d1233ce7ab3ad1fe0dbe775008ae7f806 100644 --- a/src/main/scala/leon/invariant/datastructure/DisjointSet.scala +++ b/src/main/scala/leon/invariant/datastructure/DisjointSets.scala @@ -1,8 +1,6 @@ package leon package invariant.datastructure -import scala.collection.mutable.{ Map => MutableMap } -import scala.annotation.migration import scala.collection.mutable.{Map => MutableMap} class DisjointSets[T] { diff --git a/src/main/scala/leon/invariant/datastructure/Graph.scala b/src/main/scala/leon/invariant/datastructure/Graph.scala index 484f04668c08a6b02f8acfd2db12b343e4b4c303..7262b4b4ef4932a89f76406ed49e4d5e00ee9be1 100644 --- a/src/main/scala/leon/invariant/datastructure/Graph.scala +++ b/src/main/scala/leon/invariant/datastructure/Graph.scala @@ -66,7 +66,7 @@ class DirectedGraph[T] { } }) } - if (!queue.isEmpty) { + if (queue.nonEmpty) { val (head :: tail) = queue queue = tail BFSReachRecur(head) diff --git a/src/main/scala/leon/invariant/datastructure/Maps.scala b/src/main/scala/leon/invariant/datastructure/Maps.scala index 777a79375d874588f3be108d702aa6dc8118df79..ca2dcb98e99f247651dfd1b8632c570630fce7f0 100644 --- a/src/main/scala/leon/invariant/datastructure/Maps.scala +++ b/src/main/scala/leon/invariant/datastructure/Maps.scala @@ -1,13 +1,6 @@ package leon -package invariant.util +package invariant.datastructure -import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } import scala.annotation.tailrec class MultiMap[A, B] extends scala.collection.mutable.HashMap[A, scala.collection.mutable.Set[B]] with scala.collection.mutable.MultiMap[A, B] { diff --git a/src/main/scala/leon/invariant/engine/CompositionalTemplateSolver.scala b/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala similarity index 92% rename from src/main/scala/leon/invariant/engine/CompositionalTemplateSolver.scala rename to src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala index 647616fce1bab0eb01182ad50d75c060ce4ff737..7a9605c922e21373abb6df06fe9d585a104da2ba 100644 --- a/src/main/scala/leon/invariant/engine/CompositionalTemplateSolver.scala +++ b/src/main/scala/leon/invariant/engine/CompositionalTimeBoundSolver.scala @@ -1,16 +1,13 @@ package leon package invariant.engine -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Types._ -import invariant.templateSolvers._ import transformations._ import invariant.structure.FunctionUtils._ -import transformations.InstUtil._ import leon.invariant.structure.Formula import leon.invariant.structure.Call import leon.invariant.util._ @@ -20,7 +17,6 @@ import leon.solvers.Model import Util._ import PredicateUtil._ import ProgramUtil._ -import SolverUtil._ class CompositionalTimeBoundSolver(ctx: InferenceContext, prog: Program, rootFd: FunDef) extends FunctionTemplateSolver { @@ -183,23 +179,21 @@ class CompositionalTimeBoundSolver(ctx: InferenceContext, prog: Program, rootFd: var timeTmpl: Option[Expr] = None var recTmpl: Option[Expr] = None var othersTmpls: Seq[Expr] = Seq[Expr]() - tmplConjuncts.foreach(conj => { - conj match { - case Operator(Seq(lhs, _), _) if (tupleSelectToInst.contains(lhs)) => - tupleSelectToInst(lhs) match { - case n if n == TPR.name => - tprTmpl = Some(conj) - case n if n == Time.name => - timeTmpl = Some(conj) - case n if n == Rec.name => - recTmpl = Some(conj) - case _ => - othersTmpls = othersTmpls :+ conj - } - case _ => - othersTmpls = othersTmpls :+ conj - } - }) + tmplConjuncts.foreach { + case conj@Operator(Seq(lhs, _), _) if (tupleSelectToInst.contains(lhs)) => + tupleSelectToInst(lhs) match { + case n if n == TPR.name => + tprTmpl = Some(conj) + case n if n == Time.name => + timeTmpl = Some(conj) + case n if n == Rec.name => + recTmpl = Some(conj) + case _ => + othersTmpls = othersTmpls :+ conj + } + case conj => + othersTmpls = othersTmpls :+ conj + } (tprTmpl, recTmpl, timeTmpl, othersTmpls) } } diff --git a/src/main/scala/leon/invariant/engine/ConstraintTracker.scala b/src/main/scala/leon/invariant/engine/ConstraintTracker.scala index 3058a267eceef7c5ec1aa537c4766bee15a93204..95927a66ad3bda803a329b61e284f854a94bfcb4 100644 --- a/src/main/scala/leon/invariant/engine/ConstraintTracker.scala +++ b/src/main/scala/leon/invariant/engine/ConstraintTracker.scala @@ -1,19 +1,8 @@ package leon package invariant.engine -import z3.scala._ -import purescala._ -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import evaluators._ -import java.io._ - -import invariant.factories._ -import invariant.util._ import invariant.structure._ class ConstraintTracker(ctx : InferenceContext, program: Program, rootFun : FunDef/*, temFactory: TemplateFactory*/) { diff --git a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala index e79926e62158a095b755b48db43c8f3e46c41590..3ef178a03095a74dc8233dfd1ec3739abbdb3cef 100644 --- a/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala +++ b/src/main/scala/leon/invariant/engine/InferInvariantsPhase.scala @@ -1,21 +1,7 @@ package leon package invariant.engine -import purescala.Common._ import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.Expressions._ -import purescala.Extractors._ -import purescala.Types._ -import verification.VerificationReport -import invariant.templateSolvers._ -import invariant.factories._ -import invariant.util._ -import invariant.structure.FunctionUtils._ -import invariant.structure._ -import transformations._ -import leon.purescala.ScalaPrinter -import leon.purescala.PrettyPrinter /** * @author ravi diff --git a/src/main/scala/leon/invariant/engine/InferenceContext.scala b/src/main/scala/leon/invariant/engine/InferenceContext.scala index b735cb8de7101f3bf109c6318cc7a05c571ba4b7..a6cdc317d1fcd17aba9abaa5199975e664ada9bf 100644 --- a/src/main/scala/leon/invariant/engine/InferenceContext.scala +++ b/src/main/scala/leon/invariant/engine/InferenceContext.scala @@ -12,7 +12,6 @@ import invariant.util._ import verification._ import verification.VCKinds import InferInvariantsPhase._ -import Util._ import ProgramUtil._ /** @@ -71,13 +70,13 @@ class InferenceContext(val initProgram: Program, val leonContext: LeonContext) { instrumentedProg.definedFunctions.foreach((fd) => { if (!foundStrongest && fd.hasPostcondition) { val cond = fd.postcondition.get - postTraversal((e) => e match { + postTraversal { case Equals(_, _) => { rel = Equals.apply _ foundStrongest = true } case _ => ; - })(cond) + }(cond) } }) rel diff --git a/src/main/scala/leon/invariant/engine/InferenceEngine.scala b/src/main/scala/leon/invariant/engine/InferenceEngine.scala index 85491acf1924560e5fe020a38176b513da989690..6ca6ef6327d00465bbbe5141bd69b6ce1e6338a6 100644 --- a/src/main/scala/leon/invariant/engine/InferenceEngine.scala +++ b/src/main/scala/leon/invariant/engine/InferenceEngine.scala @@ -1,31 +1,18 @@ package leon package invariant.engine -import z3.scala._ -import purescala.Common._ import purescala.Definitions._ -import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import solvers._ import java.io._ -import verification.VerificationReport import verification.VC import scala.util.control.Breaks._ -import invariant.templateSolvers._ import invariant.factories._ import invariant.util._ -import invariant.util.Util._ -import invariant.structure._ import invariant.structure.FunctionUtils._ -import leon.invariant.factories.TemplateFactory import transformations._ import leon.utils._ import Util._ -import PredicateUtil._ import ProgramUtil._ -import SolverUtil._ /** * @author ravi @@ -85,13 +72,13 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { } else { var remFuncs = functionsToAnalyze var b = 200 - var maxCegisBound = 200 + val maxCegisBound = 200 breakable { while (b <= maxCegisBound) { Stats.updateCumStats(1, "CegisBoundsTried") val succeededFuncs = analyseProgram(program, remFuncs, progressCallback) remFuncs = remFuncs.filterNot(succeededFuncs.contains _) - if (remFuncs.isEmpty) break; + if (remFuncs.isEmpty) break b += 5 //increase bounds in steps of 5 } //println("Inferrence did not succeeded for functions: " + remFuncs.map(_.id)) @@ -200,7 +187,7 @@ class InferenceEngine(ctx: InferenceContext) extends Interruptible { first = false ic } - progressCallback.map(cb => cb(inferCond)) + progressCallback.foreach(cb => cb(inferCond)) } val funsWithTemplates = inferredFuns.filter { fd => val origFd = functionByName(fd.id.name, startProg).get diff --git a/src/main/scala/leon/invariant/engine/InferenceReport.scala b/src/main/scala/leon/invariant/engine/InferenceReport.scala index d5f0196278d04489f657940caf74e30009fd02b9..4b0fb349f591780da0856e4ccbbe7e49f30c8d83 100644 --- a/src/main/scala/leon/invariant/engine/InferenceReport.scala +++ b/src/main/scala/leon/invariant/engine/InferenceReport.scala @@ -9,8 +9,6 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Definitions._ import purescala.Common._ -import invariant.templateSolvers._ -import invariant.factories._ import invariant.util._ import invariant.structure._ import leon.transformations.InstUtil @@ -18,7 +16,6 @@ import leon.purescala.PrettyPrinter import Util._ import PredicateUtil._ import ProgramUtil._ -import SolverUtil._ import FunctionUtils._ import purescala._ @@ -64,7 +61,7 @@ class InferenceReport(fvcs: Map[FunDef, List[VC]], program: Program)(implicit ct "║ └─────────┘" + (" " * (size - 12)) + "║" private def infoLine(str: String, size: Int): String = { - "║ " + str + (" " * (size - str.size - 2)) + " ║" + "║ " + str + (" " * (size - str.length - 2)) + " ║" } private def fit(str: String, maxLength: Int): String = { @@ -77,11 +74,11 @@ class InferenceReport(fvcs: Map[FunDef, List[VC]], program: Program)(implicit ct private def funName(fd: FunDef) = InstUtil.userFunctionName(fd) - override def summaryString: String = if (conditions.size > 0) { - val maxTempSize = (conditions.map(_.status.size).max + 3) + override def summaryString: String = if (conditions.nonEmpty) { + val maxTempSize = (conditions.map(_.status.length).max + 3) val outputStrs = conditions.map(vc => { val timeStr = vc.time.map(t => "%-3.3f".format(t)).getOrElse("") - "%-15s %s %-4s".format(fit(funName(vc.fd), 15), vc.status + (" " * (maxTempSize - vc.status.size)), timeStr) + "%-15s %s %-4s".format(fit(funName(vc.fd), 15), vc.status + (" " * (maxTempSize - vc.status.length)), timeStr) }) val summaryStr = { val totalTime = conditions.foldLeft(0.0)((a, ic) => a + ic.time.getOrElse(0.0)) @@ -89,7 +86,7 @@ class InferenceReport(fvcs: Map[FunDef, List[VC]], program: Program)(implicit ct "total: %-4d inferred: %-4d unknown: %-4d time: %-3.3f".format( conditions.size, inferredConds, conditions.size - inferredConds, totalTime) } - val entrySize = (outputStrs :+ summaryStr).map(_.size).max + 2 + val entrySize = (outputStrs :+ summaryStr).map(_.length).max + 2 infoHeader(entrySize) + outputStrs.map(str => infoLine(str, entrySize)).mkString("\n", "\n", "\n") + @@ -129,7 +126,7 @@ object InferenceReportUtil { def fullNameWoInst(fd: FunDef) = { val splits = DefOps.fullName(fd)(ctx.inferProgram).split("-") - if (!splits.isEmpty) splits(0) + if (splits.nonEmpty) splits(0) else "" } @@ -148,8 +145,8 @@ object InferenceReportUtil { } def mapExpr(ine: Expr): Expr = { - val replaced = simplePostTransform((e: Expr) => e match { - case FunctionInvocation(TypedFunDef(fd, targs), args) => + val replaced = simplePostTransform { + case e@FunctionInvocation(TypedFunDef(fd, targs), args) => if (initToOutput.contains(fd)) { FunctionInvocation(TypedFunDef(initToOutput(fd), targs), args) } else { @@ -159,8 +156,8 @@ object InferenceReportUtil { case _ => e } } - case _ => e - })(ine) + case e => e + }(ine) replaced } // copy bodies and specs diff --git a/src/main/scala/leon/invariant/engine/SpecInstatiator.scala b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala similarity index 100% rename from src/main/scala/leon/invariant/engine/SpecInstatiator.scala rename to src/main/scala/leon/invariant/engine/SpecInstantiator.scala diff --git a/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala b/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala index c568910684e89d78862a516583eb0baf31174888..1e717dbc1a23723680ecdffeb27b3f78553d1747 100644 --- a/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala +++ b/src/main/scala/leon/invariant/engine/TemplateEnumerator.scala @@ -1,26 +1,15 @@ package leon package invariant.engine -import z3.scala._ -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ -import scala.collection.mutable.{ Set => MutableSet } -import java.io._ -import scala.collection.mutable.{ Set => MutableSet } -import scala.collection.mutable.{ Set => MutableSet } -import invariant.templateSolvers._ import invariant.factories._ import invariant.util._ -import invariant.structure._ -import Util._ -import PredicateUtil._ import ProgramUtil._ +import scala.collection.mutable.{Set => MutableSet} /** * An enumeration based template generator. * Enumerates all numerical terms in some order (this enumeration is incomplete for termination). @@ -126,7 +115,7 @@ class FunctionTemplateEnumerator(rootFun: FunDef, prog: Program, op: (Expr, Expr if (fun != rootFun && !callGraph.transitivelyCalls(fun, rootFun)) { //check if every argument has at least one satisfying assignment? - if (fun.params.filter((vardecl) => !ttCurrent.contains(vardecl.getType)).isEmpty) { + if (!fun.params.exists((vardecl) => !ttCurrent.contains(vardecl.getType))) { //here compute all the combinations val newcalls = generateFunctionCalls(fun) @@ -153,7 +142,7 @@ class FunctionTemplateEnumerator(rootFun: FunDef, prog: Program, op: (Expr, Expr //return all the integer valued terms of newTerms //++ newTerms.getOrElse(Int32Type, Seq[Expr]()) (for now not handling int 32 terms) val numericTerms = (newTerms.getOrElse(RealType, Seq[Expr]()) ++ newTerms.getOrElse(IntegerType, Seq[Expr]())).toSeq - if (!numericTerms.isEmpty) { + if (numericTerms.nonEmpty) { //create a linear combination of intTerms val newTemp = numericTerms.foldLeft(null: Expr)((acc, t: Expr) => { val summand = Times(t, TemplateIdFactory.freshTemplateVar(): Expr) diff --git a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala index 4bb5bbc3d76a341151ff9478b139f2ad8157485d..3fe13e7c93fefdfcc4321159b0fa7b354788a801 100644 --- a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala +++ b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala @@ -1,31 +1,24 @@ package leon package invariant.engine -import z3.scala._ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ import purescala.DefOps._ -import solvers._ -import solvers.z3.FairZ3Solver -import java.io._ import purescala.ScalaPrinter + +import solvers._ import verification._ -import scala.reflect.runtime.universe -import invariant.templateSolvers._ import invariant.factories._ import invariant.util._ import invariant.structure._ import transformations._ import FunctionUtils._ -import leon.invariant.templateSolvers.ExtendedUFSolver import Util._ import PredicateUtil._ import ProgramUtil._ -import SolverUtil._ /** * @author ravi @@ -119,7 +112,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F case (Some(model), callsInPath) => toRefineCalls = callsInPath //Validate the model here - instantiateAndValidateModel(model, constTracker.getFuncs.toSeq) + instantiateAndValidateModel(model, constTracker.getFuncs) Some(InferResult(true, Some(model), constTracker.getFuncs.toList)) case (None, callsInPath) => @@ -129,7 +122,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F } } } - } while (!infRes.isDefined) + } while (infRes.isEmpty) infRes } @@ -227,7 +220,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F val resvar = FreshIdentifier("res", fd.returnType, true) // FIXME: Is this correct (ResultVariable(fd.returnType) -> resvar.toVariable)) val ninv = replace(Map(ResultVariable(fd.returnType) -> resvar.toVariable), inv) - Some(Lambda(Seq(ValDef(resvar, Some(fd.returnType))), ninv)) + Some(Lambda(Seq(ValDef(resvar)), ninv)) } } else if (fd.postcondition.isDefined) { val Lambda(resultBinder, _) = fd.postcondition.get diff --git a/src/main/scala/leon/invariant/factories/AxiomFactory.scala b/src/main/scala/leon/invariant/factories/AxiomFactory.scala index 6fdf7fa633def45017de8df5fbab4250ce70928e..907db46ab1cf1fbd4efcc19f58ae5867142575f6 100644 --- a/src/main/scala/leon/invariant/factories/AxiomFactory.scala +++ b/src/main/scala/leon/invariant/factories/AxiomFactory.scala @@ -1,24 +1,13 @@ package leon package invariant.factories -import z3.scala._ -import purescala.Common._ -import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ -import java.io._ -import leon.invariant._ -import scala.util.control.Breaks._ -import scala.concurrent._ -import scala.concurrent.duration._ import invariant.engine._ import invariant.util._ import invariant.structure._ import FunctionUtils._ -import Util._ import PredicateUtil._ class AxiomFactory(ctx : InferenceContext) { @@ -85,7 +74,6 @@ class AxiomFactory(ctx : InferenceContext) { //this is applicable only to binary operations def undistributeCalls(call1: Call, call2: Call): (Expr,Expr) = { - val fd = call1.fi.tfd.fd val tfd = call1.fi.tfd val Seq(a1,b1) = call1.fi.args @@ -93,9 +81,7 @@ class AxiomFactory(ctx : InferenceContext) { val r1 = call1.retexpr val r2 = call2.retexpr - val dret1 = TVarFactory.createTemp("dt", IntegerType).toVariable val dret2 = TVarFactory.createTemp("dt", IntegerType).toVariable - val dcall1 = Call(dret1, FunctionInvocation(tfd,Seq(a2,Plus(b1,b2)))) val dcall2 = Call(dret2, FunctionInvocation(tfd,Seq(Plus(a1,a2),b2))) (LessEquals(b1,b2), And(LessEquals(Plus(r1,r2),dret2), dcall2.toExpr)) } diff --git a/src/main/scala/leon/invariant/factories/TemplateFactory.scala b/src/main/scala/leon/invariant/factories/TemplateFactory.scala index ab1f51359151de13ef8cd506f4f1c9da9472c208..b44fcfa684d66168cad5de3e0cca7d8df5d11f50 100644 --- a/src/main/scala/leon/invariant/factories/TemplateFactory.scala +++ b/src/main/scala/leon/invariant/factories/TemplateFactory.scala @@ -1,23 +1,16 @@ package leon package invariant.factories -import z3.scala._ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ -import java.io._ -import scala.collection.mutable.{ Map => MutableMap } -import invariant._ import scala.collection.mutable.{Map => MutableMap} -import invariant.engine._ import invariant.util._ import invariant.structure._ import FunctionUtils._ -import Util._ import PredicateUtil._ import ProgramUtil._ @@ -28,7 +21,7 @@ object TemplateIdFactory { def getTemplateIds : Set[Identifier] = ids def freshIdentifier(name : String = "", idType: TypeTree = RealType) : Identifier = { - val idname = if(name.isEmpty()) "a?" + val idname = if(name.isEmpty) "a?" else name + "?" val freshid = FreshIdentifier(idname, idType, true) ids += freshid @@ -72,7 +65,7 @@ class TemplateFactory(tempGen : Option[TemplateGenerator], prog: Program, report //a mapping from function definition to the template private var templateMap = { //initialize the template map with predefined user maps - var muMap = MutableMap[FunDef, Expr]() + val muMap = MutableMap[FunDef, Expr]() functionsWOFields(prog.definedFunctions).foreach { fd => val tmpl = fd.template if (tmpl.isDefined) { @@ -114,7 +107,7 @@ class TemplateFactory(tempGen : Option[TemplateGenerator], prog: Program, report //initialize the template for the function if (!templateMap.contains(fd)) { - if(!tempGen.isDefined) templateMap += (fd -> getDefaultTemplate(fd)) + if(tempGen.isEmpty) templateMap += (fd -> getDefaultTemplate(fd)) else { templateMap += (fd -> tempGen.get.getNextTemplate(fd)) refinementSet += fd diff --git a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala index ccaabc57aabecb9fa22fd94a5972c1a554196729..dbcdc4c5e1a2de62fb5a1c1e81ebe17b4449a584 100644 --- a/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala +++ b/src/main/scala/leon/invariant/factories/TemplateInstantiator.scala @@ -1,23 +1,15 @@ package leon package invariant.factories -import z3.scala._ -import purescala._ -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Extractors._ -import purescala.Types._ -import java.io._ -import invariant.engine._ import invariant.util._ import invariant.structure._ import leon.solvers.Model import leon.invariant.util.RealValuedExprEvaluator -import Util._ import PredicateUtil._ -import ProgramUtil._ object TemplateInstantiator { /** @@ -51,17 +43,17 @@ object TemplateInstantiator { */ def instantiate(expr: Expr, tempVarMap: Map[Expr, Expr], prettyInv: Boolean = false): Expr = { //do a simple post transform and replace the template vars by their values - val inv = simplePostTransform((tempExpr: Expr) => tempExpr match { - case e @ Operator(Seq(lhs, rhs), op) if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan] + val inv = simplePostTransform { + case tempExpr@(e@Operator(Seq(lhs, rhs), op)) if ((e.isInstanceOf[Equals] || e.isInstanceOf[LessThan] || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] || e.isInstanceOf[GreaterEquals]) && - !getTemplateVars(tempExpr).isEmpty) => { + getTemplateVars(tempExpr).nonEmpty) => { val linearTemp = LinearConstraintUtil.exprToTemplate(tempExpr) instantiateTemplate(linearTemp, tempVarMap, prettyInv) } - case _ => tempExpr - })(expr) + case tempExpr => tempExpr + }(expr) inv } diff --git a/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala b/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala index 469317998a98f979a9a90770fd9c902544eee584..96c8d212a59cfc39c8cc428dc521f8e4a278a004 100644 --- a/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala +++ b/src/main/scala/leon/invariant/factories/TemplateSolverFactory.scala @@ -1,12 +1,8 @@ package leon package invariant.factories -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ import invariant._ import invariant.engine._ import invariant.util._ diff --git a/src/main/scala/leon/invariant/structure/Constraint.scala b/src/main/scala/leon/invariant/structure/Constraint.scala index 9d2490fe5f8b61b6847636007c61d6ba3c0dba34..39fd20bca11971ba0feadf26e20331970f9a9ca7 100644 --- a/src/main/scala/leon/invariant/structure/Constraint.scala +++ b/src/main/scala/leon/invariant/structure/Constraint.scala @@ -1,23 +1,10 @@ package leon package invariant.structure -import z3.scala._ -import purescala._ -import purescala.Common._ -import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ -import solvers.{ Solver, TimeoutSolver } -import solvers.z3.FairZ3Solver -import scala.collection.mutable.{ Set => MutableSet } -import scala.collection.mutable.{ Map => MutableMap } -import evaluators._ -import java.io._ -import solvers.z3.UninterpretedZ3Solver import invariant.util._ -import Util._ import PredicateUtil._ trait Constraint { @@ -42,9 +29,7 @@ class LinearTemplate(oper: Seq[Expr] => Expr, } val coeffTemplate = { //assert if the coefficients are templated expressions - assert(coeffTemp.values.foldLeft(true)((acc, e) => { - acc && isTemplateExpr(e) - })) + assert(coeffTemp.values.forall(e => isTemplateExpr(e))) coeffTemp } @@ -110,7 +95,7 @@ class LinearTemplate(oper: Seq[Expr] => Expr, rhsExprs :+= InfiniteIntegerLiteral(-v) case Some(c) => lhsExprs :+= c - case _ => Nil + case _ => } val lhsExprOpt = ((None: Option[Expr]) /: lhsExprs) { case (acc, minterm) => @@ -175,9 +160,7 @@ class LinearConstraint(opr: Seq[Expr] => Expr, cMap: Map[Expr, Expr], constant: val coeffMap = { //assert if the coefficients are only constant expressions - assert(cMap.values.foldLeft(true)((acc, e) => { - acc && variablesOf(e).isEmpty - })) + assert(cMap.values.forall(e => variablesOf(e).isEmpty)) //TODO: here we should try to simplify the constant expressions cMap } diff --git a/src/main/scala/leon/invariant/structure/Formula.scala b/src/main/scala/leon/invariant/structure/Formula.scala index daed61ff679cba61666a6c227005c103a31d8d01..c17711e102877861010359256ff902ad7d51d70b 100644 --- a/src/main/scala/leon/invariant/structure/Formula.scala +++ b/src/main/scala/leon/invariant/structure/Formula.scala @@ -22,8 +22,7 @@ import PredicateUtil._ /** * Data associated with a call */ -class CallData(val guard : Variable, val parents: List[FunDef]) { -} +class CallData(val guard : Variable, val parents: List[FunDef]) /** * Representation of an expression as a set of implications. @@ -81,13 +80,13 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { }) } - val f1 = simplePostTransform((e: Expr) => e match { - case Or(args) => { - val newargs = args.map(arg => arg match { - case v: Variable if (disjuncts.contains(v)) => arg - case v: Variable if (conjuncts.contains(v)) => throw new IllegalStateException("or gaurd inside conjunct: "+e+" or-guard: "+v) - case _ => { - val atoms = arg match { + val f1 = simplePostTransform { + case e@Or(args) => { + val newargs = args.map { + case arg@(v: Variable) if (disjuncts.contains(v)) => arg + case v: Variable if (conjuncts.contains(v)) => throw new IllegalStateException("or gaurd inside conjunct: " + e + " or-guard: " + v) + case arg => { + val atoms = arg match { case And(atms) => atms case _ => Seq(arg) } @@ -98,14 +97,14 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { disjuncts += (g -> ctrs) g } - }) + } //create a temporary for Or val gor = TVarFactory.createTemp("b", BooleanType).toVariable val newor = createOr(newargs) conjuncts += (gor -> newor) gor } - case And(args) => { + case e@And(args) => { val newargs = args.map(arg => if (getTemplateVars(e).isEmpty) { arg } else { @@ -118,8 +117,8 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { }) createAnd(newargs) } - case _ => e - })(ExpressionTransformer.simplify(simplifyArithmetic( + case e => e + }(ExpressionTransformer.simplify(simplifyArithmetic( //TODO: this is a hack as of now. Fix this. //Note: it is necessary to convert real literals to integers since the linear constraint cannot handle real literals if(ctx.usereals) ExpressionTransformer.FractionalLiteralToInt(ine) @@ -151,7 +150,7 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { val e @ Or(guards) = conjuncts(gd) //pick one guard that is true val guard = guards.collectFirst { case g @ Variable(id) if (model(id) == tru) => g } - if (!guard.isDefined) + if (guard.isEmpty) throw new IllegalStateException("No satisfiable guard found: " + e) guard.get +: traverseAnds(guard.get, model) } @@ -236,16 +235,16 @@ class Formula(val fd: FunDef, initexpr: Expr, ctx: InferenceContext) { //replace all conjunct guards in disjuncts by their mapping val disjs : Map[Expr,Expr] = disjuncts.map((entry) => { val (g,ctrs) = entry - val newctrs = ctrs.map(_ match { + val newctrs = ctrs.map { case BoolConstraint(g@Variable(_)) if conjuncts.contains(g) => conjuncts(g) case ctr@_ => ctr.toExpr - }) + } (g, createAnd(newctrs)) }) - val rootexprs = roots.map(_ match { - case g@Variable(_) if conjuncts.contains(g) => conjuncts(g) - case e@_ => e - }) + val rootexprs = roots.map { + case g@Variable(_) if conjuncts.contains(g) => conjuncts(g) + case e@_ => e + } //replace every guard in the 'disjs' by its disjunct. DO this as long as every guard is replaced in every disjunct var unpackedDisjs = disjs var replacedGuard = true diff --git a/src/main/scala/leon/invariant/structure/FunctionUtils.scala b/src/main/scala/leon/invariant/structure/FunctionUtils.scala index a0acdcda7388ee4b6971106cf78cdb49c42768a4..565a6f9f41cbc140c7b6814606897f7cc7eb11f1 100644 --- a/src/main/scala/leon/invariant/structure/FunctionUtils.scala +++ b/src/main/scala/leon/invariant/structure/FunctionUtils.scala @@ -5,13 +5,11 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ import invariant.factories._ import invariant.util._ import Util._ import PredicateUtil._ -import ProgramUtil._ import scala.language.implicitConversions /** @@ -37,8 +35,8 @@ object FunctionUtils { def isTemplateInvocation(finv: Expr) = { finv match { case FunctionInvocation(funInv, args) => - (funInv.id.name == "tmpl" && funInv.returnType == BooleanType && - args.size == 1 && args(0).isInstanceOf[Lambda]) + funInv.id.name == "tmpl" && funInv.returnType == BooleanType && + args.size == 1 && args(0).isInstanceOf[Lambda] case _ => false } @@ -46,8 +44,7 @@ object FunctionUtils { def isQMark(e: Expr) = e match { case FunctionInvocation(TypedFunDef(fd, Seq()), args) => - (fd.id.name == "?" && fd.returnType == IntegerType && - args.size <= 1) + fd.id.name == "?" && fd.returnType == IntegerType && args.size <= 1 case _ => false } @@ -104,7 +101,7 @@ object FunctionUtils { val Lambda(_, postBody) = fd.postcondition.get // collect all terms with question marks and convert them to a template val postWoQmarks = postBody match { - case And(args) if args.exists(exists(isQMark) _) => + case And(args) if args.exists(exists(isQMark)) => val (tempExprs, otherPreds) = args.partition { case a if exists(isQMark)(a) => true case _ => false diff --git a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala index 7910df7230e82f7ac8b695bc8089fcd64c005e3a..5ed3316a44bb6d0baaec32dd1962fbe7f78af0f5 100644 --- a/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala +++ b/src/main/scala/leon/invariant/structure/LinearConstraintUtil.scala @@ -3,23 +3,15 @@ package invariant.structure import purescala._ import purescala.Common._ -import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Extractors._ -import purescala.Types._ -import scala.collection.mutable.{ Set => MutableSet } import scala.collection.mutable.{ Map => MutableMap } -import java.io._ import invariant.util._ import BigInt._ -import Constructors._ -import Util._ import PredicateUtil._ -class NotImplementedException(message: String) extends RuntimeException(message) { - -} +class NotImplementedException(message: String) extends RuntimeException(message) //a collections of utility methods that manipulate the templates object LinearConstraintUtil { @@ -31,14 +23,14 @@ object LinearConstraintUtil { //some utility methods def getFIs(ctr: LinearConstraint): Set[FunctionInvocation] = { - val fis = ctr.coeffMap.keys.collect((e) => e match { + val fis = ctr.coeffMap.keys.collect { case fi: FunctionInvocation => fi - }) + } fis.toSet } def evaluate(lt: LinearTemplate): Option[Boolean] = lt match { - case lc: LinearConstraint if (lc.coeffMap.size == 0) => + case lc: LinearConstraint if lc.coeffMap.isEmpty => ExpressionTransformer.simplify(lt.toExpr) match { case BooleanLiteral(v) => Some(v) case _ => None @@ -74,7 +66,7 @@ object LinearConstraintUtil { } } else coeffMap += (term -> simplifyArithmetic(coeff)) - if (!variablesOf(coeff).isEmpty) { + if (variablesOf(coeff).nonEmpty) { isTemplate = true } } @@ -86,7 +78,7 @@ object LinearConstraintUtil { } else constant = Some(simplifyArithmetic(coeff)) - if (!variablesOf(coeff).isEmpty) { + if (variablesOf(coeff).nonEmpty) { isTemplate = true } } @@ -341,7 +333,6 @@ object LinearConstraintUtil { //now consider each constraint look for (a) equality involving the elimVar or (b) check if all bounds are lower //or (c) if all bounds are upper. var elimExpr : Option[Expr] = None - var bestExpr = false var elimCtr : Option[LinearConstraint] = None var allUpperBounds : Boolean = true var allLowerBounds : Boolean = true @@ -354,7 +345,7 @@ object LinearConstraintUtil { foundEquality = true //here, sometimes we replace an existing expression with a better one if available - if (!elimExpr.isDefined || shouldReplace(elimExpr.get, lc, elimVar)) { + if (elimExpr.isEmpty || shouldReplace(elimExpr.get, lc, elimVar)) { //if the coeffcient of elimVar is +ve the the sign of the coeff of every other term should be changed val InfiniteIntegerLiteral(elimCoeff) = lc.coeffMap(elimVar.toVariable) //make sure the value of the coefficient is 1 or -1 diff --git a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala index 98f162fea4b2974702c4672274ed6d8f86075690..15ae673019e9c50eb3ff257fc88114c935230342 100644 --- a/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/CegisSolver.scala @@ -1,23 +1,17 @@ package leon package invariant.templateSolvers -import z3.scala._ + import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import java.io._ -import scala.util.control.Breaks._ import solvers._ -import solvers.z3._ import invariant.engine._ import invariant.factories._ import invariant.util._ import invariant.structure._ import invariant.structure.FunctionUtils._ import leon.invariant.util.RealValuedExprEvaluator._ -import Util._ import PredicateUtil._ import SolverUtil._ @@ -95,7 +89,7 @@ class CegisCore(ctx: InferenceContext, val tempVarSum = if (minimizeSum) { //compute the sum of the tempIds val rootTempIds = getTemplateVars(cegisSolver.rootFun.getTemplate) - if (rootTempIds.size >= 1) { + if (rootTempIds.nonEmpty) { rootTempIds.tail.foldLeft(rootTempIds.head.asInstanceOf[Expr])((acc, tvar) => Plus(acc, tvar)) } else zero } else zero @@ -126,7 +120,7 @@ class CegisCore(ctx: InferenceContext, //sanity checks val spuriousTempIds = variablesOf(instFormula).intersect(TemplateIdFactory.getTemplateIds) - if (!spuriousTempIds.isEmpty) + if (spuriousTempIds.nonEmpty) throw new IllegalStateException("Found a template variable in instFormula: " + spuriousTempIds) if (hasReals(instFormula)) throw new IllegalStateException("Reals in instFormula: " + instFormula) @@ -144,21 +138,21 @@ class CegisCore(ctx: InferenceContext, //simplify the tempctrs, evaluate every atom that does not involve a template variable //this should get rid of all functions val satctrs = - simplePreTransform((e) => e match { + simplePreTransform { //is 'e' free of template variables ? - case _ if (variablesOf(e).filter(TemplateIdFactory.IsTemplateIdentifier _).isEmpty) => { + case e if !variablesOf(e).exists(TemplateIdFactory.IsTemplateIdentifier _) => { //evaluate the term val value = solver1.evalExpr(e) if (value.isDefined) value.get else throw new IllegalStateException("Cannot evaluate expression: " + e) } - case _ => e - })(Not(formula)) + case e => e + }(Not(formula)) solver1.free() //sanity checks val spuriousProgIds = variablesOf(satctrs).filterNot(TemplateIdFactory.IsTemplateIdentifier _) - if (!spuriousProgIds.isEmpty) + if (spuriousProgIds.nonEmpty) throw new IllegalStateException("Found a progam variable in tempctrs: " + spuriousProgIds) val tempctrs = if (!solveAsInt) ExpressionTransformer.IntLiteralToReal(satctrs) else satctrs @@ -201,7 +195,7 @@ class CegisCore(ctx: InferenceContext, println("2: " + (if (res1.isDefined) "solved" else "timed out") + "... in " + (t4 - t3) / 1000.0 + "s") if (res1.isDefined) { - if (res1.get == false) { + if (!res1.get) { //there exists no solution for templates (Some(false), newctr, Model.empty) diff --git a/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala b/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala index 805f09a941139d817309a3a10df0b07cd57a21aa..9a9c1b0b4290b07408adb33c0e174813de831a29 100644 --- a/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/ExtendedUFSolver.scala @@ -4,16 +4,9 @@ package leon package invariant.templateSolvers import z3.scala._ -import leon.solvers._ -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.Extractors._ -import purescala.ExprOps._ -import purescala.Types._ -import leon.LeonContext import leon.solvers.z3.UninterpretedZ3Solver -import leon.solvers.smtlib.SMTLIBZ3Solver /** * A uninterpreted solver extended with additional functionalities. diff --git a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala index 19dee9ab4be261cd2378d2146864fb9566912ca3..a4e20f61a7424ee59724b8cf4374cecbcaf84f9f 100644 --- a/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/FarkasLemmaSolver.scala @@ -1,26 +1,19 @@ package leon package invariant.templateSolvers -import z3.scala._ -import purescala._ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ -import java.io._ import solvers.SimpleSolverAPI -import scala.collection.mutable.{ Map => MutableMap } import invariant.engine._ -import invariant.factories._ import invariant.util._ import Util._ -import ProgramUtil._ import SolverUtil._ import PredicateUtil._ -import TimerUtil._ import invariant.structure._ +import invariant.datastructure._ import leon.solvers.TimeoutSolver import leon.solvers.SolverFactory import leon.solvers.TimeoutSolverFactory @@ -83,8 +76,8 @@ class FarkasLemmaSolver(ctx: InferenceContext, program: Program) { println("#" * 20) //Optimization 1: Check if ants are unsat (already handled) - val pathVC = createAnd(antsSimple.map(_.toExpr).toSeq ++ conseqsSimple.map(_.toExpr).toSeq) - val notPathVC = And(createAnd(antsSimple.map(_.toExpr).toSeq), Not(createAnd(conseqsSimple.map(_.toExpr).toSeq))) + val pathVC = createAnd(antsSimple.map(_.toExpr) ++ conseqsSimple.map(_.toExpr)) + val notPathVC = And(createAnd(antsSimple.map(_.toExpr)), Not(createAnd(conseqsSimple.map(_.toExpr)))) val (satVC, _) = uisolver.solveSAT(pathVC) val (satNVC, _) = uisolver.solveSAT(notPathVC) @@ -135,7 +128,7 @@ class FarkasLemmaSolver(ctx: InferenceContext, program: Program) { strictCtrLambdas :+= l GreaterEquals(l, zero) } - }).toSeq :+ GreaterEquals(lambda0, zero)) + }) :+ GreaterEquals(lambda0, zero)) //add the constraints on constant terms val sumConst = ants.foldLeft(UMinus(lambda0): Expr)((acc, ant) => ant.constTemplate match { @@ -206,7 +199,7 @@ class FarkasLemmaSolver(ctx: InferenceContext, program: Program) { // factor out common nonlinear terms and create an equiv-satisfiable constraint def reduceCommonNLTerms(ctrs: Expr) = { - var nlUsage = new CounterMap[Expr]() + val nlUsage = new CounterMap[Expr]() postTraversal{ case t: Times => nlUsage.inc(t) case e => ; @@ -223,7 +216,7 @@ class FarkasLemmaSolver(ctx: InferenceContext, program: Program) { // try eliminate nonlinearity to whatever extent possible var elimMap = Map[Identifier, (Identifier, Identifier)]() // maps the fresh identifiers to the product of the identifiers they represent. def reduceNonlinearity(farkasctrs: Expr): Expr = { - var varCounts = new CounterMap[Identifier]() + val varCounts = new CounterMap[Identifier]() // collect # of uses of each variable postTraversal { case Variable(id) => varCounts.inc(id) diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala index 41d672fe5a94861197430a69c9516985e46f2511..90cffa3f181d11deb499a06ec538a515ed164ac1 100644 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolver.scala @@ -9,7 +9,6 @@ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Types._ import evaluators._ -import scala.collection.mutable.{ Map => MutableMap } import java.io._ import solvers._ import solvers.combinators._ @@ -369,7 +368,7 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, val cegisSolver = new CegisCore(ctx, program, timeout.toInt, this) val (res, ctr, model) = cegisSolver.solve(tempIds, expr, precond, solveAsInt = false, initModel) - if (!res.isDefined) + if (res.isEmpty) reporter.info("cegis timed-out on the disjunct...") (res, ctr, model) } @@ -508,7 +507,7 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, val InfiniteIntegerLiteral(v) = model(id) v } - def eval: (Expr => Boolean) = e => e match { + def eval: (Expr => Boolean) = { case And(args) => args.forall(eval) // case Iff(Variable(id1), Variable(id2)) => model(id1) == model(id2) case Equals(Variable(id1), Variable(id2)) => model(id1) == model(id2) //note: ADTs can also be compared for equality @@ -516,7 +515,7 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, case GreaterEquals(Variable(id1), Variable(id2)) => modelVal(id1) >= modelVal(id2) case GreaterThan(Variable(id1), Variable(id2)) => modelVal(id1) > modelVal(id2) case LessThan(Variable(id1), Variable(id2)) => modelVal(id1) < modelVal(id2) - case _ => throw new IllegalStateException("Predicate not handled: " + e) + case e => throw new IllegalStateException("Predicate not handled: " + e) } eval } @@ -526,14 +525,14 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, //println("Identifier: "+id) model(id).asInstanceOf[FractionalLiteral] } - (e: Expr) => e match { + { case Equals(Variable(id1), Variable(id2)) => model(id1) == model(id2) //note: ADTs can also be compared for equality - case Operator(Seq(Variable(id1), Variable(id2)), op) if (e.isInstanceOf[LessThan] + case e@Operator(Seq(Variable(id1), Variable(id2)), op) if (e.isInstanceOf[LessThan] || e.isInstanceOf[LessEquals] || e.isInstanceOf[GreaterThan] || e.isInstanceOf[GreaterEquals]) => { evaluateRealPredicate(op(Seq(modelVal(id1), modelVal(id2)))) } - case _ => throw new IllegalStateException("Predicate not handled: " + e) + case e => throw new IllegalStateException("Predicate not handled: " + e) } } @@ -587,11 +586,11 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, var calls = Set[Call]() var cons = Set[Expr]() - satCtrs.foreach(ctr => ctr match { + satCtrs.foreach { case t: Call => calls += t case t: ADTConstraint if (t.cons.isDefined) => cons += t.cons.get case _ => ; - }) + } val callExprs = calls.map(_.toExpr) var t1 = System.currentTimeMillis() @@ -617,11 +616,11 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, //exclude guards, separate calls and cons from the rest var lnctrs = Set[LinearConstraint]() var temps = Set[LinearTemplate]() - (satCtrs ++ callCtrs ++ axiomCtrs ++ theoryCtrs).foreach(ctr => ctr match { + (satCtrs ++ callCtrs ++ axiomCtrs ++ theoryCtrs).foreach { case t: LinearConstraint => lnctrs += t case t: LinearTemplate => temps += t case _ => ; - }) + } if (this.debugChooseDisjunct) { lnctrs.map(_.toExpr).foreach((ctr) => { @@ -693,7 +692,7 @@ class NLTemplateSolver(ctx: InferenceContext, program: Program, var elimRems = Set[Identifier]() elimLnctrs.foreach((lc) => { val evars = variablesOf(lc.toExpr).intersect(elimVars) - if (!evars.isEmpty) { + if (evars.nonEmpty) { elimCtrs :+= lc elimCtrCount += 1 elimRems ++= evars diff --git a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala index a2a6524c6d6f52ffa232129b4adb2ff0f794ee42..d67955ff9efd38dd9c5c70b0684b04b5cda1661b 100644 --- a/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala +++ b/src/main/scala/leon/invariant/templateSolvers/NLTemplateSolverWithMult.scala @@ -1,15 +1,9 @@ package leon package invariant.templateSolvers -import z3.scala._ -import purescala.Common._ + import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps._ import purescala.Extractors._ -import purescala.Types._ -import java.io._ -import leon.invariant._ -import scala.util.control.Breaks._ import solvers._ import invariant.engine._ @@ -18,7 +12,6 @@ import invariant.util._ import invariant.structure._ import Util._ import PredicateUtil._ -import SolverUtil._ class NLTemplateSolverWithMult(ctx : InferenceContext, program: Program, rootFun: FunDef, ctrTracker: ConstraintTracker, minimizer: Option[(Expr, Model) => Model]) @@ -42,10 +35,10 @@ class NLTemplateSolverWithMult(ctx : InferenceContext, program: Program, rootFun //in the sequel we instantiate axioms for multiplication val inst1 = unaryMultAxioms(formula, calls, predEval(model)) val inst2 = binaryMultAxioms(formula,calls, predEval(model)) - val multCtrs = (inst1 ++ inst2).flatMap(_ match { + val multCtrs = (inst1 ++ inst2).flatMap { case And(args) => args.map(ConstraintUtil.createConstriant _) case e => Seq(ConstraintUtil.createConstriant(e)) - }) + } Stats.updateCounterStats(multCtrs.size, "MultAxiomBlowup", "disjuncts") ctx.reporter.info("Number of multiplication induced predicates: "+multCtrs.size) diff --git a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala index dceed68279e089122875d19c86664425df49aa22..a85a7d4f5b9f55af812da50d577ec36d350b7ce5 100644 --- a/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala +++ b/src/main/scala/leon/invariant/templateSolvers/TemplateSolver.scala @@ -1,18 +1,11 @@ package leon package invariant.templateSolvers -import z3.scala._ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ import java.io._ -import leon.invariant._ -import scala.util.control.Breaks._ -import scala.concurrent._ -import scala.concurrent.duration._ import invariant.engine._ import invariant.factories._ import invariant.util._ diff --git a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala index f41ec7e8963a45978fb4a9ac3fe091de48140d10..6b2aa300b2792f7f86b5f01565ee783731f8d863 100644 --- a/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala +++ b/src/main/scala/leon/invariant/templateSolvers/UFADTEliminator.scala @@ -1,20 +1,14 @@ package leon package invariant.templateSolvers -import z3.scala._ -import purescala.Common._ + import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps._ import purescala.Extractors._ import purescala.Types._ -import java.io._ import invariant.datastructure.UndirectedGraph -import scala.util.control.Breaks._ import invariant.util._ import leon.purescala.TypeOps import PredicateUtil._ -import SolverUtil._ -import Util._ class UFADTEliminator(ctx: LeonContext, program: Program) { @@ -113,7 +107,7 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { lhs :+ rhs } //remove self equalities. - val preds = eqs.filter(_ match { + val preds = eqs.filter { case Operator(Seq(Variable(lid), Variable(rid)), _) => { if (lid == rid) false else { @@ -121,8 +115,8 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { else false } } - case e @ _ => throw new IllegalStateException("Not an equality or Iff: " + e) - }) + case e@_ => throw new IllegalStateException("Not an equality or Iff: " + e) + } preds } @@ -134,21 +128,21 @@ class UFADTEliminator(ctx: LeonContext, program: Program) { axiomatizeADTCons(call1, call2) } - if (makeEfficient && ants.exists(_ match { + if (makeEfficient && ants.exists { case Equals(l, r) if (l.getType != RealType && l.getType != BooleanType && l.getType != IntegerType) => true case _ => false - })) { + }) { Seq() } else { var unsatIntEq: Option[Expr] = None var unsatOtherEq: Option[Expr] = None ants.foreach(eq => - if (!unsatOtherEq.isDefined) { + if (unsatOtherEq.isEmpty) { eq match { case Equals(lhs @ Variable(_), rhs @ Variable(_)) if !predEval(Equals(lhs, rhs)) => { if (lhs.getType != Int32Type && lhs.getType != RealType && lhs.getType != IntegerType) unsatOtherEq = Some(eq) - else if (!unsatIntEq.isDefined) + else if (unsatIntEq.isEmpty) unsatIntEq = Some(eq) } case _ => ; diff --git a/src/main/scala/leon/invariant/util/CallGraph.scala b/src/main/scala/leon/invariant/util/CallGraph.scala index 23f87d5bc0b62163ef40e852d7cc0af0cd658bc3..5e93c451215e9e1f6679e7e94da47a4a2d707a3c 100644 --- a/src/main/scala/leon/invariant/util/CallGraph.scala +++ b/src/main/scala/leon/invariant/util/CallGraph.scala @@ -1,15 +1,10 @@ package leon package invariant.util -import purescala._ -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ import ProgramUtil._ -import Util._ import invariant.structure.FunctionUtils._ import invariant.datastructure._ @@ -71,7 +66,7 @@ class CallGraph { graph.getNodes.toList.foreach((f) => { var inserted = false var index = 0 - for (i <- 0 to funcList.length - 1) { + for (i <- funcList.indices) { if (!inserted && this.transitivelyCalls(funcList(i), f)) { index = i inserted = true @@ -97,7 +92,6 @@ class CallGraph { object CallGraphUtil { def constructCallGraph(prog: Program, onlyBody: Boolean = false, withTemplates: Boolean = false): CallGraph = { -// // println("Constructing call graph") val cg = new CallGraph() functionsWOFields(prog.definedFunctions).foreach((fd) => { @@ -125,17 +119,11 @@ object CallGraphUtil { cg } - def getCallees(expr: Expr): Set[FunDef] = { - var callees = Set[FunDef]() - simplePostTransform((expr) => expr match { - //note: do not consider field invocations - case FunctionInvocation(TypedFunDef(callee, _), args) - if callee.isRealFunction => { - callees += callee - expr - } - case _ => expr - })(expr) - callees - } + def getCallees(expr: Expr): Set[FunDef] = collect { + case expr@FunctionInvocation(TypedFunDef(callee, _), _) if callee.isRealFunction => + Set(callee) + case _ => + Set[FunDef]() + }(expr) + } \ No newline at end of file diff --git a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala index 6750bd59c144fa656b2d820516d6691e3e169862..9bf2889ed103f0e14083a031e5b80c700e5c89d0 100644 --- a/src/main/scala/leon/invariant/util/ExpressionTransformer.scala +++ b/src/main/scala/leon/invariant/util/ExpressionTransformer.scala @@ -8,10 +8,7 @@ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Types._ import java.io._ -import java.io._ import purescala.ScalaPrinter -import invariant.structure.Call -import invariant.structure.FunctionUtils._ import leon.invariant.factories.TemplateIdFactory import PredicateUtil._ import Util._ @@ -185,7 +182,7 @@ object ExpressionTransformer { } } val (nexp, ncjs) = transform(inexpr, false) - val res = if (!ncjs.isEmpty) { + val res = if (ncjs.nonEmpty) { createAnd(nexp +: ncjs.toSeq) } else nexp res @@ -283,31 +280,30 @@ object ExpressionTransformer { def flattenArgs(args: Seq[Expr], insideFunction: Boolean): (Seq[Expr], Set[Expr]) = { var newConjuncts = Set[Expr]() - val newargs = args.map((arg) => - arg match { - case v: Variable => v - case r: ResultVariable => r - case _ => { - val (nexpr, ncjs) = flattenFunc(arg, insideFunction) - - newConjuncts ++= ncjs - - nexpr match { - case v: Variable => v - case r: ResultVariable => r - case _ => { - val freshArgVar = Variable(TVarFactory.createTemp("arg", arg.getType)) - newConjuncts += Equals(freshArgVar, nexpr) - freshArgVar - } + val newargs = args.map { + case v: Variable => v + case r: ResultVariable => r + case arg => { + val (nexpr, ncjs) = flattenFunc(arg, insideFunction) + + newConjuncts ++= ncjs + + nexpr match { + case v: Variable => v + case r: ResultVariable => r + case _ => { + val freshArgVar = Variable(TVarFactory.createTemp("arg", arg.getType)) + newConjuncts += Equals(freshArgVar, nexpr) + freshArgVar } } - }) + } + } (newargs, newConjuncts) } val (nexp, ncjs) = flattenFunc(inExpr, false) - if (!ncjs.isEmpty) { + if (ncjs.nonEmpty) { createAnd(nexp +: ncjs.toSeq) } else nexp } @@ -387,7 +383,7 @@ object ExpressionTransformer { */ def pullAndOrs(expr: Expr): Expr = { - simplePostTransform((e: Expr) => e match { + simplePostTransform { case Or(args) => { val newArgs = args.foldLeft(Seq[Expr]())((acc, arg) => arg match { case Or(inArgs) => acc ++ inArgs @@ -402,8 +398,8 @@ object ExpressionTransformer { }) createAnd(newArgs) } - case _ => e - })(expr) + case e => e + }(expr) } def classSelToCons(e: Expr): Expr = { @@ -466,15 +462,15 @@ object ExpressionTransformer { */ def unFlatten(ine: Expr, freevars: Set[Identifier]): Expr = { var tempMap = Map[Expr, Expr]() - val newinst = simplePostTransform((e: Expr) => e match { - case Equals(v @ Variable(id), rhs @ _) if !freevars.contains(id) => + val newinst = simplePostTransform { + case e@Equals(v@Variable(id), rhs@_) if !freevars.contains(id) => if (tempMap.contains(v)) e else { tempMap += (v -> rhs) tru } - case _ => e - })(ine) + case e => e + }(ine) val closure = (e: Expr) => replace(tempMap, e) fix(closure)(newinst) } @@ -510,11 +506,11 @@ object ExpressionTransformer { def isSubExpr(key: Expr, expr: Expr): Boolean = { var found = false - simplePostTransform((e: Expr) => e match { - case _ if (e == key) => + simplePostTransform { + case e if (e == key) => found = true; e - case _ => e - })(expr) + case e => e + }(expr) found } @@ -524,7 +520,7 @@ object ExpressionTransformer { def simplify(expr: Expr): Expr = { //Note: some simplification are already performed by the class constructors (see Tree.scala) - simplePostTransform((e: Expr) => e match { + simplePostTransform { case Equals(lhs, rhs) if (lhs == rhs) => tru case LessEquals(lhs, rhs) if (lhs == rhs) => tru case GreaterEquals(lhs, rhs) if (lhs == rhs) => tru @@ -536,8 +532,8 @@ object ExpressionTransformer { case LessThan(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 < v2) case GreaterEquals(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 >= v2) case GreaterThan(InfiniteIntegerLiteral(v1), InfiniteIntegerLiteral(v2)) => BooleanLiteral(v1 > v2) - case _ => e - })(expr) + case e => e + }(expr) } /** @@ -545,7 +541,7 @@ object ExpressionTransformer { * Note: (a) Not(Equals()) and Not(Variable) is allowed */ def isDisjunct(e: Expr): Boolean = e match { - case And(args) => args.foldLeft(true)((acc, arg) => acc && isDisjunct(arg)) + case And(args) => args.forall(arg => isDisjunct(arg)) case Not(Equals(_, _)) | Not(Variable(_)) => true case Or(_) | Implies(_, _) | Not(_) | Equals(_, _) => false case _ => true @@ -556,7 +552,7 @@ object ExpressionTransformer { * Note: (a) Not(Equals()) and Not(Variable) is allowed */ def isConjunct(e: Expr): Boolean = e match { - case Or(args) => args.foldLeft(true)((acc, arg) => acc && isConjunct(arg)) + case Or(args) => args.forall(arg => isConjunct(arg)) case Not(Equals(_, _)) | Not(Variable(_)) => true case And(_) | Implies(_, _) | Not(_) | Equals(_, _) => false case _ => true @@ -568,17 +564,17 @@ object ExpressionTransformer { case And(args) => { //have we seen an or ? if (seen == 2) false - else args.foldLeft(true)((acc, arg) => acc && uniOP(arg, 1)) + else args.forall(arg => uniOP(arg, 1)) } case Or(args) => { //have we seen an And ? if (seen == 1) false - else args.foldLeft(true)((acc, arg) => acc && uniOP(arg, 2)) + else args.forall(arg => uniOP(arg, 2)) } case t: Terminal => true /*case u @ UnaryOperator(e1, op) => uniOP(e1, seen) case b @ BinaryOperator(e1, e2, op) => uniOP(e1, seen) && uniOP(e2, seen)*/ - case n @ Operator(args, op) => args.foldLeft(true)((acc, arg) => acc && uniOP(arg, seen)) + case n @ Operator(args, op) => args.forall(arg => uniOP(arg, seen)) } def printRec(e: Expr, indent: Int): Unit = { @@ -588,7 +584,7 @@ object ExpressionTransformer { e match { case And(args) => { var start = true - args.map((arg) => { + args.foreach((arg) => { wr.print(" " * (indent + 1)) if (!start) wr.print("^") printRec(arg, indent + 1) @@ -597,7 +593,7 @@ object ExpressionTransformer { } case Or(args) => { var start = true - args.map((arg) => { + args.foreach((arg) => { wr.print(" " * (indent + 1)) if (!start) wr.print("v") printRec(arg, indent + 1) @@ -627,8 +623,8 @@ object ExpressionTransformer { } def distribute(e: Expr): Expr = { - simplePreTransform(_ match { - case e @ FunctionInvocation(TypedFunDef(fd, _), Seq(e1, e2)) if isMultFunctions(fd) => + simplePreTransform { + case e@FunctionInvocation(TypedFunDef(fd, _), Seq(e1, e2)) if isMultFunctions(fd) => val newe = (e1, e2) match { case (Plus(sum1, sum2), _) => // distribute e2 over e1 @@ -655,7 +651,7 @@ object ExpressionTransformer { } newe case other => other - })(e) + }(e) } distribute(e) } diff --git a/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala b/src/main/scala/leon/invariant/util/LetTupleSimplification.scala similarity index 92% rename from src/main/scala/leon/invariant/util/LetTupleSimplifications.scala rename to src/main/scala/leon/invariant/util/LetTupleSimplification.scala index bd99992014bd94ab29c099b7a03eb7e452e154c0..83b7b0cba056c0c7f377caa2c4f76134bab81e24 100644 --- a/src/main/scala/leon/invariant/util/LetTupleSimplifications.scala +++ b/src/main/scala/leon/invariant/util/LetTupleSimplification.scala @@ -7,14 +7,8 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Types._ -import java.io._ -import java.io._ -import purescala.ScalaPrinter import leon.utils._ import PredicateUtil._ - -import invariant.structure.Call -import invariant.structure.FunctionUtils._ import leon.transformations.InstUtil._ /** @@ -31,13 +25,13 @@ object LetTupleSimplification { val bone = BigInt(1) def letSanityChecks(ine: Expr) = { - simplePostTransform(_ match { - case letExpr @ Let(binderId, letValue, body) - if (binderId.getType != letValue.getType) => - throw new IllegalStateException("Binder and value type mismatch: "+ - s"(${binderId.getType},${letValue.getType})") + simplePostTransform { + case letExpr@Let(binderId, letValue, body) + if (binderId.getType != letValue.getType) => + throw new IllegalStateException("Binder and value type mismatch: " + + s"(${binderId.getType},${letValue.getType})") case e => e - })(ine) + }(ine) } /** @@ -130,11 +124,11 @@ object LetTupleSimplification { (arg1, arg2) match { case (_: TupleSelect, _) => error = true case (_, _: TupleSelect) => error = true - case _ => { ; } + case _ => } } - case _ => { ; } + case _ => } } @@ -172,8 +166,8 @@ object LetTupleSimplification { // in the sequel, we are using the fact that 'depth' is positive and // 'ine' contains only 'depth' variables - val simpe = simplePostTransform((e: Expr) => e match { - case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { + val simpe = simplePostTransform { + case e@FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => { if (debugMaxSimplify) { println("Simplifying: " + e) } @@ -183,20 +177,20 @@ object LetTupleSimplification { import invariant.structure.LinearConstraintUtil._ val lt = exprToTemplate(LessEquals(Minus(arg1, arg2), InfiniteIntegerLiteral(0))) //now, check if all the variables in 'lt' have only positive coefficients - val allPositive = lt.coeffTemplate.forall(entry => entry match { + val allPositive = lt.coeffTemplate.forall { case (k, IntLiteral(v)) if (v >= 0) => true case _ => false - }) && (lt.constTemplate match { + } && (lt.constTemplate match { case None => true case Some(IntLiteral(v)) if (v >= 0) => true case _ => false }) if (allPositive) arg1 else { - val allNegative = lt.coeffTemplate.forall(entry => entry match { + val allNegative = lt.coeffTemplate.forall { case (k, IntLiteral(v)) if (v <= 0) => true case _ => false - }) && (lt.constTemplate match { + } && (lt.constTemplate match { case None => true case Some(IntLiteral(v)) if (v <= 0) => true case _ => false @@ -222,14 +216,14 @@ object LetTupleSimplification { // case FunctionInvocation(tfd, args) if(tfd.fd.id.name == "max") => { // throw new IllegalStateException("Found just max in expression " + e + "\n") // } - case _ => e - })(ine) + case e => e + }(ine) simpe } def inlineMax(ine: Expr): Expr = { //inline 'max' operations here - simplePostTransform((e: Expr) => e match { + simplePostTransform { case FunctionInvocation(tfd, args) if (tfd.fd == maxFun) => val Seq(arg1, arg2) = args val bindWithLet = (value: Expr, body: (Expr with Terminal) => Expr) => { @@ -245,8 +239,8 @@ object LetTupleSimplification { } bindWithLet(arg1, a1 => bindWithLet(arg2, a2 => IfExpr(GreaterEquals(a1, a2), a1, a2))) - case _ => e - })(ine) + case e => e + }(ine) } def removeLetsFromLetValues(ine: Expr): Expr = { @@ -317,10 +311,10 @@ object LetTupleSimplification { Tuple(args :+ e2) })) } - replaceLetBody(transLet, (e: Expr) => e match { + replaceLetBody(transLet, { case Tuple(args) => op(args) - case _ => op(Seq(e)) //here, there was only one argument + case e => op(Seq(e)) //here, there was only one argument }) } transe @@ -378,7 +372,7 @@ object LetTupleSimplification { res } - val transforms = removeLetsFromLetValues _ andThen fixpoint(postMap(simplerLet)) _ andThen simplifyArithmetic + val transforms = removeLetsFromLetValues _ andThen fixpoint(postMap(simplerLet)) andThen simplifyArithmetic transforms(ine) } @@ -398,7 +392,7 @@ object LetTupleSimplification { val allLeaves = getLeaves(e, true) // Here the expression is not of the form we are currently simplifying - if (allLeaves.size == 0) e + if (allLeaves.isEmpty) e else { // fold constants here val allConstantsOpped = allLeaves.foldLeft(identity)((acc, e) => e match { @@ -406,17 +400,17 @@ object LetTupleSimplification { case _ => acc }) - val allNonConstants = allLeaves.filter((e) => e match { + val allNonConstants = allLeaves.filter { case _: InfiniteIntegerLiteral => false case _ => true - }) + } // Reconstruct the expressin tree with the non-constants and the result of constant evaluation above if (allConstantsOpped != identity) { allNonConstants.foldLeft(InfiniteIntegerLiteral(allConstantsOpped): Expr)((acc: Expr, currExpr) => makeTree(acc, currExpr)) } else { - if (allNonConstants.size == 0) InfiniteIntegerLiteral(identity) + if (allNonConstants.isEmpty) InfiniteIntegerLiteral(identity) else { allNonConstants.tail.foldLeft(allNonConstants.head)((acc: Expr, currExpr) => makeTree(acc, currExpr)) } @@ -455,10 +449,11 @@ object LetTupleSimplification { ((a: BigInt, b: BigInt) => if (a > b) a else b), getAllMaximands, 0, - ((e1, e2) => { + (e1, e2) => { val typedMaxFun = TypedFunDef(maxFun, Seq()) FunctionInvocation(typedMaxFun, Seq(e1, e2)) - })) + } + ) maxSimplifiedExpr })(e) diff --git a/src/main/scala/leon/invariant/util/Minimizer.scala b/src/main/scala/leon/invariant/util/Minimizer.scala index e39378f267c0970d85797d19eef468c3a35f1db4..3794f55463490ae048574f3613f720ba092d0ff2 100644 --- a/src/main/scala/leon/invariant/util/Minimizer.scala +++ b/src/main/scala/leon/invariant/util/Minimizer.scala @@ -1,22 +1,15 @@ package leon package invariant.util -import z3.scala._ -import purescala.Common._ + import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Extractors._ -import purescala.Types._ import solvers._ -import solvers.z3._ import solvers.smtlib.SMTLIBZ3Solver -import leon.invariant._ -import scala.util.control.Breaks._ import invariant.engine.InferenceContext import invariant.factories._ -import leon.invariant.templateSolvers.ExtendedUFSolver import leon.invariant.util.RealValuedExprEvaluator._ -import invariant.util.TimerUtil._ class Minimizer(ctx: InferenceContext, program: Program) { diff --git a/src/main/scala/leon/invariant/util/RealToIntExpr.scala b/src/main/scala/leon/invariant/util/RealToInt.scala similarity index 96% rename from src/main/scala/leon/invariant/util/RealToIntExpr.scala rename to src/main/scala/leon/invariant/util/RealToInt.scala index 54650b2c4111b32e127070cdc23889708a81b6fd..6c1860160f2bebc5bf0cb574b9c9934985246f45 100644 --- a/src/main/scala/leon/invariant/util/RealToIntExpr.scala +++ b/src/main/scala/leon/invariant/util/RealToInt.scala @@ -2,12 +2,9 @@ package leon package invariant.util import purescala.Common._ -import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ -import leon.invariant._ import invariant.factories._ import solvers._ diff --git a/src/main/scala/leon/invariant/util/RealExprEvaluator.scala b/src/main/scala/leon/invariant/util/RealValuedExprEvaluator.scala similarity index 100% rename from src/main/scala/leon/invariant/util/RealExprEvaluator.scala rename to src/main/scala/leon/invariant/util/RealValuedExprEvaluator.scala diff --git a/src/main/scala/leon/invariant/util/SolverUtil.scala b/src/main/scala/leon/invariant/util/SolverUtil.scala index e67385c435ea94b24ef63ed8fe8bdbf41f602570..466ec1d2def3831b42c1f0b4bc6eaad3351ef716 100644 --- a/src/main/scala/leon/invariant/util/SolverUtil.scala +++ b/src/main/scala/leon/invariant/util/SolverUtil.scala @@ -1,21 +1,13 @@ package leon package invariant.util -import utils._ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.Extractors._ import purescala.Types._ -import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } -import leon.invariant._ import solvers.z3._ import solvers._ -import invariant.engine._ -import invariant.factories._ -import invariant.structure._ -import FunctionUtils._ import leon.invariant.templateSolvers.ExtendedUFSolver import java.io._ import Util._ @@ -55,7 +47,7 @@ object SolverUtil { case id @ _ if (id.name.toString == "c?") => id.toVariable -> InfiniteIntegerLiteral(2) }.toMap //println("found ids: " + idmap.keys) - if (!idmap.keys.isEmpty) { + if (idmap.keys.nonEmpty) { val newpathcond = replace(idmap, expr) //check if this is solvable val solver = SimpleSolverAPI(SolverFactory(() => new ExtendedUFSolver(ctx, prog))) @@ -76,8 +68,8 @@ object SolverUtil { var controlVars = Map[Variable, Expr]() var newEqs = Map[Expr, Expr]() val solver = new ExtendedUFSolver(ctx, prog) - val newe = simplePostTransform((e: Expr) => e match { - case And(_) | Or(_) => { + val newe = simplePostTransform { + case e@(And(_) | Or(_)) => { val v = TVarFactory.createTemp("a", BooleanType).toVariable newEqs += (v -> e) val newe = Equals(v, e) @@ -88,8 +80,8 @@ object SolverUtil { solver.assertCnstr(Or(newe, cvar)) v } - case _ => e - })(ine) + case e => e + }(ine) //create new variable and add it in disjunction val cvar = FreshIdentifier("ctrl", BooleanType, true).toVariable controlVars += (cvar -> newe) diff --git a/src/main/scala/leon/invariant/util/Stats.scala b/src/main/scala/leon/invariant/util/Stats.scala index 76a1f12ee4eede0ba631969ebb7bf3076bb3bf08..72239ac32227122c5f65c1ea73ac231d04014a11 100644 --- a/src/main/scala/leon/invariant/util/Stats.scala +++ b/src/main/scala/leon/invariant/util/Stats.scala @@ -1,19 +1,11 @@ package leon package invariant.util -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps._ -import purescala.Extractors._ -import purescala.Types._ -import scala.collection.mutable.{ Map => MutableMap } -import java.io._ -import leon.invariant._ import java.io._ import scala.collection.mutable.{Map => MutableMap} - /** * A generic statistics object that provides: * (a) Temporal variables that change over time. We track the total sum and max of the values the variable takes over time diff --git a/src/main/scala/leon/invariant/util/TemporaryVarFactory.scala b/src/main/scala/leon/invariant/util/TVarFactory.scala similarity index 81% rename from src/main/scala/leon/invariant/util/TemporaryVarFactory.scala rename to src/main/scala/leon/invariant/util/TVarFactory.scala index c85c7837b1c4dbf35e2a90b87d93bccca929a6c9..6743e802b27ef44b1b966555d6965130373b5f8b 100644 --- a/src/main/scala/leon/invariant/util/TemporaryVarFactory.scala +++ b/src/main/scala/leon/invariant/util/TVarFactory.scala @@ -2,11 +2,8 @@ package leon package invariant.util import purescala.Common._ -import purescala.Definitions._ -import purescala.Expressions._ -import purescala.ExprOps._ import purescala.Types._ -import scala.collection.mutable.{ Set => MutableSet, Map => MutableMap } +import scala.collection.mutable.{ Set => MutableSet} object TVarFactory { diff --git a/src/main/scala/leon/invariant/util/TimerUtil.scala b/src/main/scala/leon/invariant/util/TimerUtil.scala index 1284951709d4c374c56f1115b8b946f1a796855d..77e9b520116234e1d4bb6fee6ddf7563deed1640 100644 --- a/src/main/scala/leon/invariant/util/TimerUtil.scala +++ b/src/main/scala/leon/invariant/util/TimerUtil.scala @@ -2,9 +2,6 @@ package leon package invariant.util import utils._ -import solvers._ -import invariant.engine._ -import purescala.Expressions._ object TimerUtil { /** @@ -12,13 +9,13 @@ object TimerUtil { */ def scheduleTask(callBack: () => Unit, timeOut: Long): Option[java.util.Timer] = { if (timeOut > 0) { - val timer = new java.util.Timer(); + val timer = new java.util.Timer() timer.schedule(new java.util.TimerTask() { def run() { callBack() timer.cancel() //the timer will be cancelled after it runs } - }, timeOut); + }, timeOut) Some(timer) } else None } @@ -30,7 +27,6 @@ class InterruptOnSignal(it: Interruptible) { private var keepRunning = true override def run(): Unit = { - val startTime: Long = System.currentTimeMillis while (!signal && keepRunning) { Thread.sleep(100) // a relatively infrequent poll } diff --git a/src/main/scala/leon/invariant/util/TreeUtil.scala b/src/main/scala/leon/invariant/util/TreeUtil.scala index 6e3e45f22a49bc3ef70d0d18e3d0371fedac2654..4109e59877c192570e95bb83b819ebd0a59edef7 100644 --- a/src/main/scala/leon/invariant/util/TreeUtil.scala +++ b/src/main/scala/leon/invariant/util/TreeUtil.scala @@ -33,8 +33,8 @@ object ProgramUtil { */ def copyProgram(prog: Program, mapdefs: (Seq[Definition] => Seq[Definition])): Program = { prog.copy(units = prog.units.collect { - case unit if (!unit.defs.isEmpty) => unit.copy(defs = unit.defs.collect { - case module : ModuleDef if (!module.defs.isEmpty) => + case unit if unit.defs.nonEmpty => unit.copy(defs = unit.defs.collect { + case module : ModuleDef if module.defs.nonEmpty => module.copy(defs = mapdefs(module.defs)) case other => other }) @@ -43,8 +43,8 @@ object ProgramUtil { def createTemplateFun(plainTemp: Expr): FunctionInvocation = { val tmpl = Lambda(getTemplateIds(plainTemp).toSeq.map(id => ValDef(id)), plainTemp) - val tmplFd = new FunDef(FreshIdentifier("tmpl", FunctionType(Seq(tmpl.getType), BooleanType), false), Seq(), Seq(ValDef(FreshIdentifier("arg", tmpl.getType), - Some(tmpl.getType))), BooleanType) + val tmplFd = new FunDef(FreshIdentifier("tmpl", FunctionType(Seq(tmpl.getType), BooleanType), false), Seq(), + Seq(ValDef(FreshIdentifier("arg", tmpl.getType))), BooleanType) tmplFd.body = Some(BooleanLiteral(true)) FunctionInvocation(TypedFunDef(tmplFd, Seq()), Seq(tmpl)) } @@ -71,11 +71,11 @@ object ProgramUtil { } def mapFunctionsInExpr(funmap: Map[FunDef, FunDef])(ine: Expr): Expr = { - simplePostTransform((e: Expr) => e match { + simplePostTransform { case FunctionInvocation(tfd, args) if funmap.contains(tfd.fd) => FunctionInvocation(TypedFunDef(funmap(tfd.fd), tfd.tps), args) - case _ => e - })(ine) + case e => e + }(ine) } /** @@ -161,7 +161,7 @@ object ProgramUtil { } def translateExprToProgram(ine: Expr, currProg: Program, newProg: Program): Expr = { - simplePostTransform((e: Expr) => e match { + simplePostTransform { case FunctionInvocation(TypedFunDef(fd, tps), args) => functionByName(fullName(fd)(currProg), newProg) match { case Some(nfd) => @@ -169,8 +169,8 @@ object ProgramUtil { case _ => throw new IllegalStateException(s"Cannot find translation for ${fd.id.name}") } - case _ => e - })(ine) + case e => e + }(ine) } def getFunctionReturnVariable(fd: FunDef) = { @@ -205,18 +205,18 @@ object PredicateUtil { */ def isTemplateExpr(expr: Expr): Boolean = { var foundVar = false - simplePostTransform((e: Expr) => e match { - case Variable(id) => { + simplePostTransform { + case e@Variable(id) => { if (!TemplateIdFactory.IsTemplateIdentifier(id)) foundVar = true e } - case ResultVariable(_) => { + case e@ResultVariable(_) => { foundVar = true e } - case _ => e - })(expr) + case e => e + }(expr) !foundVar } @@ -234,13 +234,13 @@ object PredicateUtil { */ def hasReals(expr: Expr): Boolean = { var foundReal = false - simplePostTransform((e: Expr) => e match { - case _ => { + simplePostTransform { + case e => { if (e.getType == RealType) - foundReal = true; + foundReal = true e } - })(expr) + }(expr) foundReal } @@ -252,13 +252,13 @@ object PredicateUtil { */ def hasInts(expr: Expr): Boolean = { var foundInt = false - simplePostTransform((e: Expr) => e match { + simplePostTransform { case e: Terminal if (e.getType == Int32Type || e.getType == IntegerType) => { - foundInt = true; + foundInt = true e } - case _ => e - })(expr) + case e => e + }(expr) foundInt } @@ -268,29 +268,29 @@ object PredicateUtil { def atomNum(e: Expr): Int = { var count: Int = 0 - simplePostTransform((e: Expr) => e match { - case And(args) => { + simplePostTransform { + case e@And(args) => { count += args.size e } - case Or(args) => { + case e@Or(args) => { count += args.size e } - case _ => e - })(e) + case e => e + }(e) count } def numUIFADT(e: Expr): Int = { var count: Int = 0 - simplePostTransform((e: Expr) => e match { - case FunctionInvocation(_, _) | CaseClass(_, _) | Tuple(_) => { + simplePostTransform { + case e@(FunctionInvocation(_, _) | CaseClass(_, _) | Tuple(_)) => { count += 1 e } - case _ => e - })(e) + case e => e + }(e) count } @@ -327,12 +327,12 @@ object PredicateUtil { //replaces occurrences of mult by Times def multToTimes(ine: Expr): Expr = { - simplePostTransform((e: Expr) => e match { + simplePostTransform { case FunctionInvocation(TypedFunDef(fd, _), args) if isMultFunctions(fd) => { Times(args(0), args(1)) } - case _ => e - })(ine) + case e => e + }(ine) } def createAnd(exprs: Seq[Expr]): Expr = { diff --git a/src/main/scala/leon/invariant/util/Util.scala b/src/main/scala/leon/invariant/util/Util.scala index f8358dec14f22f420a2b19b90dcf9dd9d308d80d..a792f1586479d5989773cdeee7e322eaa66983a6 100644 --- a/src/main/scala/leon/invariant/util/Util.scala +++ b/src/main/scala/leon/invariant/util/Util.scala @@ -1,19 +1,16 @@ package leon package invariant.util -import purescala.Common._ -import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps._ import purescala.Types._ -import leon.purescala.PrettyPrintable -import leon.purescala.PrinterContext +import purescala.PrettyPrintable +import purescala.PrinterContext import purescala.PrinterHelpers._ object FileCountGUID { var fileCount = 0 def getID: Int = { - var oldcnt = fileCount + val oldcnt = fileCount fileCount += 1 oldcnt } @@ -32,7 +29,7 @@ case class ResultVariable(tpe: TypeTree) extends Expr with Terminal with PrettyP val getType = tpe override def toString: String = "#res" - def printWith(implicit pctx: PrinterContext) { + def printWith(implicit pctx: PrinterContext) = { p"#res" } } diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 728b602f9fa343ae8657a2922ca03a3dfe1ada3f..4b960a043d2baf82b96a05a7133bf05faf6fafe6 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -110,7 +110,7 @@ object Constructors { canBeSubtypeOf(actualType, typeParamsOf(formalType).toSeq, formalType) match { case Some(tmap) => FunctionInvocation(fd.typed(fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) }), args) - case None => sys.error(s"$actualType cannot be a subtype of $formalType!") + case None => throw LeonFatalError(s"$args:$actualType cannot be a subtype of $formalType!") } } @@ -161,7 +161,7 @@ object Constructors { BooleanLiteral(true) } } - /** $encodingof `... match { ... }` but simplified if possible. Throws an error if no case can match the scrutined expression. + /** $encodingof `... match { ... }` but simplified if possible. Simplifies to [[Error]] if no case can match the scrutined expression. * @see [[purescala.Expressions.MatchExpr MatchExpr]] */ def matchExpr(scrutinee : Expr, cases : Seq[MatchCase]) : Expr ={ @@ -249,8 +249,8 @@ object Constructors { /** $encodingof Simplified `Array(...)` (array length defined at compile-time) * @see [[purescala.Expressions.NonemptyArray NonemptyArray]] */ - def finiteArray(els: Seq[Expr]): Expr = { - require(els.nonEmpty) + def finiteArray(els: Seq[Expr], tpe: TypeTree = Untyped): Expr = { + require(els.nonEmpty || tpe != Untyped) finiteArray(els, None, Untyped) // Untyped is not correct, but will not be used anyway } /** $encodingof Simplified `Array[...](...)` (array length and default element defined at run-time) with type information diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index ba6b11bcf49398e4a9ff1456aa4d7bffc734d2a8..a3d00ad98624de40611d810ed1f37f011c12f892 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -3,7 +3,6 @@ package leon package purescala -import sun.reflect.generics.tree.ReturnType import utils.Library import Common._ import Expressions._ @@ -41,27 +40,21 @@ object Definitions { } } - /** A ValDef represents a parameter of a [[purescala.Definitions.FunDef function]] or - * a [[purescala.Definitions.CaseClassDef case class]]. - * - * The optional [[tpe]], if present, overrides the type of the underlying Identifier [[id]]. - * This is useful to instantiate argument types of polymorphic classes. To be consistent, - * never use the type of [[id]] directly; use [[ValDef#getType]] instead. - */ - case class ValDef(id: Identifier, tpe: Option[TypeTree] = None) extends Definition with Typed { + /** + * A ValDef declares a new identifier to be of a certain type. + * The optional tpe, if present, overrides the type of the underlying Identifier id + * This is useful to instantiate argument types of polymorphic functions + */ + case class ValDef(id: Identifier, isLazy: Boolean = false) extends Definition with Typed { self: Serializable => - val getType = tpe getOrElse id.getType + val getType = id.getType var defaultValue : Option[FunDef] = None def subDefinitions = Seq() - /** Transform this [[ValDef]] into a [[Expressions.Variable Variable]] - * - * Warning: the variable will not have the same type as this ValDef, but currently - * the Identifier type is enough for all uses in Leon. - */ + /** Transform this [[ValDef]] into a [[Expressions.Variable Variable]] */ def toVariable : Variable = Variable(id) } @@ -160,7 +153,7 @@ object Definitions { object UnitDef { def apply(id: Identifier, modules : Seq[ModuleDef]) : UnitDef = - UnitDef(id,Nil, Nil, modules,true) + UnitDef(id, Nil, Nil, modules, true) } /** Objects work as containers for class definitions, functions (def's) and @@ -394,6 +387,10 @@ object Definitions { def postcondition_=(op: Option[Expr]) = { fullBody = withPostcondition(fullBody, op) } + def postOrTrue = postcondition getOrElse { + val arg = ValDef(FreshIdentifier("res", returnType, alwaysShowUniqueID = true)) + Lambda(Seq(arg), BooleanLiteral(true)) + } def hasBody = body.isDefined def hasPrecondition = precondition.isDefined @@ -481,11 +478,20 @@ object Definitions { def translated(e: Expr): Expr = instantiateType(e, typesMap, paramsMap) + /** A mapping from this [[TypedFunDef]]'s formal parameters to real arguments + * + * @param realArgs The arguments to which the formal argumentas are mapped + * */ def paramSubst(realArgs: Seq[Expr]) = { require(realArgs.size == params.size) (paramIds zip realArgs).toMap } + /** Substitute this [[TypedFunDef]]'s formal parameters with real arguments in some expression + * + * @param realArgs The arguments to which the formal argumentas are mapped + * @param e The expression in which the substitution will take place + */ def withParamSubst(realArgs: Seq[Expr], e: Expr) = { replaceFromIDs(paramSubst(realArgs), e) } @@ -505,11 +511,10 @@ object Definitions { if (typesMap.isEmpty) { (fd.params, Map()) } else { - val newParams = fd.params.map { - case vd @ ValDef(id, _) => - val newTpe = translated(vd.getType) - val newId = FreshIdentifier(id.name, newTpe, true).copiedFrom(id) - ValDef(newId).setPos(vd) + val newParams = fd.params.map { vd => + val newTpe = translated(vd.getType) + val newId = FreshIdentifier(vd.id.name, newTpe, true).copiedFrom(vd.id) + vd.copy(id = newId).setPos(vd) } val paramsMap: Map[Identifier, Identifier] = (fd.params zip newParams).map { case (vd1, vd2) => vd1.id -> vd2.id }.toMap @@ -539,6 +544,7 @@ object Definitions { def precondition = fd.precondition map cached def precOrTrue = cached(fd.precOrTrue) def postcondition = fd.postcondition map cached + def postOrTrue = cached(fd.postOrTrue) def hasImplementation = body.isDefined def hasBody = hasImplementation diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index c106d1393bcf793a7fe5d6944ecd533d84a42e36..d69bdbf94e608fed2a7d8fb3a2a60b1c0bd5f276 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -345,6 +345,7 @@ object ExprOps { val subvs = subs.flatten.toSet e match { case Variable(i) => subvs + i + case Old(i) => subvs + i case LetDef(fd, _) => subvs -- fd.params.map(_.id) case Let(i, _, _) => subvs - i case LetVar(i, _, _) => subvs - i @@ -539,8 +540,8 @@ object ExprOps { } val normalized = postMap { - case Lambda(args, body) => Some(Lambda(args.map(vd => ValDef(subst(vd.id), vd.tpe)), body)) - case Forall(args, body) => Some(Forall(args.map(vd => ValDef(subst(vd.id), vd.tpe)), body)) + case Lambda(args, body) => Some(Lambda(args.map(vd => vd.copy(id = subst(vd.id))), body)) + case Forall(args, body) => Some(Forall(args.map(vd => vd.copy(id = subst(vd.id))), body)) case Let(i, e, b) => Some(Let(subst(i), e, b)) case MatchExpr(scrut, cses) => Some(MatchExpr(scrut, cses.map { cse => cse.copy(pattern = replacePatternBinders(cse.pattern, subst)) @@ -791,7 +792,7 @@ object ExprOps { }) case CaseClassPattern(_, cct, subps) => - val subExprs = (subps zip cct.fields) map { + val subExprs = (subps zip cct.classDef.fields) map { case (p, f) => p.binder.map(_.toVariable).getOrElse(caseClassSelector(cct, in, f.id)) } @@ -868,8 +869,8 @@ object ExprOps { } case CaseClassPattern(ob, cct, subps) => - assert(cct.fields.size == subps.size) - val pairs = cct.fields.map(_.id).toList zip subps.toList + assert(cct.classDef.fields.size == subps.size) + val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2)) val together = and(bind(ob, in) +: subTests :_*) and(IsInstanceOf(in, cct), together) @@ -880,7 +881,7 @@ object ExprOps { val subTests = subps.zipWithIndex.map{case (p, i) => rec(tupleSelect(in, i+1, subps.size), p)} and(bind(ob, in) +: subTests: _*) - case up@UnapplyPattern(ob, fd, subps) => + case up @ UnapplyPattern(ob, fd, subps) => def someCase(e: Expr) = { // In the case where unapply returns a Some, it is enough that the subpatterns match andJoin(unwrapTuple(e, subps.size) zip subps map { case (ex, p) => rec(ex, p).setPos(p) }).setPos(e) @@ -904,7 +905,7 @@ object ExprOps { pattern match { case CaseClassPattern(b, cct, subps) => assert(cct.fields.size == subps.size) - val pairs = cct.fields.map(_.id).toList zip subps.toList + val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList val subMaps = pairs.map(p => mapForPattern(caseClassSelector(cct, asInstOf(in, cct), p._1), p._2)) val together = subMaps.flatten.toMap bindIn(b, Some(cct)) ++ together @@ -1138,13 +1139,47 @@ object ExprOps { case tp: TypeParameter => GenericValue(tp, 0) - case FunctionType(from, to) => - val args = from.map(tpe => ValDef(FreshIdentifier("x", tpe, true))) - Lambda(args, simplestValue(to)) + case ft @ FunctionType(from, to) => + PartialLambda(Seq.empty, Some(simplestValue(to)), ft) case _ => throw LeonFatalError("I can't choose simplest value for type " + tpe) } + def valuesOf(tp: TypeTree): Stream[Expr] = { + import utils.StreamUtils._ + tp match { + case BooleanType => + Stream(BooleanLiteral(false), BooleanLiteral(true)) + case Int32Type => + Stream.iterate(0) { prev => + if (prev > 0) -prev else -prev + 1 + } map IntLiteral + case IntegerType => + Stream.iterate(BigInt(0)) { prev => + if (prev > 0) -prev else -prev + 1 + } map InfiniteIntegerLiteral + case UnitType => + Stream(UnitLiteral()) + case tp: TypeParameter => + Stream.from(0) map (GenericValue(tp, _)) + case TupleType(stps) => + cartesianProduct(stps map (tp => valuesOf(tp))) map Tuple + case SetType(base) => + def elems = valuesOf(base) + elems.scanLeft(Stream(FiniteSet(Set(), base): Expr)){ (prev, curr) => + prev flatMap { + case fs@FiniteSet(elems, tp) => + Stream(fs, FiniteSet(elems + curr, tp)) + } + }.flatten // FIXME Need cp οr is this fine? + case cct: CaseClassType => + cartesianProduct(cct.fieldsTypes map valuesOf) map (CaseClass(cct, _)) + case act: AbstractClassType => + interleave(act.knownCCDescendants.map(cct => valuesOf(cct))) + } + } + + /** Hoists all IfExpr at top level. * * Guarantees that all IfExpr will be at the top level and as soon as you @@ -1289,12 +1324,6 @@ object ExprOps { } } - class ChooseCollectorWithPaths extends CollectorWithPaths[(Choose,Expr)] { - def collect(e: Expr, path: Seq[Expr]) = e match { - case c: Choose => Some(c -> and(path: _*)) - case _ => None - } - } def collectWithPC[T](f: PartialFunction[Expr, T])(expr: Expr): Seq[(T, Expr)] = { CollectorWithPaths(f).traverse(expr) @@ -1318,11 +1347,6 @@ object ExprOps { es.map(formulaSize).sum+1 } - /** Return a list of all [[purescala.Expressions.Choose Choose]] construct inside the expression */ - def collectChooses(e: Expr): List[Choose] = { - new ChooseCollectorWithPaths().traverse(e).map(_._1).toList - } - /** Returns true if the expression is deterministic / does not contain any [[purescala.Expressions.Choose Choose]] or [[purescala.Expressions.Hole Hole]]*/ def isDeterministic(e: Expr): Boolean = { preTraversal{ @@ -1456,8 +1480,8 @@ object ExprOps { val isType = IsInstanceOf(Variable(on), cct) - val recSelectors = cct.fields.collect { - case vd if vd.getType == on.getType => vd.id + val recSelectors = (cct.classDef.fields zip cct.fieldsTypes).collect { + case (vd, tpe) if tpe == on.getType => vd.id } if (recSelectors.isEmpty) { @@ -1951,13 +1975,55 @@ object ExprOps { es foreach rec } - def functionAppsOf(expr: Expr): Set[Application] = { - collect[Application] { - case f: Application => Set(f) - case _ => Set() - }(expr) + object InvocationExtractor { + private def flatInvocation(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { + case fi @ FunctionInvocation(tfd, args) => Some((tfd, args)) + case Application(caller, args) => flatInvocation(caller) match { + case Some((tfd, prevArgs)) => Some((tfd, prevArgs ++ args)) + case None => None + } + case _ => None + } + + def unapply(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { + case IsTyped(f: FunctionInvocation, ft: FunctionType) => None + case IsTyped(f: Application, ft: FunctionType) => None + case FunctionInvocation(tfd, args) => Some(tfd -> args) + case f: Application => flatInvocation(f) + case _ => None + } } + def firstOrderCallsOf(expr: Expr): Set[(TypedFunDef, Seq[Expr])] = + collect[(TypedFunDef, Seq[Expr])] { + case InvocationExtractor(tfd, args) => Set(tfd -> args) + case _ => Set.empty + } (expr) + + object ApplicationExtractor { + private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(fi: FunctionInvocation, _) => None + case Application(caller: Application, args) => flatApplication(caller) match { + case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) + case None => None + } + case Application(caller, args) => Some((caller, args)) + case _ => None + } + + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case IsTyped(f: Application, ft: FunctionType) => None + case f: Application => flatApplication(f) + case _ => None + } + } + + def firstOrderAppsOf(expr: Expr): Set[(Expr, Seq[Expr])] = + collect[(Expr, Seq[Expr])] { + case ApplicationExtractor(caller, args) => Set(caller -> args) + case _ => Set.empty + } (expr) + def simplifyHOFunctions(expr: Expr) : Expr = { def liftToLambdas(expr: Expr) = { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index b834eb9f1270bca728b2cdbaa63bb98a37b0dd49..532d96d44ff00e8530b59e4728dfb34a0edfa3f9 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -76,6 +76,10 @@ object Expressions { val getType = tpe } + case class Old(id: Identifier) extends Expr with Terminal { + val getType = id.getType + } + /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* * * @param pred The precondition formula inside ``require(...)`` @@ -226,7 +230,7 @@ object Expressions { } } - case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], tpe: FunctionType) extends Expr { + case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], default: Option[Expr], tpe: FunctionType) extends Expr { val getType = tpe } @@ -327,7 +331,7 @@ object Expressions { // Hacky, but ok lazy val optionType = unapplyFun.returnType.asInstanceOf[AbstractClassType] lazy val Seq(noneType, someType) = optionType.knownCCDescendants.sortBy(_.fields.size) - lazy val someValue = someType.fields.head + lazy val someValue = someType.classDef.fields.head // Pattern match unapply(scrut) // In case of None, return noneCase. // In case of Some(v), return someCase(v). @@ -765,7 +769,7 @@ object Expressions { } /** $encodingof `set.length` */ case class SetCardinality(set: Expr) extends Expr { - val getType = Int32Type + val getType = IntegerType } /** $encodingof `set.subsetOf(set2)` */ case class SubsetOf(set1: Expr, set2: Expr) extends Expr { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index cd57d2187fe0cd59d45e3d88e17b0926e2279412..47dfd7b7c576bc2e331bf5f7369b2ff6d94e805c 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -41,20 +41,21 @@ object Extractors { Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head))) case Lambda(args, body) => Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head))) - case PartialLambda(mapping, tpe) => + case PartialLambda(mapping, dflt, tpe) => val sze = tpe.from.size + 1 val subArgs = mapping.flatMap { case (args, v) => args :+ v } val builder = (as: Seq[Expr]) => { def rec(kvs: Seq[Expr]): Seq[(Seq[Expr], Expr)] = kvs match { case seq if seq.size >= sze => - val ((args :+ res), rest) = seq.splitAt(sze) + val (args :+ res, rest) = seq.splitAt(sze) (args -> res) +: rec(rest) case Seq() => Seq.empty case _ => sys.error("unexpected number of key/value expressions") } - PartialLambda(rec(as), tpe) + val (nas, nd) = if (dflt.isDefined) (as.init, Some(as.last)) else (as, None) + PartialLambda(rec(nas), nd, tpe) } - Some((subArgs, builder)) + Some((subArgs ++ dflt, builder)) case Forall(args, body) => Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head))) @@ -181,19 +182,21 @@ object Extractors { val l = as.length nonemptyArray(as.take(l - 2), Some((as(l - 2), as(l - 1)))) })) - case NonemptyArray(elems, None) => - val all = elems.values.toSeq - Some((all, finiteArray)) + case na@NonemptyArray(elems, None) => + val ArrayType(tpe) = na.getType + val (indexes, elsOrdered) = elems.toSeq.unzip + + Some(( + elsOrdered, + es => finiteArray(indexes.zip(es).toMap, None, tpe) + )) case Tuple(args) => Some((args, tupleWrap)) case IfExpr(cond, thenn, elze) => Some(( Seq(cond, thenn, elze), { case Seq(c, t, e) => IfExpr(c, t, e) } )) case MatchExpr(scrut, cases) => Some(( - scrut +: cases.flatMap { - case SimpleCase(_, e) => Seq(e) - case GuardedCase(_, e1, e2) => Seq(e1, e2) - }, + scrut +: cases.flatMap { _.expressions }, (es: Seq[Expr]) => { var i = 1 val newcases = for (caze <- cases) yield caze match { @@ -205,14 +208,13 @@ object Extractors { } )) case Passes(in, out, cases) => Some(( - in +: out +: cases.flatMap { - _.expressions - }, { + in +: out +: cases.flatMap { _.expressions }, + { case Seq(in, out, es@_*) => { var i = 0 val newcases = for (caze <- cases) yield caze match { case SimpleCase(b, _) => i += 1; SimpleCase(b, es(i - 1)) - case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 1), es(i - 2)) + case GuardedCase(b, _, _) => i += 2; GuardedCase(b, es(i - 2), es(i - 1)) } passes(in, out, newcases) @@ -352,22 +354,20 @@ object Extractors { } def unapply(me : MatchExpr) : Option[(Pattern, Expr, Expr)] = { - if (me eq null) None else { me match { + Option(me) collect { case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, variablesOf(scrut)) => - Some(( pattern, scrut, body )) - case _ => None - }} + ( pattern, scrut, body ) + } } } object LetTuple { def unapply(me : MatchExpr) : Option[(Seq[Identifier], Expr, Expr)] = { - if (me eq null) None else { me match { + Option(me) collect { case LetPattern(TuplePattern(None,subPatts), value, body) if subPatts forall { case WildcardPattern(Some(_)) => true; case _ => false } => - Some((subPatts map { _.binder.get }, value, body )) - case _ => None - }} + (subPatts map { _.binder.get }, value, body ) + } } } diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 6a8c8f9c9415a6ebc3b9c46de5d5e326a50858da..2014739eaa3fba0c84ce10685087262e7e227872 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -27,7 +27,7 @@ object MethodLifting extends TransformationPhase { // Common for both cases val ct = ccd.typed val binder = FreshIdentifier(ccd.id.name.toLowerCase, ct, true) - val fBinders = ct.fields.map{ f => f.id -> f.id.freshen }.toMap + val fBinders = (ccd.fieldsIds zip ct.fields).map(p => p._1 -> p._2.id.freshen).toMap def subst(e: Expr): Expr = e match { case CaseClassSelector(`ct`, This(`ct`), i) => Variable(fBinders(i)).setPos(e) @@ -37,19 +37,19 @@ object MethodLifting extends TransformationPhase { e } - ccd.methods.find( _.id == fdId).map { m => + ccd.methods.find(_.id == fdId).map { m => // Ancestor's method is a method in the case class - val subPatts = ct.fields map (f => WildcardPattern(Some(fBinders(f.id)))) + val subPatts = ccd.fields map (f => WildcardPattern(Some(fBinders(f.id)))) val patt = CaseClassPattern(Some(binder), ct, subPatts) val newE = simplePreTransform(subst)(breakDown(m.fullBody)) val cse = SimpleCase(patt, newE).setPos(newE) (List(cse), true) - } orElse ccd.fields.find( _.id == fdId).map { f => + } orElse ccd.fields.find(_.id == fdId).map { f => // Ancestor's method is a case class argument in the case class - val subPatts = ct.fields map (fld => + val subPatts = ccd.fields map (fld => if (fld.id == f.id) WildcardPattern(Some(fBinders(f.id))) else @@ -112,7 +112,7 @@ object MethodLifting extends TransformationPhase { val fdParams = fd.params map { vd => val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType)) - ValDef(newId).setPos(vd.getPos) + vd.copy(id = newId).setPos(vd.getPos) } val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap) @@ -140,7 +140,7 @@ object MethodLifting extends TransformationPhase { val retType = instantiateType(fd.returnType, tparamsMap) val fdParams = fd.params map { vd => val newId = FreshIdentifier(vd.id.name, instantiateType(vd.id.getType, tparamsMap)) - ValDef(newId).setPos(vd.getPos) + vd.copy(id = newId).setPos(vd.getPos) } val receiver = FreshIdentifier("thiss", recType).setPos(cd.id) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 6b22810e542bdb9af65251da1c9a4d4ab1ab7142..9a017e7df4f974475986e4939c20bfef0116ae60 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -77,6 +77,9 @@ class PrettyPrinter(opts: PrinterOptions, } p"$name" + case Old(id) => + p"old($id)" + case Variable(id) => p"$id" @@ -244,6 +247,22 @@ class PrettyPrinter(opts: PrinterOptions, case Lambda(args, body) => optP { p"($args) => $body" } + case PartialLambda(mapping, dflt, _) => + optP { + def pm(p: (Seq[Expr], Expr)): PrinterHelpers.Printable = + (pctx: PrinterContext) => p"${purescala.Constructors.tupleWrap(p._1)} => ${p._2}"(pctx) + + if (mapping.isEmpty) { + p"{}" + } else { + p"{ ${nary(mapping map pm)} }" + } + + if (dflt.isDefined) { + p" getOrElse ${dflt.get}" + } + } + case Plus(l,r) => optP { p"$l + $r" } case Minus(l,r) => optP { p"$l - $r" } case Times(l,r) => optP { p"$l * $r" } @@ -315,8 +334,8 @@ class PrettyPrinter(opts: PrinterOptions, case Not(expr) => p"\u00AC$expr" - case vd@ValDef(id, _) => - p"$id : ${vd.getType}" + case vd @ ValDef(id, lzy) => + p"$id :${if (lzy) "=> " else ""} ${vd.getType}" vd.defaultValue.foreach { fd => p" = ${fd.body.get}" } case This(_) => p"this" diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala index 1b00ed1b41a5053fb07a94695b00f595d54453ba..bb5115baab042a1bd524fdb2f73deeabefa59243 100644 --- a/src/main/scala/leon/purescala/Quantification.scala +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -6,10 +6,13 @@ package purescala import Common._ import Definitions._ import Expressions._ +import Constructors._ import Extractors._ import ExprOps._ import Types._ +import evaluators._ + object Quantification { def extractQuorums[A,B]( @@ -18,6 +21,12 @@ object Quantification { margs: A => Set[A], qargs: A => Set[B] ): Seq[Set[A]] = { + def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) + val expandedMap: Map[A, Set[A]] = matchers.map(m => m -> expand(m)).toMap + val reverseMap : Map[A, Set[A]] = expandedMap.toSeq + .flatMap(p => p._2.map(m => m -> p._1)) // flatten to reversed pairs + .groupBy(_._1).mapValues(_.map(_._2).toSet) // rebuild map from pair set + def rec(oms: Seq[A], mSet: Set[A], qss: Seq[Set[B]]): Seq[Set[A]] = { if (qss.contains(quantified)) { Seq(mSet) @@ -34,12 +43,13 @@ object Quantification { } } - def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) - val oms = matchers.toSeq.sortBy(m => -expand(m).size) - rec(oms, Set.empty, Seq.empty) + val oms = expandedMap.toSeq.sortBy(p => -p._2.size).map(_._1) + val res = rec(oms, Set.empty, Seq.empty) + + res.filter(ms => ms.forall(m => reverseMap(m) subsetOf ms)) } - def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Expr, Seq[Expr])]] = { + def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[(Set[(Expr, Expr, Seq[Expr])], Set[(Expr, Expr, Seq[Expr])])] = { object QMatcher { def unapply(e: Expr): Option[(Expr, Seq[Expr])] = e match { case QuantificationMatcher(expr, args) => @@ -52,33 +62,73 @@ object Quantification { } } - extractQuorums(collect { - case QMatcher(e, a) => Set(e -> a) - case _ => Set.empty[(Expr, Seq[Expr])] - } (expr), quantified, - (p: (Expr, Seq[Expr])) => p._2.collect { case QMatcher(e, a) => e -> a }.toSet, - (p: (Expr, Seq[Expr])) => p._2.collect { case Variable(id) if quantified(id) => id }.toSet) + val allMatchers = CollectorWithPaths { case QMatcher(expr, args) => expr -> args }.traverse(expr) + val matchers = allMatchers.map { case ((caller, args), path) => (path, caller, args) }.toSet + + val quorums = extractQuorums(matchers, quantified, + (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case QMatcher(e, a) => (p._1, e, a) }.toSet, + (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case Variable(id) if quantified(id) => id }.toSet) + + quorums.map(quorum => quorum -> matchers.filter(m => !quorum(m))) + } + + def extractModel( + asMap: Map[Identifier, Expr], + funDomains: Map[Identifier, Set[Seq[Expr]]], + typeDomains: Map[TypeTree, Set[Seq[Expr]]], + evaluator: DeterministicEvaluator + ): Map[Identifier, Expr] = asMap.map { case (id, expr) => + id -> (funDomains.get(id) match { + case Some(domain) => + PartialLambda(domain.toSeq.map { es => + val optEv = evaluator.eval(Application(expr, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(expr, es))) + }, None, id.getType.asInstanceOf[FunctionType]) + + case None => postMap { + case p @ PartialLambda(mapping, dflt, tpe) => + Some(PartialLambda(typeDomains.get(tpe) match { + case Some(domain) => domain.toSeq.map { es => + val optEv = evaluator.eval(Application(p, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(p, es))) + } + case _ => Seq.empty + }, None, tpe)) + case _ => None + } (expr) + }) } object HenkinDomains { - def empty = new HenkinDomains(Map.empty) - def apply(domains: Map[TypeTree, Set[Seq[Expr]]]) = new HenkinDomains(domains) + def empty = new HenkinDomains(Map.empty, Map.empty) } - class HenkinDomains (val domains: Map[TypeTree, Set[Seq[Expr]]]) { - def get(e: Expr): Set[Seq[Expr]] = e match { - case PartialLambda(mapping, _) => mapping.map(_._1).toSet - case _ => domains.get(e.getType) match { - case Some(domain) => domain - case None => scala.sys.error("Undefined Henkin domain for " + e) + class HenkinDomains (val lambdas: Map[Lambda, Set[Seq[Expr]]], val tpes: Map[TypeTree, Set[Seq[Expr]]]) { + def get(e: Expr): Set[Seq[Expr]] = { + val specialized: Set[Seq[Expr]] = e match { + case PartialLambda(_, Some(dflt), _) => scala.sys.error("No domain for non-partial lambdas") + case PartialLambda(mapping, _, _) => mapping.map(_._1).toSet + case l: Lambda => lambdas.getOrElse(l, Set.empty) + case _ => Set.empty } + specialized ++ tpes.getOrElse(e.getType, Set.empty) } } object QuantificationMatcher { + private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(fi: FunctionInvocation, args) => Some((fi, args)) + case Application(caller: Application, args) => flatApplication(caller) match { + case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) + case None => None + } + case Application(caller, args) => Some((caller, args)) + case _ => None + } + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(_: Application | _: FunctionInvocation, _) => None - case Application(e, args) => Some(e -> args) + case IsTyped(a: Application, ft: FunctionType) => None + case Application(e, args) => flatApplication(expr) case ArraySelect(arr, index) => Some(arr -> Seq(index)) case MapApply(map, key) => Some(map -> Seq(key)) case ElementOfSet(elem, set) => Some(set -> Seq(elem)) @@ -87,8 +137,15 @@ object Quantification { } object QuantificationTypeMatcher { + private def flatType(tpe: TypeTree): (Seq[TypeTree], TypeTree) = tpe match { + case FunctionType(from, to) => + val (nextArgs, finalTo) = flatType(to) + (from ++ nextArgs, finalTo) + case _ => (Seq.empty, tpe) + } + def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match { - case FunctionType(from, to) => Some(from -> to) + case FunctionType(from, to) => Some(flatType(tpe)) case ArrayType(base) => Some(Seq(Int32Type) -> base) case MapType(from, to) => Some(Seq(from) -> to) case SetType(base) => Some(Seq(base) -> BooleanType) @@ -96,87 +153,84 @@ object Quantification { } } - object CheckForalls extends UnitPhase[Program] { - - val name = "Foralls" - val description = "Check syntax of foralls to guarantee sound instantiations" - - def apply(ctx: LeonContext, program: Program) = { - program.definedFunctions.foreach { fd => - val foralls = collect[Forall] { - case f: Forall => Set(f) - case _ => Set.empty - } (fd.fullBody) - - val free = fd.paramIds.toSet ++ (fd.postcondition match { - case Some(Lambda(args, _)) => args.map(_.id) - case _ => Seq.empty + sealed abstract class ForallStatus { + def isValid: Boolean + } + + case object ForallValid extends ForallStatus { + def isValid = true + } + + sealed abstract class ForallInvalid(msg: String) extends ForallStatus { + def isValid = false + def getMessage: String = msg + } + + case class NoMatchers(expr: String) extends ForallInvalid("No matchers available for E-Matching in " + expr) + case class ComplexArgument(expr: String) extends ForallInvalid("Unhandled E-Matching pattern in " + expr) + case class NonBijectiveMapping(expr: String) extends ForallInvalid("Non-bijective mapping for quantifiers in " + expr) + case class InvalidOperation(expr: String) extends ForallInvalid("Invalid operation on quantifiers in " + expr) + + def checkForall(quantified: Set[Identifier], body: Expr)(implicit ctx: LeonContext): ForallStatus = { + val TopLevelAnds(conjuncts) = body + for (conjunct <- conjuncts) { + val matchers = collect[(Expr, Seq[Expr])] { + case QuantificationMatcher(e, args) => Set(e -> args) + case _ => Set.empty + } (conjunct) + + if (matchers.isEmpty) return NoMatchers(conjunct.asString) + + val complexArgs = matchers.flatMap { case (_, args) => + args.flatMap(arg => arg match { + case QuantificationMatcher(_, _) => None + case Variable(id) => None + case _ if (variablesOf(arg) & quantified).nonEmpty => Some(arg) + case _ => None }) + } - for (Forall(args, TopLevelAnds(conjuncts)) <- foralls) { - val quantified = args.map(_.id).toSet - - for (conjunct <- conjuncts) { - val matchers = collect[(Expr, Seq[Expr])] { - case QuantificationMatcher(e, args) => Set(e -> args) - case _ => Set.empty - } (conjunct) - - if (matchers.isEmpty) - ctx.reporter.warning("E-matching isn't possible without matchers!") - - if (matchers.exists { case (_, args) => - args.exists{ - case QuantificationMatcher(_, _) => false - case Variable(id) => false - case arg => (variablesOf(arg) & quantified).nonEmpty - } - }) ctx.reporter.warning("Matcher arguments must have simple form in " + conjunct) - - val freeMatchers = matchers.collect { case (Variable(id), args) if free(id) => id -> args } - - val id2Quant = freeMatchers.foldLeft(Map.empty[Identifier, Set[Identifier]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - })) - } + if (complexArgs.nonEmpty) return ComplexArgument(complexArgs.head.asString) - if (id2Quant.filter(_._2.nonEmpty).groupBy(_._2).nonEmpty) - ctx.reporter.warning("Multiple matchers must provide bijective matching in " + conjunct) - - fold[Set[Identifier]] { case (m, children) => - val q = children.toSet.flatten - - m match { - case QuantificationMatcher(_, args) => - q -- args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - } - case LessThan(_: Variable, _: Variable) => q - case LessEquals(_: Variable, _: Variable) => q - case GreaterThan(_: Variable, _: Variable) => q - case GreaterEquals(_: Variable, _: Variable) => q - case And(_) => q - case Or(_) => q - case Implies(_, _) => q - case Operator(es, _) => - val vars = es.flatMap { - case Variable(id) => Set(id) - case _ => Set.empty[Identifier] - }.toSet - - if (!(q.isEmpty || (q.size == 1 && (vars & free).isEmpty))) - ctx.reporter.warning("Invalid operation " + m + " on quantified variables") - q -- vars - case Variable(id) if quantified(id) => Set(id) - case _ => q - } - } (conjunct) - } - } + val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + })) } + + val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) + if (bijectiveMappings.size > 1) return NonBijectiveMapping(bijectiveMappings.head._2.head._1.asString) + + val matcherSet = matcherToQuants.filter(_._2.nonEmpty).keys.toSet + + val qs = fold[Set[Identifier]] { case (m, children) => + val q = children.toSet.flatten + + m match { + case QuantificationMatcher(_, args) => + q -- args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + } + case LessThan(_: Variable, _: Variable) => q + case LessEquals(_: Variable, _: Variable) => q + case GreaterThan(_: Variable, _: Variable) => q + case GreaterEquals(_: Variable, _: Variable) => q + case And(_) => q + case Or(_) => q + case Implies(_, _) => q + case Operator(es, _) => + val matcherArgs = matcherSet & es.toSet + if (q.nonEmpty && !(q.size == 1 && matcherArgs.isEmpty && m.getType == BooleanType)) + return InvalidOperation(m.asString) + else Set.empty + case Variable(id) if quantified(id) => Set(id) + case _ => q + } + } (conjunct) } + + ForallValid } } diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 51bed3eaf0f5fc2586f25fc7ff38cd6d8d3857b9..3644191c3261e238e1caf4484604f90dcf920978 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -185,14 +185,6 @@ object TypeOps { freshId(id, typeParamSubst(tps map { case (tpd, tp) => tpd.tp -> tp })(id.getType)) } - def instantiateType(vd: ValDef, tps: Map[TypeParameterDef, TypeTree]): ValDef = { - val ValDef(id, forcedType) = vd - ValDef( - freshId(id, instantiateType(id.getType, tps)), - forcedType map ((tp: TypeTree) => instantiateType(tp, tps)) - ) - } - def instantiateType(tpe: TypeTree, tps: Map[TypeParameterDef, TypeTree]): TypeTree = { if (tps.isEmpty) { tpe @@ -313,7 +305,7 @@ object TypeOps { TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter]) } val returnType = tpeSub(fd.returnType) - val params = fd.params map (instantiateType(_, tps)) + val params = fd.params map (vd => vd.copy(id = freshId(vd.id, tpeSub(vd.getType)))) val newFd = fd.duplicate(id, tparams, params, returnType) val subCalls = preMap { @@ -332,7 +324,7 @@ object TypeOps { case l @ Lambda(args, body) => val newArgs = args.map { arg => val tpe = tpeSub(arg.getType) - ValDef(freshId(arg.id, tpe)) + arg.copy(id = freshId(arg.id, tpe)) } val mapping = args.map(_.id) zip newArgs.map(_.id) Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l) @@ -340,7 +332,7 @@ object TypeOps { case f @ Forall(args, body) => val newArgs = args.map { arg => val tpe = tpeSub(arg.getType) - ValDef(freshId(arg.id, tpe)) + arg.copy(id = freshId(arg.id, tpe)) } val mapping = args.map(_.id) zip newArgs.map(_.id) Forall(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(f) diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index d39fda3338fbeab4dfa9c453a58afa6298bb6ccd..626cee7d0cac6c4692cfc11ce7ad9d2f2c954e9e 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -102,8 +102,14 @@ object Types { if (tmap.isEmpty) { classDef.fields } else { - // This is the only case where ValDef overrides the type of its Identifier - classDef.fields.map(vd => ValDef(vd.id, Some(instantiateType(vd.getType, tmap)))) + // !! WARNING !! + // vd.id changes but this should not be an issue as selector uses + // classDef.params ids which do not change! + classDef.fields.map { vd => + val newTpe = instantiateType(vd.getType, tmap) + val newId = FreshIdentifier(vd.id.name, newTpe).copiedFrom(vd.id) + vd.copy(id = newId).setPos(vd) + } } } diff --git a/src/main/scala/leon/repair/RepairNDEvaluator.scala b/src/main/scala/leon/repair/RepairNDEvaluator.scala index d3e0df1746b1fb0cb12c95764362a58b2215bd5b..56e8467478f5f0135945b9a72da0e525e4ac2c70 100644 --- a/src/main/scala/leon/repair/RepairNDEvaluator.scala +++ b/src/main/scala/leon/repair/RepairNDEvaluator.scala @@ -1,81 +1,25 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -package leon.repair - -import leon.purescala._ -import Definitions._ -import Expressions._ -import Types._ -import ExprOps.postMap -import Constructors.not -import leon.LeonContext -import leon.evaluators.DefaultEvaluator -import scala.util.Try - -// This evaluator treats the condition cond non-deterministically in the following sense: -// If a function invocation fails or violates a postcondition for cond, -// it backtracks and gets executed again for !cond -class RepairNDEvaluator(ctx: LeonContext, prog: Program, fd : FunDef, cond: Expr) extends DefaultEvaluator(ctx, prog) { - - override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { - - case FunctionInvocation(tfd, args) if tfd.fd == fd => - if (gctx.stepsLeft < 0) { - throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")") - } - gctx.stepsLeft -= 1 - - val evArgs = args.map(a => e(a)) - - // build a mapping for the function... - val frame = rctx.newVars(tfd.paramSubst(evArgs)) - - if(tfd.hasPrecondition) { - e(tfd.precondition.get)(frame, gctx) match { - case BooleanLiteral(true) => - case BooleanLiteral(false) => - throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get) - case other => - throw RuntimeError(typeErrorMsg(other, BooleanType)) - } - } - - if(!tfd.hasBody && !rctx.mappings.isDefinedAt(tfd.id)) { - throw EvalError("Evaluation of function with unknown implementation.") - } - - val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) - - def treat(subst : Expr => Expr) = { - val callResult = e(subst(body))(frame, gctx) - - tfd.postcondition match { - case Some(post) => - e(subst(Application(post, Seq(callResult))))(frame, gctx) match { - case BooleanLiteral(true) => - case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } - case None => - } - - callResult - } - - Try { - treat(e => e) - }.getOrElse { - treat( postMap { - // Use reference equality, just in case cond appears again in the program - case c if c eq cond => Some(not(cond)) - case _ => None - }) - } - - case _ => super.e(expr) +package leon +package repair + +import purescala.Definitions.Program +import purescala.Expressions._ +import purescala.ExprOps.valuesOf +import evaluators.StreamEvaluator + +/** This evaluator treats the expression [[nd]] (reference equality) as a non-deterministic value */ +class RepairNDEvaluator(ctx: LeonContext, prog: Program, nd: Expr) extends StreamEvaluator(ctx, prog) { + + override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Stream[Expr] = expr match { + case Not(c) if c eq nd => + // This is a hack: We know the only way nd is wrapped within a Not is if it is NOT within + // a recursive call. So we need to treat it deterministically at this point... + super.e(c) collect { case BooleanLiteral(b) => BooleanLiteral(!b) } + case c if c eq nd => + valuesOf(c.getType) + case other => + super.e(other) } - - - } diff --git a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala index 664b9e3b26f0229cfb0ec12cf52bef7a53d6aa6a..fd12f7216f482fa50ce84816caa79fa8a450d689 100644 --- a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala +++ b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala @@ -8,22 +8,19 @@ import leon.purescala.Common._ import leon.purescala.Expressions._ import leon.purescala.Types._ import leon.purescala.Definitions._ -import leon.purescala.Quantification._ import leon.LeonContext -import leon.evaluators.RecursiveEvaluator +import leon.evaluators._ /** * This evaluator tracks all dependencies between function calls (.fullCallGraph) * as well as if each invocation was successful or erroneous (led to an error) * (.fiStatus) */ -class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) { +class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) with HasDefaultGlobalContext { type RC = CollectingRecContext - type GC = GlobalContext - + def initRC(mappings: Map[Identifier, Expr]) = CollectingRecContext(mappings, None) - def initGC(model: leon.solvers.Model) = new GlobalContext(model) - + type FI = (FunDef, Seq[Expr]) // This is a call graph to track dependencies of function invocations. @@ -46,7 +43,7 @@ class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends Recursive private def registerFailed (fi : FI) = fiStatus_ update (fi, false) def fiStatus = fiStatus_.toMap.withDefaultValue(false) - case class CollectingRecContext(mappings: Map[Identifier, Expr], lastFI : Option[FI]) extends RecContext { + case class CollectingRecContext(mappings: Map[Identifier, Expr], lastFI : Option[FI]) extends RecContext[CollectingRecContext] { def newVars(news: Map[Identifier, Expr]) = copy(news, lastFI) def withLastFI(fi : FI) = copy(lastFI = Some(fi)) } diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 1f939c96e9ca70c72c22fac6f1b8378864b6458f..4959ae799c8f34b25dbdf1cac812ef8d954b4717 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -3,13 +3,11 @@ package leon package repair -import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ import purescala.DefOps._ -import purescala.Quantification._ import purescala.Constructors._ import purescala.Extractors.unwrapTuple @@ -129,22 +127,16 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou val origBody = fd.body.get - val spec = fd.postcondition.getOrElse( - Lambda(Seq(ValDef(FreshIdentifier("res", fd.returnType))), BooleanLiteral(true)) - ) - - val choose = Choose(spec) - val term = Terminating(fd.typed, fd.params.map(_.id.toVariable)) val guide = Guide(origBody) val pre = fd.precOrTrue - val ci = ChooseInfo( - fd, - andJoin(Seq(pre, guide, term)), - origBody, - choose, - eb + val ci = SourceInfo( + fd = fd, + pc = andJoin(Seq(pre, guide, term)), + source = origBody, + spec = fd.postOrTrue, + eb = eb ) // Return synthesizer for this choose diff --git a/src/main/scala/leon/repair/rules/Focus.scala b/src/main/scala/leon/repair/rules/Focus.scala index 8c1694dbda953071d51226782965e488b406b607..93520c5ea9cd467cf73a94392a565022782eb3d9 100644 --- a/src/main/scala/leon/repair/rules/Focus.scala +++ b/src/main/scala/leon/repair/rules/Focus.scala @@ -4,8 +4,9 @@ package leon package repair package rules +import sun.nio.cs.StreamEncoder import synthesis._ -import evaluators._ +import leon.evaluators._ import purescala.Expressions._ import purescala.Common._ @@ -75,9 +76,7 @@ case object Focus extends PreprocessingRule("Focus") { val fdSpec = { val id = FreshIdentifier("res", fd.returnType) - Let(id, fd.body.get, - fd.postcondition.map(l => application(l, Seq(id.toVariable))).getOrElse(BooleanLiteral(true)) - ) + Let(id, fd.body.get, application(fd.postOrTrue, Seq(id.toVariable))) } val TopLevelAnds(clauses) = p.ws @@ -95,11 +94,10 @@ case object Focus extends PreprocessingRule("Focus") { def testCondition(cond: Expr) = { val ndSpec = postMap { - case c if c eq cond => Some(not(cond)) // Use reference equality + case c if c eq cond => Some(not(cond)) case _ => None }(fdSpec) - - forAllTests(ndSpec, Map(), new RepairNDEvaluator(ctx, program, fd, cond)) + forAllTests(ndSpec, Map(), new AngelicEvaluator(new RepairNDEvaluator(ctx, program, cond))) } guides.flatMap { diff --git a/src/main/scala/leon/solvers/EvaluatingSolver.scala b/src/main/scala/leon/solvers/EvaluatingSolver.scala index 3463235c918d4872b0136d84b54f44e11b36cd69..75dcb5631dd0d57e7bbef24b9ce6d5d461d902c5 100644 --- a/src/main/scala/leon/solvers/EvaluatingSolver.scala +++ b/src/main/scala/leon/solvers/EvaluatingSolver.scala @@ -10,7 +10,7 @@ trait EvaluatingSolver extends Solver { val useCodeGen: Boolean - lazy val evaluator: Evaluator = + lazy val evaluator: DeterministicEvaluator = if (useCodeGen) { new CodeGenEvaluator(context, program) } else { diff --git a/src/main/scala/leon/solvers/Model.scala b/src/main/scala/leon/solvers/Model.scala index ab91f594c954805267dd523fd9f82305ab789798..07bdee913f21605fbc41f660af608c492e5ee1b5 100644 --- a/src/main/scala/leon/solvers/Model.scala +++ b/src/main/scala/leon/solvers/Model.scala @@ -12,9 +12,9 @@ trait AbstractModel[+This <: Model with AbstractModel[This]] protected val mapping: Map[Identifier, Expr] - def fill(allVars: Iterable[Identifier]): This = { + def set(allVars: Iterable[Identifier]): This = { val builder = newBuilder - builder ++= mapping ++ (allVars.toSet -- mapping.keys).map(id => id -> simplestValue(id.getType)) + builder ++= allVars.map(id => id -> mapping.getOrElse(id, simplestValue(id.getType))) builder.result } diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala index dc3e8584fd74578ac2173e1819c059da9f0ec99b..fa11ab6613bd65b196cce87ee062c3c56f0b95f9 100644 --- a/src/main/scala/leon/solvers/QuantificationSolver.scala +++ b/src/main/scala/leon/solvers/QuantificationSolver.scala @@ -4,14 +4,14 @@ package solvers import purescala.Common._ import purescala.Expressions._ import purescala.Quantification._ +import purescala.Definitions._ import purescala.Types._ -class HenkinModel(mapping: Map[Identifier, Expr], doms: HenkinDomains) +class HenkinModel(mapping: Map[Identifier, Expr], val doms: HenkinDomains) extends Model(mapping) with AbstractModel[HenkinModel] { override def newBuilder = new HenkinModelBuilder(doms) - def domains: Map[TypeTree, Set[Seq[Expr]]] = doms.domains def domain(expr: Expr) = doms.get(expr) } @@ -26,5 +26,10 @@ class HenkinModelBuilder(domains: HenkinDomains) } trait QuantificationSolver { + val program: Program def getModel: HenkinModel + + protected lazy val requireQuantification = program.definedFunctions.exists { fd => + purescala.ExprOps.exists { case _: Forall => true case _ => false } (fd.fullBody) + } } diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 6c8a0cb8e5f88d59c3cb6a1657d73a38a9a87c37..99d5abc48d07397fe35162d766e0d6d4a5094a35 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -150,18 +150,16 @@ object SolverFactory { getFromName(ctx, program)("fairz3") } - lazy val hasNativeZ3 = { - try { - new _root_.z3.scala.Z3Config - true - } catch { - case _: java.lang.UnsatisfiedLinkError => - false - } + lazy val hasNativeZ3 = try { + new _root_.z3.scala.Z3Config + true + } catch { + case _: java.lang.UnsatisfiedLinkError => + false } lazy val hasZ3 = try { - Z3Interpreter.buildDefault + Z3Interpreter.buildDefault.free() true } catch { case e: java.io.IOException => @@ -169,7 +167,7 @@ object SolverFactory { } lazy val hasCVC4 = try { - CVC4Interpreter.buildDefault + CVC4Interpreter.buildDefault.free() true } catch { case e: java.io.IOException => diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 0c7afeb65199042e34004199d17bba25dde9e1c5..b9c0dfd94e449ac5a7d130e1279da8ebb7c19b92 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -13,7 +13,7 @@ import purescala.ExprOps._ import purescala.Types._ import utils._ -import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre} +import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoChecks, optUnfoldFactor} import templates._ import evaluators._ @@ -26,6 +26,8 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying val feelingLucky = context.findOptionOrDefault(optFeelingLucky) val useCodeGen = context.findOptionOrDefault(optUseCodeGen) val assumePreHolds = context.findOptionOrDefault(optAssumePre) + val disableChecks = context.findOptionOrDefault(optNoChecks) + val unfoldFactor = context.findOptionOrDefault(optUnfoldFactor) protected var lastCheckResult : (Boolean, Option[Boolean], Option[HenkinModel]) = (false, None, None) @@ -111,9 +113,10 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying def extract(b: Expr, m: Matcher[Expr]): Set[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe - val optEnabler = evaluator.eval(b).result + val optEnabler = evaluator.eval(b, model).result + if (optEnabler == Some(BooleanLiteral(true))) { - val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg)).result) + val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg), model).result) if (optArgs.forall(_.isDefined)) { Set(optArgs.map(_.get)) } else { @@ -124,35 +127,22 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying } } - val funDomains = allVars.flatMap(id => id.getType match { - case ft @ FunctionType(fromTypes, _) => - Some(id -> templateGenerator.manager.instantiations(Variable(id), ft).flatMap { - case (b, m) => extract(b, m) - }) - case _ => None - }).toMap.mapValues(_.toSet) - - val asDMap = model.map(p => funDomains.get(p._1) match { - case Some(domain) => - val mapping = domain.toSeq.map { es => - val ev: Expr = p._2 match { - case RawArrayValue(_, mapping, dflt) => - mapping.collectFirst { - case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v - } getOrElse dflt - case _ => scala.sys.error("Unexpected function encoding " + p._2) - } - es -> ev - } + val (typeInsts, partialInsts, lambdaInsts) = templateGenerator.manager.instantiations + + val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) - case None => p - }).toMap + val funDomains: Map[Identifier, Set[Seq[Expr]]] = partialInsts.map { + case (Variable(id), domain) => id -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) - val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { + case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } - val domains = new HenkinDomains(typeDomains) + val asDMap = purescala.Quantification.extractModel(model.toMap, funDomains, typeDomains, evaluator) + val domains = new HenkinDomains(lambdaDomains, typeDomains) new HenkinModel(asDMap, domains) } @@ -160,36 +150,78 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying lastCheckResult = (true, res, model) } - def isValidModel(model: HenkinModel, silenceErrors: Boolean = false): Boolean = { - import EvaluationResults._ + def validatedModel(silenceErrors: Boolean = false): (Boolean, HenkinModel) = { + val lastModel = solver.getModel + val clauses = templateGenerator.manager.checkClauses + val optModel = if (clauses.isEmpty) Some(lastModel) else { + solver.push() + for (clause <- clauses) { + solver.assertCnstr(clause) + } + + reporter.debug(" - Verifying model transitivity") + val solverModel = solver.check match { + case Some(true) => + Some(solver.getModel) + + case Some(false) => + val msg = "- Transitivity independence not guaranteed for model" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + + case None => + val msg = "- Unknown for transitivity independence!?" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + } + + solver.pop() + solverModel + } - val expr = andJoin(constraints.toSeq) - val fullModel = model fill freeVars.toSet + optModel match { + case None => + (false, extractModel(lastModel)) - evaluator.eval(expr, fullModel) match { - case Successful(BooleanLiteral(true)) => - reporter.debug("- Model validated.") - true + case Some(m) => + val model = extractModel(m) - case Successful(BooleanLiteral(false)) => - reporter.debug("- Invalid model.") - false + val expr = andJoin(constraints.toSeq) + val fullModel = model set freeVars.toSet - case Successful(e) => - reporter.warning("- Model leads unexpected result: "+e) - false + (evaluator.check(expr, fullModel) match { + case EvaluationResults.CheckSuccess => + reporter.debug("- Model validated.") + true - case RuntimeError(msg) => - reporter.debug("- Model leads to runtime error.") - false + case EvaluationResults.CheckValidityFailure => + reporter.debug("- Invalid model.") + false - case EvaluatorError(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to evaluator error: " + msg) - } else { - reporter.warning("- Model leads to evaluator error: " + msg) - } - false + case EvaluationResults.CheckRuntimeFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to evaluation error: " + msg) + } else { + reporter.warning("- Model leads to evaluation error: " + msg) + } + false + + case EvaluationResults.CheckQuantificationFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to quantification error: " + msg) + } else { + reporter.warning("- Model leads to quantification error: " + msg) + } + false + }, fullModel) } } @@ -215,9 +247,20 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying foundAnswer(None) case Some(true) => // SAT - val model = extractModel(solver.getModel) + val (valid, model) = if (!this.disableChecks && requireQuantification) { + validatedModel(silenceErrors = false) + } else { + true -> extractModel(solver.getModel) + } + solver.pop() - foundAnswer(Some(true), Some(model)) + if (valid) { + foundAnswer(Some(true), Some(model)) + } else { + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model.asString(context)) + foundAnswer(None, Some(model)) + } case Some(false) if !unrollingBank.canUnroll => solver.pop() @@ -246,12 +289,9 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying case Some(true) => if (feelingLucky && !interrupted) { - val model = extractModel(solver.getModel) - // we might have been lucky :D - if (isValidModel(model, silenceErrors = true)) { - foundAnswer(Some(true), Some(model)) - } + val (valid, model) = validatedModel(silenceErrors = true) + if (valid) foundAnswer(Some(true), Some(model)) } case None => @@ -267,14 +307,17 @@ class UnrollingSolver(val context: LeonContext, val program: Program, underlying if(!hasFoundAnswer) { reporter.debug("- We need to keep going.") - val toRelease = unrollingBank.getBlockersToUnlock + // unfolling `unfoldFactor` times + for (i <- 1 to unfoldFactor.toInt) { + val toRelease = unrollingBank.getBlockersToUnlock - reporter.debug(" - more unrollings") + reporter.debug(" - more unrollings") - val newClauses = unrollingBank.unrollBehind(toRelease) + val newClauses = unrollingBank.unrollBehind(toRelease) - for(ncl <- newClauses) { - solver.assertCnstr(ncl) + for (ncl <- newClauses) { + solver.assertCnstr(ncl) + } } reporter.debug(" - finished unrolling") diff --git a/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala b/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala index 911af8d84b6b6ccc94dd6ed781a77ca99e8a7771..2d3218c7b09edfa465618117a66b7f428f755b1c 100644 --- a/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala +++ b/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala @@ -26,7 +26,7 @@ object AdaptationPhase extends TransformationPhase { CaseClassType(dummy, List(tp)) def mkDummyParameter(tp: TypeParameter) = - ValDef(FreshIdentifier("dummy", mkDummyTyp(tp)), Some(mkDummyTyp(tp))) + ValDef(FreshIdentifier("dummy", mkDummyTyp(tp))) def mkDummyArgument(tree: TypeTree) = CaseClass(CaseClassType(dummy, List(tree)), Nil) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala index dc267825c8356451944b5fb6b7f9de5d55504232..e2eda5e1cc84da7e72fcea0f4b1ce24ff6d91fc9 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala @@ -19,8 +19,10 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol Seq( "-q", "--produce-models", - "--no-incremental", - "--tear-down-incremental", + "--incremental", +// "--no-incremental", +// "--tear-down-incremental", +// "--dt-rewrite-error-sel", // Removing since it causes CVC4 to segfault on some inputs "--rewrite-divk", "--print-success", "--lang", "smt" diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index cb1c3246d7eff487e00904c3acbd61d7068aa316..f1cf73142d51debaeda0fc5064096e522f049955 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -7,11 +7,13 @@ package smtlib import purescala.Common._ import purescala.Expressions._ import purescala.Constructors._ +import purescala.Extractors._ import purescala.Types._ import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Forall => SMTForall, _} import _root_.smtlib.parser.Commands._ import _root_.smtlib.interpreters.CVC4Interpreter +import _root_.smtlib.theories.experimental.Sets trait SMTLIBCVC4Target extends SMTLIBTarget { @@ -27,7 +29,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { sorts.cachedB(tpe) { tpe match { case SetType(base) => - Sort(SMTIdentifier(SSymbol("Set")), Seq(declareSort(base))) + Sets.SetSort(declareSort(base)) case _ => super.declareSort(t) @@ -59,12 +61,11 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case RawArrayType(k, v) => RawArrayValue(k, Map(), fromSMT(elem, v)) - case FunctionType(from, to) => - RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) + case ft @ FunctionType(from, to) => + PartialLambda(Seq.empty, Some(fromSMT(elem, to)), ft) case MapType(k, v) => FiniteMap(Map(), k, v) - } case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), Some(tpe)) => @@ -72,12 +73,11 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { case RawArrayType(k, v) => RawArrayValue(k, Map(), fromSMT(elem, v)) - case FunctionType(from, to) => - RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) + case ft @ FunctionType(from, to) => + PartialLambda(Seq.empty, Some(fromSMT(elem, to)), ft) case MapType(k, v) => FiniteMap(Map(), k, v) - } case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), Some(tpe)) => @@ -86,9 +86,10 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { val RawArrayValue(k, elems, base) = fromSMT(arr, otpe) RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) - case FunctionType(_, v) => - val RawArrayValue(k, elems, base) = fromSMT(arr, otpe) - RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) + case FunctionType(from, v) => + val PartialLambda(mapping, dflt, ft) = fromSMT(arr, otpe) + val args = unwrapTuple(fromSMT(key, tupleTypeWrap(from)), from.size) + PartialLambda(mapping :+ (args -> fromSMT(elem, v)), dflt, ft) case MapType(k, v) => val FiniteMap(elems, k, v) = fromSMT(arr, otpe) @@ -119,33 +120,24 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { */ case fs @ FiniteSet(elems, _) => if (elems.isEmpty) { - QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset")), Some(declareSort(fs.getType))) + Sets.EmptySet(declareSort(fs.getType)) } else { val selems = elems.toSeq.map(toSMT) - val sgt = FunctionApplication(SSymbol("singleton"), Seq(selems.head)) + val sgt = Sets.Singleton(selems.head) if (selems.size > 1) { - FunctionApplication(SSymbol("insert"), selems.tail :+ sgt) + Sets.Insert(selems.tail :+ sgt) } else { sgt } } - case SubsetOf(ss, s) => - FunctionApplication(SSymbol("subset"), Seq(toSMT(ss), toSMT(s))) - - case ElementOfSet(e, s) => - FunctionApplication(SSymbol("member"), Seq(toSMT(e), toSMT(s))) - - case SetDifference(a, b) => - FunctionApplication(SSymbol("setminus"), Seq(toSMT(a), toSMT(b))) - - case SetUnion(a, b) => - FunctionApplication(SSymbol("union"), Seq(toSMT(a), toSMT(b))) - - case SetIntersection(a, b) => - FunctionApplication(SSymbol("intersection"), Seq(toSMT(a), toSMT(b))) + case SubsetOf(ss, s) => Sets.Subset(toSMT(ss), toSMT(s)) + case ElementOfSet(e, s) => Sets.Member(toSMT(e), toSMT(s)) + case SetDifference(a, b) => Sets.Setminus(toSMT(a), toSMT(b)) + case SetUnion(a, b) => Sets.Union(toSMT(a), toSMT(b)) + case SetIntersection(a, b) => Sets.Intersection(toSMT(a), toSMT(b)) case _ => super.toSMT(e) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala index d49e3316716ae2277af5ea35c6115d076e7029bd..d52ac0704e65d119bf3f81c4a30d58723f9aea81 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedTarget.scala @@ -9,8 +9,6 @@ import purescala.Definitions._ import purescala.Constructors._ import purescala.ExprOps._ -import _root_.smtlib.parser.Commands.{Assert => _, FunDef => _, _} - trait SMTLIBQuantifiedTarget extends SMTLIBTarget { protected var currentFunDef: Option[FunDef] = None @@ -30,14 +28,10 @@ trait SMTLIBQuantifiedTarget extends SMTLIBTarget { val inductiveHyps = for { fi@FunctionInvocation(tfd, args) <- functionCallsOf(cond).toSeq } yield { - val formalToRealArgs = tfd.paramIds.zip(args).toMap - val post = tfd.postcondition map { post => - application( - replaceFromIDs(formalToRealArgs, post), - Seq(fi) - ) - } getOrElse BooleanLiteral(true) - + val post = application( + tfd.withParamSubst(args, tfd.postOrTrue), + Seq(fi) + ) and(tfd.precOrTrue, post) } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index f13ed21a2929a0b845e9c0138092bb36b5fc6428..72ecfb9d34039ad03675d2d5523c34462f9de20e 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -9,7 +9,7 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Definitions._ -import _root_.smtlib.parser.Commands.{Assert => SMTAssert, _} +import _root_.smtlib.parser.Commands.{Assert => SMTAssert, FunDef => SMTFunDef, _} import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} @@ -62,27 +62,37 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } protected def getModel(filter: Identifier => Boolean): Model = { - val syms = variables.aSet.filter(filter).toList.map(variables.aToB) + val syms = variables.aSet.filter(filter).map(variables.aToB) if (syms.isEmpty) { Model.empty } else { try { - val cmd: Command = GetValue( - syms.head, - syms.tail.map(s => QualifiedIdentifier(SMTIdentifier(s))) - ) + val cmd = GetModel() emit(cmd) match { - case GetValueResponseSuccess(valuationPairs) => + case GetModelResponseSuccess(smodel) => + var modelFunDefs = Map[SSymbol, DefineFun]() - new Model(valuationPairs.collect { - case (SimpleSymbol(sym), value) if variables.containsB(sym) => - val id = variables.toA(sym) + // first-pass to gather functions + for (me <- smodel) me match { + case me @ DefineFun(SMTFunDef(a, args, _, _)) if args.nonEmpty => + modelFunDefs += a -> me + case _ => + } + + var model = Map[Identifier, Expr]() + + for (me <- smodel) me match { + case DefineFun(SMTFunDef(s, args, kind, e)) if syms(s) => + val id = variables.toA(s) + model += id -> fromSMT(e, id.getType)(Map(), modelFunDefs) + case _ => + } + + new Model(model) - (id, fromSMT(value, id.getType)(Map(), Map())) - }.toMap) case _ => - Model.empty //FIXME improve this + Model.empty // FIXME improve this } } catch { case e : SMTLIBUnsupportedError => @@ -100,6 +110,7 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) variables.push() genericValues.push() sorts.push() + lambdas.push() functions.push() errors.push() @@ -113,6 +124,7 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) variables.pop() genericValues.pop() sorts.pop() + lambdas.pop() functions.pop() errors.pop() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 33b3a26d88a660fe737b144766c8dfdfa48ff685..8b24c2053de01c42becd5336dfce113d0043b75c 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -11,6 +11,7 @@ import purescala.Expressions._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps._ import purescala.Constructors._ import purescala.Definitions._ @@ -18,7 +19,7 @@ import _root_.smtlib.common._ import _root_.smtlib.printer.{ RecursivePrinter => SMTPrinter } import _root_.smtlib.parser.Commands.{ Constructor => SMTConstructor, - FunDef => _, + FunDef => SMTFunDef, Assert => _, _ } @@ -124,6 +125,7 @@ trait SMTLIBTarget extends Interruptible { /* Symbol handling */ protected object SimpleSymbol { + def apply(sym: SSymbol) = QualifiedIdentifier(SMTIdentifier(sym)) def unapply(term: Term): Option[SSymbol] = term match { case QualifiedIdentifier(SMTIdentifier(sym, Seq()), None) => Some(sym) case _ => None @@ -131,9 +133,7 @@ trait SMTLIBTarget extends Interruptible { } import scala.language.implicitConversions - protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = { - QualifiedIdentifier(SMTIdentifier(s)) - } + protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = SimpleSymbol(s) protected val adtManager = new ADTManager(context) @@ -147,14 +147,15 @@ trait SMTLIBTarget extends Interruptible { protected def freshSym(name: String): SSymbol = id2sym(FreshIdentifier(name)) /* Metadata for CC, and variables */ - protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() - protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() - protected val testers = new IncrementalBijection[TypeTree, SSymbol]() - protected val variables = new IncrementalBijection[Identifier, SSymbol]() + protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() + protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() + protected val testers = new IncrementalBijection[TypeTree, SSymbol]() + protected val variables = new IncrementalBijection[Identifier, SSymbol]() protected val genericValues = new IncrementalBijection[GenericValue, SSymbol]() - protected val sorts = new IncrementalBijection[TypeTree, Sort]() - protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() - protected val errors = new IncrementalBijection[Unit, Boolean]() + protected val sorts = new IncrementalBijection[TypeTree, Sort]() + protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() + protected val lambdas = new IncrementalBijection[FunctionType, SSymbol]() + protected val errors = new IncrementalBijection[Unit, Boolean]() protected def hasError = errors.getB(()) contains true protected def addError() = errors += () -> true @@ -202,7 +203,8 @@ trait SMTLIBTarget extends Interruptible { r case ft @ FunctionType(from, to) => - r + val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, from.size) -> v } + PartialLambda(elems, Some(r.default), ft) case MapType(from, to) => // We expect a RawArrayValue with keys in from and values in Option[to], @@ -246,7 +248,7 @@ trait SMTLIBTarget extends Interruptible { declareSort(RawArrayType(from, library.optionType(to))) case FunctionType(from, to) => - Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(tupleTypeWrap(from)), declareSort(to))) + Ints.IntSort() case tp: TypeParameter => declareUninterpretedSort(tp) @@ -329,6 +331,20 @@ trait SMTLIBTarget extends Interruptible { } } + protected def declareLambda(tpe: FunctionType): SSymbol = { + val realTpe = bestRealType(tpe).asInstanceOf[FunctionType] + lambdas.cachedB(realTpe) { + val id = FreshIdentifier("dynLambda") + val s = id2sym(id) + emit(DeclareFun( + s, + (realTpe +: realTpe.from).map(declareSort), + declareSort(realTpe.to) + )) + s + } + } + /* Translate a Leon Expr to an SMTLIB term */ def sortToSMT(s: Sort): SExpr = { @@ -522,7 +538,10 @@ trait SMTLIBTarget extends Interruptible { * ===== Everything else ===== */ case ap @ Application(caller, args) => - ArraysEx.Select(toSMT(caller), toSMT(tupleWrap(args))) + FunctionApplication( + declareLambda(caller.getType.asInstanceOf[FunctionType]), + (caller +: args).map(toSMT) + ) case Not(u) => Core.Not(toSMT(u)) case UMinus(u) => Ints.Neg(toSMT(u)) @@ -611,6 +630,92 @@ trait SMTLIBTarget extends Interruptible { /* Translate an SMTLIB term back to a Leon Expr */ protected def fromSMT(t: Term, otpe: Option[TypeTree] = None)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + object EQ { + def unapply(t: Term): Option[(Term, Term)] = t match { + case Core.Equals(e1, e2) => Some((e1, e2)) + case FunctionApplication(f, Seq(e1, e2)) if f.toString == "=" => Some((e1, e2)) + case _ => None + } + } + + object AND { + def unapply(t: Term): Option[Seq[Term]] = t match { + case Core.And(e1, e2) => Some(Seq(e1, e2)) + case FunctionApplication(SimpleSymbol(SSymbol("and")), args) => Some(args) + case _ => None + } + def apply(ts: Seq[Term]): Term = ts match { + case Seq() => throw new IllegalArgumentException + case Seq(t) => t + case _ => FunctionApplication(SimpleSymbol(SSymbol("and")), ts) + } + } + + object Num { + def unapply(t: Term): Option[BigInt] = t match { + case SNumeral(n) => Some(n) + case FunctionApplication(f, Seq(SNumeral(n))) if f.toString == "-" => Some(-n) + case _ => None + } + } + + def extractLambda(n: BigInt, ft: FunctionType): Expr = { + val FunctionType(from, to) = ft + lambdas.getB(ft) match { + case None => simplestValue(ft) + case Some(dynLambda) => letDefs.get(dynLambda) match { + case None => simplestValue(ft) + case Some(DefineFun(SMTFunDef(a, SortedVar(dispatcher, dkind) +: args, rkind, body))) => + val lambdaArgs = from.map(tpe => FreshIdentifier("x", tpe, true)) + val argsMap: Map[Term, Identifier] = (args.map(sv => symbolToQualifiedId(sv.name)) zip lambdaArgs).toMap + + val d = symbolToQualifiedId(dispatcher) + def dispatch(t: Term): Term = t match { + case Core.ITE(EQ(di, Num(ni)), thenn, elze) if di == d => + if (ni == n) thenn else dispatch(elze) + case Core.ITE(AND(EQ(di, Num(ni)) +: rest), thenn, elze) if di == d => + if (ni == n) Core.ITE(AND(rest), thenn, dispatch(elze)) else dispatch(elze) + case _ => t + } + + def extract(t: Term): Expr = { + def recCond(term: Term): Seq[Expr] = term match { + case AND(es) => + es.foldLeft(Seq.empty[Expr]) { + case (seq, e) => seq ++ recCond(e) + } + case EQ(e1, e2) => + argsMap.get(e1).map(l => l -> e2) orElse argsMap.get(e2).map(l => l -> e1) match { + case Some((lambdaArg, term)) => Seq(Equals(lambdaArg.toVariable, fromSMT(term, lambdaArg.getType))) + case _ => Seq.empty + } + case arg => + argsMap.get(arg) match { + case Some(lambdaArg) => Seq(lambdaArg.toVariable) + case _ => Seq.empty + } + } + + def recCases(term: Term): Expr = term match { + case Core.ITE(cond, thenn, elze) => + IfExpr(andJoin(recCond(cond)), recCases(thenn), recCases(elze)) + case AND(es) if to == BooleanType => + andJoin(recCond(term)) + case EQ(e1, e2) if to == BooleanType => + andJoin(recCond(term)) + case _ => + fromSMT(term, to) + } + + val body = recCases(t) + Lambda(lambdaArgs.map(ValDef(_)), body) + } + + extract(dispatch(body)) + } + } + } + // Use as much information as there is, if there is an expected type, great, but it might not always be there (t, otpe) match { case (_, Some(UnitType)) => @@ -644,6 +749,9 @@ trait SMTLIBTarget extends Interruptible { } } + case (Num(n), Some(ft: FunctionType)) => + extractLambda(n, ft) + case (SNumeral(n), Some(RealType)) => FractionalLiteral(n, 1) @@ -665,6 +773,7 @@ trait SMTLIBTarget extends Interruptible { }.toMap fromSMT(body, tpe)(lets ++ defsMap, letDefs) + case (SimpleSymbol(s), _) if constructors.containsB(s) => constructors.toA(s) match { case cct: CaseClassType => diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index eac8f47f948f5eceb8c728c900c15d31fc8fd310..3d4a06a838a5057a693d85754bb5113e1ce7d0ae 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -76,7 +76,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget { val n = s.name.split("!").toList.last GenericValue(tp, n.toInt) - case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe)) => if (letDefs contains k) { // Need to recover value form function model diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 3d5eec72c809a7ba9459b4b46752835b63bd6011..75d163815a17fbf31c56b684d58e3d26efe7feaa 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -6,93 +6,273 @@ package templates import purescala.Common._ import purescala.Expressions._ +import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ import utils._ import Instantiation._ -class LambdaManager[T](protected val encoder: TemplateEncoder[T]) extends IncrementalState { +case class App[T](caller: T, tpe: FunctionType, args: Seq[T]) { + override def toString = "(" + caller + " : " + tpe + ")" + args.mkString("(", ",", ")") +} + +object LambdaTemplate { + + def apply[T]( + ids: (Identifier, T), + encoder: TemplateEncoder[T], + manager: QuantificationManager[T], + pathVar: (Identifier, T), + arguments: Seq[(Identifier, T)], + condVars: Map[Identifier, T], + exprVars: Map[Identifier, T], + condTree: Map[Identifier, Set[Identifier]], + guardedExprs: Map[Identifier, Seq[Expr]], + quantifications: Seq[QuantificationTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], + baseSubstMap: Map[Identifier, T], + dependencies: Map[Identifier, T], + lambda: Lambda + ) : LambdaTemplate[T] = { + + val id = ids._2 + val tpe = ids._1.getType.asInstanceOf[FunctionType] + val (clauses, blockers, applications, matchers, templateString) = + Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, + substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) + + val lambdaString : () => String = () => { + "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() + } + + val (structuralLambda, structSubst) = normalizeStructure(lambda) + val keyDeps = dependencies.map { case (id, idT) => structSubst(id) -> idT } + val key = structuralLambda.asInstanceOf[Lambda] + + new LambdaTemplate[T]( + ids, + encoder, + manager, + pathVar, + arguments, + condVars, + exprVars, + condTree, + clauses, + blockers, + applications, + quantifications, + matchers, + lambdas, + keyDeps, + key, + lambdaString + ) + } +} + +class LambdaTemplate[T] private ( + val ids: (Identifier, T), + val encoder: TemplateEncoder[T], + val manager: QuantificationManager[T], + val pathVar: (Identifier, T), + val arguments: Seq[(Identifier, T)], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val quantifications: Seq[QuantificationTemplate[T]], + val matchers: Map[T, Set[Matcher[T]]], + val lambdas: Seq[LambdaTemplate[T]], + private[templates] val dependencies: Map[Identifier, T], + private[templates] val structuralKey: Lambda, + stringRepr: () => String) extends Template[T] { + + val args = arguments.map(_._2) + val tpe = ids._1.getType.asInstanceOf[FunctionType] + + def substitute(substituter: T => T): LambdaTemplate[T] = { + val newStart = substituter(start) + val newClauses = clauses.map(substituter) + val newBlockers = blockers.map { case (b, fis) => + val bp = if (b == start) newStart else b + bp -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + } + + val newApplications = applications.map { case (b, fas) => + val bp = if (b == start) newStart else b + bp -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) + } + + val newQuantifications = quantifications.map(_.substitute(substituter)) + + val newMatchers = matchers.map { case (b, ms) => + val bp = if (b == start) newStart else b + bp -> ms.map(_.substitute(substituter)) + } + + val newLambdas = lambdas.map(_.substitute(substituter)) + + val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) + + new LambdaTemplate[T]( + ids._1 -> substituter(ids._2), + encoder, + manager, + pathVar._1 -> newStart, + arguments, + condVars, + exprVars, + condTree, + newClauses, + newBlockers, + newApplications, + newQuantifications, + newMatchers, + newLambdas, + newDependencies, + structuralKey, + stringRepr + ) + } + + private lazy val str : String = stringRepr() + override def toString : String = str + + lazy val key: (Expr, Seq[T]) = { + def rec(e: Expr): Seq[Identifier] = e match { + case Variable(id) => + if (dependencies.isDefinedAt(id)) { + Seq(id) + } else { + Seq.empty + } + + case Operator(es, _) => es.flatMap(rec) - protected val byID = new IncrementalMap[T, LambdaTemplate[T]] - protected val byType = new IncrementalMap[FunctionType, Set[(T, LambdaTemplate[T])]].withDefaultValue(Set.empty) - protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) - protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) + case _ => Seq.empty + } + + structuralKey -> rec(structuralKey).distinct.map(dependencies) + } + + override def equals(that: Any): Boolean = that match { + case t: LambdaTemplate[T] => + val (lambda1, deps1) = key + val (lambda2, deps2) = t.key + (lambda1 == lambda2) && { + (deps1 zip deps2).forall { case (id1, id2) => + (manager.byID.get(id1), manager.byID.get(id2)) match { + case (Some(t1), Some(t2)) => t1 == t2 + case _ => id1 == id2 + } + } + } + + case _ => false + } + + override def hashCode: Int = key.hashCode + + override def instantiate(substMap: Map[T, T]): Instantiation[T] = { + super.instantiate(substMap) ++ manager.instantiateAxiom(this, substMap) + } +} + +class LambdaManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(encoder) { + private[templates] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) + + protected[templates] val byID = new IncrementalMap[T, LambdaTemplate[T]] + protected val byType = new IncrementalMap[FunctionType, Set[LambdaTemplate[T]]].withDefaultValue(Set.empty) + protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) + protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) - protected def incrementals: List[IncrementalState] = - List(byID, byType, applications, freeLambdas) + private val instantiated = new IncrementalSet[(T, App[T])] - def clear(): Unit = incrementals.foreach(_.clear()) - def reset(): Unit = incrementals.foreach(_.reset()) - def push(): Unit = incrementals.foreach(_.push()) - def pop(): Unit = incrementals.foreach(_.pop()) + override protected def incrementals: List[IncrementalState] = + super.incrementals ++ List(byID, byType, applications, freeLambdas, instantiated) - def registerFree(lambdas: Seq[(TypeTree, T)]): Unit = { - for ((tpe, idT) <- lambdas) tpe match { + def registerFree(lambdas: Seq[(Identifier, T)]): Unit = { + for ((id, idT) <- lambdas) id.getType match { case ft: FunctionType => freeLambdas += ft -> (freeLambdas(ft) + idT) case _ => } } - def instantiateLambda(idT: T, template: LambdaTemplate[T]): Instantiation[T] = { - var clauses : Clauses[T] = equalityClauses(idT, template) + def instantiateLambda(template: LambdaTemplate[T]): Instantiation[T] = { + val idT = template.ids._2 + var clauses : Clauses[T] = equalityClauses(template) var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) // make sure the new lambda isn't equal to any free lambda var clauses ++= freeLambdas(template.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(pIdT, idT))) byID += idT -> template - byType += template.tpe -> (byType(template.tpe) + (idT -> template)) - for (blockedApp @ (_, App(caller, tpe, args)) <- applications(template.tpe)) { - val equals = encoder.mkEquals(idT, caller) - appBlockers += (blockedApp -> (appBlockers(blockedApp) + TemplateAppInfo(template, equals, args))) - } + if (byType(template.tpe)(template)) { + (clauses, Map.empty, Map.empty) + } else { + byType += template.tpe -> (byType(template.tpe) + template) + + for (blockedApp @ (_, App(caller, tpe, args)) <- applications(template.tpe)) { + val equals = encoder.mkEquals(idT, caller) + appBlockers += (blockedApp -> (appBlockers(blockedApp) + TemplateAppInfo(template, equals, args))) + } - (clauses, Map.empty, appBlockers) + (clauses, Map.empty, appBlockers) + } } def instantiateApp(blocker: T, app: App[T]): Instantiation[T] = { val App(caller, tpe, args) = app - var clauses : Clauses[T] = Seq.empty - var callBlockers : CallBlockers[T] = Map.empty.withDefaultValue(Set.empty) - var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) - - if (byID contains caller) { - val (newClauses, newCalls, newApps) = byID(caller).instantiate(blocker, args) + val instantiation = Instantiation.empty[T] - clauses ++= newClauses - newCalls.foreach(p => callBlockers += p._1 -> (callBlockers(p._1) ++ p._2)) - newApps.foreach(p => appBlockers += p._1 -> (appBlockers(p._1) ++ p._2)) - } else if (!freeLambdas(tpe).contains(caller)) { + if (freeLambdas(tpe).contains(caller)) instantiation else { val key = blocker -> app - // make sure that even if byType(tpe) is empty, app is recorded in blockers - // so that UnrollingBank will generate the initial block! - if (!(appBlockers contains key)) appBlockers += key -> Set.empty + if (instantiated(key)) instantiation else { + instantiated += key - for ((idT,template) <- byType(tpe)) { - val equals = encoder.mkEquals(idT, caller) - appBlockers += (key -> (appBlockers(key) + TemplateAppInfo(template, equals, args))) - } + if (byID contains caller) { + instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) + } else { - applications += tpe -> (applications(tpe) + key) - } + // make sure that even if byType(tpe) is empty, app is recorded in blockers + // so that UnrollingBank will generate the initial block! + val init = instantiation withApps Map(key -> Set.empty) + val inst = byType(tpe).foldLeft(init) { + case (instantiation, template) => + val equals = encoder.mkEquals(template.ids._2, caller) + instantiation withApp (key -> TemplateAppInfo(template, equals, args)) + } + + applications += tpe -> (applications(tpe) + key) - (clauses, callBlockers, appBlockers) + inst + } + } + } } - private def equalityClauses(idT: T, template: LambdaTemplate[T]): Seq[T] = { - byType(template.tpe).map { case (thatIdT, that) => - val equals = encoder.mkEquals(idT, thatIdT) - template.contextEquality(that) match { - case None => encoder.mkNot(equals) - case Some(Seq()) => equals - case Some(seq) => encoder.mkEquals(encoder.mkAnd(seq : _*), equals) + private def equalityClauses(template: LambdaTemplate[T]): Seq[T] = { + val (s1, deps1) = template.key + byType(template.tpe).map { that => + val (s2, deps2) = that.key + val equals = encoder.mkEquals(template.ids._2, that.ids._2) + if (s1 == s2) { + val pairs = (deps1 zip deps2).filter(p => p._1 != p._2) + if (pairs.isEmpty) equals else { + val eqs = pairs.map(p => encoder.mkEquals(p._1, p._2)) + encoder.mkEquals(encoder.mkAnd(eqs : _*), equals) + } + } else { + encoder.mkNot(equals) } }.toSeq } - } diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index fde9dc746b4a5e6207fc3ed896bbbd5700cbc79a..c7f715b5eefe573d88eeb94b45293ba7cde646bf 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -11,10 +11,11 @@ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.Quantification.{QuantificationTypeMatcher => QTM} import Instantiation._ -import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet, Stack => MutableStack, Queue} object Matcher { def argValue[T](arg: Either[T, Matcher[T]]): T = arg match { @@ -24,21 +25,27 @@ object Matcher { } case class Matcher[T](caller: T, tpe: TypeTree, args: Seq[Either[T, Matcher[T]]], encoded: T) { - override def toString = "M(" + caller + " : " + tpe + ", " + args.map(Matcher.argValue).mkString("(",",",")") + ")" + override def toString = caller + args.map { + case Right(m) => m.toString + case Left(v) => v.toString + }.mkString("(",",",")") - def substitute(substituter: T => T): Matcher[T] = copy( + def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]] = Map.empty): Matcher[T] = copy( caller = substituter(caller), - args = args.map { arg => arg.left.map(substituter).right.map(_.substitute(substituter)) }, + args = args.map { + case Left(v) => matcherSubst.get(v) match { + case Some(m) => Right(m) + case None => Left(substituter(v)) + } + case Right(m) => Right(m.substitute(substituter, matcherSubst)) + }, encoded = substituter(encoded) ) } -case class IllegalQuantificationException(expr: Expr, msg: String) - extends Exception(msg +" @ " + expr) - class QuantificationTemplate[T]( val quantificationManager: QuantificationManager[T], - val start: T, + val pathVar: (Identifier, T), val qs: (Identifier, T), val q2s: (Identifier, T), val insts: (Identifier, T), @@ -46,14 +53,39 @@ class QuantificationTemplate[T]( val quantifiers: Seq[(Identifier, T)], val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]]) { - - def instantiate(substMap: Map[T, T]): Instantiation[T] = { - quantificationManager.instantiateQuantification(this, substMap) + val lambdas: Seq[LambdaTemplate[T]]) { + + lazy val start = pathVar._2 + + def substitute(substituter: T => T): QuantificationTemplate[T] = { + new QuantificationTemplate[T]( + quantificationManager, + pathVar._1 -> substituter(start), + qs, + q2s, + insts, + guardVar, + quantifiers, + condVars, + exprVars, + condTree, + clauses.map(substituter), + blockers.map { case (b, fis) => + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + }, + applications.map { case (b, apps) => + substituter(b) -> apps.map(app => app.copy(caller = substituter(app.caller), args = app.args.map(substituter))) + }, + matchers.map { case (b, ms) => + substituter(b) -> ms.map(_.substitute(substituter)) + }, + lambdas.map(_.substitute(substituter)) + ) } } @@ -69,8 +101,9 @@ object QuantificationTemplate { quantifiers: Seq[(Identifier, T)], condVars: Map[Identifier, T], exprVars: Map[Identifier, T], + condTree: Map[Identifier, Set[Identifier]], guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], subst: Map[Identifier, T] ): QuantificationTemplate[T] = { @@ -83,245 +116,564 @@ object QuantificationTemplate { substMap = subst + q2s + insts + guards + qs) new QuantificationTemplate[T](quantificationManager, - pathVar._2, qs, q2s, insts, guards._2, quantifiers, - condVars, exprVars, clauses, blockers, applications, matchers, lambdas) + pathVar, qs, q2s, insts, guards._2, quantifiers, + condVars, exprVars, condTree, clauses, blockers, applications, matchers, lambdas) } } class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) { - private lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) + private val quantifications = new IncrementalSeq[MatcherQuantification] + private val instCtx = new InstantiationContext + + private val handled = new ContextMap + private val ignored = new ContextMap - private val quantifications = new IncrementalSeq[Quantification] - private val instantiated = new IncrementalSet[(T, Matcher[T])] - private val fInsts = new IncrementalSet[Matcher[T]] private val known = new IncrementalSet[T] + private val lambdaAxioms = new IncrementalSet[(LambdaTemplate[T], Seq[(Identifier, T)])] - private def fInstantiated = fInsts.map(m => trueT -> m) + override protected def incrementals: List[IncrementalState] = + List(quantifications, instCtx, handled, ignored, known, lambdaAxioms) ++ super.incrementals + + private sealed abstract class MatcherKey(val tpe: TypeTree) + private case class CallerKey(caller: T, tt: TypeTree) extends MatcherKey(tt) + private case class LambdaKey(lambda: Lambda, tt: TypeTree) extends MatcherKey(tt) + private case class TypeKey(tt: TypeTree) extends MatcherKey(tt) - private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = correspond(qm, m.caller, m.tpe) - private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = qm.tpe match { - case _: FunctionType => qm.tpe == tpe && (qm.caller == caller || !known(caller)) - case _ => qm.tpe == tpe + private def matcherKey(caller: T, tpe: TypeTree): MatcherKey = tpe match { + case _: FunctionType if known(caller) => CallerKey(caller, tpe) + case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structuralKey, tpe) + case _ => TypeKey(tpe) } - private val uniformQuantifiers = scala.collection.mutable.Map.empty[TypeTree, Seq[T]] + @inline + private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = + correspond(qm, m.caller, m.tpe) + + @inline + private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = + matcherKey(qm.caller, qm.tpe) == matcherKey(caller, tpe) + + private val uniformQuantMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty + private val uniformQuantSet: MutableSet[T] = MutableSet.empty + + def isQuantifier(idT: T): Boolean = uniformQuantSet(idT) private def uniformSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { qs.groupBy(_._1.getType).flatMap { case (tpe, qst) => - val prev = uniformQuantifiers.get(tpe) match { + val prev = uniformQuantMap.get(tpe) match { case Some(seq) => seq case None => Seq.empty } if (prev.size >= qst.size) { - qst.map(_._2) zip prev.take(qst.size - 1) + qst.map(_._2) zip prev.take(qst.size) } else { val (handled, newQs) = qst.splitAt(prev.size) val uQs = newQs.map(p => p._2 -> encoder.encodeId(p._1)) - uniformQuantifiers(tpe) = prev ++ uQs.map(_._2) + + uniformQuantMap(tpe) = prev ++ uQs.map(_._2) + uniformQuantSet ++= uQs.map(_._2) + (handled.map(_._2) zip prev) ++ uQs } }.toMap } - override protected def incrementals: List[IncrementalState] = - List(quantifications, instantiated, fInsts, known) ++ super.incrementals + def assumptions: Seq[T] = quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq - def assumptions: Seq[T] = quantifications.map(_.currentQ2Var).toSeq + def instantiations: (Map[TypeTree, Matchers], Map[T, Matchers], Map[Lambda, Matchers]) = { + var typeInsts: Map[TypeTree, Matchers] = Map.empty + var partialInsts: Map[T, Matchers] = Map.empty + var lambdaInsts: Map[Lambda, Matchers] = Map.empty - def instantiations: Seq[(T, Matcher[T])] = instantiated.toSeq ++ fInstantiated + val instantiations = handled.instantiations ++ instCtx.map.instantiations + for ((key, matchers) <- instantiations) key match { + case TypeKey(tpe) => typeInsts += tpe -> matchers + case CallerKey(caller, _) => partialInsts += caller -> matchers + case LambdaKey(lambda, _) => lambdaInsts += lambda -> matchers + } - def instantiations(caller: T, tpe: TypeTree): Seq[(T, Matcher[T])] = - instantiations.filter { case (b,m) => correspond(m, caller, tpe) } + (typeInsts, partialInsts, lambdaInsts) + } - override def registerFree(ids: Seq[(TypeTree, T)]): Unit = { + override def registerFree(ids: Seq[(Identifier, T)]): Unit = { super.registerFree(ids) known ++= ids.map(_._2) } - private class Quantification ( - val qs: (Identifier, T), - val q2s: (Identifier, T), - val insts: (Identifier, T), - val guardVar: T, - val quantified: Set[T], - val matchers: Set[Matcher[T]], - val allMatchers: Map[T, Set[Matcher[T]]], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val lambdas: Map[T, LambdaTemplate[T]]) { + private def matcherDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { + case Right(ma) => matcherDepth(ma) + case _ => 0 + }).max - var currentQ2Var: T = qs._2 - private var slaves: Seq[(T, Map[T,T], Quantification)] = Nil - - private def mappings(blocker: T, matcher: Matcher[T]) - (implicit instantiated: Iterable[(T, Matcher[T])]): Set[(T, Map[T, T])] = { - - // Build a mapping from applications in the quantified statement to all potential concrete - // applications previously encountered. Also make sure the current `app` is in the mapping - // as other instantiations have been performed previously when the associated applications - // were first encountered. - val matcherMappings: Set[Set[(T, Matcher[T], Matcher[T])]] = matchers - // 1. select an application in the quantified proposition for which the current app can - // be bound when generating the new constraints - .filter(qm => correspond(qm, matcher)) - // 2. build the instantiation mapping associated to the chosen current application binding - .flatMap { bindingMatcher => + private def encodeEnablers(es: Set[T]): T = + if (es.isEmpty) trueT else encoder.mkAnd(es.toSeq.sortBy(_.toString) : _*) - // 2.1. select all potential matches for each quantified application - val matcherToInstances = matchers - .map(qm => if (qm == bindingMatcher) { - bindingMatcher -> Set(blocker -> matcher) - } else { - val instances: Set[(T, Matcher[T])] = instantiated.filter { case (b, m) => correspond(qm, m) }.toSet + private type Matchers = Set[(T, Matcher[T])] - // concrete applications can appear multiple times in the constraint, and this is also the case - // for the current application for which we are generating the constraints - val withCurrent = if (correspond(qm, matcher)) instances + (blocker -> matcher) else instances + private class Context private(ctx: Map[Matcher[T], Set[Set[T]]]) extends Iterable[(Set[T], Matcher[T])] { + def this() = this(Map.empty) - qm -> withCurrent - }).toMap + def apply(p: (Set[T], Matcher[T])): Boolean = ctx.get(p._2) match { + case None => false + case Some(blockerSets) => blockerSets(p._1) || blockerSets.exists(set => set.subsetOf(p._1)) + } - // 2.2. based on the possible bindings for each quantified application, build a set of - // instantiation mappings that can be used to instantiate all necessary constraints - extractMappings(matcherToInstances) - } + def +(p: (Set[T], Matcher[T])): Context = if (apply(p)) this else { + val prev = ctx.getOrElse(p._2, Seq.empty) + val newSet = prev.filterNot(set => p._1.subsetOf(set)).toSet + p._1 + new Context(ctx + (p._2 -> newSet)) + } + + def ++(that: Context): Context = that.foldLeft(this)((ctx, p) => ctx + p) - for (mapping <- matcherMappings) yield extractSubst(quantified, mapping) + def iterator = ctx.toSeq.flatMap { case (m, bss) => bss.map(bs => bs -> m) }.iterator + def toMatchers: Matchers = this.map(p => encodeEnablers(p._1) -> p._2).toSet + } + + private class ContextMap( + private var tpeMap: MutableMap[TypeTree, Context] = MutableMap.empty, + private var funMap: MutableMap[MatcherKey, Context] = MutableMap.empty + ) extends IncrementalState { + private val stack = new MutableStack[(MutableMap[TypeTree, Context], MutableMap[MatcherKey, Context])] + + def clear(): Unit = { + stack.clear() + tpeMap.clear() + funMap.clear() } - private def extractSlaveInfo(enabler: T, senabler: T, subst: Map[T,T], ssubst: Map[T,T]): (T, Map[T,T]) = { - val substituter = encoder.substitute(subst) - val slaveEnabler = encoder.mkAnd(enabler, substituter(senabler)) - val slaveSubst = ssubst.map(p => p._1 -> substituter(p._2)) - (slaveEnabler, slaveSubst) + def reset(): Unit = clear() + + def push(): Unit = { + stack.push((tpeMap, funMap)) + tpeMap = tpeMap.clone + funMap = funMap.clone } - private def instantiate(enabler: T, subst: Map[T, T], seen: Set[Quantification]): Instantiation[T] = { - if (seen(this)) { - Instantiation.empty[T] - } else { - val nextQ2Var = encoder.encodeId(q2s._1) + def pop(): Unit = { + val (ptpeMap, pfunMap) = stack.pop() + tpeMap = ptpeMap + funMap = pfunMap + } + + def +=(p: (Set[T], Matcher[T])): Unit = matcherKey(p._2.caller, p._2.tpe) match { + case TypeKey(tpe) => tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) + p + case key => funMap(key) = funMap.getOrElse(key, new Context) + p + } + + def merge(that: ContextMap): this.type = { + for ((tpe, values) <- that.tpeMap) tpeMap(tpe) = tpeMap.getOrElse(tpe, new Context) ++ values + for ((caller, values) <- that.funMap) funMap(caller) = funMap.getOrElse(caller, new Context) ++ values + this + } + + def get(caller: T, tpe: TypeTree): Context = + funMap.getOrElse(matcherKey(caller, tpe), new Context) ++ tpeMap.getOrElse(tpe, new Context) - val baseSubstMap = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } - val lambdaSubstMap = lambdas map { case (idT, lambda) => idT -> encoder.encodeId(lambda.id) } - val substMap = subst ++ baseSubstMap ++ lambdaSubstMap + - (qs._2 -> currentQ2Var) + (guardVar -> enabler) + (q2s._2 -> nextQ2Var) + - (insts._2 -> encoder.encodeId(insts._1)) + def get(key: MatcherKey): Context = key match { + case TypeKey(tpe) => tpeMap.getOrElse(tpe, new Context) + case key => funMap.getOrElse(key, new Context) + } + + def instantiations: Map[MatcherKey, Matchers] = + (funMap.toMap ++ tpeMap.map { case (tpe,ms) => TypeKey(tpe) -> ms }).mapValues(_.toMatchers) + } + + private class InstantiationContext private ( + private var _instantiated : Context, val map : ContextMap + ) extends IncrementalState { + + private val stack = new MutableStack[Context] + + def this() = this(new Context, new ContextMap) + + def clear(): Unit = { + stack.clear() + map.clear() + _instantiated = new Context + } + + def reset(): Unit = clear() + + def push(): Unit = { + stack.push(_instantiated) + map.push() + } + + def pop(): Unit = { + _instantiated = stack.pop() + map.pop() + } - var instantiation = Template.instantiate(encoder, QuantificationManager.this, - clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) + def instantiated: Context = _instantiated + def apply(p: (Set[T], Matcher[T])): Boolean = _instantiated(p) - for { - (senabler, ssubst, slave) <- slaves - (slaveEnabler, slaveSubst) = extractSlaveInfo(enabler, senabler, subst, ssubst) - } instantiation ++= slave.instantiate(slaveEnabler, slaveSubst, seen + this) + def corresponding(m: Matcher[T]): Context = map.get(m.caller, m.tpe) - currentQ2Var = nextQ2Var + def instantiate(blockers: Set[T], matcher: Matcher[T])(qs: MatcherQuantification*): Instantiation[T] = { + if (this(blockers -> matcher)) { + Instantiation.empty[T] + } else { + map += (blockers -> matcher) + _instantiated += (blockers -> matcher) + var instantiation = Instantiation.empty[T] + for (q <- qs) instantiation ++= q.instantiate(blockers, matcher) instantiation } } - def register(senabler: T, ssubst: Map[T, T], slave: Quantification): Instantiation[T] = { - var instantiation = Instantiation.empty[T] + def merge(that: InstantiationContext): this.type = { + _instantiated ++= that._instantiated + map.merge(that.map) + this + } + } + + private trait MatcherQuantification { + val pathVar: (Identifier, T) + val quantified: Set[T] + val matchers: Set[Matcher[T]] + val allMatchers: Map[T, Set[Matcher[T]]] + val condVars: Map[Identifier, T] + val exprVars: Map[Identifier, T] + val condTree: Map[Identifier, Set[Identifier]] + val clauses: Seq[T] + val blockers: Map[T, Set[TemplateCallInfo[T]]] + val applications: Map[T, Set[App[T]]] + val lambdas: Seq[LambdaTemplate[T]] + + lazy val start = pathVar._2 + + private lazy val depth = matchers.map(matcherDepth).max + private lazy val transMatchers: Set[Matcher[T]] = (for { + (b, ms) <- allMatchers.toSeq + m <- ms if !matchers(m) && matcherDepth(m) <= depth + } yield m).toSet + + /* Build a mapping from applications in the quantified statement to all potential concrete + * applications previously encountered. Also make sure the current `app` is in the mapping + * as other instantiations have been performed previously when the associated applications + * were first encountered. + */ + private def mappings(bs: Set[T], matcher: Matcher[T]): Set[Set[(Set[T], Matcher[T], Matcher[T])]] = { + /* 1. select an application in the quantified proposition for which the current app can + * be bound when generating the new constraints + */ + matchers.filter(qm => correspond(qm, matcher)) + + /* 2. build the instantiation mapping associated to the chosen current application binding */ + .flatMap { bindingMatcher => + + /* 2.1. select all potential matches for each quantified application */ + val matcherToInstances = matchers + .map(qm => if (qm == bindingMatcher) { + bindingMatcher -> Set(bs -> matcher) + } else { + qm -> instCtx.corresponding(qm) + }).toMap + + /* 2.2. based on the possible bindings for each quantified application, build a set of + * instantiation mappings that can be used to instantiate all necessary constraints + */ + val allMappings = matcherToInstances.foldLeft[Set[Set[(Set[T], Matcher[T], Matcher[T])]]](Set(Set.empty)) { + case (mappings, (qm, instances)) => Set(instances.toSeq.flatMap { + case (bs, m) => mappings.map(mapping => mapping + ((bs, qm, m))) + } : _*) + } + + /* 2.3. filter out bindings that don't make sense where abstract sub-matchers + * (matchers in arguments of other matchers) are bound to different concrete + * matchers in the argument and quorum positions + */ + allMappings.filter { s => + def expand(ms: Traversable[(Either[T,Matcher[T]], Either[T,Matcher[T]])]): Set[(Matcher[T], Matcher[T])] = ms.flatMap { + case (Right(qm), Right(m)) => Set(qm -> m) ++ expand(qm.args zip m.args) + case _ => Set.empty[(Matcher[T], Matcher[T])] + }.toSet + + expand(s.map(p => Right(p._2) -> Right(p._3))).groupBy(_._1).forall(_._2.size == 1) + } + + allMappings + } + } + + private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Either[T, Matcher[T]]], Boolean) = { + var constraints: Set[T] = Set.empty + var eqConstraints: Set[(T, T)] = Set.empty + var matcherEqs: List[(T, T)] = Nil + var subst: Map[T, Either[T, Matcher[T]]] = Map.empty for { - instantiated <- List(instantiated, fInstantiated) - (blocker, matcher) <- instantiated - (enabler, subst) <- mappings(blocker, matcher)(instantiated) - (slaveEnabler, slaveSubst) = extractSlaveInfo(enabler, senabler, subst, ssubst) - } instantiation ++= slave.instantiate(slaveEnabler, slaveSubst, Set(this)) + (bs, qm @ Matcher(qcaller, _, qargs, _), m @ Matcher(caller, _, args, _)) <- mapping + _ = constraints ++= bs + _ = matcherEqs :+= qm.encoded -> m.encoded + (qarg, arg) <- (qargs zip args) + } qarg match { + case Left(quant) if subst.isDefinedAt(quant) => + eqConstraints += (quant -> Matcher.argValue(arg)) + case Left(quant) if quantified(quant) => + subst += quant -> arg + case Right(qam) => + val argVal = Matcher.argValue(arg) + eqConstraints += (qam.encoded -> argVal) + matcherEqs :+= qam.encoded -> argVal + } - slaves :+= (senabler, ssubst, slave) + val substituter = encoder.substitute(subst.mapValues(Matcher.argValue)) + val substConstraints = constraints.filter(_ != trueT).map(substituter) + val substEqs = eqConstraints.map(p => substituter(p._1) -> p._2) + .filter(p => p._1 != p._2).map(p => encoder.mkEquals(p._1, p._2)) + val enablers = substConstraints ++ substEqs + val isStrict = matcherEqs.forall(p => substituter(p._1) == p._2) - instantiation + (enablers, subst, isStrict) } - def instantiate(blocker: T, matcher: Matcher[T])(implicit instantiated: Iterable[(T, Matcher[T])]): Instantiation[T] = { + def instantiate(bs: Set[T], matcher: Matcher[T]): Instantiation[T] = { var instantiation = Instantiation.empty[T] - for ((enabler, subst) <- mappings(blocker, matcher)) { - instantiation ++= instantiate(enabler, subst, Set.empty) + for (mapping <- mappings(bs, matcher)) { + val (enablers, subst, isStrict) = extractSubst(mapping) + val (enabler, optEnabler) = freshBlocker(enablers) + + if (optEnabler.isDefined) { + instantiation = instantiation withClause encoder.mkEquals(enabler, optEnabler.get) + } + + val baseSubstMap = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ + freshConds(pathVar._1 -> enabler, condVars, condTree) + val lambdaSubstMap = lambdas map (lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) + val substMap = subst.mapValues(Matcher.argValue) ++ baseSubstMap ++ lambdaSubstMap ++ instanceSubst(enablers) + + if (!skip(substMap)) { + instantiation ++= Template.instantiate(encoder, QuantificationManager.this, + clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) + + val msubst = subst.collect { case (c, Right(m)) => c -> m } + val substituter = encoder.substitute(substMap) + + for ((b,ms) <- allMatchers; m <- ms) { + val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b))) + val sm = m.substitute(substituter, matcherSubst = msubst) + + if (matchers(m)) { + handled += sb -> sm + } else if (transMatchers(m) && isStrict) { + instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) + } else { + ignored += sb -> sm + } + } + } } instantiation } + + protected def instanceSubst(enablers: Set[T]): Map[T, T] + + protected def skip(subst: Map[T, T]): Boolean = false } - private def extractMappings( - bindings: Map[Matcher[T], Set[(T, Matcher[T])]] - ): Set[Set[(T, Matcher[T], Matcher[T])]] = { - val allMappings = bindings.foldLeft[Set[Set[(T, Matcher[T], Matcher[T])]]](Set(Set.empty)) { - case (mappings, (qm, instances)) => Set(instances.toSeq.flatMap { - case (b, m) => mappings.map(mapping => mapping + ((b, qm, m))) - } : _*) - } + private class Quantification ( + val pathVar: (Identifier, T), + val qs: (Identifier, T), + val q2s: (Identifier, T), + val insts: (Identifier, T), + val guardVar: T, + val quantified: Set[T], + val matchers: Set[Matcher[T]], + val allMatchers: Map[T, Set[Matcher[T]]], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { - def subBindings(b: T, sm: Matcher[T], m: Matcher[T]): Set[(T, Matcher[T], Matcher[T])] = { - (for ((sarg, arg) <- sm.args zip m.args) yield { - (sarg, arg) match { - case (Right(sargm), Right(argm)) => Set((b, sargm, argm)) ++ subBindings(b, sargm, argm) - case _ => Set.empty[(T, Matcher[T], Matcher[T])] - } - }).flatten.toSet - } + var currentQ2Var: T = qs._2 + + protected def instanceSubst(enablers: Set[T]): Map[T, T] = { + val nextQ2Var = encoder.encodeId(q2s._1) - allMappings.filter { s => - val withSubs = s ++ s.flatMap { case (b, sm, m) => subBindings(b, sm, m) } - withSubs.groupBy(_._2).forall(_._2.size == 1) + val subst = Map(qs._2 -> currentQ2Var, guardVar -> encodeEnablers(enablers), + q2s._2 -> nextQ2Var, insts._2 -> encoder.encodeId(insts._1)) + + currentQ2Var = nextQ2Var + subst } } - private def extractSubst(quantified: Set[T], mapping: Set[(T, Matcher[T], Matcher[T])]): (T, Map[T,T]) = { - var constraints: List[T] = Nil - var subst: Map[T, T] = Map.empty + private lazy val blockerId = FreshIdentifier("blocker", BooleanType, true) + private lazy val blockerCache: MutableMap[T, T] = MutableMap.empty + private def freshBlocker(enablers: Set[T]): (T, Option[T]) = enablers.toSeq match { + case Seq(b) if isBlocker(b) => (b, None) + case _ => + val enabler = encodeEnablers(enablers) + blockerCache.get(enabler) match { + case Some(b) => (b, None) + case None => + val nb = encoder.encodeId(blockerId) + blockerCache += enabler -> nb + for (b <- enablers if isBlocker(b)) implies(b, nb) + blocker(nb) + (nb, Some(enabler)) + } + } - for { - (b, Matcher(qcaller, _, qargs, _), Matcher(caller, _, args, _)) <- mapping - _ = constraints :+= b - (qarg, arg) <- (qargs zip args) - argVal = Matcher.argValue(arg) - } qarg match { - case Left(quant) if subst.isDefinedAt(quant) => - constraints :+= encoder.mkEquals(quant, argVal) - case Left(quant) if quantified(quant) => - subst += quant -> argVal - case _ => - constraints :+= encoder.mkEquals(Matcher.argValue(qarg), argVal) - } + private class LambdaAxiom ( + val pathVar: (Identifier, T), + val blocker: T, + val guardVar: T, + val quantified: Set[T], + val matchers: Set[Matcher[T]], + val allMatchers: Map[T, Set[Matcher[T]]], + val condVars: Map[Identifier, T], + val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val applications: Map[T, Set[App[T]]], + val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { - val enabler = - if (constraints.isEmpty) trueT - else if (constraints.size == 1) constraints.head - else encoder.mkAnd(constraints : _*) + protected def instanceSubst(enablers: Set[T]): Map[T, T] = { + // no need to add an equality clause here since it is already contained in the Axiom clauses + val (newBlocker, optEnabler) = freshBlocker(enablers) + val guardT = if (optEnabler.isDefined) encoder.mkAnd(start, optEnabler.get) else start + Map(guardVar -> guardT, blocker -> newBlocker) + } - (encoder.substitute(subst)(enabler), subst) + override protected def skip(subst: Map[T, T]): Boolean = { + val substituter = encoder.substitute(subst) + allMatchers.forall { case (b, ms) => + ms.forall(m => matchers(m) || instCtx(Set(substituter(b)) -> m.substitute(substituter))) + } + } } - def instantiateQuantification(template: QuantificationTemplate[T], substMap: Map[T, T]): Instantiation[T] = { - val quantified = template.quantifiers.map(_._2).toSet - - val allMatchers: Map[T, Set[Matcher[T]]] = { - def rec(templates: Map[T, LambdaTemplate[T]]): Map[T, Set[Matcher[T]]] = - templates.foldLeft(Map.empty[T, Set[Matcher[T]]]) { - case (matchers, (_, template)) => matchers merge template.matchers merge rec(template.lambdas) + private def extractQuorums( + quantified: Set[T], + matchers: Set[Matcher[T]], + lambdas: Seq[LambdaTemplate[T]] + ): Seq[Set[Matcher[T]]] = { + val extMatchers: Set[Matcher[T]] = { + def rec(templates: Seq[LambdaTemplate[T]]): Set[Matcher[T]] = + templates.foldLeft(Set.empty[Matcher[T]]) { + case (matchers, template) => matchers ++ template.matchers.flatMap(_._2) ++ rec(template.lambdas) } - template.matchers merge rec(template.lambdas) + matchers ++ rec(lambdas) } - val quantifiedMatchers = (for { - (_, ms) <- allMatchers - m @ Matcher(_, _, args, _) <- ms + val quantifiedMatchers = for { + m @ Matcher(_, _, args, _) <- extMatchers if args exists (_.left.exists(quantified)) - } yield m).toSet + } yield m - val matchQuorums: Seq[Set[Matcher[T]]] = purescala.Quantification.extractQuorums( - quantifiedMatchers, quantified, + purescala.Quantification.extractQuorums(quantifiedMatchers, quantified, (m: Matcher[T]) => m.args.collect { case Right(m) if quantifiedMatchers(m) => m }.toSet, (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) + } + + def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, T]): Instantiation[T] = { + val quantifiers = template.arguments map { + case (id, idT) => id -> substMap(idT) + } filter (p => isQuantifier(p._2)) + + if (quantifiers.isEmpty || lambdaAxioms(template -> quantifiers)) { + Instantiation.empty[T] + } else { + lambdaAxioms += template -> quantifiers + val blockerT = encoder.encodeId(blockerId) + + val guard = FreshIdentifier("guard", BooleanType, true) + val guardT = encoder.encodeId(guard) + + val substituter = encoder.substitute(substMap + (template.start -> blockerT)) + val allMatchers = template.matchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) } + val qMatchers = allMatchers.flatMap(_._2).toSet + + val encArgs = template.args map substituter + val app = Application(Variable(template.ids._1), template.arguments.map(_._1.toVariable)) + val appT = encoder.encodeExpr((template.arguments.map(_._1) zip encArgs).toMap + template.ids)(app) + val selfMatcher = Matcher(template.ids._2, template.tpe, encArgs.map(Left(_)), appT) + + val enablingClause = encoder.mkImplies(guardT, blockerT) + + instantiateAxiom( + template.pathVar._1 -> substMap(template.start), + blockerT, + guardT, + quantifiers, + qMatchers, + allMatchers + (template.start -> (allMatchers.getOrElse(template.start, Set.empty) + selfMatcher)), + template.condVars map { case (id, idT) => id -> substituter(idT) }, + template.exprVars map { case (id, idT) => id -> substituter(idT) }, + template.condTree, + (template.clauses map substituter) :+ enablingClause, + template.blockers map { case (b, fis) => + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) + }, + template.applications map { case (b, apps) => + substituter(b) -> apps.map(app => app.copy(caller = substituter(app.caller), args = app.args.map(substituter))) + }, + template.lambdas map (_.substitute(substituter)) + ) + } + } + + def instantiateAxiom( + pathVar: (Identifier, T), + blocker: T, + guardVar: T, + quantifiers: Seq[(Identifier, T)], + matchers: Set[Matcher[T]], + allMatchers: Map[T, Set[Matcher[T]]], + condVars: Map[Identifier, T], + exprVars: Map[Identifier, T], + condTree: Map[Identifier, Set[Identifier]], + clauses: Seq[T], + blockers: Map[T, Set[TemplateCallInfo[T]]], + applications: Map[T, Set[App[T]]], + lambdas: Seq[LambdaTemplate[T]] + ): Instantiation[T] = { + val quantified = quantifiers.map(_._2).toSet + val matchQuorums = extractQuorums(quantified, matchers, lambdas) + + var instantiation = Instantiation.empty[T] + + for (matchers <- matchQuorums) { + val axiom = new LambdaAxiom(pathVar, blocker, guardVar, quantified, + matchers, allMatchers, condVars, exprVars, condTree, + clauses, blockers, applications, lambdas + ) + + quantifications += axiom + + val newCtx = new InstantiationContext() + for ((b,m) <- instCtx.instantiated) { + instantiation ++= newCtx.instantiate(b, m)(axiom) + } + instCtx.merge(newCtx) + } + + val quantifierSubst = uniformSubst(quantifiers) + val substituter = encoder.substitute(quantifierSubst) + + for { + m <- matchers + sm = m.substitute(substituter) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) + + instantiation + } + + def instantiateQuantification(template: QuantificationTemplate[T], substMap: Map[T, T]): Instantiation[T] = { + val quantified = template.quantifiers.map(_._2).toSet + val matchQuorums = extractQuorums(quantified, template.matchers.flatMap(_._2).toSet, template.lambdas) var instantiation = Instantiation.empty[T] @@ -330,13 +682,16 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val subst = substMap + (template.qs._2 -> newQ) val substituter = encoder.substitute(subst) - val quantification = new Quantification(template.qs._1 -> newQ, + val quantification = new Quantification( + template.pathVar._1 -> substituter(template.start), + template.qs._1 -> newQ, template.q2s, template.insts, template.guardVar, quantified, - matchers map (m => m.substitute(substituter)), - allMatchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) }, + matchers map (_.substitute(substituter)), + template.matchers map { case (b, ms) => substituter(b) -> ms.map(_.substitute(substituter)) }, template.condVars, template.exprVars, + template.condTree, template.clauses map substituter, template.blockers map { case (b, fis) => substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) @@ -344,52 +699,17 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage template.applications map { case (b, fas) => substituter(b) -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) }, - template.lambdas map { case (idT, template) => substituter(idT) -> template.substitute(subst) } + template.lambdas map (_.substitute(substituter)) ) - def extendClauses(master: Quantification, slave: Quantification): Instantiation[T] = { - val bindingsMap: Map[Matcher[T], Set[(T, Matcher[T])]] = slave.matchers.map { sm => - val instances = master.allMatchers.toSeq.flatMap { case (b, ms) => ms.map(b -> _) } - sm -> instances.filter(p => correspond(p._2, sm)).toSet - }.toMap - - val allMappings = extractMappings(bindingsMap) - val filteredMappings = allMappings.filter { s => - s.exists { case (b, sm, m) => !master.matchers(m) } && - s.exists { case (b, sm, m) => sm != m } && - s.forall { case (b, sm, m) => - (sm.args zip m.args).forall { - case (Right(_), Left(_)) => false - case _ => true - } - } - } - - var instantiation = Instantiation.empty[T] - - for (mapping <- filteredMappings) { - val (enabler, subst) = extractSubst(slave.quantified, mapping) - instantiation ++= master.register(enabler, subst, slave) - } - - instantiation - } - - val allSet = quantification.allMatchers.flatMap(_._2).toSet - for (q <- quantifications) { - val allQSet = q.allMatchers.flatMap(_._2).toSet - if (quantification.matchers.forall(m => allQSet.exists(qm => correspond(qm, m)))) - instantiation ++= extendClauses(q, quantification) - - if (q.matchers.forall(qm => allSet.exists(m => correspond(qm, m)))) - instantiation ++= extendClauses(quantification, q) - } + quantifications += quantification - for (instantiated <- List(instantiated, fInstantiated); (b, m) <- instantiated) { - instantiation ++= quantification.instantiate(b, m)(instantiated) + val newCtx = new InstantiationContext() + for ((b,m) <- instCtx.instantiated) { + instantiation ++= newCtx.instantiate(b, m)(quantification) } + instCtx.merge(newCtx) - quantifications += quantification quantification.qs._2 } @@ -404,34 +724,74 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val quantifierSubst = uniformSubst(template.quantifiers) val substituter = encoder.substitute(substMap ++ quantifierSubst) - for ((_, ms) <- template.matchers; m <- ms) { - val sm = m.substitute(substituter) - - if (!fInsts.exists(fm => correspond(sm, fm) && sm.args == fm.args)) { - for (q <- quantifications) { - instantiation ++= q.instantiate(trueT, sm)(fInstantiated) - } - - fInsts += sm - } - } + for { + (_, ms) <- template.matchers; m <- ms + sm = m.substitute(substituter) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) instantiation } def instantiateMatcher(blocker: T, matcher: Matcher[T]): Instantiation[T] = { - val qInst = if (instantiated(blocker -> matcher)) Instantiation.empty[T] else { - var instantiation = Instantiation.empty[T] - for (q <- quantifications) { - instantiation ++= q.instantiate(blocker, matcher)(instantiated) - } + instCtx.instantiate(Set(blocker), matcher)(quantifications.toSeq : _*) + } - instantiated += (blocker -> matcher) + private type SetDef = (T, (Identifier, T), (Identifier, T), Seq[T], T, T, T) + private val setConstructors: MutableMap[TypeTree, SetDef] = MutableMap.empty + + def checkClauses: Seq[T] = { + val clauses = new scala.collection.mutable.ListBuffer[T] + + for ((key, ctx) <- ignored.instantiations) { + val insts = instCtx.map.get(key).toMatchers + + val QTM(argTypes, _) = key.tpe + val tupleType = tupleTypeWrap(argTypes) + + val (guardT, (setPrev, setPrevT), (setNext, setNextT), elems, containsT, emptyT, setT) = + setConstructors.getOrElse(tupleType, { + val guard = FreshIdentifier("guard", BooleanType) + val setPrev = FreshIdentifier("prevSet", SetType(tupleType)) + val setNext = FreshIdentifier("nextSet", SetType(tupleType)) + val elems = argTypes.map(tpe => FreshIdentifier("elem", tpe)) + + val elemExpr = tupleWrap(elems.map(_.toVariable)) + val contextExpr = And( + Implies(Variable(guard), Equals(Variable(setNext), + SetUnion(Variable(setPrev), FiniteSet(Set(elemExpr), tupleType)))), + Implies(Not(Variable(guard)), Equals(Variable(setNext), Variable(setPrev)))) + + val guardP = guard -> encoder.encodeId(guard) + val setPrevP = setPrev -> encoder.encodeId(setPrev) + val setNextP = setNext -> encoder.encodeId(setNext) + val elemsP = elems.map(e => e -> encoder.encodeId(e)) + + val containsT = encoder.encodeExpr(elemsP.toMap + setPrevP)(ElementOfSet(elemExpr, setPrevP._1.toVariable)) + val emptyT = encoder.encodeExpr(Map.empty)(FiniteSet(Set.empty, tupleType)) + val contextT = encoder.encodeExpr(Map(guardP, setPrevP, setNextP) ++ elemsP)(contextExpr) + + val setDef = (guardP._2, setPrevP, setNextP, elemsP.map(_._2), containsT, emptyT, contextT) + setConstructors += key.tpe -> setDef + setDef + }) + + var prev = emptyT + for ((b, m) <- insts.toSeq) { + val next = encoder.encodeId(setNext) + val argsMap = (elems zip m.args).map { case (idT, arg) => idT -> Matcher.argValue(arg) } + val substMap = Map(guardT -> b, setPrevT -> prev, setNextT -> next) ++ argsMap + prev = next + clauses += encoder.substitute(substMap)(setT) + } - instantiation + val setMap = Map(setPrevT -> prev) + for ((b, m) <- ctx.toSeq) { + val substMap = setMap ++ (elems zip m.args).map(p => p._1 -> Matcher.argValue(p._2)) + clauses += encoder.substitute(substMap)(encoder.mkImplies(b, containsT)) + } } - qInst + clauses.toSeq } - } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index f6484ff7be0152f316d8cf911c56a94320fade5a..60ddb145058eeb33b4bcb796307fba87d586d52f 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -9,14 +9,29 @@ import purescala.Expressions._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps._ import purescala.Definitions._ import purescala.Constructors._ +import purescala.Quantification._ + +import Instantiation._ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val assumePreHolds: Boolean) { private var cache = Map[TypedFunDef, FunctionTemplate[T]]() private var cacheExpr = Map[Expr, FunctionTemplate[T]]() + private type Clauses = ( + Map[Identifier,T], + Map[Identifier,T], + Map[Identifier, Set[Identifier]], + Map[Identifier, Seq[Expr]], + Seq[LambdaTemplate[T]], + Seq[QuantificationTemplate[T]] + ) + + private def emptyClauses: Clauses = (Map.empty, Map.empty, Map.empty, Map.empty, Seq.empty, Seq.empty) + val manager = new QuantificationManager[T](encoder) def mkTemplate(body: Expr): FunctionTemplate[T] = { @@ -70,16 +85,14 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val substMap : Map[Identifier, T] = arguments.toMap + pathVar - val (bodyConds, bodyExprs, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { - invocationEqualsBody.map(expr => mkClauses(start, expr, substMap)).getOrElse { - (Map[Identifier,T](), Map[Identifier,T](), Map[Identifier,Seq[Expr]](), Map[T,LambdaTemplate[T]](), Seq[QuantificationTemplate[T]]()) - } + val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { + invocationEqualsBody.map(expr => mkClauses(start, expr, substMap)).getOrElse(emptyClauses) } else { mkClauses(start, lambdaBody.get, substMap) } // Now the postcondition. - val (condVars, exprVars, guardedExprs, lambdas, quantifications) = tfd.postcondition match { + val (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) = tfd.postcondition match { case Some(post) => val newPost : Expr = application(matchToIfThenElse(post), Seq(invocation)) @@ -94,19 +107,15 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], newPost } - val (postConds, postExprs, postGuarded, postLambdas, postQuantifications) = mkClauses(start, postHolds, substMap) - val allGuarded = (bodyGuarded.keys ++ postGuarded.keys).map { k => - k -> (bodyGuarded.getOrElse(k, Seq.empty) ++ postGuarded.getOrElse(k, Seq.empty)) - }.toMap - - (bodyConds ++ postConds, bodyExprs ++ postExprs, allGuarded, bodyLambdas ++ postLambdas, bodyQuantifications ++ postQuantifications) + val (postConds, postExprs, postTree, postGuarded, postLambdas, postQuantifications) = mkClauses(start, postHolds, substMap) + (bodyConds ++ postConds, bodyExprs ++ postExprs, bodyTree merge postTree, bodyGuarded merge postGuarded, bodyLambdas ++ postLambdas, bodyQuantifications ++ postQuantifications) case None => - (bodyConds, bodyExprs, bodyGuarded, bodyLambdas, bodyQuantifications) + (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) } val template = FunctionTemplate(tfd, encoder, manager, - pathVar, arguments, condVars, exprVars, guardedExprs, quantifications, lambdas, isRealFunDef) + pathVar, arguments, condVars, exprVars, condTree, guardedExprs, quantifications, lambdas, isRealFunDef) cache += tfd -> template template } @@ -133,11 +142,60 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], andJoin(rec(invocation, body, args, inlineFirst)) } - def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): - (Map[Identifier,T], Map[Identifier,T], Map[Identifier, Seq[Expr]], Map[T, LambdaTemplate[T]], Seq[QuantificationTemplate[T]]) = { + private def minimalFlattening(inits: Set[Identifier], conj: Expr): (Set[Identifier], Expr) = { + var mapping: Map[Expr, Expr] = Map.empty + var quantified: Set[Identifier] = inits + var quantifierEqualities: Seq[(Expr, Identifier)] = Seq.empty + + val newConj = postMap { + case expr if mapping.isDefinedAt(expr) => + Some(mapping(expr)) + + case expr @ QuantificationMatcher(c, args) => + val isMatcher = args.exists { case Variable(id) => quantified(id) case _ => false } + val isRelevant = (variablesOf(expr) & quantified).nonEmpty + if (!isMatcher && isRelevant) { + val newArgs = args.map { + case arg @ QuantificationMatcher(_, _) if (variablesOf(arg) & quantified).nonEmpty => + val id = FreshIdentifier("flat", arg.getType) + quantifierEqualities :+= (arg -> id) + quantified += id + Variable(id) + case arg => arg + } + + val newExpr = replace((args zip newArgs).toMap, expr) + mapping += expr -> newExpr + Some(newExpr) + } else { + None + } + + case _ => None + } (conj) + + val flatConj = implies(andJoin(quantifierEqualities.map { + case (arg, id) => Equals(arg, Variable(id)) + }), newConj) + + (quantified, flatConj) + } + + def mkClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): Clauses = { + val (p, (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications)) = mkExprClauses(pathVar, expr, substMap) + val allGuarded = guardedExprs + (pathVar -> (p +: guardedExprs.getOrElse(pathVar, Seq.empty))) + (condVars, exprVars, condTree, allGuarded, lambdas, quantifications) + } + + private def mkExprClauses(pathVar: Identifier, expr: Expr, substMap: Map[Identifier, T]): (Expr, Clauses) = { var condVars = Map[Identifier, T]() - @inline def storeCond(id: Identifier) : Unit = condVars += id -> encoder.encodeId(id) + var condTree = Map[Identifier, Set[Identifier]](pathVar -> Set.empty).withDefaultValue(Set.empty) + def storeCond(pathVar: Identifier, id: Identifier) : Unit = { + condVars += id -> encoder.encodeId(id) + condTree += pathVar -> (condTree(pathVar) + id) + } + @inline def encodedCond(id: Identifier) : T = substMap.getOrElse(id, condVars(id)) var exprVars = Map[Identifier, T]() @@ -165,18 +223,37 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], @inline def registerQuantification(quantification: QuantificationTemplate[T]): Unit = quantifications :+= quantification - var lambdas = Map[T, LambdaTemplate[T]]() - @inline def registerLambda(idT: T, lambda: LambdaTemplate[T]) : Unit = lambdas += idT -> lambda + var lambdas = Seq[LambdaTemplate[T]]() + @inline def registerLambda(lambda: LambdaTemplate[T]) : Unit = lambdas :+= lambda def requireDecomposition(e: Expr) = { exists{ - case (_: Choose) | (_: Forall) => true + case (_: Choose) | (_: Forall) | (_: Lambda) => true case (_: Assert) | (_: Ensuring) => true case (_: FunctionInvocation) | (_: Application) => true case _ => false }(e) } + def groupWhile[T](es: Seq[T])(p: T => Boolean): Seq[Seq[T]] = { + var res: Seq[Seq[T]] = Nil + + var c = es + while (!c.isEmpty) { + val (span, rest) = c.span(p) + + if (span.isEmpty) { + res :+= Seq(rest.head) + c = rest.tail + } else { + res :+= span + c = rest + } + } + + res + } + def rec(pathVar: Identifier, expr: Expr): Expr = { expr match { case a @ Assert(cond, err, body) => @@ -220,13 +297,71 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], case p : Passes => sys.error("'Passes's should have been eliminated before generating templates.") case i @ Implies(lhs, rhs) => - implies(rec(pathVar, lhs), rec(pathVar, rhs)) + if (requireDecomposition(i)) { + rec(pathVar, Or(Not(lhs), rhs)) + } else { + implies(rec(pathVar, lhs), rec(pathVar, rhs)) + } case a @ And(parts) => - andJoin(parts.map(rec(pathVar, _))) + val partitions = groupWhile(parts)(!requireDecomposition(_)) + partitions.map(andJoin) match { + case Seq(e) => e + case seq => + val newExpr : Identifier = FreshIdentifier("e", BooleanType, true) + storeExpr(newExpr) + + def recAnd(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match { + case x :: Nil if !requireDecomposition(x) => + storeGuarded(pathVar, Equals(Variable(newExpr), x)) + + case x :: xs => + val newBool : Identifier = FreshIdentifier("b", BooleanType, true) + storeCond(pathVar, newBool) + + val xrec = rec(pathVar, x) + storeGuarded(pathVar, Equals(Variable(newBool), xrec)) + storeGuarded(pathVar, Implies(Not(Variable(newBool)), Not(Variable(newExpr)))) + + recAnd(newBool, xs) + + case Nil => + storeGuarded(pathVar, Variable(newExpr)) + } + + recAnd(pathVar, seq) + Variable(newExpr) + } case o @ Or(parts) => - orJoin(parts.map(rec(pathVar, _))) + val partitions = groupWhile(parts)(!requireDecomposition(_)) + partitions.map(orJoin) match { + case Seq(e) => e + case seq => + val newExpr : Identifier = FreshIdentifier("e", BooleanType, true) + storeExpr(newExpr) + + def recOr(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match { + case x :: Nil if !requireDecomposition(x) => + storeGuarded(pathVar, Equals(Variable(newExpr), x)) + + case x :: xs => + val newBool : Identifier = FreshIdentifier("b", BooleanType, true) + storeCond(pathVar, newBool) + + val xrec = rec(pathVar, x) + storeGuarded(pathVar, Equals(Not(Variable(newBool)), xrec)) + storeGuarded(pathVar, Implies(Not(Variable(newBool)), Variable(newExpr))) + + recOr(newBool, xs) + + case Nil => + storeGuarded(pathVar, Not(Variable(newExpr))) + } + + recOr(pathVar, seq) + Variable(newExpr) + } case i @ IfExpr(cond, thenn, elze) => { if(!requireDecomposition(i)) { @@ -236,8 +371,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val newBool2 : Identifier = FreshIdentifier("b", BooleanType, true) val newExpr : Identifier = FreshIdentifier("e", i.getType, true) - storeCond(newBool1) - storeCond(newBool2) + storeCond(pathVar, newBool1) + storeCond(pathVar, newBool2) storeExpr(newExpr) @@ -274,19 +409,18 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val idArgs : Seq[Identifier] = lambdaArgs(l) val trArgs : Seq[T] = idArgs.map(id => substMap.getOrElse(id, encoder.encodeId(id))) - val lid = FreshIdentifier("lambda", l.getType, true) + val lid = FreshIdentifier("lambda", bestRealType(l.getType), true) val clause = liftedEquals(Variable(lid), l, idArgs, inlineFirst = true) val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars val clauseSubst: Map[Identifier, T] = localSubst ++ (idArgs zip trArgs) - val (lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, lambdaQuants) = mkClauses(pathVar, clause, clauseSubst) - assert(lambdaQuants.isEmpty, "Unhandled quantification in lambdas in " + l) + val (lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaTemplates, lambdaQuants) = mkClauses(pathVar, clause, clauseSubst) val ids: (Identifier, T) = lid -> storeLambda(lid) val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap val template = LambdaTemplate(ids, encoder, manager, pathVar -> encodedCond(pathVar), - idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaGuarded, lambdaTemplates, localSubst, dependencies, l) - registerLambda(ids._2, template) + idArgs zip trArgs, lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaQuants, lambdaTemplates, localSubst, dependencies, l) + registerLambda(template) Variable(lid) @@ -295,7 +429,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val conjunctQs = conjuncts.map { conjunct => val vars = variablesOf(conjunct) - val quantifiers = args.map(_.id).filter(vars).toSet + val inits = args.map(_.id).filter(vars).toSet + val (quantifiers, flatConj) = minimalFlattening(inits, conjunct) val idQuantifiers : Seq[Identifier] = quantifiers.toSeq val trQuantifiers : Seq[T] = idQuantifiers.map(encoder.encodeId) @@ -305,19 +440,21 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val inst: Identifier = FreshIdentifier("inst", BooleanType, true) val guard: Identifier = FreshIdentifier("guard", BooleanType, true) - val clause = Equals(Variable(inst), Implies(Variable(guard), conjunct)) + val clause = Equals(Variable(inst), Implies(Variable(guard), flatConj)) val qs: (Identifier, T) = q -> encoder.encodeId(q) val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars val clauseSubst: Map[Identifier, T] = localSubst ++ (idQuantifiers zip trQuantifiers) - val (qConds, qExprs, qGuarded, qTemplates, qQuants) = mkClauses(pathVar, clause, clauseSubst) + val (p, (qConds, qExprs, qTree, qGuarded, qTemplates, qQuants)) = mkExprClauses(pathVar, flatConj, clauseSubst) assert(qQuants.isEmpty, "Unhandled nested quantification in "+clause) - val binder = Equals(Variable(q), And(Variable(q2), Variable(inst))) - val allQGuarded = qGuarded + (pathVar -> (binder +: qGuarded.getOrElse(pathVar, Seq.empty))) + val allGuarded = qGuarded + (pathVar -> (Seq( + Equals(Variable(inst), Implies(Variable(guard), p)), + Equals(Variable(q), And(Variable(q2), Variable(inst))) + ) ++ qGuarded.getOrElse(pathVar, Seq.empty))) val template = QuantificationTemplate[T](encoder, manager, pathVar -> encodedCond(pathVar), - qs, q2, inst, guard, idQuantifiers zip trQuantifiers, qConds, qExprs, allQGuarded, qTemplates, localSubst) + qs, q2, inst, guard, idQuantifiers zip trQuantifiers, qConds, qExprs, qTree, allGuarded, qTemplates, localSubst) registerQuantification(template) Variable(q) } @@ -329,9 +466,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], } val p = rec(pathVar, expr) - storeGuarded(pathVar, p) - - (condVars, exprVars, guardedExprs, lambdas, quantifications) + (p, (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications)) } } diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/templates/TemplateInfo.scala index e298e298a6f828c78dcf4da8de5177f94f16758b..033f15dd6f251026a260ed6344212239e1714a37 100644 --- a/src/main/scala/leon/solvers/templates/TemplateInfo.scala +++ b/src/main/scala/leon/solvers/templates/TemplateInfo.scala @@ -14,6 +14,6 @@ case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[T]) { case class TemplateAppInfo[T](template: LambdaTemplate[T], equals: T, args: Seq[T]) { override def toString = { - template.id + "|" + equals + args.mkString("(", ",", ")") + template.ids._2 + "|" + equals + args.mkString("(", ",", ")") } } diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala similarity index 57% rename from src/main/scala/leon/solvers/templates/Templates.scala rename to src/main/scala/leon/solvers/templates/TemplateManager.scala index 32d273c3937d6ba4b808b79c16edf1ded4ade785..bb6629ec16988a79eb686259aea4fa32c6111dbb 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -11,10 +11,11 @@ import purescala.Quantification._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps._ -case class App[T](caller: T, tpe: FunctionType, args: Seq[T]) { - override def toString = "(" + caller + " : " + tpe + ")" + args.mkString("(", ",", ")") -} +import utils._ + +import scala.collection.generic.CanBuildFrom object Instantiation { type Clauses[T] = Seq[T] @@ -24,12 +25,18 @@ object Instantiation { def empty[T] = (Seq.empty[T], Map.empty[T, Set[TemplateCallInfo[T]]], Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) - implicit class MapWrapper[A,B](map: Map[A,Set[B]]) { + implicit class MapSetWrapper[A,B](map: Map[A,Set[B]]) { def merge(that: Map[A,Set[B]]): Map[A,Set[B]] = (map.keys ++ that.keys).map { k => k -> (map.getOrElse(k, Set.empty) ++ that.getOrElse(k, Set.empty)) }.toMap } + implicit class MapSeqWrapper[A,B](map: Map[A,Seq[B]]) { + def merge(that: Map[A,Seq[B]]): Map[A,Seq[B]] = (map.keys ++ that.keys).map { k => + k -> (map.getOrElse(k, Seq.empty) ++ that.getOrElse(k, Seq.empty)) + }.toMap + } + implicit class InstantiationWrapper[T](i: Instantiation[T]) { def ++(that: Instantiation[T]): Instantiation[T] = { val (thisClauses, thisBlockers, thisApps) = i @@ -40,6 +47,12 @@ object Instantiation { def withClause(cl: T): Instantiation[T] = (i._1 :+ cl, i._2, i._3) def withClauses(cls: Seq[T]): Instantiation[T] = (i._1 ++ cls, i._2, i._3) + + def withCalls(calls: CallBlockers[T]): Instantiation[T] = (i._1, i._2 merge calls, i._3) + def withApps(apps: AppBlockers[T]): Instantiation[T] = (i._1, i._2, i._3 merge apps) + def withApp(app: ((T, App[T]), TemplateAppInfo[T])): Instantiation[T] = { + (i._1, i._2, i._3 merge Map(app._1 -> Set(app._2))) + } } } @@ -49,16 +62,19 @@ trait Template[T] { self => val encoder : TemplateEncoder[T] val manager : QuantificationManager[T] - val start : T + val pathVar: (Identifier, T) val args : Seq[T] val condVars : Map[Identifier, T] val exprVars : Map[Identifier, T] + val condTree : Map[Identifier, Set[Identifier]] val clauses : Seq[T] val blockers : Map[T, Set[TemplateCallInfo[T]]] val applications : Map[T, Set[App[T]]] - val quantifications: Seq[QuantificationTemplate[T]] - val matchers: Map[T, Set[Matcher[T]]] - val lambdas : Map[T, LambdaTemplate[T]] + val quantifications : Seq[QuantificationTemplate[T]] + val matchers : Map[T, Set[Matcher[T]]] + val lambdas : Seq[LambdaTemplate[T]] + + lazy val start = pathVar._2 private var substCache : Map[Seq[T],Map[T,T]] = Map.empty @@ -67,16 +83,20 @@ trait Template[T] { self => val baseSubstMap : Map[T,T] = substCache.get(args) match { case Some(subst) => subst case None => - val subst = (condVars ++ exprVars).map { case (id, idT) => idT -> encoder.encodeId(id) } ++ + val subst = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ + manager.freshConds(pathVar._1 -> aVar, condVars, condTree) ++ (this.args zip args) substCache += args -> subst subst } - val lambdaSubstMap = lambdas.map { case (idT, lambda) => idT -> encoder.encodeId(lambda.id) } + val lambdaSubstMap = lambdas.map(lambda => lambda.ids._2 -> encoder.encodeId(lambda.ids._1)) val quantificationSubstMap = quantifications.map(q => q.qs._2 -> encoder.encodeId(q.qs._1)) val substMap : Map[T,T] = baseSubstMap ++ lambdaSubstMap ++ quantificationSubstMap + (start -> aVar) + instantiate(substMap) + } + protected def instantiate(substMap: Map[T, T]): Instantiation[T] = { Template.instantiate(encoder, manager, clauses, blockers, applications, quantifications, matchers, lambdas, substMap) } @@ -86,43 +106,6 @@ trait Template[T] { self => object Template { - private object InvocationExtractor { - private def flatInvocation(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { - case fi @ FunctionInvocation(tfd, args) => Some((tfd, args)) - case Application(caller, args) => flatInvocation(caller) match { - case Some((tfd, prevArgs)) => Some((tfd, prevArgs ++ args)) - case None => None - } - case _ => None - } - - def unapply(expr: Expr): Option[(TypedFunDef, Seq[Expr])] = expr match { - case IsTyped(f: FunctionInvocation, ft: FunctionType) => None - case IsTyped(f: Application, ft: FunctionType) => None - case FunctionInvocation(tfd, args) => Some(tfd -> args) - case f: Application => flatInvocation(f) - case _ => None - } - } - - private object ApplicationExtractor { - private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(fi: FunctionInvocation, _) => None - case Application(caller: Application, args) => flatApplication(caller) match { - case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) - case None => None - } - case Application(caller, args) => Some((caller, args)) - case _ => None - } - - def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case IsTyped(f: Application, ft: FunctionType) => None - case f: Application => flatApplication(f) - case _ => None - } - } - private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") @@ -136,7 +119,7 @@ object Template { val (fiArgs, appArgs) = args.splitAt(tfd.params.size) val app @ Application(caller, arguments) = rec(FunctionInvocation(tfd, fiArgs), appArgs) - Matcher(encodeExpr(caller), caller.getType, arguments.map(arg => Left(encodeExpr(arg))), encodeExpr(app)) + Matcher(encodeExpr(caller), bestRealType(caller.getType), arguments.map(arg => Left(encodeExpr(arg))), encodeExpr(app)) } def encode[T]( @@ -146,16 +129,14 @@ object Template { condVars: Map[Identifier, T], exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], substMap: Map[Identifier, T] = Map.empty[Identifier, T], optCall: Option[TypedFunDef] = None, optApp: Option[(T, FunctionType)] = None ) : (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[T, Set[App[T]]], Map[T, Set[Matcher[T]]], () => String) = { - val idToTrId : Map[Identifier, T] = { - condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ - lambdas.map { case (idT, template) => template.id -> idT } - } + val idToTrId : Map[Identifier, T] = + condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) val encodeExpr : Expr => T = encoder.encodeExpr(idToTrId) @@ -164,7 +145,9 @@ object Template { }).toSeq val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(_._2))) - val optIdApp = optApp.map { case (idT, tpe) => App(idT, tpe, arguments.map(_._2)) } + val optIdApp = optApp.map { case (idT, tpe) => + App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(_._2)) + } lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) .map(tfd => invocationMatcher(encodeExpr)(tfd, arguments.map(_._1.toVariable))) @@ -180,17 +163,10 @@ object Template { var matchInfos : Set[Matcher[T]] = Set.empty for (e <- es) { - funInfos ++= collect[TemplateCallInfo[T]] { - case InvocationExtractor(tfd, args) => - Set(TemplateCallInfo(tfd, args.map(encodeExpr))) - case _ => Set.empty - }(e) - - appInfos ++= collect[App[T]] { - case ApplicationExtractor(c, args) => - Set(App(encodeExpr(c), c.getType.asInstanceOf[FunctionType], args.map(encodeExpr))) - case _ => Set.empty - }(e) + funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeExpr))) + appInfos ++= firstOrderAppsOf(e).map { case (c, args) => + App(encodeExpr(c), bestRealType(c.getType).asInstanceOf[FunctionType], args.map(encodeExpr)) + } matchInfos ++= fold[Map[Expr, Matcher[T]]] { (expr, res) => val result = res.flatten.toMap @@ -204,7 +180,7 @@ object Template { case None => Left(encodeExpr(arg)) }) - Some(expr -> Matcher(encodeExpr(c), c.getType, encodedArgs, encodeExpr(expr))) + Some(expr -> Matcher(encodeExpr(c), bestRealType(c.getType), encodedArgs, encodeExpr(expr))) case _ => None }) }(e).values @@ -247,7 +223,7 @@ object Template { " * Matchers :" + (if (matchers.isEmpty) "\n" else { "\n " + matchers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" }) + - " * Lambdas :\n" + lambdas.map { case (_, template) => + " * Lambdas :\n" + lambdas.map { case template => " +> " + template.toString.split("\n").mkString("\n ") + "\n" }.mkString("\n") } @@ -263,7 +239,7 @@ object Template { applications: Map[T, Set[App[T]]], quantifications: Seq[QuantificationTemplate[T]], matchers: Map[T, Set[Matcher[T]]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], substMap: Map[T, T] ): Instantiation[T] = { @@ -276,10 +252,8 @@ object Template { var instantiation: Instantiation[T] = (newClauses, newBlockers, Map.empty) - for ((idT, lambda) <- lambdas) { - val newIdT = substituter(idT) - val newTemplate = lambda.substitute(substMap) - instantiation ++= manager.instantiateLambda(newIdT, newTemplate) + for (lambda <- lambdas) { + instantiation ++= manager.instantiateLambda(lambda.substitute(substituter)) } for ((b,apps) <- applications; bp = substituter(b); app <- apps) { @@ -292,7 +266,7 @@ object Template { } for (q <- quantifications) { - instantiation ++= q.instantiate(substMap) + instantiation ++= manager.instantiateQuantification(q, substMap) } instantiation @@ -309,9 +283,10 @@ object FunctionTemplate { arguments: Seq[(Identifier, T)], condVars: Map[Identifier, T], exprVars: Map[Identifier, T], + condTree: Map[Identifier, Set[Identifier]], guardedExprs: Map[Identifier, Seq[Expr]], quantifications: Seq[QuantificationTemplate[T]], - lambdas: Map[T, LambdaTemplate[T]], + lambdas: Seq[LambdaTemplate[T]], isRealFunDef: Boolean ) : FunctionTemplate[T] = { @@ -330,10 +305,11 @@ object FunctionTemplate { tfd, encoder, manager, - pathVar._2, + pathVar, arguments.map(_._2), condVars, exprVars, + condTree, clauses, blockers, applications, @@ -350,16 +326,17 @@ class FunctionTemplate[T] private( val tfd: TypedFunDef, val encoder: TemplateEncoder[T], val manager: QuantificationManager[T], - val start: T, + val pathVar: (Identifier, T), val args: Seq[T], val condVars: Map[Identifier, T], val exprVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], val quantifications: Seq[QuantificationTemplate[T]], val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]], + val lambdas: Seq[LambdaTemplate[T]], isRealFunDef: Boolean, stringRepr: () => String) extends Template[T] { @@ -367,151 +344,45 @@ class FunctionTemplate[T] private( override def toString : String = str override def instantiate(aVar: T, args: Seq[T]): (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[(T, App[T]), Set[TemplateAppInfo[T]]]) = { - if (!isRealFunDef) manager.registerFree(tfd.params.map(_.getType) zip args) + if (!isRealFunDef) manager.registerFree(tfd.params.map(_.id) zip args) super.instantiate(aVar, args) } } -object LambdaTemplate { +class TemplateManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { + private val condImplies = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) + private val condImplied = new IncrementalMap[T, Set[T]].withDefaultValue(Set.empty) - def apply[T]( - ids: (Identifier, T), - encoder: TemplateEncoder[T], - manager: QuantificationManager[T], - pathVar: (Identifier, T), - arguments: Seq[(Identifier, T)], - condVars: Map[Identifier, T], - exprVars: Map[Identifier, T], - guardedExprs: Map[Identifier, Seq[Expr]], - lambdas: Map[T, LambdaTemplate[T]], - baseSubstMap: Map[Identifier, T], - dependencies: Map[Identifier, T], - lambda: Lambda - ) : LambdaTemplate[T] = { - - val id = ids._2 - val tpe = ids._1.getType.asInstanceOf[FunctionType] - val (clauses, blockers, applications, matchers, templateString) = - Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, - substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) + protected def incrementals: List[IncrementalState] = List(condImplies, condImplied) - val lambdaString : () => String = () => { - "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() - } + def clear(): Unit = incrementals.foreach(_.clear()) + def reset(): Unit = incrementals.foreach(_.reset()) + def push(): Unit = incrementals.foreach(_.push()) + def pop(): Unit = incrementals.foreach(_.pop()) - val (structuralLambda, structSubst) = normalizeStructure(lambda) - val keyDeps = dependencies.map { case (id, idT) => structSubst(id) -> idT } - val key = structuralLambda.asInstanceOf[Lambda] + def freshConds(path: (Identifier, T), condVars: Map[Identifier, T], tree: Map[Identifier, Set[Identifier]]): Map[T, T] = { + val subst = condVars.map { case (id, idT) => idT -> encoder.encodeId(id) } + val mapping = condVars.mapValues(subst) + path - new LambdaTemplate[T]( - ids._1, - encoder, - manager, - pathVar._2, - arguments.map(_._2), - condVars, - exprVars, - clauses, - blockers, - applications, - matchers, - lambdas, - keyDeps, - key, - lambdaString - ) - } -} - -class LambdaTemplate[T] private ( - val id: Identifier, - val encoder: TemplateEncoder[T], - val manager: QuantificationManager[T], - val start: T, - val args: Seq[T], - val condVars: Map[Identifier, T], - val exprVars: Map[Identifier, T], - val clauses: Seq[T], - val blockers: Map[T, Set[TemplateCallInfo[T]]], - val applications: Map[T, Set[App[T]]], - val matchers: Map[T, Set[Matcher[T]]], - val lambdas: Map[T, LambdaTemplate[T]], - private[templates] val dependencies: Map[Identifier, T], - private[templates] val structuralKey: Lambda, - stringRepr: () => String) extends Template[T] { - - // Universal quantification is not allowed inside closure bodies! - val quantifications: Seq[QuantificationTemplate[T]] = Seq.empty - - val tpe = id.getType.asInstanceOf[FunctionType] - - def substitute(substMap: Map[T,T]): LambdaTemplate[T] = { - val substituter : T => T = encoder.substitute(substMap) - - val newStart = substituter(start) - val newClauses = clauses.map(substituter) - val newBlockers = blockers.map { case (b, fis) => - val bp = if (b == start) newStart else b - bp -> fis.map(fi => fi.copy(args = fi.args.map(substituter))) - } - - val newApplications = applications.map { case (b, fas) => - val bp = if (b == start) newStart else b - bp -> fas.map(fa => fa.copy(caller = substituter(fa.caller), args = fa.args.map(substituter))) + for ((parent, children) <- tree; ep = mapping(parent); child <- children) { + val ec = mapping(child) + condImplies += ep -> (condImplies(ep) + ec) + condImplied += ec -> (condImplied(ec) + ep) } - val newMatchers = matchers.map { case (b, ms) => - val bp = if (b == start) newStart else b - bp -> ms.map(_.substitute(substituter)) - } - - val newLambdas = lambdas.map { case (idT, template) => idT -> template.substitute(substMap) } - - val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) - - new LambdaTemplate[T]( - id, - encoder, - manager, - newStart, - args, - condVars, - exprVars, - newClauses, - newBlockers, - newApplications, - newMatchers, - newLambdas, - newDependencies, - structuralKey, - stringRepr - ) + subst } - private lazy val str : String = stringRepr() - override def toString : String = str - - def contextEquality(that: LambdaTemplate[T]) : Option[Seq[T]] = { - if (structuralKey != that.structuralKey) { - None // can't be equal - } else if (dependencies.isEmpty) { - Some(Seq.empty) // must be equal - } else { - def rec(e1: Expr, e2: Expr): Seq[T] = (e1,e2) match { - case (Variable(id1), Variable(id2)) => - if (dependencies.isDefinedAt(id1)) { - Seq(encoder.mkEquals(dependencies(id1), that.dependencies(id2))) - } else { - Seq.empty - } - - case (Operator(es1, _), Operator(es2, _)) => - (es1 zip es2).flatMap(p => rec(p._1, p._2)) - - case _ => Seq.empty - } - - Some(rec(structuralKey, that.structuralKey)) + def blocker(b: T): Unit = condImplies += (b -> Set.empty) + def isBlocker(b: T): Boolean = condImplies.isDefinedAt(b) || condImplied.isDefinedAt(b) + + def implies(b1: T, b2: T): Unit = implies(b1, Set(b2)) + def implies(b1: T, b2s: Set[T]): Unit = { + val fb2s = b2s.filter(_ != b1) + condImplies += b1 -> (condImplies(b1) ++ fb2s) + for (b2 <- fb2s) { + condImplied += b2 -> (condImplies(b2) + b1) } } + } diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index ddfb22b0bdbd2a3be90e8950b7a21144b472d299..d1c4432a3c69f1882c9b4a1787a12dc8cf64f520 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -19,7 +19,8 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat private val manager = templateGenerator.manager // Function instantiations have their own defblocker - private val defBlockers = new IncrementalMap[TemplateCallInfo[T], T]() + private val defBlockers = new IncrementalMap[TemplateCallInfo[T], T]() + private val lambdaBlockers = new IncrementalMap[TemplateAppInfo[T], T]() // Keep which function invocation is guarded by which guard, // also specify the generation of the blocker. @@ -32,6 +33,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def push() { callInfos.push() defBlockers.push() + lambdaBlockers.push() appInfos.push() appBlockers.push() blockerToApps.push() @@ -41,6 +43,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def pop() { callInfos.pop() defBlockers.pop() + lambdaBlockers.pop() appInfos.pop() appBlockers.pop() blockerToApps.pop() @@ -50,6 +53,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def clear() { callInfos.clear() defBlockers.clear() + lambdaBlockers.clear() appInfos.clear() appBlockers.clear() blockerToApps.clear() @@ -59,6 +63,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def reset() { callInfos.reset() defBlockers.reset() + lambdaBlockers.reset() appInfos.reset() appBlockers.reset() blockerToApps.clear() @@ -257,6 +262,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat // we need to define this defBlocker and link it to definition val defBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType)) defBlockers += info -> defBlocker + manager.implies(id, defBlocker) val template = templateGenerator.mkTemplate(tfd) //reporter.debug(template) @@ -279,7 +285,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat // We connect it to the defBlocker: blocker => defBlocker if (defBlocker != id) { - newCls ++= List(encoder.mkImplies(id, defBlocker)) + newCls :+= encoder.mkImplies(id, defBlocker) } reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") @@ -293,22 +299,32 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat for ((app @ (b, _), (gen, _, _, _, infos)) <- thisAppInfos; info @ TemplateAppInfo(template, equals, args) <- infos) { var newCls = Seq.empty[T] - val nb = encoder.encodeId(FreshIdentifier("b", BooleanType, true)) - newCls :+= encoder.mkEquals(nb, encoder.mkAnd(equals, b)) + val lambdaBlocker = lambdaBlockers.get(info) match { + case Some(lambdaBlocker) => lambdaBlocker - val (newExprs, callBlocks, appBlocks) = template.instantiate(nb, args) - val blockExprs = freshAppBlocks(appBlocks.keys) + case None => + val lambdaBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType)) + lambdaBlockers += info -> lambdaBlocker + manager.implies(b, lambdaBlocker) - for ((b, newInfos) <- callBlocks) { - registerCallBlocker(nextGeneration(gen), b, newInfos) - } + val (newExprs, callBlocks, appBlocks) = template.instantiate(lambdaBlocker, args) + val blockExprs = freshAppBlocks(appBlocks.keys) + + for ((b, newInfos) <- callBlocks) { + registerCallBlocker(nextGeneration(gen), b, newInfos) + } - for ((newApp, newInfos) <- appBlocks) { - registerAppBlocker(nextGeneration(gen), newApp, newInfos) + for ((newApp, newInfos) <- appBlocks) { + registerAppBlocker(nextGeneration(gen), newApp, newInfos) + } + + newCls ++= newExprs + newCls ++= blockExprs + lambdaBlocker } - newCls ++= newExprs - newCls ++= blockExprs + val enabler = if (equals == manager.trueT) b else encoder.mkAnd(equals, b) + newCls :+= encoder.mkImplies(enabler, lambdaBlocker) reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") for (cl <- newCls) { @@ -318,6 +334,12 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat newClauses ++= newCls } + /* + for ((app @ (b, _), (gen, _, _, _, infos)) <- thisAppInfos if infos.isEmpty) { + registerAppBlocker(nextGeneration(gen), app, infos) + } + */ + reporter.debug(s" - ${newClauses.size} new clauses") //context.reporter.ifDebug { debug => // debug(s" - new clauses:") diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 9a06cff266e221b0459b7d23e9417b4940ec5743..bc6aaf207c071e694f6154fca23d128479aa50d0 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -19,6 +19,9 @@ import purescala.Types._ import scala.collection.mutable.{Map => MutableMap} +case class UnsoundExtractionException(ast: Z3AST, msg: String) + extends Exception("Can't extract " + ast + " : " + msg) + // This is just to factor out the things that are common in "classes that deal // with a Z3 instance" trait AbstractZ3Solver extends Solver { @@ -34,6 +37,9 @@ trait AbstractZ3Solver extends Solver { private[this] var freed = false val traceE = new Exception() + protected def unsound(ast: Z3AST, msg: String): Nothing = + throw UnsoundExtractionException(ast, msg) + override def finalize() { if (!freed) { println("!! Solver "+this.getClass.getName+"["+this.hashCode+"] not freed properly prior to GC:") @@ -85,10 +91,11 @@ trait AbstractZ3Solver extends Solver { protected val adtManager = new ADTManager(context) // Bijections between Leon Types/Functions/Ids to Z3 Sorts/Decls/ASTs - protected val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]() - protected val generics = new IncrementalBijection[GenericValue, Z3FuncDecl]() - protected val sorts = new IncrementalBijection[TypeTree, Z3Sort]() - protected val variables = new IncrementalBijection[Expr, Z3AST]() + protected val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]() + protected val generics = new IncrementalBijection[GenericValue, Z3FuncDecl]() + protected val lambdas = new IncrementalBijection[FunctionType, Z3FuncDecl]() + protected val sorts = new IncrementalBijection[TypeTree, Z3Sort]() + protected val variables = new IncrementalBijection[Expr, Z3AST]() protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() @@ -102,6 +109,7 @@ trait AbstractZ3Solver extends Solver { z3 = new Z3Context(z3cfg) functions.clear() + lambdas.clear() generics.clear() sorts.clear() variables.clear() @@ -184,7 +192,6 @@ trait AbstractZ3Solver extends Solver { } } } - } // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand. @@ -224,7 +231,6 @@ trait AbstractZ3Solver extends Solver { declareStructuralSort(tpe) } - case tt @ SetType(base) => sorts.cachedB(tt) { z3.mkSetSort(typeToSort(base)) @@ -251,10 +257,8 @@ trait AbstractZ3Solver extends Solver { case ft @ FunctionType(from, to) => sorts.cachedB(ft) { - val fromSort = typeToSort(tupleTypeWrap(from)) - val toSort = typeToSort(to) - - z3.mkArraySort(fromSort, toSort) + val symbol = z3.mkFreshStringSymbol(ft.toString) + z3.mkUninterpretedSort(symbol) } case other => @@ -309,12 +313,19 @@ trait AbstractZ3Solver extends Solver { newAST } case v @ Variable(id) => z3Vars.get(id) match { - case Some(ast) => ast + case Some(ast) => + ast case None => { - val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) - z3Vars = z3Vars + (id -> newAST) - variables += (v -> newAST) - newAST + variables.getB(v) match { + case Some(ast) => + ast + + case None => + val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) + z3Vars = z3Vars + (id -> newAST) + variables += (v -> newAST) + newAST + } } } @@ -483,7 +494,15 @@ trait AbstractZ3Solver extends Solver { z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) case fa @ Application(caller, args) => - z3.mkSelect(rec(caller), rec(tupleWrap(args))) + val ft @ FunctionType(froms, to) = normalizeType(caller.getType) + val funDecl = lambdas.cachedB(ft) { + val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) + val returnSort = typeToSort(to) + + val name = FreshIdentifier("dynLambda").uniqueName + z3.mkFreshFuncDecl(name, sortSeq, returnSort) + } + z3.mkApp(funDecl, (caller +: args).map(rec): _*) case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) @@ -551,7 +570,7 @@ trait AbstractZ3Solver extends Solver { val kind = z3.getASTKind(t) kind match { - case Z3NumeralIntAST(Some(v)) => { + case Z3NumeralIntAST(Some(v)) => val leading = t.toString.substring(0, 2 min t.toString.length) if(leading == "#x") { _root_.smtlib.common.Hexadecimal.fromString(t.toString.substring(2)) match { @@ -562,29 +581,25 @@ trait AbstractZ3Solver extends Solver { case other => unsupported(other, "Unexpected target type for BV value") } - case None => { - throw LeonFatalError(s"Could not translate hexadecimal Z3 numeral $t") - } + case None => unsound(t, "could not translate hexadecimal Z3 numeral") } } else { InfiniteIntegerLiteral(v) } - } - case Z3NumeralIntAST(None) => { + + case Z3NumeralIntAST(None) => _root_.smtlib.common.Hexadecimal.fromString(t.toString.substring(2)) match { case Some(hexa) => tpe match { case Int32Type => IntLiteral(hexa.toInt) case CharType => CharLiteral(hexa.toInt.toChar) - case _ => - reporter.fatalError("Unexpected target type for BV value: " + tpe.asString) + case _ => unsound(t, "unexpected target type for BV value: " + tpe.asString) } - case None => { - reporter.fatalError(s"Could not translate Z3NumeralIntAST numeral $t") - } + case None => unsound(t, "could not translate Z3NumeralIntAST numeral") } - } + case Z3NumeralRealAST(n: BigInt, d: BigInt) => FractionalLiteral(n, d) + case Z3AppAST(decl, args) => val argsSize = args.size if(argsSize == 0 && (variables containsB t)) { @@ -619,12 +634,12 @@ trait AbstractZ3Solver extends Solver { val entries = elems.map { case (IntLiteral(i), v) => i -> v - case _ => reporter.fatalError("Translation from Z3 to Array failed") + case (e,_) => unsupported(e, s"Z3 returned unexpected array index ${e.asString}") } finiteArray(entries, Some(default, s), to) - case _ => - reporter.fatalError("Translation from Z3 to Array failed") + case (s : IntLiteral, arr) => unsound(args(1), "invalid array type") + case (size, _) => unsound(args(0), "invalid array size") } } } else { @@ -638,9 +653,25 @@ trait AbstractZ3Solver extends Solver { } RawArrayValue(from, entries, default) - case None => reporter.fatalError("Translation from Z3 to Array failed") + case None => unsound(t, "invalid array AST") } + case ft @ FunctionType(fts, tt) => lambdas.getB(ft) match { + case None => simplestValue(ft) + case Some(decl) => model.getModelFuncInterpretations.find(_._1 == decl) match { + case None => simplestValue(ft) + case Some((_, mapping, elseValue)) => + val leonElseValue = rec(elseValue, tt) + PartialLambda(mapping.flatMap { case (z3Args, z3Result) => + if (t == z3Args.head) { + List((z3Args.tail zip fts).map(p => rec(p._1, p._2)) -> rec(z3Result, tt)) + } else { + Nil + } + }, Some(leonElseValue), ft) + } + } + case tp: TypeParameter => val id = t.toString.split("!").last.toInt GenericValue(tp, id) @@ -663,12 +694,9 @@ trait AbstractZ3Solver extends Solver { FiniteMap(elems, from, to) } - case FunctionType(fts, tt) => - rec(t, RawArrayType(tupleTypeWrap(fts), tt)) - case tpe @ SetType(dt) => model.getSetValue(t) match { - case None => reporter.fatalError("Translation from Z3 to set failed") + case None => unsound(t, "invalid set AST") case Some(set) => val elems = set.map(e => rec(e, dt)) FiniteSet(elems, dt) @@ -698,8 +726,7 @@ trait AbstractZ3Solver extends Solver { // case OpDiv => Division(rargs(0), rargs(1)) // case OpIDiv => Division(rargs(0), rargs(1)) // case OpMod => Modulo(rargs(0), rargs(1)) - case other => - reporter.fatalError( + case other => unsound(t, s"""|Don't know what to do with this declKind: $other |Expected type: ${Option(tpe).map{_.asString}.getOrElse("")} |Tree: $t @@ -708,11 +735,11 @@ trait AbstractZ3Solver extends Solver { } } } - case _ => - reporter.fatalError(s"Don't know what to do with this Z3 tree: $t") + case _ => unsound(t, "unexpected AST") } } - rec(tree, tpe) + + rec(tree, normalizeType(tpe)) } protected[leon] def softFromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree) : Option[Expr] = { @@ -720,6 +747,8 @@ trait AbstractZ3Solver extends Solver { Some(fromZ3Formula(model, tree, tpe)) } catch { case e: Unsupported => None + case e: UnsoundExtractionException => None + case n: java.lang.NumberFormatException => None } } diff --git a/src/main/scala/leon/solvers/z3/FairZ3Component.scala b/src/main/scala/leon/solvers/z3/FairZ3Component.scala index f321d516d66a4d43751519f06e9c65b3a0e4d126..70ebd260f931a6a428a84afad9640accf193887d 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Component.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Component.scala @@ -13,9 +13,11 @@ trait FairZ3Component extends LeonComponent { val optUseCodeGen = LeonFlagOptionDef("codegen", "Use compiled evaluator instead of interpreter", false) val optUnrollCores = LeonFlagOptionDef("unrollcores", "Use unsat-cores to drive unrolling while remaining fair", false) val optAssumePre = LeonFlagOptionDef("assumepre", "Assume precondition holds (pre && f(x) = body) when unfolding", false) + val optNoChecks = LeonFlagOptionDef("nochecks", "Disable counter-example check in presence of foralls" , false) + val optUnfoldFactor = LeonLongOptionDef("unfoldFactor", "Number of unfoldings to perform in each unfold step", default = 1, "<PosInt>") override val definedOptions: Set[LeonOptionDef[Any]] = - Set(optEvalGround, optCheckModels, optFeelingLucky, optUseCodeGen, optUnrollCores, optAssumePre) + Set(optEvalGround, optCheckModels, optFeelingLucky, optUseCodeGen, optUnrollCores, optAssumePre, optUnfoldFactor) } object FairZ3Component extends FairZ3Component diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 5baa9b9aa22713f1b4e86782570e124d48aac767..6a96848c5fb283ec1a6ae22817afff95f9300122 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -26,7 +26,8 @@ class FairZ3Solver(val context: LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction with FairZ3Component - with EvaluatingSolver { + with EvaluatingSolver + with QuantificationSolver { enclosing => @@ -36,6 +37,9 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val evalGroundApps = context.findOptionOrDefault(optEvalGround) val unrollUnsatCores = context.findOptionOrDefault(optUnrollCores) val assumePreHolds = context.findOptionOrDefault(optAssumePre) + val disableChecks = context.findOptionOrDefault(optNoChecks) + + assert(!checkModels || !disableChecks, "Options \"checkmodels\" and \"nochecks\" are mutually exclusive") protected val errors = new IncrementalBijection[Unit, Boolean]() protected def hasError = errors.getB(()) contains true @@ -56,131 +60,45 @@ class FairZ3Solver(val context: LeonContext, val program: Program) toggleWarningMessages(true) private def extractModel(model: Z3Model, ids: Set[Identifier]): HenkinModel = { - val asMap = modelToMap(model, ids) - def extract(b: Z3AST, m: Matcher[Z3AST]): Set[Seq[Expr]] = { val QuantificationTypeMatcher(fromTypes, _) = m.tpe val optEnabler = model.evalAs[Boolean](b) if (optEnabler == Some(true)) { - // FIXME: Dirty hack! - // Unfortunately, blockers don't lead to a true decision tree where all - // unexecutable branches are false since we have - // b1 ==> ( b2 \/ b3) - // b1 ==> (!b2 \/ !b3) - // so b2 /\ b3 is possible when b1 is false. This leads to unsound models - // (like Nil.tail) which obviously cannot be part of a domain but can't be - // translated back from Z3 either. - try { - val optArgs = (m.args zip fromTypes).map { - p => softFromZ3Formula(model, model.eval(Matcher.argValue(p._1), true).get, p._2) - } + val optArgs = (m.args zip fromTypes).map { + p => softFromZ3Formula(model, model.eval(Matcher.argValue(p._1), true).get, p._2) + } - if (optArgs.forall(_.isDefined)) { - Set(optArgs.map(_.get)) - } else { - Set.empty - } - } catch { - case _: Throwable => Set.empty + if (optArgs.forall(_.isDefined)) { + Set(optArgs.map(_.get)) + } else { + Set.empty } } else { Set.empty } } - val funDomains = ids.flatMap(id => id.getType match { - case ft @ FunctionType(fromTypes, _) => variables.getB(id.toVariable) match { - case Some(z3ID) => Some(id -> templateGenerator.manager.instantiations(z3ID, ft).flatMap { - case (b, m) => extract(b, m) - }) - case _ => None - } - case _ => None - }).toMap.mapValues(_.toSet) - - val asDMap = asMap.map(p => funDomains.get(p._1) match { - case Some(domain) => - val mapping = domain.toSeq.map { es => - val ev: Expr = p._2 match { - case RawArrayValue(_, mapping, dflt) => - mapping.collectFirst { - case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v - } getOrElse dflt - case _ => scala.sys.error("Unexpected function encoding " + p._2) - } - es -> ev - } - p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) - case None => p - }) - - val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) - val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + val (typeInsts, partialInsts, lambdaInsts) = templateGenerator.manager.instantiations - val domain = new HenkinDomains(typeDomains) - new HenkinModel(asDMap, domain) - } - - private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, HenkinModel) = { - if(!interrupted) { - - val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap - val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { - if (functions containsB p._1) { - val tfd = functions.toA(p._1) - if (!tfd.hasImplementation) { - val (cses, default) = p._2 - val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( - andJoin( - q._1.zip(tfd.params).map(a12 => Equals(fromZ3Formula(model, a12._1, a12._2.getType), Variable(a12._2.id))) - ), - fromZ3Formula(model, q._2, tfd.returnType), - expr)) - Seq((tfd.id, ite)) - } else Seq() - } else Seq() - }) - - val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { - if(functions containsB p._1) { - val tfd = functions.toA(p._1) - if(!tfd.hasImplementation) { - Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) - } else Seq() - } else Seq() - }).toMap - - val leonModel = extractModel(model, variables) - val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) - val evalResult = evaluator.eval(formula, fullModel) - - evalResult match { - case EvaluationResults.Successful(BooleanLiteral(true)) => - reporter.debug("- Model validated.") - (true, fullModel) - - case EvaluationResults.Successful(res) => - assert(res == BooleanLiteral(false), "Checking model returned non-boolean") - reporter.debug("- Invalid model.") - (false, fullModel) - - case EvaluationResults.RuntimeError(msg) => - reporter.debug("- Model leads to runtime error.") - (false, fullModel) - - case EvaluationResults.EvaluatorError(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to evaluator error: " + msg) - } else { - reporter.warning("Something went wrong. While evaluating the model, we got this : " + msg) - } - (false, fullModel) + val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + val funDomains: Map[Identifier, Set[Seq[Expr]]] = partialInsts.flatMap { + case (c, domain) => variables.getA(c).collect { + case Variable(id) => id -> domain.flatMap { case (b, m) => extract(b, m) }.toSet } - } else { - (false, HenkinModel.empty) } + + val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { + case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + } + + val asMap = modelToMap(model, ids) + val asDMap = purescala.Quantification.extractModel(asMap, funDomains, typeDomains, evaluator) + val domains = new HenkinDomains(lambdaDomains, typeDomains) + new HenkinModel(asDMap, domains) } implicit val z3Printable = (z3: Z3AST) => new Printable { @@ -218,23 +136,21 @@ class FairZ3Solver(val context: LeonContext, val program: Program) private val freeVars = new IncrementalSet[Identifier]() private val constraints = new IncrementalSeq[Expr]() - val unrollingBank = new UnrollingBank(context, templateGenerator) + private val incrementals: List[IncrementalState] = List( + errors, freeVars, constraints, functions, generics, lambdas, sorts, variables, + constructors, selectors, testers, unrollingBank + ) + def push() { - errors.push() solver.push() - unrollingBank.push() - freeVars.push() - constraints.push() + incrementals.foreach(_.push()) } def pop() { - errors.pop() solver.pop(1) - unrollingBank.pop() - freeVars.pop() - constraints.pop() + incrementals.foreach(_.pop()) } override def check: Option[Boolean] = { @@ -315,6 +231,115 @@ class FairZ3Solver(val context: LeonContext, val program: Program) }).toSet } + def validatedModel(silenceErrors: Boolean) : (Boolean, HenkinModel) = { + if (interrupted) { + (false, HenkinModel.empty) + } else { + val lastModel = solver.getModel + val clauses = templateGenerator.manager.checkClauses + val optModel = if (clauses.isEmpty) Some(lastModel) else { + solver.push() + for (clause <- clauses) { + solver.assertCnstr(clause) + } + + reporter.debug(" - Verifying model transitivity") + val timer = context.timers.solvers.z3.check.start() + solver.push() // FIXME: remove when z3 bug is fixed + val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.satisfactionAssumptions) :_*) + solver.pop() // FIXME: remove when z3 bug is fixed + timer.stop() + + val solverModel = res match { + case Some(true) => + Some(solver.getModel) + + case Some(false) => + val msg = "- Transitivity independence not guaranteed for model" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + + case None => + val msg = "- Unknown for transitivity independence!?" + if (silenceErrors) { + reporter.debug(msg) + } else { + reporter.warning(msg) + } + None + } + + solver.pop() + solverModel + } + + val model = optModel getOrElse lastModel + + val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap + val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { + if (functions containsB p._1) { + val tfd = functions.toA(p._1) + if (!tfd.hasImplementation) { + val (cses, default) = p._2 + val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( + andJoin( + q._1.zip(tfd.params).map(a12 => Equals(fromZ3Formula(model, a12._1, a12._2.getType), Variable(a12._2.id))) + ), + fromZ3Formula(model, q._2, tfd.returnType), + expr)) + Seq((tfd.id, ite)) + } else Seq() + } else Seq() + }) + + val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { + if(functions containsB p._1) { + val tfd = functions.toA(p._1) + if(!tfd.hasImplementation) { + Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) + } else Seq() + } else Seq() + }).toMap + + val leonModel = extractModel(model, freeVars.toSet) + val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) + + if (!optModel.isDefined) { + (false, leonModel) + } else { + (evaluator.check(entireFormula, fullModel) match { + case EvaluationResults.CheckSuccess => + reporter.debug("- Model validated.") + true + + case EvaluationResults.CheckValidityFailure => + reporter.debug("- Invalid model.") + false + + case EvaluationResults.CheckRuntimeFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to evaluation error: " + msg) + } else { + reporter.warning("- Model leads to evaluation error: " + msg) + } + false + + case EvaluationResults.CheckQuantificationFailure(msg) => + if (silenceErrors) { + reporter.debug("- Model leads to quantification error: " + msg) + } else { + reporter.warning("- Model leads to quantification error: " + msg) + } + false + }, leonModel) + } + } + } + while(!foundDefinitiveAnswer && !interrupted) { //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) @@ -346,27 +371,18 @@ class FairZ3Solver(val context: LeonContext, val program: Program) foundAnswer(None) case Some(true) => // SAT - - val z3model = solver.getModel() - - if (this.checkModels) { - val (isValid, model) = validateModel(z3model, entireFormula, allVars, silenceErrors = false) - - if (isValid) { - foundAnswer(Some(true), model) - } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model) - foundAnswer(None, model) - } + val (valid, model) = if (!this.disableChecks && (this.checkModels || requireQuantification)) { + validatedModel(false) } else { - val model = extractModel(z3model, allVars) - - //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") - //reporter.debug("- Found a model:") - //reporter.debug(modelAsString) + true -> extractModel(solver.getModel, allVars) + } + if (valid) { foundAnswer(Some(true), model) + } else { + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model.asString(context)) + foundAnswer(None, model) } case Some(false) if !unrollingBank.canUnroll => @@ -433,7 +449,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) //reporter.debug("SAT WITHOUT Blockers") if (this.feelingLucky && !interrupted) { // we might have been lucky :D - val (wereWeLucky, cleanModel) = validateModel(solver.getModel, entireFormula, allVars, silenceErrors = true) + val (wereWeLucky, cleanModel) = validatedModel(true) if(wereWeLucky) { foundAnswer(Some(true), cleanModel) diff --git a/src/main/scala/leon/synthesis/ConversionPhase.scala b/src/main/scala/leon/synthesis/ConversionPhase.scala index d76cff69b1a630718632f70e8684ff2426506cd2..d5f4e842f698e54d353d4170c1591073debefd5e 100644 --- a/src/main/scala/leon/synthesis/ConversionPhase.scala +++ b/src/main/scala/leon/synthesis/ConversionPhase.scala @@ -64,7 +64,7 @@ object ConversionPhase extends UnitPhase[Program] { * * def foo(a: T) = { * require(..a..) - * ??? + * _ * } ensuring { res => * post(res) * } @@ -75,6 +75,19 @@ object ConversionPhase extends UnitPhase[Program] { * require(..a..) * choose(x => post(x)) * } + * (in practice, there will be no pre-and postcondition) + * + * 4) Functions that have only a choose as body gets their spec from the choose. + * + * def foo(a: T) = { + * choose(x => post(a, x)) + * } + * + * gets converted to: + * + * def foo(a: T) = { + * choose(x => post(a, x)) + * } ensuring { x => post(a, x) } * */ @@ -115,7 +128,7 @@ object ConversionPhase extends UnitPhase[Program] { } } - body match { + val fullBody = body match { case Some(body) => var holes = List[Identifier]() @@ -172,6 +185,14 @@ object ConversionPhase extends UnitPhase[Program] { val newPost = post getOrElse Lambda(Seq(ValDef(FreshIdentifier("res", e.getType))), BooleanLiteral(true)) withPrecondition(Choose(newPost), pre) } + + // extract spec from chooses at the top-level + fullBody match { + case Choose(spec) => + withPostcondition(fullBody, Some(spec)) + case _ => + fullBody + } } diff --git a/src/main/scala/leon/synthesis/ExamplesBank.scala b/src/main/scala/leon/synthesis/ExamplesBank.scala index 77b451094ae1451ce7e46c9e3312488c7c398058..1dfddd0ca025caf3c7faafce3904c304be362f43 100644 --- a/src/main/scala/leon/synthesis/ExamplesBank.scala +++ b/src/main/scala/leon/synthesis/ExamplesBank.scala @@ -4,7 +4,6 @@ package synthesis import purescala.Definitions._ import purescala.Expressions._ import purescala.Constructors._ -import evaluators._ import purescala.Common._ import repair._ import leon.utils.ASCIIHelpers._ @@ -175,9 +174,7 @@ case class QualifiedExamplesBank(as: List[Identifier], xs: List[Identifier], eb: } def filterIns(expr: Expr): ExamplesBank = { - val ev = new DefaultEvaluator(hctx.sctx.context, hctx.sctx.program) - - filterIns(m => ev.eval(expr, m).result == Some(BooleanLiteral(true))) + filterIns(m => hctx.sctx.defaultEvaluator.eval(expr, m).result == Some(BooleanLiteral(true))) } def filterIns(pred: Map[Identifier, Expr] => Boolean): ExamplesBank = { diff --git a/src/main/scala/leon/synthesis/FileInterface.scala b/src/main/scala/leon/synthesis/FileInterface.scala index 251c6af0e8b11603eb9d57cc282d21c1c95e751f..25edd338fd2111314dd050d55da9ca14209a62d1 100644 --- a/src/main/scala/leon/synthesis/FileInterface.scala +++ b/src/main/scala/leon/synthesis/FileInterface.scala @@ -5,7 +5,6 @@ package synthesis import purescala.Expressions._ import purescala.Common.Tree -import purescala.Definitions.Definition import purescala.ScalaPrinter import purescala.PrinterOptions import purescala.PrinterContext @@ -15,7 +14,7 @@ import leon.utils.RangePosition import java.io.File class FileInterface(reporter: Reporter) { - def updateFile(origFile: File, solutions: Map[ChooseInfo, Expr])(implicit ctx: LeonContext) { + def updateFile(origFile: File, solutions: Map[SourceInfo, Expr])(implicit ctx: LeonContext) { import java.io.{File, BufferedWriter, FileWriter} val FileExt = """^(.+)\.([^.]+)$""".r @@ -35,7 +34,7 @@ class FileInterface(reporter: Reporter) { var newCode = origCode for ( (ci, e) <- solutions) { - newCode = substitute(newCode, ci.ch, e) + newCode = substitute(newCode, ci.source, e) } val out = new BufferedWriter(new FileWriter(newFile)) diff --git a/src/main/scala/leon/synthesis/LinearEquations.scala b/src/main/scala/leon/synthesis/LinearEquations.scala index f68fea066199089f206609be028b84af9056f280..da0f918103a959b5576a3909ebebeab3ec901305 100644 --- a/src/main/scala/leon/synthesis/LinearEquations.scala +++ b/src/main/scala/leon/synthesis/LinearEquations.scala @@ -74,7 +74,7 @@ object LinearEquations { var i = 0 while(i < sols.size) { // seriously ??? - K(i+j+1)(j) = evaluator.eval(sols(i)).asInstanceOf[EvaluationResults.Successful].value.asInstanceOf[InfiniteIntegerLiteral].value + K(i+j+1)(j) = evaluator.eval(sols(i)).asInstanceOf[EvaluationResults.Successful[Expr]].value.asInstanceOf[InfiniteIntegerLiteral].value i += 1 } } diff --git a/src/main/scala/leon/synthesis/Problem.scala b/src/main/scala/leon/synthesis/Problem.scala index f0d266df7c41b6f4a3fbbb3839e2246ca258f200..cf3640f70582fde36189bd4d525a2fc1903f26e6 100644 --- a/src/main/scala/leon/synthesis/Problem.scala +++ b/src/main/scala/leon/synthesis/Problem.scala @@ -38,13 +38,13 @@ case class Problem(as: List[Identifier], ws: Expr, pc: Expr, phi: Expr, xs: List } object Problem { - def fromChoose(ch: Choose, pc: Expr = BooleanLiteral(true), eb: ExamplesBank = ExamplesBank.empty): Problem = { + def fromSpec(spec: Expr, pc: Expr = BooleanLiteral(true), eb: ExamplesBank = ExamplesBank.empty): Problem = { val xs = { - val tps = ch.pred.getType.asInstanceOf[FunctionType].from - tps map (FreshIdentifier("x", _, true)) + val tps = spec.getType.asInstanceOf[FunctionType].from + tps map (FreshIdentifier("x", _, alwaysShowUniqueID = true)) }.toList - val phi = application(simplifyLets(ch.pred), xs map { _.toVariable}) + val phi = application(simplifyLets(spec), xs map { _.toVariable}) val as = (variablesOf(And(pc, phi)) -- xs).toList.sortBy(_.name) val TopLevelAnds(clauses) = pc @@ -57,10 +57,10 @@ object Problem { Problem(as, andJoin(wss), andJoin(pcs), phi, xs, eb) } - def fromChooseInfo(ci: ChooseInfo): Problem = { + def fromSourceInfo(ci: SourceInfo): Problem = { // Same as fromChoose, but we order the input variables by the arguments of // the functions, so that tests are compatible - val p = fromChoose(ci.ch, ci.pc, ci.eb) + val p = fromSpec(ci.spec, ci.pc, ci.eb) val argsIndex = ci.fd.params.map(_.id).zipWithIndex.toMap.withDefaultValue(100) p.copy( as = p.as.sortBy(a => argsIndex(a))) diff --git a/src/main/scala/leon/synthesis/SearchContext.scala b/src/main/scala/leon/synthesis/SearchContext.scala index 5fc938b77f54c84cd1dac691a078dd600df6f52a..1ee7361d5a878978fede4084b0909a8e711137cd 100644 --- a/src/main/scala/leon/synthesis/SearchContext.scala +++ b/src/main/scala/leon/synthesis/SearchContext.scala @@ -11,7 +11,7 @@ import graph._ */ case class SearchContext ( sctx: SynthesisContext, - ci: ChooseInfo, + ci: SourceInfo, currentNode: Node, search: Search ) { diff --git a/src/main/scala/leon/synthesis/ChooseInfo.scala b/src/main/scala/leon/synthesis/SourceInfo.scala similarity index 74% rename from src/main/scala/leon/synthesis/ChooseInfo.scala rename to src/main/scala/leon/synthesis/SourceInfo.scala index 5e36c593b12f7df74575ad634b0683991245cbb5..4bb10d38c9ffc7a7667d165b84e4f65c1edc9e0c 100644 --- a/src/main/scala/leon/synthesis/ChooseInfo.scala +++ b/src/main/scala/leon/synthesis/SourceInfo.scala @@ -9,17 +9,25 @@ import purescala.Expressions._ import purescala.ExprOps._ import Witnesses._ -case class ChooseInfo(fd: FunDef, +case class SourceInfo(fd: FunDef, pc: Expr, source: Expr, - ch: Choose, + spec: Expr, eb: ExamplesBank) { - val problem = Problem.fromChooseInfo(this) + val problem = Problem.fromSourceInfo(this) } -object ChooseInfo { - def extractFromProgram(ctx: LeonContext, prog: Program): List[ChooseInfo] = { +object SourceInfo { + + class ChooseCollectorWithPaths extends CollectorWithPaths[(Choose,Expr)] { + def collect(e: Expr, path: Seq[Expr]) = e match { + case c: Choose => Some(c -> and(path: _*)) + case _ => None + } + } + + def extractFromProgram(ctx: LeonContext, prog: Program): List[SourceInfo] = { val functions = ctx.findOption(SharedOptions.optFunctions) map { _.toSet } def excludeByDefault(fd: FunDef): Boolean = { @@ -40,9 +48,8 @@ object ChooseInfo { results.sortBy(_.source.getPos) } - def extractFromFunction(ctx: LeonContext, prog: Program, fd: FunDef): Seq[ChooseInfo] = { + def extractFromFunction(ctx: LeonContext, prog: Program, fd: FunDef): Seq[SourceInfo] = { - val actualBody = and(fd.precOrTrue, fd.body.get) val term = Terminating(fd.typed, fd.params.map(_.id.toVariable)) val eFinder = new ExamplesFinder(ctx, prog) @@ -50,14 +57,14 @@ object ChooseInfo { // We are synthesizing, so all examples are valid ones val functionEb = eFinder.extractFromFunDef(fd, partition = false) - for ((ch, path) <- new ChooseCollectorWithPaths().traverse(actualBody)) yield { + for ((ch, path) <- new ChooseCollectorWithPaths().traverse(fd)) yield { val outerEb = if (path == BooleanLiteral(true)) { functionEb } else { ExamplesBank.empty } - val ci = ChooseInfo(fd, and(path, term), ch, ch, outerEb) + val ci = SourceInfo(fd, and(path, term), ch, ch.pred, outerEb) val pcEb = eFinder.generateForPC(ci.problem.as, path, 20) val chooseEb = eFinder.extractFromProblem(ci.problem) diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index bb90d2a268411c44e9dfdbde013886b7db5b56d9..b01077ce751f6b53e397b15eb3d5126e560d1b56 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -6,6 +6,7 @@ package synthesis import solvers._ import solvers.combinators._ import purescala.Definitions.{Program, FunDef} +import evaluators._ /** * This is global information per entire search, contains necessary information @@ -22,6 +23,10 @@ case class SynthesisContext( val rules = settings.rules val solverFactory = SolverFactory.getFromSettings(context, program) + + lazy val defaultEvaluator = { + new DefaultEvaluator(context, program) + } } object SynthesisContext { diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index ef2bb41e75eb256484f1a663699bd16ca6efa40b..72cfffec3f8ef3fd8a526d778fda1f26a9ec1a41 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -65,8 +65,7 @@ object SynthesisPhase extends TransformationPhase { def apply(ctx: LeonContext, program: Program): Program = { val options = processOptions(ctx) - - val chooses = ChooseInfo.extractFromProgram(ctx, program) + val chooses = SourceInfo.extractFromProgram(ctx, program) var functions = Set[FunDef]() @@ -75,7 +74,7 @@ object SynthesisPhase extends TransformationPhase { val synthesizer = new Synthesizer(ctx, program, ci, options) - val (search, solutions) = synthesizer.validate(synthesizer.synthesize(), true) + val (search, solutions) = synthesizer.validate(synthesizer.synthesize(), allowPartial = true) try { if (options.generateDerivationTrees) { diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 54e305adf11ebc9314c262b4be5b203df8337771..35fdfb845ae0a5de611ab3a01a3111005aff7a05 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -15,7 +15,7 @@ import synthesis.graph._ class Synthesizer(val context : LeonContext, val program: Program, - val ci: ChooseInfo, + val ci: SourceInfo, val settings: SynthesisSettings) { val problem = ci.problem diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index c33a31cebe0c430a7c2b21c932529ca5b72ed053..98554a5ae492972e0b7b3915979d9af829d81555 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -10,7 +10,7 @@ import scala.collection.mutable.ArrayBuffer import leon.utils.Interruptible import java.util.concurrent.atomic.AtomicBoolean -abstract class Search(ctx: LeonContext, ci: ChooseInfo, p: Problem, costModel: CostModel) extends Interruptible { +abstract class Search(ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { val g = new Graph(costModel, p) def findNodeToExpandFrom(n: Node): Option[Node] @@ -83,7 +83,7 @@ abstract class Search(ctx: LeonContext, ci: ChooseInfo, p: Problem, costModel: C ctx.interruptManager.registerForInterrupts(this) } -class SimpleSearch(ctx: LeonContext, ci: ChooseInfo, p: Problem, costModel: CostModel, bound: Option[Int]) extends Search(ctx, ci, p, costModel) { +class SimpleSearch(ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel, bound: Option[Int]) extends Search(ctx, ci, p, costModel) { val expansionBuffer = ArrayBuffer[Node]() def findIn(n: Node) { @@ -124,7 +124,7 @@ class SimpleSearch(ctx: LeonContext, ci: ChooseInfo, p: Problem, costModel: Cost } } -class ManualSearch(ctx: LeonContext, ci: ChooseInfo, problem: Problem, costModel: CostModel, initCmd: Option[String]) extends Search(ctx, ci, problem, costModel) { +class ManualSearch(ctx: LeonContext, ci: SourceInfo, problem: Problem, costModel: CostModel, initCmd: Option[String]) extends Search(ctx, ci, problem, costModel) { import ctx.reporter._ abstract class Command diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala index 392670edff7420493b7d17d87ff94358f962c172..004a88d04dcde66008f37cef81fdd5e60261dfea 100644 --- a/src/main/scala/leon/synthesis/rules/ADTDual.scala +++ b/src/main/scala/leon/synthesis/rules/ADTDual.scala @@ -18,10 +18,10 @@ case object ADTDual extends NormalizingRule("ADTDual") { val (toRemove, toAdd) = exprs.collect { case eq @ Equals(cc @ CaseClass(ct, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, IsInstanceOf(e, ct) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) + (eq, IsInstanceOf(e, ct) +: (ct.classDef.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) case eq @ Equals(e, cc @ CaseClass(ct, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, IsInstanceOf(e, ct) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) + (eq, IsInstanceOf(e, ct) +: (ct.classDef.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) }.unzip if (toRemove.nonEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 294c6e1374fa09ed6e54690aecfb4661b5073d04..32848ae1c52fa2a2d1e960cf51c8b26b166fe577 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -11,32 +11,43 @@ import purescala.ExprOps._ import purescala.Extractors._ import purescala.Constructors._ import purescala.Definitions._ -import solvers._ case object ADTSplit extends Rule("ADT Split.") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(hctx.sctx.solverFactory.withTimeout(200L)) + // We approximate knowledge of types based on facts found at the top-level + // we don't care if the variables are known to be equal or not, we just + // don't want to split on two variables for which only one split + // alternative is viable. This should be much less expensive than making + // calls to a solver for each pair. + var facts = Map[Identifier, CaseClassType]() + + def addFacts(e: Expr): Unit = e match { + case Equals(Variable(a), CaseClass(cct, _)) => facts += a -> cct + case IsInstanceOf(Variable(a), cct: CaseClassType) => facts += a -> cct + case _ => + } + + val TopLevelAnds(as) = and(p.pc, p.phi) + for (e <- as) { + addFacts(e) + } val candidates = p.as.collect { case IsTyped(id, act @ AbstractClassType(cd, tpes)) => - val optCases = for (dcd <- cd.knownDescendants.sortBy(_.id.name)) yield dcd match { + val optCases = cd.knownDescendants.sortBy(_.id.name).collect { case ccd : CaseClassDef => val cct = CaseClassType(ccd, tpes) - val toSat = and(p.pc, IsInstanceOf(Variable(id), cct)) - val isImplied = solver.solveSAT(toSat) match { - case (Some(false), _) => true - case _ => false - } - - if (!isImplied) { - Some(ccd) + if (facts contains id) { + if (cct == facts(id)) { + Seq(ccd) + } else { + Nil + } } else { - None + Seq(ccd) } - case _ => - None } val cases = optCases.flatten @@ -83,7 +94,7 @@ case object ADTSplit extends Rule("ADT Split.") { val cases = for ((sol, (cct, problem, pattern)) <- sols zip subInfo) yield { if (sol.pre != BooleanLiteral(true)) { - val substs = (for ((field,arg) <- cct.fields zip problem.as ) yield { + val substs = (for ((field,arg) <- cct.classDef.fields zip problem.as ) yield { (arg, caseClassSelector(cct, id.toVariable, field.id)) }).toMap globalPre ::= and(IsInstanceOf(Variable(id), cct), replaceFromIDs(substs, sol.pre)) diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index f3fa5947d0c8cb8092f9d8fc032ca71c12cc2846..5d06aa76a3c1186e806edf605431643e6b2a9359 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -20,7 +20,6 @@ import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} import evaluators._ import datagen._ -import leon.utils._ import codegen.CodeGenParams abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { @@ -43,14 +42,13 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val sctx = hctx.sctx val ctx = sctx.context - // CEGIS Flags to activate or deactivate features val useOptTimeout = sctx.settings.cegisUseOptTimeout.getOrElse(true) val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false) - val useShrink = sctx.settings.cegisUseShrink.getOrElse(true) + val useShrink = sctx.settings.cegisUseShrink.getOrElse(false) // Limits the number of programs CEGIS will specifically validate individually - val validateUpTo = 5 + val validateUpTo = 3 // Shrink the program when the ratio of passing cases is less than the threshold val shrinkThreshold = 1.0/2 @@ -314,6 +312,8 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { private val cTreeFd = new FunDef(FreshIdentifier("cTree", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), p.outType) + private val solFd = new FunDef(FreshIdentifier("solFd", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), p.outType) + private val phiFd = new FunDef(FreshIdentifier("phiFd", alwaysShowUniqueID = true), Seq(), p.as.map(id => ValDef(id)), BooleanType) @@ -321,13 +321,20 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val outerSolution = { new PartialSolution(hctx.search.g, true) - .solutionAround(hctx.currentNode)(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable))) + .solutionAround(hctx.currentNode)(FunctionInvocation(solFd.typed, p.as.map(_.toVariable))) .getOrElse(ctx.reporter.fatalError("Unable to get outer solution")) } - val program0 = addFunDefs(sctx.program, Seq(cTreeFd, phiFd) ++ outerSolution.defs, hctx.ci.fd) + val program0 = addFunDefs(sctx.program, Seq(cTreeFd, solFd, phiFd) ++ outerSolution.defs, hctx.ci.fd) cTreeFd.body = None + + solFd.fullBody = Ensuring( + FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), + Lambda(p.xs.map(ValDef(_)), p.phi) + ) + + phiFd.body = Some( letTuple(p.xs, FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), @@ -349,7 +356,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { // We freshen/duplicate every functions, except these two as they are // fresh anyway and we refer to them directly. - case `cTreeFd` | `phiFd` => + case `cTreeFd` | `phiFd` | `solFd` => None case fd => @@ -372,7 +379,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { private val innerPhi = outerExprToInnerExpr(p.phi) private var programCTree: Program = _ - private var tester: (Example, Set[Identifier]) => EvaluationResults.Result = _ + private var tester: (Example, Set[Identifier]) => EvaluationResults.Result[Expr] = _ private def setCExpr(cTreeInfo: (Expr, Seq[FunDef])): Unit = { val (cTree, newFds) = cTreeInfo @@ -384,7 +391,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { //println(programCTree.asString) //println(".. "*30) -// val evaluator = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default) + //val evaluator = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default) val evaluator = new DefaultEvaluator(sctx.context, programCTree) tester = @@ -454,13 +461,13 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { //println("Solving for: "+cnstr.asString) - val solverf = SolverFactory.default(ctx, innerProgram).withTimeout(cexSolverTo) + val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) val solver = solverf.getNewSolver() try { solver.assertCnstr(cnstr) solver.check match { case Some(true) => - excludeProgram(bs) + excludeProgram(bs, true) val model = solver.getModel //println("Found counter example: ") //for ((s, v) <- model) { @@ -499,8 +506,24 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { var excludedPrograms = ArrayBuffer[Set[Identifier]]() // Explicitly remove program computed by bValues from the search space - def excludeProgram(bValues: Set[Identifier]): Unit = { - val bvs = bValues.filter(isBActive) + // + // If the bValues comes from models, we make sure the bValues we exclude + // are minimal we make sure we exclude only Bs that are used. + def excludeProgram(bValues: Set[Identifier], isMinimal: Boolean): Unit = { + val bs = bValues.filter(isBActive) + + def filterBTree(c: Identifier): Set[Identifier] = { + (for ((b, _, subcs) <- cTree(c) if bValues(b)) yield { + Set(b) ++ subcs.flatMap(filterBTree) + }).toSet.flatten + } + + val bvs = if (isMinimal) { + bs + } else { + rootCs.flatMap(filterBTree).toSet + } + excludedPrograms += bvs } @@ -513,19 +536,21 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { * First phase of CEGIS: solve for potential programs (that work on at least one input) */ def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = { - val solverf = SolverFactory.default(ctx, programCTree).withTimeout(exSolverTo) + val solverf = SolverFactory.getFromSettings(ctx, programCTree).withTimeout(exSolverTo) val solver = solverf.getNewSolver() val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) - //debugCExpr(cTree) - //println(" --- PhiFD ---") - //println(phiFd.fullBody.asString(ctx)) + //println("Program: ") + //println("-"*80) + //println(programCTree.asString) val toFind = and(innerPc, cnstr) //println(" --- Constraints ---") - //println(" - "+toFind) + //println(" - "+toFind.asString) try { - solver.assertCnstr(andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable)))) + //TODO: WHAT THE F IS THIS? + //val bsOrNotBs = andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable))) + //solver.assertCnstr(bsOrNotBs) solver.assertCnstr(toFind) for ((c, alts) <- cTree) { @@ -536,26 +561,25 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } if (activeBs.nonEmpty) { - //println(" - "+andJoin(either)) + //println(" - "+andJoin(either).asString) solver.assertCnstr(andJoin(either)) val oneOf = orJoin(activeBs.map(_.toVariable)) - //println(" - "+oneOf) + //println(" - "+oneOf.asString) solver.assertCnstr(oneOf) } } - //println(" -- Excluded:") //println(" -- Active:") val isActive = andJoin(bsOrdered.filterNot(isBActive).map(id => Not(id.toVariable))) - //println(" - "+isActive) + //println(" - "+isActive.asString) solver.assertCnstr(isActive) //println(" -- Excluded:") for (ex <- excludedPrograms) { val notThisProgram = Not(andJoin(ex.map(_.toVariable).toSeq)) - //println(f" - $notThisProgram%-40s ("+getExpr(ex)+")") + //println(f" - ${notThisProgram.asString}%-40s ("+getExpr(ex)+")") solver.assertCnstr(notThisProgram) } @@ -565,7 +589,10 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val bModel = bs.filter(b => model.get(b).contains(BooleanLiteral(true))) + //println("Tentative model: "+model.asString) + //println("Tentative model: "+bModel.filter(isBActive).map(_.asString).toSeq.sorted) //println("Tentative expr: "+getExpr(bModel)) + Some(Some(bModel)) case Some(false) => @@ -591,7 +618,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { * Second phase of CEGIS: verify a given program by looking for CEX inputs */ def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = { - val solverf = SolverFactory.default(ctx, programCTree).withTimeout(cexSolverTo) + val solverf = SolverFactory.getFromSettings(ctx, programCTree).withTimeout(cexSolverTo) val solver = solverf.getNewSolver() val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) @@ -601,6 +628,14 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { solver.assertCnstr(innerPc) solver.assertCnstr(Not(cnstr)) + //println("*"*80) + //println(Not(cnstr)) + //println(innerPc) + //println("*"*80) + //println(programCTree.asString) + //println("*"*80) + //Console.in.read() + solver.check match { case Some(true) => val model = solver.getModel @@ -685,17 +720,6 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { */ val nTests = if (p.pc == BooleanLiteral(true)) 50 else 20 - /* - val inputGenerator: FreeableIterator[Example] = { - val sff = { - (ctx: LeonContext, pgm: Program) => - SolverFactory.default(ctx, pgm).withTimeout(exSolverTo) - } - new SolverDataGen(sctx.context, sctx.program, sff).generateFor(p.as, p.pc, nTests, nTests).map { - InExample(_) - } - } */ - val inputGenerator: Iterator[Example] = if (useVanuatoo) { new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, nTests, 3000).map(InExample) } else { @@ -813,15 +837,15 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } if (doFilter && !(nPassing < nInitial * shrinkThreshold && useShrink)) { + sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") wrongPrograms.foreach { - ndProgram.excludeProgram + ndProgram.excludeProgram(_, true) } } } // CEGIS Loop at a given unfolding level while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted) { - ndProgram.solveForTentativeProgram() match { case Some(Some(bs)) => // Should we validate this program with Z3? @@ -833,18 +857,21 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { // make sure by validating this candidate with z3 true } else { + println("testing failed ?!") // One valid input failed with this candidate, we can skip - ndProgram.excludeProgram(bs) + ndProgram.excludeProgram(bs, false) false } } else { // No inputs or capability to test, we need to ask Z3 true } + sctx.reporter.debug("Found tentative model (Validate="+validateWithZ3+")!") if (validateWithZ3) { ndProgram.solveForCounterExample(bs) match { case Some(Some(inputsCE)) => + sctx.reporter.debug("Found counter-example:"+inputsCE) val ce = InExample(inputsCE) // Found counter example! baseExampleInputs += ce @@ -853,7 +880,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(ce))) { skipCESearch = true } else { - ndProgram.excludeProgram(bs) + ndProgram.excludeProgram(bs, false) } case Some(None) => @@ -863,6 +890,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { case None => // We are not sure + sctx.reporter.debug("Unknown") if (useOptTimeout) { // Interpret timeout in CE search as "the candidate is valid" sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") diff --git a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala index bc62dbba88c8b7dea8b6aaefc5879dfc3909774f..132ea9f766720d776af801e36ac58cc1f6b87b73 100644 --- a/src/main/scala/leon/synthesis/rules/EqualitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/EqualitySplit.scala @@ -6,6 +6,7 @@ package rules import leon.purescala.Common.Identifier import purescala.Expressions._ +import purescala.Extractors._ import purescala.Constructors._ import solvers._ @@ -14,31 +15,31 @@ import scala.concurrent.duration._ case object EqualitySplit extends Rule("Eq. Split") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(hctx.sctx.solverFactory.withTimeout(50.millis)) + // We approximate knowledge of equality based on facts found at the top-level + // we don't care if the variables are known to be equal or not, we just + // don't want to split on two variables for which only one split + // alternative is viable. This should be much less expensive than making + // calls to a solver for each pair. + var facts = Set[Set[Identifier]]() - val candidates = p.as.groupBy(_.getType).mapValues(_.combinations(2).filter { - case List(a1, a2) => - val toValEQ = implies(p.pc, Equals(Variable(a1), Variable(a2))) - - val impliesEQ = solver.solveSAT(Not(toValEQ)) match { - case (Some(false), _) => true - case _ => false - } - - if (!impliesEQ) { - val toValNE = implies(p.pc, not(Equals(Variable(a1), Variable(a2)))) + def addFacts(e: Expr): Unit = e match { + case Not(e) => addFacts(e) + case LessThan(Variable(a), Variable(b)) => facts += Set(a,b) + case LessEquals(Variable(a), Variable(b)) => facts += Set(a,b) + case GreaterThan(Variable(a), Variable(b)) => facts += Set(a,b) + case GreaterEquals(Variable(a), Variable(b)) => facts += Set(a,b) + case Equals(Variable(a), Variable(b)) => facts += Set(a,b) + case _ => + } - val impliesNE = solver.solveSAT(Not(toValNE)) match { - case (Some(false), _) => true - case _ => false - } + val TopLevelAnds(as) = and(p.pc, p.phi) + for (e <- as) { + addFacts(e) + } - !impliesNE - } else { - false - } - case _ => false - }).values.flatten + val candidates = p.as.groupBy(_.getType).mapValues{ as => + as.combinations(2).filterNot(facts contains _.toSet) + }.values.flatten candidates.flatMap { case List(a1, a2) => diff --git a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala index 227dd691c29734e0dd85feb179f40fac15ceea7a..6b2f9e8585a98ab9677aa5f1c439b2a1afa136ef 100644 --- a/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala +++ b/src/main/scala/leon/synthesis/rules/EquivalentInputs.scala @@ -24,7 +24,7 @@ case object EquivalentInputs extends NormalizingRule("EquivalentInputs") { val ccSubsts = for (IsInstanceOf(s, cct: CaseClassType) <- instanceOfs) yield { - val fieldsVals = (for (f <- cct.fields) yield { + val fieldsVals = (for (f <- cct.classDef.fields) yield { val id = f.id clauses.find { diff --git a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala index 292e99de80d4cf7eb3351621edde3587645b0402..8b728930a10ad7d73c121762bde43637ce988fbc 100644 --- a/src/main/scala/leon/synthesis/rules/InequalitySplit.scala +++ b/src/main/scala/leon/synthesis/rules/InequalitySplit.scala @@ -7,64 +7,42 @@ package rules import purescala.Expressions._ import purescala.Types._ import purescala.Constructors._ - -import solvers._ +import purescala.Extractors._ +import purescala.Common._ import scala.concurrent.duration._ case object InequalitySplit extends Rule("Ineq. Split.") { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { - val solver = SimpleSolverAPI(hctx.sctx.solverFactory.withTimeout(50.millis)) + // We approximate knowledge of equality based on facts found at the top-level + // we don't care if the variables are known to be equal or not, we just + // don't want to split on two variables for which only one split + // alternative is viable. This should be much less expensive than making + // calls to a solver for each pair. + var facts = Set[Set[Identifier]]() + + def addFacts(e: Expr): Unit = e match { + case Not(e) => addFacts(e) + case LessThan(Variable(a), Variable(b)) => facts += Set(a,b) + case LessEquals(Variable(a), Variable(b)) => facts += Set(a,b) + case GreaterThan(Variable(a), Variable(b)) => facts += Set(a,b) + case GreaterEquals(Variable(a), Variable(b)) => facts += Set(a,b) + case Equals(Variable(a), Variable(b)) => facts += Set(a,b) + case _ => + } + + val TopLevelAnds(as) = and(p.pc, p.phi) + for (e <- as) { + addFacts(e) + } val argsPairs = p.as.filter(_.getType == IntegerType).combinations(2) ++ p.as.filter(_.getType == Int32Type).combinations(2) - val candidates = argsPairs.toList.filter { - case List(a1, a2) => - val toValLT = implies(p.pc, LessThan(Variable(a1), Variable(a2))) - - val impliesLT = solver.solveSAT(not(toValLT)) match { - case (Some(false), _) => true - case _ => false - } - - if (!impliesLT) { - val toValGT = implies(p.pc, GreaterThan(Variable(a1), Variable(a2))) - - val impliesGT = solver.solveSAT(not(toValGT)) match { - case (Some(false), _) => true - case _ => false - } + val candidates = argsPairs.toList.filter { case List(a1, a2) => !(facts contains Set(a1, a2)) } - if (!impliesGT) { - val toValEQ = implies(p.pc, Equals(Variable(a1), Variable(a2))) - - val impliesEQ = solver.solveSAT(not(toValEQ)) match { - case (Some(false), _) => true - case _ => false - } - - !impliesEQ - } else { - false - } - } else { - false - } - case _ => false - } - - - candidates.flatMap { + candidates.collect { case List(a1, a2) => - - val subLT = p.copy(pc = and(LessThan(Variable(a1), Variable(a2)), p.pc), - eb = p.qeb.filterIns(LessThan(Variable(a1), Variable(a2)))) - val subEQ = p.copy(pc = and(Equals(Variable(a1), Variable(a2)), p.pc), - eb = p.qeb.filterIns(Equals(Variable(a1), Variable(a2)))) - val subGT = p.copy(pc = and(GreaterThan(Variable(a1), Variable(a2)), p.pc), - eb = p.qeb.filterIns(GreaterThan(Variable(a1), Variable(a2)))) - val onSuccess: List[Solution] => Option[Solution] = { case sols@List(sLT, sEQ, sGT) => val pre = orJoin(sols.map(_.pre)) @@ -85,9 +63,24 @@ case object InequalitySplit extends Rule("Ineq. Split.") { None } - Some(decomp(List(subLT, subEQ, subGT), onSuccess, s"Ineq. Split on '$a1' and '$a2'")) - case _ => - None + val subTypes = List(p.outType, p.outType, p.outType) + + new RuleInstantiation(s"Ineq. Split on '$a1' and '$a2'", + SolutionBuilderDecomp(subTypes, onSuccess)) { + + def apply(hctx: SearchContext) = { + implicit val _ = hctx + + val subLT = p.copy(pc = and(LessThan(Variable(a1), Variable(a2)), p.pc), + eb = p.qeb.filterIns(LessThan(Variable(a1), Variable(a2)))) + val subEQ = p.copy(pc = and(Equals(Variable(a1), Variable(a2)), p.pc), + eb = p.qeb.filterIns(Equals(Variable(a1), Variable(a2)))) + val subGT = p.copy(pc = and(GreaterThan(Variable(a1), Variable(a2)), p.pc), + eb = p.qeb.filterIns(GreaterThan(Variable(a1), Variable(a2)))) + + RuleExpanded(List(subLT, subEQ, subGT)) + } + } } } } diff --git a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala index dfd1360d5a3ea8536edd6d1ff3df89d07d12fe24..c3df2a38297ea7766cdb85cb011124a41d0fbb20 100644 --- a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala +++ b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala @@ -7,14 +7,14 @@ package utils import purescala.DefOps.funDefsFromMain import purescala.Definitions._ -object SynthesisProblemExtractionPhase extends SimpleLeonPhase[Program, (Program, Map[FunDef, Seq[ChooseInfo]])] { +object SynthesisProblemExtractionPhase extends SimpleLeonPhase[Program, (Program, Map[FunDef, Seq[SourceInfo]])] { val name = "Synthesis Problem Extraction" val description = "Synthesis Problem Extraction" - def apply(ctx: LeonContext, p: Program): (Program, Map[FunDef, Seq[ChooseInfo]]) = { + def apply(ctx: LeonContext, p: Program): (Program, Map[FunDef, Seq[SourceInfo]]) = { // Look for choose() val results = for (f <- funDefsFromMain(p).toSeq.sortBy(_.id) if f.body.isDefined) yield { - f -> ChooseInfo.extractFromFunction(ctx, p, f) + f -> SourceInfo.extractFromFunction(ctx, p, f) } (p, results.toMap) diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala index c5a8ccb69c870b16f91e1c1dcfa5cb933bc8138b..799e51db6c1d08895ffb23793fc7328e881a7242 100644 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -6,7 +6,7 @@ package termination import purescala.Expressions._ import purescala.Common._ import purescala.Definitions._ -import purescala.Constructors.tupleWrap +import purescala.Constructors._ class ChainProcessor( val checker: TerminationChecker, @@ -33,40 +33,40 @@ class ChainProcessor( reporter.debug("-+> Multiple looping points, can't build chain proof") None } else { + val funDef = loopPoints.headOption getOrElse { + chainsMap.collectFirst { case (fd, (fds, chains)) if chains.nonEmpty => fd }.get + } - def exprs(fd: FunDef): (Expr, Seq[(Seq[Expr], Expr)], Set[Chain]) = { - val fdChains = chainsMap(fd)._2 - - val e1 = tupleWrap(fd.params.map(_.toVariable)) - val e2s = fdChains.toSeq.map { chain => - val freshParams = chain.finalParams.map(arg => FreshIdentifier(arg.id.name, arg.id.getType, true)) - (chain.loop(finalArgs = freshParams), tupleWrap(freshParams.map(_.toVariable))) - } + val chains = chainsMap(funDef)._2 - (e1, e2s, fdChains) + val e1 = tupleWrap(funDef.params.map(_.toVariable)) + val e2s = chains.toSeq.map { chain => + val freshParams = chain.finalParams.map(arg => FreshIdentifier(arg.id.name, arg.id.getType, true)) + (chain.loop(finalArgs = freshParams), tupleWrap(freshParams.map(_.toVariable))) } - val funDefs = if (loopPoints.size == 1) Set(loopPoints.head) else problem.funSet - reporter.debug("-+> Searching for structural size decrease") - val (se1, se2s, _) = exprs(funDefs.head) - val structuralFormulas = modules.structuralDecreasing(se1, se2s) + val structuralFormulas = modules.structuralDecreasing(e1, e2s) val structuralDecreasing = structuralFormulas.exists(formula => definitiveALL(formula)) reporter.debug("-+> Searching for numerical converging") - // worth checking multiple funDefs as the endpoint discovery can be context sensitive - val numericDecreasing = funDefs.exists { fd => - val (ne1, ne2s, fdChains) = exprs(fd) - val numericFormulas = modules.numericConverging(ne1, ne2s, fdChains) - numericFormulas.exists(formula => definitiveALL(formula)) - } + val numericFormulas = modules.numericConverging(e1, e2s, chains) + val numericDecreasing = numericFormulas.exists(formula => definitiveALL(formula)) if (structuralDecreasing || numericDecreasing) Some(problem.funDefs map Cleared) - else - None + else { + val chainsUnlooping = chains.flatMap(c1 => chains.flatMap(c2 => c1 compose c2)).forall { + chain => !definitiveSATwithModel(andJoin(chain.loop())).isDefined + } + + if (chainsUnlooping) + Some(problem.funDefs map Cleared) + else + None + } } } } diff --git a/src/main/scala/leon/termination/ProcessingPipeline.scala b/src/main/scala/leon/termination/ProcessingPipeline.scala index d2bae17b326623178671a2fc7db26954e013057e..92a11c6452c88f49986cd9382cae9ed72f70619a 100644 --- a/src/main/scala/leon/termination/ProcessingPipeline.scala +++ b/src/main/scala/leon/termination/ProcessingPipeline.scala @@ -165,8 +165,8 @@ abstract class ProcessingPipeline(context: LeonContext, initProgram: Program) ex val components = SCC.scc(callGraph) for (fd <- funDefs -- components.toSet.flatten) clearedMap(fd) = "Non-recursive" - - components.map(fds => Problem(fds.toSeq)) + val newProblems = components.filter(fds => fds.forall { fd => !terminationMap.isDefinedAt(fd) }) + newProblems.map(fds => Problem(fds.toSeq)) } def verifyTermination(funDef: FunDef): Unit = { diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index 99124c5e64ed8bb61e3c44a015775280d06c584e..3f7be09f145d2768ff00f8573f078d6e90bc1b2d 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -34,7 +34,7 @@ trait Solvable extends Processor { val sizeUnit : UnitDef = UnitDef(FreshIdentifier("$size"),Seq(sizeModule)) val newProgram : Program = program.copy( units = sizeUnit :: program.units) - SolverFactory.getFromSettings(context, newProgram).withTimeout(500.millisecond) + SolverFactory.getFromSettings(context, newProgram).withTimeout(10.seconds) } type Solution = (Option[Boolean], Map[Identifier, Expr]) diff --git a/src/main/scala/leon/transformations/InstrumentationUtil.scala b/src/main/scala/leon/transformations/InstrumentationUtil.scala index 2c461e466298614ebae30df8a33ff901fea72fa5..1c3b744d8f667b3a1bb7b653ff299868fe98b38a 100644 --- a/src/main/scala/leon/transformations/InstrumentationUtil.scala +++ b/src/main/scala/leon/transformations/InstrumentationUtil.scala @@ -63,7 +63,7 @@ object InstUtil { val vary = yid.toVariable val args = Seq(xid, yid) val maxType = FunctionType(Seq(IntegerType, IntegerType), IntegerType) - val mfd = new FunDef(FreshIdentifier("max", maxType, false), Seq(), args.map((arg) => ValDef(arg, Some(arg.getType))), IntegerType) + val mfd = new FunDef(FreshIdentifier("max", maxType, false), Seq(), args.map(arg => ValDef(arg)), IntegerType) val cond = GreaterEquals(varx, vary) mfd.body = Some(IfExpr(cond, varx, vary)) diff --git a/src/main/scala/leon/transformations/IntToRealProgram.scala b/src/main/scala/leon/transformations/IntToRealProgram.scala index 21229282b398fcc10cd9548ddaead2f25d8087a3..5ab16991398e5dcb4352e870795fff51045fb9c9 100644 --- a/src/main/scala/leon/transformations/IntToRealProgram.scala +++ b/src/main/scala/leon/transformations/IntToRealProgram.scala @@ -72,8 +72,7 @@ abstract class ProgramTypeTransformer { } def mapDecl(decl: ValDef): ValDef = { - val newtpe = mapType(decl.getType) - new ValDef(mapId(decl.id), Some(newtpe)) + decl.copy(id = mapId(decl.id)) } def mapType(tpe: TypeTree): TypeTree = { @@ -141,9 +140,9 @@ abstract class ProgramTypeTransformer { // FIXME //add a new postcondition newfd.fullBody = if (fd.postcondition.isDefined && newfd.body.isDefined) { - val Lambda(Seq(ValDef(resid, _)), pexpr) = fd.postcondition.get + val Lambda(Seq(ValDef(resid, lzy)), pexpr) = fd.postcondition.get val tempRes = mapId(resid).toVariable - Ensuring(newfd.body.get, Lambda(Seq(ValDef(tempRes.id, Some(tempRes.getType))), transformExpr(pexpr))) + Ensuring(newfd.body.get, Lambda(Seq(ValDef(tempRes.id, lzy)), transformExpr(pexpr))) // Some(mapId(resid), transformExpr(pexpr)) } else NoTree(fd.returnType) @@ -233,4 +232,4 @@ class RealToIntProgram extends ProgramTypeTransformer { } def mappedFun(fd: FunDef): FunDef = newFundefs(fd) -} \ No newline at end of file +} diff --git a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala index d17968dc05b789b936554b9c10dbde22ac360dfb..a696dae0580baf44c6ef955c72ccdcce9791bdab 100644 --- a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala +++ b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala @@ -26,7 +26,7 @@ object MultFuncs { val vary = yid.toVariable val args = Seq(xid, yid) val funcType = FunctionType(Seq(domain, domain), domain) - val mfd = new FunDef(FreshIdentifier("pmult", funcType, false), Seq(), args.map((arg) => ValDef(arg, Some(arg.getType))), domain) + val mfd = new FunDef(FreshIdentifier("pmult", funcType, false), Seq(), args.map(arg => ValDef(arg)), domain) val tmfd = TypedFunDef(mfd, Seq()) //define a body (a) using mult(x,y) = if(x == 0 || y ==0) 0 else mult(x-1,y) + y @@ -47,7 +47,7 @@ object MultFuncs { val post1 = Implies(guard, defn2) // mfd.postcondition = Some((resvar.id, And(Seq(post0, post1)))) - mfd.fullBody = Ensuring(mfd.body.get, Lambda(Seq(ValDef(resvar.id, Some(resvar.getType))), And(Seq(post0, post1)))) + mfd.fullBody = Ensuring(mfd.body.get, Lambda(Seq(ValDef(resvar.id)), And(Seq(post0, post1)))) //set function properties (for now, only monotonicity) mfd.addFlags(Set(Annotation("theoryop", Seq()), Annotation("monotonic", Seq()))) //"distributive" ? mfd @@ -59,7 +59,7 @@ object MultFuncs { val yid = FreshIdentifier("y", domain) val args = Seq(xid, yid) val funcType = FunctionType(Seq(domain, domain), domain) - val fd = new FunDef(FreshIdentifier("mult", funcType, false), Seq(), args.map((arg) => ValDef(arg, Some(arg.getType))), domain) + val fd = new FunDef(FreshIdentifier("mult", funcType, false), Seq(), args.map(arg => ValDef(arg)), domain) val tpivMultFun = TypedFunDef(pivMultFun, Seq()) //the body is defined as mult(x,y) = val px = if(x < 0) -x else x; diff --git a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala index 87c5795838b460314d19805cbb8ca6caae4ee233..c4830cf40ca429358b0d2e420bbbe7f7d1f5c276 100644 --- a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala +++ b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala @@ -111,7 +111,7 @@ class SerialInstrumenter(ctx: LeonContext, program: Program) { def mapPost(pred: Expr, from: FunDef, to: FunDef) = { pred match { - case Lambda(Seq(ValDef(fromRes, _)), postCond) if (instFuncs.contains(from)) => + case Lambda(Seq(ValDef(fromRes, lzy)), postCond) if (instFuncs.contains(from)) => val toResId = FreshIdentifier(fromRes.name, to.returnType, true) val newpost = postMap((e: Expr) => e match { case Variable(`fromRes`) => @@ -124,7 +124,7 @@ class SerialInstrumenter(ctx: LeonContext, program: Program) { case _ => None })(postCond) - Lambda(Seq(ValDef(toResId)), mapExpr(newpost)) + Lambda(Seq(ValDef(toResId, lzy)), mapExpr(newpost)) case _ => mapExpr(pred) } @@ -489,4 +489,4 @@ abstract class Instrumenter(program: Program, si: SerialInstrumenter) { */ def instrumentMatchCase(me: MatchExpr, mc: MatchCase, caseExprCost: Expr, scrutineeCost: Expr): Expr -} \ No newline at end of file +} diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 87d1003150995eb236e5f3abb738450516c5b7f7..7bfeece2f6f35e88c4557db2ecafe4c684a1198d 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -5,7 +5,6 @@ package utils import leon.purescala._ import leon.purescala.Definitions.Program -import leon.purescala.Quantification.CheckForalls import leon.solvers.isabelle.AdaptationPhase import leon.verification.InjectAsserts import leon.xlang.{NoXLangFeaturesChecking, XLangDesugaringPhase} @@ -39,8 +38,7 @@ class PreprocessingPhase(desugarXLang: Boolean = false) extends LeonPhase[Progra synthesis.ConversionPhase andThen CheckADTFieldsTypes andThen InjectAsserts andThen - InliningPhase andThen - CheckForalls + InliningPhase val pipeX = if(desugarXLang) { XLangDesugaringPhase andThen diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala index 5a5e2dff3088da991ac99a6b8f1e46f759837629..002f2ebedc8a6dfb265fbf101c2185b3bfa17ce1 100644 --- a/src/main/scala/leon/utils/SeqUtils.scala +++ b/src/main/scala/leon/utils/SeqUtils.scala @@ -2,6 +2,7 @@ package leon.utils +import scala.collection.SeqView import scala.collection.mutable.ArrayBuffer object SeqUtils { @@ -42,3 +43,47 @@ object SeqUtils { } } } + +class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], Seq[Seq[A]]] { + override protected def underlying: Seq[Seq[A]] = SeqUtils.cartesianProduct(views) + + override def length: Int = views.map{ _.size }.product + + override def apply(idx: Int): Seq[A] = { + if (idx < 0 || idx >= length) throw new IndexOutOfBoundsException + var c = idx + for (v <- views) yield { + val ic = c % v.size + c /= v.size + v(ic) + } + } + + override def iterator: Iterator[Seq[A]] = new Iterator[Seq[A]] { + // It's unfortunate, but we have to use streams to memoize + private val streams = views.map { _.toStream } + private val current = streams.toArray + + // We take a note if there exists an empty view to begin with + // (which means the whole iterator is empty) + private val empty = streams exists { _.isEmpty } + + override def hasNext: Boolean = !empty && current.exists { _.nonEmpty } + + override def next(): Seq[A] = { + if (!hasNext) throw new NoSuchElementException("next on empty iterator") + // Propagate curry + for (i <- (0 to streams.size).takeWhile(current(_).isEmpty)) { + current(i) = streams(i) + } + + val ret = current map { _.head } + + for (i <- (0 to streams.size)) { + current(i) = current(i).tail + } + + ret + } + } +} \ No newline at end of file diff --git a/src/main/scala/leon/utils/StreamUtils.scala b/src/main/scala/leon/utils/StreamUtils.scala index c972f2e27dd1a28325f8477fa2042ea0620bc78f..2ea08a593725b6f28f4e96be85a00088eb1f9f76 100644 --- a/src/main/scala/leon/utils/StreamUtils.scala +++ b/src/main/scala/leon/utils/StreamUtils.scala @@ -3,17 +3,23 @@ package leon.utils object StreamUtils { - def interleave[T](streams : Seq[Stream[T]]) : Stream[T] = { - var ss = streams - while(ss.nonEmpty && ss.head.isEmpty) { - ss = ss.tail + + def interleave[T](streams: Stream[Stream[T]]): Stream[T] = { + def rec(streams: Stream[Stream[T]], diag: Int): Stream[T] = { + if(streams.isEmpty) Stream() else { + val (take, leave) = streams.splitAt(diag) + val (nonEmpty, empty) = take partition (_.nonEmpty) + nonEmpty.map(_.head) ++ rec(nonEmpty.map(_.tail) ++ leave, diag + 1 - empty.size) + } } - if(ss.isEmpty) return Stream.empty - if(ss.size == 1) return ss.head + rec(streams, 1) + } - // TODO: This circular-shifts the list. I'd be interested in a constant time - // operation. Perhaps simply by choosing the right data-structure? - Stream.cons(ss.head.head, interleave(ss.tail :+ ss.head.tail)) + def interleave[T](streams : Seq[Stream[T]]) : Stream[T] = { + if (streams.isEmpty) Stream() else { + val nonEmpty = streams filter (_.nonEmpty) + nonEmpty.toStream.map(_.head) ++ interleave(nonEmpty.map(_.tail)) + } } def cartesianProduct[T](streams : Seq[Stream[T]]) : Stream[List[T]] = { diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index 65f96a090d4449845921252737bb35a3c3b9a327..dd437c224170c908e9175815354748a4236e4880 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -21,7 +21,7 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { } private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr): Seq[Expr] = { - val childrenOfSameType = cct.fields.filter(_.getType == parentType) + val childrenOfSameType = (cct.classDef.fields zip cct.fieldsTypes).collect { case (vd, tpe) if tpe == parentType => vd } for (field <- childrenOfSameType) yield { caseClassSelector(cct, expr, field.id) } diff --git a/src/main/scala/leon/verification/VerificationPhase.scala b/src/main/scala/leon/verification/VerificationPhase.scala index 21119c4e2728c99850e3af0d1c12e498866f771e..ea09de821427e39cffa3ea9c304e7d1e5a910a4c 100644 --- a/src/main/scala/leon/verification/VerificationPhase.scala +++ b/src/main/scala/leon/verification/VerificationPhase.scala @@ -57,7 +57,6 @@ object VerificationPhase extends SimpleLeonPhase[Program,VerificationReport] { } } - try { val vcs = generateVCs(vctx, toVerify) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index fe9582d548deb679e0de7a082d5e4ba178582a1f..e9802d500208b96b5af448391d69c6660a05f843 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -10,6 +10,7 @@ import leon.purescala.Extractors._ import leon.purescala.Constructors._ import leon.purescala.ExprOps._ import leon.purescala.TypeOps._ +import leon.purescala.Types._ import leon.xlang.Expressions._ object ImperativeCodeElimination extends UnitPhase[Program] { @@ -22,13 +23,22 @@ object ImperativeCodeElimination extends UnitPhase[Program] { fd <- pgm.definedFunctions body <- fd.body } { - val (res, scope, _) = toFunction(body)(State(fd, Set())) + val (res, scope, _) = toFunction(body)(State(fd, Set(), Map())) fd.body = Some(scope(res)) } } - case class State(parent: FunDef, varsInScope: Set[Identifier]) { + /* varsInScope refers to variable declared in the same level scope. + Typically, when entering a nested function body, the scope should be + reset to empty */ + private case class State( + parent: FunDef, + varsInScope: Set[Identifier], + funDefsMapping: Map[FunDef, (FunDef, List[Identifier])] + ) { def withVar(i: Identifier) = copy(varsInScope = varsInScope + i) + def withFunDef(fd: FunDef, nfd: FunDef, ids: List[Identifier]) = + copy(funDefsMapping = funDefsMapping + (fd -> (nfd, ids))) } //return a "scope" consisting of purely functional code that defines potentially needed @@ -119,6 +129,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) case wh@While(cond, body) => + //TODO: rewrite by re-using the nested function transformation code val (condRes, condScope, condFun) = toFunction(cond) val (_, bodyScope, bodyFun) = toFunction(body) val condBodyFun = condFun ++ bodyFun @@ -218,14 +229,115 @@ object ImperativeCodeElimination extends UnitPhase[Program] { bindFun ++ bodyFun ) + //a function invocation can update variables in scope. + case fi@FunctionInvocation(tfd, args) => + + + val (recArgs, argScope, argFun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => { + val (accArgs, accScope, accFun) = acc + val (argVal, argScope, argFun) = toFunction(arg) + val newScope = (body: Expr) => argScope(replaceNames(argFun, accScope(body))) + (argVal +: accArgs, newScope, argFun ++ accFun) + }) + + val fd = tfd.fd + state.funDefsMapping.get(fd) match { + case Some((newFd, modifiedVars)) => { + val newInvoc = FunctionInvocation(newFd.typed, recArgs ++ modifiedVars.map(id => id.toVariable)).setPos(fi) + val freshNames = modifiedVars.map(id => id.freshen) + val tmpTuple = FreshIdentifier("t", newFd.returnType) + + val scope = (body: Expr) => { + argScope(Let(tmpTuple, newInvoc, + freshNames.zipWithIndex.foldRight(body)((p, b) => + Let(p._1, TupleSelect(tmpTuple.toVariable, p._2 + 2), b)) + )) + } + val newMap = argFun ++ modifiedVars.zip(freshNames).toMap + + (TupleSelect(tmpTuple.toVariable, 1), scope, newMap) + } + case None => + (FunctionInvocation(tfd, recArgs).copiedFrom(fi), argScope, argFun) + } + + case LetDef(fd, b) => - //Recall that here the nested function should not access mutable variables from an outside scope - fd.body.foreach { bd => - val (fdRes, fdScope, _) = toFunction(bd) - fd.body = Some(fdScope(fdRes)) + + def fdWithoutSideEffects = { + fd.body.foreach { bd => + val (fdRes, fdScope, _) = toFunction(bd) + fd.body = Some(fdScope(fdRes)) + } + val (bodyRes, bodyScope, bodyFun) = toFunction(b) + (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)).setPos(fd).copiedFrom(expr), bodyFun) + } + + fd.body match { + case Some(bd) => { + + val modifiedVars: List[Identifier] = + collect[Identifier]({ + case Assignment(v, _) => Set(v) + case _ => Set() + })(bd).intersect(state.varsInScope).toList + + if(modifiedVars.isEmpty) fdWithoutSideEffects else { + + val freshNames: List[Identifier] = modifiedVars.map(id => id.freshen) + + val newParams: Seq[ValDef] = fd.params ++ freshNames.map(n => ValDef(n)) + val freshVarDecls: List[Identifier] = freshNames.map(id => id.freshen) + + val rewritingMap: Map[Identifier, Identifier] = + modifiedVars.zip(freshVarDecls).toMap + val freshBody = + preMap({ + case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e)) + case Variable(id) => rewritingMap.get(id).map(nid => Variable(nid)) + case _ => None + })(bd) + val wrappedBody = freshNames.zip(freshVarDecls).foldLeft(freshBody)((body, p) => { + LetVar(p._2, Variable(p._1), body) + }) + + val newReturnType = TupleType(fd.returnType :: modifiedVars.map(_.getType)) + + val newFd = new FunDef(fd.id.freshen, fd.tparams, newParams, newReturnType).setPos(fd) + + val (fdRes, fdScope, fdFun) = + toFunction(wrappedBody)( + State(state.parent, Set(), + state.funDefsMapping + (fd -> ((newFd, freshVarDecls)))) + ) + val newRes = Tuple(fdRes :: freshVarDecls.map(vd => fdFun(vd).toVariable)) + val newBody = fdScope(newRes) + + newFd.body = Some(newBody) + newFd.precondition = fd.precondition.map(prec => { + replace(modifiedVars.zip(freshNames).map(p => (p._1.toVariable, p._2.toVariable)).toMap, prec) + }) + newFd.postcondition = fd.postcondition.map(post => { + val Lambda(Seq(res), postBody) = post + val newRes = ValDef(FreshIdentifier(res.id.name, newFd.returnType)) + + val newBody = + replace( + modifiedVars.zipWithIndex.map{ case (v, i) => + (v.toVariable, TupleSelect(newRes.toVariable, i+2)): (Expr, Expr)}.toMap ++ + modifiedVars.zip(freshNames).map{ case (ov, nv) => + (Old(ov), nv.toVariable)}.toMap + + (res.toVariable -> TupleSelect(newRes.toVariable, 1)), + postBody) + Lambda(Seq(newRes), newBody).setPos(post) + }) + + val (bodyRes, bodyScope, bodyFun) = toFunction(b)(state.withFunDef(fd, newFd, modifiedVars)) + (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)).copiedFrom(expr), bodyFun) + } + } + case None => fdWithoutSideEffects } - val (bodyRes, bodyScope, bodyFun) = toFunction(b) - (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)).copiedFrom(expr), bodyFun) case c @ Choose(b) => //Recall that Choose cannot mutate variables from the scope diff --git a/src/sphinx/library.rst b/src/sphinx/library.rst index cd2fd9c691250b3457ca74749d3a0edc7d9ab495..5cbb60c2dfc91266a52e0a273ae0647162832a36 100644 --- a/src/sphinx/library.rst +++ b/src/sphinx/library.rst @@ -74,7 +74,9 @@ which instruct Leon to handle some functions or objects in a specialized way. | | code written in full Scala which is not verifiable| | | by Leon. | +-------------------+---------------------------------------------------+ - +| ``@inline`` | Inline this function. Leon will refuse to inline | +| | (mutually) recursive functions. | ++-------------------+---------------------------------------------------+ List[T] ------- diff --git a/src/sphinx/options.rst b/src/sphinx/options.rst index 5648b042c47f3a252672d698d496a2b3309221a8..1c929cd29cdc5281378cd322389b867902f0ff2b 100644 --- a/src/sphinx/options.rst +++ b/src/sphinx/options.rst @@ -9,10 +9,10 @@ or just ``--option``. To disable a flag option, use ``--option=false`` or ``off`` or ``no``. Additionally, if you need to pass options to the ``scalac`` frontend of Leon, -you can do it by using a single dash ``-``. For example, ``-Ybrowse:typer``. +you can do it by using a single dash ``-``. For example, try ``-Ybrowse:typer``. The rest of this section presents command-line options that Leon recognizes. -For more up-to-date list, please invoke ``leon --help``. +For a short (but always up-to-date) summary, you can also invoke ``leon --help``. Choosing which Leon feature to use ---------------------------------- @@ -73,10 +73,14 @@ These options are available to all Leon components: * ``eval`` (Evaluators) + * ``isabelle`` (:ref:`The Isabelle-based solver <isabelle>`) + * ``leon`` (The top-level component) * ``options`` (Options parsed by Leon) - + + * ``positions`` (When printing, attach positions to trees) + * ``repair`` (Program repair) * ``solver`` (SMT solvers and their wrappers) @@ -88,7 +92,9 @@ These options are available to all Leon components: * ``timers`` (Timers, timer pools) * ``trees`` (Manipulation of trees) - + + * ``types`` (When printing, attach types to expressions) + * ``verification`` (Verification) * ``xlang`` (Transformation of XLang into Pure Scala programs) @@ -100,9 +106,9 @@ These options are available to all Leon components: where Leon manipulates the input in a per-function basis. Leon will match against suffixes of qualified names. For instance: - ``--functions=List.size`` will match the method - ``leon.collection.List.size`` while ``--functions=size`` will match all ``size`` - methods and functions. This option supports ``_`` as wildcard: ``--functions=List._`` will + ``--functions=List.size`` will match the method ``leon.collection.List.size``, + while ``--functions=size`` will match all methods and functions named ``size``. + This option supports ``_`` as wildcard: ``--functions=List._`` will match all ``List`` methods. * ``--solvers=s1,s2,...`` @@ -201,7 +207,8 @@ Additional Options (by component) --------------------------------- The following options relate to specific components in Leon. Bear in mind -that related components might still use these options, e.g. repair will use +that related components might still use these options, e.g. repair, +which invokes synthesis and verification, will also use synthesis options and verification options. Verification diff --git a/src/sphinx/purescala.rst b/src/sphinx/purescala.rst index ee41a7ec83a68ff479b2316b9085b5302c017237..7cde91421d3a77cfc68162b6c643e4396a8bc85b 100644 --- a/src/sphinx/purescala.rst +++ b/src/sphinx/purescala.rst @@ -36,6 +36,51 @@ Pure Scala supports two kinds of top-level declarations: .. _adts: + +Boolean +####### + +Booleans are used to express truth conditions in Leon. +Unlike some proof assistants, there is no separation +at the type level between +Boolean values and the truth conditions of conjectures +and theorems. + +Typical propositional operations are available using Scala +notation, along +with a new shorthand for implication. The `if` expression +is also present. + +.. code-block:: scala + + a && b + a || b + a == b + !a + a ==> b // Leon syntax for boolean implication + +Leon uses short-circuit interpretation of `&&`, `||`, and `==>`, +which evaluates the second argument only when needed: + +.. code-block:: scala + + a && b === if (a) b else false + + a || b === if (a) true else b + + a ==> b === if (a) b else true + +This aspect is important because of: + +1. evaluation of expressions, which is kept compatible with Scala + +2. verification condition generation for safety: arguments of Boolean operations +may be operations with preconditions; these preconditions apply only in case +that the corresponding argument is evaluated. + +3. termination checking, which takes into account that only one of the paths in an if expression is evaluated for a given truth value of the condition. + + Algebraic Data Types -------------------- @@ -266,16 +311,6 @@ TupleX val y = 1 -> 2 // alternative Scala syntax for Tuple2 x._1 // == 1 -Boolean -####### - -.. code-block:: scala - - a && b - a || b - a == b - !a - a ==> b // Leon syntax for boolean implication Int ### diff --git a/src/test/resources/regression/repair/Heap3.scala b/src/test/resources/regression/repair/Heap3.scala new file mode 100644 index 0000000000000000000000000000000000000000..3305b5d15a0731bd3aeaa1269c2404807f49a266 --- /dev/null +++ b/src/test/resources/regression/repair/Heap3.scala @@ -0,0 +1,113 @@ +/* Copyright 2009-2015 EPFL, Lausanne + * + * Author: Ravi + * Date: 20.11.2013 + **/ + +import leon.lang._ +import leon.collection._ + +object Heaps { + + sealed abstract class Heap { + val rank : BigInt = this match { + case Leaf() => 0 + case Node(_, l, r) => + 1 + max(l.rank, r.rank) + } + def content : Set[BigInt] = this match { + case Leaf() => Set[BigInt]() + case Node(v,l,r) => l.content ++ Set(v) ++ r.content + } + } + case class Leaf() extends Heap + case class Node(value:BigInt, left: Heap, right: Heap) extends Heap + + def max(i1 : BigInt, i2 : BigInt) = if (i1 >= i2) i1 else i2 + + def hasHeapProperty(h : Heap) : Boolean = h match { + case Leaf() => true + case Node(v, l, r) => + ( l match { + case Leaf() => true + case n@Node(v2,_,_) => v >= v2 && hasHeapProperty(n) + }) && + ( r match { + case Leaf() => true + case n@Node(v2,_,_) => v >= v2 && hasHeapProperty(n) + }) + } + + def hasLeftistProperty(h: Heap) : Boolean = h match { + case Leaf() => true + case Node(_,l,r) => + hasLeftistProperty(l) && + hasLeftistProperty(r) && + l.rank >= r.rank + } + + def heapSize(t: Heap): BigInt = { t match { + case Leaf() => BigInt(0) + case Node(v, l, r) => heapSize(l) + 1 + heapSize(r) + }} ensuring(_ >= 0) + + private def merge(h1: Heap, h2: Heap) : Heap = { + require( + hasLeftistProperty(h1) && hasLeftistProperty(h2) && + hasHeapProperty(h1) && hasHeapProperty(h2) + ) + (h1,h2) match { + case (Leaf(), _) => h2 + case (_, Leaf()) => h1 + case (Node(v1, l1, r1), Node(v2, l2, r2)) => + if(v1 >= v2) // FIXME swapped the branches + makeN(v2, l2, merge(h1, r2)) + else + makeN(v1, l1, merge(r1, h2)) + } + } ensuring { res => + hasLeftistProperty(res) && hasHeapProperty(res) && + heapSize(h1) + heapSize(h2) == heapSize(res) && + h1.content ++ h2.content == res.content + } + + private def makeN(value: BigInt, left: Heap, right: Heap) : Heap = { + require( + hasLeftistProperty(left) && hasLeftistProperty(right) + ) + if(left.rank >= right.rank) + Node(value, left, right) + else + Node(value, right, left) + } ensuring { res => + hasLeftistProperty(res) } + + def insert(element: BigInt, heap: Heap) : Heap = { + require(hasLeftistProperty(heap) && hasHeapProperty(heap)) + + merge(Node(element, Leaf(), Leaf()), heap) + + } ensuring { res => + hasLeftistProperty(res) && hasHeapProperty(res) && + heapSize(res) == heapSize(heap) + 1 && + res.content == heap.content ++ Set(element) + } + + def findMax(h: Heap) : Option[BigInt] = { + h match { + case Node(m,_,_) => Some(m) + case Leaf() => None() + } + } + + def removeMax(h: Heap) : Heap = { + require(hasLeftistProperty(h) && hasHeapProperty(h)) + h match { + case Node(_,l,r) => merge(l, r) + case l => l + } + } ensuring { res => + hasLeftistProperty(res) && hasHeapProperty(res) + } + +} diff --git a/src/test/resources/regression/termination/valid/NNF.scala b/src/test/resources/regression/termination/valid/NNF.scala new file mode 100644 index 0000000000000000000000000000000000000000..455cef7ec5201fd17edee71059dc38c9d7afa3d8 --- /dev/null +++ b/src/test/resources/regression/termination/valid/NNF.scala @@ -0,0 +1,99 @@ +import leon.lang._ +import leon.annotation._ + +object PropositionalLogic { + + sealed abstract class Formula + case class And(lhs: Formula, rhs: Formula) extends Formula + case class Or(lhs: Formula, rhs: Formula) extends Formula + case class Implies(lhs: Formula, rhs: Formula) extends Formula + case class Not(f: Formula) extends Formula + case class Literal(id: BigInt) extends Formula + + def simplify(f: Formula): Formula = (f match { + case And(lhs, rhs) => And(simplify(lhs), simplify(rhs)) + case Or(lhs, rhs) => Or(simplify(lhs), simplify(rhs)) + case Implies(lhs, rhs) => Or(Not(simplify(lhs)), simplify(rhs)) + case Not(f) => Not(simplify(f)) + case Literal(_) => f + }) ensuring(isSimplified(_)) + + def isSimplified(f: Formula): Boolean = f match { + case And(lhs, rhs) => isSimplified(lhs) && isSimplified(rhs) + case Or(lhs, rhs) => isSimplified(lhs) && isSimplified(rhs) + case Implies(_,_) => false + case Not(f) => isSimplified(f) + case Literal(_) => true + } + + def nnf(formula: Formula): Formula = (formula match { + case And(lhs, rhs) => And(nnf(lhs), nnf(rhs)) + case Or(lhs, rhs) => Or(nnf(lhs), nnf(rhs)) + case Implies(lhs, rhs) => nnf(Or(Not(lhs), rhs)) + case Not(And(lhs, rhs)) => Or(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Or(lhs, rhs)) => And(nnf(Not(lhs)), nnf(Not(rhs))) + case Not(Implies(lhs, rhs)) => And(nnf(lhs), nnf(Not(rhs))) + case Not(Not(f)) => nnf(f) + case Not(Literal(_)) => formula + case Literal(_) => formula + }) ensuring(isNNF(_)) + + def isNNF(f: Formula): Boolean = f match { + case And(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Or(lhs, rhs) => isNNF(lhs) && isNNF(rhs) + case Implies(lhs, rhs) => false + case Not(Literal(_)) => true + case Not(_) => false + case Literal(_) => true + } + + def evalLit(id : BigInt) : Boolean = (id == 42) // could be any function + def eval(f: Formula) : Boolean = f match { + case And(lhs, rhs) => eval(lhs) && eval(rhs) + case Or(lhs, rhs) => eval(lhs) || eval(rhs) + case Implies(lhs, rhs) => !eval(lhs) || eval(rhs) + case Not(f) => !eval(f) + case Literal(id) => evalLit(id) + } + + @induct + def simplifySemantics(f: Formula) : Boolean = { + eval(f) == eval(simplify(f)) + } holds + + // Note that matching is exhaustive due to precondition. + def vars(f: Formula): Set[BigInt] = { + require(isNNF(f)) + f match { + case And(lhs, rhs) => vars(lhs) ++ vars(rhs) + case Or(lhs, rhs) => vars(lhs) ++ vars(rhs) + case Not(Literal(i)) => Set[BigInt](i) + case Literal(i) => Set[BigInt](i) + } + } + + def fv(f : Formula) = { vars(nnf(f)) } + + @induct + def wrongCommutative(f: Formula) : Boolean = { + nnf(simplify(f)) == simplify(nnf(f)) + } holds + + @induct + def simplifyPreservesNNF(f: Formula) : Boolean = { + require(isNNF(f)) + isNNF(simplify(f)) + } holds + + @induct + def nnfIsStable(f: Formula) : Boolean = { + require(isNNF(f)) + nnf(f) == f + } holds + + @induct + def simplifyIsStable(f: Formula) : Boolean = { + require(isSimplified(f)) + simplify(f) == f + } holds +} diff --git a/src/test/resources/regression/verification/purescala/invalid/CallByName1.scala b/src/test/resources/regression/verification/purescala/invalid/CallByName1.scala new file mode 100644 index 0000000000000000000000000000000000000000..c96ab1617e254c19d8b08a8ecbd818d9cdc305e4 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/CallByName1.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object CallByName1 { + def byName1(i: Int, a: => Int): Int = { + if (i > 0) a + 1 + else 0 + } + + def byName2(i: Int, a: => Int): Int = { + if (i > 0) byName1(i - 1, a) + 2 + else 0 + } + + def test(): Boolean = { + byName1(1, byName2(3, 0)) == 0 && byName1(1, byName2(3, 0)) == 1 + }.holds +} diff --git a/src/test/resources/regression/verification/purescala/invalid/Existentials.scala b/src/test/resources/regression/verification/purescala/invalid/Existentials.scala new file mode 100644 index 0000000000000000000000000000000000000000..19679db9202b325edc0eb8dfff8eeea41bb47d36 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Existentials.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object Existentials { + + def exists[A](p: A => Boolean): Boolean = !forall((x: A) => !p(x)) + + def check1(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) == exists((y1:BigInt) => p(y1)) + }.holds + + /* + def check2(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) ==> exists((y1:BigInt) => p(y1)) + }.holds + */ +} diff --git a/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala b/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala new file mode 100644 index 0000000000000000000000000000000000000000..83773b2224fb8790acfc25125143b358c1226f05 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/ForallAssoc.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object ForallAssoc { + + /* + def test3(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, f(4, 5)))) == f(f(f(f(1, 2), 3), 4), 4) + }.holds + */ + + def test4(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, 4))) == 0 + }.holds + +} diff --git a/src/test/resources/regression/verification/purescala/invalid/PropositionalLogic.scala b/src/test/resources/regression/verification/purescala/invalid/PropositionalLogic.scala index 547978dbcca900272b76aac30dfd8e2a8f9e2c62..a8927f360e817d10761723c8a4b7e9085bf0738d 100644 --- a/src/test/resources/regression/verification/purescala/invalid/PropositionalLogic.scala +++ b/src/test/resources/regression/verification/purescala/invalid/PropositionalLogic.scala @@ -78,7 +78,7 @@ object PropositionalLogic { require(isNNF(f)) nnf(f) == f }.holds - + @induct def simplifyIsStable(f: Formula) : Boolean = { require(isSimplified(f)) diff --git a/src/test/resources/regression/verification/purescala/invalid/TestLazinessOfAnd.scala b/src/test/resources/regression/verification/purescala/invalid/TestLazinessOfAnd.scala new file mode 100644 index 0000000000000000000000000000000000000000..9ee698e1d42d38214a0290cbad859b0b5ad87405 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/TestLazinessOfAnd.scala @@ -0,0 +1,18 @@ +import leon.lang._ + +object AndTest { + + def nonterm(x: BigInt) : BigInt = { + nonterm(x + 1) + } ensuring(res => false) + + def precond(y : BigInt) = y < 0 + + /** + * Leon should find a counter-example here. + **/ + def foo(y: BigInt) : Boolean = { + require(precond(y)) + y >= 0 && (nonterm(0) == 0) + } holds +} diff --git a/src/test/resources/regression/verification/purescala/invalid/Unapply1.scala b/src/test/resources/regression/verification/purescala/invalid/Unapply1.scala index 20ff95383b4f7042b0eba32fcc4e5c8b0cea3680..674ca7c69fa29333755d8908f707ae9931d20bf2 100644 --- a/src/test/resources/regression/verification/purescala/invalid/Unapply1.scala +++ b/src/test/resources/regression/verification/purescala/invalid/Unapply1.scala @@ -6,10 +6,14 @@ object Unap1 { } object Unapply1 { + + sealed abstract class Bool + case class True() extends Bool + case class False() extends Bool - def bar: Boolean = { (42, false, ()) match { - case Unap1(_, b) if b => b + def bar: Bool = { (42, False().asInstanceOf[Bool], ()) match { + case Unap1(_, b) if b == True() => b case Unap1((), b) => b - }} ensuring { res => res } + }} ensuring { res => res == True() } } diff --git a/src/test/resources/regression/verification/purescala/invalid/Unapply2.scala b/src/test/resources/regression/verification/purescala/invalid/Unapply2.scala index ae4167c20026f0e926494af6cf3a11f25e9399e5..7efd6e220fd1c7fcd08fbaa5b7016bf6801e151f 100644 --- a/src/test/resources/regression/verification/purescala/invalid/Unapply2.scala +++ b/src/test/resources/regression/verification/purescala/invalid/Unapply2.scala @@ -5,7 +5,12 @@ object Unap2 { } object Unapply { - def bar: Boolean = { (42, false, ()) match { - case Unap2(_, b) if b => b - }} ensuring { res => res } + + sealed abstract class Bool + case class True() extends Bool + case class False() extends Bool + + def bar: Bool = { (42, False().asInstanceOf[Bool], ()) match { + case Unap2(_, b) if b == True() => b + }} ensuring { res => res == True() } } diff --git a/src/test/resources/regression/verification/purescala/valid/CallByName1.scala b/src/test/resources/regression/verification/purescala/valid/CallByName1.scala new file mode 100644 index 0000000000000000000000000000000000000000..912acdc481e8191bd072bd7eaeb13684cd93c384 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/CallByName1.scala @@ -0,0 +1,9 @@ +import leon.lang._ + +object CallByName1 { + def add(a: => Int, b: => Int): Int = a + b + + def test(): Int = { + add(1,2) + } ensuring (_ == 3) +} diff --git a/src/test/resources/regression/verification/purescala/valid/Client.scala b/src/test/resources/regression/verification/purescala/valid/Client.scala new file mode 100644 index 0000000000000000000000000000000000000000..c0dd37cda62e30202e9255b534685f9e0e9b840a --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Client.scala @@ -0,0 +1,18 @@ +import leon.collection._ +import leon.lang._ + +object Minimal { + + case class Client(f: Int => List[Int]) + + val client = Client(x => List(1)) + + // def f(x: Int) = List(1) + // val client = Client(f) + + def theorem() = { + client.f(0).size != BigInt(0) + } holds + +} + diff --git a/src/test/resources/regression/verification/purescala/valid/Existentials.scala b/src/test/resources/regression/verification/purescala/valid/Existentials.scala new file mode 100644 index 0000000000000000000000000000000000000000..992d58cd0678e98146ad1085a22269671a35421c --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Existentials.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object Existentials { + + def exists[A](p: A => Boolean): Boolean = !forall((x: A) => !p(x)) + + /* + def check1(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) == exists((y1:BigInt) => p(y1)) + }.holds + */ + + def check2(y: BigInt, p: BigInt => Boolean) : Boolean = { + p(y) ==> exists((y1:BigInt) => p(y1)) + }.holds +} diff --git a/src/test/resources/regression/verification/purescala/valid/FlatMap.scala b/src/test/resources/regression/verification/purescala/valid/FlatMap.scala index 3f1dd39ffd79dd10d1300a9aa0ac60f2b29d46a0..878b61456b6658b13f05d0e5f1b05dc25d959676 100644 --- a/src/test/resources/regression/verification/purescala/valid/FlatMap.scala +++ b/src/test/resources/regression/verification/purescala/valid/FlatMap.scala @@ -1,6 +1,7 @@ /* Copyright 2009-2015 EPFL, Lausanne */ import leon.lang._ +import leon.proof._ import leon.collection._ object FlatMap { @@ -17,7 +18,7 @@ object FlatMap { def associative_append_lemma_induct[T](l1: List[T], l2: List[T], l3: List[T]): Boolean = { l1 match { case Nil() => associative_append_lemma(l1, l2, l3) - case Cons(head, tail) => associative_append_lemma(l1, l2, l3) && associative_append_lemma_induct(tail, l2, l3) + case Cons(head, tail) => associative_append_lemma(l1, l2, l3) because associative_append_lemma_induct(tail, l2, l3) } }.holds @@ -31,20 +32,21 @@ object FlatMap { } def associative_lemma_induct[T,U,V](list: List[T], flist: List[U], glist: List[V], f: T => List[U], g: U => List[V]): Boolean = { - associative_lemma(list, f, g) && - append(glist, flatMap(append(flist, flatMap(list, f)), g)) == append(append(glist, flatMap(flist, g)), flatMap(list, (x: T) => flatMap(f(x), g))) && - (glist match { - case Cons(ghead, gtail) => - associative_lemma_induct(list, flist, gtail, f, g) - case Nil() => flist match { - case Cons(fhead, ftail) => - associative_lemma_induct(list, ftail, g(fhead), f, g) - case Nil() => list match { - case Cons(head, tail) => associative_lemma_induct(tail, f(head), Nil(), f, g) - case Nil() => true + associative_lemma(list, f, g) because { + append(glist, flatMap(append(flist, flatMap(list, f)), g)) == append(append(glist, flatMap(flist, g)), flatMap(list, (x: T) => flatMap(f(x), g))) because + (glist match { + case Cons(ghead, gtail) => + associative_lemma_induct(list, flist, gtail, f, g) + case Nil() => flist match { + case Cons(fhead, ftail) => + associative_lemma_induct(list, ftail, g(fhead), f, g) + case Nil() => list match { + case Cons(head, tail) => associative_lemma_induct(tail, f(head), Nil(), f, g) + case Nil() => true + } } - } - }) + }) + } }.holds } diff --git a/src/test/resources/regression/verification/purescala/valid/FoldAssociative.scala b/src/test/resources/regression/verification/purescala/valid/FoldAssociative.scala index ff0444f269dd6f9b2a2b7f2b54e13f54c6755ab2..6920231115fea93a1b7c7cefc27691a0ee2471bb 100644 --- a/src/test/resources/regression/verification/purescala/valid/FoldAssociative.scala +++ b/src/test/resources/regression/verification/purescala/valid/FoldAssociative.scala @@ -2,6 +2,7 @@ import leon._ import leon.lang._ +import leon.proof._ object FoldAssociative { @@ -56,7 +57,7 @@ object FoldAssociative { val f = (x: Int, s: Int) => x + s val l1 = take(list, x) val l2 = drop(list, x) - lemma_split(list, x) && (list match { + lemma_split(list, x) because (list match { case Cons(head, tail) if x > 0 => lemma_split_induct(tail, x - 1) case _ => true @@ -77,7 +78,7 @@ object FoldAssociative { val f = (x: Int, s: Int) => x + s val l1 = take(list, x) val l2 = drop(list, x) - lemma_reassociative(list, x) && (list match { + lemma_reassociative(list, x) because (list match { case Cons(head, tail) if x > 0 => lemma_reassociative_induct(tail, x - 1) case _ => true @@ -93,7 +94,7 @@ object FoldAssociative { def lemma_reassociative_presplit_induct(l1: List, l2: List): Boolean = { val f = (x: Int, s: Int) => x + s val list = append(l1, l2) - lemma_reassociative_presplit(l1, l2) && (l1 match { + lemma_reassociative_presplit(l1, l2) because (l1 match { case Cons(head, tail) => lemma_reassociative_presplit_induct(tail, l2) case Nil() => true diff --git a/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala b/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae90e9489678e1f4cc20d90b0cd23767cc317546 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/ForallAssoc.scala @@ -0,0 +1,23 @@ +import leon.lang._ + +object ForallAssoc { + + def ex[A](x1: A, x2: A, x3: A, x4: A, x5: A, f: (A, A) => A) = { + require(forall { + (x: A, y: A, z: A) => f(x, f(y, z)) == f(f(x, y), z) + }) + + f(x1, f(x2, f(x3, f(x4, x5)))) == f(f(x1, f(x2, f(x3, x4))), x5) + }.holds + + def test1(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, 4))) == f(f(f(1, 2), 3), 4) + }.holds + + def test2(f: (BigInt, BigInt) => BigInt): Boolean = { + require(forall((x: BigInt, y: BigInt, z: BigInt) => f(x, f(y, z)) == f(f(x, y), z))) + f(1, f(2, f(3, f(4, 5)))) == f(f(f(f(1, 2), 3), 4), 5) + }.holds + +} diff --git a/src/test/resources/regression/verification/purescala/valid/Lists1.scala b/src/test/resources/regression/verification/purescala/valid/Lists1.scala index 4dbee2b7137c8a427cf7bd302e3849050f0b2bc7..0b10f27d05354ef02fab45cc935f7f3695bee378 100644 --- a/src/test/resources/regression/verification/purescala/valid/Lists1.scala +++ b/src/test/resources/regression/verification/purescala/valid/Lists1.scala @@ -1,6 +1,7 @@ /* Copyright 2009-2015 EPFL, Lausanne */ import leon.lang._ +import leon.proof._ import leon.collection._ import leon.annotation._ @@ -21,10 +22,10 @@ object Lists1 { } def exists_lemma_induct[T](list: List[T], f: T => Boolean): Boolean = { - list match { - case Nil() => exists_lemma(list, f) - case Cons(head, tail) => exists_lemma(list, f) && exists_lemma_induct(tail, f) - } + exists_lemma(list, f) because (list match { + case Nil() => true + case Cons(head, tail) => exists_lemma_induct(tail, f) + }) }.holds } diff --git a/src/test/resources/regression/verification/purescala/valid/Lists2.scala b/src/test/resources/regression/verification/purescala/valid/Lists2.scala index 57d86819bf5cb17a77e451aead68393f7c30ed17..79126cb2226b55c811313c1c82375b977265ffc2 100644 --- a/src/test/resources/regression/verification/purescala/valid/Lists2.scala +++ b/src/test/resources/regression/verification/purescala/valid/Lists2.scala @@ -1,6 +1,7 @@ /* Copyright 2009-2015 EPFL, Lausanne */ import leon.lang._ +import leon.proof._ object Lists2 { abstract class List[T] @@ -22,10 +23,10 @@ object Lists2 { } def positive_lemma_induct(list: List[Int]): Boolean = { - list match { - case Nil() => positive_lemma(list) - case Cons(head, tail) => positive_lemma(list) && positive_lemma_induct(tail) - } + positive_lemma(list) because (list match { + case Nil() => true + case Cons(head, tail) => positive_lemma_induct(tail) + }) }.holds def remove[T](list: List[T], e: T) : List[T] = { @@ -41,10 +42,10 @@ object Lists2 { } def remove_lemma_induct[T](list: List[T], e: T): Boolean = { - list match { - case Nil() => remove_lemma(list, e) - case Cons(head, tail) => remove_lemma(list, e) && remove_lemma_induct(tail, e) - } + remove_lemma(list, e) because (list match { + case Nil() => true + case Cons(head, tail) => remove_lemma_induct(tail, e) + }) }.holds } diff --git a/src/test/resources/regression/verification/purescala/valid/Lists3.scala b/src/test/resources/regression/verification/purescala/valid/Lists3.scala index 1dae4a5c6f5ecf649bc78d2b19151007cc57a47d..dedab2a4e32aee2bf02e816ad8677e02a0a669a2 100644 --- a/src/test/resources/regression/verification/purescala/valid/Lists3.scala +++ b/src/test/resources/regression/verification/purescala/valid/Lists3.scala @@ -1,6 +1,7 @@ /* Copyright 2009-2015 EPFL, Lausanne */ import leon.lang._ +import leon.proof._ object Lists3 { abstract class List[T] @@ -26,10 +27,10 @@ object Lists3 { } def positive_lemma_induct(list: List[Int]): Boolean = { - list match { - case Nil() => positive_lemma(list) - case Cons(head, tail) => positive_lemma(list) && positive_lemma_induct(tail) - } + positive_lemma(list) because (list match { + case Nil() => true + case Cons(head, tail) => positive_lemma_induct(tail) + }) }.holds } diff --git a/src/test/resources/regression/verification/purescala/valid/Lists4.scala b/src/test/resources/regression/verification/purescala/valid/Lists4.scala index d4e212aff31bd011ce143901efa34a5a428f0d94..124d944da9e0769e74d43c4b3612f5a79657d9fa 100644 --- a/src/test/resources/regression/verification/purescala/valid/Lists4.scala +++ b/src/test/resources/regression/verification/purescala/valid/Lists4.scala @@ -1,6 +1,7 @@ /* Copyright 2009-2015 EPFL, Lausanne */ import leon.lang._ +import leon.proof._ object Lists4 { abstract class List[T] @@ -17,10 +18,10 @@ object Lists4 { } def map_lemma_induct[D,E,F](list: List[D], f: D => E, g: E => F): Boolean = { - list match { - case Nil() => map_lemma(list, f, g) - case Cons(head, tail) => map_lemma(list, f, g) && map_lemma_induct(tail, f, g) - } + map_lemma(list, f, g) because (list match { + case Nil() => true + case Cons(head, tail) => map_lemma_induct(tail, f, g) + }) }.holds } diff --git a/src/test/resources/regression/verification/purescala/valid/Lists5.scala b/src/test/resources/regression/verification/purescala/valid/Lists5.scala index f8fd9491fee2ad4a27ed9a4230ba1b30afec3093..9826dc372a2a79ae733696126de8565394148ff8 100644 --- a/src/test/resources/regression/verification/purescala/valid/Lists5.scala +++ b/src/test/resources/regression/verification/purescala/valid/Lists5.scala @@ -1,6 +1,7 @@ /* Copyright 2009-2015 EPFL, Lausanne */ import leon.lang._ +import leon.proof._ import leon.collection._ object Lists5 { @@ -26,9 +27,9 @@ object Lists5 { } def equivalence_lemma_induct[T](f: T => Boolean, list: List[T]): Boolean = { - list match { - case Cons(head, tail) => equivalence_lemma(f, list) && equivalence_lemma_induct(f, tail) - case Nil() => equivalence_lemma(f, list) - } + equivalence_lemma(f, list) because (list match { + case Cons(head, tail) => equivalence_lemma_induct(f, tail) + case Nil() => true + }) }.holds } diff --git a/src/test/resources/regression/verification/purescala/valid/Lists6.scala b/src/test/resources/regression/verification/purescala/valid/Lists6.scala index 8257249830949d07dc36c0c74633d166bebbe561..763fabdaddab6a7d4732dcd2c102beb897b27d0e 100644 --- a/src/test/resources/regression/verification/purescala/valid/Lists6.scala +++ b/src/test/resources/regression/verification/purescala/valid/Lists6.scala @@ -1,6 +1,7 @@ /* Copyright 2009-2015 EPFL, Lausanne */ import leon.lang._ +import leon.proof._ import leon.collection._ object Lists6 { @@ -16,7 +17,7 @@ object Lists6 { } def associative_lemma_induct[T](list: List[T], f: T => Boolean, g: T => Boolean): Boolean = { - associative_lemma(list, f, g) && (list match { + associative_lemma(list, f, g) because (list match { case Cons(head, tail) => associative_lemma_induct(tail, f, g) case Nil() => true }) diff --git a/src/test/resources/regression/verification/purescala/valid/Monads3.scala b/src/test/resources/regression/verification/purescala/valid/Monads3.scala index 05b46ceb0f02a96811b7ca8a4606b4b4b30de3e3..fc0e7149b7f43455b0531c4f2502d592bbd764da 100644 --- a/src/test/resources/regression/verification/purescala/valid/Monads3.scala +++ b/src/test/resources/regression/verification/purescala/valid/Monads3.scala @@ -1,6 +1,7 @@ /* Copyright 2009-2015 EPFL, Lausanne */ import leon.lang._ +import leon.proof._ import leon.collection._ object Monads3 { @@ -22,20 +23,21 @@ object Monads3 { } def associative_lemma_induct[T,U,V](list: List[T], flist: List[U], glist: List[V], f: T => List[U], g: U => List[V]): Boolean = { - associative_lemma(list, f, g) && - append(glist, flatMap(append(flist, flatMap(list, f)), g)) == append(append(glist, flatMap(flist, g)), flatMap(list, (x: T) => flatMap(f(x), g))) && - (glist match { - case Cons(ghead, gtail) => - associative_lemma_induct(list, flist, gtail, f, g) - case Nil() => flist match { - case Cons(fhead, ftail) => - associative_lemma_induct(list, ftail, g(fhead), f, g) - case Nil() => list match { - case Cons(head, tail) => associative_lemma_induct(tail, f(head), Nil(), f, g) - case Nil() => true + associative_lemma(list, f, g) because { + append(glist, flatMap(append(flist, flatMap(list, f)), g)) == append(append(glist, flatMap(flist, g)), flatMap(list, (x: T) => flatMap(f(x), g))) because + (glist match { + case Cons(ghead, gtail) => + associative_lemma_induct(list, flist, gtail, f, g) + case Nil() => flist match { + case Cons(fhead, ftail) => + associative_lemma_induct(list, ftail, g(fhead), f, g) + case Nil() => list match { + case Cons(head, tail) => associative_lemma_induct(tail, f(head), Nil(), f, g) + case Nil() => true + } } - } - }) + }) + } }.holds def left_unit_law[T,U](x: T, f: T => List[U]): Boolean = { @@ -47,7 +49,7 @@ object Monads3 { } def right_unit_induct[T,U](list: List[T]): Boolean = { - right_unit_law(list) && (list match { + right_unit_law(list) because (list match { case Cons(head, tail) => right_unit_induct(tail) case Nil() => true }) @@ -62,7 +64,7 @@ object Monads3 { } def flatMap_to_zero_induct[T,U](list: List[T]): Boolean = { - flatMap_to_zero_law(list) && (list match { + flatMap_to_zero_law(list) because (list match { case Cons(head, tail) => flatMap_to_zero_induct(tail) case Nil() => true }) diff --git a/src/test/resources/regression/verification/purescala/valid/ParBalance.scala b/src/test/resources/regression/verification/purescala/valid/ParBalance.scala index 29133762e6da8118e1e380b6fabb4fb3880d482b..35ca7537645182c003ebd3e8003e6db7c5648de2 100644 --- a/src/test/resources/regression/verification/purescala/valid/ParBalance.scala +++ b/src/test/resources/regression/verification/purescala/valid/ParBalance.scala @@ -2,6 +2,7 @@ import leon._ import leon.lang._ +import leon.proof._ object ParBalance { @@ -72,7 +73,7 @@ object ParBalance { def balanced_foldLeft_equivalence(list: List, p: (BigInt, BigInt)): Boolean = { require(p._1 >= 0 && p._2 >= 0) val f = (s: (BigInt, BigInt), x: BigInt) => reduce(s, parPair(x)) - (foldLeft(list, p, f) == (BigInt(0), BigInt(0))) == balanced_withReduce(list, p) && (list match { + (foldLeft(list, p, f) == (BigInt(0), BigInt(0))) == balanced_withReduce(list, p) because (list match { case Cons(head, tail) => val p2 = f(p, head) balanced_foldLeft_equivalence(tail, p2) @@ -178,14 +179,14 @@ object ParBalance { }.holds def reverse_init_equivalence(list: List): Boolean = { - reverse(init(list)) == tail(reverse(list)) && (list match { + reverse(init(list)) == tail(reverse(list)) because (list match { case Cons(head, tail) => reverse_init_equivalence(tail) case Nil() => true }) }.holds def reverse_equality_equivalence(l1: List, l2: List): Boolean = { - (l1 == l2) == (reverse(l1) == reverse(l2)) && ((l1, l2) match { + (l1 == l2) == (reverse(l1) == reverse(l2)) because ((l1, l2) match { case (Cons(h1, t1), Cons(h2, t2)) => reverse_equality_equivalence(t1, t2) case _ => true }) @@ -198,7 +199,7 @@ object ParBalance { // always decreasing, so that the termination checker can prove termination. def reverse_reverse_equivalence(s: BigInt, list: List): Boolean = { require(size(list) == s) - reverse(reverse(list)) == list && ((list, reverse(list)) match { + reverse(reverse(list)) == list because ((list, reverse(list)) match { case (Cons(h1, t1), Cons(h2, t2)) => reverse_reverse_equivalence(size(t1), t1) && reverse_reverse_equivalence(size(t2), t2) case _ => true diff --git a/src/test/resources/regression/verification/purescala/valid/Predicate.scala b/src/test/resources/regression/verification/purescala/valid/Predicate.scala new file mode 100644 index 0000000000000000000000000000000000000000..c011d4b910c2d3859b0e2cca80e069de5b19788b --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Predicate.scala @@ -0,0 +1,48 @@ +package leon.monads.predicate + +import leon.collection._ +import leon.lang._ +import leon.annotation._ + +object Predicate { + + def exists[A](p: A => Boolean): Boolean = !forall((a: A) => !p(a)) + + // Monadic bind + @inline + def flatMap[A,B](p: A => Boolean, f: A => (B => Boolean)): B => Boolean = { + (b: B) => exists[A](a => p(a) && f(a)(b)) + } + + // All monads are also functors, and they define the map function + @inline + def map[A,B](p: A => Boolean, f: A => B): B => Boolean = { + (b: B) => exists[A](a => p(a) && f(a) == b) + } + + /* + @inline + def >>=[B](f: A => Predicate[B]): Predicate[B] = flatMap(f) + + @inline + def >>[B](that: Predicate[B]) = >>= ( _ => that ) + + @inline + def withFilter(f: A => Boolean): Predicate[A] = { + Predicate { a => p(a) && f(a) } + } + */ + + def equals[A](p: A => Boolean, that: A => Boolean): Boolean = { + forall[A](a => p(a) == that(a)) + } + + def test[A,B,C](p: A => Boolean, f: A => B, g: B => C): Boolean = { + equals(map(map(p, f), g), map(p, (a: A) => g(f(a)))) + }.holds + + def testInt(p: BigInt => Boolean, f: BigInt => BigInt, g: BigInt => BigInt): Boolean = { + equals(map(map(p, f), g), map(p, (x: BigInt) => g(f(x)))) + }.holds +} + diff --git a/src/test/resources/regression/verification/purescala/valid/Unapply.scala b/src/test/resources/regression/verification/purescala/valid/Unapply.scala index 941b1f370d740204e8fa850a9d5e5a787b1c401e..1885837d990c32185a42c1232a53de48c3935340 100644 --- a/src/test/resources/regression/verification/purescala/valid/Unapply.scala +++ b/src/test/resources/regression/verification/purescala/valid/Unapply.scala @@ -5,8 +5,18 @@ object Unap { } object Unapply { - def bar: Boolean = { (42, true, ()) match { - case Unap(_, b) if b => b - case Unap((), b) => !b - }} ensuring { res => res } + + sealed abstract class Bool + case class True() extends Bool + case class False() extends Bool + + def not(b: Bool): Bool = b match { + case True() => False() + case False() => True() + } + + def bar: Bool = { (42, True().asInstanceOf[Bool], ()) match { + case Unap(_, b) if b == True() => b + case Unap((), b) => not(b) + }} ensuring { res => res == True() } } diff --git a/src/test/resources/regression/verification/xlang/invalid/NestedFunState1.scala b/src/test/resources/regression/verification/xlang/invalid/NestedFunState1.scala new file mode 100644 index 0000000000000000000000000000000000000000..9372ea742a278a3da3d476dca57e52c55c09b978 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/invalid/NestedFunState1.scala @@ -0,0 +1,20 @@ +object NestedFunState1 { + + def simpleSideEffect(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def incA(prevA: BigInt): Unit = { + require(prevA == a) + a += 1 + } ensuring(_ => a == prevA + 1) + + incA(a) + incA(a) + incA(a) + incA(a) + a + } ensuring(_ == 5) + +} diff --git a/src/test/resources/regression/verification/xlang/invalid/NestedFunState2.scala b/src/test/resources/regression/verification/xlang/invalid/NestedFunState2.scala new file mode 100644 index 0000000000000000000000000000000000000000..68769df28353c9defa6d33000bb8ae4de21c7ace --- /dev/null +++ b/src/test/resources/regression/verification/xlang/invalid/NestedFunState2.scala @@ -0,0 +1,23 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object NestedFunState2 { + + def sum(n: BigInt): BigInt = { + require(n > 0) + var i = BigInt(0) + var res = BigInt(0) + + def iter(): Unit = { + require(res >= i && i >= 0) + if(i < n) { + i += 1 + res += i + iter() + } + } + + iter() + res + } ensuring(_ < 0) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayNested1.scala b/src/test/resources/regression/verification/xlang/valid/ArrayNested1.scala new file mode 100644 index 0000000000000000000000000000000000000000..196a6442bd544a3981e6a7de7e923e040cadd32d --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayNested1.scala @@ -0,0 +1,19 @@ +import leon.lang._ + +object ArrayNested1 { + + def test(): Int = { + + var a = Array(1, 2, 0) + + def nested(): Unit = { + require(a.length == 3) + a = a.updated(1, 5) + } + + nested() + a(1) + + } ensuring(_ == 5) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayNested2.scala b/src/test/resources/regression/verification/xlang/valid/ArrayNested2.scala new file mode 100644 index 0000000000000000000000000000000000000000..a8935ab8c1d4eec6980b85f171e695071f2a9442 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayNested2.scala @@ -0,0 +1,19 @@ +import leon.lang._ + +object ArrayNested2 { + + def test(): Int = { + + val a = Array(1, 2, 0) + + def nested(): Unit = { + require(a.length == 3) + a(2) = 5 + } + + nested() + a(2) + + } ensuring(_ == 5) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder1.scala b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder1.scala new file mode 100644 index 0000000000000000000000000000000000000000..8ef668a2afeb61a6b305831358a65e675bf670ed --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder1.scala @@ -0,0 +1,22 @@ +object FunInvocEvaluationOrder1 { + + def test(): Int = { + + var res = 10 + justAddingStuff({ + res += 1 + res + }, { + res *= 2 + res + }, { + res += 10 + res + }) + + res + } ensuring(_ == 32) + + def justAddingStuff(x: Int, y: Int, z: Int): Int = x + y + z + +} diff --git a/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder2.scala b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder2.scala new file mode 100644 index 0000000000000000000000000000000000000000..12090f0f3fad0a5fcf2451e30fdb6f25bb09c590 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder2.scala @@ -0,0 +1,17 @@ +object FunInvocEvaluationOrder2 { + + def leftToRight(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def nested(x: BigInt, y: BigInt): BigInt = { + require(y >= x) + x + y + } + + nested({a += 10; a}, {a *= 2; a}) + + } ensuring(_ == 30) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder3.scala b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder3.scala new file mode 100644 index 0000000000000000000000000000000000000000..f263529653486ee5477b9fd1934760bcdc737d15 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/FunInvocEvaluationOrder3.scala @@ -0,0 +1,17 @@ +object FunInvocEvaluationOrder3 { + + def leftToRight(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def nested(x: BigInt, y: BigInt): Unit = { + a = x + y + } + + nested({a += 10; a}, {a *= 2; a}) + a + + } ensuring(_ == 30) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState1.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState1.scala new file mode 100644 index 0000000000000000000000000000000000000000..1ec0266f846dd55cfa1a6274d5c6dc7dc67dd125 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState1.scala @@ -0,0 +1,23 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object NestedFunState1 { + + def sum(n: BigInt): BigInt = { + require(n > 0) + var i = BigInt(0) + var res = BigInt(0) + + def iter(): Unit = { + require(res >= i && i >= 0) + if(i < n) { + i += 1 + res += i + iter() + } + } ensuring(_ => res >= n) + + iter() + res + } ensuring(_ >= n) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState2.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState2.scala new file mode 100644 index 0000000000000000000000000000000000000000..e48b1b959b3a83262329cf48c63370a2b591cd41 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState2.scala @@ -0,0 +1,19 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object NestedFunState2 { + + def countConst(): Int = { + + var counter = 0 + + def inc(): Unit = { + counter += 1 + } + + inc() + inc() + inc() + counter + } ensuring(_ == 3) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState3.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState3.scala new file mode 100644 index 0000000000000000000000000000000000000000..650f5987dcd352a3f5c47d41c684e94d4b9be20c --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState3.scala @@ -0,0 +1,25 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ +import leon.lang._ + +object NestedFunState3 { + + + def counterN(n: Int): Int = { + require(n > 0) + + var counter = 0 + + def inc(): Unit = { + counter += 1 + } + + var i = 0 + (while(i < n) { + inc() + i += 1 + }) invariant(i >= 0 && counter == i && i <= n) + + counter + } ensuring(_ == n) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState4.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState4.scala new file mode 100644 index 0000000000000000000000000000000000000000..9f4fb2621f60c60da3a677dc732e6c93f4d8ae72 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState4.scala @@ -0,0 +1,38 @@ +import leon.lang._ + +object NestedFunState4 { + + def deep(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def iter(): Unit = { + require(a >= 0) + + var b = BigInt(0) + + def nestedIter(): Unit = { + b += 1 + } + + var i = BigInt(0) + (while(i < n) { + i += 1 + nestedIter() + }) invariant(i >= 0 && i <= n && b == i) + + a += b + + } ensuring(_ => a >= n) + + var i = BigInt(0) + (while(i < n) { + i += 1 + iter() + }) invariant(i >= 0 && i <= n && a >= i) + + a + } ensuring(_ >= n) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState5.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState5.scala new file mode 100644 index 0000000000000000000000000000000000000000..13f3cf47cd0ee6c831f12877493cbfe6c7edea5a --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState5.scala @@ -0,0 +1,29 @@ +import leon.lang._ + +object NestedFunState5 { + + def deep(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def iter(prevA: BigInt): Unit = { + require(prevA == a) + def nestedIter(): Unit = { + a += 1 + } + + nestedIter() + nestedIter() + + } ensuring(_ => a == prevA + 2) + + var i = BigInt(0) + (while(i < n) { + i += 1 + iter(a) + }) invariant(i >= 0 && i <= n && a >= 2*i) + + a + } ensuring(_ >= 2*n) +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState6.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState6.scala new file mode 100644 index 0000000000000000000000000000000000000000..cea0f2e1900a00bc803df4107b46dd62cf082c68 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState6.scala @@ -0,0 +1,20 @@ +object NestedFunState6 { + + def simpleSideEffect(n: BigInt): BigInt = { + require(n > 0) + + var a = BigInt(0) + + def incA(prevA: BigInt): Unit = { + require(prevA == a) + a += 1 + } ensuring(_ => a == prevA + 1) + + incA(a) + incA(a) + incA(a) + incA(a) + a + } ensuring(_ == 4) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunState7.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunState7.scala new file mode 100644 index 0000000000000000000000000000000000000000..ff2418a0c972add2d48c08f6319905f35d0493c2 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunState7.scala @@ -0,0 +1,27 @@ +import leon.lang._ + +object NestedFunState7 { + + def test(x: BigInt): BigInt = { + + var a = BigInt(0) + + def defCase(): Unit = { + a = 100 + } + + if(x >= 0) { + a = 2*x + if(a < 100) { + a = 100 - a + } else { + defCase() + } + } else { + defCase() + } + + a + } ensuring(res => res >= 0 && res <= 100) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedOld1.scala b/src/test/resources/regression/verification/xlang/valid/NestedOld1.scala new file mode 100644 index 0000000000000000000000000000000000000000..903dfd63aca5105c9366c459552865646139033e --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedOld1.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object NestedOld1 { + + def test(): Int = { + var counter = 0 + + def inc(): Unit = { + counter += 1 + } ensuring(_ => counter == old(counter) + 1) + + inc() + counter + } ensuring(_ == 1) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedOld2.scala b/src/test/resources/regression/verification/xlang/valid/NestedOld2.scala new file mode 100644 index 0000000000000000000000000000000000000000..fe6143ecf623c295251ca2b9d75d6474d422fd4f --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedOld2.scala @@ -0,0 +1,24 @@ +import leon.lang._ + +object NestedOld2 { + + def test(): Int = { + var counterPrev = 0 + var counterNext = 1 + + def step(): Unit = { + require(counterNext == counterPrev + 1) + counterPrev = counterNext + counterNext = counterNext+1 + } ensuring(_ => { + counterPrev == old(counterNext) && + counterNext == old(counterNext) + 1 && + counterPrev == old(counterPrev) + 1 + }) + + step() + step() + counterNext + } ensuring(_ == 3) + +} diff --git a/src/test/scala/leon/integration/evaluators/CodegenEvaluatorSuite.scala b/src/test/scala/leon/integration/evaluators/CodegenEvaluatorSuite.scala index 0072da45d8f6556100410f733685c5c8f7177b89..e5565e073dc629a546019e219297c841c9d84258 100644 --- a/src/test/scala/leon/integration/evaluators/CodegenEvaluatorSuite.scala +++ b/src/test/scala/leon/integration/evaluators/CodegenEvaluatorSuite.scala @@ -9,7 +9,7 @@ import leon.purescala.Expressions._ import leon.purescala.Types._ import leon.codegen._ -class CodegenEvaluatorSuite extends LeonTestSuiteWithProgram { +class CodegenEvaluatorSuite extends LeonTestSuiteWithProgram with helpers.ExpressionsDSL{ val sources = List(""" import leon.lang._ @@ -316,7 +316,7 @@ class CodegenEvaluatorSuite extends LeonTestSuiteWithProgram { "Overrides1" -> Tuple(Seq(BooleanLiteral(false), BooleanLiteral(true))), "Overrides2" -> Tuple(Seq(BooleanLiteral(false), BooleanLiteral(true))), "Arrays1" -> IntLiteral(2), - "Arrays2" -> IntLiteral(6) + "Arrays2" -> IntLiteral(10) ) for { @@ -324,29 +324,25 @@ class CodegenEvaluatorSuite extends LeonTestSuiteWithProgram { requireMonitor <- Seq(false, true) doInstrument <- Seq(false,true) } { - val opts = ((if(requireMonitor) Some("monitor") else None) ++ - (if(doInstrument) Some("instrument") else None)).mkString("+") + val opts = (if(requireMonitor) "monitor " else "") + + (if(doInstrument) "instrument" else "") val testName = f"$name%-20s $opts%-18s" - test("Evaluation of "+testName) { case (ctx, pgm) => - val eval = new CodeGenEvaluator(ctx, pgm, CodeGenParams( + test("Evaluation of "+testName) { implicit fix => + val eval = new CodeGenEvaluator(fix._1, fix._2, CodeGenParams( maxFunctionInvocations = if (requireMonitor) 1000 else -1, // Monitor calls and abort execution if more than X calls checkContracts = true, // Generate calls that checks pre/postconditions doInstrument = doInstrument // Instrument reads to case classes (mainly for vanuatoo) )) - val fun = pgm.lookup(name+".test").collect { - case fd: FunDef => fd - }.getOrElse { - fail("Failed to lookup '"+name+".test'") - } - - (eval.eval(FunctionInvocation(fun.typed(Seq()), Seq())).result, exp) match { + (eval.eval(fcall(name + ".test")()).result, exp) match { case (Some(res), exp) => + assert(res === exp) case (None, Error(_, _)) => - case (None, _) => - case (_, Error(_, _)) => + // OK + case _ => + fail("") } } } diff --git a/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala b/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala index 810d9d40edfd4aaf18c7c54aff0a579db600f2db..84b0087caee305b66e16d82ad2adcd9f4b10e06d 100644 --- a/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala +++ b/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala @@ -5,7 +5,7 @@ package leon.integration.evaluators import leon._ import leon.test._ import leon.test.helpers._ -import leon.evaluators._ +import leon.evaluators.{Evaluator => _, DeterministicEvaluator => Evaluator, _} import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.purescala.Expressions._ @@ -187,6 +187,38 @@ class EvaluatorSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { | def f3 = C(42).isInstanceOf[A] |}""".stripMargin, + """import leon.lang._ + |import leon.collection._ + | + |object Foo { + | def unapply(i: BigInt): Option[BigInt] = if (i > 0) Some(i) else None() + |} + | + |object Unapply { + | def foo = + | (BigInt(1) match { + | case Foo(i) => i + | case _ => BigInt(0) + | }) + (BigInt(-12) match { + | case Foo(i) => i + | case _ => BigInt(2) + | }) + | + | def size[A](l: List[A]): BigInt = l match { + | case _ :: _ :: _ :: Nil() => 3 + | case _ :: _ :: Nil() => 2 + | case _ :: Nil() => 1 + | case Nil() => 0 + | case _ :: _ => 42 + | } + | + | def s1 = size(1 :: 2 :: 3 :: Nil[Int]()) + | def s2 = size(Nil[Int]()) + | def s3 = size(List(1,2,3,4,5,6,7,8)) + | + |} + """.stripMargin, + """object Casts1 { | abstract class Foo | case class Bar1(v: BigInt) extends Foo @@ -205,7 +237,8 @@ class EvaluatorSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { def normalEvaluators(implicit ctx: LeonContext, pgm: Program): List[Evaluator] = { List( - new DefaultEvaluator(ctx, pgm) + new DefaultEvaluator(ctx, pgm), + new AngelicEvaluator(new StreamEvaluator(ctx, pgm)) ) } @@ -370,6 +403,15 @@ class EvaluatorSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { } } + test("Unapply") { implicit fix => + for(e <- allEvaluators) { + eval(e, fcall("Unapply.foo")()) === bi(3) + eval(e, fcall("Unapply.s1")()) === bi(3) + eval(e, fcall("Unapply.s2")()) === bi(0) + eval(e, fcall("Unapply.s3")()) === bi(42) + } + } + test("Casts1") { implicit fix => def bar1(es: Expr*) = cc("Casts1.Bar1")(es: _*) def bar2(es: Expr*) = cc("Casts1.Bar2")(es: _*) diff --git a/src/test/scala/leon/integration/purescala/CallGraphSuite.scala b/src/test/scala/leon/integration/purescala/CallGraphSuite.scala index 7c47baa4b1abf2ec0c6493503fd7c944983c8c50..636cd4f406d37c2786f3e3493e48db9a7a98bec7 100644 --- a/src/test/scala/leon/integration/purescala/CallGraphSuite.scala +++ b/src/test/scala/leon/integration/purescala/CallGraphSuite.scala @@ -4,9 +4,7 @@ package leon.integration.purescala import leon.test._ -import leon._ -import leon.purescala.Definitions._ -import leon.utils._ +import leon.purescala.Definitions.Program class CallGraphSuite extends LeonTestSuiteWithProgram with helpers.ExpressionsDSL { diff --git a/src/test/scala/leon/integration/purescala/DataGenSuite.scala b/src/test/scala/leon/integration/purescala/DataGenSuite.scala index 3a2eaf24fbdc5b6cbbf66625970a068b513dd595..b1acfd74c26a745c5016bd219716783a6bea0ccd 100644 --- a/src/test/scala/leon/integration/purescala/DataGenSuite.scala +++ b/src/test/scala/leon/integration/purescala/DataGenSuite.scala @@ -3,12 +3,9 @@ package leon.integration.purescala import leon.test._ -import leon.utils.{TemporaryInputPhase, PreprocessingPhase} -import leon.frontends.scalac.ExtractionPhase import leon.purescala.Common._ import leon.purescala.Expressions._ -import leon.purescala.Definitions._ import leon.purescala.Types._ import leon.datagen._ diff --git a/src/test/scala/leon/integration/purescala/DefOpsSuite.scala b/src/test/scala/leon/integration/purescala/DefOpsSuite.scala index 3df239fb142498445443586e49e1fc3fc57e177b..c7aad05c63a2365e711a5f093853bd5524a87473 100644 --- a/src/test/scala/leon/integration/purescala/DefOpsSuite.scala +++ b/src/test/scala/leon/integration/purescala/DefOpsSuite.scala @@ -4,10 +4,8 @@ package leon.integration.purescala import leon.test._ -import leon._ import leon.purescala.Definitions._ import leon.purescala.DefOps._ -import leon.utils._ class DefOpsSuite extends LeonTestSuiteWithProgram { diff --git a/src/test/scala/leon/integration/solvers/EnumerationSolverSuite.scala b/src/test/scala/leon/integration/solvers/EnumerationSolverSuite.scala index 59324582603c4460234f3a493766791a34ba27d9..89031230bf40c9bcf03a27f0f41bc5be06c125a7 100644 --- a/src/test/scala/leon/integration/solvers/EnumerationSolverSuite.scala +++ b/src/test/scala/leon/integration/solvers/EnumerationSolverSuite.scala @@ -2,7 +2,6 @@ package leon.integration.solvers -import leon.test._ import leon.solvers._ import leon.purescala.Common._ import leon.purescala.Definitions._ diff --git a/src/test/scala/leon/integration/solvers/FairZ3SolverTests.scala b/src/test/scala/leon/integration/solvers/FairZ3SolverTests.scala index f89ecfd13746390086b6566573e8a45c8391ebd6..b08e7f9164173d12e681b6dce086a75a9e225990 100644 --- a/src/test/scala/leon/integration/solvers/FairZ3SolverTests.scala +++ b/src/test/scala/leon/integration/solvers/FairZ3SolverTests.scala @@ -2,14 +2,12 @@ package leon.integration.solvers -import leon.test._ import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.purescala.Expressions._ import leon.purescala.Types._ import leon.LeonContext -import leon.solvers._ import leon.solvers.z3._ class FairZ3SolverTests extends LeonSolverSuite { diff --git a/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala b/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..64cb3b70b062ba30ccf6223af587ed85c9b4ff88 --- /dev/null +++ b/src/test/scala/leon/integration/solvers/GlobalVariablesSuite.scala @@ -0,0 +1,74 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.integration.solvers + +import leon.test._ +import leon.test.helpers._ +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.ExprOps._ +import leon.purescala.Constructors._ +import leon.purescala.Expressions._ +import leon.purescala.Types._ +import leon.LeonContext + +import leon.solvers._ +import leon.solvers.smtlib._ +import leon.solvers.combinators._ +import leon.solvers.z3._ + +class GlobalVariablesSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { + + val sources = List( + """|import leon.lang._ + |import leon.annotation._ + | + |object GlobalVariables { + | + | def test(i: BigInt): BigInt = { + | 0 // will be replaced + | } + |} """.stripMargin + ) + + val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { + (if (SolverFactory.hasNativeZ3) Seq( + ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) + ) else Nil) ++ + (if (SolverFactory.hasZ3) Seq( + ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ) else Nil) ++ + (if (SolverFactory.hasCVC4) Seq( + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ) else Nil) + } + + // Check that we correctly extract several types from solver models + for ((sname, sf) <- getFactories) { + test(s"Global Variables in $sname") { implicit fix => + val ctx = fix._1 + val pgm = fix._2 + + pgm.lookup("GlobalVariables.test") match { + case Some(fd: FunDef) => + val b0 = FreshIdentifier("B", BooleanType); + fd.body = Some(IfExpr(b0.toVariable, bi(1), bi(-1))) + + val cnstr = LessThan(FunctionInvocation(fd.typed, Seq(bi(42))), bi(0)) + val solver = sf(ctx, pgm) + solver.assertCnstr(And(b0.toVariable, cnstr)) + + try { + if (solver.check != Some(false)) { + fail("Global variables not correctly handled.") + } + } finally { + solver.free() + } + case _ => + fail("Function with global body not found") + } + + } + } +} diff --git a/src/test/scala/leon/integration/solvers/ModelEnumerationSuite.scala b/src/test/scala/leon/integration/solvers/ModelEnumerationSuite.scala index 38553d171ddcdb1c0738aacf97e717e05f3e8256..3b405720a004c288f33a25456b6cd961d7ed7e53 100644 --- a/src/test/scala/leon/integration/solvers/ModelEnumerationSuite.scala +++ b/src/test/scala/leon/integration/solvers/ModelEnumerationSuite.scala @@ -12,7 +12,7 @@ import leon.purescala.Common._ import leon.evaluators._ import leon.purescala.Expressions._ -class ModelEnumeratorSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { +class ModelEnumerationSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { val sources = List( """|import leon.lang._ |import leon.annotation._ diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index 40fd073574cc9cb4f58d97ef22ca3b70ad87ce69..7ba3913305d25006906faa9791fa719ffccd8121 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -32,61 +32,72 @@ class SolversSuite extends LeonTestSuiteWithProgram { ) else Nil) } - // Check that we correctly extract several types from solver models - for ((sname, sf) <- getFactories) { - test(s"Model Extraction in $sname") { implicit fix => - val ctx = fix._1 - val pgm = fix._2 + val types = Seq( + BooleanType, + UnitType, + CharType, + RealType, + IntegerType, + Int32Type, + TypeParameter.fresh("T"), + SetType(IntegerType), + MapType(IntegerType, IntegerType), + FunctionType(Seq(IntegerType), IntegerType), + TupleType(Seq(IntegerType, BooleanType, Int32Type)) + ) - val solver = sf(ctx, pgm) + val vs = types.map(FreshIdentifier("v", _).toVariable) - val types = Seq( - BooleanType, - UnitType, - CharType, - IntegerType, - Int32Type, - TypeParameter.fresh("T"), - SetType(IntegerType), - MapType(IntegerType, IntegerType), - TupleType(Seq(IntegerType, BooleanType, Int32Type)) - ) + // We need to make sure models are not co-finite + val cnstrs = vs.map(v => v.getType match { + case UnitType => + Equals(v, simplestValue(v.getType)) + case SetType(base) => + Not(ElementOfSet(simplestValue(base), v)) + case MapType(from, to) => + Not(Equals(MapApply(v, simplestValue(from)), simplestValue(to))) + case FunctionType(froms, to) => + Not(Equals(Application(v, froms.map(simplestValue)), simplestValue(to))) + case _ => + not(Equals(v, simplestValue(v.getType))) + }) - val vs = types.map(FreshIdentifier("v", _).toVariable) + def checkSolver(solver: Solver, vs: Set[Variable], cnstr: Expr)(implicit fix: (LeonContext, Program)): Unit = { + try { + solver.assertCnstr(cnstr) - - // We need to make sure models are not co-finite - val cnstr = andJoin(vs.map(v => v.getType match { - case UnitType => - Equals(v, simplestValue(v.getType)) - case SetType(base) => - Not(ElementOfSet(simplestValue(base), v)) - case MapType(from, to) => - Not(Equals(MapApply(v, simplestValue(from)), simplestValue(to))) - case _ => - not(Equals(v, simplestValue(v.getType))) - })) - - try { - solver.assertCnstr(cnstr) - - solver.check match { - case Some(true) => - val model = solver.getModel - for (v <- vs) { - if (model.isDefinedAt(v.id)) { - assert(model(v.id).getType === v.getType, "Extracting value of type "+v.getType) - } else { - fail("Model does not contain "+v.id+" of type "+v.getType) - } + solver.check match { + case Some(true) => + val model = solver.getModel + for (v <- vs) { + if (model.isDefinedAt(v.id)) { + assert(model(v.id).getType === v.getType, "Extracting value of type "+v.getType) + } else { + fail("Model does not contain "+v.id+" of type "+v.getType) } - case _ => - fail("Constraint "+cnstr.asString+" is unsat!?") - } - } finally { - solver.free + } + case _ => + fail("Constraint "+cnstr.asString+" is unsat!?") } + } finally { + solver.free + } + } + + // Check that we correctly extract several types from solver models + for ((sname, sf) <- getFactories) { + test(s"Model Extraction in $sname") { implicit fix => + val ctx = fix._1 + val pgm = fix._2 + val solver = sf(ctx, pgm) + checkSolver(solver, vs.toSet, andJoin(cnstrs)) + } + } + test(s"Data generation in enum solver") { implicit fix => + for ((v,cnstr) <- vs zip cnstrs) { + val solver = new EnumerationSolver(fix._1, fix._2) + checkSolver(solver, Set(v), cnstr) } } } diff --git a/src/test/scala/leon/integration/solvers/TimeoutSolverSuite.scala b/src/test/scala/leon/integration/solvers/TimeoutSolverSuite.scala index cce647cc1eed81df80b6f64d1ffcff2acf9c84aa..4ee34098827272c7039b5865fb016b465ca9ccd4 100644 --- a/src/test/scala/leon/integration/solvers/TimeoutSolverSuite.scala +++ b/src/test/scala/leon/integration/solvers/TimeoutSolverSuite.scala @@ -4,7 +4,6 @@ package leon.integration.solvers import leon._ import leon.test._ -import leon.utils.Interruptible import leon.solvers._ import leon.purescala.Common._ import leon.purescala.Definitions._ diff --git a/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala b/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala index d9e4d8bb969138727a9fcbde103d709813f88c5e..b3c0f82f2d98362c873c555992fca0367fe13ba1 100644 --- a/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala +++ b/src/test/scala/leon/integration/solvers/UnrollingSolverSuite.scala @@ -2,13 +2,11 @@ package leon.integration.solvers -import leon.test._ import leon.LeonContext import leon.purescala.Expressions._ import leon.purescala.Types._ import leon.purescala.Common._ import leon.purescala.Definitions._ -import leon.solvers._ import leon.solvers.z3._ import leon.solvers.combinators._ diff --git a/src/test/scala/leon/regression/repair/RepairSuite.scala b/src/test/scala/leon/regression/repair/RepairSuite.scala index 50b845f923e2356d1ce4c0c9a27692a78cd30ce5..6301f4ea1a61144f4b38cd2d7660579c192f5089 100644 --- a/src/test/scala/leon/regression/repair/RepairSuite.scala +++ b/src/test/scala/leon/regression/repair/RepairSuite.scala @@ -16,6 +16,7 @@ class RepairSuite extends LeonRegressionSuite { val fileToFun = Map( "Compiler1.scala" -> "desugar", + "Heap3.scala" -> "merge", "Heap4.scala" -> "merge", "ListEasy.scala" -> "pad", "List1.scala" -> "pad", @@ -23,7 +24,10 @@ class RepairSuite extends LeonRegressionSuite { "MergeSort2.scala" -> "merge" ) - for (file <- filesInResourceDir("regression/repair/", _.endsWith(".scala")) if fileToFun contains file.getName) { + for (file <- filesInResourceDir("regression/repair/", _.endsWith(".scala"))) { + if (!(fileToFun contains file.getName)) { + fail(s"Don't know which function to repair for ${file.getName}") + } val path = file.getAbsoluteFile.toString val name = file.getName @@ -43,7 +47,7 @@ class RepairSuite extends LeonRegressionSuite { test(name) { pipeline.run(ctx, List(path)) if(reporter.errorCount > 0) { - fail("Errors during repair:\n")//+reporter.lastErrors.mkString("\n")) + fail("Errors during repair:\n"+reporter.lastErrors.mkString("\n")) } } } diff --git a/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala b/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala index 8ed9ed14e43e81cf7bcdfe10143b8627bf40e182..c8bfc2688ec53e75e886cffc26621778675c454d 100644 --- a/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala +++ b/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala @@ -50,7 +50,7 @@ class StablePrintingSuite extends LeonRegressionSuite { InnerCaseSplit ) - def getChooses(ctx: LeonContext, content: String): (Program, SynthesisSettings, Seq[ChooseInfo]) = { + def getChooses(ctx: LeonContext, content: String): (Program, SynthesisSettings, Seq[SourceInfo]) = { val opts = SynthesisSettings() val pipeline = leon.utils.TemporaryInputPhase andThen frontends.scalac.ExtractionPhase andThen @@ -58,7 +58,7 @@ class StablePrintingSuite extends LeonRegressionSuite { val (ctx2, program) = pipeline.run(ctx, (List(content), Nil)) - (program, opts, ChooseInfo.extractFromProgram(ctx2, program)) + (program, opts, SourceInfo.extractFromProgram(ctx2, program)) } case class Job(content: String, choosesToProcess: Set[Int], rules: List[String]) { @@ -93,8 +93,8 @@ class StablePrintingSuite extends LeonRegressionSuite { for (e <- reporter.lastErrors) { info(e) } - println(e) info(e.getMessage) + e.printStackTrace() fail("Compilation failed") } @@ -117,7 +117,7 @@ class StablePrintingSuite extends LeonRegressionSuite { case Some(sol) => val result = sol.toSimplifiedExpr(ctx, pgm) - val newContent = new FileInterface(ctx.reporter).substitute(j.content, ci.ch, (indent: Int) => { + val newContent = new FileInterface(ctx.reporter).substitute(j.content, ci.source, (indent: Int) => { val p = new ScalaPrinter(PrinterOptions(), Some(pgm)) p.pp(result)(PrinterContext(result, List(ci.fd), indent, p)) p.toString diff --git a/src/test/scala/leon/regression/synthesis/SynthesisRegressionSuite.scala b/src/test/scala/leon/regression/synthesis/SynthesisRegressionSuite.scala index b5e4bfe993f0a92b10f0dbe42582eee162625281..9dbb0e36fc65223a06b132dc393ab62f7748948a 100644 --- a/src/test/scala/leon/regression/synthesis/SynthesisRegressionSuite.scala +++ b/src/test/scala/leon/regression/synthesis/SynthesisRegressionSuite.scala @@ -21,7 +21,7 @@ class SynthesisRegressionSuite extends LeonRegressionSuite { private def testSynthesis(cat: String, f: File, bound: Int) { - var chooses = List[ChooseInfo]() + var chooses = List[SourceInfo]() var program: Program = null var ctx: LeonContext = null var opts: SynthesisSettings = null @@ -37,7 +37,7 @@ class SynthesisRegressionSuite extends LeonRegressionSuite { program = pgm2 - chooses = ChooseInfo.extractFromProgram(ctx2, program) + chooses = SourceInfo.extractFromProgram(ctx2, program) } for (ci <- chooses) { diff --git a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala index f95d7b964df84fb1d44b8a560ea4a01bdbd51b6e..c70df950e0768ca71dd7bba013ac04cd7404edab 100644 --- a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala @@ -24,7 +24,7 @@ class SynthesisSuite extends LeonRegressionSuite { class TestSearch(ctx: LeonContext, - ci: ChooseInfo, + ci: SourceInfo, p: Problem, strat: SynStrat) extends SimpleSearch(ctx, ci, p, CostModels.default, None) { @@ -154,9 +154,9 @@ object Injection { case class Nil() extends List // proved with unrolling=0 - def size(l: List) : Int = (l match { - case Nil() => 0 - case Cons(t) => 1 + size(t) + def size(l: List) : BigInt = (l match { + case Nil() => BigInt(0) + case Cons(t) => BigInt(1) + size(t) }) ensuring(res => res >= 0) def simple(in: List) = choose{out: List => size(out) == size(in) } @@ -274,10 +274,10 @@ object ChurchNumerals { case object Z extends Num case class S(pred: Num) extends Num - def value(n:Num) : Int = { + def value(n:Num) : BigInt = { n match { - case Z => 0 - case S(p) => 1 + value(p) + case Z => BigInt(0) + case S(p) => BigInt(1) + value(p) } } ensuring (_ >= 0) @@ -309,10 +309,10 @@ object ChurchNumerals { case object Z extends Num case class S(pred: Num) extends Num - def value(n:Num) : Int = { + def value(n:Num) : BigInt = { n match { - case Z => 0 - case S(p) => 1 + value(p) + case Z => BigInt(0) + case S(p) => BigInt(1) + value(p) } } ensuring (_ >= 0) diff --git a/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala b/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala index d308e840f159a8ab7de49380d920dfc12ee8d2ad..569eb95df2f7f66a05d7a825765adb6123b4d854 100644 --- a/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala +++ b/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala @@ -62,7 +62,7 @@ class PureScalaValidSuiteZ3 extends PureScalaValidSuite { val optionVariants = if (isZ3Available) List(opts(3)) else Nil } class PureScalaValidSuiteCVC4 extends PureScalaValidSuite { - val optionVariants = if (isCVC4Available) List(opts(4)) else Nil + val optionVariants = if (isCVC4Available) opts.takeRight(1) else Nil } class PureScalaInvalidSuite extends PureScalaVerificationSuite { diff --git a/src/test/scala/leon/test/LeonTestSuite.scala b/src/test/scala/leon/test/LeonTestSuite.scala index 4c459cf6b2d1608db9371d21086c3bd462528b34..18bd8fc1465988ed7c3894858cc266f975b2395c 100644 --- a/src/test/scala/leon/test/LeonTestSuite.scala +++ b/src/test/scala/leon/test/LeonTestSuite.scala @@ -3,10 +3,6 @@ package leon.test import leon._ -import leon.purescala.Definitions.Program -import leon.LeonContext -import leon.utils._ -import leon.frontends.scalac.ExtractionPhase import org.scalatest._ import org.scalatest.exceptions.TestFailedException diff --git a/src/test/scala/leon/unit/evaluators/EvaluatorSuite.scala b/src/test/scala/leon/unit/evaluators/EvaluatorSuite.scala index 4d97487f6d85a50e39efd8277d64c99958a63df4..b510c01bb891532c60cace6017b2a61d09bfdbb3 100644 --- a/src/test/scala/leon/unit/evaluators/EvaluatorSuite.scala +++ b/src/test/scala/leon/unit/evaluators/EvaluatorSuite.scala @@ -1,40 +1,36 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -package leon.unit.allEvaluators +package leon.unit.evaluators import leon._ import leon.test._ import leon.evaluators._ -import leon.utils.{TemporaryInputPhase, PreprocessingPhase} -import leon.frontends.scalac.ExtractionPhase - import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.purescala.Expressions._ -import leon.purescala.DefOps._ import leon.purescala.Types._ import leon.purescala.Extractors._ import leon.purescala.Constructors._ -import leon.codegen._ class EvaluatorSuite extends LeonTestSuite with helpers.ExpressionsDSL { implicit val pgm = Program.empty - def normalEvaluators(implicit ctx: LeonContext, pgm: Program): List[Evaluator] = { + def normalEvaluators(implicit ctx: LeonContext, pgm: Program): List[DeterministicEvaluator] = { List( - new DefaultEvaluator(ctx, pgm) + new DefaultEvaluator(ctx, pgm), + new AngelicEvaluator(new StreamEvaluator(ctx, pgm)) ) } - def codegenEvaluators(implicit ctx: LeonContext, pgm: Program): List[Evaluator] = { + def codegenEvaluators(implicit ctx: LeonContext, pgm: Program): List[DeterministicEvaluator] = { List( new CodeGenEvaluator(ctx, pgm) ) } - def allEvaluators(implicit ctx: LeonContext, pgm: Program): List[Evaluator] = { + def allEvaluators(implicit ctx: LeonContext, pgm: Program): List[DeterministicEvaluator] = { normalEvaluators ++ codegenEvaluators } @@ -275,7 +271,7 @@ class EvaluatorSuite extends LeonTestSuite with helpers.ExpressionsDSL { def success: Expr = res } - case class Success(expr: Expr, env: Map[Identifier, Expr], evaluator: Evaluator, res: Expr) extends EvalDSL { + case class Success(expr: Expr, env: Map[Identifier, Expr], evaluator: DeterministicEvaluator, res: Expr) extends EvalDSL { override def failed = { fail(s"Evaluation of '$expr' with '$evaluator' (and env $env) should have failed") } @@ -285,7 +281,7 @@ class EvaluatorSuite extends LeonTestSuite with helpers.ExpressionsDSL { } } - case class Failed(expr: Expr, env: Map[Identifier, Expr], evaluator: Evaluator, err: String) extends EvalDSL { + case class Failed(expr: Expr, env: Map[Identifier, Expr], evaluator: DeterministicEvaluator, err: String) extends EvalDSL { override def success = { fail(s"Evaluation of '$expr' with '$evaluator' (and env $env) should have succeeded but failed with $err") } @@ -295,7 +291,7 @@ class EvaluatorSuite extends LeonTestSuite with helpers.ExpressionsDSL { def ===(res: Expr) = success } - def eval(e: Evaluator, toEval: Expr, env: Map[Identifier, Expr] = Map()): EvalDSL = { + def eval(e: DeterministicEvaluator, toEval: Expr, env: Map[Identifier, Expr] = Map()): EvalDSL = { e.eval(toEval, env) match { case EvaluationResults.Successful(res) => Success(toEval, env, e, res) case EvaluationResults.RuntimeError(err) => Failed(toEval, env, e, err) diff --git a/testcases/synthesis/etienne-thesis/List/Delete.scala b/testcases/synthesis/etienne-thesis/List/Delete.scala new file mode 100644 index 0000000000000000000000000000000000000000..46f0710878d25a30e679f034ff07f985078950d1 --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Delete.scala @@ -0,0 +1,41 @@ +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Delete { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1) + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + def insert(in1: List, v: BigInt): List = { + Cons(v, in1) + } ensuring { content(_) == content(in1) ++ Set(v) } + + //def delete(in1: List, v: BigInt): List = { + // in1 match { + // case Cons(h,t) => + // if (h == v) { + // delete(t, v) + // } else { + // Cons(h, delete(t, v)) + // } + // case Nil => + // Nil + // } + //} ensuring { content(_) == content(in1) -- Set(v) } + + def delete(in1: List, v: BigInt) = choose { + (out : List) => + content(out) == content(in1) -- Set(v) + } +} diff --git a/testcases/synthesis/etienne-thesis/List/Diff.scala b/testcases/synthesis/etienne-thesis/List/Diff.scala new file mode 100644 index 0000000000000000000000000000000000000000..9fb3ade9558a7db175c6056efb9bf724f487d7ca --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Diff.scala @@ -0,0 +1,50 @@ +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Diff { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1) + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + def insert(in1: List, v: BigInt): List = { + Cons(v, in1) + } ensuring { content(_) == content(in1) ++ Set(v) } + + def delete(in1: List, v: BigInt): List = { + in1 match { + case Cons(h,t) => + if (h == v) { + delete(t, v) + } else { + Cons(h, delete(t, v)) + } + case Nil => + Nil + } + } ensuring { content(_) == content(in1) -- Set(v) } + + // def diff(in1: List, in2: List): List = { + // in2 match { + // case Nil => + // in1 + // case Cons(h, t) => + // diff(delete(in1, h), t) + // } + // } ensuring { content(_) == content(in1) -- content(in2) } + + def diff(in1: List, in2: List) = choose { + (out : List) => + content(out) == content(in1) -- content(in2) + } +} diff --git a/testcases/synthesis/etienne-thesis/List/Insert.scala b/testcases/synthesis/etienne-thesis/List/Insert.scala new file mode 100644 index 0000000000000000000000000000000000000000..48c38f4df0098410814f3ed27038c9c36c0d6532 --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Insert.scala @@ -0,0 +1,28 @@ +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Insert { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1) + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + //def insert(in1: List, v: BigInt): List = { + // Cons(v, in1) + //} ensuring { content(_) == content(in1) ++ Set(v) } + + def insert(in1: List, v: BigInt) = choose { + (out : List) => + content(out) == content(in1) ++ Set(v) + } +} diff --git a/testcases/synthesis/etienne-thesis/List/Split.scala b/testcases/synthesis/etienne-thesis/List/Split.scala new file mode 100644 index 0000000000000000000000000000000000000000..fb98204096f3e4a97b990527e588b655e8b93fac --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Split.scala @@ -0,0 +1,41 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Complete { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1) + size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + def splitSpec(list : List, res : (List,List)) : Boolean = { + val s1 = size(res._1) + val s2 = size(res._2) + abs(s1 - s2) <= 1 && s1 + s2 == size(list) && + content(res._1) ++ content(res._2) == content(list) + } + + def abs(i : BigInt) : BigInt = { + if(i < 0) -i else i + } ensuring(_ >= 0) + + def dispatch(es: (BigInt, BigInt), rest: (List, List)): (List, List) = { + (Cons(es._1, rest._1), Cons(es._2, rest._2)) + } + + def split(list : List) : (List,List) = { + choose { (res : (List,List)) => splitSpec(list, res) } + } + +} diff --git a/testcases/synthesis/etienne-thesis/List/Union.scala b/testcases/synthesis/etienne-thesis/List/Union.scala new file mode 100644 index 0000000000000000000000000000000000000000..d6a5fa579f5745f00e469f19ee853da15d4fece7 --- /dev/null +++ b/testcases/synthesis/etienne-thesis/List/Union.scala @@ -0,0 +1,37 @@ +import leon.annotation._ +import leon.lang._ +import leon.lang.synthesis._ + +object Union { + sealed abstract class List + case class Cons(head: BigInt, tail: List) extends List + case object Nil extends List + + def size(l: List) : BigInt = (l match { + case Nil => BigInt(0) + case Cons(_, t) => BigInt(1)+ size(t) + }) ensuring(res => res >= 0) + + def content(l: List): Set[BigInt] = l match { + case Nil => Set.empty[BigInt] + case Cons(i, t) => Set(i) ++ content(t) + } + + def insert(in1: List, v: BigInt): List = { + Cons(v, in1) + } ensuring { content(_) == content(in1) ++ Set(v) } + + // def union(in1: List, in2: List): List = { + // in1 match { + // case Cons(h, t) => + // Cons(h, union(t, in2)) + // case Nil => + // in2 + // } + // } ensuring { content(_) == content(in1) ++ content(in2) } + + def union(in1: List, in2: List) = choose { + (out : List) => + content(out) == content(in1) ++ content(in2) + } +} diff --git a/testcases/verification/compilation/ExprCompiler.scala b/testcases/verification/compilation/ExprCompiler.scala new file mode 100644 index 0000000000000000000000000000000000000000..d6b7d74caa52a4a49255e71781c86f1052c439db --- /dev/null +++ b/testcases/verification/compilation/ExprCompiler.scala @@ -0,0 +1,85 @@ +import leon.lang._ +import leon.lang.Option +import leon.collection._ +import leon.annotation._ +import leon.proof._ +import leon.lang.synthesis._ + +object TinyCertifiedCompiler { + abstract class ByteCode[A] + case class Load[A](c: A) extends ByteCode[A] // loads a constant in to the stack + case class OpInst[A]() extends ByteCode[A] + + abstract class ExprTree[A] + case class Const[A](c: A) extends ExprTree[A] + case class Op[A](e1: ExprTree[A], e2: ExprTree[A]) extends ExprTree[A] + + def compile[A](e: ExprTree[A]): List[ByteCode[A]] = { + e match { + case Const(c) => + Cons(Load(c), Nil()) + case Op(e1, e2) => + (compile(e1) ++ compile(e2)) ++ Cons(OpInst(), Nil()) + } + } + + def op[A](x: A, y: A): A = { // uninterpreted + ???[A] + } + + def run[A](bytecode: List[ByteCode[A]], S: List[A]): List[A] = { + (bytecode, S) match { + case (Cons(Load(c), tail), _) => + run(tail, c :: S) // adding elements to the head of the stack + case (Cons(OpInst(), tail), Cons(x, Cons(y, rest))) => + run(tail, op(y, x) :: rest) + case (Cons(_, tail), _) => + run(tail, S) + case (Nil(), _) => // no bytecode to execute + S + } + } + + def interpret[A](e: ExprTree[A]): A = { + e match { + case Const(c) => c + case Op(e1, e2) => op(interpret(e1), interpret(e2)) + } + } + + def runAppendLemma[A](l1: List[ByteCode[A]], l2: List[ByteCode[A]], S: List[A]): Boolean = { + // lemma + (run(l1 ++ l2, S) == run(l2, run(l1, S))) because + // induction scheme (induct over l1) + (l1 match { + case Nil() => true + case Cons(h, tail) => + (h, S) match { + case (Load(c), _) => + runAppendLemma(tail, l2, Cons[A](c, S)) + case (OpInst(), Cons(x, Cons(y, rest))) => + runAppendLemma(tail, l2, Cons[A](op(y, x), rest)) + case (_, _) => + runAppendLemma(tail, l2, S) + case _ => true + } + }) + }.holds + + def compileInterpretEquivalenceLemma[A](e: ExprTree[A], S: List[A]): Boolean = { + // lemma + (run(compile(e), S) == interpret(e) :: S) because + // induction scheme + (e match { + case Const(c) => true + case Op(e1, e2) => + // lemma instantiation + val c1 = compile(e1) + val c2 = compile(e2) + runAppendLemma((c1 ++ c2), Cons[ByteCode[A]](OpInst[A](), Nil()), S) && + runAppendLemma(c1, c2, S) && + compileInterpretEquivalenceLemma(e1, S) && + compileInterpretEquivalenceLemma(e2, Cons[A](interpret(e1), S)) + }) + }.holds +} diff --git a/testcases/verification/compilation/IntExprCompiler.scala b/testcases/verification/compilation/IntExprCompiler.scala new file mode 100644 index 0000000000000000000000000000000000000000..b7f46a49a27cf4092518afa08f0cef5ff8f724c7 --- /dev/null +++ b/testcases/verification/compilation/IntExprCompiler.scala @@ -0,0 +1,107 @@ +import leon.lang._ +import leon.lang.Option +import leon.collection._ +import leon.annotation._ +import leon.proof._ + +object TinyCertifiedCompiler { + + abstract class ByteCode + case class Load(c: BigInt) extends ByteCode // loads a constant in to the stack + case class OpInst() extends ByteCode + + abstract class ExprTree + case class Const(c: BigInt) extends ExprTree + case class Op(e1: ExprTree, e2: ExprTree) extends ExprTree + + def compile(e: ExprTree): List[ByteCode] = { + e match { + case Const(c) => + Cons(Load(c), Nil[ByteCode]()) + case Op(e1, e2) => + (compile(e1) ++ compile(e2)) ++ Cons(OpInst(), Nil[ByteCode]()) + } + } + + def op(x: BigInt, y: BigInt): BigInt = { + x - y + } + + def run(bytecode: List[ByteCode], S: List[BigInt]): List[BigInt] = { + (bytecode, S) match { + case (Cons(Load(c), tail), _) => + run(tail, Cons[BigInt](c, S)) // adding elements to the head of the stack + case (Cons(OpInst(), tail), Cons(x, Cons(y, rest))) => + run(tail, Cons[BigInt](op(y, x), rest)) + case (Cons(_, tail), _) => + run(tail, S) + case (Nil(), _) => // no bytecode to execute + S + } + } + + def interpret(e: ExprTree): BigInt = { + e match { + case Const(c) => c + case Op(e1, e2) => + op(interpret(e1), interpret(e2)) + } + } + + def runAppendLemma(l1: List[ByteCode], l2: List[ByteCode], S: List[BigInt]): Boolean = { + // lemma + (run(l1 ++ l2, S) == run(l2, run(l1, S))) because + // induction scheme (induct over l1) + (l1 match { + case Nil() => + true + case Cons(h, tail) => + (h, S) match { + case (Load(c), _) => + runAppendLemma(tail, l2, Cons[BigInt](c, S)) + case (OpInst(), Cons(x, Cons(y, rest))) => + runAppendLemma(tail, l2, Cons[BigInt](op(y, x), rest)) + case (_, _) => + runAppendLemma(tail, l2, S) + case _ => + true + } + }) + }.holds + + def run1(bytecode: List[ByteCode], S: List[BigInt]): List[BigInt] = { + (bytecode, S) match { + case (Cons(Load(c), tail), _) => + run1(tail, Cons[BigInt](c, S)) // adding elements to the head of the stack + case (Cons(OpInst(), tail), Cons(x, Cons(y, rest))) => + run1(tail, Cons[BigInt](op(x, y), rest)) + case (Cons(_, tail), _) => + run1(tail, S) + case (Nil(), _) => // no bytecode to execute + S + } + } + + def compileInterpretEquivalenceLemma1(e: ExprTree, S: List[BigInt]): Boolean = { + // lemma + (run1(compile(e), S) == interpret(e) :: S) + } holds + + def compileInterpretEquivalenceLemma(e: ExprTree, S: List[BigInt]): Boolean = { + // lemma + (run(compile(e), S) == interpret(e) :: S) because + // induction scheme + (e match { + case Const(c) => + true + case Op(e1, e2) => + // lemma instantiation + val c1 = compile(e1) + val c2 = compile(e2) + runAppendLemma((c1 ++ c2), Cons(OpInst(), Nil[ByteCode]()), S) && + runAppendLemma(c1, c2, S) && + compileInterpretEquivalenceLemma(e1, S) && + compileInterpretEquivalenceLemma(e2, Cons[BigInt](interpret(e1), S)) + }) + } holds +} diff --git a/testcases/verification/higher-order/valid/FlatMap.scala b/testcases/verification/higher-order/valid/FlatMap.scala index 56f8c47e86d835055e6c3a9ea44cb7b886e464fb..06d356c5fe5b1e7443831e7d541450c36449ef18 100644 --- a/testcases/verification/higher-order/valid/FlatMap.scala +++ b/testcases/verification/higher-order/valid/FlatMap.scala @@ -1,4 +1,5 @@ import leon.lang._ +import leon.proof._ import leon.collection._ object FlatMap { @@ -13,10 +14,10 @@ object FlatMap { } def associative_append_lemma_induct[T](l1: List[T], l2: List[T], l3: List[T]): Boolean = { - l1 match { - case Nil() => associative_append_lemma(l1, l2, l3) - case Cons(head, tail) => associative_append_lemma(l1, l2, l3) && associative_append_lemma_induct(tail, l2, l3) - } + associative_append_lemma(l1, l2, l3) because (l1 match { + case Nil() => true + case Cons(head, tail) => associative_append_lemma_induct(tail, l2, l3) + }) }.holds def flatMap[T,U](list: List[T], f: T => List[U]): List[U] = list match { @@ -29,20 +30,21 @@ object FlatMap { } def associative_lemma_induct[T,U,V](list: List[T], flist: List[U], glist: List[V], f: T => List[U], g: U => List[V]): Boolean = { - associative_lemma(list, f, g) && - append(glist, flatMap(append(flist, flatMap(list, f)), g)) == append(append(glist, flatMap(flist, g)), flatMap(list, (x: T) => flatMap(f(x), g))) && - (glist match { - case Cons(ghead, gtail) => - associative_lemma_induct(list, flist, gtail, f, g) - case Nil() => flist match { - case Cons(fhead, ftail) => - associative_lemma_induct(list, ftail, g(fhead), f, g) - case Nil() => list match { - case Cons(head, tail) => associative_lemma_induct(tail, f(head), Nil(), f, g) - case Nil() => true + associative_lemma(list, f, g) because { + append(glist, flatMap(append(flist, flatMap(list, f)), g)) == append(append(glist, flatMap(flist, g)), flatMap(list, (x: T) => flatMap(f(x), g))) because + (glist match { + case Cons(ghead, gtail) => + associative_lemma_induct(list, flist, gtail, f, g) + case Nil() => flist match { + case Cons(fhead, ftail) => + associative_lemma_induct(list, ftail, g(fhead), f, g) + case Nil() => list match { + case Cons(head, tail) => associative_lemma_induct(tail, f(head), Nil(), f, g) + case Nil() => true + } } - } - }) + }) + } }.holds } diff --git a/testcases/web/verification/10_FoldAssociative.scala b/testcases/web/verification/10_FoldAssociative.scala index ff0444f269dd6f9b2a2b7f2b54e13f54c6755ab2..6920231115fea93a1b7c7cefc27691a0ee2471bb 100644 --- a/testcases/web/verification/10_FoldAssociative.scala +++ b/testcases/web/verification/10_FoldAssociative.scala @@ -2,6 +2,7 @@ import leon._ import leon.lang._ +import leon.proof._ object FoldAssociative { @@ -56,7 +57,7 @@ object FoldAssociative { val f = (x: Int, s: Int) => x + s val l1 = take(list, x) val l2 = drop(list, x) - lemma_split(list, x) && (list match { + lemma_split(list, x) because (list match { case Cons(head, tail) if x > 0 => lemma_split_induct(tail, x - 1) case _ => true @@ -77,7 +78,7 @@ object FoldAssociative { val f = (x: Int, s: Int) => x + s val l1 = take(list, x) val l2 = drop(list, x) - lemma_reassociative(list, x) && (list match { + lemma_reassociative(list, x) because (list match { case Cons(head, tail) if x > 0 => lemma_reassociative_induct(tail, x - 1) case _ => true @@ -93,7 +94,7 @@ object FoldAssociative { def lemma_reassociative_presplit_induct(l1: List, l2: List): Boolean = { val f = (x: Int, s: Int) => x + s val list = append(l1, l2) - lemma_reassociative_presplit(l1, l2) && (l1 match { + lemma_reassociative_presplit(l1, l2) because (l1 match { case Cons(head, tail) => lemma_reassociative_presplit_induct(tail, l2) case Nil() => true