From 8b18f6816a102e38669fc5a932365213b080ea0b Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Mon, 7 Nov 2016 09:41:26 +0100
Subject: [PATCH] Bag encoder that is sound for models

---
 .../inox/solvers/theories/BagEncoder.scala    | 37 ++++++++++++-------
 1 file changed, 24 insertions(+), 13 deletions(-)

diff --git a/src/main/scala/inox/solvers/theories/BagEncoder.scala b/src/main/scala/inox/solvers/theories/BagEncoder.scala
index f984320e9..86e527993 100644
--- a/src/main/scala/inox/solvers/theories/BagEncoder.scala
+++ b/src/main/scala/inox/solvers/theories/BagEncoder.scala
@@ -22,16 +22,18 @@ trait BagEncoder extends TheoryEncoder {
 
   val Bag = T(BagID)
 
+  private def get(bag: Expr, x: Expr): Expr = {
+    if_ (bag.getField(keys) contains x) {
+      bag.getField(f)(x)
+    } else_ {
+      E(BigInt(0))
+    }
+  }
+
   val GetID = FreshIdentifier("get")
   val Get = mkFunDef(GetID)("T") { case Seq(aT) => (
     Seq("bag" :: Bag(aT), "x" :: aT),
-    IntegerType, { case Seq(bag, x) =>
-      if_ (bag.getField(keys) contains x) {
-        bag.getField(f)(x)
-      } else_ {
-        E(BigInt(0))
-      }
-    })
+    IntegerType, { case Seq(bag, x) => get(bag, x) })
   }
 
   val AddID = FreshIdentifier("add")
@@ -39,7 +41,7 @@ trait BagEncoder extends TheoryEncoder {
     Seq("bag" :: Bag(aT), "x" :: aT),
     Bag(aT), { case Seq(bag, x) => Bag(aT)(
       bag.getField(keys) insert x,
-      \("y" :: aT)(y => bag.getField(f)(y) + {
+      \("y" :: aT)(y => get(bag, y) + {
         if_ (y === x) { E(BigInt(1)) } else_ { E(BigInt(0)) }
       }))
     })
@@ -50,7 +52,7 @@ trait BagEncoder extends TheoryEncoder {
     Seq("b1" :: Bag(aT), "b2" :: Bag(aT)),
     Bag(aT), { case Seq(b1, b2) => Bag(aT)(
       b1.getField(keys) ++ b2.getField(keys),
-      \("y" :: aT)(y => b1.getField(f)(y) + b2.getField(f)(y)))
+      \("y" :: aT)(y => get(b1, y) + get(b2, y)))
     })
   }
 
@@ -59,7 +61,7 @@ trait BagEncoder extends TheoryEncoder {
     Seq("b1" :: Bag(aT), "b2" :: Bag(aT)),
     Bag(aT), { case Seq(b1, b2) => Bag(aT)(
       b1.getField(keys),
-      \("y" :: aT)(y => let("res" :: IntegerType, b1.getField(f)(y) - b2.getField(f)(y)) {
+      \("y" :: aT)(y => let("res" :: IntegerType, get(b1, y) - get(b2, y)) {
         res => if_ (res < E(BigInt(0))) { E(BigInt(0)) } else_ { res }
       }))
     })
@@ -70,8 +72,8 @@ trait BagEncoder extends TheoryEncoder {
     Seq("b1" :: Bag(aT), "b2" :: Bag(aT)),
     Bag(aT), { case Seq(b1, b2) => Bag(aT)(
       b1.getField(keys) & b2.getField(keys),
-      \("y" :: aT)(y => let("r1" :: IntegerType, b1.getField(f)(y)) { r1 =>
-        let("r2" :: IntegerType, b2.getField(f)(y)) { r2 =>
+      \("y" :: aT)(y => let("r1" :: IntegerType, get(b1, y)) { r1 =>
+        let("r2" :: IntegerType, get(b2, y)) { r2 =>
           if_ (r1 > r2) { r2 } else_ { r1 }
         }
       }))
@@ -82,7 +84,16 @@ trait BagEncoder extends TheoryEncoder {
   val BagEquals = mkFunDef(EqualsID)("T") { case Seq(aT) => (
     Seq("b1" :: Bag(aT), "b2" :: Bag(aT)),
     BooleanType, { case Seq(b1, b2) =>
-      forall("x" :: aT)(x => b1.getField(f)(x) === b2.getField(f)(x))
+      forall("x" :: aT) { x =>
+        let("f1x" :: aT, b1.getField(f)(x)) { f1x =>
+          let("f2x" :: aT, b2.getField(f)(x)) { f2x =>
+            f1x === f2x ||
+            (!(b1.getField(keys) contains x) && f2x === E(BigInt(0))) ||
+            (f1x === E(BigInt(0)) && !(b2.getField(keys) contains x)) ||
+            (!(b1.getField(keys) contains x) && !(b2.getField(keys) contains x))
+          }
+        }
+      }
     })
   }
 
-- 
GitLab