From f81cd4e6f6c0d0b3227e0def4ba32119c8244e7f Mon Sep 17 00:00:00 2001
From: Regis Blanc <regwblanc@gmail.com>
Date: Wed, 8 Jul 2015 18:10:47 +0200
Subject: [PATCH] conditional probabilities

---
 testcases/proof/proba/Coins.scala | 58 +++++++++++++++++++++++++++++++
 1 file changed, 58 insertions(+)

diff --git a/testcases/proof/proba/Coins.scala b/testcases/proof/proba/Coins.scala
index c35115372..ff8bf9122 100644
--- a/testcases/proof/proba/Coins.scala
+++ b/testcases/proof/proba/Coins.scala
@@ -38,6 +38,55 @@ object Coins {
   def isIndependent(dist: CoinsJoinDist): Boolean =
     join(firstCoin(dist), secondCoin(dist)) == dist
 
+  def isEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = {
+    require(isDist(dist1) && dist1.hh > 0 && isDist(dist2) && dist2.hh > 0)
+
+    dist1.hh*dist2.hh == dist2.hh*dist1.hh &&
+    dist1.ht*dist2.hh == dist2.ht*dist1.hh &&
+    dist1.th*dist2.hh == dist2.th*dist1.hh &&
+    dist1.tt*dist2.hh == dist2.tt*dist1.hh
+  }
+
+
+  case class CoinCondDist(ifHead: CoinDist, ifTail: CoinDist)
+
+  def condByFirstCoin(dist: CoinsJoinDist): CoinCondDist = {
+    CoinCondDist(
+      CoinDist(
+        dist.hh*(dist.th + dist.tt), //probability of head if head
+        dist.ht*(dist.th + dist.tt)  //probability of tail if head
+      ),
+      CoinDist(
+        dist.th*(dist.hh + dist.ht), //probability of head if tail
+        dist.tt*(dist.hh + dist.ht)  //probability of tail if tail
+      )
+    )
+  }
+
+  def combine(cond: CoinCondDist, dist: CoinDist): CoinsJoinDist = {
+    require(isDist(dist) && dist.pHead > 0 && dist.pTail > 0)
+
+    val hh = cond.ifHead.pHead * dist.pHead
+    val ht = cond.ifHead.pTail * dist.pHead
+    val th = cond.ifTail.pHead * dist.pTail
+    val tt = cond.ifTail.pTail * dist.pTail
+    CoinsJoinDist(hh, ht, th, tt)
+  }
+  
+  def condIsSound(dist: CoinsJoinDist): Boolean = {
+    require(isDist(dist) && dist.hh > 0 && dist.ht > 0 && dist.th > 0 && dist.tt > 0)
+
+    val computedDist = combine(condByFirstCoin(dist), firstCoin(dist))
+    isEquivalent(dist, computedDist)
+  } holds
+
+
+  //should be INVALID
+  def anyDistributionsNotEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = {
+    require(isDist(dist1) && isDist(dist2) && dist1.hh > 0 && dist2.hh > 0)
+    isEquivalent(dist1, dist2)
+  } holds
+
   //sum modulo: face is 0, tail is 1
   def sum(coin1: CoinDist, coin2: CoinDist): CoinDist = {
     require(isDist(coin1) && isDist(coin2))
@@ -55,6 +104,15 @@ object Coins {
     )
   }
 
+  
+
+
+
+
+  /***************************************************
+   *          properties of sum operation            *
+   ***************************************************/
+
   def sumIsUniform1(coin1: CoinDist, coin2: CoinDist): Boolean = {
     require(isDist(coin1) && isDist(coin2) && isUniform(coin1) && isUniform(coin2))
     val dist = sum(coin1, coin2)
-- 
GitLab