From ab98b3eebdd413e99e0dd750430145a1e65db8a3 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Tue, 29 Sep 2015 18:00:06 +0200
Subject: [PATCH] Fix model extraction of solvers and test them

---
 src/main/scala/leon/solvers/Solver.scala      |  11 ++
 .../solvers/smtlib/SMTLIBCVC4Target.scala     | 107 +++++++++++-------
 .../leon/solvers/smtlib/SMTLIBTarget.scala    |  40 +++----
 .../leon/solvers/smtlib/SMTLIBZ3Target.scala  |  51 +++++----
 .../integration/solvers/SolversSuite.scala    |  88 ++++++++++++++
 5 files changed, 208 insertions(+), 89 deletions(-)
 create mode 100644 src/test/scala/leon/integration/solvers/SolversSuite.scala

diff --git a/src/main/scala/leon/solvers/Solver.scala b/src/main/scala/leon/solvers/Solver.scala
index 3188031e9..1e20b15fd 100644
--- a/src/main/scala/leon/solvers/Solver.scala
+++ b/src/main/scala/leon/solvers/Solver.scala
@@ -36,6 +36,16 @@ trait AbstractModel[+This <: Model with AbstractModel[This]]
 
   def iterator = mapping.iterator
   def seq = mapping.seq
+
+  def asString(implicit ctx: LeonContext) = {
+    if (mapping.isEmpty) {
+      "Model()"
+    } else {
+      (for ((k,v) <- mapping.toSeq.sortBy(_._1)) yield {
+        f"  ${k.asString}%-20s -> ${v.asString}"
+      }).mkString("Model(\n", ",\n", ")")
+    }
+  }
 }
 
 trait AbstractModelBuilder[+This <: Model with AbstractModel[This]]
@@ -101,6 +111,7 @@ trait Solver extends Interruptible {
     leonContext.reporter.warning(err.getMessage)
     throw err
   }
+
   protected def unsupported(t: Tree, str: String): Nothing = {
     val err = SolverUnsupportedError(t, this, Some(str))
     leonContext.reporter.warning(err.getMessage)
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
index 679e42ac7..d5515b42e 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
@@ -36,61 +36,82 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
     }
   }
 
-  override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match {
-    // EK: This hack is necessary for sygus which does not strictly follow smt-lib for negative literals
-    case (SimpleSymbol(SSymbol(v)), IntegerType) if v.startsWith("-") =>
-      try {
-        InfiniteIntegerLiteral(v.toInt)
-      } catch {
-        case t: Throwable =>
-          super.fromSMT(s, tpe)
-      }
+  override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None)
+                                (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
+    (t, otpe) match {
+      // EK: This hack is necessary for sygus which does not strictly follow smt-lib for negative literals
+      case (SimpleSymbol(SSymbol(v)), Some(IntegerType)) if v.startsWith("-") =>
+        try {
+          InfiniteIntegerLiteral(v.toInt)
+        } catch {
+          case _: Throwable =>
+            super.fromSMT(t, otpe)
+        }
+
+      case (SimpleSymbol(s), Some(tp: TypeParameter)) =>
+        val n = s.name.split("_").toList.last
+        GenericValue(tp, n.toInt)
 
-    case (SimpleSymbol(s), tp: TypeParameter) =>
-      val n = s.name.split("_").toList.last
-      GenericValue(tp, n.toInt)
+      case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), Some(SetType(base))) =>
+        FiniteSet(Set(), base)
 
-    case (QualifiedIdentifier(SMTIdentifier(SSymbol("emptyset"), Seq()), _), SetType(base)) =>
-      FiniteSet(Set(), base)
+      case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), Some(tpe)) =>
+        tpe match {
+          case RawArrayType(k, v) =>
+            RawArrayValue(k, Map(), fromSMT(elem, v))
 
-    case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), RawArrayType(k,v)) =>
-      RawArrayValue(k, Map(), fromSMT(elem, v))
+          case FunctionType(from, to) =>
+            RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to))
 
