From 0994ddb6a66471b9da803d0b3b005ded23ed5092 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Sat, 9 Apr 2016 02:41:14 +0200
Subject: [PATCH] test and fix Bijection

---
 src/main/scala/leon/utils/Bijection.scala     | 41 +++++++----
 .../leon/unit/utils/BijectionSuite.scala      | 71 +++++++++++++++++++
 2 files changed, 100 insertions(+), 12 deletions(-)
 create mode 100644 src/test/scala/leon/unit/utils/BijectionSuite.scala

diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala
index 251ebde95..b6f999034 100644
--- a/src/main/scala/leon/utils/Bijection.scala
+++ b/src/main/scala/leon/utils/Bijection.scala
@@ -9,6 +9,21 @@ object Bijection {
   def apply[A, B](a: (A, B)*): Bijection[A, B] = apply(a.toSeq)
 }
 
+/** Represents a Bijection between A and B.
+  *
+  * This basically maintains a bi-directional mapping
+  * between A and B. This is a common operation throughout
+  * Leon, that would be usually done by keeping a mapping and
+  * the corresponding reverse mapping. This class abstracts
+  * away the details.
+  *
+  * This is a true Bijection, which means that at any point in
+  * time there is a one-to-one correspondance between each element in
+  * A and B. In particular, adding two successive mapping from a1 to b1,
+  * and then from a1 to b2, will represent the final bijection of a1 <-> b2.
+  * calling getA(b1) should return None at that point, even though it used
+  * to map to a1 before.
+  */
 class Bijection[A, B] extends Iterable[(A, B)] {
   protected val a2b = MutableMap[A, B]()
   protected val b2a = MutableMap[B, A]()
@@ -16,6 +31,8 @@ class Bijection[A, B] extends Iterable[(A, B)] {
   def iterator = a2b.iterator
 
   def +=(a: A, b: B): Unit = {
+    getB(a).foreach(ob => b2a.remove(ob))
+    getA(b).foreach(oa => a2b.remove(oa))
     a2b += a -> b
     b2a += b -> a
   }
@@ -34,19 +51,19 @@ class Bijection[A, B] extends Iterable[(A, B)] {
     b2a.clear()
   }
 
-  def getA(b: B) = b2a.get(b)
-  def getB(a: A) = a2b.get(a)
+  def getA(b: B): Option[A] = b2a.get(b)
+  def getB(a: A): Option[B] = a2b.get(a)
   
-  def getAorElse(b: B, orElse: =>A) = b2a.getOrElse(b, orElse)
-  def getBorElse(a: A, orElse: =>B) = a2b.getOrElse(a, orElse)
+  def getAorElse(b: B, orElse: =>A): A = b2a.getOrElse(b, orElse)
+  def getBorElse(a: A, orElse: =>B): B = a2b.getOrElse(a, orElse)
 
-  def toA(b: B) = getA(b).get
-  def toB(a: A) = getB(a).get
+  def toA(b: B): A = getA(b).get
+  def toB(a: A): B = getB(a).get
 
-  def fromA(a: A) = getB(a).get
-  def fromB(b: B) = getA(b).get
+  def fromA(a: A): B = getB(a).get
+  def fromB(b: B): A = getA(b).get
 
-  def cachedB(a: A)(c: => B) = {
+  def cachedB(a: A)(c: => B): B = {
     getB(a).getOrElse {
       val res = c
       this += a -> res
@@ -54,7 +71,7 @@ class Bijection[A, B] extends Iterable[(A, B)] {
     }
   }
 
-  def cachedA(b: B)(c: => A) = {
+  def cachedA(b: B)(c: => A): A = {
     getA(b).getOrElse {
       val res = c
       this += res -> b
@@ -62,8 +79,8 @@ class Bijection[A, B] extends Iterable[(A, B)] {
     }
   }
 
-  def containsA(a: A) = a2b contains a
-  def containsB(b: B) = b2a contains b
+  def containsA(a: A): Boolean = a2b contains a
+  def containsB(b: B): Boolean = b2a contains b
 
   def aSet = a2b.keySet
   def bSet = b2a.keySet
diff --git a/src/test/scala/leon/unit/utils/BijectionSuite.scala b/src/test/scala/leon/unit/utils/BijectionSuite.scala
new file mode 100644
index 000000000..05be7f981
--- /dev/null
+++ b/src/test/scala/leon/unit/utils/BijectionSuite.scala
@@ -0,0 +1,71 @@
+/* Copyright 2009-2016 EPFL, Lausanne */
+
+package leon.unit.utils
+
+import leon.utils._
+import org.scalatest._
+
+class BijectionSuite extends FunSuite {
+
+  test("Empty Bijection returns None") {
+    val b = new Bijection[Int, Int]
+    assert(b.getA(0) === None)
+    assert(b.getA(1) === None)
+    assert(b.getB(0) === None)
+    assert(b.getB(1) === None)
+  }
+
+  test("Bijection with one element works both way") {
+    val b = new Bijection[Int, Int]
+    b += (12 -> 33)
+
+    assert(b.getA(33) === Some(12))
+    assert(b.getA(1) === None)
+    assert(b.getB(12) === Some(33))
+    assert(b.getB(1) === None)
+  }
+
+  test("Bijection latest update prevails") {
+    val b = new Bijection[Int, Int]
+    b += (12 -> 33)
+    b += (12 -> 34)
+
+    assert(b.getB(1) === None)
+    assert(b.getB(12) === Some(34))
+  }
+
+  test("Bijection latest update should delete all previous existing mappings") {
+    val b = new Bijection[Int, Int]
+    b += (12 -> 33)
+    b += (12 -> 34)
+
+    assert(b.getB(12) === Some(34))
+    assert(b.getA(34) === Some(12))
+    assert(b.getA(33) === None)
+
+    val b2 = new Bijection[Int, Int]
+    b2 += (12 -> 33)
+    b2 += (11 -> 33)
+
+    assert(b2.getB(12) === None)
+    assert(b2.getB(11) === Some(33))
+    assert(b2.getA(33) === Some(11))
+  }
+
+  test("Bijection multiple mixed updates should maintain the invariant") {
+    val b = new Bijection[Int, Int]
+    b += (12 -> 33)
+    b += (13 -> 34)
+    b += (12 -> 34)
+    b += (11 -> 33)
+    b += (13 -> 32)
+
+    assert(b.getB(12) == Some(34))
+    assert(b.getB(13) == Some(32))
+    assert(b.getB(11) == Some(33))
+    assert(b.getA(34) == Some(12))
+    assert(b.getA(32) == Some(13))
+    assert(b.getA(33) == Some(11))
+  }
+
+}
-- 
GitLab