From ab98b3eebdd413e99e0dd750430145a1e65db8a3 Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <ekneuss@gmail.com> Date: Tue, 29 Sep 2015 18:00:06 +0200 Subject: [PATCH] Fix model extraction of solvers and test them --- src/main/scala/leon/solvers/Solver.scala | 11 ++ .../solvers/smtlib/SMTLIBCVC4Target.scala | 107 +++++++++++------- .../leon/solvers/smtlib/SMTLIBTarget.scala | 40 +++---- .../leon/solvers/smtlib/SMTLIBZ3Target.scala | 51 +++++---- .../integration/solvers/SolversSuite.scala | 88 ++++++++++++++ 5 files changed, 208 insertions(+), 89 deletions(-) create mode 100644 src/test/scala/leon/integration/solvers/SolversSuite.scala diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala index 3188031e9..1e20b15fd 100644 --- a/src/main/scala/leon/solvers/Solver.scala +++ b/src/main/scala/leon/solvers/Solver.scala @@ -36,6 +36,16 @@ trait AbstractModel[+This <: Model with AbstractModel[This]] def iterator = mapping.iterator def seq = mapping.seq + + def asString(implicit ctx: LeonContext) = { + if (mapping.isEmpty) { + "Model()" + } else { + (for ((k,v) <- mapping.toSeq.sortBy(_._1)) yield { + f" ${k.asString}%-20s -> ${v.asString}" + }).mkString("Model(\n", ",\n", ")") + } + } } trait AbstractModelBuilder[+This <: Model with AbstractModel[This]] @@ -101,6 +111,7 @@ trait Solver extends Interruptible { leonContext.reporter.warning(err.getMessage) throw err } + protected def unsupported(t: Tree, str: String): Nothing = { val err = SolverUnsupportedError(t, this, Some(str)) leonContext.reporter.warning(err.getMessage) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index 679e42ac7..d5515b42e 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -36,61 +36,82 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { } } - override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { - // EK: This hack is necessary for sygus which does not strictly follow smt-lib for negative literals - case (SimpleSymbol(SSymbol(v)), IntegerType) if v.startsWith("-") => - try { - InfiniteIntegerLiteral(v.toInt) - } catch { - case t: Throwable => - super.fromSMT(s, tpe) - } + override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) + (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + (t, otpe) match { + // EK: This hack is necessary for sygus which does not strictly follow smt-lib for negative literals + case (SimpleSymbol(SSymbol(v)), Some(IntegerType)) if v.startsWith("-") => + try { + InfiniteIntegerLiteral(v.toInt) + } catch { + case _: Throwable => + super.fromSMT(t, otpe) + } + + case (SimpleSymbol(s), Some(tp: TypeParameter)) => + val n = s.name.split("_").toList.last + GenericValue(tp, n.toInt) - case (SimpleSymbol(s), tp: TypeParameter) => - val n = s.name.split("_").toList.last - GenericValue(tp, n.toInt) + case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), Some(SetType(base))) => + FiniteSet(Set(), base) - case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), SetType(base)) => - FiniteSet(Set(), base) + case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), Some(tpe)) => + tpe match { + case RawArrayType(k, v) => + RawArrayValue(k, Map(), fromSMT(elem, v)) - case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), RawArrayType(k,v)) => - RawArrayValue(k, Map(), fromSMT(elem, v)) + case FunctionType(from, to) => + RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) - case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), FunctionType(from,to)) => - RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) + case MapType(k, v) => + FiniteMap(Nil, k, v) - case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), RawArrayType(k,v)) => - val RawArrayValue(_, elems, base) = fromSMT(arr, tpe) - RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) + } - case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), FunctionType(from,to)) => - val RawArrayValue(k, elems, base) = fromSMT(arr, tpe) - RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, to)), base) + case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), Some(tpe)) => + tpe match { + case RawArrayType(k, v) => + RawArrayValue(k, Map(), fromSMT(elem, v)) - case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) => - FiniteSet(elems.map(fromSMT(_, base)).toSet, base) + case FunctionType(from, to) => + RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) - case (FunctionApplication(SimpleSymbol(SSymbol("insert")), elems), SetType(base)) => - val selems = elems.init.map(fromSMT(_, base)) - val FiniteSet(se, _) = fromSMT(elems.last, tpe) - FiniteSet(se ++ selems, base) + case MapType(k, v) => + FiniteMap(Nil, k, v) - case (FunctionApplication(SimpleSymbol(SSymbol("union")), elems), SetType(base)) => - FiniteSet(elems.flatMap(fromSMT(_, tpe) match { - case FiniteSet(elems, _) => elems - }).toSet, base) + } - case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), RawArrayType(k, v)) => - RawArrayValue(k, Map(), fromSMT(elem, v)) + case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), Some(tpe)) => + tpe match { + case RawArrayType(_, v) => + val RawArrayValue(k, elems, base) = fromSMT(arr, otpe) + RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) - // FIXME (nicolas) - // some versions of CVC4 seem to generate array constants with "as const" notation instead of the __array_store_all__ - // one I've witnessed up to now. Don't know why this is happening... - case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), FunctionType(from, to)) => - RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to)) + case FunctionType(_, v) => + val RawArrayValue(k, elems, base) = fromSMT(arr, otpe) + RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) - case _ => - super.fromSMT(s, tpe) + case MapType(k, v) => + val FiniteMap(elems, k, v) = fromSMT(arr, otpe) + FiniteMap(elems :+ (fromSMT(key, k) -> fromSMT(elem, v)), k, v) + } + + case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), Some(SetType(base))) => + FiniteSet(elems.map(fromSMT(_, base)).toSet, base) + + case (FunctionApplication(SimpleSymbol(SSymbol("insert")), elems), Some(SetType(base))) => + val selems = elems.init.map(fromSMT(_, base)) + val FiniteSet(se, _) = fromSMT(elems.last, otpe) + FiniteSet(se ++ selems, base) + + case (FunctionApplication(SimpleSymbol(SSymbol("union")), elems), Some(SetType(base))) => + FiniteSet(elems.flatMap(fromSMT(_, otpe) match { + case FiniteSet(elems, _) => elems + }).toSet, base) + + case _ => + super.fromSMT(t, otpe) + } } override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]) = e match { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index b5b81a3e3..17faa8737 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -44,8 +44,7 @@ trait SMTLIBTarget extends Interruptible { protected def getNewInterpreter(ctx: LeonContext): ProcessInterpreter - protected def unsupported(t: Tree, str: String): Nothing; - + protected def unsupported(t: Tree, str: String): Nothing protected lazy val interpreter = getNewInterpreter(context) @@ -622,15 +621,6 @@ trait SMTLIBTarget extends Interruptible { } /* Translate an SMTLIB term back to a Leon Expr */ - - protected def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - fromSMT(pair._1, Some(pair._2)) - } - - protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - fromSMT(s, Some(tpe)) - } - protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { @@ -740,22 +730,22 @@ trait SMTLIBTarget extends Interruptible { LessThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) case ("+", args) => - args.map(fromSMT(_, IntegerType)).reduceLeft(plus _) + args.map(fromSMT(_, otpe)).reduceLeft(plus _) case ("-", List(a)) => - UMinus(fromSMT(a, IntegerType)) + UMinus(fromSMT(a, otpe)) case ("-", List(a, b)) => - Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + Minus(fromSMT(a, otpe), fromSMT(b, otpe)) case ("*", args) => - args.map(fromSMT(_, IntegerType)).reduceLeft(times _) + args.map(fromSMT(_, otpe)).reduceLeft(times _) case ("/", List(a, b)) => - Division(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + Division(fromSMT(a, otpe), fromSMT(b, otpe)) case ("div", List(a, b)) => - Division(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + Division(fromSMT(a, otpe), fromSMT(b, otpe)) case ("not", List(a)) => Not(fromSMT(a, BooleanType)) @@ -774,24 +764,30 @@ trait SMTLIBTarget extends Interruptible { reporter.fatalError("Function "+app+" not handled in fromSMT: "+s) } + case (Core.True(), Some(BooleanType)) => BooleanLiteral(true) + case (Core.False(), Some(BooleanType)) => BooleanLiteral(false) + case (SimpleSymbol(s), otpe) if lets contains s => fromSMT(lets(s), otpe) case (SimpleSymbol(s), otpe) => variables.getA(s).map(_.toVariable).getOrElse { - reporter.fatalError("Unknown symbol: "+s) + throw new Exception() } - case (Core.True(), Some(BooleanType)) => BooleanLiteral(true) - case (Core.False(), Some(BooleanType)) => BooleanLiteral(false) - case _ => - reporter.fatalError("Unhandled case in fromSMT: " + t+" (_ :"+otpe+")") + reporter.fatalError(s"Unhandled case in fromSMT: $t : ${otpe.map(_.asString(context)).getOrElse("?")} (${t.getClass})") } } + final protected def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + fromSMT(pair._1, Some(pair._2)) + } + final protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + fromSMT(s, Some(tpe)) + } } // Unique numbers diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index c0c696617..657100814 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -72,31 +72,34 @@ trait SMTLIBZ3Target extends SMTLIBTarget { Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) } - override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { - case (SimpleSymbol(s), tp: TypeParameter) => - val n = s.name.split("!").toList.last - GenericValue(tp, n.toInt) - - - case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), tpe) => - if (letDefs contains k) { - // Need to recover value form function model - fromRawArray(extractRawArray(letDefs(k), tpe), tpe) - } else { - throw LeonFatalError("Array on non-function or unknown symbol "+k) - } - - case (FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))), - Seq(defV) - ), tpe) => - val ktpe = sorts.fromB(k) - val vtpe = sorts.fromB(v) - - fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe) + override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) + (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + (t, otpe) match { + case (SimpleSymbol(s), Some(tp: TypeParameter)) => + val n = s.name.split("!").toList.last + GenericValue(tp, n.toInt) + + + case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe)) => + if (letDefs contains k) { + // Need to recover value form function model + fromRawArray(extractRawArray(letDefs(k), tpe), tpe) + } else { + throw LeonFatalError("Array on non-function or unknown symbol "+k) + } + + case (FunctionApplication( + QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))), + Seq(defV) + ), Some(tpe)) => + val ktpe = sorts.fromB(k) + val vtpe = sorts.fromB(v) + + fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe) - case _ => - super.fromSMT(s, tpe) + case _ => + super.fromSMT(t, otpe) + } } override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala new file mode 100644 index 000000000..e09177804 --- /dev/null +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -0,0 +1,88 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.integration.solvers + +import leon.test._ +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.ExprOps._ +import leon.purescala.Constructors._ +import leon.purescala.Expressions._ +import leon.purescala.Types._ +import leon.LeonContext + +import leon.solvers._ +import leon.solvers.smtlib._ +import leon.solvers.combinators._ +import leon.solvers.z3._ + +class SolversSuite extends LeonTestSuiteWithProgram { + + val sources = List() + + val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { + (if (SolverFactory.hasNativeZ3) Seq( + ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) + ) else Nil) ++ + (if (SolverFactory.hasZ3) Seq( + ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ) else Nil) ++ + (if (SolverFactory.hasCVC4) Seq( + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ) else Nil) + } + + // Check that we correctly extract several types from solver models + for ((sname, sf) <- getFactories) { + test(s"Model Extraction in $sname") { implicit fix => + val ctx = fix._1 + val pgm = fix._2 + + val solver = sf(ctx, pgm) + + val types = Seq( + BooleanType, + UnitType, + CharType, + IntegerType, + Int32Type, + TypeParameter.fresh("T"), + SetType(IntegerType), + MapType(IntegerType, IntegerType), + TupleType(Seq(IntegerType, BooleanType, Int32Type)) + ) + + val vs = types.map(FreshIdentifier("v", _).toVariable) + + + // We need to make sure models are not co-finite + val cnstr = andJoin(vs.map(v => v.getType match { + case UnitType => + Equals(v, simplestValue(v.getType)) + case SetType(base) => + Not(ElementOfSet(simplestValue(base), v)) + case MapType(from, to) => + Not(Equals(MapApply(v, simplestValue(from)), simplestValue(to))) + case _ => + not(Equals(v, simplestValue(v.getType))) + })) + + solver.assertCnstr(cnstr) + + solver.check match { + case Some(true) => + val model = solver.getModel + for (v <- vs) { + if (model.isDefinedAt(v.id)) { + assert(model(v.id).getType === v.getType, "Extracting value of type "+v.getType) + } else { + fail("Model does not contain "+v.id+" of type "+v.getType) + } + } + case _ => + fail("Constraint "+cnstr.asString+" is unsat!?") + } + + } + } +} -- GitLab