From 0994ddb6a66471b9da803d0b3b005ded23ed5092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Sat, 9 Apr 2016 02:41:14 +0200 Subject: [PATCH] test and fix Bijection --- src/main/scala/leon/utils/Bijection.scala | 41 +++++++---- .../leon/unit/utils/BijectionSuite.scala | 71 +++++++++++++++++++ 2 files changed, 100 insertions(+), 12 deletions(-) create mode 100644 src/test/scala/leon/unit/utils/BijectionSuite.scala diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index 251ebde95..b6f999034 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -9,6 +9,21 @@ object Bijection { def apply[A, B](a: (A, B)*): Bijection[A, B] = apply(a.toSeq) } +/** Represents a Bijection between A and B. + * + * This basically maintains a bi-directional mapping + * between A and B. This is a common operation throughout + * Leon, that would be usually done by keeping a mapping and + * the corresponding reverse mapping. This class abstracts + * away the details. + * + * This is a true Bijection, which means that at any point in + * time there is a one-to-one correspondance between each element in + * A and B. In particular, adding two successive mapping from a1 to b1, + * and then from a1 to b2, will represent the final bijection of a1 <-> b2. + * calling getA(b1) should return None at that point, even though it used + * to map to a1 before. + */ class Bijection[A, B] extends Iterable[(A, B)] { protected val a2b = MutableMap[A, B]() protected val b2a = MutableMap[B, A]() @@ -16,6 +31,8 @@ class Bijection[A, B] extends Iterable[(A, B)] { def iterator = a2b.iterator def +=(a: A, b: B): Unit = { + getB(a).foreach(ob => b2a.remove(ob)) + getA(b).foreach(oa => a2b.remove(oa)) a2b += a -> b b2a += b -> a } @@ -34,19 +51,19 @@ class Bijection[A, B] extends Iterable[(A, B)] { b2a.clear() } - def getA(b: B) = b2a.get(b) - def getB(a: A) = a2b.get(a) + def getA(b: B): Option[A] = b2a.get(b) + def getB(a: A): Option[B] = a2b.get(a) - def getAorElse(b: B, orElse: =>A) = b2a.getOrElse(b, orElse) - def getBorElse(a: A, orElse: =>B) = a2b.getOrElse(a, orElse) + def getAorElse(b: B, orElse: =>A): A = b2a.getOrElse(b, orElse) + def getBorElse(a: A, orElse: =>B): B = a2b.getOrElse(a, orElse) - def toA(b: B) = getA(b).get - def toB(a: A) = getB(a).get + def toA(b: B): A = getA(b).get + def toB(a: A): B = getB(a).get - def fromA(a: A) = getB(a).get - def fromB(b: B) = getA(b).get + def fromA(a: A): B = getB(a).get + def fromB(b: B): A = getA(b).get - def cachedB(a: A)(c: => B) = { + def cachedB(a: A)(c: => B): B = { getB(a).getOrElse { val res = c this += a -> res @@ -54,7 +71,7 @@ class Bijection[A, B] extends Iterable[(A, B)] { } } - def cachedA(b: B)(c: => A) = { + def cachedA(b: B)(c: => A): A = { getA(b).getOrElse { val res = c this += res -> b @@ -62,8 +79,8 @@ class Bijection[A, B] extends Iterable[(A, B)] { } } - def containsA(a: A) = a2b contains a - def containsB(b: B) = b2a contains b + def containsA(a: A): Boolean = a2b contains a + def containsB(b: B): Boolean = b2a contains b def aSet = a2b.keySet def bSet = b2a.keySet diff --git a/src/test/scala/leon/unit/utils/BijectionSuite.scala b/src/test/scala/leon/unit/utils/BijectionSuite.scala new file mode 100644 index 000000000..05be7f981 --- /dev/null +++ b/src/test/scala/leon/unit/utils/BijectionSuite.scala @@ -0,0 +1,71 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package leon.unit.utils + +import leon.utils._ +import org.scalatest._ + +class BijectionSuite extends FunSuite { + + test("Empty Bijection returns None") { + val b = new Bijection[Int, Int] + assert(b.getA(0) === None) + assert(b.getA(1) === None) + assert(b.getB(0) === None) + assert(b.getB(1) === None) + } + + test("Bijection with one element works both way") { + val b = new Bijection[Int, Int] + b += (12 -> 33) + + assert(b.getA(33) === Some(12)) + assert(b.getA(1) === None) + assert(b.getB(12) === Some(33)) + assert(b.getB(1) === None) + } + + test("Bijection latest update prevails") { + val b = new Bijection[Int, Int] + b += (12 -> 33) + b += (12 -> 34) + + assert(b.getB(1) === None) + assert(b.getB(12) === Some(34)) + } + + test("Bijection latest update should delete all previous existing mappings") { + val b = new Bijection[Int, Int] + b += (12 -> 33) + b += (12 -> 34) + + assert(b.getB(12) === Some(34)) + assert(b.getA(34) === Some(12)) + assert(b.getA(33) === None) + + val b2 = new Bijection[Int, Int] + b2 += (12 -> 33) + b2 += (11 -> 33) + + assert(b2.getB(12) === None) + assert(b2.getB(11) === Some(33)) + assert(b2.getA(33) === Some(11)) + } + + test("Bijection multiple mixed updates should maintain the invariant") { + val b = new Bijection[Int, Int] + b += (12 -> 33) + b += (13 -> 34) + b += (12 -> 34) + b += (11 -> 33) + b += (13 -> 32) + + assert(b.getB(12) == Some(34)) + assert(b.getB(13) == Some(32)) + assert(b.getB(11) == Some(33)) + assert(b.getA(34) == Some(12)) + assert(b.getA(32) == Some(13)) + assert(b.getA(33) == Some(11)) + } + +} -- GitLab