From d897d93d022a2b9caa2b969303f832c953c764cc Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Wed, 9 Nov 2016 11:24:45 +0100
Subject: [PATCH] Moved model extraction to the right level in the encoding
 pipeline

---
 .../scala/inox/solvers/SolverFactory.scala    |   3 +
 .../solvers/unrolling/UnrollingSolver.scala   | 156 ++++++++++++------
 .../inox/solvers/z3/AbstractZ3Solver.scala    |  25 ++-
 .../inox/solvers/z3/NativeZ3Solver.scala      |  28 ++++
 4 files changed, 150 insertions(+), 62 deletions(-)

diff --git a/src/main/scala/inox/solvers/SolverFactory.scala b/src/main/scala/inox/solvers/SolverFactory.scala
index a98f7776d..211ad6aad 100644
--- a/src/main/scala/inox/solvers/SolverFactory.scala
+++ b/src/main/scala/inox/solvers/SolverFactory.scala
@@ -93,6 +93,7 @@ object SolverFactory {
         val encoder = enc
       } with unrolling.UnrollingSolver with theories.Z3Theories with TimeoutSolver {
         val evaluator = ev
+        lazy val modelEvaluator = RecursiveEvaluator(targetProgram, options)
 
         object underlying extends {
           val program: targetProgram.type = targetProgram
@@ -106,6 +107,7 @@ object SolverFactory {
         val encoder = enc
       } with unrolling.UnrollingSolver with theories.CVC4Theories with TimeoutSolver {
         val evaluator = ev
+        lazy val modelEvaluator = RecursiveEvaluator(targetProgram, options)
 
         object underlying extends {
           val program: targetProgram.type = targetProgram
@@ -119,6 +121,7 @@ object SolverFactory {
         val encoder = enc
       } with unrolling.UnrollingSolver with theories.Z3Theories with TimeoutSolver {
         val evaluator = ev
+        lazy val modelEvaluator = RecursiveEvaluator(targetProgram, options)
 
         object underlying extends {
           val program: targetProgram.type = targetProgram
diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
index fae2c9e23..c40001bb2 100644
--- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
+++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala
@@ -46,8 +46,11 @@ trait AbstractUnrollingSolver extends Solver { self =>
 
   protected final def encode(tpe: Type): t.Type = programEncoder.encode(tpe)
   protected final def decode(tpe: t.Type): Type = programEncoder.decode(tpe)
+
+  /*
   protected final def encode(ft: FunctionType): t.FunctionType =
     programEncoder.encode(ft).asInstanceOf[t.FunctionType]
+  */
 
   protected val templates: Templates {
     val program: targetProgram.type
@@ -124,9 +127,14 @@ trait AbstractUnrollingSolver extends Solver { self =>
   protected def wrapModel(model: underlying.Model): ModelWrapper
 
   trait ModelWrapper {
-    protected def modelEval(elem: Encoded, tpe: t.Type): Option[t.Expr]
+    def modelEval(elem: Encoded, tpe: t.Type): Option[t.Expr]
+
+    def extractConstructor(elem: Encoded): Option[Identifier]
+    def extractSet(elem: Encoded): Option[Seq[Encoded]]
+    def extractMap(elem: Encoded): Option[(Seq[(Encoded, Encoded)], Encoded)]
+    def extractBag(elem: Encoded): Option[Seq[(Encoded, Encoded)]]
 
-    def eval(elem: Encoded, tpe: s.Type): Option[Expr] = modelEval(elem, encode(tpe)).flatMap {
+    def eval(elem: Encoded, tpe: Type): Option[Expr] = modelEval(elem, encode(tpe)).flatMap {
       expr => try {
         Some(decode(expr))
       } catch {
@@ -176,13 +184,18 @@ trait AbstractUnrollingSolver extends Solver { self =>
   private def extractTotalModel(model: underlying.Model): Map[ValDef, Expr] = {
     val wrapped = wrapModel(model)
 
+    import targetProgram._
+    import targetProgram.trees._
+    import targetProgram.symbols._
+
     // maintain extracted functions to make sure equality is well-defined
     var funExtractions: Seq[(Encoded, Lambda)] = Seq.empty
 
     def extractValue(v: Encoded, tpe: Type): Expr = {
-      def functionsOf(expr: Expr, selector: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = {
-        def reconstruct(subs: Seq[(Seq[(Expr, Expr)], Seq[Expr] => Expr)],
-                        recons: Seq[Expr] => Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) =
+
+      def functionsOf(v: Encoded, tpe: Type): (Seq[(Encoded, FunctionType)], Seq[Expr] => Expr) = {
+        def reconstruct(subs: Seq[(Seq[(Encoded, FunctionType)], Seq[Expr] => Expr)],
+                        recons: Seq[Expr] => Expr): (Seq[(Encoded, FunctionType)], Seq[Expr] => Expr) =
           (subs.flatMap(_._1), (exprs: Seq[Expr]) => {
             var curr = exprs
             recons(subs.map { case (es, recons) =>
@@ -192,32 +205,55 @@ trait AbstractUnrollingSolver extends Solver { self =>
             })
           })
 
-        def rec(expr: Expr, selector: Expr): (Seq[(Expr, Expr)], Seq[Expr] => Expr) = expr match {
-          case (_: Lambda) =>
-            (Seq(expr -> selector), (es: Seq[Expr]) => es.head)
-
-          case Tuple(es) => reconstruct(es.zipWithIndex.map {
-            case (e, i) => rec(e, TupleSelect(selector, i + 1))
-          }, Tuple)
+        def rec(v: Encoded, tpe: Type): (Seq[(Encoded, FunctionType)], Seq[Expr] => Expr) = tpe match {
+          case ft: FunctionType =>
+            (Seq(v -> ft), es => es.head)
+
+          case TupleType(tps) =>
+            val id = Variable(FreshIdentifier("tuple"), tpe)
+            val encoder = templates.mkEncoder(Map(id -> v)) _
+            reconstruct(tps.zipWithIndex.map {
+              case (tpe, index) => rec(encoder(TupleSelect(id, index + 1)), tpe)
+            }, Tuple)
+
+          case ADTType(sid, tps) =>
+            val adt = ADTType(wrapped.extractConstructor(v).get, tps)
+            val id = Variable(FreshIdentifier("adt"), adt)
+            val encoder = templates.mkEncoder(Map(id -> v)) _
+            reconstruct(adt.getADT.toConstructor.fields.map {
+              vd => rec(encoder(ADTSelector(id, vd.id)), vd.tpe)
+            }, ADT(adt, _))
+
+          case SetType(base) =>
+            val vs = wrapped.extractSet(v).get
+            reconstruct(vs.map(rec(_, base)), FiniteSet(_, base))
+
+          case MapType(from, to) =>
+            val (vs, dflt) = wrapped.extractMap(v).get
+            reconstruct(vs.flatMap(p => Seq(rec(p._1, from), rec(p._2, to))) :+ rec(dflt, to), {
+              case es :+ default => FiniteMap(es.grouped(2).map(s => s(0) -> s(1)).toSeq, default, from, to)
+            })
 
-          case ADT(adt, es) => reconstruct((adt.getADT.toConstructor.fields zip es).map {
-            case (vd, e) => rec(e, ADTSelector(selector, vd.id))
-          }, ADT(adt, _))
+          case BagType(base) =>
+            val vs = wrapped.extractBag(v).get
+            reconstruct(vs.map(p => rec(p._1, base)), es => FiniteBag((es zip vs).map {
+              case (k, (_, v)) => k -> wrapped.modelEval(v, IntegerType).get
+            }, base))
 
-          case _ => (Seq.empty, (es: Seq[Expr]) => expr)
+          case _ => (Seq.empty, (es: Seq[Expr]) => wrapped.modelEval(v, tpe).get)
         }
 
-        rec(expr, selector)
+        rec(v, tpe)
       }
 
-      val value = wrapped.eval(v, tpe).getOrElse(simplestValue(tpe))
-      val id = Variable(FreshIdentifier("v"), tpe)
-      val (functions, recons) = functionsOf(value, id)
-      recons(functions.map { case (f, selector) =>
-        val encoded = templates.mkEncoder(Map(encode(id) -> v))(encode(selector))
-        val tpe = bestRealType(f.getType).asInstanceOf[FunctionType]
-        extractFunction(encoded, tpe)
-      })
+      if (wrapped.modelEval(v, tpe).isDefined) {
+        val (functions, recons) = functionsOf(v, tpe)
+        recons(functions.map { case (f, tpe) =>
+          extractFunction(f, bestRealType(tpe).asInstanceOf[FunctionType])
+        })
+      } else {
+        simplestValue(tpe)
+      }
     }
 
     object FiniteLambda {
@@ -257,21 +293,20 @@ trait AbstractUnrollingSolver extends Solver { self =>
 
     def extractFunction(f: Encoded, tpe: FunctionType): Expr = {
       def extractLambda(f: Encoded, tpe: FunctionType): Option[Lambda] = {
-        val optEqTemplate = templates.getLambdaTemplates(encode(tpe)).find { tmpl =>
-          wrapped.eval(tmpl.start, BooleanType) == Some(BooleanLiteral(true)) &&
-          wrapped.eval(templates.mkEquals(tmpl.ids._2, f), BooleanType) == Some(BooleanLiteral(true))
+        val optEqTemplate = templates.getLambdaTemplates(tpe).find { tmpl =>
+          wrapped.modelEval(tmpl.start, BooleanType) == Some(BooleanLiteral(true)) &&
+          wrapped.modelEval(templates.mkEquals(tmpl.ids._2, f), BooleanType) == Some(BooleanLiteral(true))
         }
 
         optEqTemplate.map { tmpl =>
           val localsSubst = tmpl.structure.locals.map { case (v, ev) =>
-            val dv = decode(v)
-            dv -> wrapped.eval(ev, dv.tpe).getOrElse {
+            v -> wrapped.modelEval(ev, v.tpe).getOrElse {
               scala.sys.error("Unexpectedly failed to extract " + templates.asString(ev) +
-                " with expected type " + dv.tpe.asString)
+                " with expected type " + v.tpe.asString)
             }
           }.toMap
 
-          exprOps.replaceFromSymbols(localsSubst, decode(tmpl.structure.lambda)).asInstanceOf[Lambda]
+          exprOps.replaceFromSymbols(localsSubst, tmpl.structure.lambda).asInstanceOf[Lambda]
         }
       }
 
@@ -287,9 +322,9 @@ trait AbstractUnrollingSolver extends Solver { self =>
             case ft: FunctionType =>
               val nextParams = params.tail
               val nextArguments = arguments.map(_.tail)
-              extract(templates.mkApp(caller, encode(tpe), Seq.empty), ft, nextParams, nextArguments, dflt)
+              extract(templates.mkApp(caller, tpe, Seq.empty), ft, nextParams, nextArguments, dflt)
             case _ =>
-              (extractValue(templates.mkApp(caller, encode(tpe), Seq.empty), tpe.to), false)
+              (extractValue(templates.mkApp(caller, tpe, Seq.empty), tpe.to), false)
           }
 
           (Lambda(Seq.empty, result), real)
@@ -300,7 +335,7 @@ trait AbstractUnrollingSolver extends Solver { self =>
               case (currCond, arguments) => tpe.to match {
                 case ft: FunctionType =>
                   val (currArgs, restArgs) = (arguments.head.head._1, arguments.map(_.tail))
-                  val newCaller = templates.mkApp(caller, encode(tpe), currArgs)
+                  val newCaller = templates.mkApp(caller, tpe, currArgs)
                   val (res, real) = extract(newCaller, ft, params.tail, restArgs, dflt)
                   val mappings: Seq[(Expr, Expr)] = if (real) {
                     Seq(BooleanLiteral(true) -> res)
@@ -312,7 +347,7 @@ trait AbstractUnrollingSolver extends Solver { self =>
 
                 case _ =>
                   val currArgs = arguments.head.head._1
-                  val res = extractValue(templates.mkApp(caller, encode(tpe), currArgs), tpe.to)
+                  val res = extractValue(templates.mkApp(caller, tpe, currArgs), tpe.to)
                   Seq(currCond -> res)
               }
             }
@@ -323,10 +358,10 @@ trait AbstractUnrollingSolver extends Solver { self =>
               case (encoded, `lambda`) => Right(encoded)
               case (e, img) if (
                 bestRealType(img.getType) == bestRealType(lambda.getType) &&
-                wrapped.eval(templates.mkEquals(e, f), BooleanType) == Some(BooleanLiteral(true))
+                wrapped.modelEval(templates.mkEquals(e, f), BooleanType) == Some(BooleanLiteral(true))
               )=> Left(img)
             }) match {
-              case Some(Right(enc)) => wrapped.eval(enc, tpe).get match {
+              case Some(Right(enc)) => wrapped.modelEval(enc, tpe).get match {
                 case Lambda(_, Let(_, Tuple(es), _)) =>
                   uniquateClosure(if (es.size % 2 == 0) -es.size / 2 else es.size / 2, lambda)
                 case l => scala.sys.error("Unexpected extracted lambda format: " + l)
@@ -351,13 +386,13 @@ trait AbstractUnrollingSolver extends Solver { self =>
         rec(tpe)
       }
 
-      val arguments = templates.getGroundInstantiations(f, encode(tpe)).flatMap { case (b, eArgs) =>
-        wrapped.eval(b, BooleanType).filter(_ == BooleanLiteral(true)).map(_ => eArgs)
+      val arguments = templates.getGroundInstantiations(f, tpe).flatMap { case (b, eArgs) =>
+        wrapped.modelEval(b, BooleanType).filter(_ == BooleanLiteral(true)).map(_ => eArgs)
       }.distinct
 
       extractLambda(f, tpe).getOrElse {
         if (arguments.isEmpty) {
-          wrapped.eval(f, tpe).get
+          wrapped.modelEval(f, tpe).get
         } else {
           val projection: Encoded = arguments.head.head
 
@@ -389,7 +424,7 @@ trait AbstractUnrollingSolver extends Solver { self =>
           }
 
           val (app, to) = unflatten(flatArguments.last._1).foldLeft(f -> (tpe: Type)) {
-            case ((f, tpe: FunctionType), args) => (templates.mkApp(f, encode(tpe), args), tpe.to)
+            case ((f, tpe: FunctionType), args) => (templates.mkApp(f, tpe, args), tpe.to)
           }
           val default = extractValue(app, to)
 
@@ -398,7 +433,7 @@ trait AbstractUnrollingSolver extends Solver { self =>
       }
     }
 
-    freeVars.toMap.map { case (v, idT) => v.toVal -> extractValue(idT, v.tpe) }
+    freeVars.toMap.map { case (v, idT) => v.toVal -> decode(extractValue(idT, encode(v.tpe))) }
   }
 
   def checkAssumptions(config: Configuration)(assumptions: Set[Expr]): config.Response[Model, Assumptions] = {
@@ -646,13 +681,38 @@ trait UnrollingSolver extends AbstractUnrollingSolver { self =>
     def mkImplies(l: Expr, r: Expr) = implies(l, r)
   }
 
+  protected val modelEvaluator: DeterministicEvaluator {
+    val program: self.targetProgram.type
+  }
+
   protected def declareVariable(v: t.Variable): t.Variable = v
-  protected def wrapModel(model: Map[t.ValDef, t.Expr]): super.ModelWrapper =
-    ModelWrapper(model.map(p => decode(p._1) -> decode(p._2)))
+  protected def wrapModel(model: Map[t.ValDef, t.Expr]): super.ModelWrapper = ModelWrapper(model)
+
+  private case class ModelWrapper(model: Map[t.ValDef, t.Expr]) extends super.ModelWrapper {
+    private def e(expr: t.Expr): Option[t.Expr] = modelEvaluator.eval(expr, model).result
+
+    def extractConstructor(elem: t.Expr): Option[Identifier] = e(elem) match {
+      case Some(t.ADT(t.ADTType(id, _), _)) => Some(id)
+      case _ => None
+    }
+
+    def extractSet(elem: t.Expr): Option[Seq[t.Expr]] = e(elem) match {
+      case Some(t.FiniteSet(elems, _)) => Some(elems)
+      case _ => None
+    }
+
+    def extractBag(elem: t.Expr): Option[Seq[(t.Expr, t.Expr)]] = e(elem) match {
+      case Some(t.FiniteBag(elems, _)) => Some(elems)
+      case _ => None
+    }
+
+    def extractMap(elem: t.Expr): Option[(Seq[(t.Expr, t.Expr)], t.Expr)] = e(elem) match {
+      case Some(t.FiniteMap(elems, default, _, _)) => Some((elems, default))
+      case _ => None
+    }
+
+    def modelEval(elem: t.Expr, tpe: t.Type): Option[t.Expr] = e(elem)
 
-  private case class ModelWrapper(model: Map[ValDef, Expr]) extends super.ModelWrapper {
-    def modelEval(elem: t.Expr, tpe: t.Type): Option[t.Expr] =
-      evaluator.eval(decode(elem), model).result.map(encode)
     override def toString = model.mkString("\n")
   }
 
diff --git a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala
index 29987c499..87cf23fbe 100644
--- a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala
@@ -88,18 +88,18 @@ trait AbstractZ3Solver
   }
 
   // ADT Manager
-  private val adtManager = new ADTManager
+  private[z3] val adtManager = new ADTManager
 
-  // Bijections between Inox Types/Functions/Ids to Z3 Sorts/Decls/ASTs
-  private val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]()
-  private val lambdas   = new IncrementalBijection[FunctionType, Z3FuncDecl]()
-  private val variables = new IncrementalBijection[Variable, Z3AST]()
+  // Bije[z3]ctions between Inox Types/Functions/Ids to Z3 Sorts/Decls/ASTs
+  private[z3] val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]()
+  private[z3] val lambdas   = new IncrementalBijection[FunctionType, Z3FuncDecl]()
+  private[z3] val variables = new IncrementalBijection[Variable, Z3AST]()
 
-  private val constructors = new IncrementalBijection[Type, Z3FuncDecl]()
-  private val selectors    = new IncrementalBijection[(Type, Int), Z3FuncDecl]()
-  private val testers      = new IncrementalBijection[Type, Z3FuncDecl]()
+  private[z3] val constructors = new IncrementalBijection[Type, Z3FuncDecl]()
+  private[z3] val selectors    = new IncrementalBijection[(Type, Int), Z3FuncDecl]()
+  private[z3] val testers      = new IncrementalBijection[Type, Z3FuncDecl]()
 
-  private val sorts     = new IncrementalMap[Type, Z3Sort]()
+  private[z3] val sorts     = new IncrementalMap[Type, Z3Sort]()
 
   def push(): Unit = {
     adtManager.push()
@@ -207,7 +207,6 @@ trait AbstractZ3Solver
   // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand.
   private def prepareSorts(): Unit = {
 
-    //TODO: mkBitVectorType
     sorts += Int32Type -> z3.mkBVSort(32)
     sorts += CharType -> z3.mkBVSort(32)
     sorts += IntegerType -> z3.mkIntSort
@@ -534,11 +533,9 @@ trait AbstractZ3Solver
           if(ts.length > 4 && ts.substring(0, 2) == "bv" && ts.substring(ts.length - 4) == "[32]") {
             val integer = ts.substring(2, ts.length - 4)
             tpe match {
-              case Int32Type => 
-                IntLiteral(integer.toLong.toInt)
+              case Int32Type => IntLiteral(integer.toLong.toInt)
               case CharType  => CharLiteral(integer.toInt.toChar)
-              case IntegerType => 
-                IntegerLiteral(BigInt(integer))
+              // @nv XXX: why would we have this!? case IntegerType => IntegerLiteral(BigInt(integer))
               case _ =>
                 reporter.fatalError("Unexpected target type for BV value: " + tpe.asString)
             }
diff --git a/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala b/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala
index 3043b4c89..55570ff79 100644
--- a/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala
+++ b/src/main/scala/inox/solvers/z3/NativeZ3Solver.scala
@@ -72,6 +72,34 @@ trait NativeZ3Solver extends AbstractUnrollingSolver { self =>
   protected def wrapModel(model: Z3Model): super.ModelWrapper = ModelWrapper(model)
 
   private case class ModelWrapper(model: Z3Model) extends super.ModelWrapper {
+    def extractConstructor(v: Z3AST): Option[Identifier] = model.eval(v).flatMap {
+      elem => z3.getASTKind(elem) match {
+        case Z3AppAST(decl, args) if underlying.constructors containsB decl =>
+          underlying.constructors.toA(decl) match {
+            case t.ADTType(id, _) => Some(id)
+            case _ => None
+          }
+        case _ => None
+      }
+    }
+
+    def extractSet(v: Z3AST): Option[Seq[Z3AST]] = model.eval(v).flatMap {
+      elem => model.getSetValue(elem) collect { case (set, true) => set.toSeq }
+    }
+
+    def extractBag(v: Z3AST): Option[Seq[(Z3AST, Z3AST)]] = model.eval(v).flatMap {
+      elem => model.getArrayValue(elem) flatMap { case (z3Map, z3Default) =>
+        z3.getASTKind(z3Default) match {
+          case Z3NumeralIntAST(Some(0)) => Some(z3Map.toSeq)
+          case _ => None
+        }
+      }
+    }
+
+    def extractMap(v: Z3AST): Option[(Seq[(Z3AST, Z3AST)], Z3AST)] = model.eval(v).flatMap {
+      elem => model.getArrayValue(elem).map(p => p._1.toSeq -> p._2)
+    }
+
     def modelEval(elem: Z3AST, tpe: t.Type): Option[t.Expr] = {
       val timer = ctx.timers.solvers.z3.eval.start()
       val res = tpe match {
-- 
GitLab