From 1b3e61b0e1b5372cf539afe9df46130ecaad091e Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <ekneuss@gmail.com>
Date: Sat, 30 May 2015 02:11:07 +0200
Subject: [PATCH] Move Option to lang.Option, add Map.get/getOrElse/+

---
 doc/purescala.rst                             |  4 +++
 library/lang/Map.scala                        | 13 +++++++++-
 library/{collection => lang}/Option.scala     |  2 +-
 .../frontends/scalac/CodeExtraction.scala     | 25 +++++++++++++++++++
 .../leon/solvers/smtlib/SMTLIBSolver.scala    |  7 +++++-
 .../leon/solvers/smtlib/SMTLIBZ3Solver.scala  |  1 +
 src/main/scala/leon/utils/Library.scala       |  6 ++---
 .../frontends/passing/ImplicitDefs2.scala     |  1 -
 .../purescala/valid/MapGetOrElse2.scala       | 17 +++++++++++++
 .../purescala/valid/MapPlus.scala             | 18 +++++++++++++
 testcases/repair/Parser/Parser.scala          |  1 +
 testcases/repair/Parser/Parser1.scala         |  1 +
 testcases/repair/Parser/Parser2.scala         |  1 +
 testcases/repair/Parser/Parser3.scala         |  1 +
 testcases/repair/Parser/Parser4.scala         |  1 +
 testcases/repair/Parser/Parser5.scala         |  1 +
 16 files changed, 93 insertions(+), 7 deletions(-)
 rename library/{collection => lang}/Option.scala (98%)
 create mode 100644 src/test/resources/regression/verification/purescala/valid/MapGetOrElse2.scala
 create mode 100644 src/test/resources/regression/verification/purescala/valid/MapPlus.scala

diff --git a/doc/purescala.rst b/doc/purescala.rst
index 004d37a13..f3c40b5fd 100644
--- a/doc/purescala.rst
+++ b/doc/purescala.rst
@@ -297,6 +297,10 @@ Map
  m isDefinedAt index
  m contains index
  m.updated(index, value)
+ m + (index -> value)
+ m + (value, index)
+ m.get(index)
+ m.getOrElse(index, value2)
 
 
 Function
diff --git a/library/lang/Map.scala b/library/lang/Map.scala
index bd186b721..368c028ab 100644
--- a/library/lang/Map.scala
+++ b/library/lang/Map.scala
@@ -12,10 +12,21 @@ object Map {
 }
 
 @ignore