-    case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), FunctionType(from,to)) =>
-      RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to))
+          case MapType(k, v) =>
+            FiniteMap(Nil, k, v)
 
-    case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), RawArrayType(k,v)) =>
-      val RawArrayValue(_, elems, base) = fromSMT(arr, tpe)
-      RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base)
+        }
 
-    case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), FunctionType(from,to)) =>
-      val RawArrayValue(k, elems, base) = fromSMT(arr, tpe)
-      RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, to)), base)
+      case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), Some(tpe)) =>
+        tpe match {
+          case RawArrayType(k, v) =>
+            RawArrayValue(k, Map(), fromSMT(elem, v))
 
-    case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), SetType(base)) =>
-      FiniteSet(elems.map(fromSMT(_, base)).toSet, base)
+          case FunctionType(from, to) =>
+            RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to))
 
-    case (FunctionApplication(SimpleSymbol(SSymbol("insert")), elems), SetType(base)) =>
-      val selems = elems.init.map(fromSMT(_, base))
-      val FiniteSet(se, _) = fromSMT(elems.last, tpe)
-      FiniteSet(se ++ selems, base)
+          case MapType(k, v) =>
+            FiniteMap(Nil, k, v)
 
-    case (FunctionApplication(SimpleSymbol(SSymbol("union")), elems), SetType(base)) =>
-      FiniteSet(elems.flatMap(fromSMT(_, tpe) match {
-        case FiniteSet(elems, _) => elems
-      }).toSet, base)
+        }
 
-    case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), RawArrayType(k, v)) =>
-      RawArrayValue(k, Map(), fromSMT(elem, v))
+      case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), Some(tpe)) =>
+        tpe match {
+          case RawArrayType(_, v) =>
+            val RawArrayValue(k, elems, base) = fromSMT(arr, otpe)
+            RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base)
 
-    // FIXME (nicolas)
-    // some versions of CVC4 seem to generate array constants with "as const" notation instead of the __array_store_all__
-    // one I've witnessed up to now. Don't know why this is happening...
-    case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), FunctionType(from, to)) =>
-      RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to))
+          case FunctionType(_, v) =>
+            val RawArrayValue(k, elems, base) = fromSMT(arr, otpe)
+            RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base)
 
-    case _ =>
-      super.fromSMT(s, tpe)
+          case MapType(k, v) =>
+            val FiniteMap(elems, k, v) = fromSMT(arr, otpe)
+            FiniteMap(elems :+ (fromSMT(key, k) -> fromSMT(elem, v)), k, v)
+        }
+
+      case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), Some(SetType(base))) =>
+        FiniteSet(elems.map(fromSMT(_, base)).toSet, base)
+
+      case (FunctionApplication(SimpleSymbol(SSymbol("insert")), elems), Some(SetType(base))) =>
+        val selems = elems.init.map(fromSMT(_, base))
+        val FiniteSet(se, _) = fromSMT(elems.last, otpe)
+        FiniteSet(se ++ selems, base)
+
+      case (FunctionApplication(SimpleSymbol(SSymbol("union")), elems), Some(SetType(base))) =>
+        FiniteSet(elems.flatMap(fromSMT(_, otpe) match {
+          case FiniteSet(elems, _) => elems
+        }).toSet, base)
+
+      case _ =>
+        super.fromSMT(t, otpe)
+    }
   }
 
   override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]) = e match {
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
index b5b81a3e3..17faa8737 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
@@ -44,8 +44,7 @@ trait SMTLIBTarget extends Interruptible {
 
   protected def getNewInterpreter(ctx: LeonContext): ProcessInterpreter
 
-  protected def unsupported(t: Tree, str: String): Nothing;
-
+  protected def unsupported(t: Tree, str: String): Nothing
 
   protected lazy val interpreter = getNewInterpreter(context)
 
@@ -622,15 +621,6 @@ trait SMTLIBTarget extends Interruptible {
   }
 
   /* Translate an SMTLIB term back to a Leon Expr */
-
-  protected def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
-    fromSMT(pair._1, Some(pair._2))
-  }
-
-  protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
-    fromSMT(s, Some(tpe))
-  }
-
   protected def fromSMT(t: Term, otpe: Option[TypeTree] = None)
                        (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
 
@@ -740,22 +730,22 @@ trait SMTLIBTarget extends Interruptible {
             LessThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType))
 
           case ("+", args) =>
