From f997139f0f41bb6a47a7b71cc6bf673994553e02 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Wed, 7 Oct 2015 17:43:19 +0200
Subject: [PATCH] FiniteMap is implemented with a Map

---
 .../scala/leon/codegen/CompilationUnit.scala  |  4 ++--
 .../scala/leon/datagen/VanuatooDataGen.scala  |  2 +-
 .../leon/evaluators/RecursiveEvaluator.scala  | 22 +++++++++----------
 .../frontends/scalac/CodeExtraction.scala     | 12 +++++-----
 src/main/scala/leon/purescala/ExprOps.scala   |  4 ++--
 .../scala/leon/purescala/Expressions.scala    |  2 +-
 .../scala/leon/purescala/Extractors.scala     |  8 +++----
 .../solvers/smtlib/SMTLIBCVC4Target.scala     |  6 ++---
 .../leon/solvers/smtlib/SMTLIBTarget.scala    |  2 +-
 .../leon/solvers/z3/AbstractZ3Solver.scala    |  2 +-
 .../evaluators/EvaluatorSuite.scala           |  9 --------
 11 files changed, 31 insertions(+), 42 deletions(-)

diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala
index d14b02f95..69f95f455 100644
--- a/src/main/scala/leon/codegen/CompilationUnit.scala
+++ b/src/main/scala/leon/codegen/CompilationUnit.scala
@@ -315,8 +315,8 @@ class CompilationUnit(val ctx: LeonContext,
         val k = jvmToValue(entry.getKey, from)
         val v = jvmToValue(entry.getValue, to)
         (k, v)
-      }
-      FiniteMap(pairs.toSeq, from, to)
+      }.toMap
+      FiniteMap(pairs, from, to)
 
     case (lambda: runtime.Lambda, _: FunctionType) =>
       val cls = lambda.getClass
diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala
index b742ec58f..e68e21f01 100644
--- a/src/main/scala/leon/datagen/VanuatooDataGen.scala
+++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala
@@ -99,7 +99,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
         val cs = for (size <- List(0, 1, 2, 5)) yield {
           val subs   = (1 to size).flatMap(i => List(from, to)).toList
 
-          Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toSeq, from, to), mt.asString(ctx)+"@"+size)
+          Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toMap, from, to), mt.asString(ctx)+"@"+size)
         }
         constructors += mt -> cs
         cs
diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index bd2a594ec..f17f9caef 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -603,26 +603,24 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
       )
 
     case f @ FiniteMap(ss, kT, vT) =>
