From 8c402795ed49a7d2740c7a394644818f993d0bce Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Thu, 26 Mar 2015 15:02:10 +0100
Subject: [PATCH] Upgrade to SMTLIB 2.5

---
 project/Build.scala                           |   2 +-
 .../solvers/smtlib/SMTLIBCVC4Target.scala     |  31 +++-
 .../leon/solvers/smtlib/SMTLIBTarget.scala    | 146 ++++++++++++------
 .../leon/solvers/smtlib/SMTLIBZ3Target.scala  |  19 ++-
 src/main/scala/leon/utils/Library.scala       |   4 +
 5 files changed, 134 insertions(+), 68 deletions(-)

diff --git a/project/Build.scala b/project/Build.scala
index ea9934a42..be7d12251 100644
--- a/project/Build.scala
+++ b/project/Build.scala
@@ -80,6 +80,6 @@ object Leon extends Build {
     def project(repo: String, version: String) = RootProject(uri(s"${repo}#${version}"))
 
     lazy val bonsai      = project("git://github.com/colder/bonsai.git",     "0fec9f97f4220fa94b1f3f305f2e8b76a3cd1539")
-    lazy val scalaSmtLib = project("git://github.com/regb/scala-smtlib.git", "a7e4c4c1963cbf202c4cc12da0751ec56a498398")
+    lazy val scalaSmtLib = project("git://github.com/regb/scala-smtlib.git", "3dcd7ba20e5dcddd777c57466b60f7c4d86a3ff2")
   }
 }
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
index 2d88f4f34..8074d9f9a 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
@@ -6,14 +6,15 @@ package smtlib
 
 import purescala._
 import Common._
-import Expressions._
+import Expressions.{Assert => _, _}
 import Extractors._
 import Constructors._
 import Types._
 
-import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _}
+import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, ForAll => SMTForall, _}
 import _root_.smtlib.parser.Commands._
 import _root_.smtlib.interpreters.CVC4Interpreter
+import _root_.smtlib.theories._
 
 trait SMTLIBCVC4Target extends SMTLIBTarget {
   this: SMTLIBSolver =>
@@ -89,12 +90,6 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
       super[SMTLIBTarget].fromSMT(s, tpe)
   }
 
-  def encodeMapType(tpe: TypeTree): TypeTree = tpe match {
-    case MapType(from, to) =>
-      tupleTypeWrap(Seq(SetType(from), RawArrayType(from, to)))
-    case _ => sys.error("Woot")
-  }
-
   override def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]) = e match {
     case a @ FiniteArray(elems, default, size) =>
       val tpe @ ArrayType(base) = normalizeType(a.getType)
@@ -108,6 +103,26 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
 
       FunctionApplication(constructors.toB(tpe), Seq(toSMT(size), ar))
 
+    case fm @ FiniteMap(elems) =>
+      import OptionManager._
+      val mt @ MapType(from, to) = fm.getType
+      val ms = declareSort(mt)
+
+      var m: Term = declareVariable(FreshIdentifier("mapconst", RawArrayType(from, leonOptionType(to))))
+
+      sendCommand(Assert(SMTForall(
+        SortedVar(SSymbol("mapelem"), declareSort(from)), Seq(),
+        Core.Equals(
+          ArraysEx.Select(m, SSymbol("mapelem")),
+          toSMT(mkLeonNone(to))
+        )
+      )))
+
+      for ((k, v) <- elems) {
+        m = FunctionApplication(SSymbol("store"), Seq(m, toSMT(k), toSMT(mkLeonSome(v))))
+      }
+
+      m
     /**
      * ===== Set operations =====
      */
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
index 3566b165b..d3700d15e 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
@@ -14,7 +14,7 @@ import utils.IncrementalBijection
 
 import _root_.smtlib.common._
 import _root_.smtlib.printer.{RecursivePrinter => SMTPrinter}
-import _root_.smtlib.parser.Commands.{Constructor => SMTConstructor, _}
+import _root_.smtlib.parser.Commands.{Constructor => SMTConstructor, FunDef => _, _}
 import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, Let => SMTLet, _}
 import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _}
 import _root_.smtlib.theories._
