From a451ad7b4dd601e86924ad12ea30b002609eb352 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Mon, 2 Jun 2014 16:24:38 +0200
Subject: [PATCH] Support get-model for arrays in z3-smt

---
 project/Build.scala                           |   2 +-
 .../solvers/combinators/UnrollingSolver.scala |   2 +-
 .../leon/solvers/smtlib/SMTLIBSolver.scala    |  41 +------
 .../leon/solvers/smtlib/SMTLIBTarget.scala    |  56 ++++++++--
 .../leon/solvers/smtlib/SMTLIBZ3Target.scala  | 102 +++++++++++++++++-
 .../scala/leon/solvers/z3/FairZ3Solver.scala  |   2 +-
 6 files changed, 157 insertions(+), 48 deletions(-)

diff --git a/project/Build.scala b/project/Build.scala
index a5205b3f0..4f6b4ddd0 100644
--- a/project/Build.scala
+++ b/project/Build.scala
@@ -79,7 +79,7 @@ object Leon extends Build {
   object Github {
     lazy val bonsai = RootProject(uri("git://github.com/colder/bonsai.git#8f485605785bda98ac61885b0c8036133783290a"))
 
-    private val scalaSmtLibVersion = "d4622f38a04a191798eb29f39d3c8b2ec312e811"
+    private val scalaSmtLibVersion = "90f66cf07aef34b05dc5585bb35aca773b3d0d43"
     lazy val scalaSmtLib = RootProject(uri("git://github.com/regb/scala-smtlib.git#%s".format(scalaSmtLibVersion)))
   }
 }
diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
index ce527cfcc..43b794e2d 100644
--- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
+++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala
@@ -100,7 +100,7 @@ class UnrollingSolver(val context: LeonContext, underlyings: SolverFactory[Incre
               reporter.debug(" - more unrollings")
               val newClauses = unrollOneStep()
               reporter.debug(s"   - ${newClauses.size} new clauses")
-              readLine()
+              //readLine()
               solver.assertCnstr(And(newClauses))
           }
 
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
index a1de9e9e2..51ac76726 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
@@ -13,11 +13,6 @@ import TreeOps._
 import TypeTrees._
 import Definitions._
 
-import _root_.smtlib.{PrettyPrinter => SMTPrinter, Interpreter => SMTInterpreter}
-import _root_.smtlib.Commands.{Identifier => _, Assert => SMTAssert, _}
-import _root_.smtlib.CommandResponses.{Error => ErrorResponse, _}
-import _root_.smtlib.sexpr.SExprs._
-import _root_.smtlib.interpreters._
 
 abstract class SMTLIBSolver(val context: LeonContext,
                             val program: Program)
@@ -28,41 +23,15 @@ abstract class SMTLIBSolver(val context: LeonContext,
 
   override def name: String = "smt-"+targetName
 
-  override def assertCnstr(expr: Expr): Unit = {
-    variablesOf(expr).foreach(declareVariable)
-    val sexpr = toSMT(expr)(Map())
-    sendCommand(SMTAssert(sexpr))
-  }
-
-  override def check: Option[Boolean] = sendCommand(CheckSat) match {
-    case CheckSatResponse(SatStatus)     => Some(true)
-    case CheckSatResponse(UnsatStatus)   => Some(false)
-    case CheckSatResponse(UnknownStatus) => None
-  }
-
-  override def getModel: Map[Identifier, Expr] = {
-    val syms = variables.bSet.toList
-    val cmd: Command = GetValue(syms.head, syms.tail)
-
-    val GetValueResponse(valuationPairs) = sendCommand(cmd)
-
-    valuationPairs.collect {
-      case (sym: SSymbol, value) if variables.containsB(sym) =>
-        //println("Getting model for "+sym)
-        //println("Value "+value)
-        (variables.toA(sym), fromSMT(value)(Map()))
-    }.toMap
-  }
+  /**
+   * Most of the solver functionality is defined, and thus extensible, in
+   * SMTLIBTarget, which gets specialized based on how individual solvers
+   * diverge from the SMTLib standard.
+   */
 
   override def free() = {
     interpreter.free()
     out.close
   }
 
-  override def push(): Unit = {
-    sendCommand(Push(1))
-  }
-  override def pop(lvl: Int = 1): Unit = {
-    sendCommand(Pop(1))
-  }
 }
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
index dc6e9adc3..87196540d 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
@@ -28,12 +28,12 @@ trait SMTLIBTarget {
 
   val interpreter = getNewInterpreter()
 
-  val out = new java.io.FileWriter(s"vcs-$targetName.log", true)
+  val out = new java.io.FileWriter(s"vcs-$targetName.smt2", true)
   reporter.ifDebug { debug =>
     out.write("; -----------------------------------------------------\n")
   }
 
-  def id2sym(id: Identifier): SSymbol = SSymbol(id.name.toUpperCase+"!"+id.globalId)
+  def id2sym(id: Identifier): SSymbol = SSymbol(id.name+"!"+id.globalId)
 
   // metadata for CC, and variables
   val constructors = new Bijection[TypeTree, SSymbol]()
@@ -57,6 +57,11 @@ trait SMTLIBTarget {
       tpe match {
         case BooleanType => SSymbol("Bool")
         case Int32Type => SSymbol("Int")
+        case UnitType =>
+          val s = SSymbol("Unit")
+          val cmd = NonStandardCommand(SList(SSymbol("declare-sort"), s))
+          sendCommand(cmd)
+          s
         case TypeParameter(id) =>
           val s = id2sym(id)
           val cmd = NonStandardCommand(SList(SSymbol("declare-sort"), s))
@@ -276,10 +281,14 @@ trait SMTLIBTarget {
     }
   }
 
+  def extractSpecialSymbols(s: SSymbol): Option[Expr] = {
+    None
+  }
+
   def fromSMT(s: SExpr)(implicit bindings: Map[SSymbol, Expr]): Expr = s match {
     case SInt(n)          => IntLiteral(n.toInt)
-    case SSymbol("TRUE")  => BooleanLiteral(true)
-    case SSymbol("FALSE") => BooleanLiteral(false)
+    case SSymbol("true")  => BooleanLiteral(true)
+    case SSymbol("false") => BooleanLiteral(false)
     case s: SSymbol if constructors.containsB(s) =>
       constructors.toA(s) match {
         case cct: CaseClassType =>
@@ -289,8 +298,10 @@ trait SMTLIBTarget {
       }
 
     case s: SSymbol =>
-      println(s)
-      bindings.getOrElse(s, variables.fromB(s).toVariable)
+      (bindings.get(s) orElse variables.getA(s).map(_.toVariable)
+                       orElse extractSpecialSymbols(s)).getOrElse {
+        unsupported("Unknown symbol: "+s)
+      }
 
     case SList((s: SSymbol) :: args) if(constructors.containsB(s)) => 
       val rargs = args.map(fromSMT)
@@ -303,7 +314,7 @@ trait SMTLIBTarget {
           unsupported("Woot? structural type that is non-structural: "+t)
       }
 
-    case SList(List(SSymbol("LET"), SList(defs), body)) =>
+    case SList(List(SSymbol("let"), SList(defs), body)) =>
       val leonDefs: Seq[(SSymbol, Identifier, Expr)] = defs.map {
         case SList(List(s : SSymbol, value)) =>
           (s, FreshIdentifier(s.s), fromSMT(value))
@@ -343,4 +354,35 @@ trait SMTLIBTarget {
     assert(!response.isInstanceOf[Error])
     response
   }
+
+  override def assertCnstr(expr: Expr): Unit = {
+    variablesOf(expr).foreach(declareVariable)
+    val sexpr = toSMT(expr)(Map())
+    sendCommand(Assert(sexpr))
+  }
+
+  override def check: Option[Boolean] = sendCommand(CheckSat) match {
+    case CheckSatResponse(SatStatus)     => Some(true)
+    case CheckSatResponse(UnsatStatus)   => Some(false)
+    case CheckSatResponse(UnknownStatus) => None
+  }
+
+  override def getModel: Map[Identifier, Expr] = {
+    val syms = variables.bSet.toList
+    val cmd: Command = GetValue(syms.head, syms.tail)
+
+    val GetValueResponse(valuationPairs) = sendCommand(cmd)
+
+    valuationPairs.collect {
+      case (sym: SSymbol, value) if variables.containsB(sym) =>
+        (variables.toA(sym), fromSMT(value)(Map()))
+    }.toMap
+  }
+
+  override def push(): Unit = {
+    sendCommand(Push(1))
+  }
+  override def pop(lvl: Int = 1): Unit = {
+    sendCommand(Pop(1))
+  }
 }
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
index 0a6198af4..127f1e79c 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala
@@ -13,8 +13,8 @@ import TypeTrees._
 
 import _root_.smtlib.sexpr.SExprs._
 import _root_.smtlib.interpreters.Z3Interpreter
-
-import _root_.smtlib.Commands.NonStandardCommand
+import _root_.smtlib.Commands.{GetValue, NonStandardCommand}
+import _root_.smtlib.CommandResponses.SExprResponse
 
 trait SMTLIBZ3Target extends SMTLIBTarget {
   this: SMTLIBSolver =>
@@ -53,6 +53,33 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
     SList(setSort.get, declareSort(of))
   }
 
+  override def extractSpecialSymbols(s: SSymbol): Option[Expr] = {
+    s.s.split("!").toList.reverse match {
+      case n :: "val" :: rest =>
+        val sort = rest.reverse.mkString("!")
+
+        sorts.getA(SSymbol(sort)) match {
+          case Some(tp : TypeParameter) =>
+            Some(GenericValue(tp, n.toInt))
+          case _ =>
+            None
+        }
+      case _ =>
+        None
+    }
+  }
+
+  override def fromSMT(s: SExpr)(implicit bindings: Map[SSymbol, Expr]): Expr = s match {
+    case SList(List(`extSym`, SSymbol("as-array"), k: SSymbol)) =>
+      bindings(k)
+
+    // SMT representation for empty sets: Array(* -> false)
+    case SList(List(SList(List(SSymbol("as"), SSymbol("const"), SList(List(SSymbol("Array"), s, SSymbol("Bool"))))), SSymbol("false"))) =>
+      FiniteSet(Nil).setType(sorts.fromB(s))
+
+    case _ =>
+      super.fromSMT(s)
+  }
 
   override def toSMT(e: Expr)(implicit bindings: Map[Identifier, SExpr]) = e match {
       case fs @ FiniteSet(elems) =>
@@ -79,4 +106,75 @@ trait SMTLIBZ3Target extends SMTLIBTarget {
       case _ =>
         super.toSMT(e)
   }
+
+  // We use get-model instead
+  override def getModel: Map[Identifier, Expr] = {
+
+    val cmd = NonStandardCommand(SList(SSymbol("get-model")))
+
+    val res = sendCommand(cmd)
+
+    val smodel = res match {
+      case SExprResponse(SList(SSymbol("model") :: entries)) => entries
+      case _ => Nil
+    }
+
+    var maps = Map[SSymbol, TypeTree]()
+
+    // First pass to gather functions (arrays defs)
+    for (me <- smodel) me match {
+      case SList(List(SSymbol("define-fun"), a: SSymbol, SList(Nil), _, SList(List(`extSym`, SSymbol("as-array"), k: SSymbol)))) =>
+        maps += k -> variables.toA(a).getType
+      case _ =>
+    }
+
+    var bindings = Map[SSymbol, Expr]()
+
+    // Second pass to gather functions (arrays defs)
+    for (me <- smodel) me match {
+      case SList(List(SSymbol("define-fun"), s: SSymbol, SList(SList(List(arg, _)) :: Nil), _, e)) if maps contains s =>
+        def extractCases(e: SExpr): (Map[Expr, Expr], Expr) = e match {
+          case SList(SSymbol("ite") :: SList(SSymbol("=") :: `arg` :: k :: Nil) :: v :: e :: Nil) =>
+            val (cs, d) = extractCases(e)
+            (Map(fromSMT(k)(Map()) -> fromSMT(v)(Map())) ++ cs, d)
+          case e =>
+            (Map(),fromSMT(e)(Map()))
+        }
+
+        def buildValue(cases: Map[Expr, Expr], default: Expr, tpe: TypeTree): Expr = tpe match {
+          case SetType(base) =>
+            assert(default == BooleanLiteral(false))
+            FiniteSet(cases.keySet.toSeq).setType(tpe)
+          case _ =>
+            unsupported("Cannot build array/map model to "+tpe)
+        }
+
+        val tpe = maps(s)
+        val (cases, default) = extractCases(e)
+
+        bindings += s -> buildValue(cases, default, tpe)
+      case _ =>
+    }
+
+    var model = Map[Identifier, Expr]()
+
+    for (me <- smodel) me match {
+      case SList(List(SSymbol("define-fun"), s: SSymbol, SList(args), kind, e)) =>
+        if (args.isEmpty) {
+          model += variables.toA(s) -> fromSMT(e)(bindings)
+        }
+
+      case SList(SSymbol("forall") :: _) => // no body
+        // Ignore
+
+      case SList(SSymbol("declare-fun") :: _) => // no body
+        // Ignore
+
+      case _ =>
+        unsupported("Unknown model entry: "+me)
+    }
+
+
+    model
+  }
 }
diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
index f61186f07..79d60b377 100644
--- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala
@@ -558,7 +558,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program)
               solver.assertCnstr(ncl)
             }
 
-            readLine()
+            //readLine()
 
             reporter.debug(" - finished unrolling")
           }
-- 
GitLab