-      FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.distinct, kT, vT)
+      FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }, kT, vT)
 
     case g @ MapApply(m,k) => (e(m), e(k)) match {
-      case (FiniteMap(ss, _, _), e) => ss.find(_._1 == e) match {
-        case Some((_, v0)) => v0
-        case None => throw RuntimeError("Key not found: " + e.asString)
-      }
-      case (l,r) => throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType)))
+      case (FiniteMap(ss, _, _), e) =>
+        ss.getOrElse(e, throw RuntimeError("Key not found: " + e.asString))
+      case (l,r) =>
+        throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType)))
     }
     case u @ MapUnion(m1,m2) => (e(m1), e(m2)) match {
-      case (f1@FiniteMap(ss1, _, _), FiniteMap(ss2, _, _)) => {
-        val filtered1 = ss1.filterNot(s1 => ss2.exists(s2 => s2._1 == s1._1))
-        val newSs = filtered1 ++ ss2
+      case (f1@FiniteMap(ss1, _, _), FiniteMap(ss2, _, _)) =>
+        val newSs = ss1 ++ ss2
         val MapType(kT, vT) = u.getType
         FiniteMap(newSs, kT, vT)
-      }
-      case (l, r) => throw EvalError(typeErrorMsg(l, m1.getType))
+      case (l, r) =>
+        throw EvalError(typeErrorMsg(l, m1.getType))
     }
     case i @ MapIsDefinedAt(m,k) => (e(m), e(k)) match {
-      case (FiniteMap(ss, _, _), e) => BooleanLiteral(ss.exists(_._1 == e))
+      case (FiniteMap(ss, _, _), e) => BooleanLiteral(ss.contains(e))
       case (l, r) => throw EvalError(typeErrorMsg(l, m.getType))
     }
 
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index be64c2f77..9fa7606da 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -1329,10 +1329,10 @@ trait CodeExtraction extends ASTExtractors {
           Forall(vds, exBody)
 
         case ExFiniteMap(tptFrom, tptTo, args) =>
-          val singletons: Seq[(LeonExpr, LeonExpr)] = args.collect {
+          val singletons = args.collect {
             case ExTuple(tpes, trees) if trees.size == 2 =>
               (extractTree(trees(0)), extractTree(trees(1)))
-          }
+          }.toMap
 
           if (singletons.size != args.size) {
             outOfSubsetError(tr, "Some map elements could not be extracted as Tuple2")
@@ -1678,17 +1678,17 @@ trait CodeExtraction extends ASTExtractors {
               MapIsDefinedAt(a1, a2)
 
             case (IsTyped(a1, mt: MapType), "updated", List(k, v)) =>
-              MapUnion(a1, FiniteMap(Seq((k, v)), mt.from, mt.to))
+              MapUnion(a1, FiniteMap(Map(k -> v), mt.from, mt.to))
 
             case (IsTyped(a1, mt: MapType), "+", List(k, v)) =>
-              MapUnion(a1, FiniteMap(Seq((k, v)), mt.from, mt.to))
+              MapUnion(a1, FiniteMap(Map(k -> v), mt.from, mt.to))
 
             case (IsTyped(a1, mt: MapType), "+", List(IsTyped(kv, TupleType(List(_, _))))) =>
               kv match {
                 case Tuple(List(k, v)) =>
-                  MapUnion(a1, FiniteMap(Seq((k, v)), mt.from, mt.to))
+                  MapUnion(a1, FiniteMap(Map(k -> v), mt.from, mt.to))
                 case kv =>
-                  MapUnion(a1, FiniteMap(Seq((TupleSelect(kv, 1), TupleSelect(kv, 2))), mt.from, mt.to))
+                  MapUnion(a1, FiniteMap(Map(TupleSelect(kv, 1) -> TupleSelect(kv, 2)), mt.from, mt.to))
               }
 
             case (IsTyped(a1, mt1: MapType), "++", List(IsTyped(a2, mt2: MapType)))  if mt1 == mt2 =>
diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala
index 2f703312a..200a9e167 100644
--- a/src/main/scala/leon/purescala/ExprOps.scala
+++ b/src/main/scala/leon/purescala/ExprOps.scala
@@ -1107,7 +1107,7 @@ object ExprOps {
     case BooleanType                => BooleanLiteral(false)
     case UnitType                   => UnitLiteral()
     case SetType(baseType)          => FiniteSet(Set(), baseType)
-    case MapType(fromType, toType)  => FiniteMap(Nil, fromType, toType)
+    case MapType(fromType, toType)  => FiniteMap(Map(), fromType, toType)
     case TupleType(tpes)            => Tuple(tpes.map(simplestValue))
     case ArrayType(tpe)             => EmptyArray(tpe)
 
@@ -1304,7 +1304,7 @@ object ExprOps {
     case wp: WildcardPattern =>
       1
     case _ =>
-      1 + (if(p.binder.isDefined) 1 else 0) + p.subPatterns.map(patternSize).sum
+      1 + p.binder.size + p.subPatterns.map(patternSize).sum
   }
 
   def formulaSize(e: Expr): Int = e match {
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index 3db4d0bb7..7216652a4 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -788,7 +788,7 @@ object Expressions {
 
   /* Map operations */
   /** $encodingof `Map[keyType, valueType](key1 -> value1, key2 -> value2 ...)` */
-  case class FiniteMap(singletons: Seq[(Expr, Expr)], keyType: TypeTree, valueType: TypeTree) extends Expr {
+  case class FiniteMap(pairs: Map[Expr, Expr], keyType: TypeTree, valueType: TypeTree) extends Expr {
     val getType = MapType(keyType, valueType).unveilUntyped
   }
   /** $encodingof `map.apply(key)` (or `map(key)`)*/
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index d2ccfa7da..cd57d2187 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -159,12 +159,12 @@ object Extractors {
       case FiniteSet(els, base) =>
         Some((els.toSeq, els => FiniteSet(els.toSet, base)))
       case FiniteMap(args, f, t) => {
-        val subArgs = args.flatMap { case (k, v) => Seq(k, v) }
+        val subArgs = args.flatMap { case (k, v) => Seq(k, v) }.toSeq
         val builder = (as: Seq[Expr]) => {
-          def rec(kvs: Seq[Expr]): Seq[(Expr, Expr)] = kvs match {
+          def rec(kvs: Seq[Expr]): Map[Expr, Expr] = kvs match {
             case Seq(k, v, t@_*) =>
-              (k, v) +: rec(t)
-            case Seq() => Seq()
+              Map(k -> v) ++ rec(t)
+            case Seq() => Map()
             case _ => sys.error("odd number of key/value expressions")
           }
           FiniteMap(rec(as), f, t)
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
index 620a294ed..cb1c3246d 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala
@@ -63,7 +63,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
             RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to))
 
           case MapType(k, v) =>
-            FiniteMap(Nil, k, v)
+            FiniteMap(Map(), k, v)
 
         }
 
@@ -76,7 +76,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
             RawArrayValue(tupleTypeWrap(from), Map(), fromSMT(elem, to))
 
           case MapType(k, v) =>
-            FiniteMap(Nil, k, v)
+            FiniteMap(Map(), k, v)
 
         }
 
@@ -92,7 +92,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget {
 
           case MapType(k, v) =>
             val FiniteMap(elems, k, v) = fromSMT(arr, otpe)
-            FiniteMap(elems :+ (fromSMT(key, k) -> fromSMT(elem, v)), k, v)
+            FiniteMap(elems + (fromSMT(key, k) -> fromSMT(elem, v)), k, v)
         }
 
       case (FunctionApplication(SimpleSymbol(SSymbol("singleton")), elems), Some(SetType(base))) =>
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
index 952dbeadc..9035700a0 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala
@@ -213,7 +213,7 @@ trait SMTLIBTarget extends Interruptible {
       val elems = r.elems.flatMap {
         case (k, CaseClass(leonSome, Seq(x))) => Some(k -> x)
         case (k, _)                           => None
-      }.toSeq
+      }.toMap
       FiniteMap(elems, from, to)
 
     case other =>
diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
index abf4e4169..9a06cff26 100644
--- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
+++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala
@@ -658,7 +658,7 @@ trait AbstractZ3Solver extends Solver {
                     val elems = r.elems.flatMap {
                       case (k, CaseClass(leonSome, Seq(x))) => Some(k -> x)
                       case (k, _) => None
-                    }.toSeq
+                    }.toMap
 
                     FiniteMap(elems, from, to)
                 }
diff --git a/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala b/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala
index 6eaf48138..810d9d40e 100644
--- a/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala
+++ b/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala
@@ -6,17 +6,10 @@ import leon._
 import leon.test._
 import leon.test.helpers._
 import leon.evaluators._
-
-import leon.utils.{TemporaryInputPhase, PreprocessingPhase}
-import leon.frontends.scalac.ExtractionPhase
-
 import leon.purescala.Common._
 import leon.purescala.Definitions._
 import leon.purescala.Expressions._
-import leon.purescala.DefOps._
 import leon.purescala.Types._
-import leon.purescala.Extractors._
-import leon.purescala.Constructors._
 import leon.codegen._
 
 class EvaluatorSuite extends LeonTestSuiteWithProgram with ExpressionsDSL {
@@ -292,8 +285,6 @@ class EvaluatorSuite extends LeonTestSuiteWithProgram with ExpressionsDSL {
         fail("Expected FiniteMap, got "+e)
     }
 
-    val em     = FiniteMap(Seq(), Int32Type, Int32Type)
-
     for(e <- allEvaluators) {
       eqMap(eval(e, fcall("Maps.finite0")()).res, Map())
       eqMap(eval(e, fcall("Maps.finite1")()).res, Map(i(1) -> i(2)))
-- 
GitLab