@@ -52,9 +52,58 @@ trait SMTLIBTarget {
   val selectors    = new IncrementalBijection[(TypeTree, Int), SSymbol]()
   val testers      = new IncrementalBijection[TypeTree, SSymbol]()
   val variables    = new IncrementalBijection[Identifier, SSymbol]()
+  val classes      = new IncrementalBijection[CaseClassDef, SSymbol]()
   val sorts        = new IncrementalBijection[TypeTree, Sort]()
   val functions    = new IncrementalBijection[TypedFunDef, SSymbol]()
 
+  protected object OptionManager {
+    lazy val leonOption = program.library.Option.get
+    lazy val leonSome = program.library.Some.get
+    lazy val leonNone = program.library.None.get
+    def leonOptionType(tp: TypeTree) = AbstractClassType(leonOption, Seq(tp))
+
+    def mkLeonSome(e: Expr) = CaseClass(CaseClassType(leonSome, Seq(e.getType)), Seq(e))
+    def mkLeonNone(tp: TypeTree) = CaseClass(CaseClassType(leonNone, Seq(tp)), Seq())
+
+    def someTester(tp: TypeTree): SSymbol = {
+      val someTp = CaseClassType(leonSome, Seq(tp))
+      testers.getB(someTp) match {
+        case Some(s) => s
+        case None =>
+          declareOptionSort(tp)
+          someTester(tp)
+      }
+    }
+    def someConstructor(tp: TypeTree): SSymbol = {
+      val someTp = CaseClassType(leonSome, Seq(tp))
+      constructors.getB(someTp) match {
+        case Some(s) => s
+        case None =>
+          declareOptionSort(tp)
+          someConstructor(tp)
+      }
+    }
+    def someSelector(tp: TypeTree): SSymbol = {
+      val someTp = CaseClassType(leonSome, Seq(tp))
+      selectors.getB(someTp,0) match {
+        case Some(s) => s
+        case None =>
+          declareOptionSort(tp)
+          someSelector(tp)
+      }
+    }
+
+    def inlinedOptionGet(t : Term, tp: TypeTree): Term = {
+      FunctionApplication(SSymbol("ite"), Seq(
+          FunctionApplication(someTester(tp), Seq(t)),
+          FunctionApplication(someSelector(tp), Seq(t)),
+          declareVariable(FreshIdentifier("error_value", tp))
+        )
+      )
+    }
+
+  }
+
   def normalizeType(t: TypeTree): TypeTree = t match {
     case ct: ClassType if ct.parent.isDefined => ct.parent.get
     case tt: TupleType => tupleTypeWrap(tt.bases.map(normalizeType))
@@ -83,6 +132,15 @@ trait SMTLIBTarget {
     case ft @ FunctionType(from, to) =>
       finiteLambda(r.default, r.elems.toSeq, from)
 
+    case MapType(from, to) =>
+      // We expect a RawArrayValue with keys in from and values in Option[to],
+      // with default value == None
+      require(from == r.keyTpe && r.default == OptionManager.mkLeonNone(to))
+      val elems = r.elems.mapValues {
+        case CaseClass(leonSome, Seq(x)) => x
+      }.toSeq
+      finiteMap(elems, from, to)
+
     case _ =>
       unsupported("Unable to extract from raw array for "+tpe)
   }
@@ -122,55 +180,26 @@ trait SMTLIBTarget {
     }
   }
 
-  var mapSort: Option[SSymbol] = None
-  var optionSort: Option[SSymbol] = None
-
   def declareOptionSort(of: TypeTree): Sort = {
-    optionSort match {
-      case None =>
-        val t      = SSymbol("T")
-
-        val s      = SSymbol("Option")
-        val some   = SSymbol("Some")
-        val some_v = SSymbol("Some_v")
-        val none   = SSymbol("None")
-
-        val caseSome = SList(some, SList(some_v, t))
-        val caseNone = SList(none)
-
-        val cmd = NonStandardCommand(SList(SSymbol("declare-datatypes"), SList(t), SList(SList(s, caseSome, caseNone))))
-        sendCommand(cmd)
-
-        optionSort = Some(s)
-      case _ =>
-    }
-
-    Sort(SMTIdentifier(optionSort.get), Seq(declareSort(of)))
+    declareSort(OptionManager.leonOptionType(of))
   }
 
   def declareMapSort(from: TypeTree, to: TypeTree): Sort = {
-    mapSort match {
-      case None =>
-        val m = SSymbol("Map")
-        val a = SSymbol("A")
-        val b = SSymbol("B")
-        mapSort = Some(m)
+    sorts.cachedB(MapType(from, to)) {
+        val m = freshSym("Map")
 
-        val optSort = declareOptionSort(to)
+        val toSort = declareOptionSort(to)
+        val fromSort = declareSort(from)
 
         val arraySort = Sort(SMTIdentifier(SSymbol("Array")),
-                             Seq(Sort(SMTIdentifier(a)), optSort))
+                             Seq(fromSort, toSort))
+        val cmd = DefineSort(m, Seq(), arraySort)
 
-        val cmd = DefineSort(m, Seq(a, b), arraySort)
         sendCommand(cmd)
-      case _ =>
+        Sort(SMTIdentifier(m), Seq())
     }
-
-    Sort(SMTIdentifier(mapSort.get), Seq(declareSort(from), declareSort(to)))
   }
 
-
-
   def freshSym(id: Identifier): SSymbol = freshSym(id.name)
   def freshSym(name: String): SSymbol = id2sym(FreshIdentifier(name))
 
@@ -315,7 +344,11 @@ trait SMTLIBTarget {
         FreshIdentifier(tfd.id.name)
       }
       val s = id2sym(id)
-      sendCommand(DeclareFun(s, tfd.params.map(p => declareSort(p.getType)), declareSort(tfd.returnType)))
+      sendCommand(DeclareFun(
+        s,
+        tfd.params.map( (p: ValDef) => declareSort(p.getType)),
+        declareSort(tfd.returnType)
+      ))
       s
     }
   }
