From f81cd4e6f6c0d0b3227e0def4ba32119c8244e7f Mon Sep 17 00:00:00 2001 From: Regis Blanc <regwblanc@gmail.com> Date: Wed, 8 Jul 2015 18:10:47 +0200 Subject: [PATCH] conditional probabilities --- testcases/proof/proba/Coins.scala | 58 +++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/testcases/proof/proba/Coins.scala b/testcases/proof/proba/Coins.scala index c35115372..ff8bf9122 100644 --- a/testcases/proof/proba/Coins.scala +++ b/testcases/proof/proba/Coins.scala @@ -38,6 +38,55 @@ object Coins { def isIndependent(dist: CoinsJoinDist): Boolean = join(firstCoin(dist), secondCoin(dist)) == dist + def isEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = { + require(isDist(dist1) && dist1.hh > 0 && isDist(dist2) && dist2.hh > 0) + + dist1.hh*dist2.hh == dist2.hh*dist1.hh && + dist1.ht*dist2.hh == dist2.ht*dist1.hh && + dist1.th*dist2.hh == dist2.th*dist1.hh && + dist1.tt*dist2.hh == dist2.tt*dist1.hh + } + + + case class CoinCondDist(ifHead: CoinDist, ifTail: CoinDist) + + def condByFirstCoin(dist: CoinsJoinDist): CoinCondDist = { + CoinCondDist( + CoinDist( + dist.hh*(dist.th + dist.tt), //probability of head if head + dist.ht*(dist.th + dist.tt) //probability of tail if head + ), + CoinDist( + dist.th*(dist.hh + dist.ht), //probability of head if tail + dist.tt*(dist.hh + dist.ht) //probability of tail if tail + ) + ) + } + + def combine(cond: CoinCondDist, dist: CoinDist): CoinsJoinDist = { + require(isDist(dist) && dist.pHead > 0 && dist.pTail > 0) + + val hh = cond.ifHead.pHead * dist.pHead + val ht = cond.ifHead.pTail * dist.pHead + val th = cond.ifTail.pHead * dist.pTail + val tt = cond.ifTail.pTail * dist.pTail + CoinsJoinDist(hh, ht, th, tt) + } + + def condIsSound(dist: CoinsJoinDist): Boolean = { + require(isDist(dist) && dist.hh > 0 && dist.ht > 0 && dist.th > 0 && dist.tt > 0) + + val computedDist = combine(condByFirstCoin(dist), firstCoin(dist)) + isEquivalent(dist, computedDist) + } holds + + + //should be INVALID + def anyDistributionsNotEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = { + require(isDist(dist1) && isDist(dist2) && dist1.hh > 0 && dist2.hh > 0) + isEquivalent(dist1, dist2) + } holds + //sum modulo: face is 0, tail is 1 def sum(coin1: CoinDist, coin2: CoinDist): CoinDist = { require(isDist(coin1) && isDist(coin2)) @@ -55,6 +104,15 @@ object Coins { ) } + + + + + + /*************************************************** + * properties of sum operation * + ***************************************************/ + def sumIsUniform1(coin1: CoinDist, coin2: CoinDist): Boolean = { require(isDist(coin1) && isDist(coin2) && isUniform(coin1) && isUniform(coin2)) val dist = sum(coin1, coin2) -- GitLab