-            args.map(fromSMT(_, IntegerType)).reduceLeft(plus _)
+            args.map(fromSMT(_, otpe)).reduceLeft(plus _)
 
           case ("-", List(a)) =>
-            UMinus(fromSMT(a, IntegerType))
+            UMinus(fromSMT(a, otpe))
 
           case ("-", List(a, b)) =>
-            Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType))
+            Minus(fromSMT(a, otpe), fromSMT(b, otpe))
 
           case ("*", args) =>
-            args.map(fromSMT(_, IntegerType)).reduceLeft(times _)
+            args.map(fromSMT(_, otpe)).reduceLeft(times _)
 
           case ("/", List(a, b)) =>
-            Division(fromSMT(a, IntegerType), fromSMT(b, IntegerType))
+            Division(fromSMT(a, otpe), fromSMT(b, otpe))
 
           case ("div", List(a, b)) =>
-            Division(fromSMT(a, IntegerType), fromSMT(b, IntegerType))
+            Division(fromSMT(a, otpe), fromSMT(b, otpe))
 
           case ("not", List(a)) =>
             Not(fromSMT(a, BooleanType))
@@ -774,24 +764,30 @@ trait SMTLIBTarget extends Interruptible {
             reporter.fatalError("Function "+app+" not handled in fromSMT: "+s)
         }
 
+      case (Core.True(), Some(BooleanType))  => BooleanLiteral(true)
+      case (Core.False(), Some(BooleanType)) => BooleanLiteral(false)
+
       case (SimpleSymbol(s), otpe) if lets contains s =>
         fromSMT(lets(s), otpe)
 
       case (SimpleSymbol(s), otpe) =>
         variables.getA(s).map(_.toVariable).getOrElse {
-          reporter.fatalError("Unknown symbol: "+s)
+          throw new Exception()
         }
 
-      case (Core.True(), Some(BooleanType))  => BooleanLiteral(true)
-      case (Core.False(), Some(BooleanType)) => BooleanLiteral(false)
-
       case _ =>
-        reporter.fatalError("Unhandled case in fromSMT: " + t+" (_ :"+otpe+")")
+        reporter.fatalError(s"Unhandled case in fromSMT: $t : ${otpe.map(_.asString(context)).getOrElse("?")} (${t.getClass})")
 
     }
   }
 
+  final protected def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
+    fromSMT(pair._1, Some(pair._2))
+  }
 
+  final protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
+    fromSMT(s, Some(tpe))
+  }
 }
 
 // Unique numbers
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
index c0c696617..657100814 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
@@ -72,31 +72,34 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
     Sort(SMTIdentifier(setSort.get), Seq(declareSort(of)))
   }
 