@@ -348,8 +381,7 @@ trait SMTLIBTarget {
         )
 
       case er @ Error(tpe, _) =>
-        val s = declareVariable(FreshIdentifier("error_value", tpe))
-        s
+        declareVariable(FreshIdentifier("error_value", tpe))
 
       case s @ CaseClassSelector(cct, e, id) =>
         declareSort(cct)
@@ -411,26 +443,38 @@ trait SMTLIBTarget {
        * ===== Map operations =====
        */
       case m @ FiniteMap(elems) =>
+        import OptionManager._
         val mt @ MapType(_, to) = m.getType
         val ms = declareSort(mt)
 
-        val opt = declareOptionSort(to)
-
-        var res: Term = FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(ms)), List(QualifiedIdentifier(SMTIdentifier(SSymbol("None")), Some(opt))))
+        var res: Term = FunctionApplication(
+          QualifiedIdentifier(SMTIdentifier(SSymbol("const")), Some(ms)),
+          List(toSMT(mkLeonNone(to)))
+        )
         for ((k, v) <- elems) {
-          res = ArraysEx.Store(res, toSMT(k), FunctionApplication(SSymbol("Some"), List(toSMT(v))))
+          res = ArraysEx.Store(res, toSMT(k), toSMT(mkLeonSome(v)))
         }
 
         res
 
       case MapGet(m, k) =>
-        declareSort(m.getType)
-        FunctionApplication(SSymbol("Some_v"), List(ArraysEx.Select(toSMT(m), toSMT(k))))
+        import OptionManager._
+        val mt@MapType(_, vt) = m.getType
+        declareSort(mt)
+        // m(k) becomes
+        // (Option$get (select m k))
+        inlinedOptionGet(ArraysEx.Select(toSMT(m), toSMT(k)), vt)
 
       case MapIsDefinedAt(m, k) =>
-        declareSort(m.getType)
-        FunctionApplication(SSymbol("is-Some"), List(ArraysEx.Select(toSMT(m), toSMT(k))))
-
+        import OptionManager._
+        val mt@MapType(_, vt) = m.getType
+        declareSort(mt)
+        // m.isDefinedAt(k) becomes
+        // (Option$isDefined (select m k))
+        FunctionApplication(
+          someTester(vt),
+          Seq(ArraysEx.Select(toSMT(m), toSMT(k)))
+        )
       /**
        * ===== Everything else =====
        */
