From f34d906ffc40a3688de866d01309e047c45d50b9 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Fri, 28 Oct 2016 12:08:43 +0200
Subject: [PATCH] FiniteMap with to type + fixes in ProgramEncoder

---
 src/main/scala/inox/Main.scala                   |  2 +-
 src/main/scala/inox/ast/Expressions.scala        |  4 ++--
 src/main/scala/inox/ast/Extractors.scala         |  8 ++++----
 src/main/scala/inox/ast/Printers.scala           |  2 +-
 src/main/scala/inox/ast/ProgramEncoder.scala     | 14 ++++++++++++++
 src/main/scala/inox/ast/SymbolOps.scala          |  9 +++++----
 .../inox/evaluators/RecursiveEvaluator.scala     | 12 ++++++------
 .../scala/inox/solvers/smtlib/CVC4Target.scala   |  8 ++++----
 .../scala/inox/solvers/smtlib/SMTLIBTarget.scala |  2 +-
 .../scala/inox/solvers/smtlib/Z3Target.scala     | 16 +++++++---------
 .../scala/inox/solvers/z3/AbstractZ3Solver.scala |  8 ++++----
 src/main/scala/inox/tip/Parser.scala             |  3 ++-
 12 files changed, 51 insertions(+), 37 deletions(-)

diff --git a/src/main/scala/inox/Main.scala b/src/main/scala/inox/Main.scala
index 625152626..c6be43a62 100644
--- a/src/main/scala/inox/Main.scala
+++ b/src/main/scala/inox/Main.scala
@@ -110,7 +110,7 @@ trait MainHelpers {
 
     Context(
       reporter = reporter,
-      options = Options(inoxOptions),
+      options = Options(inoxOptions :+ optFiles(files)),
       interruptManager = new utils.InterruptManager(reporter)
     )
   }
diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala
index 66f2f55c5..bf4ebabd6 100644
--- a/src/main/scala/inox/ast/Expressions.scala
+++ b/src/main/scala/inox/ast/Expressions.scala
@@ -670,10 +670,10 @@ trait Expressions { self: Trees =>
   /* Total map operations */
 
   /** $encodingof `Map[keyType, valueType](key1 -> value1, key2 -> value2 ...)` */
-  case class FiniteMap(pairs: Seq[(Expr, Expr)], default: Expr, keyType: Type) extends Expr with CachingTyped {
+  case class FiniteMap(pairs: Seq[(Expr, Expr)], default: Expr, keyType: Type, valueType: Type) extends Expr with CachingTyped {
     protected def computeType(implicit s: Symbols): Type = MapType(
       checkParamTypes(pairs.map(_._1.getType), List.fill(pairs.size)(keyType), keyType),
-      s.leastUpperBound(pairs.map(_._2.getType) :+ default.getType).getOrElse(Untyped)
+      checkParamTypes(pairs.map(_._2.getType) :+ default.getType, List.fill(pairs.size + 1)(valueType), valueType)
     ).unveilUntyped
   }
 
diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala
index f7e261e88..f09ba2b8b 100644
--- a/src/main/scala/inox/ast/Extractors.scala
+++ b/src/main/scala/inox/ast/Extractors.scala
@@ -150,9 +150,9 @@ trait TreeDeconstructor {
         t.FiniteBag(rec(as), tps.head)
       }
       (Seq(), subArgs, Seq(base), builder)
-    case s.FiniteMap(elems, default, kT) =>
+    case s.FiniteMap(elems, default, kT, vT) =>
       val subArgs = elems.flatMap { case (k, v) => Seq(k, v) } :+ default
-      val builder = (vs: Seq[t.Variable], as: Seq[t.Expr], kT: Seq[t.Type]) => {
+      val builder = (vs: Seq[t.Variable], as: Seq[t.Expr], tps: Seq[t.Type]) => {
         def rec(kvs: Seq[t.Expr]): (Seq[(t.Expr, t.Expr)], t.Expr) = kvs match {
           case Seq(k, v, t @ _*) =>
             val (kvs, default) = rec(t)
@@ -160,9 +160,9 @@ trait TreeDeconstructor {
           case Seq(default) => (Seq(), default)
         }
         val (pairs, default) = rec(as)
-        t.FiniteMap(pairs, default, kT.head)
+        t.FiniteMap(pairs, default, tps(0), tps(1))
       }
-      (Seq(), subArgs, Seq(kT), builder)
+      (Seq(), subArgs, Seq(kT, vT), builder)
     case s.Tuple(args) =>
       (Seq(), args, Seq(), (_, es, _) => t.Tuple(es))
     case s.IfExpr(cond, thenn, elze) => (
diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala
index e180bc0f3..27ac8db61 100644
--- a/src/main/scala/inox/ast/Printers.scala
+++ b/src/main/scala/inox/ast/Printers.scala
@@ -239,7 +239,7 @@ trait Printers {
     }
     case fs @ FiniteSet(rs, _) => p"{${rs.distinct}}"
     case fs @ FiniteBag(rs, _) => p"{${rs.toMap.toSeq}}"
-    case fm @ FiniteMap(rs, _, _) => p"{${rs.toMap.toSeq}}"
+    case fm @ FiniteMap(rs, _, _, _) => p"{${rs.toMap.toSeq}}"
     case Not(ElementOfSet(e, s)) => p"$e \u2209 $s"
     case ElementOfSet(e, s) => p"$e \u2208 $s"
     case SubsetOf(l, r) => p"$l \u2286 $r"
diff --git a/src/main/scala/inox/ast/ProgramEncoder.scala b/src/main/scala/inox/ast/ProgramEncoder.scala
index 0f8edaf9d..0d60a6ed6 100644
--- a/src/main/scala/inox/ast/ProgramEncoder.scala
+++ b/src/main/scala/inox/ast/ProgramEncoder.scala
@@ -42,6 +42,13 @@ trait ProgramEncoder { self =>
     val sourceProgram: that.sourceProgram.type = that.sourceProgram
     val t: self.t.type = self.t
 
+    // make sure we don't ignore potential `encodedProgram` overrides
+    // note that we don't actually need to look at `that.encodedProgram` since the type
+    // of the compose method ensures the override is not ignored
+    override protected def encodedProgram: Program { val trees: self.t.type } = self.encodedProgram
+    override protected val extraFunctions: Seq[t.FunDef] = self.extraFunctions
+    override protected val extraADTs: Seq[t.ADTDefinition] = self.extraADTs
+
     val encoder = self.encoder compose that.encoder
     val decoder = that.decoder compose self.decoder
   }
@@ -53,6 +60,13 @@ trait ProgramEncoder { self =>
     val sourceProgram: self.sourceProgram.type = self.sourceProgram
     val t: that.t.type = that.t
 
+    // make sure we don't ignore potential `encodedProgram` overrides
+    // note that we don't actually need to look at `that.encodedProgram` since the type
+    // of the andThen method ensures the override is not ignored
+    override protected def encodedProgram: Program { val trees: that.t.type } = that.encodedProgram
+    override protected val extraFunctions: Seq[t.FunDef] = that.extraFunctions
+    override protected val extraADTs: Seq[t.ADTDefinition] = that.extraADTs
+
     val encoder = self.encoder andThen that.encoder
     val decoder = that.decoder andThen self.decoder
   }
diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala
index b9f5f56b7..7bd557dae 100644
--- a/src/main/scala/inox/ast/SymbolOps.scala
+++ b/src/main/scala/inox/ast/SymbolOps.scala
@@ -327,7 +327,7 @@ trait SymbolOps { self: TypeOps =>
     case UnitType                   => UnitLiteral()
     case SetType(baseType)          => FiniteSet(Seq(), baseType)
     case BagType(baseType)          => FiniteBag(Seq(), baseType)
-    case MapType(fromType, toType)  => FiniteMap(Seq(), simplestValue(toType), fromType)
+    case MapType(fromType, toType)  => FiniteMap(Seq(), simplestValue(toType), fromType, toType)
     case TupleType(tpes)            => Tuple(tpes.map(simplestValue))
 
     case adt @ ADTType(id, tps) =>
@@ -389,7 +389,7 @@ trait SymbolOps { self: TypeOps =>
         val seqs = elems.scanLeft(Stream(Seq[(Expr, Expr)]())) { (prev, curr) =>
           prev flatMap { case seq => Stream(seq, seq :+ curr) }
         }.flatten
-        cartesianProduct(seqs, valuesOf(to)) map { case (values, default) => FiniteMap(values, default, from) }
+        cartesianProduct(seqs, valuesOf(to)) map { case (values, default) => FiniteMap(values, default, from, to) }
       case adt: ADTType => adt.getADT match {
         case tcons: TypedADTConstructor =>
           cartesianProduct(tcons.fieldsTypes map valuesOf) map (ADT(adt, _))
@@ -620,8 +620,9 @@ trait SymbolOps { self: TypeOps =>
       case (FiniteBag(elements, fbtpe), BagType(tpe)) =>
         fbtpe == tpe &&
         elements.forall{ case (key, value) => isValueOfType(key, tpe) && isValueOfType(value, IntegerType) }
-      case (FiniteMap(elems, default, kt), MapType(from, to)) =>
-        (kt == from) < s"$kt not equal to $from" && (default.getType == to) < s"${default.getType} not equal to $to" &&
+      case (FiniteMap(elems, default, kt, vt), MapType(from, to)) =>
+        (kt == from) < s"$kt not equal to $from" && (vt == to) < s"${default.getType} not equal to $to" &&
+        isValueOfType(default, to) < s"${default} not a value of type $to" &&
         (elems forall (kv => isValueOfType(kv._1, from) < s"${kv._1} not a value of type $from" && isValueOfType(unWrapSome(kv._2), to) < s"${unWrapSome(kv._2)} not a value of type ${to}" ))
       case (ADT(adt, args), adt2: ADTType) =>
         isSubtypeOf(adt, adt2) < s"$adt not a subtype of $adt2" &&
diff --git a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala
index 56ea64bf6..6675d5915 100644
--- a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala
@@ -118,7 +118,7 @@ trait RecursiveEvaluator
           BooleanLiteral(el1.toSet == el2.toSet)
         case (FiniteBag(el1, _),FiniteBag(el2, _)) =>
           BooleanLiteral(el1.toMap == el2.toMap)
-        case (FiniteMap(el1, dflt1, _),FiniteMap(el2, dflt2, _)) =>
+        case (FiniteMap(el1, dflt1, _, _),FiniteMap(el2, dflt2, _, _)) =>
           BooleanLiteral(el1.toMap == el2.toMap && dflt1 == dflt2)
         case (l1: Lambda, l2: Lambda) =>
           val (nl1, subst1) = normalizeStructure(l1)
@@ -481,20 +481,20 @@ trait RecursiveEvaluator
       replaceFromSymbols(variablesOf(c).map(v => v -> e(v)).toMap, c).asInstanceOf[Choose]
     }
 
-    case f @ FiniteMap(ss, dflt, vT) =>
+    case f @ FiniteMap(ss, dflt, kT, vT) =>
       // we use toMap.toSeq to reduce dupplicate keys
-      FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.toMap.toSeq, e(dflt), vT)
+      FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.toMap.toSeq, e(dflt), kT, vT)
 
     case g @ MapApply(m,k) => (e(m), e(k)) match {
-      case (FiniteMap(ss, dflt, _), e) =>
+      case (FiniteMap(ss, dflt, _, _), e) =>
         ss.toMap.getOrElse(e, dflt)
       case (l,r) =>
         throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType)))
     }
 
     case g @ MapUpdated(m, k, v) => (e(m), e(k), e(v)) match {
-      case (FiniteMap(ss, dflt, tpe), ek, ev) =>
-        FiniteMap((ss.toMap + (ek -> ev)).toSeq, dflt, tpe)
+      case (FiniteMap(ss, dflt, kT, vT), ek, ev) =>
+        FiniteMap((ss.toMap + (ek -> ev)).toSeq, dflt, kT, vT)
       case (m,l,r) =>
         throw EvalError("Unexpected operation: " + m.asString +
           ".updated(" + l.asString + ", " + r.asString + ")")
diff --git a/src/main/scala/inox/solvers/smtlib/CVC4Target.scala b/src/main/scala/inox/solvers/smtlib/CVC4Target.scala
index b28755ba2..5e07a5d49 100644
--- a/src/main/scala/inox/solvers/smtlib/CVC4Target.scala
+++ b/src/main/scala/inox/solvers/smtlib/CVC4Target.scala
@@ -54,17 +54,17 @@ trait CVC4Target extends SMTLIBTarget with SMTLIBDebugger {
         FiniteSet(Seq(), base)
 
       case (FunctionApplication(QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), _), Seq(elem)), Some(MapType(k, v))) =>
-        FiniteMap(Seq(), fromSMT(elem, v), k)
+        FiniteMap(Seq(), fromSMT(elem, v), k, v)
 
       case (FunctionApplication(SimpleSymbol(SSymbol("__array_store_all__")), Seq(_, elem)), Some(MapType(k, v))) =>
-        FiniteMap(Seq(), fromSMT(elem, v), k)
+        FiniteMap(Seq(), fromSMT(elem, v), k, v)
 
       case (FunctionApplication(SimpleSymbol(SSymbol("store")), Seq(arr, key, elem)), Some(MapType(kT, vT))) =>
-        val FiniteMap(elems, default, _) = fromSMT(arr, otpe)
+        val FiniteMap(elems, default, _, _) = fromSMT(arr, otpe)
         val newKey = fromSMT(key, kT)
         val newV   = fromSMT(elem, vT)
         val newElems = elems.filterNot(_._1 == newKey) :+ (newKey -> newV)
-        FiniteMap(newElems, default, kT)
+        FiniteMap(newElems, default, kT, vT)
 
       case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), Some(SetType(base))) =>
         FiniteSet(elems.map(fromSMT(_, base)), base)
diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala
index 890761d96..2b71f2ccf 100644
--- a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala
@@ -324,7 +324,7 @@ trait SMTLIBTarget extends Interruptible with ADTManagers {
         ArraysEx.Select(toSMT(a), toSMT(i))
       case al @ MapUpdated(map, k, v) =>
         ArraysEx.Store(toSMT(map), toSMT(k), toSMT(v))
-      case ra @ FiniteMap(elems, default, keyTpe) =>
+      case ra @ FiniteMap(elems, default, keyTpe, valueType) =>
         val s = declareSort(ra.getType)
 
         var res: Term = FunctionApplication(
diff --git a/src/main/scala/inox/solvers/smtlib/Z3Target.scala b/src/main/scala/inox/solvers/smtlib/Z3Target.scala
index e803c54a3..c9b7fc76e 100644
--- a/src/main/scala/inox/solvers/smtlib/Z3Target.scala
+++ b/src/main/scala/inox/solvers/smtlib/Z3Target.scala
@@ -75,28 +75,26 @@ trait Z3Target extends SMTLIBTarget with SMTLIBDebugger {
           }
           // Need to recover value form function model
           val (cases, default) = extractCases(body)
-          FiniteMap(cases.toSeq, default, keyType)
+          FiniteMap(cases.toSeq, default, keyType, valueType)
         } else {
           throw FatalError("Array on non-function or unknown symbol "+k)
         }
 
       case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe @ SetType(base))) =>