-  override protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match {
-    case (SimpleSymbol(s), tp: TypeParameter) =>
-      val n = s.name.split("!").toList.last
-      GenericValue(tp, n.toInt)
-
-
-    case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), tpe) =>
-      if (letDefs contains k) {
-        // Need to recover value form function model
-        fromRawArray(extractRawArray(letDefs(k), tpe), tpe)
-      } else {
-        throw LeonFatalError("Array on non-function or unknown symbol "+k)
-      }
-
-    case (FunctionApplication(
-      QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))),
-      Seq(defV)
-    ), tpe) =>
-      val ktpe = sorts.fromB(k)
-      val vtpe = sorts.fromB(v)
-
-      fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe)
+  override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None)
+                                (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = {
+    (t, otpe) match {
+      case (SimpleSymbol(s), Some(tp: TypeParameter)) =>
+        val n = s.name.split("!").toList.last
+        GenericValue(tp, n.toInt)
+
+
+      case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe)) =>
+        if (letDefs contains k) {
+          // Need to recover value form function model
+          fromRawArray(extractRawArray(letDefs(k), tpe), tpe)
+        } else {
+          throw LeonFatalError("Array on non-function or unknown symbol "+k)
+        }
+
+      case (FunctionApplication(
+        QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))),
+        Seq(defV)
+      ), Some(tpe)) =>
+        val ktpe = sorts.fromB(k)
+        val vtpe = sorts.fromB(v)
+
+        fromRawArray(RawArrayValue(ktpe, Map(), fromSMT(defV, vtpe)), tpe)
 
-    case _ =>
-      super.fromSMT(s, tpe)
+      case _ =>
+        super.fromSMT(t, otpe)
+    }
   }
 
   override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match {
diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala
new file mode 100644
index 000000000..e09177804
--- /dev/null
+++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala
@@ -0,0 +1,88 @@
+/* Copyright 2009-2015 EPFL, Lausanne */
+
+package leon.integration.solvers
+
+import leon.test._
+import leon.purescala.Common._
+import leon.purescala.Definitions._
+import leon.purescala.ExprOps._
+import leon.purescala.Constructors._
+import leon.purescala.Expressions._
+import leon.purescala.Types._
+import leon.LeonContext
+
+import leon.solvers._
+import leon.solvers.smtlib._
+import leon.solvers.combinators._
+import leon.solvers.z3._
+
+class SolversSuite extends LeonTestSuiteWithProgram {
+
+  val sources = List()
+
+  val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = {
+    (if (SolverFactory.hasNativeZ3) Seq(
+      ("fairz3",   (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm))
+    ) else Nil) ++
+    (if (SolverFactory.hasZ3)       Seq(
+      ("smt-z3",   (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm)))
+    ) else Nil) ++
+    (if (SolverFactory.hasCVC4)     Seq(
+      ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm)))
+    ) else Nil)
+  }
+
+  // Check that we correctly extract several types from solver models
+  for ((sname, sf) <- getFactories) {
+    test(s"Model Extraction in $sname") { implicit fix =>
+      val ctx = fix._1
+      val pgm = fix._2
+
+      val solver = sf(ctx, pgm)
+
+      val types = Seq(
+        BooleanType,
+        UnitType,
+        CharType,
+        IntegerType,
+        Int32Type,
+        TypeParameter.fresh("T"),
+        SetType(IntegerType),
+        MapType(IntegerType, IntegerType),
+        TupleType(Seq(IntegerType, BooleanType, Int32Type))
+      )
+
+      val vs = types.map(FreshIdentifier("v", _).toVariable)
+
+
+      // We need to make sure models are not co-finite
+      val cnstr = andJoin(vs.map(v => v.getType match {
+        case UnitType =>
+          Equals(v, simplestValue(v.getType))
+        case SetType(base) =>
+          Not(ElementOfSet(simplestValue(base), v))
+        case MapType(from, to) =>
+          Not(Equals(MapApply(v, simplestValue(from)), simplestValue(to)))
+        case _ =>
+          not(Equals(v, simplestValue(v.getType)))
+      }))
+
+      solver.assertCnstr(cnstr)
+
+      solver.check match {
+        case Some(true) =>
+          val model = solver.getModel
+          for (v <- vs) {
+            if (model.isDefinedAt(v.id)) {
+              assert(model(v.id).getType === v.getType, "Extracting value of type "+v.getType)
+            } else {
+              fail("Model does not contain "+v.id+" of type "+v.getType)
+            }
+          }
+        case _ =>
+          fail("Constraint "+cnstr.asString+" is unsat!?")
+      }
+
+    }
+  }
+}
-- 
GitLab