@@ -522,7 +566,7 @@ trait SMTLIBTarget {
     case (Core.True(), BooleanType)  => BooleanLiteral(true)
     case (Core.False(), BooleanType)  => BooleanLiteral(false)
 
-    case (FixedSizeBitVectors.BitVectorConstant(n, 32), Int32Type) => IntLiteral(n.toInt)
+    case (FixedSizeBitVectors.BitVectorConstant(n, b), Int32Type) if b == BigInt(32) => IntLiteral(n.toInt)
     case (SHexadecimal(hexa), Int32Type) => IntLiteral(hexa.toInt)
 
     case (SimpleSymbol(s), _: ClassType) if constructors.containsB(s) =>
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
index 24e435e8d..1aed748ea 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
@@ -12,7 +12,7 @@ import Types._
 import ExprOps.simplestValue
 
 import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _}
-import _root_.smtlib.parser.Commands.{DefineSort, GetModel, DefineFun}
+import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _}
 import _root_.smtlib.interpreters.Z3Interpreter
 import _root_.smtlib.parser.CommandsResponses.GetModelResponseSuccess
 import _root_.smtlib.theories.Core.{Equals => SMTEquals, _}
@@ -73,7 +73,10 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
         unsupported(" as-array on non-function or unknown symbol "+k)
       }
 
-    case (FunctionApplication(QualifiedIdentifier(SimpleSymbol(SSymbol("const")), Some(ArraysEx.ArraySort(k, v))), Seq(defV)), tpe) =>
+    case (FunctionApplication(
+      QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))),
+      Seq(defV)
+    ), tpe) =>
       val ktpe = sorts.fromB(k)
       val vtpe = sorts.fromB(v)
 
@@ -141,8 +144,7 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
   }
 
   def extractRawArray(s: DefineFun)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): RawArrayValue = s match {
-    case DefineFun(a, List(SortedVar(arg, akind)), rkind, body) =>
-
+    case DefineFun(SMTFunDef(a, Seq(SortedVar(arg, akind)), rkind, body)) =>
       val argTpe = sorts.toA(akind)
       val retTpe = sorts.toA(rkind)
 
@@ -178,7 +180,7 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
 
     // First pass to gather functions (arrays defs)
     for (me <- smodel) me match {
-      case me @ DefineFun(a, args, _, _) if args.nonEmpty =>
+      case me @ DefineFun(SMTFunDef(a, args, _, _)) if args.nonEmpty =>
         modelFunDefs += a -> me
       case _ =>
     }
@@ -186,7 +188,7 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
     var model = Map[Identifier, Expr]()
 
     for (me <- smodel) me match {
-      case DefineFun(s, args, kind, e) =>
+      case DefineFun(SMTFunDef(s, args, kind, e)) =>
         if(args.isEmpty) {
           variables.getA(s) match {
             case Some(id) =>
@@ -205,8 +207,9 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
   object ArrayMap {
     def apply(op: SSymbol, arrs: Term*) = {
       FunctionApplication(
-        QualifiedIdentifier(SMTIdentifier(SSymbol("(_ map "+op.name+")"))), //hack to get around Z3 syntax
-        arrs)
+        QualifiedIdentifier(SMTIdentifier(SSymbol("map"), List(op))),
+        arrs
+      )
     }
   }
 
diff --git a/src/main/scala/leon/utils/Library.scala b/src/main/scala/leon/utils/Library.scala
index 88ed6e5d8..a716d5c23 100644
--- a/src/main/scala/leon/utils/Library.scala
+++ b/src/main/scala/leon/utils/Library.scala
@@ -11,6 +11,10 @@ case class Library(pgm: Program) {
   lazy val Cons = lookup("leon.collection.Cons") collect { case ccd : CaseClassDef => ccd }
   lazy val Nil  = lookup("leon.collection.Nil") collect { case ccd : CaseClassDef => ccd }
 
+  lazy val Option = lookup("leon.collection.Option") collect { case acd : AbstractClassDef => acd }
+  lazy val Some = lookup("leon.collection.Some") collect { case ccd : CaseClassDef => ccd }
+  lazy val None = lookup("leon.collection.None") collect { case ccd : CaseClassDef => ccd }
+
   lazy val String = lookup("leon.lang.string.String") collect { case ccd : CaseClassDef => ccd }
 
   lazy val setToList = lookup("leon.collection.setToList") collect { case fd : FunDef => fd }
-- 
GitLab