-        val fm @ FiniteMap(cases, dflt, _) = fromSMT(t, Some(MapType(base, BooleanType)))
+        val fm @ FiniteMap(cases, dflt, _, _) = fromSMT(t, Some(MapType(base, BooleanType)))
         if (dflt != BooleanLiteral(false)) unsupported(fm, "Solver returned a co-finite set which is not supported")
         FiniteSet(cases.collect { case (k, BooleanLiteral(true)) => k }, base)
 
       case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe @ BagType(base))) =>
-        val fm @ FiniteMap(cases, dflt, _) = fromSMT(t, Some(MapType(base, IntegerType)))
+        val fm @ FiniteMap(cases, dflt, _, _) = fromSMT(t, Some(MapType(base, IntegerType)))
         if (dflt != IntegerLiteral(0)) unsupported(fm, "Solver returned a co-finite bag which is not supported")
         FiniteBag(cases.filter(_._2 != IntegerLiteral(BigInt(0))), base)
 
       case (FunctionApplication(
         QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))),
         Seq(defV)
-      ), Some(tpe: MapType)) =>
-        val ktpe = sorts.fromB(k)
-        val vtpe = sorts.fromB(v)
-        FiniteMap(Seq(), fromSMT(defV, Some(vtpe)), ktpe)
+      ), Some(MapType(from, to))) =>
+        FiniteMap(Seq(), fromSMT(defV, Some(to)), from, to)
 
       case (FunctionApplication(
         QualifiedIdentifier(SMTIdentifier(SSymbol("const"), _), Some(ArraysEx.ArraySort(k, v))),
@@ -126,7 +124,7 @@ trait Z3Target extends SMTLIBTarget with SMTLIBDebugger {
      */
     case fs @ FiniteSet(elems, base) =>
       declareSort(fs.getType)
-      toSMT(FiniteMap(elems map ((_, BooleanLiteral(true))), BooleanLiteral(false), base))
+      toSMT(FiniteMap(elems map ((_, BooleanLiteral(true))), BooleanLiteral(false), base, BooleanType))
 
     case SubsetOf(ss, s) =>
       // a isSubset b   ==>   (a zip b).map(implies) == (* => true)
@@ -155,7 +153,7 @@ trait Z3Target extends SMTLIBTarget with SMTLIBDebugger {
     case fb @ FiniteBag(elems, base) =>
       val BagType(t) = fb.getType
       declareSort(BagType(t))
-      toSMT(FiniteMap(elems, IntegerLiteral(0), t))
+      toSMT(FiniteMap(elems, IntegerLiteral(0), t, IntegerType))
 
     case BagAdd(b, e) =>
       val bid = FreshIdentifier("b", true)
diff --git a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala
index a2fc56cf3..fd394dfb8 100644
--- a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala
@@ -441,7 +441,7 @@ trait AbstractZ3Solver
        */
       case fb @ FiniteBag(elems, base) =>
         typeToSort(fb.getType)
-        rec(FiniteMap(elems, IntegerLiteral(0), base))
+        rec(FiniteMap(elems, IntegerLiteral(0), base, IntegerType))
 
       case BagAdd(b, e) =>
         val (bag, elem) = (rec(b), rec(e))
@@ -476,7 +476,7 @@ trait AbstractZ3Solver
       case al @ MapUpdated(a, i, e) =>
         z3.mkStore(rec(a), rec(i), rec(e))
 
-      case FiniteMap(elems, default, keyTpe) =>
+      case FiniteMap(elems, default, keyTpe, valueType) =>
         val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default))
 
         elems.foldLeft(ar) {
@@ -612,12 +612,12 @@ trait AbstractZ3Solver
                       case (k,v) => (rec(k, from), rec(v, to))
                     }
 
-                    FiniteMap(entries.toSeq, default, from)
+                    FiniteMap(entries.toSeq, default, from, to)
                   case None => unsound(t, "invalid array AST")
                 }
 
               case BagType(base) =>
-                val fm @ FiniteMap(entries, default, from) = rec(t, MapType(base, IntegerType))
+                val fm @ FiniteMap(entries, default, from, IntegerType) = rec(t, MapType(base, IntegerType))
                 if (default != IntegerLiteral(0)) {
                   unsound(t, "co-finite bag AST")
                 }
diff --git a/src/main/scala/inox/tip/Parser.scala b/src/main/scala/inox/tip/Parser.scala
index 72245aeb8..2b6122036 100644
--- a/src/main/scala/inox/tip/Parser.scala
+++ b/src/main/scala/inox/tip/Parser.scala
@@ -513,7 +513,8 @@ class Parser(file: File) {
     case ArraysEx.Select(e1, e2) => MapApply(extractTerm(e1), extractTerm(e2))
     case ArraysEx.Store(e1, e2, e3) => MapUpdated(extractTerm(e1), extractTerm(e2), extractTerm(e3))
     case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("const")), Some(sort)), Seq(dflt)) =>
-      FiniteMap(Seq.empty, extractTerm(dflt), extractSort(sort))
+      val d = extractTerm(dflt)
+      FiniteMap(Seq.empty, d, extractSort(sort), locals.symbols.bestRealType(d.getType(locals.symbols)))
 
     case Sets.Union(e1, e2) => SetUnion(extractTerm(e1), extractTerm(e2))
     case Sets.Intersection(e1, e2) => SetIntersection(extractTerm(e1), extractTerm(e2))
-- 
GitLab