Skip to content
Snippets Groups Projects
Commit 5c500290 authored by Regis Blanc's avatar Regis Blanc
Browse files

rewrite coins with rational

parent f81cd4e6
No related branches found
No related tags found
No related merge requests found
...@@ -3,24 +3,29 @@ import leon.lang._ ...@@ -3,24 +3,29 @@ import leon.lang._
object Coins { object Coins {
/* case class CoinDist(pHead: Rational) {
* number of outcomes for each face def pTail: Rational = Rational(1) - pHead
*/ }
case class CoinDist(pHead: BigInt, pTail: BigInt)
def isDist(dist: CoinDist): Boolean =
dist.pHead.isRational && dist.pHead >= Rational(0) && dist.pHead <= Rational(1)
case class CoinsJoinDist(hh: Rational, ht: Rational, th: Rational, tt: Rational)
def isDist(dist: CoinsJoinDist): Boolean =
dist.hh.isRational && dist.hh >= Rational(0) && dist.hh <= Rational(1) &&
dist.ht.isRational && dist.ht >= Rational(0) && dist.ht <= Rational(1) &&
dist.th.isRational && dist.th >= Rational(0) && dist.th <= Rational(1) &&
dist.tt.isRational && dist.tt >= Rational(0) && dist.tt <= Rational(1) &&
(dist.hh + dist.ht + dist.th + dist.tt) ~ Rational(1)
case class CoinsJoinDist(hh: BigInt, ht: BigInt, th: BigInt, tt: BigInt)
def isUniform(dist: CoinDist): Boolean = { def isUniform(dist: CoinDist): Boolean = {
require(isDist(dist)) require(isDist(dist))
dist.pHead == dist.pTail dist.pHead ~ Rational(1, 2)
} }
def isDist(dist: CoinDist): Boolean =
dist.pHead >= 0 && dist.pTail >= 0 && (dist.pHead > 0 || dist.pTail > 0)
def isDist(dist: CoinsJoinDist): Boolean =
dist.hh >= 0 && dist.ht >= 0 && dist.th >= 0 && dist.tt >= 0 &&
(dist.hh > 0 || dist.ht > 0 || dist.th > 0 || dist.tt > 0)
def join(c1: CoinDist, c2: CoinDist): CoinsJoinDist = def join(c1: CoinDist, c2: CoinDist): CoinsJoinDist =
CoinsJoinDist( CoinsJoinDist(
...@@ -29,82 +34,76 @@ object Coins { ...@@ -29,82 +34,76 @@ object Coins {
c1.pTail*c2.pHead, c1.pTail*c2.pHead,
c1.pTail*c2.pTail) c1.pTail*c2.pTail)
def firstCoin(dist: CoinsJoinDist): CoinDist = def firstCoin(dist: CoinsJoinDist): CoinDist = {
CoinDist(dist.hh + dist.ht, dist.th + dist.tt) CoinDist(dist.hh + dist.ht)
} ensuring(res => res.pTail ~ (dist.th + dist.tt))
def secondCoin(dist: CoinsJoinDist): CoinDist = def secondCoin(dist: CoinsJoinDist): CoinDist = {
CoinDist(dist.hh + dist.th, dist.ht + dist.tt) CoinDist(dist.hh + dist.th)
} ensuring(res => res.pTail ~ (dist.ht + dist.tt))
def isIndependent(dist: CoinsJoinDist): Boolean = def isIndependent(dist: CoinsJoinDist): Boolean =
join(firstCoin(dist), secondCoin(dist)) == dist join(firstCoin(dist), secondCoin(dist)) == dist
def isEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = { def isEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = {
require(isDist(dist1) && dist1.hh > 0 && isDist(dist2) && dist2.hh > 0) require(isDist(dist1) && isDist(dist2))
dist1.hh*dist2.hh == dist2.hh*dist1.hh && (dist1.hh ~ dist2.hh) &&
dist1.ht*dist2.hh == dist2.ht*dist1.hh && (dist1.ht ~ dist2.ht) &&
dist1.th*dist2.hh == dist2.th*dist1.hh && (dist1.th ~ dist2.th) &&
dist1.tt*dist2.hh == dist2.tt*dist1.hh (dist1.tt ~ dist2.tt)
} }
case class CoinCondDist(ifHead: CoinDist, ifTail: CoinDist) //case class CoinCondDist(ifHead: CoinDist, ifTail: CoinDist)
def condByFirstCoin(dist: CoinsJoinDist): CoinCondDist = { //def condByFirstCoin(dist: CoinsJoinDist): CoinCondDist = {
CoinCondDist( // CoinCondDist(
CoinDist( // CoinDist(
dist.hh*(dist.th + dist.tt), //probability of head if head // dist.hh*(dist.th + dist.tt), //probability of head if head
dist.ht*(dist.th + dist.tt) //probability of tail if head // dist.ht*(dist.th + dist.tt) //probability of tail if head
), // ),
CoinDist( // CoinDist(
dist.th*(dist.hh + dist.ht), //probability of head if tail // dist.th*(dist.hh + dist.ht), //probability of head if tail
dist.tt*(dist.hh + dist.ht) //probability of tail if tail // dist.tt*(dist.hh + dist.ht) //probability of tail if tail
) // )
) // )
} //}
def combine(cond: CoinCondDist, dist: CoinDist): CoinsJoinDist = { //def combine(cond: CoinCondDist, dist: CoinDist): CoinsJoinDist = {
require(isDist(dist) && dist.pHead > 0 && dist.pTail > 0) // require(isDist(dist) && dist.pHead > 0 && dist.pTail > 0)
val hh = cond.ifHead.pHead * dist.pHead // val hh = cond.ifHead.pHead * dist.pHead
val ht = cond.ifHead.pTail * dist.pHead // val ht = cond.ifHead.pTail * dist.pHead
val th = cond.ifTail.pHead * dist.pTail // val th = cond.ifTail.pHead * dist.pTail
val tt = cond.ifTail.pTail * dist.pTail // val tt = cond.ifTail.pTail * dist.pTail
CoinsJoinDist(hh, ht, th, tt) // CoinsJoinDist(hh, ht, th, tt)
} //}
//
def condIsSound(dist: CoinsJoinDist): Boolean = { //def condIsSound(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && dist.hh > 0 && dist.ht > 0 && dist.th > 0 && dist.tt > 0) // require(isDist(dist) && dist.hh > 0 && dist.ht > 0 && dist.th > 0 && dist.tt > 0)
val computedDist = combine(condByFirstCoin(dist), firstCoin(dist)) // val computedDist = combine(condByFirstCoin(dist), firstCoin(dist))
isEquivalent(dist, computedDist) // isEquivalent(dist, computedDist)
} holds //} holds
//should be INVALID //should be INVALID
def anyDistributionsNotEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = { def anyDistributionsNotEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = {
require(isDist(dist1) && isDist(dist2) && dist1.hh > 0 && dist2.hh > 0) require(isDist(dist1) && isDist(dist2))
isEquivalent(dist1, dist2) isEquivalent(dist1, dist2)
} holds } holds
//sum modulo: face is 0, tail is 1 //sum modulo: face is 0, tail is 1
def sum(coin1: CoinDist, coin2: CoinDist): CoinDist = { def sum(coin1: CoinDist, coin2: CoinDist): CoinDist = {
require(isDist(coin1) && isDist(coin2)) require(isDist(coin1) && isDist(coin2))
CoinDist( CoinDist(coin1.pHead*coin2.pHead + coin1.pTail*coin2.pTail)
coin1.pHead*coin2.pHead + coin1.pTail*coin2.pTail, } ensuring(res => res.pTail ~ (coin1.pHead*coin2.pTail + coin1.pTail*coin2.pHead))
coin1.pHead*coin2.pTail + coin1.pTail*coin2.pHead
)
}
def sum(dist: CoinsJoinDist): CoinDist = { def sum(dist: CoinsJoinDist): CoinDist = {
require(isDist(dist)) require(isDist(dist))
CoinDist( CoinDist(dist.hh + dist.tt)
dist.hh + dist.tt, } ensuring(res => res.pTail ~ (dist.ht + dist.th))
dist.ht + dist.th
)
}
...@@ -113,73 +112,73 @@ object Coins { ...@@ -113,73 +112,73 @@ object Coins {
* properties of sum operation * * properties of sum operation *
***************************************************/ ***************************************************/
def sumIsUniform1(coin1: CoinDist, coin2: CoinDist): Boolean = { //def sumIsUniform1(coin1: CoinDist, coin2: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isUniform(coin1) && isUniform(coin2)) // require(isDist(coin1) && isDist(coin2) && isUniform(coin1) && isUniform(coin2))
val dist = sum(coin1, coin2) // val dist = sum(coin1, coin2)
isUniform(dist) // isUniform(dist)
} holds //} holds
def sumIsUniform2(coin1: CoinDist, coin2: CoinDist): Boolean = { //def sumIsUniform2(coin1: CoinDist, coin2: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isUniform(coin1)) // require(isDist(coin1) && isDist(coin2) && isUniform(coin1))
val dist = sum(coin1, coin2) // val dist = sum(coin1, coin2)
isUniform(dist) // isUniform(dist)
} holds //} holds
def sumUniform3(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = { //def sumUniform3(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isDist(coin3) && isUniform(coin1)) // require(isDist(coin1) && isDist(coin2) && isDist(coin3) && isUniform(coin1))
val dist = sum(sum(coin1, coin2), coin3) // val dist = sum(sum(coin1, coin2), coin3)
isUniform(dist) // isUniform(dist)
} holds //} holds
def sumUniformWithIndependence(dist: CoinsJoinDist): Boolean = { //def sumUniformWithIndependence(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && isIndependent(dist) && (isUniform(firstCoin(dist)) || isUniform(secondCoin(dist)))) // require(isDist(dist) && isIndependent(dist) && (isUniform(firstCoin(dist)) || isUniform(secondCoin(dist))))
val res = sum(dist) // val res = sum(dist)
isUniform(res) // isUniform(res)
} holds //} holds
//should find counterexample, indepenence is required ////should find counterexample, indepenence is required
def sumUniformWithoutIndependence(dist: CoinsJoinDist): Boolean = { //def sumUniformWithoutIndependence(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && (isUniform(firstCoin(dist)) || isUniform(secondCoin(dist)))) // require(isDist(dist) && (isUniform(firstCoin(dist)) || isUniform(secondCoin(dist))))
val res = sum(dist) // val res = sum(dist)
isUniform(res) // isUniform(res)
} holds //} holds
//sum of two non-uniform dices is potentially uniform (no result) ////sum of two non-uniform dices is potentially uniform (no result)
def sumNonUniform1(coin1: CoinDist, coin2: CoinDist): Boolean = { //def sumNonUniform1(coin1: CoinDist, coin2: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && !isUniform(coin1) && !isUniform(coin2)) // require(isDist(coin1) && isDist(coin2) && !isUniform(coin1) && !isUniform(coin2))
val dist = sum(coin1, coin2) // val dist = sum(coin1, coin2)
!isUniform(dist) // !isUniform(dist)
} holds //} holds
def sumNonUniform2(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = { //def sumNonUniform2(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isDist(coin3) && !isUniform(coin1) && !isUniform(coin2) && !isUniform(coin3)) // require(isDist(coin1) && isDist(coin2) && isDist(coin3) && !isUniform(coin1) && !isUniform(coin2) && !isUniform(coin3))
val dist = sum(sum(coin1, coin2), coin3) // val dist = sum(sum(coin1, coin2), coin3)
!isUniform(dist) // !isUniform(dist)
} //holds //} //holds
def sumNonUniformWithIndependence(dist: CoinsJoinDist): Boolean = { //def sumNonUniformWithIndependence(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && isIndependent(dist) && !isUniform(firstCoin(dist)) && !isUniform(secondCoin(dist))) // require(isDist(dist) && isIndependent(dist) && !isUniform(firstCoin(dist)) && !isUniform(secondCoin(dist)))
val res = sum(dist) // val res = sum(dist)
!isUniform(res) // !isUniform(res)
} holds //} holds
//independence is required ////independence is required
def sumNonUniformWithoutIndependence(dist: CoinsJoinDist): Boolean = { //def sumNonUniformWithoutIndependence(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && !isUniform(firstCoin(dist)) && !isUniform(secondCoin(dist))) // require(isDist(dist) && !isUniform(firstCoin(dist)) && !isUniform(secondCoin(dist)))
val res = sum(dist) // val res = sum(dist)
!isUniform(res) // !isUniform(res)
} holds //} holds
def sumIsCommutative(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = { //def sumIsCommutative(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2)) // require(isDist(coin1) && isDist(coin2))
sum(coin1, coin2) == sum(coin2, coin1) // sum(coin1, coin2) == sum(coin2, coin1)
} holds //} holds
def sumIsAssociative(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = { //def sumIsAssociative(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isDist(coin3)) // require(isDist(coin1) && isDist(coin2) && isDist(coin3))
sum(sum(coin1, coin2), coin3) == sum(coin1, sum(coin2, coin3)) // sum(sum(coin1, coin2), coin3) == sum(coin1, sum(coin2, coin3))
} //holds //} //holds
} }
import leon.annotation._
import leon.lang._
object Coins {
/*
* number of outcomes for each face
*/
case class CoinDist(pHead: BigInt, pTail: BigInt)
case class CoinsJoinDist(hh: BigInt, ht: BigInt, th: BigInt, tt: BigInt)
def isUniform(dist: CoinDist): Boolean = {
require(isDist(dist))
dist.pHead == dist.pTail
}
def isDist(dist: CoinDist): Boolean =
dist.pHead >= 0 && dist.pTail >= 0 && (dist.pHead > 0 || dist.pTail > 0)
def isDist(dist: CoinsJoinDist): Boolean =
dist.hh >= 0 && dist.ht >= 0 && dist.th >= 0 && dist.tt >= 0 &&
(dist.hh > 0 || dist.ht > 0 || dist.th > 0 || dist.tt > 0)
def join(c1: CoinDist, c2: CoinDist): CoinsJoinDist =
CoinsJoinDist(
c1.pHead*c2.pHead,
c1.pHead*c2.pTail,
c1.pTail*c2.pHead,
c1.pTail*c2.pTail)
def firstCoin(dist: CoinsJoinDist): CoinDist =
CoinDist(dist.hh + dist.ht, dist.th + dist.tt)
def secondCoin(dist: CoinsJoinDist): CoinDist =
CoinDist(dist.hh + dist.th, dist.ht + dist.tt)
def isIndependent(dist: CoinsJoinDist): Boolean =
join(firstCoin(dist), secondCoin(dist)) == dist
def isEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = {
require(isDist(dist1) && dist1.hh > 0 && isDist(dist2) && dist2.hh > 0)
dist1.hh*dist2.hh == dist2.hh*dist1.hh &&
dist1.ht*dist2.hh == dist2.ht*dist1.hh &&
dist1.th*dist2.hh == dist2.th*dist1.hh &&
dist1.tt*dist2.hh == dist2.tt*dist1.hh
}
case class CoinCondDist(ifHead: CoinDist, ifTail: CoinDist)
def condByFirstCoin(dist: CoinsJoinDist): CoinCondDist = {
CoinCondDist(
CoinDist(
dist.hh*(dist.th + dist.tt), //probability of head if head
dist.ht*(dist.th + dist.tt) //probability of tail if head
),
CoinDist(
dist.th*(dist.hh + dist.ht), //probability of head if tail
dist.tt*(dist.hh + dist.ht) //probability of tail if tail
)
)
}
def combine(cond: CoinCondDist, dist: CoinDist): CoinsJoinDist = {
require(isDist(dist) && dist.pHead > 0 && dist.pTail > 0)
val hh = cond.ifHead.pHead * dist.pHead
val ht = cond.ifHead.pTail * dist.pHead
val th = cond.ifTail.pHead * dist.pTail
val tt = cond.ifTail.pTail * dist.pTail
CoinsJoinDist(hh, ht, th, tt)
}
def condIsSound(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && dist.hh > 0 && dist.ht > 0 && dist.th > 0 && dist.tt > 0)
val computedDist = combine(condByFirstCoin(dist), firstCoin(dist))
isEquivalent(dist, computedDist)
} holds
//should be INVALID
def anyDistributionsNotEquivalent(dist1: CoinsJoinDist, dist2: CoinsJoinDist): Boolean = {
require(isDist(dist1) && isDist(dist2) && dist1.hh > 0 && dist2.hh > 0)
isEquivalent(dist1, dist2)
} holds
//sum modulo: face is 0, tail is 1
def sum(coin1: CoinDist, coin2: CoinDist): CoinDist = {
require(isDist(coin1) && isDist(coin2))
CoinDist(
coin1.pHead*coin2.pHead + coin1.pTail*coin2.pTail,
coin1.pHead*coin2.pTail + coin1.pTail*coin2.pHead
)
}
def sum(dist: CoinsJoinDist): CoinDist = {
require(isDist(dist))
CoinDist(
dist.hh + dist.tt,
dist.ht + dist.th
)
}
/***************************************************
* properties of sum operation *
***************************************************/
def sumIsUniform1(coin1: CoinDist, coin2: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isUniform(coin1) && isUniform(coin2))
val dist = sum(coin1, coin2)
isUniform(dist)
} holds
def sumIsUniform2(coin1: CoinDist, coin2: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isUniform(coin1))
val dist = sum(coin1, coin2)
isUniform(dist)
} holds
def sumUniform3(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isDist(coin3) && isUniform(coin1))
val dist = sum(sum(coin1, coin2), coin3)
isUniform(dist)
} holds
def sumUniformWithIndependence(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && isIndependent(dist) && (isUniform(firstCoin(dist)) || isUniform(secondCoin(dist))))
val res = sum(dist)
isUniform(res)
} holds
//should find counterexample, indepenence is required
def sumUniformWithoutIndependence(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && (isUniform(firstCoin(dist)) || isUniform(secondCoin(dist))))
val res = sum(dist)
isUniform(res)
} holds
//sum of two non-uniform dices is potentially uniform (no result)
def sumNonUniform1(coin1: CoinDist, coin2: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && !isUniform(coin1) && !isUniform(coin2))
val dist = sum(coin1, coin2)
!isUniform(dist)
} holds
def sumNonUniform2(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isDist(coin3) && !isUniform(coin1) && !isUniform(coin2) && !isUniform(coin3))
val dist = sum(sum(coin1, coin2), coin3)
!isUniform(dist)
} //holds
def sumNonUniformWithIndependence(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && isIndependent(dist) && !isUniform(firstCoin(dist)) && !isUniform(secondCoin(dist)))
val res = sum(dist)
!isUniform(res)
} holds
//independence is required
def sumNonUniformWithoutIndependence(dist: CoinsJoinDist): Boolean = {
require(isDist(dist) && !isUniform(firstCoin(dist)) && !isUniform(secondCoin(dist)))
val res = sum(dist)
!isUniform(res)
} holds
def sumIsCommutative(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2))
sum(coin1, coin2) == sum(coin2, coin1)
} holds
def sumIsAssociative(coin1: CoinDist, coin2: CoinDist, coin3: CoinDist): Boolean = {
require(isDist(coin1) && isDist(coin2) && isDist(coin3))
sum(sum(coin1, coin2), coin3) == sum(coin1, sum(coin2, coin3))
} //holds
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment