diff --git a/library/lang/Rational.scala b/library/lang/Rational.scala index e2d6f758873623a8ed7c1ccee8a973b4264dca6e..c8027dee9c801e3d59505e4a307ca21d901fedba 100644 --- a/library/lang/Rational.scala +++ b/library/lang/Rational.scala @@ -12,32 +12,34 @@ case class Rational(numerator: BigInt, denominator: BigInt) { def +(that: Rational): Rational = { require(this.isRational && that.isRational) Rational(this.numerator*that.denominator + that.numerator*this.denominator, this.denominator*that.denominator) - } + } ensuring(res => res.isRational) def -(that: Rational): Rational = { require(this.isRational && that.isRational) Rational(this.numerator*that.denominator - that.numerator*this.denominator, this.denominator*that.denominator) - } + } ensuring(res => res.isRational) def unary_- : Rational = { require(this.isRational) Rational(-this.numerator, this.denominator) - } + } ensuring(res => res.isRational) def *(that: Rational): Rational = { require(this.isRational && that.isRational) Rational(this.numerator*that.numerator, this.denominator*that.denominator) - } + } ensuring(res => res.isRational) def /(that: Rational): Rational = { require(this.isRational && that.isRational && that.nonZero) - Rational(this.numerator*that.denominator, this.denominator*that.numerator) - } + val newNumerator = this.numerator*that.denominator + val newDenominator = this.denominator*that.numerator + normalize(newNumerator, newDenominator) + } ensuring(res => res.isRational) def reciprocal: Rational = { require(this.isRational && this.nonZero) - Rational(this.denominator, this.numerator) - } + normalize(this.denominator, this.numerator) + } ensuring(res => res.isRational) def ~(that: Rational): Boolean = { @@ -47,50 +49,22 @@ case class Rational(numerator: BigInt, denominator: BigInt) { def <(that: Rational): Boolean = { require(this.isRational && that.isRational) - if(this.denominator >= 0 && that.denominator >= 0) - this.numerator*that.denominator < that.numerator*this.denominator - else if(this.denominator >= 0 && that.denominator < 0) - this.numerator*that.denominator > that.numerator*this.denominator - else if(this.denominator < 0 && that.denominator >= 0) - this.numerator*that.denominator > that.numerator*this.denominator - else - this.numerator*that.denominator < that.numerator*this.denominator + this.numerator*that.denominator < that.numerator*this.denominator } def <=(that: Rational): Boolean = { require(this.isRational && that.isRational) - if(this.denominator >= 0 && that.denominator >= 0) - this.numerator*that.denominator <= that.numerator*this.denominator - else if(this.denominator >= 0 && that.denominator < 0) - this.numerator*that.denominator >= that.numerator*this.denominator - else if(this.denominator < 0 && that.denominator >= 0) - this.numerator*that.denominator >= that.numerator*this.denominator - else - this.numerator*that.denominator <= that.numerator*this.denominator + this.numerator*that.denominator <= that.numerator*this.denominator } def >(that: Rational): Boolean = { require(this.isRational && that.isRational) - if(this.denominator >= 0 && that.denominator >= 0) - this.numerator*that.denominator > that.numerator*this.denominator - else if(this.denominator >= 0 && that.denominator < 0) - this.numerator*that.denominator < that.numerator*this.denominator - else if(this.denominator < 0 && that.denominator >= 0) - this.numerator*that.denominator < that.numerator*this.denominator - else - this.numerator*that.denominator > that.numerator*this.denominator + this.numerator*that.denominator > that.numerator*this.denominator } def >=(that: Rational): Boolean = { require(this.isRational && that.isRational) - if(this.denominator >= 0 && that.denominator >= 0) - this.numerator*that.denominator >= that.numerator*this.denominator - else if(this.denominator >= 0 && that.denominator < 0) - this.numerator*that.denominator <= that.numerator*this.denominator - else if(this.denominator < 0 && that.denominator >= 0) - this.numerator*that.denominator <= that.numerator*this.denominator - else - this.numerator*that.denominator >= that.numerator*this.denominator + this.numerator*that.denominator >= that.numerator*this.denominator } def nonZero: Boolean = { @@ -98,7 +72,14 @@ case class Rational(numerator: BigInt, denominator: BigInt) { numerator != 0 } - def isRational: Boolean = denominator != 0 + def isRational: Boolean = denominator > 0 + + private def normalize(num: BigInt, den: BigInt): Rational = { + if(den < 0) + Rational(-num, -den) + else + Rational(num, den) + } } object Rational { diff --git a/testcases/verification/math/RationalProps.scala b/testcases/verification/math/RationalProps.scala index e84b92361f1ca79a7ae23d9b149d6cbeeb48620b..60d30a6c69d361ef3185e572357f8fcb52708227 100644 --- a/testcases/verification/math/RationalProps.scala +++ b/testcases/verification/math/RationalProps.scala @@ -2,6 +2,8 @@ import leon.lang._ import leon.collection._ import leon._ +import scala.language.postfixOps + object RationalProps { def squarePos(r: Rational): Rational = {