-case class Map[A, B](val theMap: scala.collection.immutable.Map[A,B]) {
+case class Map[A, B] (theMap: scala.collection.immutable.Map[A,B]) {
   def apply(k: A): B = theMap.apply(k)
   def ++(b: Map[A, B]): Map[A,B] = new Map (theMap ++ b.theMap)
   def updated(k: A, v: B): Map[A,B] = new Map(theMap.updated(k, v))
   def contains(a: A): Boolean = theMap.contains(a)
   def isDefinedAt(a: A): Boolean = contains(a)
+
+  def +(kv: (A, B)): Map[A,B] = updated(kv._1, kv._2)
+  def +(k: A, v: B): Map[A,B] = updated(k, v)
+
+  def getOrElse(k: A, default: B): B = get(k).getOrElse(default)
+
+  def get(k: A): Option[B] = if (contains(k)) {
+    Some[B](apply(k))
+  } else {
+    None[B]()
+  }
 }
diff --git a/library/collection/Option.scala b/library/lang/Option.scala
similarity index 98%
rename from library/collection/Option.scala
rename to library/lang/Option.scala
index 56feffdc0..0d77ef74e 100644
--- a/library/collection/Option.scala
+++ b/library/lang/Option.scala
@@ -1,6 +1,6 @@
 /* Copyright 2009-2015 EPFL, Lausanne */
 
-package leon.collection
+package leon.lang
 
 import leon.annotation._
 
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index b37b6672f..204720092 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -520,6 +520,7 @@ trait CodeExtraction extends ASTExtractors {
 
         parent.foreach(_.classDef.registerChildren(ccd))
 
+
         classesToClasses += sym -> ccd
 
         val fields = args.map { case (symbol, t) =>
@@ -1704,6 +1705,19 @@ trait CodeExtraction extends ASTExtractors {
             case (IsTyped(a1, MapType(_, vt)), "apply", List(a2)) =>
               MapGet(a1, a2)
 
+            case (IsTyped(a1, MapType(_, vt)), "get", List(a2)) =>
+              val someClass = CaseClassType(libraryCaseClass(sym.pos, "leon.lang.Some"), Seq(vt))
+              val noneClass = CaseClassType(libraryCaseClass(sym.pos, "leon.lang.None"), Seq(vt))
+
+              IfExpr(MapIsDefinedAt(a1, a2).setPos(current.pos),
+                CaseClass(someClass, Seq(MapGet(a1, a2).setPos(current.pos))).setPos(current.pos),
+                CaseClass(noneClass, Seq()).setPos(current.pos))
+
+            case (IsTyped(a1, MapType(_, vt)), "getOrElse", List(a2, a3)) =>
+              IfExpr(MapIsDefinedAt(a1, a2).setPos(current.pos),
+                MapGet(a1, a2).setPos(current.pos),
+                a3)
+
             case (IsTyped(a1, mt: MapType), "isDefinedAt", List(a2)) =>
               MapIsDefinedAt(a1, a2)
 
@@ -1713,6 +1727,17 @@ trait CodeExtraction extends ASTExtractors {
             case (IsTyped(a1, mt: MapType), "updated", List(k, v)) =>
               MapUnion(a1, FiniteMap(Seq((k, v)), mt.from, mt.to))
 
+            case (IsTyped(a1, mt: MapType), "+", List(k, v)) =>
+              MapUnion(a1, FiniteMap(Seq((k, v)), mt.from, mt.to))
+
+            case (IsTyped(a1, mt: MapType), "+", List(IsTyped(kv, TupleType(List(_, _))))) =>
+              kv match {
+                case Tuple(List(k, v)) =>
+                  MapUnion(a1, FiniteMap(Seq((k, v)), mt.from, mt.to))
+                case kv =>
+                  MapUnion(a1, FiniteMap(Seq((TupleSelect(kv, 1), TupleSelect(kv, 2))), mt.from, mt.to))
+              }
+
             case (IsTyped(a1, mt1: MapType), "++", List(IsTyped(a2, mt2: MapType)))  if mt1 == mt2 =>
               MapUnion(a1, a2)
 
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
index 3df3f89e5..ae77673d8 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
@@ -225,7 +225,12 @@ abstract class SMTLIBSolver(val context: LeonContext,
     case MapType(from, to) =>
       // We expect a RawArrayValue with keys in from and values in Option[to],
       // with default value == None
-      require(from == r.keyTpe && r.default == OptionManager.mkLeonNone(to))
+      if (r.default != OptionManager.mkLeonNone(to)) {
+        reporter.warning("Co-finite maps are not supported. (Default was "+r.default+")")
+        throw new IllegalArgumentException
+      }
+      require(r.keyTpe == from, s"Type error in solver model, expected $from, found ${r.keyTpe}")
+
       val elems = r.elems.flatMap {
         case (k, CaseClass(leonSome, Seq(x))) => Some(k -> x)
         case (k, _) => None
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala
index 37efdd120..cc6e2161c 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Solver.scala
@@ -156,6 +156,7 @@ class SMTLIBZ3Solver(context: LeonContext, program: Program) extends SMTLIBSolve
     case DefineFun(SMTFunDef(a, Seq(SortedVar(arg, akind)), rkind, body)) =>
       val (argTpe, retTpe) = tpe match {
         case SetType(base) => (base, BooleanType)
+        case MapType(from, to) => (from, OptionManager.leonOptionType(to))
         case ArrayType(base) => (Int32Type, base)
         case FunctionType(args, ret) => (tupleTypeWrap(args), ret)
         case RawArrayType(from, to) => (from, to)
diff --git a/src/main/scala/leon/utils/Library.scala b/src/main/scala/leon/utils/Library.scala
index f22169ba9..e7c76c831 100644
--- a/src/main/scala/leon/utils/Library.scala
+++ b/src/main/scala/leon/utils/Library.scala
@@ -11,9 +11,9 @@ case class Library(pgm: Program) {
   lazy val Cons = lookup("leon.collection.Cons") collect { case ccd : CaseClassDef => ccd }
   lazy val Nil  = lookup("leon.collection.Nil") collect { case ccd : CaseClassDef => ccd }
 
-  lazy val Option = lookup("leon.collection.Option") collect { case acd : AbstractClassDef => acd }
-  lazy val Some = lookup("leon.collection.Some") collect { case ccd : CaseClassDef => ccd }
-  lazy val None = lookup("leon.collection.None") collect { case ccd : CaseClassDef => ccd }
+  lazy val Option = lookup("leon.lang.Option") collect { case acd : AbstractClassDef => acd }
+  lazy val Some = lookup("leon.lang.Some") collect { case ccd : CaseClassDef => ccd }
+  lazy val None = lookup("leon.lang.None") collect { case ccd : CaseClassDef => ccd }
 
   lazy val String = lookup("leon.lang.string.String") collect { case ccd : CaseClassDef => ccd }
 
diff --git a/src/test/resources/regression/frontends/passing/ImplicitDefs2.scala b/src/test/resources/regression/frontends/passing/ImplicitDefs2.scala
index a157e5dd6..9f4b1a83f 100644
--- a/src/test/resources/regression/frontends/passing/ImplicitDefs2.scala
+++ b/src/test/resources/regression/frontends/passing/ImplicitDefs2.scala
@@ -3,7 +3,6 @@ package leon.blup
 import leon._
 import leon.lang._
 import leon.annotation._
-import leon.collection.{Option,None,Some}
 //import leon.proof._
 
 // FIXME: the following should go into the leon.proof package object.
diff --git a/src/test/resources/regression/verification/purescala/valid/MapGetOrElse2.scala b/src/test/resources/regression/verification/purescala/valid/MapGetOrElse2.scala
new file mode 100644
index 000000000..7b1e19418
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/MapGetOrElse2.scala
@@ -0,0 +1,17 @@
+import leon.lang._
+
+object MapGetOrElse2 {
+  def test1(a: Map[BigInt, BigInt]) = {
+    require(!(a contains 0))
+    a.get(0)
+  } ensuring {
+    _ == None[BigInt]()
+  }
+
+  def test2(a: Map[BigInt, BigInt]) = {
+    require(!(a contains 0))
+    a.getOrElse(0, 0)
+  } ensuring {
+    _ == 0
+  }
+}
diff --git a/src/test/resources/regression/verification/purescala/valid/MapPlus.scala b/src/test/resources/regression/verification/purescala/valid/MapPlus.scala
new file mode 100644
index 000000000..86a3d6d22
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/MapPlus.scala
@@ -0,0 +1,18 @@
+import leon.lang._
+
+object MapGetPlus {
+  def test1(a: Map[BigInt, BigInt]) = {
+    require(!(a contains 0))
+    val b = a + (0, 1)
+    val c = a + (BigInt(0) -> BigInt(1))
+    b(0) == c(0)
+  }.holds
+
+  def test2(a: Map[BigInt, BigInt]) = {
+    require(!(a contains 0))
+    val t = (BigInt(0) -> BigInt(1))
+    val b = a + (0, 1)
+    val c = a + t
+    b(0) == c(0)
+  }.holds
+}
diff --git a/testcases/repair/Parser/Parser.scala b/testcases/repair/Parser/Parser.scala
index 876652d8e..8d838f470 100644
--- a/testcases/repair/Parser/Parser.scala
+++ b/testcases/repair/Parser/Parser.scala
@@ -1,4 +1,5 @@
 import leon._
+import leon.lang._
 import leon.collection._
 
 object Parser {
diff --git a/testcases/repair/Parser/Parser1.scala b/testcases/repair/Parser/Parser1.scala
index 1a12b820c..5fd18245f 100644
--- a/testcases/repair/Parser/Parser1.scala
+++ b/testcases/repair/Parser/Parser1.scala
@@ -1,4 +1,5 @@
 import leon._
+import leon.lang._
 import leon.collection._
 
 object Parser {
diff --git a/testcases/repair/Parser/Parser2.scala b/testcases/repair/Parser/Parser2.scala
index 24fef5686..919132e04 100644
--- a/testcases/repair/Parser/Parser2.scala
+++ b/testcases/repair/Parser/Parser2.scala
@@ -1,4 +1,5 @@
 import leon._
+import leon.lang._
 import leon.collection._
 
 object Parser {
diff --git a/testcases/repair/Parser/Parser3.scala b/testcases/repair/Parser/Parser3.scala
index dcc8107f3..cf11c4d3f 100644
--- a/testcases/repair/Parser/Parser3.scala
+++ b/testcases/repair/Parser/Parser3.scala
@@ -1,4 +1,5 @@
 import leon._
+import leon.lang._
 import leon.collection._
 
 object Parser {
diff --git a/testcases/repair/Parser/Parser4.scala b/testcases/repair/Parser/Parser4.scala
index dcc8107f3..cf11c4d3f 100644
--- a/testcases/repair/Parser/Parser4.scala
+++ b/testcases/repair/Parser/Parser4.scala
@@ -1,4 +1,5 @@
 import leon._
+import leon.lang._
 import leon.collection._
 
 object Parser {
diff --git a/testcases/repair/Parser/Parser5.scala b/testcases/repair/Parser/Parser5.scala
index dcc8107f3..cf11c4d3f 100644
--- a/testcases/repair/Parser/Parser5.scala
+++ b/testcases/repair/Parser/Parser5.scala
@@ -1,4 +1,5 @@
 import leon._
+import leon.lang._
 import leon.collection._
 
 object Parser {
-- 
GitLab