diff --git a/src/main/java/leon/codegen/runtime/FiniteLambda.java b/src/main/java/leon/codegen/runtime/FiniteLambda.java deleted file mode 100644 index cefc755227326e5ba701defed954cb6549aa7666..0000000000000000000000000000000000000000 --- a/src/main/java/leon/codegen/runtime/FiniteLambda.java +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon.codegen.runtime; - -import java.util.HashMap; - -public final class FiniteLambda extends Lambda { - private final HashMap<Tuple, Object> _underlying = new HashMap<Tuple, Object>(); - private final Object dflt; - - public FiniteLambda(Object dflt) { - super(); - this.dflt = dflt; - } - - public void add(Tuple key, Object value) { - _underlying.put(key, value); - } - - @Override - public Object apply(Object[] args) { - Tuple tuple = new Tuple(args); - if (_underlying.containsKey(tuple)) { - return _underlying.get(tuple); - } else { - return dflt; - } - } -} diff --git a/src/main/java/leon/codegen/runtime/Lambda.java b/src/main/java/leon/codegen/runtime/Lambda.java index b266f83c9cf633b45dde24f2f2433be802c5becf..a6abbef37edbe8f87f480a21a6200e32a9e0206b 100644 --- a/src/main/java/leon/codegen/runtime/Lambda.java +++ b/src/main/java/leon/codegen/runtime/Lambda.java @@ -3,5 +3,5 @@ package leon.codegen.runtime; public abstract class Lambda { - public abstract Object apply(Object[] args); + public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException; } diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java new file mode 100644 index 0000000000000000000000000000000000000000..62e7a7b3bbc3053b4515159f1814f213edbc3a58 --- /dev/null +++ b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java @@ -0,0 +1,33 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.List; +import java.util.LinkedList; +import java.util.HashMap; + +public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { + private final HashMap<Integer, List<Tuple>> domains = new HashMap<Integer, List<Tuple>>(); + + public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { + super(maxInvocations); + } + + public void add(int type, Tuple input) { + if (!domains.containsKey(type)) domains.put(type, new LinkedList<Tuple>()); + domains.get(type).add(input); + } + + public List<Tuple> domain(Object obj, int type) { + List<Tuple> domain = new LinkedList<Tuple>(); + if (obj instanceof PartialLambda) { + for (Tuple key : ((PartialLambda) obj).mapping.keySet()) { + domain.add(key); + } + } + + domain.addAll(domains.get(type)); + + return domain; + } +} diff --git a/src/main/java/leon/codegen/runtime/PartialLambda.java b/src/main/java/leon/codegen/runtime/PartialLambda.java new file mode 100644 index 0000000000000000000000000000000000000000..826cc5ed9930e54bc2f50d7f09e6fa09be3fa307 --- /dev/null +++ b/src/main/java/leon/codegen/runtime/PartialLambda.java @@ -0,0 +1,41 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.HashMap; + +public final class PartialLambda extends Lambda { + final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>(); + + public PartialLambda() { + super(); + } + + public void add(Tuple key, Object value) { + mapping.put(key, value); + } + + @Override + public Object apply(Object[] args) throws LeonCodeGenRuntimeException { + Tuple tuple = new Tuple(args); + if (mapping.containsKey(tuple)) { + return mapping.get(tuple); + } else { + throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments"); + } + } + + @Override + public boolean equals(Object that) { + if (that != null && (that instanceof PartialLambda)) { + return mapping.equals(((PartialLambda) that).mapping); + } else { + return false; + } + } + + @Override + public int hashCode() { + return 63 + 11 * mapping.hashCode(); + } +} diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 985c5970ef79d89102eb008f8ccf7df44f347265..e8b232715b1337f36355f8b40ccdb451b9ac65bb 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -6,10 +6,11 @@ package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.ExprOps.{simplestValue, matchToIfThenElse} +import purescala.ExprOps.{simplestValue, matchToIfThenElse, collect} import purescala.Types._ import purescala.Constructors._ import purescala.Extractors._ +import purescala.Quantification._ import cafebabe._ import cafebabe.AbstractByteCodes._ @@ -47,22 +48,24 @@ trait CodeGeneration { def withArgs(newArgs: Map[Identifier, Int]) = Locals(vars, args ++ newArgs, closures, isStatic) def withClosures(newClosures: Map[Identifier,(String,String,String)]) = Locals(vars, args, closures ++ newClosures, isStatic) - - /** The index of the monitor object in this function */ - def monitorIndex = if (isStatic) 0 else 1 } - + object NoLocals { /** Make a [[Locals]] object without any local variables */ def apply(isStatic : Boolean) = new Locals(Map(), Map(), Map(), isStatic) } + lazy val monitorID = FreshIdentifier("__$monitor") + private[codegen] val ObjectClass = "java/lang/Object" private[codegen] val BoxedIntClass = "java/lang/Integer" private[codegen] val BoxedBoolClass = "java/lang/Boolean" private[codegen] val BoxedCharClass = "java/lang/Character" private[codegen] val BoxedArrayClass = "leon/codegen/runtime/ArrayBox" + private[codegen] val JavaListClass = "java/util/List" + private[codegen] val JavaIteratorClass = "java/util/Iterator" + private[codegen] val TupleClass = "leon/codegen/runtime/Tuple" private[codegen] val SetClass = "leon/codegen/runtime/Set" private[codegen] val MapClass = "leon/codegen/runtime/Map" @@ -76,6 +79,7 @@ trait CodeGeneration { private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" private[codegen] val GenericValuesClass = "leon/codegen/runtime/GenericValues" private[codegen] val MonitorClass = "leon/codegen/runtime/LeonCodeGenRuntimeMonitor" + private[codegen] val HenkinClass = "leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor" def idToSafeJVMName(id: Identifier) = { scala.reflect.NameTransformer.encode(id.uniqueName).replaceAll("\\.", "\\$") @@ -147,15 +151,15 @@ trait CodeGeneration { * @param funDef The function definition to be compiled * @param owner The module/class that contains `funDef` */ - def compileFunDef(funDef : FunDef, owner : Definition) { + def compileFunDef(funDef: FunDef, owner: Definition) { val isStatic = owner.isInstanceOf[ModuleDef] - + val cf = classes(owner) val (_,mn,_) = leonFunDefToJVMInfo(funDef).get val paramsTypes = funDef.params.map(a => typeToJVM(a.getType)) - val realParams = if (params.requireMonitor) { + val realParams = if (requireMonitor) { ("L" + MonitorClass + ";") +: paramsTypes } else { paramsTypes @@ -176,16 +180,15 @@ trait CodeGeneration { METHOD_ACC_PUBLIC | METHOD_ACC_FINAL ).asInstanceOf[U2]) - + val ch = m.codeHandler - + // An offset we introduce to the parameters: // 1 if this is a method, so we need "this" in position 0 of the stack // 1 if we are monitoring - val paramsOffset = Seq(!isStatic, params.requireMonitor).count(x => x) - val newMapping = - funDef.params.map(_.id).zipWithIndex.toMap.mapValues(_ + paramsOffset) - + val idParams = (if (requireMonitor) Seq(monitorID) else Seq.empty) ++ funDef.params.map(_.id) + val newMapping = idParams.zipWithIndex.toMap.mapValues(_ + (if (!isStatic) 1 else 0)) + val body = funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name)) val bodyWithPre = if(funDef.hasPrecondition && params.checkContracts) { @@ -200,12 +203,14 @@ trait CodeGeneration { case _ => bodyWithPre } + val locals = Locals(newMapping, Map.empty, Map.empty, isStatic) + if (params.recordInvocations) { - // index of monitor object will be before the first Scala parameter - ch << ALoad(paramsOffset-1) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + load(monitorID, ch)(locals) + ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") } - mkExpr(bodyWithPost, ch)(Locals(newMapping, Map.empty, Map.empty, isStatic)) + mkExpr(bodyWithPost, ch)(locals) funDef.returnType match { case ValueType() => @@ -218,6 +223,318 @@ trait CodeGeneration { ch.freeze } + private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String] + private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda] + + private def compileLambda(l: Lambda, ch: CodeHandler)(implicit locals: Locals): Unit = { + val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(l) + val reverseSubst = structSubst.map(p => p._2 -> p._1) + val nl = normalized.asInstanceOf[Lambda] + + val closureIDs = purescala.ExprOps.variablesOf(nl).toSeq.sortBy(_.uniqueName) + val closuresWithoutMonitor = closureIDs.map(id => id -> typeToJVM(id.getType)) + val closures = if (requireMonitor) { + (monitorID -> s"L$MonitorClass;") +: closuresWithoutMonitor + } else closuresWithoutMonitor + + val afName = lambdaToClass.getOrElse(nl, { + val afName = "Leon$CodeGen$Lambda$" + lambdaCounter.nextGlobal + lambdaToClass += nl -> afName + classToLambda += afName -> nl + + val cf = new ClassFile(afName, Some(LambdaClass)) + + cf.setFlags(( + CLASS_ACC_SUPER | + CLASS_ACC_PUBLIC | + CLASS_ACC_FINAL + ).asInstanceOf[U2]) + + if (closures.isEmpty) { + cf.addDefaultConstructor + } else { + for ((id, jvmt) <- closures) { + val fh = cf.addField(jvmt, id.uniqueName) + fh.setFlags(( + FIELD_ACC_PUBLIC | + FIELD_ACC_FINAL + ).asInstanceOf[U2]) + } + + val cch = cf.addConstructor(closures.map(_._2).toList).codeHandler + + cch << ALoad(0) + cch << InvokeSpecial(LambdaClass, constructorName, "()V") + + var c = 1 + for ((id, jvmt) <- closures) { + cch << ALoad(0) + cch << (jvmt match { + case "I" | "Z" => ILoad(c) + case _ => ALoad(c) + }) + cch << PutField(afName, id.uniqueName, jvmt) + c += 1 + } + + cch << RETURN + cch.freeze + } + + locally { + val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") + + apm.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val argMapping = nl.args.map(_.id).zipWithIndex.toMap + val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap + + val newLocals = locals.withArgs(argMapping).withClosures(closureMapping) + + val apch = apm.codeHandler + + mkBoxedExpr(nl.body, apch)(newLocals) + + apch << ARETURN + + apch.freeze + } + + locally { + val emh = cf.addMethod("Z", "equals", s"L$ObjectClass;") + emh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val ech = emh.codeHandler + + val notRefEq = ech.getFreshLabel("notrefeq") + val notEq = ech.getFreshLabel("noteq") + val castSlot = ech.getFreshVar + + // If references are equal, trees are equal. + ech << ALoad(0) << ALoad(1) << If_ACmpNe(notRefEq) << Ldc(1) << IRETURN << Label(notRefEq) + + // We check the type (this also checks against null).... + ech << ALoad(1) << InstanceOf(afName) << IfEq(notEq) + + // ...finally, we compare fields one by one, shortcircuiting on disequalities. + if(closures.nonEmpty) { + ech << ALoad(1) << CheckCast(afName) << AStore(castSlot) + + for((id,jvmt) <- closures) { + ech << ALoad(0) << GetField(afName, id.uniqueName, jvmt) + ech << ALoad(castSlot) << GetField(afName, id.uniqueName, jvmt) + + jvmt match { + case "I" | "Z" => + ech << If_ICmpNe(notEq) + + case ot => + ech << InvokeVirtual(ObjectClass, "equals", s"(L$ObjectClass;)Z") << IfEq(notEq) + } + } + } + + ech << Ldc(1) << IRETURN << Label(notEq) << Ldc(0) << IRETURN + ech.freeze + } + + locally { + val hashFieldName = "$leon$hashCode" + cf.addField("I", hashFieldName).setFlags(FIELD_ACC_PRIVATE) + val hmh = cf.addMethod("I", "hashCode", "") + hmh.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val hch = hmh.codeHandler + + val wasNotCached = hch.getFreshLabel("wasNotCached") + + hch << ALoad(0) << GetField(afName, hashFieldName, "I") << DUP + hch << IfEq(wasNotCached) + hch << IRETURN + hch << Label(wasNotCached) << POP + + hch << Ldc(closuresWithoutMonitor.size) << NewArray(s"$ObjectClass") + for (((id, jvmt),i) <- closuresWithoutMonitor.zipWithIndex) { + hch << DUP << Ldc(i) + hch << ALoad(0) << GetField(afName, id.uniqueName, jvmt) + mkBox(id.getType, hch) + hch << AASTORE + } + + hch << Ldc(afName.hashCode) + hch << InvokeStatic(HashingClass, "seqHash", s"([L$ObjectClass;I)I") << DUP + hch << ALoad(0) << SWAP << PutField(afName, hashFieldName, "I") + hch << IRETURN + + hch.freeze + } + + loader.register(cf) + + afName + }) + + val consSig = "(" + closures.map(_._2).mkString("") + ")V" + + ch << New(afName) << DUP + for ((id,jvmt) <- closures) { + if (id == monitorID) { + load(monitorID, ch) + } else { + mkExpr(Variable(reverseSubst(id)), ch) + } + } + ch << InvokeSpecial(afName, constructorName, consSig) + } + + private val typeIdCache = scala.collection.mutable.Map.empty[TypeTree, Int] + private[codegen] def typeId(tpe: TypeTree): Int = typeIdCache.get(tpe) match { + case Some(id) => id + case None => + val id = typeIdCache.size + typeIdCache += tpe -> id + id + } + + private def compileForall(f: Forall, ch: CodeHandler)(implicit locals: Locals): Unit = { + // make sure we have an available HenkinModel + val monitorOk = ch.getFreshLabel("monitorOk") + load(monitorID, ch) + ch << InstanceOf(HenkinClass) << IfNe(monitorOk) + ch << New(ImpossibleEvaluationClass) << DUP + ch << Ldc("Can't evaluate foralls without domain") + ch << InvokeSpecial(ImpossibleEvaluationClass, constructorName, "(Ljava/lang/String;)V") + ch << ATHROW + ch << Label(monitorOk) + + val Forall(fargs, TopLevelAnds(conjuncts)) = f + val endLabel = ch.getFreshLabel("forallEnd") + + for (conj <- conjuncts) { + val vars = purescala.ExprOps.variablesOf(conj) + val args = fargs.map(_.id).filter(vars) + val quantified = args.toSet + + val matchQuorums = extractQuorums(conj, quantified) + + val matcherIndexes = matchQuorums.flatten.distinct.zipWithIndex.toMap + + def buildLoops( + mis: List[(Expr, Seq[Expr], Int)], + localMapping: Map[Identifier, Int], + pointerMapping: Map[(Int, Int), Identifier] + ): Unit = mis match { + case (expr, args, qidx) :: rest => + load(monitorID, ch) + ch << CheckCast(HenkinClass) + + mkExpr(expr, ch) + ch << Ldc(typeId(expr.getType)) + ch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") + ch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") + + val loop = ch.getFreshLabel("loop") + val out = ch.getFreshLabel("out") + ch << Label(loop) + // it + ch << DUP + // it, it + ch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") + // it, hasNext + ch << IfEq(out) << DUP + // it, it + ch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") + // it, elem + ch << CheckCast(TupleClass) + + val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { + val id = FreshIdentifier("q", arg.getType, true) + val slot = ch.getFreshVar + + ch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") + mkUnbox(arg.getType, ch) + ch << (typeToJVM(arg.getType) match { + case "I" | "Z" => IStore(slot) + case _ => AStore(slot) + }) + + (id -> slot, (qidx -> aidx) -> id) + }).unzip + + ch << POP + // it + + buildLoops(rest, localMapping ++ newLoc, pointerMapping ++ newPtr) + + ch << Goto(loop) + ch << Label(out) << POP + + case Nil => + var okLabel: Option[String] = None + for (quorum <- matchQuorums) { + okLabel.foreach(ok => ch << Label(ok)) + okLabel = Some(ch.getFreshLabel("quorumOk")) + + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + + for (q @ (expr, args) <- quorum) { + val qidx = matcherIndexes(q) + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id), aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, pointerMapping(qidx -> aidx).toVariable) + } ++ equalities.map { + case (k1, k2) => Equals(pointerMapping(k1).toVariable, pointerMapping(k2).toVariable) + }) + + mkExpr(enabler, ch)(locals.withVars(localMapping)) + ch << IfEq(okLabel.get) + + val varsMap = args.map(id => id -> localMapping(pointerMapping(mapping(id)))).toMap + mkExpr(conj, ch)(locals.withVars(varsMap)) + ch << IfNe(okLabel.get) + + // -- Forall is false! -- + // POP all the iterators... + for (_ <- List.range(0, matcherIndexes.size)) ch << POP + + // ... and return false + ch << Ldc(0) << Goto(endLabel) + } + + ch << Label(okLabel.get) + } + + buildLoops(matcherIndexes.toList.map { case ((e, as), idx) => (e, as, idx) }, Map.empty, Map.empty) + } + + ch << Ldc(1) << Label(endLabel) + } + private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { e match { case Variable(id) => @@ -226,7 +543,7 @@ trait CodeGeneration { case Assert(cond, oerr, body) => mkExpr(IfExpr(Not(cond), Error(body.getType, oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch) - case en@Ensuring(_, _) => + case en @ Ensuring(_, _) => mkExpr(en.toAssert, ch) case Let(i,d,b) => @@ -267,8 +584,10 @@ trait CodeGeneration { throw CompilationException("Unknown class : " + cct.id) } ch << New(ccName) << DUP - if (params.requireMonitor) - ch << ALoad(locals.monitorIndex) + if (requireMonitor) { + load(monitorID, ch) + } + for((a, vd) <- as zip cct.classDef.fields) { vd.getType match { case TypeParameter(_) => @@ -404,30 +723,29 @@ trait CodeGeneration { throw CompilationException("Unknown method : " + tfd.id) } - if (params.requireMonitor) { - // index of monitor object will be before the first Scala parameter - ch << ALoad(locals.monitorIndex) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + if (requireMonitor) { + load(monitorID, ch) + ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") } - + // Get static field ch << GetStatic(className, fieldName, typeToJVM(tfd.fd.returnType)) - + // unbox field (tfd.fd.returnType, tfd.returnType) match { case (TypeParameter(_), tpe) => mkUnbox(tpe, ch) case _ => } - + case FunctionInvocation(TypedFunDef(fd, Seq(tp)), Seq(set)) if fd == program.library.setToList.get => - val IteratorClass = "java/util/Iterator" val nil = CaseClass(CaseClassType(program.library.Nil.get, Seq(tp)), Seq()) val cons = program.library.Cons.get val (consName, ccApplySig) = leonClassToJVMInfo(cons).getOrElse { throw CompilationException("Unknown class : " + cons) } - + mkExpr(nil, ch) mkExpr(set, ch) //if (params.requireMonitor) { @@ -436,49 +754,50 @@ trait CodeGeneration { // No dynamic dispatching/overriding in Leon, // so no need to take care of own vs. "super" methods - ch << InvokeVirtual(SetClass, "getElements", s"()L$IteratorClass;") - + ch << InvokeVirtual(SetClass, "getElements", s"()L$JavaIteratorClass;") + val loop = ch.getFreshLabel("loop") val out = ch.getFreshLabel("out") ch << Label(loop) // list, it ch << DUP // list, it, it - ch << InvokeInterface(IteratorClass, "hasNext", "()Z") + ch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") // list, it, hasNext ch << IfEq(out) // list, it ch << DUP2 // list, it, list, it - ch << InvokeInterface(IteratorClass, "next", s"()L$ObjectClass;") << SWAP + ch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") << SWAP // list, it, elem, list ch << New(consName) << DUP << DUP2_X2 // list, it, cons, cons, elem, list, cons, cons ch << POP << POP // list, it, cons, cons, elem, list - - if (params.requireMonitor) { - ch << ALoad(locals.monitorIndex) << DUP_X2 << POP + + if (requireMonitor) { + load(monitorID, ch) + ch << DUP_X2 << POP } ch << InvokeSpecial(consName, constructorName, ccApplySig) // list, it, newList ch << DUP_X2 << POP << SWAP << POP // newList, it ch << Goto(loop) - + ch << Label(out) // list, it ch << POP // list - + // Static lazy fields/ functions case fi @ FunctionInvocation(tfd, as) => val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - - if (params.requireMonitor) { - ch << ALoad(locals.monitorIndex) + + if (requireMonitor) { + load(monitorID, ch) } for((a, vd) <- as zip tfd.fd.params) { @@ -497,16 +816,16 @@ trait CodeGeneration { mkUnbox(tpe, ch) case _ => } - + // Strict fields are handled as fields case MethodInvocation(rec, _, tfd, _) if tfd.fd.canBeStrictField => val (className, fieldName, _) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - if (params.requireMonitor) { - // index of monitor object will be before the first Scala parameter - ch << ALoad(locals.monitorIndex) << InvokeVirtual(MonitorClass, "onInvoke", "()V") + if (requireMonitor) { + load(monitorID, ch) + ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") } // Load receiver mkExpr(rec,ch) @@ -520,19 +839,19 @@ trait CodeGeneration { mkUnbox(tpe, ch) case _ => } - + // This is for lazy fields and real methods. // To access a lazy field, we call its accessor function. case MethodInvocation(rec, cd, tfd, as) => val (className, methodName, sig) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - + // Receiver of the method call mkExpr(rec,ch) - - if (params.requireMonitor) { - ch << ALoad(locals.monitorIndex) + + if (requireMonitor) { + load(monitorID, ch) } for((a, vd) <- as zip tfd.fd.params) { @@ -543,10 +862,10 @@ trait CodeGeneration { mkExpr(a, ch) } } - + // No interfaces in Leon, so no need to use InvokeInterface ch << InvokeVirtual(className, methodName, sig) - + (tfd.fd.returnType, tfd.returnType) match { case (TypeParameter(_), tpe) => mkUnbox(tpe, ch) @@ -566,84 +885,11 @@ trait CodeGeneration { mkUnbox(app.getType, ch) case l @ Lambda(args, body) => - val afName = "Leon$CodeGen$Lambda$" + lambdaCounter.nextGlobal - lambdas += afName -> l - - val cf = new ClassFile(afName, Some(LambdaClass)) - - cf.setFlags(( - CLASS_ACC_SUPER | - CLASS_ACC_PUBLIC | - CLASS_ACC_FINAL - ).asInstanceOf[U2]) - - val closures = purescala.ExprOps.variablesOf(l).toSeq.sortBy(_.uniqueName) - val closureTypes = closures.map(id => id.name -> typeToJVM(id.getType)) - - if (closureTypes.isEmpty) { - cf.addDefaultConstructor - } else { - for ((nme, jvmt) <- closureTypes) { - val fh = cf.addField(jvmt, nme) - fh.setFlags(( - FIELD_ACC_PUBLIC | - FIELD_ACC_FINAL - ).asInstanceOf[U2]) - } - - val cch = cf.addConstructor(closureTypes.map(_._2).toList).codeHandler - - cch << ALoad(0) - cch << InvokeSpecial(LambdaClass, constructorName, "()V") - - var c = 1 - for ((nme, jvmt) <- closureTypes) { - cch << ALoad(0) - cch << (jvmt match { - case "I" | "Z" => ILoad(c) - case _ => ALoad(c) - }) - cch << PutField(afName, nme, jvmt) - c += 1 - } - - cch << RETURN - cch.freeze - } - - locally { - - val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") - - apm.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val argMapping = args.map(_.id).zipWithIndex.toMap - val closureMapping = (closures zip closureTypes).map { case (id, (name, tpe)) => id -> (afName, name, tpe) }.toMap - - val newLocals = locals.withArgs(argMapping).withClosures(closureMapping) - - val apch = apm.codeHandler + compileLambda(l, ch) - mkBoxedExpr(body, apch)(newLocals) + case f @ Forall(args, body) => + compileForall(f, ch) - apch << ARETURN - - apch.freeze - } - - loader.register(cf) - - val consSig = "(" + closures.map(id => typeToJVM(id.getType)).mkString("") + ")V" - - ch << New(afName) << DUP - for (a <- closures) { - mkExpr(Variable(a), ch) - } - ch << InvokeSpecial(afName, constructorName, consSig) - // Arithmetic case Plus(l, r) => mkExpr(l, ch) @@ -1133,7 +1379,7 @@ trait CodeGeneration { * @param lzy The lazy field to be compiled * @param owner The module/class containing `lzy` */ - def compileLazyField(lzy : FunDef, owner : Definition) { + def compileLazyField(lzy: FunDef, owner: Definition) { ctx.reporter.internalAssertion(lzy.canBeLazyField, s"Trying to compile non-lazy ${lzy.id.name} as a lazy field") val (_, accessorName, _ ) = leonFunDefToJVMInfo(lzy).get @@ -1158,7 +1404,7 @@ trait CodeGeneration { // accessor method locally { - val parameters = if (params.requireMonitor) { + val parameters = if (requireMonitor) { Seq("L" + MonitorClass + ";") } else Seq() @@ -1173,7 +1419,7 @@ trait CodeGeneration { val body = lzy.body.getOrElse(throw CompilationException("Lazy field without body?")) val initLabel = ch.getFreshLabel("isInitialized") - if (params.requireMonitor) { + if (requireMonitor) { ch << ALoad(if (isStatic) 0 else 1) << InvokeVirtual(MonitorClass, "onInvoke", "()V") } @@ -1271,9 +1517,8 @@ trait CodeGeneration { } else { ch << ALoad(0) << SWAP << PutField (className, name, jvmType) } - } - - + } + def compileAbstractClassDef(acd : AbstractClassDef) { val cName = defToJVMName(acd) @@ -1314,8 +1559,7 @@ trait CodeGeneration { // definition of the constructor locally { - - val constrParams = if (params.requireMonitor) { + val constrParams = if (requireMonitor) { Seq("L" + MonitorClass + ";") } else Seq() @@ -1330,8 +1574,8 @@ trait CodeGeneration { case Some(parent) => val pName = defToJVMName(parent.classDef) // Load monitor object - if (params.requireMonitor) cch << ALoad(1) - val constrSig = if (params.requireMonitor) "(L" + MonitorClass + ";)V" else "()V" + if (requireMonitor) cch << ALoad(1) + val constrSig = if (requireMonitor) "(L" + MonitorClass + ";)V" else "()V" cch << InvokeSpecial(pName, constructorName, constrSig) case None => @@ -1349,7 +1593,7 @@ trait CodeGeneration { cch << RETURN cch.freeze } - + } /** @@ -1385,8 +1629,6 @@ trait CodeGeneration { } } - - def compileCaseClassDef(ccd: CaseClassDef) { val cName = defToJVMName(ccd) @@ -1407,7 +1649,6 @@ trait CodeGeneration { } locally { - val (fields, methods) = ccd.methods partition { _.canBeField } val (strictFields, lazyFields) = fields partition { _.canBeStrictField } @@ -1430,7 +1671,7 @@ trait CodeGeneration { val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.getType)) } // definition of the constructor - if(!params.doInstrument && !params.requireMonitor && ccd.fields.isEmpty && !ccd.methods.exists(_.canBeField)) { + if(!params.doInstrument && !requireMonitor && ccd.fields.isEmpty && !ccd.methods.exists(_.canBeField)) { cf.addDefaultConstructor } else { for((nme, jvmt) <- namesTypes) { @@ -1447,13 +1688,10 @@ trait CodeGeneration { } // If we are monitoring function calls, we have an extra argument on the constructor - val realArgs = if (params.requireMonitor) { + val realArgs = if (requireMonitor) { ("L" + MonitorClass + ";") +: (namesTypes map (_._2)) } else namesTypes map (_._2) - // Offset of the first Scala parameter of the constructor - val paramOffset = if (params.requireMonitor) 2 else 1 - val cch = cf.addConstructor(realArgs.toList).codeHandler if (params.doInstrument) { @@ -1462,7 +1700,7 @@ trait CodeGeneration { cch << PutField(cName, instrumentedField, "I") } - var c = paramOffset + var c = if (requireMonitor) 2 else 1 for((nme, jvmt) <- namesTypes) { cch << ALoad(0) cch << (jvmt match { @@ -1478,8 +1716,8 @@ trait CodeGeneration { // Load this cch << ALoad(0) // Load monitor object - if (params.requireMonitor) cch << ALoad(1) - val constrSig = if (params.requireMonitor) "(L" + MonitorClass + ";)V" else "()V" + if (requireMonitor) cch << ALoad(1) + val constrSig = if (requireMonitor) "(L" + MonitorClass + ";)V" else "()V" cch << InvokeSpecial(pName.get, constructorName, constrSig) } else { // Call constructor of java.lang.Object @@ -1487,7 +1725,6 @@ trait CodeGeneration { cch << InvokeSpecial(ObjectClass, constructorName, "()V") } - // Now initialize fields for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)} for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)} @@ -1612,7 +1849,7 @@ trait CodeGeneration { ).asInstanceOf[U2]) val hch = hmh.codeHandler - + val wasNotCached = hch.getFreshLabel("wasNotCached") hch << ALoad(0) << GetField(cName, hashFieldName, "I") << DUP @@ -1625,7 +1862,7 @@ trait CodeGeneration { hch << InvokeStatic(HashingClass, "seqHash", s"([L$ObjectClass;I)I") << DUP hch << ALoad(0) << SWAP << PutField(cName, hashFieldName, "I") hch << IRETURN - + hch.freeze } diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index a4dc5c0e3ac05aa58f16b7363a1b61a419737e3b..a4597b88b7cd5111f12b1bc4d8b97fcda2487a86 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -6,11 +6,12 @@ package codegen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ +import purescala.ExprOps._ import purescala.Types._ import purescala.Extractors._ import purescala.Constructors._ - -import runtime.LeonCodeGenRuntimeMonitor +import codegen.runtime.LeonCodeGenRuntimeMonitor +import codegen.runtime.LeonCodeGenRuntimeHenkinMonitor import utils.UniqueCounter import cafebabe._ @@ -28,12 +29,17 @@ class CompilationUnit(val ctx: LeonContext, val program: Program, val params: CodeGenParams = CodeGenParams.default) extends CodeGeneration { + protected[codegen] val requireQuantification = program.definedFunctions.exists { fd => + exists { case _: Forall => true case _ => false } (fd.fullBody) + } + + protected[codegen] val requireMonitor = params.requireMonitor || requireQuantification + val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) - var lambdas = Map[String, Lambda]() - var classes = Map[Definition, ClassFile]() + var classes = Map[Definition, ClassFile]() var defToModuleOrClass = Map[Definition, Definition]() - + def defineClass(df: Definition) { val cName = defToJVMName(df) @@ -59,7 +65,7 @@ class CompilationUnit(val ctx: LeonContext, def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { classes.get(cd) match { case Some(cf) => - val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else "" + val monitorType = if (requireMonitor) "L"+MonitorClass+";" else "" val sig = "(" + monitorType + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" Some((cf.className, sig)) case _ => None @@ -78,7 +84,7 @@ class CompilationUnit(val ctx: LeonContext, */ def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { funDefInfo.get(fd).orElse { - val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else "" + val monitorType = if (requireMonitor) "L"+MonitorClass+";" else "" val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType) @@ -121,6 +127,21 @@ class CompilationUnit(val ctx: LeonContext, conss.last } + def modelToJVM(model: solvers.Model, maxInvocations: Int): LeonCodeGenRuntimeMonitor = model match { + case hModel: solvers.HenkinModel => + val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations) + for ((tpe, domain) <- hModel.domains; args <- domain) { + val tpeId = typeId(tpe) + // note here that it doesn't matter that `lhm` doesn't yet have its domains + // filled since all values in `args` should be grounded + val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] + lhm.add(tpeId, inputJvm) + } + lhm + case _ => + new LeonCodeGenRuntimeMonitor(maxInvocations) + } + /** Translates Leon values (not generic expressions) to JVM compatible objects. * * Currently, this method is only used to prepare arguments to reflective calls. @@ -155,7 +176,7 @@ class CompilationUnit(val ctx: LeonContext, case CaseClass(cct, args) => caseClassConstructor(cct.classDef) match { case Some(cons) => - val realArgs = if (params.requireMonitor) monitor +: args.map(valueToJVM) else args.map(valueToJVM) + val realArgs = if (requireMonitor) monitor +: args.map(valueToJVM) else args.map(valueToJVM) cons.newInstance(realArgs.toArray : _*).asInstanceOf[AnyRef] case None => ctx.reporter.fatalError("Case class constructor not found?!?") @@ -180,11 +201,9 @@ class CompilationUnit(val ctx: LeonContext, } m - case f @ purescala.Extractors.FiniteLambda(dflt, els) => - val l = new leon.codegen.runtime.FiniteLambda(valueToJVM(dflt)) - - for ((k,v) <- els) { - val ks = unwrapTuple(k, f.getType.asInstanceOf[FunctionType].from.size) + case f @ PartialLambda(mapping, _) => + val l = new leon.codegen.runtime.PartialLambda() + for ((ks,v) <- mapping) { // Force tuple even with 1/0 elems. val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple] val vJvm = valueToJVM(v) @@ -261,7 +280,7 @@ class CompilationUnit(val ctx: LeonContext, case (lambda: runtime.Lambda, _: FunctionType) => val cls = lambda.getClass - val l = lambdas(cls.getName) + val l = classToLambda(cls.getName) val closures = purescala.ExprOps.variablesOf(l).toSeq.sortBy(_.uniqueName) val closureVals = closures.map { id => val fieldVal = lambda.getClass.getField(id.name).get(lambda) @@ -308,7 +327,7 @@ class CompilationUnit(val ctx: LeonContext, val argsTypes = args.map(a => typeToJVM(a.getType)) - val realArgs = if (params.requireMonitor) { + val realArgs = if (requireMonitor) { ("L" + MonitorClass + ";") +: argsTypes } else { argsTypes @@ -328,11 +347,11 @@ class CompilationUnit(val ctx: LeonContext, val ch = m.codeHandler - val newMapping = if (params.requireMonitor) { - args.zipWithIndex.toMap.mapValues(_ + 1) - } else { - args.zipWithIndex.toMap - } + val newMapping = if (requireMonitor) { + args.zipWithIndex.toMap.mapValues(_ + 1) + (monitorID -> 0) + } else { + args.zipWithIndex.toMap + } mkExpr(e, ch)(Locals(newMapping, Map.empty, Map.empty, isStatic = true)) @@ -454,7 +473,6 @@ class CompilationUnit(val ctx: LeonContext, } for (m <- u.modules) compileModule(m) - } classes.values.foreach(loader.register) diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index e31e28e92922875a72a3a9cd233fade73a25223c..ad012bb74f4606bfeb53924c97015a2c3ce54fe5 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -8,11 +8,11 @@ import purescala.Expressions._ import cafebabe._ -import runtime.{LeonCodeGenRuntimeMonitor => LM } +import runtime.{LeonCodeGenRuntimeMonitor => LM, LeonCodeGenRuntimeHenkinMonitor => LHM} import java.lang.reflect.InvocationTargetException -class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr, argsDecl: Seq[Identifier]) { +class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, argsDecl: Seq[Identifier]) { private lazy val cl = unit.loader.loadClass(cf.className) private lazy val meth = cl.getMethods()(0) @@ -28,7 +28,7 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr def evalToJVM(args: Seq[AnyRef], monitor: LM): AnyRef = { assert(args.size == argsDecl.size) - val realArgs = if (params.requireMonitor) { + val realArgs = if (unit.requireMonitor) { monitor +: args } else { args @@ -51,11 +51,10 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr } } - def eval(args: Seq[Expr]) : Expr = { + def eval(model: solvers.Model) : Expr = { try { - val monitor = - new LM(params.maxFunctionInvocations) - evalFromJVM(argsToJVM(args, monitor),monitor) + val monitor = unit.modelToJVM(model, params.maxFunctionInvocations) + evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) } catch { case ite : InvocationTargetException => throw ite.getCause } diff --git a/src/main/scala/leon/datagen/NaiveDataGen.scala b/src/main/scala/leon/datagen/NaiveDataGen.scala index 3949935200be25d5c219e1752ea1b197c173ed02..121c42b8cd561fe5faecd6d59662e66f1c1c3e02 100644 --- a/src/main/scala/leon/datagen/NaiveDataGen.scala +++ b/src/main/scala/leon/datagen/NaiveDataGen.scala @@ -7,6 +7,7 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Types._ import purescala.Definitions._ +import purescala.Quantification._ import utils.StreamUtils._ import evaluators._ @@ -87,7 +88,7 @@ class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds : evaluator.compile(satisfying, ins).map { evalFun => val sat = EvaluationResults.Successful(BooleanLiteral(true)) - { (e: Seq[Expr]) => evalFun(e) == sat } + { (e: Seq[Expr]) => evalFun(new solvers.Model((ins zip e).toMap)) == sat } } getOrElse { { (e: Seq[Expr]) => false } } diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 0cf311562c9ffe37020186cf96365ba5b15b8178..b742ec58f68dc34b35cd2e093578c4e19b2462ef 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -224,7 +224,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { type InstrumentedResult = (EvaluationResults.Result, Option[vanuatoo.Pattern[Expr, TypeTree]]) - def compile(expression : Expr, argorder : Seq[Identifier]) : Option[Expr=>InstrumentedResult] = { + def compile(expression: Expr, argorder: Seq[Identifier]) : Option[Expr=>InstrumentedResult] = { import leon.codegen.runtime.LeonCodeGenRuntimeException import leon.codegen.runtime.LeonCodeGenEvaluationException @@ -241,7 +241,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { Some((args : Expr) => { try { val monitor = new LeonCodeGenRuntimeMonitor(unit.params.maxFunctionInvocations) - val jvmArgs = ce.argsToJVM(Seq(args), monitor ) + val jvmArgs = ce.argsToJVM(Seq(args), monitor) val result = ce.evalFromJVM(jvmArgs, monitor) diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index 64456226b16815c056202738c7946c2e0cef9968..769748dff385e23686681eabfe4fe130f70e8832 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -6,11 +6,12 @@ package evaluators import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ +import purescala.Quantification._ import codegen.CompilationUnit import codegen.CodeGenParams -class CodeGenEvaluator(ctx : LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) { +class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) { val name = "codegen-eval" val description = "Evaluator for PureScala expressions based on compilation to JVM" @@ -19,28 +20,29 @@ class CodeGenEvaluator(ctx : LeonContext, val unit : CompilationUnit) extends Ev this(ctx, new CompilationUnit(ctx, prog, params)) } - def eval(expression : Expr, mapping : Map[Identifier,Expr]) : EvaluationResult = { - val toPairs = mapping.toSeq + def eval(expression: Expr, model: solvers.Model) : EvaluationResult = { + val toPairs = model.toSeq compile(expression, toPairs.map(_._1)).map { e => - ctx.timers.evaluators.codegen.runtime.start() - val res = e(toPairs.map(_._2)) + val res = e(model) ctx.timers.evaluators.codegen.runtime.stop() res }.getOrElse(EvaluationResults.EvaluatorError("Couldn't compile expression.")) } - override def compile(expression : Expr, argorder : Seq[Identifier]) : Option[Seq[Expr]=>EvaluationResult] = { + override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { import leon.codegen.runtime.LeonCodeGenRuntimeException import leon.codegen.runtime.LeonCodeGenEvaluationException ctx.timers.evaluators.codegen.compilation.start() try { - val ce = unit.compileExpression(expression, argorder)(ctx) + val ce = unit.compileExpression(expression, args)(ctx) - Some((args : Seq[Expr]) => { - try { - EvaluationResults.Successful(ce.eval(args)) + Some((model: solvers.Model) => { + if (args.exists(arg => !model.isDefinedAt(arg))) { + EvaluationResults.EvaluatorError("Model undefined for free arguments") + } else try { + EvaluationResults.Successful(ce.eval(model)) } catch { case e : ArithmeticException => EvaluationResults.RuntimeError(e.getMessage) @@ -65,6 +67,7 @@ class CodeGenEvaluator(ctx : LeonContext, val unit : CompilationUnit) extends Ev } catch { case t: Throwable => ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage) + t.printStackTrace() None } finally { ctx.timers.evaluators.codegen.compilation.stop() diff --git a/src/main/scala/leon/evaluators/DefaultEvaluator.scala b/src/main/scala/leon/evaluators/DefaultEvaluator.scala index 148ac359a6e7db5b61db55bd8b3a517776a39609..d732d48c0d40aaacf97cd8125b077c1cef397148 100644 --- a/src/main/scala/leon/evaluators/DefaultEvaluator.scala +++ b/src/main/scala/leon/evaluators/DefaultEvaluator.scala @@ -6,13 +6,14 @@ package evaluators import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ +import purescala.Quantification._ class DefaultEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog, 50000) { type RC = DefaultRecContext type GC = GlobalContext def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC = new GlobalContext() + def initGC(model: solvers.Model) = new GlobalContext(model) case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext { def newVars(news: Map[Identifier, Expr]) = copy(news) diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index a058dd3a04fe53700aa71265ba5952b9fc82f6ce..cd843fbb145e4f9220b2e9fe3d91e30d8ff3c1be 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -6,6 +6,7 @@ package evaluators import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ +import purescala.Quantification._ import purescala.Types._ import codegen._ @@ -17,7 +18,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte implicit val debugSection = utils.DebugSectionEvaluation def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) - def initGC = new GlobalContext() + def initGC(model: solvers.Model) = new GlobalContext(model) var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) @@ -125,9 +126,9 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte } - override def eval(ex: Expr, mappings: Map[Identifier, Expr]) = { + override def eval(ex: Expr, model: solvers.Model) = { monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) - super.eval(ex, mappings) + super.eval(ex, model) } } diff --git a/src/main/scala/leon/evaluators/Evaluator.scala b/src/main/scala/leon/evaluators/Evaluator.scala index fa1f18352381c220610184c147a52b8537335fed..e24c0e364dfca43be0da185aab0a4d549d2ff785 100644 --- a/src/main/scala/leon/evaluators/Evaluator.scala +++ b/src/main/scala/leon/evaluators/Evaluator.scala @@ -6,27 +6,35 @@ package evaluators import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ +import purescala.Quantification._ +import purescala.ExprOps._ -abstract class Evaluator(val context : LeonContext, val program : Program) extends LeonComponent { +import solvers.Model + +abstract class Evaluator(val context: LeonContext, val program: Program) extends LeonComponent { type EvaluationResult = EvaluationResults.Result /** Evaluates an expression, using `mapping` as a valuation function for the free variables. */ - def eval(expr: Expr, mapping: Map[Identifier,Expr]) : EvaluationResult + def eval(expr: Expr, model: Model) : EvaluationResult + + /** Evaluates an expression given a simple model (assumes expr is quantifier-free). + * Mainly useful for compatibility reasons. + */ + final def eval(expr: Expr, mapping: Map[Identifier, Expr]) : EvaluationResult = eval(expr, new Model(mapping)) /** Evaluates a ground expression. */ - final def eval(expr: Expr) : EvaluationResult = eval(expr, Map.empty) + final def eval(expr: Expr) : EvaluationResult = eval(expr, Model.empty) /** Compiles an expression into a function, where the arguments are the free variables in the expression. * `argorder` specifies in which order the arguments should be passed. * The default implementation uses the evaluation function each time, but evaluators are free * to (and encouraged to) apply any specialization. */ - def compile(expr : Expr, argorder : Seq[Identifier]) : Option[Seq[Expr]=>EvaluationResult] = Some( - (args : Seq[Expr]) => if(args.size != argorder.size) { - EvaluationResults.EvaluatorError("Wrong number of arguments for evaluation.") + def compile(expr: Expr, args: Seq[Identifier]) : Option[Model => EvaluationResult] = Some( + (model: Model) => if(args.exists(arg => !model.isDefinedAt(arg))) { + EvaluationResults.EvaluatorError("Wrong number of arguments for evaluation.") } else { - val mapping = argorder.zip(args).toMap - eval(expr, mapping) + eval (expr, model) } ) } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 38710837fb255db8596334a1fe818b62835556fe..8ca885dbed64a3b51426182fd2e8a97f65f2f421 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -11,7 +11,9 @@ import purescala.Types._ import purescala.TypeOps.isSubtypeOf import purescala.Constructors._ import purescala.Extractors._ +import purescala.Quantification._ +import solvers.{Model, HenkinModel} import solvers.SolverFactory import synthesis.ConvertHoles.convertHoles @@ -41,25 +43,25 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } } - class GlobalContext { + class GlobalContext(val model: Model) { def maxSteps = RecursiveEvaluator.this.maxSteps var stepsLeft = maxSteps } def initRC(mappings: Map[Identifier, Expr]): RC - def initGC(): GC + def initGC(model: Model): GC // Used by leon-web, please do not delete var lastGC: Option[GC] = None private[this] var clpCache = Map[(Choose, Seq[Expr]), Expr]() - def eval(ex: Expr, mappings: Map[Identifier, Expr]) = { + def eval(ex: Expr, model: Model) = { try { - lastGC = Some(initGC()) + lastGC = Some(initGC(model)) ctx.timers.evaluators.recursive.runtime.start() - EvaluationResults.Successful(e(ex)(initRC(mappings), lastGC.get)) + EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get)) } catch { case so: StackOverflowError => EvaluationResults.EvaluatorError("Stack overflow") @@ -87,10 +89,16 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Application(caller, args) => e(caller) match { - case l@Lambda(params, body) => + case l @ Lambda(params, body) => val newArgs = args.map(e) val mapping = l.paramSubst(newArgs) e(body)(rctx.withNewVars(mapping), gctx) + case PartialLambda(mapping, _) => + mapping.find { case (pargs, res) => + (args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true)) + }.map(_._2).getOrElse { + throw EvalError("Cannot apply partial lambda outside of domain") + } case f => throw EvalError("Cannot apply non-lambda function " + f.asString) } @@ -217,6 +225,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (lv,rv) match { case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) + case (PartialLambda(m1, _), PartialLambda(m2, _)) => BooleanLiteral(m1.toSet == m2.toSet) case _ => BooleanLiteral(lv == rv) } @@ -487,8 +496,71 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int FiniteSet(els.map(e), base) case l @ Lambda(_, _) => - val mapping = variablesOf(l).map(id => id -> e(Variable(id))).toMap - replaceFromIDs(mapping, l) + val (nl, structSubst) = normalizeStructure(l) + val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap + replaceFromIDs(mapping, nl) + + case PartialLambda(mapping, tpe) => + PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), tpe) + + case f @ Forall(fargs, TopLevelAnds(conjuncts)) => + val henkinModel: HenkinModel = gctx.model match { + case hm: HenkinModel => hm + case _ => throw EvalError("Can't evaluate foralls without henkin model") + } + + e(andJoin(for (conj <- conjuncts) yield { + val vars = variablesOf(conj) + val args = fargs.map(_.id).filter(vars) + val quantified = args.toSet + + val matcherQuorums = extractQuorums(conj, quantified) + + val instantiations = matcherQuorums.flatMap { quorum => + var mappings: Seq[(Identifier, Int, Int)] = Seq.empty + var constraints: Seq[(Expr, Int, Int)] = Seq.empty + + for (((expr, args), qidx) <- quorum.zipWithIndex) { + val (qmappings, qconstraints) = args.zipWithIndex.partition { + case (Variable(id),aidx) => quantified(id) + case _ => false + } + + mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) + constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) + } + + var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty + val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { + val base :: others = es.toList.map(p => (p._2, p._3)) + equalities ++= others.map(p => base -> p) + (id -> base) + } + + val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { + case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d)) + } + + argSets.map { args => + val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { + case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } + }.toMap + + val map = mapping.map { case (id, key) => id -> argMap(key) } + val enabler = andJoin(constraints.map { + case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) + } ++ equalities.map { + case (k1, k2) => Equals(argMap(k1), argMap(k2)) + }) + + (enabler, map) + } + } + + e(andJoin(instantiations.map { case (enabler, mapping) => + e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx) + })) + })) case ArrayLength(a) => val FiniteArray(_, _, IntLiteral(length)) = e(a) diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index ea5ad3e0d76b5fa5230aa203d29ba214c7d3f15f..ec977763f3da3f583d225cbdde8dd98aa33f312b 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -6,6 +6,7 @@ package evaluators import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ +import purescala.Quantification._ import purescala.Types._ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) extends RecursiveEvaluator(ctx, prog, maxSteps) { @@ -14,9 +15,9 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex def initRC(mappings: Map[Identifier, Expr]) = TracingRecContext(mappings, 2) - def initGC() = new TracingGlobalContext(Nil) + def initGC(model: solvers.Model) = new TracingGlobalContext(Nil, model) - class TracingGlobalContext(var values: List[(Tree, Expr)]) extends GlobalContext + class TracingGlobalContext(var values: List[(Tree, Expr)], model: solvers.Model) extends GlobalContext(model) case class TracingRecContext(mappings: Map[Identifier, Expr], tracingFrames: Int) extends RecContext { def newVars(news: Map[Identifier, Expr]) = copy(mappings = news) diff --git a/src/main/scala/leon/purescala/CheckForalls.scala b/src/main/scala/leon/purescala/CheckForalls.scala deleted file mode 100644 index bb9874373cc560794ec429ae2bbad6ea21a6c1dc..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/CheckForalls.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package purescala - -import Common._ -import Definitions._ -import Expressions._ -import Extractors._ -import ExprOps._ - -object CheckForalls extends UnitPhase[Program] { - - val name = "Foralls" - val description = "Check syntax of foralls to guarantee sound instantiations" - - def apply(ctx: LeonContext, program: Program) = { - program.definedFunctions.foreach { fd => - if (fd.body.exists(b => exists { - case f: Forall => true - case _ => false - } (b))) ctx.reporter.warning("Universal quantification in function bodies is not supported in " + fd) - - val foralls = (fd.precondition.toSeq ++ fd.postcondition.toSeq).flatMap { prec => - collect[Forall] { - case f: Forall => Set(f) - case _ => Set.empty - } (prec) - } - - val free = fd.params.map(_.id).toSet ++ (fd.postcondition match { - case Some(Lambda(args, _)) => args.map(_.id) - case _ => Seq.empty - }) - - object Matcher { - def unapply(e: Expr): Option[(Identifier, Seq[Expr])] = e match { - case Application(Variable(id), args) if free(id) => Some(id -> args) - case ArraySelect(Variable(id), index) if free(id) => Some(id -> Seq(index)) - case MapGet(Variable(id), key) if free(id) => Some(id -> Seq(key)) - case _ => None - } - } - - for (Forall(args, TopLevelAnds(conjuncts)) <- foralls) { - val quantified = args.map(_.id).toSet - - for (conjunct <- conjuncts) { - val matchers = collect[(Identifier, Seq[Expr])] { - case Matcher(id, args) => Set(id -> args) - case _ => Set.empty - } (conjunct) - - if (matchers.exists { case (id, args) => - args.exists(arg => arg match { - case Matcher(_, _) => false - case Variable(id) => false - case _ if (variablesOf(arg) & quantified).nonEmpty => true - case _ => false - }) - }) ctx.reporter.warning("Matcher arguments must have simple form in " + conjunct) - - val id2Quant = matchers.foldLeft(Map.empty[Identifier, Set[Identifier]]) { - case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - })) - } - - if (id2Quant.filter(_._2.nonEmpty).groupBy(_._2).size != 1) - ctx.reporter.warning("Multiple matchers must provide bijective matching in " + conjunct) - - foldRight[Set[Identifier]] { case (m, children) => - val q = children.toSet.flatten - - m match { - case Matcher(_, args) => - q -- args.flatMap { - case Variable(id) if quantified(id) => Set(id) - case _ => Set.empty[Identifier] - } - case LessThan(_: Variable, _: Variable) => q - case LessEquals(_: Variable, _: Variable) => q - case GreaterThan(_: Variable, _: Variable) => q - case GreaterEquals(_: Variable, _: Variable) => q - case And(_) => q - case Or(_) => q - case Implies(_, _) => q - case Operator(es, _) => - val vars = es.flatMap { - case Variable(id) => Set(id) - case _ => Set.empty[Identifier] - }.toSet - - if (!(q.isEmpty || (q.size == 1 && (vars & free).isEmpty))) - ctx.reporter.warning("Invalid operation " + m + " on quantified variables") - q -- vars - case Variable(id) if quantified(id) => Set(id) - case _ => q - } - } (conjunct) - } - } - } - } -} diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 9d347eccf8797ec3157929f5a58796ffaaae7cad..7747ec680948859ea6099c595578705ad33f27ed 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -273,23 +273,6 @@ object Constructors { NonemptyArray(els.zipWithIndex.map{ _.swap }.toMap, defaultLength) } - /** Takes a mapping from keys to values and a default expression and return a lambda of the form - * {{{ - * (x1, ..., xn) => - * if ( key1 == (x1, ..., xn) ) value1 - * else if ( key2 == (x1, ..., xn) ) value2 - * ... - * else default - * }}} - */ - def finiteLambda(default: Expr, els: Seq[(Expr, Expr)], inputTypes: Seq[TypeTree]): Lambda = { - val args = inputTypes map { tpe => ValDef(FreshIdentifier("x", tpe, true)) } - val argsExpr = tupleWrap(args map { _.toVariable }) - val body = els.foldRight(default) { case ((key, value), default) => - IfExpr(Equals(argsExpr, key), value, default) - } - Lambda(args, body) - } /** $encodingof simplified `... == ...` (equality). * @see [[purescala.Expressions.Equals Equals]] */ diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 319d0d6102602dfc982444f02259fbb23097129f..7a8056c6a9492c723a81b6ef1854f4a564816d7e 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -71,7 +71,7 @@ object Definitions { lazy val library = Library(this) def subDefinitions = units - + def definedFunctions = units.flatMap(_.definedFunctions) def definedClasses = units.flatMap(_.definedClasses) def classHierarchyRoots = units.flatMap(_.classHierarchyRoots) @@ -81,7 +81,7 @@ object Definitions { case md: ModuleDef => md }) } - + lazy val callGraph = new CallGraph(this) def caseClassDef(name: String) = definedClasses.collectFirst { diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 3f641000ae024b6482a42f6781e8d641d431fde4..1d98f57b87c86d91e61bee3e68c0894feac6780f 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -26,7 +26,7 @@ import solvers._ * - [[ExprOps.postMap postMap]] * - [[ExprOps.genericTransform genericTransform]] * - * These operations usually take a higher order function that gets apply to the + * These operations usually take a higher order function that gets applied to the * expression tree in some strategy. They provide an expressive way to build complex * operations on Leon expressions. * @@ -317,8 +317,8 @@ object ExprOps { case Variable(i) => subvs + i case LetDef(fd,_) => subvs -- fd.params.map(_.id) case Let(i,_,_) => subvs - i - case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders) - case Passes(_, _ , cses) => subvs -- cses.flatMap(_.pattern.binders) + case MatchExpr(_, cses) => subvs -- cses.flatMap(_.pattern.binders) + case Passes(_, _, cses) => subvs -- cses.flatMap(_.pattern.binders) case Lambda(args, _) => subvs -- args.map(_.id) case Forall(args, _) => subvs -- args.map(_.id) case _ => subvs @@ -369,19 +369,23 @@ object ExprOps { }).setPos(expr) } + def replacePatternBinders(pat: Pattern, subst: Map[Identifier, Identifier]): Pattern = { + def rec(p: Pattern): Pattern = p match { + case InstanceOfPattern(ob, ctd) => InstanceOfPattern(ob map subst, ctd) + case WildcardPattern(ob) => WildcardPattern(ob map subst) + case TuplePattern(ob, sps) => TuplePattern(ob map subst, sps map rec) + case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob map subst, ccd, sps map rec) + case UnapplyPattern(ob, obj, sps) => UnapplyPattern(ob map subst, obj, sps map rec) + case LiteralPattern(ob, lit) => LiteralPattern(ob map subst, lit) + } + + rec(pat) + } + /** ATTENTION: Unused, and untested * rewrites pattern-matching expressions to use fresh variables for the binders */ def freshenLocals(expr: Expr) : Expr = { - def rewritePattern(p: Pattern, sm: Map[Identifier,Identifier]) : Pattern = p match { - case InstanceOfPattern(ob, ctd) => InstanceOfPattern(ob map sm, ctd) - case WildcardPattern(ob) => WildcardPattern(ob map sm) - case TuplePattern(ob, sps) => TuplePattern(ob.map(sm(_)), sps.map(rewritePattern(_, sm))) - case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob.map(sm(_)), ccd, sps.map(rewritePattern(_, sm))) - case UnapplyPattern(ob, obj, sps) => UnapplyPattern(ob.map(sm(_)), obj, sps.map(rewritePattern(_, sm))) - case LiteralPattern(ob, lit) => LiteralPattern(ob map sm, lit) - } - def freshenCase(cse: MatchCase) : MatchCase = { val allBinders: Set[Identifier] = cse.pattern.binders val subMap: Map[Identifier,Identifier] = @@ -389,7 +393,7 @@ object ExprOps { val subVarMap: Map[Expr,Expr] = subMap.map(kv => Variable(kv._1) -> Variable(kv._2)) MatchCase( - rewritePattern(cse.pattern, subMap), + replacePatternBinders(cse.pattern, subMap), cse.optGuard map { replace(subVarMap, _)}, replace(subVarMap,cse.rhs) ) @@ -463,6 +467,63 @@ object ExprOps { fixpoint(postMap(rec))(expr) } + private val typedIds: scala.collection.mutable.Map[TypeTree, List[Identifier]] = + scala.collection.mutable.Map.empty.withDefaultValue(List.empty) + + /** Normalizes identifiers in an expression to enable some notion of structural + * equality between expressions on which usual equality doesn't make sense + * (i.e. closures). + * + * This function relies on the static map `typedIds` to ensure identical + * structures and must therefore be synchronized. + */ + def normalizeStructure(expr: Expr): (Expr, Map[Identifier, Identifier]) = synchronized { + val allVars : Seq[Identifier] = foldRight[Seq[Identifier]] { + (expr, idSeqs) => idSeqs.foldLeft(expr match { + case Lambda(args, _) => args.map(_.id) + case Forall(args, _) => args.map(_.id) + case LetDef(fd, _) => fd.params.map(_.id) + case Let(i, _, _) => Seq(i) + case MatchExpr(_, cses) => cses.flatMap(_.pattern.binders) + case Passes(_, _, cses) => cses.flatMap(_.pattern.binders) + case Variable(id) => Seq(id) + case _ => Seq.empty[Identifier] + })((acc, seq) => acc ++ seq) + } (expr).distinct + + val grouped : Map[TypeTree, Seq[Identifier]] = allVars.groupBy(_.getType) + val subst = grouped.foldLeft(Map.empty[Identifier, Identifier]) { case (subst, (tpe, ids)) => + val currentVars = typedIds(tpe) + + val freshCount = ids.size - currentVars.size + val typedVars = if (freshCount > 0) { + val allIds = currentVars ++ List.range(0, freshCount).map(_ => FreshIdentifier("x", tpe, true)) + typedIds += tpe -> allIds + allIds + } else { + currentVars + } + + subst ++ (ids zip typedVars) + } + + val normalized = postMap { + case Lambda(args, body) => Some(Lambda(args.map(vd => ValDef(subst(vd.id), vd.tpe)), body)) + case Forall(args, body) => Some(Forall(args.map(vd => ValDef(subst(vd.id), vd.tpe)), body)) + case Let(i, e, b) => Some(Let(subst(i), e, b)) + case MatchExpr(scrut, cses) => Some(MatchExpr(scrut, cses.map { cse => + cse.copy(pattern = replacePatternBinders(cse.pattern, subst)) + })) + case Passes(in, out, cses) => Some(Passes(in, out, cses.map { cse => + cse.copy(pattern = replacePatternBinders(cse.pattern, subst)) + })) + case Variable(id) => Some(Variable(subst(id))) + case _ => None + } (expr) + + (normalized, subst) + } + /** Returns '''true''' if the formula is Ground, * which means that it does not contain any variable ([[purescala.ExprOps#variablesOf]] e is empty) * and [[purescala.ExprOps#isDeterministic isDeterministic]] @@ -1244,7 +1305,7 @@ object ExprOps { } /** Returns the value for an identifier given a model. */ - def valuateWithModel(model: Map[Identifier, Expr])(id: Identifier): Expr = { + def valuateWithModel(model: Model)(id: Identifier): Expr = { model.getOrElse(id, simplestValue(id.getType)) } @@ -1252,7 +1313,7 @@ object ExprOps { * * Complete with simplest values in case of incomplete model. */ - def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Map[Identifier, Expr]): Expr = { + def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Model): Expr = { val valuator = valuateWithModel(model) _ replace(vars.map(id => Variable(id) -> valuator(id)).toMap, expr) } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index f738c3f380e4f919b2659c5fe85a8798ae349eba..c807ef81601c672ffb335013eae264724116b9d6 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -226,6 +226,10 @@ object Expressions { } } + case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], tpe: FunctionType) extends Expr { + val getType = tpe + } + /* Universal Quantification */ case class Forall(args: Seq[ValDef], body: Expr) extends Expr { diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index c1e1ff4399f3f6ea45d16a81b7cd7c72b02b849e..d2ccfa7da010c732fdf3784b614a218e6198fada 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -41,6 +41,20 @@ object Extractors { Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head))) case Lambda(args, body) => Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head))) + case PartialLambda(mapping, tpe) => + val sze = tpe.from.size + 1 + val subArgs = mapping.flatMap { case (args, v) => args :+ v } + val builder = (as: Seq[Expr]) => { + def rec(kvs: Seq[Expr]): Seq[(Seq[Expr], Expr)] = kvs match { + case seq if seq.size >= sze => + val ((args :+ res), rest) = seq.splitAt(sze) + (args -> res) +: rec(rest) + case Seq() => Seq.empty + case _ => sys.error("unexpected number of key/value expressions") + } + PartialLambda(rec(as), tpe) + } + Some((subArgs, builder)) case Forall(args, body) => Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head))) @@ -261,39 +275,6 @@ object Extractors { def unapply[T <: Typed](e: T): Option[(T, TypeTree)] = Some((e, e.getType)) } - /* - * Extract a default expression and key-value pairs from a lambda constructed with - * Constructors.finiteLambda - */ - object FiniteLambda { - def unapply(lambda: Lambda): Option[(Expr, Seq[(Expr, Expr)])] = { - val inSize = lambda.getType.asInstanceOf[FunctionType].from.size - val Lambda(args, body) = lambda - def step(e: Expr): (Option[(Expr, Expr)], Expr) = e match { - case IfExpr(Equals(argsExpr, key), value, default) if { - val formal = args.map{ _.id } - val real = unwrapTuple(argsExpr, inSize).collect{ case Variable(id) => id} - formal == real - } => - (Some((key, value)), default) - case other => - (None, other) - } - - def rec(e: Expr): (Expr, Seq[(Expr, Expr)]) = { - step(e) match { - case (None, default) => (default, Seq()) - case (Some(pair), default) => - val (defaultRest, pairs) = rec(default) - (defaultRest, pair +: pairs) - } - } - - Some(rec(body)) - - } - } - object FiniteArray { def unapply(e: Expr): Option[(Map[Int, Expr], Option[Expr], Expr)] = e match { case EmptyArray(_) => diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala new file mode 100644 index 0000000000000000000000000000000000000000..34392526d76ed4162510eddb2b685d06b3cd6c0f --- /dev/null +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -0,0 +1,192 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Expressions._ +import Extractors._ +import ExprOps._ +import Types._ + +object Quantification { + + def extractQuorums[A,B]( + matchers: Set[A], + quantified: Set[B], + margs: A => Set[A], + qargs: A => Set[B] + ): Seq[Set[A]] = { + def rec(oms: Seq[A], mSet: Set[A], qss: Seq[Set[B]]): Seq[Set[A]] = { + if (qss.exists(_ == quantified)) { + Seq(mSet) + } else { + var res = Seq.empty[Set[A]] + val rest = oms.scanLeft(List.empty[A])((acc, o) => o :: acc).drop(1) + for ((m :: ms) <- rest if margs(m).forall(mSet)) { + val qas = qargs(m) + if (!qss.exists(qs => qs.subsetOf(qas) || qas.subsetOf(qs))) { + res ++= rec(ms, mSet + m, qss ++ qss.map(_ ++ qas) :+ qas) + } + } + res + } + } + + def expand(m: A): Set[A] = Set(m) ++ margs(m).flatMap(expand) + val oms = matchers.toSeq.sortBy(m => -expand(m).size) + rec(oms, Set.empty, Seq.empty) + } + + def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Expr, Seq[Expr])]] = { + object QMatcher { + def unapply(e: Expr): Option[(Expr, Seq[Expr])] = e match { + case QuantificationMatcher(expr, args) => + if (args.exists { case Variable(id) => quantified(id) case _ => false }) { + Some(expr -> args) + } else { + None + } + case _ => None + } + } + + extractQuorums(collect { + case QMatcher(e, a) => Set(e -> a) + case _ => Set.empty[(Expr, Seq[Expr])] + } (expr), quantified, + (p: (Expr, Seq[Expr])) => p._2.collect { case QMatcher(e, a) => e -> a }.toSet, + (p: (Expr, Seq[Expr])) => p._2.collect { case Variable(id) if quantified(id) => id }.toSet) + } + + object HenkinDomains { + def empty = new HenkinDomains(Map.empty) + def apply(domains: Map[TypeTree, Set[Seq[Expr]]]) = new HenkinDomains(domains) + } + + class HenkinDomains (val domains: Map[TypeTree, Set[Seq[Expr]]]) { + def get(e: Expr): Set[Seq[Expr]] = e match { + case PartialLambda(mapping, _) => mapping.map(_._1).toSet + case _ => domains.get(e.getType) match { + case Some(domain) => domain + case None => scala.sys.error("Undefined Henkin domain for " + e) + } + } + } + + object QuantificationMatcher { + def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { + case Application(_: Application | _: FunctionInvocation, _) => None + case Application(e, args) => Some(e -> args) + case ArraySelect(arr, index) => Some(arr -> Seq(index)) + case MapApply(map, key) => Some(map -> Seq(key)) + // case ElementOfSet(set, elem) => Some(set -> Seq(elem)) + case _ => None + } + } + + object QuantificationTypeMatcher { + def unapply(tpe: TypeTree): Option[(Seq[TypeTree], TypeTree)] = tpe match { + case FunctionType(from, to) => Some(from -> to) + case ArrayType(base) => Some(Seq(Int32Type) -> base) + case MapType(from, to) => Some(Seq(from) -> to) + case SetType(base) => Some(Seq(base) -> BooleanType) + case _ => None + } + } + + object CheckForalls extends UnitPhase[Program] { + + val name = "Foralls" + val description = "Check syntax of foralls to guarantee sound instantiations" + + def apply(ctx: LeonContext, program: Program) = { + program.definedFunctions.foreach { fd => + if (fd.body.exists(b => exists { + case f: Forall => true + case _ => false + } (b))) ctx.reporter.warning("Universal quantification in function bodies is not supported in " + fd) + + val foralls = (fd.precondition.toSeq ++ fd.postcondition.toSeq).flatMap { prec => + collect[Forall] { + case f: Forall => Set(f) + case _ => Set.empty + } (prec) + } + + val free = fd.params.map(_.id).toSet ++ (fd.postcondition match { + case Some(Lambda(args, _)) => args.map(_.id) + case _ => Seq.empty + }) + + object Matcher { + def unapply(e: Expr): Option[(Identifier, Seq[Expr])] = e match { + case QuantificationMatcher(Variable(id), args) if free(id) => Some(id -> args) + case _ => None + } + } + + for (Forall(args, TopLevelAnds(conjuncts)) <- foralls) { + val quantified = args.map(_.id).toSet + + for (conjunct <- conjuncts) { + val matchers = collect[(Identifier, Seq[Expr])] { + case Matcher(id, args) => Set(id -> args) + case _ => Set.empty + } (conjunct) + + if (matchers.exists { case (id, args) => + args.exists(arg => arg match { + case Matcher(_, _) => false + case Variable(id) => false + case _ if (variablesOf(arg) & quantified).nonEmpty => true + case _ => false + }) + }) ctx.reporter.warning("Matcher arguments must have simple form in " + conjunct) + + val id2Quant = matchers.foldLeft(Map.empty[Identifier, Set[Identifier]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + })) + } + + if (id2Quant.filter(_._2.nonEmpty).groupBy(_._2).size != 1) + ctx.reporter.warning("Multiple matchers must provide bijective matching in " + conjunct) + + foldRight[Set[Identifier]] { case (m, children) => + val q = children.toSet.flatten + + m match { + case Matcher(_, args) => + q -- args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + } + case LessThan(_: Variable, _: Variable) => q + case LessEquals(_: Variable, _: Variable) => q + case GreaterThan(_: Variable, _: Variable) => q + case GreaterEquals(_: Variable, _: Variable) => q + case And(_) => q + case Or(_) => q + case Implies(_, _) => q + case Operator(es, _) => + val vars = es.flatMap { + case Variable(id) => Set(id) + case _ => Set.empty[Identifier] + }.toSet + + if (!(q.isEmpty || (q.size == 1 && (vars & free).isEmpty))) + ctx.reporter.warning("Invalid operation " + m + " on quantified variables") + q -- vars + case Variable(id) if quantified(id) => Set(id) + case _ => q + } + } (conjunct) + } + } + } + } + } +} diff --git a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala index 87905111f0ee42180a3ef2ab50f7975bcac2e5d3..664b9e3b26f0229cfb0ec12cf52bef7a53d6aa6a 100644 --- a/src/main/scala/leon/repair/RepairTrackingEvaluator.scala +++ b/src/main/scala/leon/repair/RepairTrackingEvaluator.scala @@ -8,6 +8,7 @@ import leon.purescala.Common._ import leon.purescala.Expressions._ import leon.purescala.Types._ import leon.purescala.Definitions._ +import leon.purescala.Quantification._ import leon.LeonContext import leon.evaluators.RecursiveEvaluator @@ -21,7 +22,7 @@ class RepairTrackingEvaluator(ctx: LeonContext, prog: Program) extends Recursive type GC = GlobalContext def initRC(mappings: Map[Identifier, Expr]) = CollectingRecContext(mappings, None) - def initGC() = new GlobalContext() + def initGC(model: leon.solvers.Model) = new GlobalContext(model) type FI = (FunDef, Seq[Expr]) diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index e4cfb8aafba33e82ef24806bd7b68de334c237b2..d38af251c0faadd29fcd5b7bd48f7bea21fa2691 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -9,6 +9,7 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ import purescala.DefOps._ +import purescala.Quantification._ import purescala.Constructors._ import purescala.Extractors.unwrapTuple @@ -202,10 +203,11 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou case None => _ => true case Some(pre) => - evaluator.compile(pre, fd.params map { _.id }) match { + val argIds = fd.params.map(_.id) + evaluator.compile(pre, argIds) match { case Some(evalFun) => val sat = EvaluationResults.Successful(BooleanLiteral(true)); - { (e: Seq[Expr]) => evalFun(e) == sat } + { (es: Seq[Expr]) => evalFun(new solvers.Model((argIds zip es).toMap)) == sat } case None => { _ => false } } diff --git a/src/main/scala/leon/solvers/EnumerationSolver.scala b/src/main/scala/leon/solvers/EnumerationSolver.scala index 3a2db100e7fab9568c5f43618d4104925ec9f735..d7d52d371ab45f7a0cbddab50bab14bc552de62e 100644 --- a/src/main/scala/leon/solvers/EnumerationSolver.scala +++ b/src/main/scala/leon/solvers/EnumerationSolver.scala @@ -46,7 +46,7 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends datagen = None } - private var modelMap = Map[Identifier, Expr]() + private var model = Model.empty def check: Option[Boolean] = { val timer = context.timers.solvers.enum.check.start() @@ -55,15 +55,15 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends if (interrupted) { None } else { - modelMap = Map() - val allFreeVars = freeVars.toSet.toSeq.sortBy(_.name) + model = Model.empty + val allFreeVars = freeVars.toSeq.sortBy(_.name) val allConstraints = constraints.toSeq val it = datagen.get.generateFor(allFreeVars, andJoin(allConstraints), 1, maxTried) if (it.hasNext) { - val model = it.next - modelMap = (allFreeVars zip model).toMap + val varModels = it.next + model = new Model((allFreeVars zip varModels).toMap) Some(true) } else { None @@ -78,8 +78,8 @@ class EnumerationSolver(val context: LeonContext, val program: Program) extends res } - def getModel: Map[Identifier, Expr] = { - modelMap + def getModel: Model = { + model } def free() = { diff --git a/src/main/scala/leon/solvers/EvaluatingSolver.scala b/src/main/scala/leon/solvers/EvaluatingSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..3463235c918d4872b0136d84b54f44e11b36cd69 --- /dev/null +++ b/src/main/scala/leon/solvers/EvaluatingSolver.scala @@ -0,0 +1,19 @@ +package leon +package solvers + +import purescala.Definitions._ +import evaluators._ + +trait EvaluatingSolver extends Solver { + val context: LeonContext + val program: Program + + val useCodeGen: Boolean + + lazy val evaluator: Evaluator = + if (useCodeGen) { + new CodeGenEvaluator(context, program) + } else { + new DefaultEvaluator(context, program) + } +} diff --git a/src/main/scala/leon/solvers/GroundSolver.scala b/src/main/scala/leon/solvers/GroundSolver.scala index 29ee752383afad17134d7d6aead3eb12c274da72..f38ddd188f15f40058606e9d018f4b2b5b6f186b 100644 --- a/src/main/scala/leon/solvers/GroundSolver.scala +++ b/src/main/scala/leon/solvers/GroundSolver.scala @@ -26,7 +26,7 @@ class GroundSolver(val context: LeonContext, val program: Program) extends Solve private val assertions = new IncrementalSeq[Expr]() // Ground terms will always have the empty model - def getModel: Map[Identifier, Expr] = Map() + def getModel: Model = Model.empty def assertCnstr(expression: Expr): Unit = { assertions += expression diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..dc3e8584fd74578ac2173e1819c059da9f0ec99b --- /dev/null +++ b/src/main/scala/leon/solvers/QuantificationSolver.scala @@ -0,0 +1,30 @@ +package leon +package solvers + +import purescala.Common._ +import purescala.Expressions._ +import purescala.Quantification._ +import purescala.Types._ + +class HenkinModel(mapping: Map[Identifier, Expr], doms: HenkinDomains) + extends Model(mapping) + with AbstractModel[HenkinModel] { + override def newBuilder = new HenkinModelBuilder(doms) + + def domains: Map[TypeTree, Set[Seq[Expr]]] = doms.domains + def domain(expr: Expr) = doms.get(expr) +} + +object HenkinModel { + def empty = new HenkinModel(Map.empty, HenkinDomains.empty) +} + +class HenkinModelBuilder(domains: HenkinDomains) + extends ModelBuilder + with AbstractModelBuilder[HenkinModel] { + override def result = new HenkinModel(mapBuilder.result, domains) +} + +trait QuantificationSolver { + def getModel: HenkinModel +} diff --git a/src/main/scala/leon/solvers/SimpleSolverAPI.scala b/src/main/scala/leon/solvers/SimpleSolverAPI.scala index 37aefc1b5507f087981570b11fe153237d9832de..33f6f133644f55d24efddea8e4ce010516b0cba9 100644 --- a/src/main/scala/leon/solvers/SimpleSolverAPI.scala +++ b/src/main/scala/leon/solvers/SimpleSolverAPI.scala @@ -17,7 +17,7 @@ class SimpleSolverAPI(sf: SolverFactory[Solver]) { } } - def solveSAT(expression: Expr): (Option[Boolean], Map[Identifier, Expr]) = { + def solveSAT(expression: Expr): (Option[Boolean], Model) = { val s = sf.getNewSolver() try { s.assertCnstr(expression) @@ -25,16 +25,16 @@ class SimpleSolverAPI(sf: SolverFactory[Solver]) { case Some(true) => (Some(true), s.getModel) case Some(false) => - (Some(false), Map()) + (Some(false), Model.empty) case None => - (None, Map()) + (None, Model.empty) } } finally { sf.reclaim(s) } } - def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Map[Identifier, Expr], Set[Expr]) = { + def solveSATWithCores(expression: Expr, assumptions: Set[Expr]): (Option[Boolean], Model, Set[Expr]) = { val s = sf.getNewSolver() try { s.assertCnstr(expression) @@ -42,9 +42,9 @@ class SimpleSolverAPI(sf: SolverFactory[Solver]) { case Some(true) => (Some(true), s.getModel, Set()) case Some(false) => - (Some(false), Map(), s.getUnsatCore) + (Some(false), Model.empty, s.getUnsatCore) case None => - (None, Map(), Set()) + (None, Model.empty, Set()) } } finally { sf.reclaim(s) diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala index a9018b2defa113783cd88dd166d5a01c1e6256a4..3188031e9ca0ee918e419928c4416d6349c39b8d 100644 --- a/src/main/scala/leon/solvers/Solver.scala +++ b/src/main/scala/leon/solvers/Solver.scala @@ -5,9 +5,72 @@ package solvers import utils.{DebugSectionSolver, Interruptible} import purescala.Expressions._ -import leon.purescala.Common.{Tree, Identifier} +import purescala.Common.{Tree, Identifier} +import purescala.ExprOps._ import verification.VC +trait AbstractModel[+This <: Model with AbstractModel[This]] + extends scala.collection.IterableLike[(Identifier, Expr), This] { + + protected val mapping: Map[Identifier, Expr] + + def fill(allVars: Iterable[Identifier]): This = { + val builder = newBuilder + builder ++= mapping ++ (allVars.toSet -- mapping.keys).map(id => id -> simplestValue(id.getType)) + builder.result + } + + def ++(mapping: Map[Identifier, Expr]): This = { + val builder = newBuilder + builder ++= this.mapping ++ mapping + builder.result + } + + def filter(allVars: Iterable[Identifier]): This = { + val builder = newBuilder + for (p <- mapping.filterKeys(allVars.toSet)) { + builder += p + } + builder.result + } + + def iterator = mapping.iterator + def seq = mapping.seq +} + +trait AbstractModelBuilder[+This <: Model with AbstractModel[This]] + extends scala.collection.mutable.Builder[(Identifier, Expr), This] { + + import scala.collection.mutable.MapBuilder + protected val mapBuilder = new MapBuilder[Identifier, Expr, Map[Identifier, Expr]](Map.empty) + + def +=(elem: (Identifier, Expr)): this.type = { + mapBuilder += elem + this + } + + def clear(): Unit = mapBuilder.clear +} + +class Model(protected val mapping: Map[Identifier, Expr]) + extends AbstractModel[Model] + with (Identifier => Expr) { + + def newBuilder = new ModelBuilder + def isDefinedAt(id: Identifier): Boolean = mapping.isDefinedAt(id) + def get(id: Identifier): Option[Expr] = mapping.get(id) + def getOrElse[E >: Expr](id: Identifier, e: E): E = get(id).getOrElse(e) + def apply(id: Identifier): Expr = get(id).getOrElse { throw new IllegalArgumentException } +} + +object Model { + def empty = new Model(Map.empty) +} + +class ModelBuilder extends AbstractModelBuilder[Model] { + def result = new Model(mapBuilder.result) +} + trait Solver extends Interruptible { def name: String val context: LeonContext @@ -20,7 +83,7 @@ trait Solver extends Interruptible { } def check: Option[Boolean] - def getModel: Map[Identifier, Expr] + def getModel: Model def getResultSolver: Option[Solver] = Some(this) def free() diff --git a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala index 9997d5176d0389ad0d0a790e228fbd77cace93b9..429bc24b8d9389cd97782237d4b69cabd562ef64 100644 --- a/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala +++ b/src/main/scala/leon/solvers/combinators/PortfolioSolver.scala @@ -21,7 +21,7 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, var constraints = List[Expr]() - protected var modelMap = Map[Identifier, Expr]() + protected var model = Model.empty protected var resultSolver: Option[Solver] = None override def getResultSolver = resultSolver @@ -35,17 +35,17 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, } def check: Option[Boolean] = { - modelMap = Map() + model = Model.empty context.reporter.debug("Running portfolio check") // solving val fs = solvers.map { s => Future { val result = s.check - val model: Map[Identifier, Expr] = if (result == Some(true)) { + val model: Model = if (result == Some(true)) { s.getModel } else { - Map() + Model.empty } (s, result, model) } @@ -55,7 +55,7 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, val res = Await.result(result, Duration.Inf) match { case Some((s, r, m)) => - modelMap = m + model = m resultSolver = s.getResultSolver resultSolver.foreach { solv => context.reporter.debug("Solved with "+solv) @@ -80,12 +80,12 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, def free() = { solvers.foreach(_.free) - modelMap = Map() + model = Model.empty constraints = Nil } - def getModel: Map[Identifier, Expr] = { - modelMap + def getModel: Model = { + model } def interrupt(): Unit = { @@ -98,7 +98,7 @@ class PortfolioSolver[S <: Solver with Interruptible](val context: LeonContext, def reset() = { solvers.foreach(_.reset) - modelMap = Map() + model = Model.empty constraints = Nil } } diff --git a/src/main/scala/leon/solvers/combinators/RewritingSolver.scala b/src/main/scala/leon/solvers/combinators/RewritingSolver.scala index 8aeb812fcffd38b5e2d65527aaa646706f001079..2c414f2c0bf9a6f14e6f571b98eed2bfc3ce49aa 100644 --- a/src/main/scala/leon/solvers/combinators/RewritingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/RewritingSolver.scala @@ -12,19 +12,19 @@ abstract class RewritingSolver[+S <: Solver, T](underlying: S) { /** The type T is used to encode any meta information useful, for instance, to reconstruct * models. */ - def rewriteCnstr(expression : Expr) : (Expr,T) + def rewriteCnstr(expression: Expr) : (Expr,T) - def reconstructModel(model : Map[Identifier,Expr], meta : T) : Map[Identifier,Expr] + def reconstructModel(model: Model, meta: T) : Model private var storedMeta : List[T] = Nil - def assertCnstr(expression : Expr) { + def assertCnstr(expression: Expr) { val (rewritten, meta) = rewriteCnstr(expression) storedMeta = meta :: storedMeta underlying.assertCnstr(rewritten) } - def getModel : Map[Identifier,Expr] = { + def getModel: Model = { storedMeta match { case Nil => underlying.getModel case m :: _ => reconstructModel(underlying.getModel, m) diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index aa936b2838f6b1a2818ce091ce9f520987259920..b5cd0a04a3d6d050f7474e01e1a53b16c66119b5 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -6,9 +6,11 @@ package combinators import purescala.Common._ import purescala.Definitions._ +import purescala.Quantification._ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ +import purescala.Types._ import utils._ import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre} @@ -16,20 +18,17 @@ import templates._ import utils.Interruptible import evaluators._ -class UnrollingSolver(val context: LeonContext, program: Program, underlying: Solver) extends Solver with NaiveAssumptionSolver { +class UnrollingSolver(val context: LeonContext, val program: Program, underlying: Solver) + extends Solver + with NaiveAssumptionSolver + with EvaluatingSolver + with QuantificationSolver { + val feelingLucky = context.findOptionOrDefault(optFeelingLucky) val useCodeGen = context.findOptionOrDefault(optUseCodeGen) val assumePreHolds = context.findOptionOrDefault(optAssumePre) - private val evaluator : Evaluator = { - if(useCodeGen) { - new CodeGenEvaluator(context, program) - } else { - new DefaultEvaluator(context, program) - } - } - - protected var lastCheckResult : (Boolean, Option[Boolean], Option[Map[Identifier,Expr]]) = (false, None, None) + protected var lastCheckResult : (Boolean, Option[Boolean], Option[HenkinModel]) = (false, None, None) private val freeVars = new IncrementalSet[Identifier]() private val constraints = new IncrementalSeq[Expr]() @@ -106,17 +105,61 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So def hasFoundAnswer = lastCheckResult._1 - def foundAnswer(res: Option[Boolean], model: Option[Map[Identifier, Expr]] = None) = { + private def extractModel(model: Model): HenkinModel = { + val allVars = freeVars.toSet + + def extract(b: Expr, m: Matcher[Expr]): Set[Seq[Expr]] = { + val QuantificationTypeMatcher(fromTypes, _) = m.tpe + val optEnabler = evaluator.eval(b).result + val optArgs = m.args.map(arg => evaluator.eval(Matcher.argValue(arg)).result) + if (optEnabler == Some(BooleanLiteral(true)) && optArgs.forall(_.isDefined)) { + Set(optArgs.map(_.get)) + } else { + Set.empty + } + } + + val funDomains = allVars.flatMap(id => id.getType match { + case ft @ FunctionType(fromTypes, _) => + Some(id -> templateGenerator.manager.instantiations(Variable(id), ft).flatMap { + case (b, m) => extract(b, m) + }) + case _ => None + }).toMap.mapValues(_.toSet) + + val asDMap = model.map(p => funDomains.get(p._1) match { + case Some(domain) => + val mapping = domain.toSeq.map { es => + val ev: Expr = p._2 match { + case RawArrayValue(_, mapping, dflt) => + mapping.collectFirst { + case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v + } getOrElse dflt + case _ => scala.sys.error("Unexpected function encoding " + p._2) + } + es -> ev + } + + p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) + case None => p + }).toMap + + val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) + val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + + val domains = new HenkinDomains(typeDomains) + new HenkinModel(asDMap, domains) + } + + def foundAnswer(res: Option[Boolean], model: Option[HenkinModel] = None) = { lastCheckResult = (true, res, model) } - def isValidModel(model: Map[Identifier, Expr], silenceErrors: Boolean = false): Boolean = { + def isValidModel(model: HenkinModel, silenceErrors: Boolean = false): Boolean = { import EvaluationResults._ val expr = andJoin(constraints.toSeq) - val allVars = freeVars.toSet - - val fullModel = allVars.map(v => v -> model.getOrElse(v, simplestValue(v.getType))).toMap + val fullModel = model fill freeVars.toSet evaluator.eval(expr, fullModel) match { case Successful(BooleanLiteral(true)) => @@ -152,7 +195,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So reporter.debug(" - Running search...") solver.push() - solver.assertCnstr(andJoin((assumptions ++ unrollingBank.currentBlockers ++ unrollingBank.quantificationAssumptions).toSeq)) + solver.assertCnstr(andJoin((assumptions ++ unrollingBank.satisfactionAssumptions).toSeq)) val res = solver.check reporter.debug(" - Finished search with blocked literals") @@ -167,7 +210,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So foundAnswer(None) case Some(true) => // SAT - val model = solver.getModel + val model = extractModel(solver.getModel) solver.pop() foundAnswer(Some(true), Some(model)) @@ -188,7 +231,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So } solver.push() - solver.assertCnstr(andJoin(assumptions.toSeq ++ unrollingBank.quantificationAssumptions)) + solver.assertCnstr(andJoin(assumptions.toSeq ++ unrollingBank.refutationAssumptions)) val res2 = solver.check res2 match { @@ -198,7 +241,7 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So case Some(true) => if (feelingLucky && !interrupted) { - val model = solver.getModel + val model = extractModel(solver.getModel) // we might have been lucky :D if (isValidModel(model, silenceErrors = true)) { @@ -241,13 +284,12 @@ class UnrollingSolver(val context: LeonContext, program: Program, underlying: So } } - def getModel: Map[Identifier,Expr] = { - val allVars = freeVars.toSet + def getModel: HenkinModel = { lastCheckResult match { case (true, Some(true), Some(m)) => - m.filterKeys(allVars) + m.filter(freeVars.toSet) case _ => - Map() + HenkinModel.empty } } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala index 86da0be80db1c8eec85f6dedc6972871ced49ad3..1f0b760d70c59796cd7cdb5a9272526077bda087 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4ProofSolver.scala @@ -37,7 +37,7 @@ class SMTLIBCVC4ProofSolver(context: LeonContext, program: Program) extends SMTL } // This solver does not support model extraction - override def getModel: Map[Identifier, Expr] = { + override def getModel: solvers.Model = { // We don't send the error through reporter because it may be caught by PortfolioSolver throw LeonFatalError(Some(s"Solver $name does not support model extraction.")) } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala index fa11aa42159144eb8423a8dab8f5c98fe07d6aa7..7df71249185b5727c9144526d945ddb789920882 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala @@ -76,16 +76,16 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), RawArrayType(k,v)) => RawArrayValue(k, Map(), fromSMT(elem, v)) - case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), ft @ FunctionType(from,to)) => - finiteLambda(fromSMT(elem, to), Seq.empty, from) + case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), FunctionType(from,to)) => + RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), RawArrayType(k,v)) => val RawArrayValue(_, elems, base) = fromSMT(arr, tpe) RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) - case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), ft @ FunctionType(from,to)) => - val FiniteLambda(dflt, mapping) = fromSMT(arr, tpe) - finiteLambda(dflt, mapping :+ (fromSMT(key, tupleTypeWrap(from)) -> fromSMT(elem, to)), from) + case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), FunctionType(from,to)) => + val RawArrayValue(k, elems, base) = fromSMT(arr, tpe) + RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, to)), base) case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) => FiniteSet(elems.map(fromSMT(_, base)).toSet, base) @@ -103,8 +103,8 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol // FIXME (nicolas) // some versions of CVC4 seem to generate array constants with "as const" notation instead of the __array_store_all__ // one I've witnessed up to now. Don't know why this is happening... - case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), ft @ FunctionType(from, to)) => - finiteLambda(fromSMT(elem, to), Seq.empty, from) + case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), FunctionType(from, to)) => + RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), RawArrayType(k, v)) => RawArrayValue(k, Map(), fromSMT(elem, v)) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala index 0d761fe214656eb145047769a2adc8aef8e5232a..f27838071238dcfb31f14f32c2e27c9cead486b5 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBQuantifiedSolver.scala @@ -54,7 +54,7 @@ trait SMTLIBQuantifiedSolver extends SMTLIBSolver { // Normally, UnrollingSolver tracks the input variable, but this one // is invoked alone so we have to filter them here - override def getModel: Map[Identifier, Expr] = { + override def getModel: leon.solvers.Model = { val filter = currentFunDef.map{ _.params.map{_.id}.toSet }.getOrElse( (_:Identifier) => true ) getModel(filter) } diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index f6e00625a2a5a34d2a517c5d4450c17eb7ccd2fd..aa18becd2b2416fa7a72d74573f28652217b13ee 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -28,8 +28,9 @@ import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} import _root_.smtlib.theories._ import _root_.smtlib.{Interpreter => SMTInterpreter} -abstract class SMTLIBSolver(val context: LeonContext, - val program: Program) extends Solver with NaiveAssumptionSolver { +abstract class SMTLIBSolver(val context: LeonContext, val program: Program) + extends Solver + with NaiveAssumptionSolver { /* Solver name */ def targetName: String @@ -46,7 +47,6 @@ abstract class SMTLIBSolver(val context: LeonContext, protected val interpreter = getNewInterpreter(context) - /* Printing VCs */ protected lazy val out: Option[java.io.FileWriter] = if (reporter.isDebugEnabled) Some { val file = context.files.headOption.map(_.getName).getOrElse("NA") @@ -171,8 +171,8 @@ abstract class SMTLIBSolver(val context: LeonContext, case RawArrayType(from, to) => r - case ft @ FunctionType(from, to) => - finiteLambda(r.default, r.elems.toSeq, from) + case FunctionType(from, to) => + r case MapType(from, to) => // We expect a RawArrayValue with keys in from and values in Option[to], @@ -742,10 +742,10 @@ abstract class SMTLIBSolver(val context: LeonContext, } } - protected def getModel(filter: Identifier => Boolean): Map[Identifier, Expr] = { + protected def getModel(filter: Identifier => Boolean): Model = { val syms = variables.aSet.filter(filter).toList.map(variables.aToB) if (syms.isEmpty) { - Map() + Model.empty } else { val cmd: Command = GetValue( syms.head, @@ -754,20 +754,20 @@ abstract class SMTLIBSolver(val context: LeonContext, sendCommand(cmd) match { case GetValueResponseSuccess(valuationPairs) => - - valuationPairs.collect { + new Model(valuationPairs.collect { case (SimpleSymbol(sym), value) if variables.containsB(sym) => val id = variables.toA(sym) (id, fromSMT(value, id.getType)(Map(), Map())) - }.toMap + }.toMap) + case _ => - Map() //FIXME improve this + Model.empty //FIXME improve this } } } - override def getModel: Map[Identifier, Expr] = getModel( _ => true) + override def getModel: Model = getModel( _ => true) override def push(): Unit = { constructors.push() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala index 759dedbd8229c909a5755be064d369cf6332762f..ec91138e3725af9ff88e19a7b95a1dbedeb74cc9 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala @@ -159,7 +159,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve } // EK: We use get-model instead in order to extract models for arrays - override def getModel: Map[Identifier, Expr] = { + override def getModel: Model = { val cmd = GetModel() @@ -199,8 +199,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve case _ => } - - model + new Model(model) } object ArrayMap { diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 3b355a4b7a960279f4dfd9d65ebfbdec4337147e..3d5eec72c809a7ba9459b4b46752835b63bd6011 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -9,51 +9,23 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import utils._ import Instantiation._ -class LambdaManager[T](protected val encoder: TemplateEncoder[T]) { +class LambdaManager[T](protected val encoder: TemplateEncoder[T]) extends IncrementalState { - protected type IdMap = Map[T, LambdaTemplate[T]] - protected def byID : IdMap = byIDStack.head - private var byIDStack : List[IdMap] = List(Map.empty) - private def byID_=(map: IdMap) : Unit = { - byIDStack = map :: byIDStack.tail - } + protected val byID = new IncrementalMap[T, LambdaTemplate[T]] + protected val byType = new IncrementalMap[FunctionType, Set[(T, LambdaTemplate[T])]].withDefaultValue(Set.empty) + protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) + protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) - protected type TypeMap = Map[FunctionType, Set[(T, LambdaTemplate[T])]] - protected def byType : TypeMap = byTypeStack.head - private var byTypeStack : List[TypeMap] = List(Map.empty.withDefaultValue(Set.empty)) - private def byType_=(map: TypeMap) : Unit = { - byTypeStack = map :: byTypeStack.tail - } + protected def incrementals: List[IncrementalState] = + List(byID, byType, applications, freeLambdas) - protected type ApplicationMap = Map[FunctionType, Set[(T, App[T])]] - protected def applications : ApplicationMap = applicationsStack.head - private var applicationsStack : List[ApplicationMap] = List(Map.empty.withDefaultValue(Set.empty)) - private def applications_=(map: ApplicationMap) : Unit = { - applicationsStack = map :: applicationsStack.tail - } - - protected type FreeMap = Map[FunctionType, Set[T]] - protected def freeLambdas : FreeMap = freeLambdasStack.head - private var freeLambdasStack : List[FreeMap] = List(Map.empty.withDefaultValue(Set.empty)) - private def freeLambdas_=(map: FreeMap) : Unit = { - freeLambdasStack = map :: freeLambdasStack.tail - } - - def push(): Unit = { - byIDStack = byID :: byIDStack - byTypeStack = byType :: byTypeStack - applicationsStack = applications :: applicationsStack - freeLambdasStack = freeLambdas :: freeLambdasStack - } - - def pop(lvl: Int): Unit = { - byIDStack = byIDStack.drop(lvl) - byTypeStack = byTypeStack.drop(lvl) - applicationsStack = applicationsStack.drop(lvl) - freeLambdasStack = freeLambdasStack.drop(lvl) - } + def clear(): Unit = incrementals.foreach(_.clear()) + def reset(): Unit = incrementals.foreach(_.reset()) + def push(): Unit = incrementals.foreach(_.push()) + def pop(): Unit = incrementals.foreach(_.pop()) def registerFree(lambdas: Seq[(TypeTree, T)]): Unit = { for ((tpe, idT) <- lambdas) tpe match { diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index a420b817b20a5f8fc3896217985729febd75ce89..74ffe69e103d592767825a8516c9d252698d481c 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -4,8 +4,10 @@ package leon package solvers package templates +import leon.utils._ import purescala.Common._ import purescala.Extractors._ +import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ @@ -88,47 +90,29 @@ object QuantificationTemplate { class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) { - private val nextQ: () => T = { - val id: Identifier = FreshIdentifier("q", BooleanType, true) - () => encoder.encodeId(id) - } + private val quantifications = new IncrementalSeq[Quantification] + private val instantiated = new IncrementalSet[(T, Matcher[T])] + private val known = new IncrementalSet[T] - private var quantificationsStack: List[Seq[Quantification]] = List(Seq.empty) - private def quantifications: Seq[Quantification] = quantificationsStack.head - private def quantifications_=(qs: Seq[Quantification]): Unit = { - quantificationsStack = qs :: quantificationsStack.tail + private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = correspond(qm, m.caller, m.tpe) + private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = qm.tpe match { + case _: FunctionType => qm.tpe == tpe && (qm.caller == caller || !known(caller)) + case _ => qm.tpe == tpe } - private var instantiatedStack: List[Set[(T, Matcher[T])]] = List(Set.empty) - private def instantiated: Set[(T, Matcher[T])] = instantiatedStack.head - private def instantiated_=(ias: Set[(T, Matcher[T])]): Unit = { - instantiatedStack = ias :: instantiatedStack.tail - } + override protected def incrementals: List[IncrementalState] = + List(quantifications, instantiated, known) ++ super.incrementals - private var knownStack: List[Set[T]] = List(Set.empty) - private def known(idT: T): Boolean = knownStack.head(idT) || byID.isDefinedAt(idT) - private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = qm.tpe match { - case _: FunctionType => qm.tpe == m.tpe && (qm.caller == m.caller || !known(m.caller)) - case _ => qm.tpe == m.tpe - } + def assumptions: Seq[T] = quantifications.map(_.currentQ2Var).toSeq - override def push(): Unit = { - quantificationsStack = quantifications :: quantificationsStack - instantiatedStack = instantiated :: instantiatedStack - knownStack = knownStack.head :: knownStack - } + def instantiations: Seq[(T, Matcher[T])] = instantiated.toSeq - override def pop(lvl: Int): Unit = { - quantificationsStack = quantificationsStack.drop(lvl) - instantiatedStack = instantiatedStack.drop(lvl) - knownStack = knownStack.drop(lvl) - } - - def assumptions: Seq[T] = quantifications.map(_.currentQ2Var) + def instantiations(caller: T, tpe: TypeTree): Seq[(T, Matcher[T])] = + instantiated.toSeq.filter { case (b,m) => correspond(m, caller, tpe) } override def registerFree(ids: Seq[(TypeTree, T)]): Unit = { super.registerFree(ids) - knownStack = (knownStack.head ++ ids.map(_._2)) :: knownStack.tail + known ++= ids.map(_._2) } private class Quantification ( @@ -167,7 +151,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage .map(qm => if (qm == bindingMatcher) { bindingMatcher -> Set(blocker -> matcher) } else { - val instances: Set[(T, Matcher[T])] = instantiated.filter { case (b, m) => correspond(qm, m) } + val instances: Set[(T, Matcher[T])] = instantiated.filter { case (b, m) => correspond(qm, m) }.toSet // concrete applications can appear multiple times in the constraint, and this is also the case // for the current application for which we are generating the constraints @@ -303,23 +287,16 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage template.matchers merge rec(template.lambdas) } - val quantifiedMatchers = for { + val quantifiedMatchers = (for { (_, ms) <- allMatchers m @ Matcher(_, _, args, _) <- ms if args exists (_.left.exists(quantified)) - } yield m - - val matchQuorums: Seq[Set[Matcher[T]]] = quantifiedMatchers.toSet.subsets.filter { ms => - var doubled: Boolean = false - var qs: Set[T] = Set.empty - for (m @ Matcher(_, _, args, _) <- ms) { - val qargs = (args collect { case Left(a) if quantified(a) => a }).toSet - if ((qs & qargs).nonEmpty) doubled = true - qs ++= qargs - } + } yield m).toSet - !doubled && (qs == quantified) - }.toList + val matchQuorums: Seq[Set[Matcher[T]]] = purescala.Quantification.extractQuorums( + quantifiedMatchers, quantified, + (m: Matcher[T]) => m.args.collect { case Right(m) if quantifiedMatchers(m) => m }.toSet, + (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) var instantiation = Instantiation.empty[T] @@ -387,7 +364,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage instantiation ++= quantification.instantiate(b, m) } - quantifications :+= quantification + quantifications += quantification quantification.qs._2 } diff --git a/src/main/scala/leon/solvers/templates/Templates.scala b/src/main/scala/leon/solvers/templates/Templates.scala index 678050fa131dd26713836860d3a0e57031eb9b95..f4714bd9e561335d90d75860770b7bede75d04ee 100644 --- a/src/main/scala/leon/solvers/templates/Templates.scala +++ b/src/main/scala/leon/solvers/templates/Templates.scala @@ -5,11 +5,12 @@ package solvers package templates import purescala.Common._ +import purescala.Definitions._ import purescala.Expressions._ +import purescala.Quantification._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.Definitions._ case class App[T](caller: T, tpe: FunctionType, args: Seq[T]) { override def toString = "(" + caller + " : " + tpe + ")" + args.mkString("(", ",", ")") @@ -122,15 +123,6 @@ object Template { } } - private object MatchExtractor { - def unapply(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case ApplicationExtractor(caller, args) => Some(caller -> args) - case ArraySelect(arr, index) => Some(arr -> Seq(index)) - case MapGet(map, key) => Some(map -> Seq(key)) - case _ => None - } - } - private def invocationMatcher[T](encodeExpr: Expr => T)(tfd: TypedFunDef, args: Seq[Expr]): Matcher[T] = { assert(tfd.returnType.isInstanceOf[FunctionType], "invocationMatcher() is only defined on function-typed defs") @@ -174,7 +166,7 @@ object Template { val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(_._2))) val optIdApp = optApp.map { case (idT, tpe) => App(idT, tpe, arguments.map(_._2)) } - val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) + lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) .map(tfd => invocationMatcher(encodeExpr)(tfd, arguments.map(_._1.toVariable))) val (blockers, applications, matchers) = { @@ -204,7 +196,7 @@ object Template { val result = res.flatten.toMap result ++ (expr match { - case MatchExtractor(c, args) => + case QuantificationMatcher(c, args) => // Note that we rely here on the fact that foldRight visits the matcher's arguments first, // so any Matcher in arguments will belong to the `result` map val encodedArgs = args.map(arg => result.get(arg) match { @@ -382,47 +374,6 @@ class FunctionTemplate[T] private( object LambdaTemplate { - private var typedIds : Map[TypeTree, List[Identifier]] = Map.empty.withDefaultValue(List.empty) - - private def structuralKey[T](lambda: Lambda, dependencies: Map[Identifier, T]): (Lambda, Map[Identifier,T]) = { - - def closureIds(expr: Expr): Seq[Identifier] = { - val allVars : Seq[Identifier] = foldRight[Seq[Identifier]] { - (expr, idSeqs) => idSeqs.foldLeft(expr match { - case Variable(id) => Seq(id) - case _ => Seq.empty[Identifier] - })((acc, seq) => acc ++ seq) - } (expr) - - val vars = variablesOf(expr) - allVars.filter(vars(_)).distinct - } - - val grouped : Map[TypeTree, Seq[Identifier]] = (lambda.args.map(_.id) ++ closureIds(lambda)).groupBy(_.getType) - val subst : Map[Identifier, Identifier] = grouped.foldLeft(Map.empty[Identifier,Identifier]) { case (subst, (tpe, ids)) => - val currentVars = typedIds(tpe) - - val freshCount = ids.size - currentVars.size - val typedVars = if (freshCount > 0) { - val allIds = currentVars ++ List.range(0, freshCount).map(_ => FreshIdentifier("x", tpe, true)) - typedIds += tpe -> allIds - allIds - } else { - currentVars - } - - subst ++ (ids zip typedVars) - } - - val newArgs = lambda.args.map(vd => ValDef(subst(vd.id), vd.tpe)) - val newBody = replaceFromIDs(subst.mapValues(_.toVariable), lambda.body) - val structuralLambda = Lambda(newArgs, newBody) - - val newDeps = dependencies.map { case (id, idT) => subst(id) -> idT } - - structuralLambda -> newDeps - } - def apply[T]( ids: (Identifier, T), encoder: TemplateEncoder[T], @@ -448,7 +399,9 @@ object LambdaTemplate { "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() } - val (key, keyDeps) = structuralKey(lambda, dependencies) + val (structuralLambda, structSubst) = normalizeStructure(lambda) + val keyDeps = dependencies.map { case (id, idT) => structSubst(id) -> idT } + val key = structuralLambda.asInstanceOf[Lambda] new LambdaTemplate[T]( ids._1, diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index 617f1c5f47c538997801d3b43ce3d15e995c5a82..ddfb22b0bdbd2a3be90e8950b7a21144b472d299 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -18,90 +18,84 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat private val encoder = templateGenerator.encoder private val manager = templateGenerator.manager - // Keep which function invocation is guarded by which guard, - // also specify the generation of the blocker. - private val callInfos = new IncrementalMap[T, (Int, Int, T, Set[TemplateCallInfo[T]])]() - private def callInfo = callInfos.toMap - // Function instantiations have their own defblocker - private val defBlockerss = new IncrementalMap[TemplateCallInfo[T], T]() - private def defBlockers = defBlockerss.toMap - - private val appInfos = new IncrementalMap[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]() - private def appInfo = appInfos.toMap - - private val appBlockerss = new IncrementalMap[(T, App[T]), T]() - private def appBlockers = appBlockerss.toMap + private val defBlockers = new IncrementalMap[TemplateCallInfo[T], T]() + // Keep which function invocation is guarded by which guard, + // also specify the generation of the blocker. + private val callInfos = new IncrementalMap[T, (Int, Int, T, Set[TemplateCallInfo[T]])]() + private val appInfos = new IncrementalMap[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]() + private val appBlockers = new IncrementalMap[(T, App[T]), T]() private val blockerToApps = new IncrementalMap[T, (T, App[T])]() - private def blockerToApp = blockerToApps.toMap - - private val functionVarss = new IncrementalMap[TypeTree, Set[T]]() - private def functionVars = functionVarss.toMap + private val functionVars = new IncrementalMap[TypeTree, Set[T]]() def push() { callInfos.push() - defBlockerss.push() + defBlockers.push() appInfos.push() - appBlockerss.push() + appBlockers.push() blockerToApps.push() - functionVarss.push() + functionVars.push() } def pop() { callInfos.pop() - defBlockerss.pop() + defBlockers.pop() appInfos.pop() - appBlockerss.pop() + appBlockers.pop() blockerToApps.pop() - functionVarss.pop() + functionVars.pop() } def clear() { callInfos.clear() - defBlockerss.clear() + defBlockers.clear() appInfos.clear() - appBlockerss.clear() - functionVarss.clear() + appBlockers.clear() + blockerToApps.clear() + functionVars.clear() } def reset() { callInfos.reset() - defBlockerss.reset() + defBlockers.reset() appInfos.reset() - appBlockerss.reset() - functionVarss.reset() + appBlockers.reset() + blockerToApps.clear() + functionVars.reset() } def dumpBlockers() = { - val generations = (callInfo.map(_._2._1).toSet ++ appInfo.map(_._2._1).toSet).toSeq.sorted + val generations = (callInfos.map(_._2._1).toSet ++ appInfos.map(_._2._1).toSet).toSeq.sorted generations.foreach { generation => reporter.debug("--- " + generation) - for ((b, (gen, origGen, ast, fis)) <- callInfo if gen == generation) { + for ((b, (gen, origGen, ast, fis)) <- callInfos if gen == generation) { reporter.debug(f". $b%15s ~> "+fis.mkString(", ")) } - for ((app, (gen, origGen, b, notB, infos)) <- appInfo if gen == generation) { + for ((app, (gen, origGen, b, notB, infos)) <- appInfos if gen == generation) { reporter.debug(f". $b%15s ~> "+infos.mkString(", ")) } } } - def canUnroll = callInfo.nonEmpty || appInfo.nonEmpty + def satisfactionAssumptions = currentBlockers ++ manager.assumptions + + def refutationAssumptions = manager.assumptions - def currentBlockers = callInfo.map(_._2._3).toSeq ++ appInfo.map(_._2._4).toSeq + def canUnroll = callInfos.nonEmpty || appInfos.nonEmpty - def quantificationAssumptions = manager.assumptions + def currentBlockers = callInfos.map(_._2._3).toSeq ++ appInfos.map(_._2._4).toSeq def getBlockersToUnlock: Seq[T] = { - if (callInfo.isEmpty && appInfo.isEmpty) { + if (callInfos.isEmpty && appInfos.isEmpty) { Seq.empty } else { - val minGeneration = (callInfo.values.map(_._1) ++ appInfo.values.map(_._1)).min - val callBlocks = callInfo.filter(_._2._1 == minGeneration).toSeq.map(_._1) - val appBlocks = appInfo.values.filter(_._1 == minGeneration).toSeq.map(_._3) + val minGeneration = (callInfos.values.map(_._1) ++ appInfos.values.map(_._1)).min + val callBlocks = callInfos.filter(_._2._1 == minGeneration).toSeq.map(_._1) + val appBlocks = appInfos.values.filter(_._1 == minGeneration).toSeq.map(_._3) callBlocks ++ appBlocks } } @@ -109,7 +103,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat private def registerCallBlocker(gen: Int, id: T, fis: Set[TemplateCallInfo[T]]) { val notId = encoder.mkNot(id) - callInfo.get(id) match { + callInfos.get(id) match { case Some((exGen, origGen, _, exFis)) => // PS: when recycling `b`s, this assertion becomes dangerous. // It's better to simply take the max of the generations. @@ -117,17 +111,17 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat val minGen = gen min exGen - callInfo += id -> (minGen, origGen, notId, fis++exFis) + callInfos += id -> (minGen, origGen, notId, fis++exFis) case None => - callInfo += id -> (gen, gen, notId, fis) + callInfos += id -> (gen, gen, notId, fis) } } private def registerAppBlocker(gen: Int, app: (T, App[T]), info: Set[TemplateAppInfo[T]]) : Unit = { - appInfo.get(app) match { + appInfos.get(app) match { case Some((exGen, origGen, b, notB, exInfo)) => val minGen = gen min exGen - appInfo += app -> (minGen, origGen, b, notB, exInfo ++ info) + appInfos += app -> (minGen, origGen, b, notB, exInfo ++ info) case None => val b = appBlockers.get(app) match { @@ -136,8 +130,8 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } val notB = encoder.mkNot(b) - appInfo += app -> (gen, gen, b, notB, info) - blockerToApp += b -> app + appInfos += app -> (gen, gen, b, notB, info) + blockerToApps += b -> app } } @@ -154,7 +148,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } private def extendAppBlock(app: (T, App[T]), infos: Set[TemplateAppInfo[T]]) : T = { - assert(!appInfo.isDefinedAt(app), "appInfo -= app must have been called to ensure blocker freshness") + assert(!appInfos.isDefinedAt(app), "appInfo -= app must have been called to ensure blocker freshness") assert(appBlockers.isDefinedAt(app), "freshAppBlocks must have been called on app before it can be unlocked") assert(infos.nonEmpty, "No point in extending blockers if no templates have been unrolled!") @@ -207,45 +201,45 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def nextGeneration(gen: Int) = gen + 3 def decreaseAllGenerations() = { - for ((block, (gen, origGen, ast, infos)) <- callInfo) { + for ((block, (gen, origGen, ast, infos)) <- callInfos) { // We also decrease the original generation here - callInfo += block -> (math.max(1,gen-1), math.max(1,origGen-1), ast, infos) + callInfos += block -> (math.max(1,gen-1), math.max(1,origGen-1), ast, infos) } - for ((app, (gen, origGen, b, notB, infos)) <- appInfo) { - appInfo += app -> (math.max(1,gen-1), math.max(1,origGen-1), b, notB, infos) + for ((app, (gen, origGen, b, notB, infos)) <- appInfos) { + appInfos += app -> (math.max(1,gen-1), math.max(1,origGen-1), b, notB, infos) } } def promoteBlocker(b: T) = { - if (callInfo contains b) { - val (_, origGen, ast, fis) = callInfo(b) + if (callInfos contains b) { + val (_, origGen, ast, fis) = callInfos(b) - callInfo += b -> (1, origGen, ast, fis) + callInfos += b -> (1, origGen, ast, fis) } - if (blockerToApp contains b) { - val app = blockerToApp(b) - val (_, origGen, _, notB, infos) = appInfo(app) + if (blockerToApps contains b) { + val app = blockerToApps(b) + val (_, origGen, _, notB, infos) = appInfos(app) - appInfo += app -> (1, origGen, b, notB, infos) + appInfos += app -> (1, origGen, b, notB, infos) } } def unrollBehind(ids: Seq[T]): Seq[T] = { - assert(ids.forall(id => (callInfo contains id) || (blockerToApp contains id))) + assert(ids.forall(id => (callInfos contains id) || (blockerToApps contains id))) var newClauses : Seq[T] = Seq.empty - val newCallInfos = ids.flatMap(id => callInfo.get(id).map(id -> _)) - callInfo --= ids + val newCallInfos = ids.flatMap(id => callInfos.get(id).map(id -> _)) + callInfos --= ids - val apps = ids.flatMap(id => blockerToApp.get(id)) - val appInfos = apps.map(app => app -> appInfo(app)) - blockerToApp --= ids - appInfo --= apps + val apps = ids.flatMap(id => blockerToApps.get(id)) + val thisAppInfos = apps.map(app => app -> appInfos(app)) + blockerToApps --= ids + appInfos --= apps - for ((app, (_, _, _, _, infos)) <- appInfos if infos.nonEmpty) { + for ((app, (_, _, _, _, infos)) <- thisAppInfos if infos.nonEmpty) { val extension = extendAppBlock(app, infos) reporter.debug(" -> extending lambda blocker: " + extension) newClauses :+= extension @@ -296,7 +290,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat newClauses ++= newCls } - for ((app @ (b, _), (gen, _, _, _, infos)) <- appInfos; info @ TemplateAppInfo(template, equals, args) <- infos) { + for ((app @ (b, _), (gen, _, _, _, infos)) <- thisAppInfos; info @ TemplateAppInfo(template, equals, args) <- infos) { var newCls = Seq.empty[T] val nb = encoder.encodeId(FreshIdentifier("b", BooleanType, true)) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index d2dc57f508c42b431e5682efb211b01289064fcf..44df7eb678380b27c93439d19eb65b3626202a80 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -544,25 +544,6 @@ trait AbstractZ3Solver extends Solver { rec(expr) } - protected def fromRawArray(r: Expr, tpe: TypeTree): Expr = r match { - case rav: RawArrayValue => - fromRawArray(rav, tpe) - case _ => - scala.sys.error("Unable to extract from raw array for "+r.asString) - } - - protected def fromRawArray(r: RawArrayValue, tpe: TypeTree): Expr = tpe match { - case RawArrayType(from, to) => - r - - case ft @ FunctionType(from, to) => - finiteLambda(r.default, r.elems.toSeq, from) - - - case _ => - scala.sys.error("Unable to extract from raw array for "+tpe.asString) - } - protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { def rec(t: Z3AST, tpe: TypeTree): Expr = { @@ -678,15 +659,8 @@ trait AbstractZ3Solver extends Solver { FiniteMap(elems, from, to) } - case FunctionType(fts, tt) => - model.getArrayValue(t) match { - case None => reporter.fatalError("Translation from Z3 to function value failed") - case Some((map, elseZ3Value)) => - val leonElseValue = rec(elseZ3Value, tt) - val leonMap = map.toSeq.map(p => rec(p._1, tupleTypeWrap(fts)) -> rec(p._2, tt)) - finiteLambda(leonElseValue, leonMap, fts) - } + rec(t, RawArrayType(tupleTypeWrap(fts), tt)) case tpe @ SetType(dt) => model.getSetValue(t) match { diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index b8eeeafabe977d3716fcd4cb59113bf108f3fbea..406dcde0a303db2cd981a0d12df8420d049db926 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -12,6 +12,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.Constructors._ +import purescala.Quantification._ import purescala.ExprOps._ import purescala.Types._ @@ -24,7 +25,8 @@ import termination._ class FairZ3Solver(val context: LeonContext, val program: Program) extends AbstractZ3Solver with Z3ModelReconstruction - with FairZ3Component { + with FairZ3Component + with EvaluatingSolver { enclosing => @@ -39,15 +41,6 @@ class FairZ3Solver(val context: LeonContext, val program: Program) protected def hasError = errors.getB(()) contains true protected def addError() = errors += () -> true - private val evaluator: Evaluator = - if(useCodeGen) { - // TODO If somehow we could not recompile each time we create a solver, - // that would be good? - new CodeGenEvaluator(context, program) - } else { - new DefaultEvaluator(context, program) - } - protected[z3] def getEvaluator : Evaluator = evaluator private val terminator : TerminationChecker = new SimpleTerminationChecker(context, program) @@ -62,8 +55,57 @@ class FairZ3Solver(val context: LeonContext, val program: Program) )} toggleWarningMessages(true) + private def extractModel(model: Z3Model, ids: Set[Identifier]): HenkinModel = { + val asMap = modelToMap(model, ids) + + def extract(b: Z3AST, m: Matcher[Z3AST]): Set[Seq[Expr]] = { + val QuantificationTypeMatcher(fromTypes, _) = m.tpe + val optEnabler = model.evalAs[Boolean](b) + val optArgs = (m.args zip fromTypes).map { + p => softFromZ3Formula(model, model.eval(Matcher.argValue(p._1), true).get, p._2) + } + + if (optEnabler == Some(true) && optArgs.forall(_.isDefined)) { + Set(optArgs.map(_.get)) + } else { + Set.empty + } + } + + val funDomains = ids.flatMap(id => id.getType match { + case ft @ FunctionType(fromTypes, _) => variables.getB(id.toVariable) match { + case Some(z3ID) => Some(id -> templateGenerator.manager.instantiations(z3ID, ft).flatMap { + case (b, m) => extract(b, m) + }) + case _ => None + } + case _ => None + }).toMap.mapValues(_.toSet) + + val asDMap = asMap.map(p => funDomains.get(p._1) match { + case Some(domain) => + val mapping = domain.toSeq.map { es => + val ev: Expr = p._2 match { + case RawArrayValue(_, mapping, dflt) => + mapping.collectFirst { + case (k,v) if evaluator.eval(Equals(k, tupleWrap(es))).result == Some(BooleanLiteral(true)) => v + } getOrElse dflt + case _ => scala.sys.error("Unexpected function encoding " + p._2) + } + es -> ev + } + p._1 -> PartialLambda(mapping, p._1.getType.asInstanceOf[FunctionType]) + case None => p + }) + + val typeGrouped = templateGenerator.manager.instantiations.groupBy(_._2.tpe) + val typeDomains = typeGrouped.mapValues(_.flatMap { case (b, m) => extract(b, m) }.toSet) + + val domain = new HenkinDomains(typeDomains) + new HenkinModel(asDMap, domain) + } - private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, Map[Identifier,Expr]) = { + private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, HenkinModel) = { if(!interrupted) { val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap @@ -92,22 +134,23 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } else Seq() }).toMap - val asMap = modelToMap(model, variables) ++ functionsAsMap ++ constantFunctionsAsMap - val evalResult = evaluator.eval(formula, asMap) + val leonModel = extractModel(model, variables) + val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) + val evalResult = evaluator.eval(formula, fullModel) evalResult match { case EvaluationResults.Successful(BooleanLiteral(true)) => reporter.debug("- Model validated.") - (true, asMap) + (true, fullModel) case EvaluationResults.Successful(res) => assert(res == BooleanLiteral(false), "Checking model returned non-boolean") reporter.debug("- Invalid model.") - (false, asMap) + (false, fullModel) case EvaluationResults.RuntimeError(msg) => reporter.debug("- Model leads to runtime error.") - (false, asMap) + (false, fullModel) case EvaluationResults.EvaluatorError(msg) => if (silenceErrors) { @@ -115,11 +158,11 @@ class FairZ3Solver(val context: LeonContext, val program: Program) } else { reporter.warning("Something went wrong. While evaluating the model, we got this : " + msg) } - (false, asMap) + (false, fullModel) } } else { - (false, Map.empty) + (false, HenkinModel.empty) } } @@ -156,7 +199,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val solver = z3.mkSolver() private val freeVars = new IncrementalSet[Identifier]() - private var constraints = new IncrementalSeq[Expr]() + private val constraints = new IncrementalSeq[Expr]() val unrollingBank = new UnrollingBank(context, templateGenerator) @@ -195,7 +238,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) var foundDefinitiveAnswer = false var definitiveAnswer : Option[Boolean] = None - var definitiveModel : Map[Identifier,Expr] = Map.empty + var definitiveModel : HenkinModel = HenkinModel.empty var definitiveCore : Set[Expr] = Set.empty def assertCnstr(expression: Expr) { @@ -204,7 +247,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) freeVars ++= newFreeVars // We make sure all free variables are registered as variables - freeVars.toSet.foreach { v => + freeVars.foreach { v => variables.cachedB(Variable(v)) { templateGenerator.encoder.encodeId(v) } @@ -236,7 +279,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) def entireFormula = andJoin(assumptions.toSeq ++ constraints.toSeq) - def foundAnswer(answer : Option[Boolean], model : Map[Identifier,Expr] = Map.empty, core: Set[Expr] = Set.empty) : Unit = { + def foundAnswer(answer: Option[Boolean], model: HenkinModel = HenkinModel.empty, core: Set[Expr] = Set.empty) : Unit = { foundDefinitiveAnswer = true definitiveAnswer = answer definitiveModel = model @@ -268,13 +311,13 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val timer = context.timers.solvers.z3.check.start() solver.push() // FIXME: remove when z3 bug is fixed - val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.currentBlockers ++ unrollingBank.quantificationAssumptions) :_*) + val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.satisfactionAssumptions) :_*) solver.pop() // FIXME: remove when z3 bug is fixed timer.stop() reporter.debug(" - Finished search with blocked literals") - lazy val allVars = freeVars.toSet + lazy val allVars: Set[Identifier] = freeVars.toSet res match { case None => @@ -300,7 +343,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) foundAnswer(None, model) } } else { - val model = modelToMap(z3model, allVars) + val model = extractModel(z3model, allVars) //lazy val modelAsString = model.toList.map(p => p._1 + " -> " + p._2).mkString("\n") //reporter.debug("- Found a model:") @@ -359,7 +402,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) val timer = context.timers.solvers.z3.check.start() solver.push() // FIXME: remove when z3 bug is fixed - val res2 = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.quantificationAssumptions) : _*) + val res2 = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.refutationAssumptions) : _*) solver.pop() // FIXME: remove when z3 bug is fixed timer.stop() diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 1106777377fa9de534bb0f1b60b1151197592f83..f755f2a4a95c6b55774d966abda7b659206c96d6 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -66,7 +66,7 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) } def getModel = { - modelToMap(solver.getModel(), freeVariables.toSet) + new Model(modelToMap(solver.getModel(), freeVariables.toSet)) } def getUnsatCore = { diff --git a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala index 62f43d736038798be92f55cad7ea83db0f3e9513..4383f8f31fe6eeaccb8e174aa0c0251638f581ba 100644 --- a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala +++ b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala @@ -6,6 +6,7 @@ package solvers.z3 import z3.scala._ import purescala.Common._ import purescala.Expressions._ +import purescala.Constructors._ import purescala.ExprOps._ import purescala.Types._ @@ -18,7 +19,7 @@ trait Z3ModelReconstruction { def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = { val expectedType = if(tpe == null) id.getType else tpe - + variables.getB(id.toVariable).flatMap { z3ID => expectedType match { case BooleanType => model.evalAs[Boolean](z3ID).map(BooleanLiteral) @@ -43,7 +44,7 @@ trait Z3ModelReconstruction { reporter.debug("Completing variable '" + id + "' to simplest value") } - for(id <- ids) { + for (id <- ids) { modelValue(model, id) match { case None if AUTOCOMPLETEMODELS => completeID(id) case None => ; @@ -51,6 +52,7 @@ trait Z3ModelReconstruction { case Some(ex) => asMap = asMap + (id -> ex) } } + asMap } diff --git a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala index c363bcfa5f738543cd9f6ee7fc4572cf4a69623d..990be35e1801b5145997ec3da71471a504febec5 100644 --- a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala +++ b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala @@ -8,6 +8,7 @@ import purescala.Expressions._ import purescala.Common._ import purescala.Types._ import purescala.Constructors._ +import purescala.Quantification._ import evaluators._ import codegen.CodeGenParams @@ -85,7 +86,7 @@ abstract class BottomUpTEGISLike[T <% Typed](name: String) extends Rule(name) { { (vecs: Vector[Vector[Expr]]) => val res = (0 to nTests-1).toVector.flatMap { i => - val inputs = vecs.map(_(i)) + val inputs = new solvers.Model((args zip vecs.map(_(i))).toMap) ev(inputs) match { case EvaluationResults.Successful(out) => Some(out) case _ => diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala index 930771443ace087f38a3f9653289d9b65ef9a9b0..790db91062742f6d2b30ec731677c3bf168ecc7e 100644 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -35,8 +35,8 @@ class LoopProcessor(val checker: TerminationChecker, val modules: ChainBuilder w val resTuple = tupleWrap(freshParams.map(_.toVariable)) definitiveSATwithModel(andJoin(path :+ Equals(srcTuple, resTuple))) match { - case Some(map) => - val args = chain.funDef.params.map(arg => map(arg.id)) + case Some(model) => + val args = chain.funDef.params.map(arg => model(arg.id)) val res = if (chain.relations.exists(_.inLambda)) MaybeBroken(chain.funDef, args) else Broken(chain.funDef, args) nonTerminating(chain.funDef) = res case None => diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index 91590dd4aca8e82c846b8c970475b500db07e785..99124c5e64ed8bb61e3c44a015775280d06c584e 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -64,7 +64,7 @@ trait Solvable extends Processor { } } - def definitiveSATwithModel(problem: Expr): Option[Map[Identifier, Expr]] = { + def definitiveSATwithModel(problem: Expr): Option[Model] = { withoutPosts { val (sat, model) = SimpleSolverAPI(solver).solveSAT(problem) if (sat.isDefined && sat.get) Some(model) else None diff --git a/src/main/scala/leon/utils/IncrementalMap.scala b/src/main/scala/leon/utils/IncrementalMap.scala index 4515351de19cf83301f8888d7c52c35b386dc73b..b07d8adc19d51573ef63784b2fc49dbea9c38ff8 100644 --- a/src/main/scala/leon/utils/IncrementalMap.scala +++ b/src/main/scala/leon/utils/IncrementalMap.scala @@ -2,12 +2,23 @@ package leon.utils -import scala.collection.mutable.{Stack, Map => MMap} +import scala.collection.mutable.{Stack, Map => MMap, Builder} +import scala.collection.generic.Shrinkable +import scala.collection.IterableLike -class IncrementalMap[A, B] extends IncrementalState { - private[this] val stack = new Stack[MMap[A, B]]() +class IncrementalMap[A, B] private(dflt: Option[B]) + extends IncrementalState + with Iterable[(A,B)] + with IterableLike[(A,B), Map[A,B]] + with Builder[(A,B), IncrementalMap[A, B]] + with Shrinkable[A] + with (A => B) { - def clear(): Unit = { + def this() = this(None) + + private val stack = new Stack[MMap[A, B]]() + + override def clear(): Unit = { stack.clear() } @@ -22,22 +33,40 @@ class IncrementalMap[A, B] extends IncrementalState { } else { MMap[A,B]() ++ stack.head } - stack.push(last) + + val withDefault = dflt match { + case Some(value) => last.withDefaultValue(value) + case None => last + } + + stack.push(withDefault) } def pop(): Unit = { stack.pop() } - def +=(a: A, b: B): Unit = { - stack.head += a -> b + def withDefaultValue[B1 >: B](dflt: B1) = { + val map = new IncrementalMap[A, B1](Some(dflt)) + map.stack.pop() + for (s <- stack.toList) map.stack.push(MMap[A,B1]().withDefaultValue(dflt) ++ s) + map } - def ++=(as: Traversable[(A, B)]): Unit = { - stack.head ++= as - } + def get(k: A) = stack.head.get(k) + def apply(k: A) = stack.head.apply(k) + def contains(k: A) = stack.head.contains(k) + def isDefinedAt(k: A) = stack.head.isDefinedAt(k) + def getOrElse[B1 >: B](k: A, e: => B1) = stack.head.getOrElse(k, e) + def values = stack.head.values + + def iterator = stack.head.iterator + def +=(kv: (A, B)) = { stack.head += kv; this } + def -=(k: A) = { stack.head -= k; this } - def toMap = stack.head + override def seq = stack.head.seq + override def newBuilder = new scala.collection.mutable.MapBuilder(Map.empty[A,B]) + def result = this push() } diff --git a/src/main/scala/leon/utils/IncrementalSeq.scala b/src/main/scala/leon/utils/IncrementalSeq.scala index 9f66d66a895aeec9f11c64d7a59eef5268b39eb6..4ec9290b5eb5c2672b0f4fae44760081ca14ba80 100644 --- a/src/main/scala/leon/utils/IncrementalSeq.scala +++ b/src/main/scala/leon/utils/IncrementalSeq.scala @@ -3,9 +3,15 @@ package leon.utils import scala.collection.mutable.Stack +import scala.collection.mutable.Builder import scala.collection.mutable.ArrayBuffer +import scala.collection.{Iterable, IterableLike} + +class IncrementalSeq[A] extends IncrementalState + with Iterable[A] + with IterableLike[A, Seq[A]] + with Builder[A, IncrementalSeq[A]] { -class IncrementalSeq[A] extends IncrementalState { private[this] val stack = new Stack[ArrayBuffer[A]]() def clear() : Unit = { @@ -25,11 +31,11 @@ class IncrementalSeq[A] extends IncrementalState { stack.pop() } - def +=(e: A): Unit = { - stack.head += e - } + def iterator = stack.flatten.iterator + def +=(e: A) = { stack.head += e; this } - def toSeq = stack.toSeq.flatten + override def newBuilder = new scala.collection.mutable.ArrayBuffer() + def result = this push() } diff --git a/src/main/scala/leon/utils/IncrementalSet.scala b/src/main/scala/leon/utils/IncrementalSet.scala index 95b473c756cb0a7a51835112c04d52bc4b404a46..b88dcf840534228c9dabba551a271691eb74f356 100644 --- a/src/main/scala/leon/utils/IncrementalSet.scala +++ b/src/main/scala/leon/utils/IncrementalSet.scala @@ -3,11 +3,17 @@ package leon.utils import scala.collection.mutable.{Stack, Set => MSet} +import scala.collection.mutable.Builder +import scala.collection.{Iterable, IterableLike} + +class IncrementalSet[A] extends IncrementalState + with Iterable[A] + with IterableLike[A, Set[A]] + with Builder[A, IncrementalSet[A]] { -class IncrementalSet[A] extends IncrementalState { private[this] val stack = new Stack[MSet[A]]() - def clear(): Unit = { + override def clear(): Unit = { stack.clear() } @@ -24,15 +30,15 @@ class IncrementalSet[A] extends IncrementalState { stack.pop() } - def +=(a: A): Unit = { - stack.head += a - } + def apply(elem: A) = toSet.contains(elem) + def contains(elem: A) = toSet.contains(elem) - def ++=(as: Traversable[A]): Unit = { - stack.head ++= as - } + def iterator = stack.flatten.iterator + def += (elem: A) = { stack.head += elem; this } + def -= (elem: A) = { stack.foreach(_ -= elem); this } - def toSet = stack.toSet.flatten + override def newBuilder = new scala.collection.mutable.SetBuilder(Set.empty[A]) + def result = this push() } diff --git a/src/main/scala/leon/utils/IncrementalState.scala b/src/main/scala/leon/utils/IncrementalState.scala index b84606af2a7f8e0975bc37cc0f5d037e62c4e2a7..32f6b7c2b11ec3405aeeb0a9f81b9bfb08076ad8 100644 --- a/src/main/scala/leon/utils/IncrementalState.scala +++ b/src/main/scala/leon/utils/IncrementalState.scala @@ -4,6 +4,8 @@ trait IncrementalState { def push(): Unit def pop(): Unit + final def pop(lvl: Int): Unit = List.range(0, lvl).foreach(_ => pop()) + def clear(): Unit def reset(): Unit } diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 7c91fa436fd5e804fb35830e113913264e27d235..148dcba64b12eb96949d109ec5672ccaef3ac437 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -5,6 +5,7 @@ package utils import leon.purescala._ import leon.purescala.Definitions.Program +import leon.purescala.Quantification.CheckForalls import leon.solvers.isabelle.AdaptationPhase import leon.synthesis.{ConvertWithOracle, ConvertHoles} import leon.verification.InjectAsserts diff --git a/src/main/scala/leon/verification/VerificationCondition.scala b/src/main/scala/leon/verification/VerificationCondition.scala index 8ddbf97abfb7c28c2aafaef446c7c0bed422ffe5..33511b23c520f980969934e50a9965a3378816e3 100644 --- a/src/main/scala/leon/verification/VerificationCondition.scala +++ b/src/main/scala/leon/verification/VerificationCondition.scala @@ -88,7 +88,7 @@ sealed abstract class VCStatus(val name: String) { } object VCStatus { - case class Invalid(cex: Map[Identifier, Expr]) extends VCStatus("invalid") + case class Invalid(cex: Model) extends VCStatus("invalid") case object Valid extends VCStatus("valid") case object Unknown extends VCStatus("unknown") case object Timeout extends VCStatus("timeout") diff --git a/src/test/resources/regression/verification/purescala/invalid/HOInvocations2.scala b/src/test/resources/regression/verification/purescala/invalid/HOInvocations2.scala index 5cd0afc42e447404d227dbfa43b3ea3f7f260678..1639def5dba3c6658fd4348dd6d0d96300cd4770 100644 --- a/src/test/resources/regression/verification/purescala/invalid/HOInvocations2.scala +++ b/src/test/resources/regression/verification/purescala/invalid/HOInvocations2.scala @@ -1,6 +1,6 @@ import leon.lang._ -object HOInvocations { +object HOInvocations2 { def switch(x: BigInt, f: (BigInt) => BigInt, g: (BigInt) => BigInt) = if(x > 0) f else g def failling_1(f: (BigInt) => BigInt) = { diff --git a/src/test/resources/regression/verification/purescala/invalid/PositiveMap2.scala b/src/test/resources/regression/verification/purescala/invalid/PositiveMap2.scala index 7b15cb2576e9119d27f77e158fddfee0e6e4bc98..713f901c9484184884e176d1a5ab0a51ced6d10f 100644 --- a/src/test/resources/regression/verification/purescala/invalid/PositiveMap2.scala +++ b/src/test/resources/regression/verification/purescala/invalid/PositiveMap2.scala @@ -1,6 +1,6 @@ import leon.lang._ -object PositiveMap { +object PositiveMap2 { abstract class List case class Cons(head: BigInt, tail: List) extends List diff --git a/src/test/resources/regression/verification/purescala/valid/HOInvocations2.scala b/src/test/resources/regression/verification/purescala/valid/HOInvocations2.scala index ba44b8fdc3e38542dd6e84664a717517a7392f5f..0af2a6a0653ef86ac513ae29623f90ffb570a910 100644 --- a/src/test/resources/regression/verification/purescala/valid/HOInvocations2.scala +++ b/src/test/resources/regression/verification/purescala/valid/HOInvocations2.scala @@ -1,6 +1,6 @@ import leon.lang._ -object HOInvocations { +object HOInvocations2 { def switch(x: Int, f: (Int) => Int, g: (Int) => Int) = if(x > 0) f else g def passing_1(f: (Int) => Int) = { diff --git a/src/test/resources/regression/verification/purescala/valid/PositiveMap2.scala b/src/test/resources/regression/verification/purescala/valid/PositiveMap2.scala index 030e7095954a063b0d9b0e19ed15c0d6958e2304..eb2262f02e3ebfc65dd7a40fa4db4741baafeb34 100644 --- a/src/test/resources/regression/verification/purescala/valid/PositiveMap2.scala +++ b/src/test/resources/regression/verification/purescala/valid/PositiveMap2.scala @@ -1,7 +1,7 @@ import leon.lang._ -object PositiveMap { +object PositiveMap2 { abstract class List case class Cons(head: BigInt, tail: List) extends List diff --git a/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala b/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala index 4404b2acf3734ff8664f309a709835564e73ccc5..c7f33afba55f14607b350b631f5dbaf6a028890a 100644 --- a/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala +++ b/src/test/scala/leon/regression/synthesis/StablePrintingSuite.scala @@ -71,6 +71,7 @@ class StablePrintingSuite extends LeonRegressionSuite { for (e <- reporter.lastErrors) { info(e) } + println(e) info(e.getMessage) fail("Compilation failed") } diff --git a/src/test/scala/leon/regression/verification/VerificationSuite.scala b/src/test/scala/leon/regression/verification/VerificationSuite.scala index 4118162453b9e13ec02020fcc3cb912bca9fb624..4c9602c00aeb363a40e725f55c81270dd9b4cf48 100644 --- a/src/test/scala/leon/regression/verification/VerificationSuite.scala +++ b/src/test/scala/leon/regression/verification/VerificationSuite.scala @@ -45,7 +45,8 @@ trait VerificationSuite extends LeonRegressionSuite { user map { u => (u.id, Program(u :: lib)) } } for ((id, p) <- programs; options <- optionVariants) { - test(f"${nextInt()}%3d: ${id.name} ${options.mkString(" ")}") { + val index = nextInt() + test(f"$index%3d: ${id.name} ${options.mkString(" ")}") { val ctx = createLeonContext(options: _*) if (forError) { intercept[LeonFatalError] {