From 89f7731a5746982fdaf02378a62aefa0640cd801 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Tue, 15 Sep 2015 15:22:39 +0200 Subject: [PATCH] Evaluating foralls and lots of fixes --- .../leon/codegen/runtime/FiniteLambda.java | 29 - .../java/leon/codegen/runtime/Lambda.java | 2 +- .../LeonCodeGenRuntimeHenkinMonitor.java | 33 ++ .../leon/codegen/runtime/PartialLambda.java | 41 ++ .../scala/leon/codegen/CodeGeneration.scala | 541 +++++++++++++----- .../scala/leon/codegen/CompilationUnit.scala | 60 +- .../leon/codegen/CompiledExpression.scala | 13 +- .../scala/leon/datagen/NaiveDataGen.scala | 3 +- .../scala/leon/datagen/VanuatooDataGen.scala | 4 +- .../leon/evaluators/CodeGenEvaluator.scala | 23 +- .../leon/evaluators/DefaultEvaluator.scala | 3 +- .../scala/leon/evaluators/DualEvaluator.scala | 7 +- .../scala/leon/evaluators/Evaluator.scala | 24 +- .../leon/evaluators/RecursiveEvaluator.scala | 88 ++- .../leon/evaluators/TracingEvaluator.scala | 5 +- .../scala/leon/purescala/CheckForalls.scala | 106 ---- .../scala/leon/purescala/Constructors.scala | 17 - .../scala/leon/purescala/Definitions.scala | 4 +- src/main/scala/leon/purescala/ExprOps.scala | 91 ++- .../scala/leon/purescala/Expressions.scala | 4 + .../scala/leon/purescala/Extractors.scala | 47 +- .../scala/leon/purescala/Quantification.scala | 192 +++++++ .../leon/repair/RepairTrackingEvaluator.scala | 3 +- src/main/scala/leon/repair/Repairman.scala | 6 +- .../leon/solvers/EnumerationSolver.scala | 14 +- .../scala/leon/solvers/EvaluatingSolver.scala | 19 + .../scala/leon/solvers/GroundSolver.scala | 2 +- .../leon/solvers/QuantificationSolver.scala | 30 + .../scala/leon/solvers/SimpleSolverAPI.scala | 12 +- src/main/scala/leon/solvers/Solver.scala | 67 ++- .../solvers/combinators/PortfolioSolver.scala | 18 +- .../solvers/combinators/RewritingSolver.scala | 8 +- .../solvers/combinators/UnrollingSolver.scala | 88 ++- .../smtlib/SMTLIBCVC4ProofSolver.scala | 2 +- .../solvers/smtlib/SMTLIBCVC4Solver.scala | 14 +- .../smtlib/SMTLIBQuantifiedSolver.scala | 2 +- .../leon/solvers/smtlib/SMTLIBSolver.scala | 24 +- .../leon/solvers/smtlib/SMTLIBZ3Solver.scala | 5 +- .../solvers/templates/LambdaManager.scala | 52 +- .../templates/QuantificationManager.scala | 71 +-- .../leon/solvers/templates/Templates.scala | 61 +- .../solvers/templates/UnrollingBank.scala | 128 ++--- .../leon/solvers/z3/AbstractZ3Solver.scala | 28 +- .../scala/leon/solvers/z3/FairZ3Solver.scala | 95 ++- .../solvers/z3/UninterpretedZ3Solver.scala | 2 +- .../solvers/z3/Z3ModelReconstruction.scala | 6 +- .../leon/synthesis/rules/BottomUpTegis.scala | 3 +- .../leon/termination/LoopProcessor.scala | 4 +- .../scala/leon/termination/Processor.scala | 2 +- .../scala/leon/utils/IncrementalMap.scala | 51 +- .../scala/leon/utils/IncrementalSeq.scala | 16 +- .../scala/leon/utils/IncrementalSet.scala | 24 +- .../scala/leon/utils/IncrementalState.scala | 2 + .../scala/leon/utils/PreprocessingPhase.scala | 1 + .../verification/VerificationCondition.scala | 2 +- .../purescala/invalid/HOInvocations2.scala | 2 +- .../purescala/invalid/PositiveMap2.scala | 2 +- .../purescala/valid/HOInvocations2.scala | 2 +- .../purescala/valid/PositiveMap2.scala | 2 +- .../synthesis/StablePrintingSuite.scala | 1 + .../verification/VerificationSuite.scala | 3 +- 61 files changed, 1415 insertions(+), 796 deletions(-) delete mode 100644 src/main/java/leon/codegen/runtime/FiniteLambda.java create mode 100644 src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java create mode 100644 src/main/java/leon/codegen/runtime/PartialLambda.java delete mode 100644 src/main/scala/leon/purescala/CheckForalls.scala create mode 100644 src/main/scala/leon/purescala/Quantification.scala create mode 100644 src/main/scala/leon/solvers/EvaluatingSolver.scala create mode 100644 src/main/scala/leon/solvers/QuantificationSolver.scala diff --git a/src/main/java/leon/codegen/runtime/FiniteLambda.java b/src/main/java/leon/codegen/runtime/FiniteLambda.java deleted file mode 100644 index cefc75522..000000000 --- a/src/main/java/leon/codegen/runtime/FiniteLambda.java +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon.codegen.runtime; - -import java.util.HashMap; - -public final class FiniteLambda extends Lambda { - private final HashMap<Tuple, Object> _underlying = new HashMap<Tuple, Object>(); - private final Object dflt; - - public FiniteLambda(Object dflt) { - super(); - this.dflt = dflt; - } - - public void add(Tuple key, Object value) { - _underlying.put(key, value); - } - - @Override - public Object apply(Object[] args) { - Tuple tuple = new Tuple(args); - if (_underlying.containsKey(tuple)) { - return _underlying.get(tuple); - } else { - return dflt; - } - } -} diff --git a/src/main/java/leon/codegen/runtime/Lambda.java b/src/main/java/leon/codegen/runtime/Lambda.java index b266f83c9..a6abbef37 100644 --- a/src/main/java/leon/codegen/runtime/Lambda.java +++ b/src/main/java/leon/codegen/runtime/Lambda.java @@ -3,5 +3,5 @@ package leon.codegen.runtime; public abstract class Lambda { - public abstract Object apply(Object[] args); + public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException; } diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java new file mode 100644 index 000000000..62e7a7b3b --- /dev/null +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java @@ -0,0 +1,33 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.List; +import java.util.LinkedList; +import java.util.HashMap; + +public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { + private final HashMap<Integer, List<Tuple>> domains = new HashMap<Integer, List<Tuple>>(); + + public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { + super(maxInvocations); + } + + public void add(int type, Tuple input) { + if (!domains.containsKey(type)) domains.put(type, new LinkedList<Tuple>()); + domains.get(type).add(input); + } + + public List<Tuple> domain(Object obj, int type) { + List<Tuple> domain = new LinkedList<Tuple>(); + if (obj instanceof PartialLambda) { + for (Tuple key : ((PartialLambda) obj).mapping.keySet()) { + domain.add(key); + } + } + + domain.addAll(domains.get(type)); + + return domain; + } +} diff --git a/src/main/java/leon/codegen/runtime/PartialLambda.java b/src/main/java/leon/codegen/runtime/PartialLambda.java new file mode 100644 index 000000000..826cc5ed9 --- /dev/null +++ b/src/main/java/leon/codegen/runtime/PartialLambda.java @@ -0,0 +1,41 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.HashMap; + +public final class PartialLambda extends Lambda { + final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>(); + + public PartialLambda() { + super(); + } + + public void add(Tuple key, Object value) { + mapping.put(key, value); + } + + @Override + public Object apply(Object[] args) throws LeonCodeGenRuntimeException { + Tuple tuple = new Tuple(args); + if (mapping.containsKey(tuple)) { + return mapping.get(tuple); + } else { + throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments"); + } + } + + @Override + public boolean equals(Object that) { + if (that != null && (that instanceof PartialLambda)) { + return mapping.equals(((PartialLambda) that).mapping); + } else { + return false; + } + } + + @Override + public int hashCode() { + return 63 + 11 * mapping.hashCode(); + } +} diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 985c5970e..e8b232715 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -6,10 +6,11 @@ package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps.{simplestValue, matchToIfThenElse} +import purescala.ExprOps.{simplestValue, matchToIfThenElse, collect} import purescala.Types._ import purescala.Constructors._ import purescala.Extractors._ +import purescala.Quantification._ import cafebabe._ import cafebabe.AbstractByteCodes._ @@ -47,22 +48,24 @@ trait CodeGeneration { def withArgs(newArgs: Map[Identifier, Int]) = Locals(vars, args ++ newArgs, closures, isStatic) def withClosures(newClosures: Map[Identifier,(String,String,String)]) = Locals(vars, args, closures ++ newClosures, isStatic) - - /** The index of the monitor object in this function */ - def monitorIndex = if (isStatic) 0 else 1 } - + object NoLocals { /** Make a [[Locals]] object without any local variables */ def apply(isStatic : Boolean) = new Locals(Map(), Map(), Map(), isStatic) } + lazy val monitorID = FreshIdentifier("__$monitor") + private[codegen] val ObjectClass = "java/lang/Object" private[codegen] val BoxedIntClass = "java/lang/Integer" private[codegen] val BoxedBoolClass = "java/lang/Boolean" private[codegen] val BoxedCharClass = "java/lang/Character" private[codegen] val BoxedArrayClass = "leon/codegen/runtime/ArrayBox" + private[codegen] val JavaListClass = "java/util/List" + private[codegen] val JavaIteratorClass = "java/util/Iterator" + private[codegen] val TupleClass = "leon/codegen/runtime/Tuple" private[codegen] val SetClass = "leon/codegen/runtime/Set" private[codegen] val MapClass = "leon/codegen/runtime/Map" @@ -76,6 +79,7 @@ trait CodeGeneration { private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" private[codegen] val GenericValuesClass = "leon/codegen/runtime/GenericValues" private[codegen] val MonitorClass = "leon/codegen/runtime/LeonCodeGenRuntimeMonitor" + private[codegen] val HenkinClass = "leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor" def idToSafeJVMName(id: Identifier) = { scala.reflect.NameTransformer.encode(id.uniqueName).replaceAll("\\.", "\\$") @@ -147,15 +151,15 @@ trait CodeGeneration { * @param funDef The function definition to be compiled * @param owner The module/class that contains `funDef` */ - def compileFunDef(funDef : FunDef, owner : Definition) { + def compileFunDef(funDef: FunDef, owner: Definition) { val isStatic = owner.isInstanceOf[ModuleDef] - + val cf = classes(owner) val (_,mn,_) = leonFunDefToJVMInfo(funDef).get val paramsTypes = funDef.params.map(a => typeToJVM(a.getType)) - val realParams = if (params.requireMonitor) { + val realParams = if (requireMonitor) { ("L" + MonitorClass + ";") +: paramsTypes } else { paramsTypes @@ -176,16 +180,15 @@ trait CodeGeneration { METHOD_ACC_PUBLIC | METHOD_ACC_FINAL ).asInstanceOf[U2]) - + val ch = m.codeHandler - + // An offset we introduce to the parameters: // 1 if this is a method, so we need "this" in position 0 of the stack // 1 if we are monitoring - val paramsOffset = Seq(!isStatic, params.requireMonitor).count(x => x) - val newMapping = - funDef.params.map(_.id).zipWithIndex.toMap.mapValues(_ + paramsOffset) - + val idParams = (if (requireMonitor) Seq(monitorID) else Seq.empty) ++ funDef.params.map(_.id) + val newMapping = idParams.zipWithIndex.toMap.mapValues(_ + (if (!isStatic) 1 else 0)) + val body = funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name)) val bodyWithPre = if(funDef.hasPrecondition && params.checkContracts) { @@ -200,12 +203,14 @@ trait CodeGeneration { case _ => bodyWithPre } + val locals = Locals(newMapping, Map.empty, Map.empty, isStatic) + if (params.recordInvocations) { - // index of monitor object will be before the first Scala parameter - ch << ALoad(paramsOffset-1) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + load(monitorID, ch)(locals) + ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") } - mkExpr(bodyWithPost, ch)(Locals(newMapping, Map.empty, Map.empty, isStatic)) + mkExpr(bodyWithPost, ch)(locals) funDef.returnType match { case ValueType() => @@ -218,6 +223,318 @@ trait CodeGeneration { ch.freeze } + private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String] + private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda] + + private def compileLambda(l: Lambda, ch: CodeHandler)(implicit locals: Locals): Unit = { + val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(l) + val reverseSubst = structSubst.map(p => p._2 -> p._1) + val nl = normalized.asInstanceOf[Lambda] + + val closureIDs = purescala.ExprOps.variablesOf(nl).toSeq.sortBy(_.uniqueName) + val closuresWithoutMonitor = closureIDs.map(id => id -> typeToJVM(id.getType)) + val closures = if (requireMonitor) { + (monitorID -> s"L$MonitorClass;") +: closuresWithoutMonitor + } else closuresWithoutMonitor + + val afName = lambdaToClass.getOrElse(nl, { + val afName = "Leon$CodeGen$Lambda$" + lambdaCounter.nextGlobal + lambdaToClass += nl -> afName + classToLambda += afName -> nl + + val cf = new ClassFile(afName, Some(LambdaClass)) + + cf.setFlags(( + CLASS_ACC_SUPER | + CLASS_ACC_PUBLIC | + CLASS_ACC_FINAL + ).asInstanceOf[U2]) + + if (closures.isEmpty) { + cf.addDefaultConstructor + } else { + for ((id, jvmt) <- closures) { + val fh = cf.addField(jvmt, id.uniqueName) + fh.setFlags(( + FIELD_ACC_PUBLIC | + FIELD_ACC_FINAL + ).asInstanceOf[U2]) + } + + val cch = cf.addConstructor(closures.map(_._2).toList).codeHandler + + cch << ALoad(0) + cch << InvokeSpecial(LambdaClass, constructorName, "()V") + + var c = 1 + for ((id, jvmt) <- closures) { + cch << ALoad(0) + cch << (jvmt match { + case "I" | "Z" => ILoad(c) + case _ => ALoad(c) + }) + cch << PutField(afName, id.uniqueName, jvmt) + c += 1 + } + + cch << RETURN + cch.freeze + } + + locally { + val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") + + apm.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val argMapping = nl.args.map(_.id).zipWithIndex.toMap + val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap + + val newLocals = locals.withArgs(argMapping).withClosures(closureMapping) + + val apch = apm.codeHandler + + mkBoxedExpr(nl.body, apch)(newLocals) + + apch << ARETURN + + apch.freeze + } + + locally { + val emh = cf.addMethod("Z", "equals", s"L$ObjectClass;") + emh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val ech = emh.codeHandler + + val notRefEq = ech.getFreshLabel("notrefeq") + val notEq = ech.getFreshLabel("noteq") + val castSlot = ech.getFreshVar + + // If references are equal, trees are equal. + ech << ALoad(0) << ALoad(1) << If_ACmpNe(notRefEq) << Ldc(1) << IRETURN << Label(notRefEq) + + // We check the type (this also checks against null).... + ech << ALoad(1) << InstanceOf(afName) << IfEq(notEq) + + // ...finally, we compare fields one by one, shortcircuiting on disequalities. + if(closures.nonEmpty) { + ech << ALoad(1) << CheckCast(afName) << AStore(castSlot) + + for((id,jvmt) <- closures) { + ech << ALoad(0) << GetField(afName, id.uniqueName, jvmt) + ech << ALoad(castSlot) << GetField(afName, id.uniqueName, jvmt) + + jvmt match { + case "I" | "Z" => + ech << If_ICmpNe(notEq) + + case ot => + ech << InvokeVirtual(ObjectClass, "equals", s"(L$ObjectClass;)Z") << IfEq(notEq) + } + } + } + + ech << Ldc(1) << IRETURN << Label(notEq) << Ldc(0) << IRETURN + ech.freeze + } + + locally { + val hashFieldName = "$leon$hashCode" + cf.addField("I", hashFieldName).setFlags(FIELD_ACC_PRIVATE) + val hmh = cf.addMethod("I", "hashCode", "") + hmh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val hch = hmh.codeHandler + + val wasNotCached = hch.getFreshLabel("wasNotCached") + + hch << ALoad(0) << GetField(afName, hashFieldName, "I") << DUP + hch << IfEq(wasNotCached) + hch << IRETURN + hch << Label(wasNotCached) << POP + + hch << Ldc(closuresWithoutMonitor.size) << NewArray(s"$ObjectClass") + for (((id, jvmt),i) <- closuresWithoutMonitor.zipWithIndex) { + hch << DUP << Ldc(i) + hch << ALoad(0) << GetField(afName, id.uniqueName, jvmt) + mkBox(id.getType, hch) + hch << AASTORE + } + + hch << Ldc(afName.hashCode) + hch << InvokeStatic(HashingClass, "seqHash", s"([L$ObjectClass;I)I") << DUP + hch << ALoad(0) << SWAP << PutField(afName, hashFieldName, "I") + hch << IRETURN + + hch.freeze + } + + loader.register(cf) + + afName + }) + + val consSig = "(" + closures.map(_._2).mkString("") + ")V" + + ch << New(afName) << DUP + for ((id,jvmt) <- closures) { + if (id == monitorID) { + load(monitorID, ch) + } else { + mkExpr(Variable(reverseSubst(id)), ch) + } + } + ch << InvokeSpecial(afName, constructorName, consSig) + } + + private val typeIdCache = scala.collection.mutable.Map.empty[TypeTree, Int] + private[codegen] def typeId(tpe: TypeTree): Int = typeIdCache.get(tpe) match { + case Some(id) => id + case None => + val id = typeIdCache.size + typeIdCache += tpe -> id + id + } + + private def compileForall(f: Forall, ch: CodeHandler)(implicit locals: Locals): Unit = { + // make sure we have an available HenkinModel + val monitorOk = ch.getFreshLabel("monitorOk") + load(monitorID, ch) + ch << InstanceOf(HenkinClass) << IfNe(monitorOk) + ch << New(ImpossibleEvaluationClass) << DUP + ch << Ldc("Can't evaluate foralls without domain") + ch << InvokeSpecial(ImpossibleEvaluationClass, constructorName, "(Ljava/lang/String;)V") + ch << ATHROW + ch << Label(monitorOk) + + val Forall(fargs, TopLevelAnds(conjuncts)) = f + val endLabel = ch.getFreshLabel("forallEnd") + + for (conj <- conjuncts) { + val vars = purescala.ExprOps.variablesOf(conj) + val args = fargs.map(_.id).filter(vars) + val quantified = args.toSet + + val matchQuorums = extractQuorums(conj, quantified) + + val matcherIndexes = matchQuorums.flatten.distinct.zipWithIndex.toMap + + def buildLoops( + mis: List[(Expr, Seq[Expr], Int)], + localMapping: Map[Identifier, Int], + pointerMapping: Map[(Int, Int), Identifier] + ): Unit = mis match { + case (expr, args, qidx) :: rest => + load(monitorID, ch) + ch << CheckCast(HenkinClass) + + mkExpr(expr, ch) + ch << Ldc(typeId(expr.getType)) + ch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") + ch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") + + val loop = ch.getFreshLabel("loop") + val out = ch.getFreshLabel("out") + ch << Label(loop) + // it + ch << DUP + // it, it + ch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") + // it, hasNext + ch << IfEq(out) << DUP + // it, it + ch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") + // it, elem + ch << CheckCast(TupleClass) + + val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { + val id = FreshIdentifier("q", arg.getType, true) + val slot = ch.getFreshVar + + ch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") + mkUnbox(arg.getType, ch) + ch << (typeToJVM(arg.getType) match { + case "I" | "Z" => IStore(slot) + case _ => AStore(slot) + }) + + (id -> slot, (qidx -> aidx) -> id) + }).unzip + + ch << POP + // it + + buildLoops(rest, localMapping ++ newLoc, pointerMapping ++ newPtr) + + ch << Goto(loop) + ch << Label(out) << POP + + case Nil => + var okLabel: Option[String] = None + for (quorum <- matchQuorums) { + okLabel.foreach(ok => ch << Label(ok)) + okLabel = Some(ch.getFreshLabel("quorumOk")) + + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + + for (q @ (expr, args) <- quorum) { + val qidx = matcherIndexes(q) + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id), aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, pointerMapping(qidx -> aidx).toVariable) + } ++ equalities.map { + case (k1, k2) => Equals(pointerMapping(k1).toVariable, pointerMapping(k2).toVariable) + }) + + mkExpr(enabler, ch)(locals.withVars(localMapping)) + ch << IfEq(okLabel.get) + + val varsMap = args.map(id => id -> localMapping(pointerMapping(mapping(id)))).toMap + mkExpr(conj, ch)(locals.withVars(varsMap)) + ch << IfNe(okLabel.get) + + // -- Forall is false! -- + // POP all the iterators... + for (_ <- List.range(0, matcherIndexes.size)) ch << POP + + // ... and return false + ch << Ldc(0) << Goto(endLabel) + } + + ch << Label(okLabel.get) + } + + buildLoops(matcherIndexes.toList.map { case ((e, as), idx) => (e, as, idx) }, Map.empty, Map.empty) + } + + ch << Ldc(1) << Label(endLabel) + } + private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { e match { case Variable(id) => @@ -226,7 +543,7 @@ trait CodeGeneration { case Assert(cond, oerr, body) => mkExpr(IfExpr(Not(cond), Error(body.getType, oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch) - case en@Ensuring(_, _) => + case en @ Ensuring(_, _) => mkExpr(en.toAssert, ch) case Let(i,d,b) => @@ -267,8 +584,10 @@ trait CodeGeneration { throw CompilationException("Unknown class : " + cct.id) } ch << New(ccName) << DUP - if (params.requireMonitor) - ch << ALoad(locals.monitorIndex) + if (requireMonitor) { + load(monitorID, ch) + } + for((a, vd) <- as zip cct.classDef.fields) { vd.getType match { case TypeParameter(_) => @@ -404,30 +723,29 @@ trait CodeGeneration { throw CompilationException("Unknown method : " + tfd.id) } - if (params.requireMonitor) { - // index of monitor object will be before the first Scala parameter - ch << ALoad(locals.monitorIndex) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + if (requireMonitor) { + load(monitorID, ch) + ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") } - + // Get static field ch << GetStatic(className, fieldName, typeToJVM(tfd.fd.returnType)) - + // unbox field (tfd.fd.returnType, tfd.returnType) match { case (TypeParameter(_), tpe) => mkUnbox(tpe, ch) case _ => } - + case FunctionInvocation(TypedFunDef(fd, Seq(tp)), Seq(set)) if fd == program.library.setToList.get => - val IteratorClass = "java/util/Iterator" val nil = CaseClass(CaseClassType(program.library.Nil.get, Seq(tp)), Seq()) val cons = program.library.Cons.get val (consName, ccApplySig) = leonClassToJVMInfo(cons).getOrElse { throw CompilationException("Unknown class : " + cons) } - + mkExpr(nil, ch) mkExpr(set, ch) //if (params.requireMonitor) { @@ -436,49 +754,50 @@ trait CodeGeneration { // No dynamic dispatching/overriding in Leon, // so no need to take care of own vs. "super" methods - ch << InvokeVirtual(SetClass, "getElements", s"()L$IteratorClass;") - + ch << InvokeVirtual(SetClass, "getElements", s"()L$JavaIteratorClass;") + val loop = ch.getFreshLabel("loop") val out = ch.getFreshLabel("out") ch << Label(loop) // list, it ch << DUP // list, it, it - ch << InvokeInterface(IteratorClass, "hasNext", "()Z") + ch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") // list, it, hasNext ch << IfEq(out) // list, it ch << DUP2 // list, it, list, it - ch << InvokeInterface(IteratorClass, "next", s"()L$ObjectClass;") << SWAP + ch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") << SWAP // list, it, elem, list ch << New(consName) << DUP << DUP2_X2 // list, it, cons, cons, elem, list, cons, cons ch << POP << POP // list, it, cons, cons, elem, list - - if (params.requireMonitor) { - ch << ALoad(locals.monitorIndex) << DUP_X2 << POP + + if (requireMonitor) { + load(monitorID, ch) + ch << DUP_X2 << POP } ch << InvokeSpecial(consName, constructorName, ccApplySig) // list, it, newList ch << DUP_X2 << POP << SWAP << POP // newList, it ch << Goto(loop) - + ch << Label(out) // list, it ch << POP // list - + // Static lazy fields/ functions case fi @ FunctionInvocation(tfd, as) => val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - - if (params.requireMonitor) { - ch << ALoad(locals.monitorIndex) + + if (requireMonitor) { + load(monitorID, ch) } for((a, vd) <- as zip tfd.fd.params) { @@ -497,16 +816,16 @@ trait CodeGeneration { mkUnbox(tpe, ch) case _ => } - + // Strict fields are handled as fields case MethodInvocation(rec, _, tfd, _) if tfd.fd.canBeStrictField => val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - if (params.requireMonitor) { - // index of monitor object will be before the first Scala parameter - ch << ALoad(locals.monitorIndex) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + if (requireMonitor) { + load(monitorID, ch) + ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") } // Load receiver mkExpr(rec,ch) @@ -520,19 +839,19 @@ trait CodeGeneration { mkUnbox(tpe, ch) case _ => } - + // This is for lazy fields and real methods. // To access a lazy field, we call its accessor function. case MethodInvocation(rec, cd, tfd, as) => val (className, methodName, sig) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - + // Receiver of the method call mkExpr(rec,ch) - - if (params.requireMonitor) { - ch << ALoad(locals.monitorIndex) + + if (requireMonitor) { + load(monitorID, ch) } for((a, vd) <- as zip tfd.fd.params) { @@ -543,10 +862,10 @@ trait CodeGeneration { mkExpr(a, ch) } } - + // No interfaces in Leon, so no need to use InvokeInterface ch << InvokeVirtual(className, methodName, sig) - + (tfd.fd.returnType, tfd.returnType) match { case (TypeParameter(_), tpe) => mkUnbox(tpe, ch) @@ -566,84 +885,11 @@ trait CodeGeneration { mkUnbox(app.getType, ch) case l @ Lambda(args, body) => - val afName = "Leon$CodeGen$Lambda$" + lambdaCounter.nextGlobal - lambdas += afName -> l - - val cf = new ClassFile(afName, Some(LambdaClass)) - - cf.setFlags(( - CLASS_ACC_SUPER | - CLASS_ACC_PUBLIC | - CLASS_ACC_FINAL - ).asInstanceOf[U2]) - - val closures = purescala.ExprOps.variablesOf(l).toSeq.sortBy(_.uniqueName) - val closureTypes = closures.map(id => id.name -> typeToJVM(id.getType)) - - if (closureTypes.isEmpty) { - cf.addDefaultConstructor - } else { - for ((nme, jvmt) <- closureTypes) { - val fh = cf.addField(jvmt, nme) - fh.setFlags(( - FIELD_ACC_PUBLIC | - FIELD_ACC_FINAL - ).asInstanceOf[U2]) - } - - val cch = cf.addConstructor(closureTypes.map(_._2).toList).codeHandler - - cch << ALoad(0) - cch << InvokeSpecial(LambdaClass, constructorName, "()V") - - var c = 1 - for ((nme, jvmt) <- closureTypes) { - cch << ALoad(0) - cch << (jvmt match { - case "I" | "Z" => ILoad(c) - case _ => ALoad(c) - }) - cch << PutField(afName, nme, jvmt) - c += 1 - } - - cch << RETURN - cch.freeze - } - - locally { - - val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") - - apm.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val argMapping = args.map(_.id).zipWithIndex.toMap - val closureMapping = (closures zip closureTypes).map { case (id, (name, tpe)) => id -> (afName, name, tpe) }.toMap - - val newLocals = locals.withArgs(argMapping).withClosures(closureMapping) - - val apch = apm.codeHandler + compileLambda(l, ch) - mkBoxedExpr(body, apch)(newLocals) + case f @ Forall(args, body) => + compileForall(f, ch) - apch << ARETURN - - apch.freeze - } - - loader.register(cf) - - val consSig = "(" + closures.map(id => typeToJVM(id.getType)).mkString("") + ")V" - - ch << New(afName) << DUP - for (a <- closures) { - mkExpr(Variable(a), ch) - } - ch << InvokeSpecial(afName, constructorName, consSig) - // Arithmetic case Plus(l, r) => mkExpr(l, ch) @@ -1133,7 +1379,7 @@ trait CodeGeneration { * @param lzy The lazy field to be compiled * @param owner The module/class containing `lzy` */ - def compileLazyField(lzy : FunDef, owner : Definition) { + def compileLazyField(lzy: FunDef, owner: Definition) { ctx.reporter.internalAssertion(lzy.canBeLazyField, s"Trying to compile non-lazy ${lzy.id.name} as a lazy field") val (_, accessorName, _ ) = leonFunDefToJVMInfo(lzy).get @@ -1158,7 +1404,7 @@ trait CodeGeneration { // accessor method locally { - val parameters = if (params.requireMonitor) { + val parameters = if (requireMonitor) { Seq("L" + MonitorClass + ";") } else Seq() @@ -1173,7 +1419,7 @@ trait CodeGeneration { val body = lzy.body.getOrElse(throw CompilationException("Lazy field without body?")) val initLabel = ch.getFreshLabel("isInitialized") - if (params.requireMonitor) { + if (requireMonitor) { ch << ALoad(if (isStatic) 0 else 1) << InvokeVirtual(MonitorClass, "onInvoke", "()V") } @@ -1271,9 +1517,8 @@ trait CodeGeneration { } else { ch << ALoad(0) << SWAP << PutField (className, name, jvmType) } - } - - + } + def compileAbstractClassDef(acd : AbstractClassDef) { val cName = defToJVMName(acd) @@ -1314,8 +1559,7 @@ trait CodeGeneration { // definition of the constructor locally { - - val constrParams = if (params.requireMonitor) { + val constrParams = if (requireMonitor) { Seq("L" + MonitorClass + ";") } else Seq() @@ -1330,8 +1574,8 @@ trait CodeGeneration { case Some(parent) => val pName = defToJVMName(parent.classDef) // Load monitor object - if (params.requireMonitor) cch << ALoad(1) - val constrSig = if (params.requireMonitor) "(L" + MonitorClass + ";)V" else "()V" + if (requireMonitor) cch << ALoad(1) + val constrSig = if (requireMonitor) "(L" + MonitorClass + ";)V" else "()V" cch << InvokeSpecial(pName, constructorName, constrSig) case None => @@ -1349,7 +1593,7 @@ trait CodeGeneration { cch << RETURN cch.freeze } - + } /** @@ -1385,8 +1629,6 @@ trait CodeGeneration { } } - - def compileCaseClassDef(ccd: CaseClassDef) { val cName = defToJVMName(ccd) @@ -1407,7 +1649,6 @@ trait CodeGeneration { } locally { - val (fields, methods) = ccd.methods partition { _.canBeField } val (strictFields, lazyFields) = fields partition { _.canBeStrictField } @@ -1430,7 +1671,7 @@ trait CodeGeneration { val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.getType)) } // definition of the constructor - if(!params.doInstrument && !params.requireMonitor && ccd.fields.isEmpty && !ccd.methods.exists(_.canBeField)) { + if(!params.doInstrument && !requireMonitor && ccd.fields.isEmpty && !ccd.methods.exists(_.canBeField)) { cf.addDefaultConstructor } else { for((nme, jvmt) <- namesTypes) { @@ -1447,13 +1688,10 @@ trait CodeGeneration { } // If we are monitoring function calls, we have an extra argument on the constructor - val realArgs = if (params.requireMonitor) { + val realArgs = if (requireMonitor) { ("L" + MonitorClass + ";") +: (namesTypes map (_._2)) } else namesTypes map (_._2) - // Offset of the first Scala parameter of the constructor - val paramOffset = if (params.requireMonitor) 2 else 1 - val cch = cf.addConstructor(realArgs.toList).codeHandler if (params.doInstrument) { @@ -1462,7 +1700,7 @@ trait CodeGeneration { cch << PutField(cName, instrumentedField, "I") } - var c = paramOffset + var c = if (requireMonitor) 2 else 1 for((nme, jvmt) <- namesTypes) { cch << ALoad(0) cch << (jvmt match { @@ -1478,8 +1716,8 @@ trait CodeGeneration { // Load this cch << ALoad(0) // Load monitor object - if (params.requireMonitor) cch << ALoad(1) - val constrSig = if (params.requireMonitor) "(L" + MonitorClass + ";)V" else "()V" + if (requireMonitor) cch << ALoad(1) + val constrSig = if (requireMonitor) "(L" + MonitorClass + ";)V" else "()V" cch << InvokeSpecial(pName.get, constructorName, constrSig) } else { // Call constructor of java.lang.Object @@ -1487,7 +1725,6 @@ trait CodeGeneration { cch << InvokeSpecial(ObjectClass, constructorName, "()V") } - // Now initialize fields for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)} for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)} @@ -1612,7 +1849,7 @@ trait CodeGeneration { ).asInstanceOf[U2]) val hch = hmh.codeHandler - + val wasNotCached = hch.getFreshLabel("wasNotCached") hch << ALoad(0) << GetField(cName, hashFieldName, "I") << DUP @@ -1625,7 +1862,7 @@ trait CodeGeneration { hch << InvokeStatic(HashingClass, "seqHash", s"([L$ObjectClass;I)I") << DUP hch << ALoad(0) << SWAP << PutField(cName, hashFieldName, "I") hch << IRETURN - + hch.freeze } diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index a4dc5c0e3..a4597b88b 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -6,11 +6,12 @@ package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ +import purescala.ExprOps._ import purescala.Types._ import purescala.Extractors._ import purescala.Constructors._ - -import runtime.LeonCodeGenRuntimeMonitor +import codegen.runtime.LeonCodeGenRuntimeMonitor +import codegen.runtime.LeonCodeGenRuntimeHenkinMonitor import utils.UniqueCounter import cafebabe._ @@ -28,12 +29,17 @@ class CompilationUnit(val ctx: LeonContext, val program: Program, val params: CodeGenParams = CodeGenParams.default) extends CodeGeneration { + protected[codegen] val requireQuantification = program.definedFunctions.exists { fd => + exists { case _: Forall => true case _ => false } (fd.fullBody) + } + + protected[codegen] val requireMonitor = params.requireMonitor || requireQuantification + val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) - var lambdas = Map[String, Lambda]() - var classes = Map[Definition, ClassFile]() + var classes = Map[Definition, ClassFile]() var defToModuleOrClass = Map[Definition, Definition]() - + def defineClass(df: Definition) { val cName = defToJVMName(df) @@ -59,7 +65,7 @@ class CompilationUnit(val ctx: LeonContext, def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { classes.get(cd) match { case Some(cf) => - val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else "" + val monitorType = if (requireMonitor) "L"+MonitorClass+";" else "" val sig = "(" + monitorType + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" Some((cf.className, sig)) case _ => None @@ -78,7 +84,7 @@ class CompilationUnit(val ctx: LeonContext, */ def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { funDefInfo.get(fd).orElse { - val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else "" + val monitorType = if (requireMonitor) "L"+MonitorClass+";" else "" val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType) @@ -121,6 +127,21 @@ class CompilationUnit(val ctx: LeonContext, conss.last } + def modelToJVM(model: solvers.Model, maxInvocations: Int): LeonCodeGenRuntimeMonitor = model match { + case hModel: solvers.HenkinModel => + val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations) + for ((tpe, domain) <- hModel.domains; args <- domain) { + val tpeId = typeId(tpe) + // note here that it doesn't matter that `lhm` doesn't yet have its domains + // filled since all values in `args` should be grounded + val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] + lhm.add(tpeId, inputJvm) + } + lhm + case _ => + new LeonCodeGenRuntimeMonitor(maxInvocations) + } + /** Translates Leon values (not generic expressions) to JVM compatible objects. * * Currently, this method is only used to prepare arguments to reflective calls. @@ -155,7 +176,7 @@ class CompilationUnit(val ctx: LeonContext, case CaseClass(cct, args) => caseClassConstructor(cct.classDef) match { case Some(cons) => - val realArgs = if (params.requireMonitor) monitor +: args.map(valueToJVM) else args.map(valueToJVM) + val realArgs = if (requireMonitor) monitor +: args.map(valueToJVM) else args.map(valueToJVM) cons.newInstance(realArgs.toArray : _*).asInstanceOf[AnyRef] case None => ctx.reporter.fatalError("Case class constructor not found?!?") @@ -180,11 +201,9 @@ class CompilationUnit(val ctx: LeonContext, } m - case f @ purescala.Extractors.FiniteLambda(dflt, els) => - val l = new leon.codegen.runtime.FiniteLambda(valueToJVM(dflt)) - - for ((k,v) <- els) { - val ks = unwrapTuple(k, f.getType.asInstanceOf[FunctionType].from.size) + case f @ PartialLambda(mapping, _) => + val l = new leon.codegen.runtime.PartialLambda() + for ((ks,v) <- mapping) { // Force tuple even with 1/0 elems. val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple] val vJvm = valueToJVM(v) @@ -261,7 +280,7 @@ class CompilationUnit(val ctx: LeonContext, case (lambda: runtime.Lambda, _: FunctionType) => val cls = lambda.getClass - val l = lambdas(cls.getName) + val l = classToLambda(cls.getName) val closures = purescala.ExprOps.variablesOf(l).toSeq.sortBy(_.uniqueName) val closureVals = closures.map { id => val fieldVal = lambda.getClass.getField(id.name).get(lambda) @@ -308,7 +327,7 @@ class CompilationUnit(val ctx: LeonContext, val argsTypes = args.map(a => typeToJVM(a.getType)) - val realArgs = if (params.requireMonitor) { + val realArgs = if (requireMonitor) { ("L" + MonitorClass + ";") +: argsTypes } else { argsTypes @@ -328,11 +347,11 @@ class CompilationUnit(val ctx: LeonContext, val ch = m.codeHandler - val newMapping = if (params.requireMonitor) { - args.zipWithIndex.toMap.mapValues(_ + 1) - } else { - args.zipWithIndex.toMap - } + val newMapping = if (requireMonitor) { + args.zipWithIndex.toMap.mapValues(_ + 1) + (monitorID -> 0) + } else { + args.zipWithIndex.toMap + } mkExpr(e, ch)(Locals(newMapping, Map.empty, Map.empty, isStatic = true)) @@ -454,7 +473,6 @@ class CompilationUnit(val ctx: LeonContext, } for (m <- u.modules) compileModule(m) - } classes.values.foreach(loader.register) diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index e31e28e92..ad012bb74 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -8,11 +8,11 @@ import purescala.Expressions._ import cafebabe._ -import runtime.{LeonCodeGenRuntimeMonitor => LM } +import runtime.{LeonCodeGenRuntimeMonitor => LM, LeonCodeGenRuntimeHenkinMonitor => LHM} import java.lang.reflect.InvocationTargetException -class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr, argsDecl: Seq[Identifier]) { +class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, argsDecl: Seq[Identifier]) { private lazy val cl = unit.loader.loadClass(cf.className) private lazy val meth = cl.getMethods()(0) @@ -28,7 +28,7 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr def evalToJVM(args: Seq[AnyRef], monitor: LM): AnyRef = { assert(args.size == argsDecl.size) - val realArgs = if (params.requireMonitor) { + val realArgs = if (unit.requireMonitor) { monitor +: args } else { args @@ -51,11 +51,10 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr } } - def eval(args: Seq[Expr]) : Expr = { + def eval(model: solvers.Model) : Expr = { try { - val monitor = - new LM(params.maxFunctionInvocations) - evalFromJVM(argsToJVM(args, monitor),monitor) + val monitor = unit.modelToJVM(model, params.maxFunctionInvocations) + evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) } catch { case ite : InvocationTargetException => throw ite.getCause } diff --git a/src/main/scala/leon/datagen/NaiveDataGen.scala b/src/main/scala/leon/datagen/NaiveDataGen.scala index 394993520..121c42b8c 100644 --- a/src/main/scala/leon/datagen/NaiveDataGen.scala +++ b/src/main/scala/leon/datagen/NaiveDataGen.scala @@ -7,6 +7,7 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Types._ import purescala.Definitions._ +import purescala.Quantification._ import utils.StreamUtils._ import evaluators._ @@ -87,7 +88,7 @@ class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds : evaluator.compile(satisfying, ins).map { evalFun => val sat = EvaluationResults.Successful(BooleanLiteral(true)) - { (e: Seq[Expr]) => evalFun(e) == sat } + { (e: Seq[Expr]) => evalFun(new solvers.Model((ins zip e).toMap)) == sat } } getOrElse { { (e: Seq[Expr]) => false } } diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 0cf311562..b742ec58f 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -224,7 +224,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { type InstrumentedResult = (EvaluationResults.Result, Option[vanuatoo.Pattern[Expr, TypeTree]]) - def compile(expression : Expr, argorder : Seq[Identifier]) : Option[Expr=>InstrumentedResult] = { + def compile(expression: Expr, argorder: Seq[Identifier]) : Option[Expr=>InstrumentedResult] = { import leon.codegen.runtime.LeonCodeGenRuntimeException import leon.codegen.runtime.LeonCodeGenEvaluationException @@ -241,7 +241,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { Some((args : Expr) => { try { val monitor = new LeonCodeGenRuntimeMonitor(unit.params.maxFunctionInvocations) - val jvmArgs = ce.argsToJVM(Seq(args), monitor ) + val jvmArgs = ce.argsToJVM(Seq(args), monitor) val result = ce.evalFromJVM(jvmArgs, monitor) diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index 64456226b..769748dff 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -6,11 +6,12 @@ package evaluators import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ +import purescala.Quantification._ import codegen.CompilationUnit import codegen.CodeGenParams -class CodeGenEvaluator(ctx : LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) { +class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) { val name = "codegen-eval" val description = "Evaluator for PureScala expressions based on compilation to JVM" @@ -19,28 +20,29 @@ class CodeGenEvaluator(ctx : LeonContext, val unit : CompilationUnit) extends Ev this(ctx, new CompilationUnit(ctx, prog, params)) } - def eval(expression : Expr, mapping : Map[Identifier,Expr]) : EvaluationResult = { - val toPairs = mapping.toSeq + def eval(expression: Expr, model: solvers.Model) : EvaluationResult = { + val toPairs = model.toSeq compile(expression, toPairs.map(_._1)).map { e => - ctx.timers.evaluators.codegen.runtime.start() - val res = e(toPairs.map(_._2)) + val res = e(model) ctx.timers.evaluators.codegen.runtime.stop() res }.getOrElse(EvaluationResults.EvaluatorError("Couldn't compile expression.")) } - override def compile(expression : Expr, argorder : Seq[Identifier]) : Option[Seq[Expr]=>EvaluationResult] = { + override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { import leon.codegen.runtime.LeonCodeGenRuntimeException import leon.codegen.runtime.LeonCodeGenEvaluationException ctx.timers.evaluators.codegen.compilation.start() try { - val ce = unit.compileExpression(expression, argorder)(ctx) + val ce = unit.compileExpression(expression, args)(ctx) - Some((args : Seq[Expr]) => { - try { - EvaluationResults.Successful(ce.eval(args)) + Some((model: solvers.Model) => { + if (args.exists(arg => !model.isDefinedAt(arg))) { + EvaluationResults.EvaluatorError("Model undefined for free arguments") + } else try { + EvaluationResults.Successful(ce.eval(model)) } catch { case e : ArithmeticException => EvaluationResults.RuntimeError(e.getMessage) @@ -65,6 +67,7 @@ class CodeGenEvaluator(ctx : LeonContext, val unit : CompilationUnit) extends Ev } catch { case t: Throwable => ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage) + t.printStackTrace() None } finally { ctx.timers.evaluators.codegen.compilation.stop() diff --git a/src/main/scala/leon/evaluators/DefaultEvaluator.scala b/src/main/scala/leon/evaluators/DefaultEvaluator.scala index 148ac359a..d732d48c0 100644 --- a/src/main/scala/leon/evaluators/DefaultEvaluator.scala +++ b/src/main/scala/leon/evaluators/DefaultEvaluator.scala @@ -6,13 +6,14 @@ package evaluators import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ +import purescala.Quantification._ class DefaultEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) { type RC = DefaultRecContext type GC = GlobalContext def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC = new GlobalContext() + def initGC(model: solvers.Model) = new GlobalContext(model) case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext { def newVars(news: Map[Identifier, Expr]) = copy(news) diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index a058dd3a0..cd843fbb1 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -6,6 +6,7 @@ package evaluators import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ +import purescala.Quantification._ import purescala.Types._ import codegen._ @@ -17,7 +18,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte implicit val debugSection = utils.DebugSectionEvaluation def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC = new GlobalContext() + def initGC(model: solvers.Model) = new GlobalContext(model) var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) @@ -125,9 +126,9 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte } - override def eval(ex: Expr, mappings: Map[Identifier, Expr]) = { + override def eval(ex: Expr, model: solvers.Model) = { monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) - super.eval(ex, mappings) + super.eval(ex, model) } } diff --git a/src/main/scala/leon/evaluators/Evaluator.scala b/src/main/scala/leon/evaluators/Evaluator.scala index fa1f18352..e24c0e364 100644 --- a/src/main/scala/leon/evaluators/Evaluator.scala +++ b/src/main/scala/leon/evaluators/Evaluator.scala @@ -6,27 +6,35 @@ package evaluators import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ +import purescala.Quantification._ +import purescala.ExprOps._ -abstract class Evaluator(val context : LeonContext, val program : Program) extends LeonComponent { +import solvers.Model + +abstract class Evaluator(val context: LeonContext, val program: Program) extends LeonComponent { type EvaluationResult = EvaluationResults.Result /** Evaluates an expression, using `mapping` as a valuation function for the free variables. */ - def eval(expr: Expr, mapping: Map[Identifier,Expr]) : EvaluationResult + def eval(expr: Expr, model: Model) : EvaluationResult + + /** Evaluates an expression given a simple model (assumes expr is quantifier-free). + * Mainly useful for compatibility reasons. + */ + final def eval(expr: Expr, mapping: Map[Identifier, Expr]) : EvaluationResult = eval(expr, new Model(mapping)) /** Evaluates a ground expression. */ - final def eval(expr: Expr) : EvaluationResult = eval(expr, Map.empty) + final def eval(expr: Expr) : EvaluationResult = eval(expr, Model.empty) /** Compiles an expression into a function, where the arguments are the free variables in the expression. * `argorder` specifies in which order the arguments should be passed. * The default implementation uses the evaluation function each time, but evaluators are free * to (and encouraged to) apply any specialization. */ - def compile(expr : Expr, argorder : Seq[Identifier]) : Option[Seq[Expr]=>EvaluationResult] = Some( - (args : Seq[Expr]) => if(args.size != argorder.size) { - EvaluationResults.EvaluatorError("Wrong number of arguments for evaluation.") + def compile(expr: Expr, args: Seq[Identifier]) : Option[Model => EvaluationResult] = Some( + (model: Model) => if(args.exists(arg => !model.isDefinedAt(arg))) { + EvaluationResults.EvaluatorError("Wrong number of arguments for evaluation.") } else { - val mapping = argorder.zip(args).toMap - eval(expr, mapping) + eval (expr, model) } ) } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 38710837f..8ca885dbe 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -11,7 +11,9 @@ import purescala.Types._ import purescala.TypeOps.isSubtypeOf import purescala.Constructors._ import purescala.Extractors._ +import purescala.Quantification._ +import solvers.{Model, HenkinModel} import solvers.SolverFactory import synthesis.ConvertHoles.convertHoles @@ -41,25 +43,25 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } } - class GlobalContext { + class GlobalContext(val model: Model) { def maxSteps = RecursiveEvaluator.this.maxSteps var stepsLeft = maxSteps } def initRC(mappings: Map[Identifier, Expr]): RC - def initGC(): GC + def initGC(model: Model): GC // Used by leon-web, please do not delete var lastGC: Option[GC] = None private[this] var clpCache = Map[(Choose, Seq[Expr]), Expr]() - def eval(ex: Expr, mappings: Map[Identifier, Expr]) = { + def eval(ex: Expr, model: Model) = { try { - lastGC = Some(initGC()) + lastGC = Some(initGC(model)) ctx.timers.evaluators.recursive.runtime.start() - EvaluationResults.Successful(e(ex)(initRC(mappings), lastGC.get)) + EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get)) } catch { case so: StackOverflowError => EvaluationResults.EvaluatorError("Stack overflow") @@ -87,10 +89,16 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Application(caller, args) => e(caller) match { - case l@Lambda(params, body) => + case l @ Lambda(params, body) => val newArgs = args.map(e) val mapping = l.paramSubst(newArgs) e(body)(rctx.withNewVars(mapping), gctx) + case PartialLambda(mapping, _) => + mapping.find { case (pargs, res) => + (args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true)) + }.map(_._2).getOrElse { + throw EvalError("Cannot apply partial lambda outside of domain") + } case f => throw EvalError("Cannot apply non-lambda function " + f.asString) } @@ -217,6 +225,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (lv,rv) match { case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) + case (PartialLambda(m1, _), PartialLambda(m2, _)) => BooleanLiteral(m1.toSet == m2.toSet) case _ => BooleanLiteral(lv == rv) } @@ -487,8 +496,71 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int FiniteSet(els.map(e), base) case l @ Lambda(_, _) => - val mapping = variablesOf(l).map(id => id -> e(Variable(id))).toMap - replaceFromIDs(mapping, l) + val (nl, structSubst) = normalizeStructure(l) + val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap + replaceFromIDs(mapping, nl) + + case PartialLambda(mapping, tpe) => + PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), tpe) + + case f @ Forall(fargs, TopLevelAnds(conjuncts)) => + val henkinModel: HenkinModel = gctx.model match { + case hm: HenkinModel => hm + case _ => throw EvalError("Can't evaluate foralls without henkin model") + } + + e(andJoin(for (conj <- conjuncts) yield { + val vars = variablesOf(conj) + val args = fargs.map(_.id).filter(vars) + val quantified = args.toSet + + val matcherQuorums = extractQuorums(conj, quantified) + + val instantiations = matcherQuorums.flatMap { quorum => + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + + for (((expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id),aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { + case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d)) + } + + argSets.map { args => + val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { + case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } + }.toMap + + val map = mapping.map { case (id, key) => id -> argMap(key) } + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) + } ++ equalities.map { + case (k1, k2) => Equals(argMap(k1), argMap(k2)) + }) + + (enabler, map) + } + } + + e(andJoin(instantiations.map { case (enabler, mapping) => + e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx) + })) + })) case ArrayLength(a) => val FiniteArray(_, _, IntLiteral(length)) = e(a) diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index ea5ad3e0d..ec977763f 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -6,6 +6,7 @@ package evaluators import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ +import purescala.Quantification._ import purescala.Types._ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) extends RecursiveEvaluator(ctx, prog, maxSteps) { @@ -14,9 +15,9 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex def initRC(mappings: Map[Identifier, Expr]) = TracingRecContext(mappings, 2) - def initGC() = new TracingGlobalContext(Nil) + def initGC(model: solvers.Model) = new TracingGlobalContext(Nil, model) - class TracingGlobalContext(var values: List[(Tree, Expr)]) extends GlobalContext + class TracingGlobalContext(var values: List[(Tree, Expr)], model: solvers.Model) extends GlobalContext(model) case class TracingRecContext(mappings: Map[Identifier, Expr], tracingFrames: Int) extends RecContext { def newVars(news: Map[Identifier, Expr]) = copy(mappings = news) diff --git a/src/main/scala/leon/purescala/CheckForalls.scala b/src/main/scala/leon/purescala/CheckForalls.scala deleted file mode 100644 index bb9874373..000000000 --- a/src/main/scala/leon/purescala/CheckForalls.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package purescala - -import Common._ -import Definitions._ -import Expressions._ -import Extractors._ -import ExprOps._ - -object CheckForalls extends UnitPhase[Program] { - - val name = "Foralls" - val description = "Check syntax of foralls to guarantee sound instantiations" - - def apply(ctx: LeonContext, program: Program) = { - program.definedFunctions.foreach { fd => - if (fd.body.exists(b => exists { - case f: Forall => true - case _ => false - } (b))) ctx.reporter.warning("Universal quantification in function bodies is not supported in " + fd) - - val foralls = (fd.precondition.toSeq ++ fd.postcondition.toSeq).flatMap { prec => - collect[Forall] { - case f: Forall => Set(f) - case _ => Set.empty - } (prec) - } - - val free = fd.params.map(_.id).toSet ++ (fd.postcondition match { - case Some(Lambda(args, _)) => args.map(_.id) - case _ => Seq.empty - }) - - object Matcher { - def unapply(e: Expr): Option[(Identifier, Seq[Expr])] = e match { - case Application(Variable(id), args) if free(id) => Some(id -> args) - case ArraySelect(Variable(id), index) if free(id) => Some(id -> Seq(index)) - case MapGet(Variable(id), key) if free(id) => Some(id -> Seq(key)) - case _ => None - } - } - - for (Forall(args, TopLevelAnds(conjuncts)) <- foralls) { - val quantified = args.map(_.id).toSet - - for (conjunct <- conjuncts) { - val matchers = collect[(Identifier, Seq[Expr])] { - case Matcher(id, args) => Set(id -> args) - case _ => Set.empty - } (conjunct) - - if (matchers.exists { case (id, args) => - args.exists(arg => arg match { - case Matcher(_, _) => false - case Variable(id) => false - case _ if (variablesOf(arg) & quantified).nonEmpty => true - case _ => false - }) - }) ctx.reporter.warning("Matcher arguments must have simple form in " + conjunct) - - val id2Quant = matchers.foldLeft(Map.empty[Identifier, Set[Identifier]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - })) - } - - if (id2Quant.filter(_._2.nonEmpty).groupBy(_._2).size != 1) - ctx.reporter.warning("Multiple matchers must provide bijective matching in " + conjunct) - - foldRight[Set[Identifier]] { case (m, children) => - val q = children.toSet.flatten - - m match { - case Matcher(_, args) => - q -- args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - } - case LessThan(_: Variable, _: Variable) => q - case LessEquals(_: Variable, _: Variable) => q - case GreaterThan(_: Variable, _: Variable) => q - case GreaterEquals(_: Variable, _: Variable) => q - case And(_) => q - case Or(_) => q - case Implies(_, _) => q - case Operator(es, _) => - val vars = es.flatMap { - case Variable(id) => Set(id) - case _ => Set.empty[Identifier] - }.toSet - - if (!(q.isEmpty || (q.size == 1 && (vars & free).isEmpty))) - ctx.reporter.warning("Invalid operation " + m + " on quantified variables") - q -- vars - case Variable(id) if quantified(id) => Set(id) - case _ => q - } - } (conjunct) - } - } - } - } -} diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 9d347eccf..7747ec680 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -273,23 +273,6 @@ object Constructors { NonemptyArray(els.zipWithIndex.map{ _.swap }.toMap, defaultLength) } - /** Takes a mapping from keys to values and a default expression and return a lambda of the form - * {{{ - * (x1, ..., xn) => - * if ( key1 == (x1, ..., xn) ) value1 - * else if ( key2 == (x1, ..., xn) ) value2 - * ... - * else default - * }}} - */ - def finiteLambda(default: Expr, els: Seq[(Expr, Expr)], inputTypes: Seq[TypeTree]): Lambda = { - val args = inputTypes map { tpe => ValDef(FreshIdentifier("x", tpe, true)) } - val argsExpr = tupleWrap(args map { _.toVariable }) - val body = els.foldRight(default) { case ((key, value), default) => - IfExpr(Equals(argsExpr, key), value, default) - } - Lambda(args, body) - } /** $encodingof simplified `... == ...` (equality). * @see [[purescala.Expressions.Equals Equals]] */ diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 319d0d610..7a8056c6a 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -71,7 +71,7 @@ object Definitions { lazy val library = Library(this) def subDefinitions = units - + def definedFunctions = units.flatMap(_.definedFunctions) def definedClasses = units.flatMap(_.definedClasses) def classHierarchyRoots = units.flatMap(_.classHierarchyRoots) @@ -81,7 +81,7 @@ object Definitions { case md: ModuleDef => md }) } - + lazy val callGraph = new CallGraph(this) def caseClassDef(name: String) = definedClasses.collectFirst { diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 3f641000a..1d98f57b8 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -26,7 +26,7 @@ import solvers._ * - [[ExprOps.postMap postMap]] * - [[ExprOps.genericTransform genericTransform]] * - * These operations usually take a higher order function that gets apply to the + * These operations usually take a higher order function that gets applied to the * expression tree in some strategy. They provide an expressive way to build complex * operations on Leon expressions. * @@ -317,8 +317,8 @@ object ExprOps { case Variable(i) => subvs + i case LetDef(fd,_) => subvs -- fd.params.map(_.id) case Let(i,_,_) => subvs - i - case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders) - case Passes(_, _ , cses) => subvs -- cses.flatMap(_.pattern.binders) + case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders) + case Passes(_, _, cses) => subvs -- cses.flatMap(_.pattern.binders) case Lambda(args, _) => subvs -- args.map(_.id) case Forall(args, _) => subvs -- args.map(_.id) case _ => subvs @@ -369,19 +369,23 @@ object ExprOps { }).setPos(expr) } + def replacePatternBinders(pat: Pattern, subst: Map[Identifier, Identifier]): Pattern = { + def rec(p: Pattern): Pattern = p match { + case InstanceOfPattern(ob, ctd) => InstanceOfPattern(ob map subst, ctd) + case WildcardPattern(ob) => WildcardPattern(ob map subst) + case TuplePattern(ob, sps) => TuplePattern(ob map subst, sps map rec) + case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob map subst, ccd, sps map rec) + case UnapplyPattern(ob, obj, sps) => UnapplyPattern(ob map subst, obj, sps map rec) + case LiteralPattern(ob, lit) => LiteralPattern(ob map subst, lit) + } + + rec(pat) + } + /** ATTENTION: Unused, and untested * rewrites pattern-matching expressions to use fresh variables for the binders */ def freshenLocals(expr: Expr) : Expr = { - def rewritePattern(p: Pattern, sm: Map[Identifier,Identifier]) : Pattern = p match { - case InstanceOfPattern(ob, ctd) => InstanceOfPattern(ob map sm, ctd) - case WildcardPattern(ob) => WildcardPattern(ob map sm) - case TuplePattern(ob, sps) => TuplePattern(ob.map(sm(_)), sps.map(rewritePattern(_, sm))) - case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob.map(sm(_)), ccd, sps.map(rewritePattern(_, sm))) - case UnapplyPattern(ob, obj, sps) => UnapplyPattern(ob.map(sm(_)), obj, sps.map(rewritePattern(_, sm))) - case LiteralPattern(ob, lit) => LiteralPattern(ob map sm, lit) - } - def freshenCase(cse: MatchCase) : MatchCase = { val allBinders: Set[Identifier] = cse.pattern.binders val subMap: Map[Identifier,Identifier] = @@ -389,7 +393,7 @@ object ExprOps { val subVarMap: Map[Expr,Expr] = subMap.map(kv => Variable(kv._1) -> Variable(kv._2)) MatchCase( - rewritePattern(cse.pattern, subMap), + replacePatternBinders(cse.pattern, subMap), cse.optGuard map { replace(subVarMap, _)}, replace(subVarMap,cse.rhs) ) @@ -463,6 +467,63 @@ object ExprOps { fixpoint(postMap(rec))(expr) } + private val typedIds: scala.collection.mutable.Map[TypeTree, List[Identifier]] = + scala.collection.mutable.Map.empty.withDefaultValue(List.empty) + + /** Normalizes identifiers in an expression to enable some notion of structural + * equality between expressions on which usual equality doesn't make sense + * (i.e. closures). + * + * This function relies on the static map `typedIds` to ensure identical + * structures and must therefore be synchronized. + */ + def normalizeStructure(expr: Expr): (Expr, Map[Identifier, Identifier]) = synchronized { + val allVars : Seq[Identifier] = foldRight[Seq[Identifier]] { + (expr, idSeqs) => idSeqs.foldLeft(expr match { + case Lambda(args, _) => args.map(_.id) + case Forall(args, _) => args.map(_.id) + case LetDef(fd, _) => fd.params.map(_.id) + case Let(i, _, _) => Seq(i) + case MatchExpr(_, cses) => cses.flatMap(_.pattern.binders) + case Passes(_, _, cses) => cses.flatMap(_.pattern.binders) + case Variable(id) => Seq(id) + case _ => Seq.empty[Identifier] + })((acc, seq) => acc ++ seq) + } (expr).distinct + + val grouped : Map[TypeTree, Seq[Identifier]] = allVars.groupBy(_.getType) + val subst = grouped.foldLeft(Map.empty[Identifier, Identifier]) { case (subst, (tpe, ids)) => + val currentVars = typedIds(tpe) + + val freshCount = ids.size - currentVars.size + val typedVars = if (freshCount > 0) { + val allIds = currentVars ++ List.range(0, freshCount).map(_ => FreshIdentifier("x", tpe, true)) + typedIds += tpe -> allIds + allIds + } else { + currentVars + } + + subst ++ (ids zip typedVars) + } + + val normalized = postMap { + case Lambda(args, body) => Some(Lambda(args.map(vd => ValDef(subst(vd.id), vd.tpe)), body)) + case Forall(args, body) => Some(Forall(args.map(vd => ValDef(subst(vd.id), vd.tpe)), body)) + case Let(i, e, b) => Some(Let(subst(i), e, b)) + case MatchExpr(scrut, cses) => Some(MatchExpr(scrut, cses.map { cse => + cse.copy(pattern = replacePatternBinders(cse.pattern, subst)) + })) + case Passes(in, out, cses) => Some(Passes(in, out, cses.map { cse => + cse.copy(pattern = replacePatternBinders(cse.pattern, subst)) + })) + case Variable(id) => Some(Variable(subst(id))) + case _ => None + } (expr) + + (normalized, subst) + } + /** Returns '''true''' if the formula is Ground, * which means that it does not contain any variable ([[purescala.ExprOps#variablesOf]] e is empty) * and [[purescala.ExprOps#isDeterministic isDeterministic]] @@ -1244,7 +1305,7 @@ object ExprOps { } /** Returns the value for an identifier given a model. */ - def valuateWithModel(model: Map[Identifier, Expr])(id: Identifier): Expr = { + def valuateWithModel(model: Model)(id: Identifier): Expr = { model.getOrElse(id, simplestValue(id.getType)) } @@ -1252,7 +1313,7 @@ object ExprOps { * * Complete with simplest values in case of incomplete model. */ - def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Map[Identifier, Expr]): Expr = { + def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Model): Expr = { val valuator = valuateWithModel(model) _ replace(vars.map(id => Variable(id) -> valuator(id)).toMap, expr) } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index f738c3f38..c807ef816 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -226,6 +226,10 @@ object Expressions { } } + case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], tpe: FunctionType) extends Expr { + val getType = tpe + } + /* Universal Quantification */ case class Forall(args: Seq[ValDef], body: Expr) extends Expr { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index c1e1ff439..d2ccfa7da 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -41,6 +41,20 @@ object Extractors { Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head))) case Lambda(args, body) => Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head))) + case PartialLambda(mapping, tpe) => + val sze = tpe.from.size + 1 + val subArgs = mapping.flatMap { case (args, v) => args :+ v } + val builder = (as: Seq[Expr]) => { + def rec(kvs: Seq[Expr]): Seq[(Seq[Expr], Expr)] = kvs match { + case seq if seq.size >= sze => + val ((args :+ res), rest) = seq.splitAt(sze) + (args -> res) +: rec(rest) + case Seq() => Seq.empty + case _ => sys.error("unexpected number of key/value expressions") + } + PartialLambda(rec(as), tpe) + } + Some((subArgs, builder)) case Forall(args, body) => Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head))) @@ -261,39 +275,6 @@ object Extractors { def unapply[T <: Typed](e: T): Option[(T, TypeTree)] = Some((e, e.getType)) } - /* - * Extract a default expression and key-value pairs from a lambda constructed with - * Constructors.finiteLambda - */ - object FiniteLambda { - def unapply(lambda: Lambda): Option[(Expr, Seq[(Expr, Expr)])] = { - val inSize = lambda.getType.asInstanceOf[FunctionType].from.size - val Lambda(args, body) = lambda - def step(e: Expr): (Option[(Expr, Expr)], Expr) = e match { - case IfExpr(Equals(argsExpr, key), value, default) if { - val formal = args.map{ _.id } - val real = unwrapTuple(argsExpr, inSize).collect{ case Variable(id) => id} - formal == real - } => - (Some((key, value)), default) - case other => - (None, other) - } - - def rec(e: Expr): (Expr, Seq[(Expr, Expr)]) = { - step(e) match { - case (None, default) => (default, Seq()) - case (Some(pair), default) => - val (defaultRest, pairs) = rec(default) - (defaultRest, pair +: pairs) - } - } - - Some(rec(body)) - - } - } - object FiniteArray { def unapply(e: Expr): Option[(Map[Int, Expr], Option[Expr], Expr)] = e match { case EmptyArray(_) => diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala new file mode 100644 index 000000000..34392526d --- /dev/null +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -0,0 +1,192 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Extractors._ +import ExprOps._ +import Types._ + +object Quantification { + + def extractQuorums[A,B]( + matchers: Set[A], + quantified: Set[B], + margs: A => Set[A], + qargs: A => Set[B] + ): Seq[Set[A]] = { + def rec(oms: Seq[A], mSet: Set[A], qss: Seq[Set[B]]): Seq[Set[A]] = { + if (qss.exists(_ == quantified)) { + Seq(mSet) + } else { + var res = Seq.empty[Set[A]] + val rest = oms.scanLeft(List.empty[A])((acc, o) => o :: acc).drop(1) + for ((m :: ms) <- rest if margs(m).forall(mSet)) { + val qas = qargs(m) + if (!qss.exists(qs => qs.subsetOf(qas) || qas.subsetOf(qs))) { + res ++= rec(ms, mSet + m, qss ++ qss.map(_ ++ qas) :+ qas) + } + } + res + } + } + + def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) + val oms = matchers.toSeq.sortBy(m => -expand(m).size) + rec(oms, Set.empty, Seq.empty) + } + + def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Expr, Seq[Expr])]] = { + object QMatcher { + def unapply(e: Expr): Option[(Expr, Seq[Expr])] = e match { + case QuantificationMatcher(expr, args) => + if (args.exists { case Variable(id) => quantified(id) case _ => false }) { + Some(expr -> args) + } else { + None + } + case _ => None + } + } + + extractQuorums(collect { + case QMatcher(e, a) => Set(e -> a) + case _ => Set.empty[(Expr, Seq[Expr])] + } (expr), quantified, + (p: (Expr, Seq[Expr])) => p._2.collect { case QMatcher(e, a) => e -> a }.toSet, + (p: (Expr, Seq[Expr])) => p._2.collect { case Variable(id) if quantified(id) => id }.toSet) + } + + object HenkinDomains { + def empty = new HenkinDomains(Map.empty) + def apply(domains: Map[TypeTree, Set[Seq[Expr]]]) = new HenkinDomains(domains) + } + + class HenkinDomains (val domains: Map[TypeTree, Set[Seq[Expr]]]) { + def get(e: Expr): Set[Seq[Expr]] = e match { + case PartialLambda(mapping, _) => mapping.map(_._1).toSet + case _ => domains.get(e.getType) match { + case Some(domain) => domain + case None => scala.sys.error("Undefined Henkin domain for " + e) + } + } + } + + object QuantificationMatcher { + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(_: Application | _: FunctionInvocation, _) => None + case Application(e, args) => Some(e -> args) + case ArraySelect(arr, index) => Some(arr -> Seq(index)) + case MapApply(map, key) => Some(map -> Seq(key)) + // case ElementOfSet(set, elem) => Some(set -> Seq(elem)) + case _ => None + } + } + + object QuantificationTypeMatcher { + def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match { + case FunctionType(from, to) => Some(from -> to) + case ArrayType(base) => Some(Seq(Int32Type) -> base) + case MapType(from, to) => Some(Seq(from) -> to) + case SetType(base) => Some(Seq(base) -> BooleanType) + case _ => None + } + } + + object CheckForalls extends UnitPhase[Program] { + + val name = "Foralls" + val description = "Check syntax of foralls to guarantee sound instantiations" + + def apply(ctx: LeonContext, program: Program) = { + program.definedFunctions.foreach { fd => + if (fd.body.exists(b => exists { + case f: Forall => true + case _ => false + } (b))) ctx.reporter.warning("Universal quantification in function bodies is not supported in " + fd) + + val foralls = (fd.precondition.toSeq ++ fd.postcondition.toSeq).flatMap { prec => + collect[Forall] { + case f: Forall => Set(f) + case _ => Set.empty + } (prec) + } + + val free = fd.params.map(_.id).toSet ++ (fd.postcondition match { + case Some(Lambda(args, _)) => args.map(_.id) + case _ => Seq.empty + }) + + object Matcher { + def unapply(e: Expr): Option[(Identifier, Seq[Expr])] = e match { + case QuantificationMatcher(Variable(id), args) if free(id) => Some(id -> args) + case _ => None + } + } + + for (Forall(args, TopLevelAnds(conjuncts)) <- foralls) { + val quantified = args.map(_.id).toSet + + for (conjunct <- conjuncts) { + val matchers = collect[(Identifier, Seq[Expr])] { + case Matcher(id, args) => Set(id -> args) + case _ => Set.empty + } (conjunct) + + if (matchers.exists { case (id, args) => + args.exists(arg => arg match { + case Matcher(_, _) => false + case Variable(id) => false + case _ if (variablesOf(arg) & quantified).nonEmpty => true + case _ => false + }) + }) ctx.reporter.warning("Matcher arguments must have simple form in " + conjunct) + + val id2Quant = matchers.foldLeft(Map.empty[Identifier, Set[Identifier]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + })) + } + + if (id2Quant.filter(_._2.nonEmpty).groupBy(_._2).size != 1) + ctx.reporter.warning("Multiple matchers must provide bijective matching in " + conjunct) + + foldRight[Set[Identifier]] { case (m, children) => + val q = children.toSet.flatten + + m match { + case Matcher(_, args) => + q -- args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + } + case LessThan(_: Variable, _: Variable) => q + case LessEquals(_: Variable, _: Variable) => q + case GreaterThan(_: Variable, _: Variable) => q + case GreaterEquals(_: Variable, _: Variable) => q + case And(_) => q + case Or(_) => q + case Implies(_, _) => q + case Operator(es, _) => + val vars = es.flatMap { + case Variable(id) => Set(id) + case _ => Set.empty[Identifier] + }.toSet + + if (!(q.isEmpty || (q.size == 1 && (vars & free).isEmpty))) + ctx.reporter.warning("Invalid operation " + m + " on quantified variables") + q -- vars + case Variable(id) if quantified(id) => Set(id) + case _ => q + } + } (conjunct) + } + } + } + } + } +} diff --git a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala index 87905111f..664b9e3b2 100644 --- a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala +++ b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala @@ -8,6 +8,7 @@ import leon.purescala.Common._ import leon.purescala.Expressions._ import leon.purescala.Types._ import leon.purescala.Definitions._ +import leon.purescala.Quantification._ import leon.LeonContext import leon.evaluators.RecursiveEvaluator @@ -21,7 +22,7 @@ class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends Recursive type GC = GlobalContext def initRC(mappings: Map[Identifier, Expr]) = CollectingRecContext(mappings, None) - def initGC() = new GlobalContext() + def initGC(model: leon.solvers.Model) = new GlobalContext(model) type FI = (FunDef, Seq[Expr]) diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index e4cfb8aaf..d38af251c 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -9,6 +9,7 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ import purescala.DefOps._ +import purescala.Quantification._ import purescala.Constructors._ import purescala.Extractors.unwrapTuple @@ -202,10 +203,11 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou case None => _ => true case Some(pre) => - evaluator.compile(pre, fd.params map { _.id }) match { + val argIds = fd.params.map(_.id) + evaluator.compile(pre, argIds) match { case Some(evalFun) => val sat = EvaluationResults.Successful(BooleanLiteral(true)); - { (e: Seq[Expr]) => evalFun(e) == sat } + { (es: Seq[Expr]) => evalFun(new solvers.Model((argIds zip es).toMap)) == sat } case None => { _ => false } } diff --git a/src/main/scala/leon/solvers/EnumerationSolver.scala b/src/main/scala/leon/solvers/EnumerationSolver.scala index 3a2db100e..d7d52d371 100644 --- a/src/main/scala/leon/solvers/EnumerationSolver.scala +++ b/src/main/scala/leon/solvers/EnumerationSolver.scala @@ -46,7 +46,7 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends datagen = None } - private var modelMap = Map[Identifier, Expr]() + private var model = Model.empty def check: Option[Boolean] = { val timer = context.timers.solvers.enum.check.start() @@ -55,15 +55,15 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends if (interrupted) { None } else { - modelMap = Map() - val allFreeVars = freeVars.toSet.toSeq.sortBy(_.name) + model = Model.empty + val allFreeVars = freeVars.toSeq.sortBy(_.name) val allConstraints = constraints.toSeq val it = datagen.get.generateFor(allFreeVars, andJoin(allConstraints), 1, maxTried) if (it.hasNext) { - val model = it.next - modelMap = (allFreeVars zip model).toMap + val varModels = it.next + model = new Model((allFreeVars zip varModels).toMap) Some(true) } else { None @@ -78,8 +78,8 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends res } - def getModel: Map[Identifier, Expr] = { - modelMap + def getModel: Model = { + model } def free() = { diff --git a/src/main/scala/leon/solvers/EvaluatingSolver.scala b/src/main/scala/leon/solvers/EvaluatingSolver.scala new file mode 100644 index 000000000..3463235c9 --- /dev/null +++ b/src/main/scala/leon/solvers/EvaluatingSolver.scala @@ -0,0 +1,19 @@ +package leon +package solvers + +import purescala.Definitions._ +import evaluators._ + +trait EvaluatingSolver extends Solver { + val context: LeonContext + val program: Program + + val useCodeGen: Boolean + + lazy val evaluator: Evaluator = + if (useCodeGen) { + new CodeGenEvaluator(context, program) + } else { + new DefaultEvaluator(context, program) + } +} diff --git a/src/main/scala/leon/solvers/GroundSolver.scala b/src/main/scala/leon/solvers/GroundSolver.scala index 29ee75238..f38ddd188 100644 --- a/src/main/scala/leon/solvers/GroundSolver.scala +++ b/src/main/scala/leon/solvers/GroundSolver.scala @@ -26,7 +26,7 @@ class GroundSolver(val context: LeonContext, val program: Program) extends Solve private val assertions = new IncrementalSeq[Expr]() // Ground terms will always have the empty model - def getModel: Map[Identifier, Expr] = Map() + def getModel: Model = Model.empty def assertCnstr(expression: Expr): Unit = { assertions += expression diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala new file mode 100644 index 000000000..dc3e8584f --- /dev/null +++ b/src/main/scala/leon/solvers/QuantificationSolver.scala @@ -0,0 +1,30 @@ +package leon +package solvers + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Quantification._ +import purescala.Types._ + +class HenkinModel(mapping: Map[Identifier, Expr], doms: HenkinDomains) + extends Model(mapping) + with AbstractModel[HenkinModel] { + override def newBuilder = new HenkinModelBuilder(doms) + + def domains: Map[TypeTree, Set[Seq[Expr]]] = doms.domains + def domain(expr: Expr) = doms.get(expr) +} + +object HenkinModel { + def empty = new HenkinModel(Map.empty, HenkinDomains.empty) +} + +class HenkinModelBuilder(domains: HenkinDomains) + extends ModelBuilder + with AbstractModelBuilder[HenkinModel] { + override def result = new HenkinModel(mapBuilder.result, domains) +} + +trait QuantificationSolver { + def getModel: HenkinModel +} diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/leon/solvers/SimpleSolverAPI.scala index 37aefc1b5..33f6f1336 100644 --- a/src/main/scala/leon/solvers/SimpleSolverAPI.scala +++ b/src/main/scala/leon/solvers/SimpleSolverAPI.scala @@ -17,7 +17,7 @@ class SimpleSolverAPI(sf: SolverFactory[Solver]) { } } - def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = { + def solveSAT(expression: Expr): (Option[Boolean], Model) = { val s = sf.getNewSolver() try { s.assertCnstr(expression) @@ -25,16 +25,16 @@ class SimpleSolverAPI(sf: SolverFactory[Solver]) { case Some(true) => (Some(true), s.getModel) case Some(false) => - (Some(false), Map()) + (Some(false), Model.empty) case None => - (None, Map()) + (None, Model.empty) } } finally { sf.reclaim(s) } } - def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { + def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Model, Set[Expr]) = { val s = sf.getNewSolver() try { s.assertCnstr(expression) @@ -42,9 +42,9 @@ class SimpleSolverAPI(sf: SolverFactory[Solver]) { case Some(true) => (Some(true), s.getModel, Set()) case Some(false) => - (Some(false), Map(), s.getUnsatCore) + (Some(false), Model.empty, s.getUnsatCore) case None => - (None, Map(), Set()) + (None, Model.empty, Set()) } } finally { sf.reclaim(s) diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala index a9018b2de..3188031e9 100644 --- a/src/main/scala/leon/solvers/Solver.scala +++ b/src/main/scala/leon/solvers/Solver.scala @@ -5,9 +5,72 @@ package solvers import utils.{DebugSectionSolver, Interruptible} import purescala.Expressions._ -import leon.purescala.Common.{Tree, Identifier} +import purescala.Common.{Tree, Identifier} +import purescala.ExprOps._ import verification.VC +trait AbstractModel[+This <: Model with AbstractModel[This]] + extends scala.collection.IterableLike[(Identifier, Expr), This] { + + protected val mapping: Map[Identifier, Expr] + + def fill(allVars: Iterable[Identifier]): This = { + val builder = newBuilder + builder ++= mapping ++ (allVars.toSet -- mapping.keys).map(id => id -> simplestValue(id.getType)) + builder.result + } + + def ++(mapping: Map[Identifier, Expr]): This = { + val builder = newBuilder + builder ++= this.mapping ++ mapping + builder.result + } + + def filter(allVars: Iterable[Identifier]): This = { + val builder = newBuilder + for (p <- mapping.filterKeys(allVars.toSet)) { + builder += p + } + builder.result + } + + def iterator = mapping.iterator + def seq = mapping.seq +} + +trait AbstractModelBuilder[+This <: Model with AbstractModel[This]] + extends scala.collection.mutable.Builder[(Identifier, Expr), This] { + + import scala.collection.mutable.MapBuilder + protected val mapBuilder = new MapBuilder[Identifier, Expr, Map[Identifier, Expr]](Map.empty) + + def +=(elem: (Identifier, Expr)): this.type = { + mapBuilder += elem + this + } + + def clear(): Unit = mapBuilder.clear +} + +class Model(protected val mapping: Map[Identifier, Expr]) + extends AbstractModel[Model] + with (Identifier => Expr) { + + def newBuilder = new ModelBuilder + def isDefinedAt(id: Identifier): Boolean = mapping.isDefinedAt(id) + def get(id: Identifier): Option[Expr] = mapping.get(id) + def getOrElse[E >: Expr](id: Identifier, e: E): E = get(id).getOrElse(e) + def apply(id: Identifier): Expr = get(id).getOrElse { throw new IllegalArgumentException } +} + +object Model { + def empty = new Model(Map.empty) +} + +class ModelBuilder extends AbstractModelBuilder[Model] { + def result = new Model(mapBuilder.result) +} + trait Solver extends Interruptible { def name: String val context: LeonContext @@ -20,7 +83,7 @@ trait Solver extends Interruptible { } def check: Option[Boolean] - def getModel: Map[Identifier, Expr] + def getModel: Model def getResultSolver: Option[Solver] = Some(this) def free() diff --git a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala index 9997d5176..429bc24b8 100644 --- a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala +++ b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala @@ -21,7 +21,7 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, var constraints = List[Expr]() - protected var modelMap = Map[Identifier, Expr]() + protected var model = Model.empty protected var resultSolver: Option[Solver] = None override def getResultSolver = resultSolver @@ -35,17 +35,17 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, } def check: Option[Boolean] = { - modelMap = Map() + model = Model.empty context.reporter.debug("Running portfolio check") // solving val fs = solvers.map { s => Future { val result = s.check - val model: Map[Identifier, Expr] = if (result == Some(true)) { + val model: Model = if (result == Some(true)) { s.getModel } else { - Map() + Model.empty } (s, result, model) } @@ -55,7 +55,7 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, val res = Await.result(result, Duration.Inf) match { case Some((s, r, m)) => - modelMap = m + model = m resultSolver = s.getResultSolver resultSolver.foreach { solv => context.reporter.debug("Solved with "+solv) @@ -80,12 +80,12 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, def free() = { solvers.foreach(_.free) - modelMap = Map() + model = Model.empty constraints = Nil } - def getModel: Map[Identifier, Expr] = { - modelMap + def getModel: Model = { + model } def interrupt(): Unit = { @@ -98,7 +98,7 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, def reset() = { solvers.foreach(_.reset) - modelMap = Map() + model = Model.empty constraints = Nil } } diff --git a/src/main/scala/leon/solvers/combinators/RewritingSolver.scala b/src/main/scala/leon/solvers/combinators/RewritingSolver.scala index 8aeb812fc..2c414f2c0 100644 --- a/src/main/scala/leon/solvers/combinators/RewritingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/RewritingSolver.scala @@ -12,19 +12,19 @@ abstract class RewritingSolver[+S <: Solver, T](underlying: S) { /** The type T is used to encode any meta information useful, for instance, to reconstruct * models. */ - def rewriteCnstr(expression : Expr) : (Expr,T) + def rewriteCnstr(expression: Expr) : (Expr,T) - def reconstructModel(model : Map[Identifier,Expr], meta : T) : Map[Identifier,Expr] + def reconstructModel(model: Model, meta: T) : Model private var storedMeta : List[T] = Nil - def assertCnstr(expression : Expr) { + def assertCnstr(expression: Expr) { val (rewritten, meta) = rewriteCnstr(expression) storedMeta = meta :: storedMeta underlying.assertCnstr(rewritten) } - def getModel : Map[Identifier,Expr] = { + def getModel: Model = { storedMeta match { case Nil => underlying.getModel case m :: _ => reconstructModel(underlying.getModel, m) diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index aa936b283..b5cd0a04a 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -6,9 +6,11 @@ package combinators import purescala.Common._ import purescala.Definitions._ +import purescala.Quantification._ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ +import purescala.Types._ import utils._ import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre} @@ -16,20 +18,17 @@ import templates._ import utils.Interruptible import evaluators._ -class UnrollingSolver(val context: LeonContext, program: Program, underlying: Solver) extends Solver with NaiveAssumptionSolver { +class UnrollingSolver(val context: LeonContext, val program: Program, underlying: Solver) + extends Solver + with NaiveAssumptionSolver + with EvaluatingSolver + with QuantificationSolver { + val feelingLucky = context.findOptionOrDefault(optFeelingLucky) val useCodeGen = context.findOptionOrDefault(optUseCodeGen) val assumePreHolds = context.findOptionOrDefault(optAssumePre) - private val evaluator : Evaluator = { - if(useCodeGen) { - new CodeGenEvaluator(context, program) - } else { - new DefaultEvaluator(context, program) - } - } - - protected var lastCheckResult : (Boolean, Option[Boolean], Option[Map[Identifier,Expr]]) = (false, None, None) + protected var lastCheckResult : (Boolean, Option[Boolean], Option[HenkinModel]) = (false, None, None) private val freeVars = new IncrementalSet[Identifier]() private val constraints = new IncrementalSeq[Expr]() @@ -106,17 +105,61 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So def hasFoundAnswer = lastCheckResult._1 - def foundAnswer(res: Option[Boolean], model: Option[Map[Identifier, Expr]] = None) = { + private def extractModel(model: Model): HenkinModel = { + val allVars = freeVars.toSet + + def extract(b: Expr, m: Matcher[Expr]): Set[Seq[Expr]] = { + val QuantificationTypeMatcher(fromTypes, _) = m.tpe + val optEnabler = evaluator.eval(b).result + val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg)).result) + if (optEnabler == Some(BooleanLiteral(true)) && optArgs.forall(_.isDefined)) { + Set(optArgs.map(_.get)) + } else { + Set.empty + } + } + + val funDomains = allVars.flatMap(id => id.getType match { + case ft @ FunctionType(fromTypes, _) => + Some(id -> templateGenerator.manager.instantiations(Variable(id), ft).flatMap { + case (b, m) => extract(b, m) + }) + case _ => None + }).toMap.mapValues(_.toSet) + + val asDMap = model.map(p => funDomains.get(p._1) match { + case Some(domain) => + val mapping = domain.toSeq.map { es => + val ev: Expr = p._2 match { + case RawArrayValue(_, mapping, dflt) => + mapping.collectFirst { + case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v + } getOrElse dflt + case _ => scala.sys.error("Unexpected function encoding " + p._2) + } + es -> ev + } + + p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) + case None => p + }).toMap + + val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) + val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + + val domains = new HenkinDomains(typeDomains) + new HenkinModel(asDMap, domains) + } + + def foundAnswer(res: Option[Boolean], model: Option[HenkinModel] = None) = { lastCheckResult = (true, res, model) } - def isValidModel(model: Map[Identifier, Expr], silenceErrors: Boolean = false): Boolean = { + def isValidModel(model: HenkinModel, silenceErrors: Boolean = false): Boolean = { import EvaluationResults._ val expr = andJoin(constraints.toSeq) - val allVars = freeVars.toSet - - val fullModel = allVars.map(v => v -> model.getOrElse(v, simplestValue(v.getType))).toMap + val fullModel = model fill freeVars.toSet evaluator.eval(expr, fullModel) match { case Successful(BooleanLiteral(true)) => @@ -152,7 +195,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So reporter.debug(" - Running search...") solver.push() - solver.assertCnstr(andJoin((assumptions ++ unrollingBank.currentBlockers ++ unrollingBank.quantificationAssumptions).toSeq)) + solver.assertCnstr(andJoin((assumptions ++ unrollingBank.satisfactionAssumptions).toSeq)) val res = solver.check reporter.debug(" - Finished search with blocked literals") @@ -167,7 +210,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So foundAnswer(None) case Some(true) => // SAT - val model = solver.getModel + val model = extractModel(solver.getModel) solver.pop() foundAnswer(Some(true), Some(model)) @@ -188,7 +231,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So } solver.push() - solver.assertCnstr(andJoin(assumptions.toSeq ++ unrollingBank.quantificationAssumptions)) + solver.assertCnstr(andJoin(assumptions.toSeq ++ unrollingBank.refutationAssumptions)) val res2 = solver.check res2 match { @@ -198,7 +241,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So case Some(true) => if (feelingLucky && !interrupted) { - val model = solver.getModel + val model = extractModel(solver.getModel) // we might have been lucky :D if (isValidModel(model, silenceErrors = true)) { @@ -241,13 +284,12 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So } } - def getModel: Map[Identifier,Expr] = { - val allVars = freeVars.toSet + def getModel: HenkinModel = { lastCheckResult match { case (true, Some(true), Some(m)) => - m.filterKeys(allVars) + m.filter(freeVars.toSet) case _ => - Map() + HenkinModel.empty } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala index 86da0be80..1f0b760d7 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala @@ -37,7 +37,7 @@ class SMTLIBCVC4ProofSolver(context: LeonContext, program: Program) extends SMTL } // This solver does not support model extraction - override def getModel: Map[Identifier, Expr] = { + override def getModel: solvers.Model = { // We don't send the error through reporter because it may be caught by PortfolioSolver throw LeonFatalError(Some(s"Solver $name does not support model extraction.")) } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala index fa11aa421..7df712491 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala @@ -76,16 +76,16 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), RawArrayType(k,v)) => RawArrayValue(k, Map(), fromSMT(elem, v)) - case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), ft @ FunctionType(from,to)) => - finiteLambda(fromSMT(elem, to), Seq.empty, from) + case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), FunctionType(from,to)) => + RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), RawArrayType(k,v)) => val RawArrayValue(_, elems, base) = fromSMT(arr, tpe) RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) - case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), ft @ FunctionType(from,to)) => - val FiniteLambda(dflt, mapping) = fromSMT(arr, tpe) - finiteLambda(dflt, mapping :+ (fromSMT(key, tupleTypeWrap(from)) -> fromSMT(elem, to)), from) + case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), FunctionType(from,to)) => + val RawArrayValue(k, elems, base) = fromSMT(arr, tpe) + RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, to)), base) case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) => FiniteSet(elems.map(fromSMT(_, base)).toSet, base) @@ -103,8 +103,8 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol // FIXME (nicolas) // some versions of CVC4 seem to generate array constants with "as const" notation instead of the __array_store_all__ // one I've witnessed up to now. Don't know why this is happening... - case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), ft @ FunctionType(from, to)) => - finiteLambda(fromSMT(elem, to), Seq.empty, from) + case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), FunctionType(from, to)) => + RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), RawArrayType(k, v)) => RawArrayValue(k, Map(), fromSMT(elem, v)) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala index 0d761fe21..f27838071 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala @@ -54,7 +54,7 @@ trait SMTLIBQuantifiedSolver extends SMTLIBSolver { // Normally, UnrollingSolver tracks the input variable, but this one // is invoked alone so we have to filter them here - override def getModel: Map[Identifier, Expr] = { + override def getModel: leon.solvers.Model = { val filter = currentFunDef.map{ _.params.map{_.id}.toSet }.getOrElse( (_:Identifier) => true ) getModel(filter) } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index f6e00625a..aa18becd2 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -28,8 +28,9 @@ import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} import _root_.smtlib.theories._ import _root_.smtlib.{Interpreter => SMTInterpreter} -abstract class SMTLIBSolver(val context: LeonContext, - val program: Program) extends Solver with NaiveAssumptionSolver { +abstract class SMTLIBSolver(val context: LeonContext, val program: Program) + extends Solver + with NaiveAssumptionSolver { /* Solver name */ def targetName: String @@ -46,7 +47,6 @@ abstract class SMTLIBSolver(val context: LeonContext, protected val interpreter = getNewInterpreter(context) - /* Printing VCs */ protected lazy val out: Option[java.io.FileWriter] = if (reporter.isDebugEnabled) Some { val file = context.files.headOption.map(_.getName).getOrElse("NA") @@ -171,8 +171,8 @@ abstract class SMTLIBSolver(val context: LeonContext, case RawArrayType(from, to) => r - case ft @ FunctionType(from, to) => - finiteLambda(r.default, r.elems.toSeq, from) + case FunctionType(from, to) => + r case MapType(from, to) => // We expect a RawArrayValue with keys in from and values in Option[to], @@ -742,10 +742,10 @@ abstract class SMTLIBSolver(val context: LeonContext, } } - protected def getModel(filter: Identifier => Boolean): Map[Identifier, Expr] = { + protected def getModel(filter: Identifier => Boolean): Model = { val syms = variables.aSet.filter(filter).toList.map(variables.aToB) if (syms.isEmpty) { - Map() + Model.empty } else { val cmd: Command = GetValue( syms.head, @@ -754,20 +754,20 @@ abstract class SMTLIBSolver(val context: LeonContext, sendCommand(cmd) match { case GetValueResponseSuccess(valuationPairs) => - - valuationPairs.collect { + new Model(valuationPairs.collect { case (SimpleSymbol(sym), value) if variables.containsB(sym) => val id = variables.toA(sym) (id, fromSMT(value, id.getType)(Map(), Map())) - }.toMap + }.toMap) + case _ => - Map() //FIXME improve this + Model.empty //FIXME improve this } } } - override def getModel: Map[Identifier, Expr] = getModel( _ => true) + override def getModel: Model = getModel( _ => true) override def push(): Unit = { constructors.push() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala index 759dedbd8..ec91138e3 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala @@ -159,7 +159,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve } // EK: We use get-model instead in order to extract models for arrays - override def getModel: Map[Identifier, Expr] = { + override def getModel: Model = { val cmd = GetModel() @@ -199,8 +199,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve case _ => } - - model + new Model(model) } object ArrayMap { diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 3b355a4b7..3d5eec72c 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -9,51 +9,23 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import utils._ import Instantiation._ -class LambdaManager[T](protected val encoder: TemplateEncoder[T]) { +class LambdaManager[T](protected val encoder: TemplateEncoder[T]) extends IncrementalState { - protected type IdMap = Map[T, LambdaTemplate[T]] - protected def byID : IdMap = byIDStack.head - private var byIDStack : List[IdMap] = List(Map.empty) - private def byID_=(map: IdMap) : Unit = { - byIDStack = map :: byIDStack.tail - } + protected val byID = new IncrementalMap[T, LambdaTemplate[T]] + protected val byType = new IncrementalMap[FunctionType, Set[(T, LambdaTemplate[T])]].withDefaultValue(Set.empty) + protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) + protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) - protected type TypeMap = Map[FunctionType, Set[(T, LambdaTemplate[T])]] - protected def byType : TypeMap = byTypeStack.head - private var byTypeStack : List[TypeMap] = List(Map.empty.withDefaultValue(Set.empty)) - private def byType_=(map: TypeMap) : Unit = { - byTypeStack = map :: byTypeStack.tail - } + protected def incrementals: List[IncrementalState] = + List(byID, byType, applications, freeLambdas) - protected type ApplicationMap = Map[FunctionType, Set[(T, App[T])]] - protected def applications : ApplicationMap = applicationsStack.head - private var applicationsStack : List[ApplicationMap] = List(Map.empty.withDefaultValue(Set.empty)) - private def applications_=(map: ApplicationMap) : Unit = { - applicationsStack = map :: applicationsStack.tail - } - - protected type FreeMap = Map[FunctionType, Set[T]] - protected def freeLambdas : FreeMap = freeLambdasStack.head - private var freeLambdasStack : List[FreeMap] = List(Map.empty.withDefaultValue(Set.empty)) - private def freeLambdas_=(map: FreeMap) : Unit = { - freeLambdasStack = map :: freeLambdasStack.tail - } - - def push(): Unit = { - byIDStack = byID :: byIDStack - byTypeStack = byType :: byTypeStack - applicationsStack = applications :: applicationsStack - freeLambdasStack = freeLambdas :: freeLambdasStack - } - - def pop(lvl: Int): Unit = { - byIDStack = byIDStack.drop(lvl) - byTypeStack = byTypeStack.drop(lvl) - applicationsStack = applicationsStack.drop(lvl) - freeLambdasStack = freeLambdasStack.drop(lvl) - } + def clear(): Unit = incrementals.foreach(_.clear()) + def reset(): Unit = incrementals.foreach(_.reset()) + def push(): Unit = incrementals.foreach(_.push()) + def pop(): Unit = incrementals.foreach(_.pop()) def registerFree(lambdas: Seq[(TypeTree, T)]): Unit = { for ((tpe, idT) <- lambdas) tpe match { diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index a420b817b..74ffe69e1 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -4,8 +4,10 @@ package leon package solvers package templates +import leon.utils._ import purescala.Common._ import purescala.Extractors._ +import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ @@ -88,47 +90,29 @@ object QuantificationTemplate { class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) { - private val nextQ: () => T = { - val id: Identifier = FreshIdentifier("q", BooleanType, true) - () => encoder.encodeId(id) - } + private val quantifications = new IncrementalSeq[Quantification] + private val instantiated = new IncrementalSet[(T, Matcher[T])] + private val known = new IncrementalSet[T] - private var quantificationsStack: List[Seq[Quantification]] = List(Seq.empty) - private def quantifications: Seq[Quantification] = quantificationsStack.head - private def quantifications_=(qs: Seq[Quantification]): Unit = { - quantificationsStack = qs :: quantificationsStack.tail + private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = correspond(qm, m.caller, m.tpe) + private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = qm.tpe match { + case _: FunctionType => qm.tpe == tpe && (qm.caller == caller || !known(caller)) + case _ => qm.tpe == tpe } - private var instantiatedStack: List[Set[(T, Matcher[T])]] = List(Set.empty) - private def instantiated: Set[(T, Matcher[T])] = instantiatedStack.head - private def instantiated_=(ias: Set[(T, Matcher[T])]): Unit = { - instantiatedStack = ias :: instantiatedStack.tail - } + override protected def incrementals: List[IncrementalState] = + List(quantifications, instantiated, known) ++ super.incrementals - private var knownStack: List[Set[T]] = List(Set.empty) - private def known(idT: T): Boolean = knownStack.head(idT) || byID.isDefinedAt(idT) - private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = qm.tpe match { - case _: FunctionType => qm.tpe == m.tpe && (qm.caller == m.caller || !known(m.caller)) - case _ => qm.tpe == m.tpe - } + def assumptions: Seq[T] = quantifications.map(_.currentQ2Var).toSeq - override def push(): Unit = { - quantificationsStack = quantifications :: quantificationsStack - instantiatedStack = instantiated :: instantiatedStack - knownStack = knownStack.head :: knownStack - } + def instantiations: Seq[(T, Matcher[T])] = instantiated.toSeq - override def pop(lvl: Int): Unit = { - quantificationsStack = quantificationsStack.drop(lvl) - instantiatedStack = instantiatedStack.drop(lvl) - knownStack = knownStack.drop(lvl) - } - - def assumptions: Seq[T] = quantifications.map(_.currentQ2Var) + def instantiations(caller: T, tpe: TypeTree): Seq[(T, Matcher[T])] = + instantiated.toSeq.filter { case (b,m) => correspond(m, caller, tpe) } override def registerFree(ids: Seq[(TypeTree, T)]): Unit = { super.registerFree(ids) - knownStack = (knownStack.head ++ ids.map(_._2)) :: knownStack.tail + known ++= ids.map(_._2) } private class Quantification ( @@ -167,7 +151,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage .map(qm => if (qm == bindingMatcher) { bindingMatcher -> Set(blocker -> matcher) } else { - val instances: Set[(T, Matcher[T])] = instantiated.filter { case (b, m) => correspond(qm, m) } + val instances: Set[(T, Matcher[T])] = instantiated.filter { case (b, m) => correspond(qm, m) }.toSet // concrete applications can appear multiple times in the constraint, and this is also the case // for the current application for which we are generating the constraints @@ -303,23 +287,16 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage template.matchers merge rec(template.lambdas) } - val quantifiedMatchers = for { + val quantifiedMatchers = (for { (_, ms) <- allMatchers m @ Matcher(_, _, args, _) <- ms if args exists (_.left.exists(quantified)) - } yield m - - val matchQuorums: Seq[Set[Matcher[T]]] = quantifiedMatchers.toSet.subsets.filter { ms => - var doubled: Boolean = false - var qs: Set[T] = Set.empty - for (m @ Matcher(_, _, args, _) <- ms) { - val qargs = (args collect { case Left(a) if quantified(a) => a }).toSet - if ((qs & qargs).nonEmpty) doubled = true - qs ++= qargs - } + } yield m).toSet - !doubled && (qs == quantified) - }.toList + val matchQuorums: Seq[Set[Matcher[T]]] = purescala.Quantification.extractQuorums( + quantifiedMatchers, quantified, + (m: Matcher[T]) => m.args.collect { case Right(m) if quantifiedMatchers(m) => m }.toSet, + (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) var instantiation = Instantiation.empty[T] @@ -387,7 +364,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage instantiation ++= quantification.instantiate(b, m) } - quantifications :+= quantification + quantifications += quantification quantification.qs._2 } diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index 678050fa1..f4714bd9e 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -5,11 +5,12 @@ package solvers package templates import purescala.Common._ +import purescala.Definitions._ import purescala.Expressions._ +import purescala.Quantification._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.Definitions._ case class App[T](caller: T, tpe: FunctionType, args: Seq[T]) { override def toString = "(" + caller + " : " + tpe + ")" + args.mkString("(", ",", ")") @@ -122,15 +123,6 @@ object Template { } } - private object MatchExtractor { - def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case ApplicationExtractor(caller, args) => Some(caller -> args) - case ArraySelect(arr, index) => Some(arr -> Seq(index)) - case MapGet(map, key) => Some(map -> Seq(key)) - case _ => None - } - } - private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") @@ -174,7 +166,7 @@ object Template { val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(_._2))) val optIdApp = optApp.map { case (idT, tpe) => App(idT, tpe, arguments.map(_._2)) } - val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) + lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) .map(tfd => invocationMatcher(encodeExpr)(tfd, arguments.map(_._1.toVariable))) val (blockers, applications, matchers) = { @@ -204,7 +196,7 @@ object Template { val result = res.flatten.toMap result ++ (expr match { - case MatchExtractor(c, args) => + case QuantificationMatcher(c, args) => // Note that we rely here on the fact that foldRight visits the matcher's arguments first, // so any Matcher in arguments will belong to the `result` map val encodedArgs = args.map(arg => result.get(arg) match { @@ -382,47 +374,6 @@ class FunctionTemplate[T] private( object LambdaTemplate { - private var typedIds : Map[TypeTree, List[Identifier]] = Map.empty.withDefaultValue(List.empty) - - private def structuralKey[T](lambda: Lambda, dependencies: Map[Identifier, T]): (Lambda, Map[Identifier,T]) = { - - def closureIds(expr: Expr): Seq[Identifier] = { - val allVars : Seq[Identifier] = foldRight[Seq[Identifier]] { - (expr, idSeqs) => idSeqs.foldLeft(expr match { - case Variable(id) => Seq(id) - case _ => Seq.empty[Identifier] - })((acc, seq) => acc ++ seq) - } (expr) - - val vars = variablesOf(expr) - allVars.filter(vars(_)).distinct - } - - val grouped : Map[TypeTree, Seq[Identifier]] = (lambda.args.map(_.id) ++ closureIds(lambda)).groupBy(_.getType) - val subst : Map[Identifier, Identifier] = grouped.foldLeft(Map.empty[Identifier,Identifier]) { case (subst, (tpe, ids)) => - val currentVars = typedIds(tpe) - - val freshCount = ids.size - currentVars.size - val typedVars = if (freshCount > 0) { - val allIds = currentVars ++ List.range(0, freshCount).map(_ => FreshIdentifier("x", tpe, true)) - typedIds += tpe -> allIds - allIds - } else { - currentVars - } - - subst ++ (ids zip typedVars) - } - - val newArgs = lambda.args.map(vd => ValDef(subst(vd.id), vd.tpe)) - val newBody = replaceFromIDs(subst.mapValues(_.toVariable), lambda.body) - val structuralLambda = Lambda(newArgs, newBody) - - val newDeps = dependencies.map { case (id, idT) => subst(id) -> idT } - - structuralLambda -> newDeps - } - def apply[T]( ids: (Identifier, T), encoder: TemplateEncoder[T], @@ -448,7 +399,9 @@ object LambdaTemplate { "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() } - val (key, keyDeps) = structuralKey(lambda, dependencies) + val (structuralLambda, structSubst) = normalizeStructure(lambda) + val keyDeps = dependencies.map { case (id, idT) => structSubst(id) -> idT } + val key = structuralLambda.asInstanceOf[Lambda] new LambdaTemplate[T]( ids._1, diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index 617f1c5f4..ddfb22b0b 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -18,90 +18,84 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat private val encoder = templateGenerator.encoder private val manager = templateGenerator.manager - // Keep which function invocation is guarded by which guard, - // also specify the generation of the blocker. - private val callInfos = new IncrementalMap[T, (Int, Int, T, Set[TemplateCallInfo[T]])]() - private def callInfo = callInfos.toMap - // Function instantiations have their own defblocker - private val defBlockerss = new IncrementalMap[TemplateCallInfo[T], T]() - private def defBlockers = defBlockerss.toMap - - private val appInfos = new IncrementalMap[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]() - private def appInfo = appInfos.toMap - - private val appBlockerss = new IncrementalMap[(T, App[T]), T]() - private def appBlockers = appBlockerss.toMap + private val defBlockers = new IncrementalMap[TemplateCallInfo[T], T]() + // Keep which function invocation is guarded by which guard, + // also specify the generation of the blocker. + private val callInfos = new IncrementalMap[T, (Int, Int, T, Set[TemplateCallInfo[T]])]() + private val appInfos = new IncrementalMap[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]() + private val appBlockers = new IncrementalMap[(T, App[T]), T]() private val blockerToApps = new IncrementalMap[T, (T, App[T])]() - private def blockerToApp = blockerToApps.toMap - - private val functionVarss = new IncrementalMap[TypeTree, Set[T]]() - private def functionVars = functionVarss.toMap + private val functionVars = new IncrementalMap[TypeTree, Set[T]]() def push() { callInfos.push() - defBlockerss.push() + defBlockers.push() appInfos.push() - appBlockerss.push() + appBlockers.push() blockerToApps.push() - functionVarss.push() + functionVars.push() } def pop() { callInfos.pop() - defBlockerss.pop() + defBlockers.pop() appInfos.pop() - appBlockerss.pop() + appBlockers.pop() blockerToApps.pop() - functionVarss.pop() + functionVars.pop() } def clear() { callInfos.clear() - defBlockerss.clear() + defBlockers.clear() appInfos.clear() - appBlockerss.clear() - functionVarss.clear() + appBlockers.clear() + blockerToApps.clear() + functionVars.clear() } def reset() { callInfos.reset() - defBlockerss.reset() + defBlockers.reset() appInfos.reset() - appBlockerss.reset() - functionVarss.reset() + appBlockers.reset() + blockerToApps.clear() + functionVars.reset() } def dumpBlockers() = { - val generations = (callInfo.map(_._2._1).toSet ++ appInfo.map(_._2._1).toSet).toSeq.sorted + val generations = (callInfos.map(_._2._1).toSet ++ appInfos.map(_._2._1).toSet).toSeq.sorted generations.foreach { generation => reporter.debug("--- " + generation) - for ((b, (gen, origGen, ast, fis)) <- callInfo if gen == generation) { + for ((b, (gen, origGen, ast, fis)) <- callInfos if gen == generation) { reporter.debug(f". $b%15s ~> "+fis.mkString(", ")) } - for ((app, (gen, origGen, b, notB, infos)) <- appInfo if gen == generation) { + for ((app, (gen, origGen, b, notB, infos)) <- appInfos if gen == generation) { reporter.debug(f". $b%15s ~> "+infos.mkString(", ")) } } } - def canUnroll = callInfo.nonEmpty || appInfo.nonEmpty + def satisfactionAssumptions = currentBlockers ++ manager.assumptions + + def refutationAssumptions = manager.assumptions - def currentBlockers = callInfo.map(_._2._3).toSeq ++ appInfo.map(_._2._4).toSeq + def canUnroll = callInfos.nonEmpty || appInfos.nonEmpty - def quantificationAssumptions = manager.assumptions + def currentBlockers = callInfos.map(_._2._3).toSeq ++ appInfos.map(_._2._4).toSeq def getBlockersToUnlock: Seq[T] = { - if (callInfo.isEmpty && appInfo.isEmpty) { + if (callInfos.isEmpty && appInfos.isEmpty) { Seq.empty } else { - val minGeneration = (callInfo.values.map(_._1) ++ appInfo.values.map(_._1)).min - val callBlocks = callInfo.filter(_._2._1 == minGeneration).toSeq.map(_._1) - val appBlocks = appInfo.values.filter(_._1 == minGeneration).toSeq.map(_._3) + val minGeneration = (callInfos.values.map(_._1) ++ appInfos.values.map(_._1)).min + val callBlocks = callInfos.filter(_._2._1 == minGeneration).toSeq.map(_._1) + val appBlocks = appInfos.values.filter(_._1 == minGeneration).toSeq.map(_._3) callBlocks ++ appBlocks } } @@ -109,7 +103,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat private def registerCallBlocker(gen: Int, id: T, fis: Set[TemplateCallInfo[T]]) { val notId = encoder.mkNot(id) - callInfo.get(id) match { + callInfos.get(id) match { case Some((exGen, origGen, _, exFis)) => // PS: when recycling `b`s, this assertion becomes dangerous. // It's better to simply take the max of the generations. @@ -117,17 +111,17 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat val minGen = gen min exGen - callInfo += id -> (minGen, origGen, notId, fis++exFis) + callInfos += id -> (minGen, origGen, notId, fis++exFis) case None => - callInfo += id -> (gen, gen, notId, fis) + callInfos += id -> (gen, gen, notId, fis) } } private def registerAppBlocker(gen: Int, app: (T, App[T]), info: Set[TemplateAppInfo[T]]) : Unit = { - appInfo.get(app) match { + appInfos.get(app) match { case Some((exGen, origGen, b, notB, exInfo)) => val minGen = gen min exGen - appInfo += app -> (minGen, origGen, b, notB, exInfo ++ info) + appInfos += app -> (minGen, origGen, b, notB, exInfo ++ info) case None => val b = appBlockers.get(app) match { @@ -136,8 +130,8 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } val notB = encoder.mkNot(b) - appInfo += app -> (gen, gen, b, notB, info) - blockerToApp += b -> app + appInfos += app -> (gen, gen, b, notB, info) + blockerToApps += b -> app } } @@ -154,7 +148,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } private def extendAppBlock(app: (T, App[T]), infos: Set[TemplateAppInfo[T]]) : T = { - assert(!appInfo.isDefinedAt(app), "appInfo -= app must have been called to ensure blocker freshness") + assert(!appInfos.isDefinedAt(app), "appInfo -= app must have been called to ensure blocker freshness") assert(appBlockers.isDefinedAt(app), "freshAppBlocks must have been called on app before it can be unlocked") assert(infos.nonEmpty, "No point in extending blockers if no templates have been unrolled!") @@ -207,45 +201,45 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def nextGeneration(gen: Int) = gen + 3 def decreaseAllGenerations() = { - for ((block, (gen, origGen, ast, infos)) <- callInfo) { + for ((block, (gen, origGen, ast, infos)) <- callInfos) { // We also decrease the original generation here - callInfo += block -> (math.max(1,gen-1), math.max(1,origGen-1), ast, infos) + callInfos += block -> (math.max(1,gen-1), math.max(1,origGen-1), ast, infos) } - for ((app, (gen, origGen, b, notB, infos)) <- appInfo) { - appInfo += app -> (math.max(1,gen-1), math.max(1,origGen-1), b, notB, infos) + for ((app, (gen, origGen, b, notB, infos)) <- appInfos) { + appInfos += app -> (math.max(1,gen-1), math.max(1,origGen-1), b, notB, infos) } } def promoteBlocker(b: T) = { - if (callInfo contains b) { - val (_, origGen, ast, fis) = callInfo(b) + if (callInfos contains b) { + val (_, origGen, ast, fis) = callInfos(b) - callInfo += b -> (1, origGen, ast, fis) + callInfos += b -> (1, origGen, ast, fis) } - if (blockerToApp contains b) { - val app = blockerToApp(b) - val (_, origGen, _, notB, infos) = appInfo(app) + if (blockerToApps contains b) { + val app = blockerToApps(b) + val (_, origGen, _, notB, infos) = appInfos(app) - appInfo += app -> (1, origGen, b, notB, infos) + appInfos += app -> (1, origGen, b, notB, infos) } } def unrollBehind(ids: Seq[T]): Seq[T] = { - assert(ids.forall(id => (callInfo contains id) || (blockerToApp contains id))) + assert(ids.forall(id => (callInfos contains id) || (blockerToApps contains id))) var newClauses : Seq[T] = Seq.empty - val newCallInfos = ids.flatMap(id => callInfo.get(id).map(id -> _)) - callInfo --= ids + val newCallInfos = ids.flatMap(id => callInfos.get(id).map(id -> _)) + callInfos --= ids - val apps = ids.flatMap(id => blockerToApp.get(id)) - val appInfos = apps.map(app => app -> appInfo(app)) - blockerToApp --= ids - appInfo --= apps + val apps = ids.flatMap(id => blockerToApps.get(id)) + val thisAppInfos = apps.map(app => app -> appInfos(app)) + blockerToApps --= ids + appInfos --= apps - for ((app, (_, _, _, _, infos)) <- appInfos if infos.nonEmpty) { + for ((app, (_, _, _, _, infos)) <- thisAppInfos if infos.nonEmpty) { val extension = extendAppBlock(app, infos) reporter.debug(" -> extending lambda blocker: " + extension) newClauses :+= extension @@ -296,7 +290,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat newClauses ++= newCls } - for ((app @ (b, _), (gen, _, _, _, infos)) <- appInfos; info @ TemplateAppInfo(template, equals, args) <- infos) { + for ((app @ (b, _), (gen, _, _, _, infos)) <- thisAppInfos; info @ TemplateAppInfo(template, equals, args) <- infos) { var newCls = Seq.empty[T] val nb = encoder.encodeId(FreshIdentifier("b", BooleanType, true)) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index d2dc57f50..44df7eb67 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -544,25 +544,6 @@ trait AbstractZ3Solver extends Solver { rec(expr) } - protected def fromRawArray(r: Expr, tpe: TypeTree): Expr = r match { - case rav: RawArrayValue => - fromRawArray(rav, tpe) - case _ => - scala.sys.error("Unable to extract from raw array for "+r.asString) - } - - protected def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { - case RawArrayType(from, to) => - r - - case ft @ FunctionType(from, to) => - finiteLambda(r.default, r.elems.toSeq, from) - - - case _ => - scala.sys.error("Unable to extract from raw array for "+tpe.asString) - } - protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { def rec(t: Z3AST, tpe: TypeTree): Expr = { @@ -678,15 +659,8 @@ trait AbstractZ3Solver extends Solver { FiniteMap(elems, from, to) } - case FunctionType(fts, tt) => - model.getArrayValue(t) match { - case None => reporter.fatalError("Translation from Z3 to function value failed") - case Some((map, elseZ3Value)) => - val leonElseValue = rec(elseZ3Value, tt) - val leonMap = map.toSeq.map(p => rec(p._1, tupleTypeWrap(fts)) -> rec(p._2, tt)) - finiteLambda(leonElseValue, leonMap, fts) - } + rec(t, RawArrayType(tupleTypeWrap(fts), tt)) case tpe @ SetType(dt) => model.getSetValue(t) match { diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index b8eeeafab..406dcde0a 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -12,6 +12,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.Constructors._ +import purescala.Quantification._ import purescala.ExprOps._ import purescala.Types._ @@ -24,7 +25,8 @@ import termination._ class FairZ3Solver(val context: LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction - with FairZ3Component { + with FairZ3Component + with EvaluatingSolver { enclosing => @@ -39,15 +41,6 @@ class FairZ3Solver(val context: LeonContext, val program: Program) protected def hasError = errors.getB(()) contains true protected def addError() = errors += () -> true - private val evaluator: Evaluator = - if(useCodeGen) { - // TODO If somehow we could not recompile each time we create a solver, - // that would be good? - new CodeGenEvaluator(context, program) - } else { - new DefaultEvaluator(context, program) - } - protected[z3] def getEvaluator : Evaluator = evaluator private val terminator : TerminationChecker = new SimpleTerminationChecker(context, program) @@ -62,8 +55,57 @@ class FairZ3Solver(val context: LeonContext, val program: Program) )} toggleWarningMessages(true) + private def extractModel(model: Z3Model, ids: Set[Identifier]): HenkinModel = { + val asMap = modelToMap(model, ids) + + def extract(b: Z3AST, m: Matcher[Z3AST]): Set[Seq[Expr]] = { + val QuantificationTypeMatcher(fromTypes, _) = m.tpe + val optEnabler = model.evalAs[Boolean](b) + val optArgs = (m.args zip fromTypes).map { + p => softFromZ3Formula(model, model.eval(Matcher.argValue(p._1), true).get, p._2) + } + + if (optEnabler == Some(true) && optArgs.forall(_.isDefined)) { + Set(optArgs.map(_.get)) + } else { + Set.empty + } + } + + val funDomains = ids.flatMap(id => id.getType match { + case ft @ FunctionType(fromTypes, _) => variables.getB(id.toVariable) match { + case Some(z3ID) => Some(id -> templateGenerator.manager.instantiations(z3ID, ft).flatMap { + case (b, m) => extract(b, m) + }) + case _ => None + } + case _ => None + }).toMap.mapValues(_.toSet) + + val asDMap = asMap.map(p => funDomains.get(p._1) match { + case Some(domain) => + val mapping = domain.toSeq.map { es => + val ev: Expr = p._2 match { + case RawArrayValue(_, mapping, dflt) => + mapping.collectFirst { + case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v + } getOrElse dflt + case _ => scala.sys.error("Unexpected function encoding " + p._2) + } + es -> ev + } + p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) + case None => p + }) + + val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) + val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + + val domain = new HenkinDomains(typeDomains) + new HenkinModel(asDMap, domain) + } - private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, Map[Identifier,Expr]) = { + private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, HenkinModel) = { if(!interrupted) { val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap @@ -92,22 +134,23 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } else Seq() }).toMap - val asMap = modelToMap(model, variables) ++ functionsAsMap ++ constantFunctionsAsMap - val evalResult = evaluator.eval(formula, asMap) + val leonModel = extractModel(model, variables) + val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) + val evalResult = evaluator.eval(formula, fullModel) evalResult match { case EvaluationResults.Successful(BooleanLiteral(true)) => reporter.debug("- Model validated.") - (true, asMap) + (true, fullModel) case EvaluationResults.Successful(res) => assert(res == BooleanLiteral(false), "Checking model returned non-boolean") reporter.debug("- Invalid model.") - (false, asMap) + (false, fullModel) case EvaluationResults.RuntimeError(msg) => reporter.debug("- Model leads to runtime error.") - (false, asMap) + (false, fullModel) case EvaluationResults.EvaluatorError(msg) => if (silenceErrors) { @@ -115,11 +158,11 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } else { reporter.warning("Something went wrong. While evaluating the model, we got this : " + msg) } - (false, asMap) + (false, fullModel) } } else { - (false, Map.empty) + (false, HenkinModel.empty) } } @@ -156,7 +199,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val solver = z3.mkSolver() private val freeVars = new IncrementalSet[Identifier]() - private var constraints = new IncrementalSeq[Expr]() + private val constraints = new IncrementalSeq[Expr]() val unrollingBank = new UnrollingBank(context, templateGenerator) @@ -195,7 +238,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) var foundDefinitiveAnswer = false var definitiveAnswer : Option[Boolean] = None - var definitiveModel : Map[Identifier,Expr] = Map.empty + var definitiveModel : HenkinModel = HenkinModel.empty var definitiveCore : Set[Expr] = Set.empty def assertCnstr(expression: Expr) { @@ -204,7 +247,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) freeVars ++= newFreeVars // We make sure all free variables are registered as variables - freeVars.toSet.foreach { v => + freeVars.foreach { v => variables.cachedB(Variable(v)) { templateGenerator.encoder.encodeId(v) } @@ -236,7 +279,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) def entireFormula = andJoin(assumptions.toSeq ++ constraints.toSeq) - def foundAnswer(answer : Option[Boolean], model : Map[Identifier,Expr] = Map.empty, core: Set[Expr] = Set.empty) : Unit = { + def foundAnswer(answer: Option[Boolean], model: HenkinModel = HenkinModel.empty, core: Set[Expr] = Set.empty) : Unit = { foundDefinitiveAnswer = true definitiveAnswer = answer definitiveModel = model @@ -268,13 +311,13 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val timer = context.timers.solvers.z3.check.start() solver.push() // FIXME: remove when z3 bug is fixed - val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.currentBlockers ++ unrollingBank.quantificationAssumptions) :_*) + val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.satisfactionAssumptions) :_*) solver.pop() // FIXME: remove when z3 bug is fixed timer.stop() reporter.debug(" - Finished search with blocked literals") - lazy val allVars = freeVars.toSet + lazy val allVars: Set[Identifier] = freeVars.toSet res match { case None => @@ -300,7 +343,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) foundAnswer(None, model) } } else { - val model = modelToMap(z3model, allVars) + val model = extractModel(z3model, allVars) //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") //reporter.debug("- Found a model:") @@ -359,7 +402,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val timer = context.timers.solvers.z3.check.start() solver.push() // FIXME: remove when z3 bug is fixed - val res2 = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.quantificationAssumptions) : _*) + val res2 = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.refutationAssumptions) : _*) solver.pop() // FIXME: remove when z3 bug is fixed timer.stop() diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 110677737..f755f2a4a 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -66,7 +66,7 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) } def getModel = { - modelToMap(solver.getModel(), freeVariables.toSet) + new Model(modelToMap(solver.getModel(), freeVariables.toSet)) } def getUnsatCore = { diff --git a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala index 62f43d736..4383f8f31 100644 --- a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala +++ b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala @@ -6,6 +6,7 @@ package solvers.z3 import z3.scala._ import purescala.Common._ import purescala.Expressions._ +import purescala.Constructors._ import purescala.ExprOps._ import purescala.Types._ @@ -18,7 +19,7 @@ trait Z3ModelReconstruction { def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = { val expectedType = if(tpe == null) id.getType else tpe - + variables.getB(id.toVariable).flatMap { z3ID => expectedType match { case BooleanType => model.evalAs[Boolean](z3ID).map(BooleanLiteral) @@ -43,7 +44,7 @@ trait Z3ModelReconstruction { reporter.debug("Completing variable '" + id + "' to simplest value") } - for(id <- ids) { + for (id <- ids) { modelValue(model, id) match { case None if AUTOCOMPLETEMODELS => completeID(id) case None => ; @@ -51,6 +52,7 @@ trait Z3ModelReconstruction { case Some(ex) => asMap = asMap + (id -> ex) } } + asMap } diff --git a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala index c363bcfa5..990be35e1 100644 --- a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala +++ b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala @@ -8,6 +8,7 @@ import purescala.Expressions._ import purescala.Common._ import purescala.Types._ import purescala.Constructors._ +import purescala.Quantification._ import evaluators._ import codegen.CodeGenParams @@ -85,7 +86,7 @@ abstract class BottomUpTEGISLike[T <% Typed](name: String) extends Rule(name) { { (vecs: Vector[Vector[Expr]]) => val res = (0 to nTests-1).toVector.flatMap { i => - val inputs = vecs.map(_(i)) + val inputs = new solvers.Model((args zip vecs.map(_(i))).toMap) ev(inputs) match { case EvaluationResults.Successful(out) => Some(out) case _ => diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala index 930771443..790db9106 100644 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -35,8 +35,8 @@ class LoopProcessor(val checker: TerminationChecker, val modules: ChainBuilder w val resTuple = tupleWrap(freshParams.map(_.toVariable)) definitiveSATwithModel(andJoin(path :+ Equals(srcTuple, resTuple))) match { - case Some(map) => - val args = chain.funDef.params.map(arg => map(arg.id)) + case Some(model) => + val args = chain.funDef.params.map(arg => model(arg.id)) val res = if (chain.relations.exists(_.inLambda)) MaybeBroken(chain.funDef, args) else Broken(chain.funDef, args) nonTerminating(chain.funDef) = res case None => diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index 91590dd4a..99124c5e6 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -64,7 +64,7 @@ trait Solvable extends Processor { } } - def definitiveSATwithModel(problem: Expr): Option[Map[Identifier, Expr]] = { + def definitiveSATwithModel(problem: Expr): Option[Model] = { withoutPosts { val (sat, model) = SimpleSolverAPI(solver).solveSAT(problem) if (sat.isDefined && sat.get) Some(model) else None diff --git a/src/main/scala/leon/utils/IncrementalMap.scala b/src/main/scala/leon/utils/IncrementalMap.scala index 4515351de..b07d8adc1 100644 --- a/src/main/scala/leon/utils/IncrementalMap.scala +++ b/src/main/scala/leon/utils/IncrementalMap.scala @@ -2,12 +2,23 @@ package leon.utils -import scala.collection.mutable.{Stack, Map => MMap} +import scala.collection.mutable.{Stack, Map => MMap, Builder} +import scala.collection.generic.Shrinkable +import scala.collection.IterableLike -class IncrementalMap[A, B] extends IncrementalState { - private[this] val stack = new Stack[MMap[A, B]]() +class IncrementalMap[A, B] private(dflt: Option[B]) + extends IncrementalState + with Iterable[(A,B)] + with IterableLike[(A,B), Map[A,B]] + with Builder[(A,B), IncrementalMap[A, B]] + with Shrinkable[A] + with (A => B) { - def clear(): Unit = { + def this() = this(None) + + private val stack = new Stack[MMap[A, B]]() + + override def clear(): Unit = { stack.clear() } @@ -22,22 +33,40 @@ class IncrementalMap[A, B] extends IncrementalState { } else { MMap[A,B]() ++ stack.head } - stack.push(last) + + val withDefault = dflt match { + case Some(value) => last.withDefaultValue(value) + case None => last + } + + stack.push(withDefault) } def pop(): Unit = { stack.pop() } - def +=(a: A, b: B): Unit = { - stack.head += a -> b + def withDefaultValue[B1 >: B](dflt: B1) = { + val map = new IncrementalMap[A, B1](Some(dflt)) + map.stack.pop() + for (s <- stack.toList) map.stack.push(MMap[A,B1]().withDefaultValue(dflt) ++ s) + map } - def ++=(as: Traversable[(A, B)]): Unit = { - stack.head ++= as - } + def get(k: A) = stack.head.get(k) + def apply(k: A) = stack.head.apply(k) + def contains(k: A) = stack.head.contains(k) + def isDefinedAt(k: A) = stack.head.isDefinedAt(k) + def getOrElse[B1 >: B](k: A, e: => B1) = stack.head.getOrElse(k, e) + def values = stack.head.values + + def iterator = stack.head.iterator + def +=(kv: (A, B)) = { stack.head += kv; this } + def -=(k: A) = { stack.head -= k; this } - def toMap = stack.head + override def seq = stack.head.seq + override def newBuilder = new scala.collection.mutable.MapBuilder(Map.empty[A,B]) + def result = this push() } diff --git a/src/main/scala/leon/utils/IncrementalSeq.scala b/src/main/scala/leon/utils/IncrementalSeq.scala index 9f66d66a8..4ec9290b5 100644 --- a/src/main/scala/leon/utils/IncrementalSeq.scala +++ b/src/main/scala/leon/utils/IncrementalSeq.scala @@ -3,9 +3,15 @@ package leon.utils import scala.collection.mutable.Stack +import scala.collection.mutable.Builder import scala.collection.mutable.ArrayBuffer +import scala.collection.{Iterable, IterableLike} + +class IncrementalSeq[A] extends IncrementalState + with Iterable[A] + with IterableLike[A, Seq[A]] + with Builder[A, IncrementalSeq[A]] { -class IncrementalSeq[A] extends IncrementalState { private[this] val stack = new Stack[ArrayBuffer[A]]() def clear() : Unit = { @@ -25,11 +31,11 @@ class IncrementalSeq[A] extends IncrementalState { stack.pop() } - def +=(e: A): Unit = { - stack.head += e - } + def iterator = stack.flatten.iterator + def +=(e: A) = { stack.head += e; this } - def toSeq = stack.toSeq.flatten + override def newBuilder = new scala.collection.mutable.ArrayBuffer() + def result = this push() } diff --git a/src/main/scala/leon/utils/IncrementalSet.scala b/src/main/scala/leon/utils/IncrementalSet.scala index 95b473c75..b88dcf840 100644 --- a/src/main/scala/leon/utils/IncrementalSet.scala +++ b/src/main/scala/leon/utils/IncrementalSet.scala @@ -3,11 +3,17 @@ package leon.utils import scala.collection.mutable.{Stack, Set => MSet} +import scala.collection.mutable.Builder +import scala.collection.{Iterable, IterableLike} + +class IncrementalSet[A] extends IncrementalState + with Iterable[A] + with IterableLike[A, Set[A]] + with Builder[A, IncrementalSet[A]] { -class IncrementalSet[A] extends IncrementalState { private[this] val stack = new Stack[MSet[A]]() - def clear(): Unit = { + override def clear(): Unit = { stack.clear() } @@ -24,15 +30,15 @@ class IncrementalSet[A] extends IncrementalState { stack.pop() } - def +=(a: A): Unit = { - stack.head += a - } + def apply(elem: A) = toSet.contains(elem) + def contains(elem: A) = toSet.contains(elem) - def ++=(as: Traversable[A]): Unit = { - stack.head ++= as - } + def iterator = stack.flatten.iterator + def += (elem: A) = { stack.head += elem; this } + def -= (elem: A) = { stack.foreach(_ -= elem); this } - def toSet = stack.toSet.flatten + override def newBuilder = new scala.collection.mutable.SetBuilder(Set.empty[A]) + def result = this push() } diff --git a/src/main/scala/leon/utils/IncrementalState.scala b/src/main/scala/leon/utils/IncrementalState.scala index b84606af2..32f6b7c2b 100644 --- a/src/main/scala/leon/utils/IncrementalState.scala +++ b/src/main/scala/leon/utils/IncrementalState.scala @@ -4,6 +4,8 @@ trait IncrementalState { def push(): Unit def pop(): Unit + final def pop(lvl: Int): Unit = List.range(0, lvl).foreach(_ => pop()) + def clear(): Unit def reset(): Unit } diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 7c91fa436..148dcba64 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -5,6 +5,7 @@ package utils import leon.purescala._ import leon.purescala.Definitions.Program +import leon.purescala.Quantification.CheckForalls import leon.solvers.isabelle.AdaptationPhase import leon.synthesis.{ConvertWithOracle, ConvertHoles} import leon.verification.InjectAsserts diff --git a/src/main/scala/leon/verification/VerificationCondition.scala b/src/main/scala/leon/verification/VerificationCondition.scala index 8ddbf97ab..33511b23c 100644 --- a/src/main/scala/leon/verification/VerificationCondition.scala +++ b/src/main/scala/leon/verification/VerificationCondition.scala @@ -88,7 +88,7 @@ sealed abstract class VCStatus(val name: String) { } object VCStatus { - case class Invalid(cex: Map[Identifier, Expr]) extends VCStatus("invalid") + case class Invalid(cex: Model) extends VCStatus("invalid") case object Valid extends VCStatus("valid") case object Unknown extends VCStatus("unknown") case object Timeout extends VCStatus("timeout") diff --git a/src/test/resources/regression/verification/purescala/invalid/HOInvocations2.scala b/src/test/resources/regression/verification/purescala/invalid/HOInvocations2.scala index 5cd0afc42..1639def5d 100644 --- a/src/test/resources/regression/verification/purescala/invalid/HOInvocations2.scala +++ b/src/test/resources/regression/verification/purescala/invalid/HOInvocations2.scala @@ -1,6 +1,6 @@ import leon.lang._ -object HOInvocations { +object HOInvocations2 { def switch(x: BigInt, f: (BigInt) => BigInt, g: (BigInt) => BigInt) = if(x > 0) f else g def failling_1(f: (BigInt) => BigInt) = { diff --git a/src/test/resources/regression/verification/purescala/invalid/PositiveMap2.scala b/src/test/resources/regression/verification/purescala/invalid/PositiveMap2.scala index 7b15cb257..713f901c9 100644 --- a/src/test/resources/regression/verification/purescala/invalid/PositiveMap2.scala +++ b/src/test/resources/regression/verification/purescala/invalid/PositiveMap2.scala @@ -1,6 +1,6 @@ import leon.lang._ -object PositiveMap { +object PositiveMap2 { abstract class List case class Cons(head: BigInt, tail: List) extends List diff --git a/src/test/resources/regression/verification/purescala/valid/HOInvocations2.scala b/src/test/resources/regression/verification/purescala/valid/HOInvocations2.scala index ba44b8fdc..0af2a6a06 100644 --- a/src/test/resources/regression/verification/purescala/valid/HOInvocations2.scala +++ b/src/test/resources/regression/verification/purescala/valid/HOInvocations2.scala @@ -1,6 +1,6 @@ import leon.lang._ -object HOInvocations { +object HOInvocations2 { def switch(x: Int, f: (Int) => Int, g: (Int) => Int) = if(x > 0) f else g def passing_1(f: (Int) => Int) = { diff --git a/src/test/resources/regression/verification/purescala/valid/PositiveMap2.scala b/src/test/resources/regression/verification/purescala/valid/PositiveMap2.scala index 030e70959..eb2262f02 100644 --- a/src/test/resources/regression/verification/purescala/valid/PositiveMap2.scala +++ b/src/test/resources/regression/verification/purescala/valid/PositiveMap2.scala @@ -1,7 +1,7 @@ import leon.lang._ -object PositiveMap { +object PositiveMap2 { abstract class List case class Cons(head: BigInt, tail: List) extends List diff --git a/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala b/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala index 4404b2acf..c7f33afba 100644 --- a/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala +++ b/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala @@ -71,6 +71,7 @@ class StablePrintingSuite extends LeonRegressionSuite { for (e <- reporter.lastErrors) { info(e) } + println(e) info(e.getMessage) fail("Compilation failed") } diff --git a/src/test/scala/leon/regression/verification/VerificationSuite.scala b/src/test/scala/leon/regression/verification/VerificationSuite.scala index 411816245..4c9602c00 100644 --- a/src/test/scala/leon/regression/verification/VerificationSuite.scala +++ b/src/test/scala/leon/regression/verification/VerificationSuite.scala @@ -45,7 +45,8 @@ trait VerificationSuite extends LeonRegressionSuite { user map { u => (u.id, Program(u :: lib)) } } for ((id, p) <- programs; options <- optionVariants) { - test(f"${nextInt()}%3d: ${id.name} ${options.mkString(" ")}") { + val index = nextInt() + test(f"$index%3d: ${id.name} ${options.mkString(" ")}") { val ctx = createLeonContext(options: _*) if (forError) { intercept[LeonFatalError] { -- GitLab