From 42ffffa3949d298039bf1e5b5362ed935617c1b2 Mon Sep 17 00:00:00 2001 From: Regis Blanc <regwblanc@gmail.com> Date: Tue, 14 Jul 2015 14:32:41 +0200 Subject: [PATCH] real complete enough to prove theorems --- build.sbt | 2 +- library/lang/Real.scala | 8 +++ .../leon/frontends/scalac/ASTExtractors.scala | 8 +++ .../frontends/scalac/CodeExtraction.scala | 22 +++++- .../scala/leon/purescala/Constructors.scala | 5 ++ .../scala/leon/purescala/Expressions.scala | 2 +- .../scala/leon/purescala/Extractors.scala | 10 +++ .../scala/leon/purescala/PrettyPrinter.scala | 7 ++ .../leon/solvers/smtlib/SMTLIBSolver.scala | 27 ++++++- testcases/verification/math/RealProps.scala | 72 +++++++++++++++++++ 10 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 testcases/verification/math/RealProps.scala diff --git a/build.sbt b/build.sbt index 58a1d325b..fae3a337f 100644 --- a/build.sbt +++ b/build.sbt @@ -118,7 +118,7 @@ def ghProject(repo: String, version: String) = RootProject(uri(s"${repo}#${versi lazy val bonsai = ghProject("git://github.com/colder/bonsai.git", "0fec9f97f4220fa94b1f3f305f2e8b76a3cd1539") -lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "711e9a1ef994935482bc83ff3795a94f637f0a04") +lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "6b74fb332416d470e0be7bf6e94ebc18fd300c2f") lazy val root = (project in file(".")). configs(PerfTest). diff --git a/library/lang/Real.scala b/library/lang/Real.scala index 3e94b27c0..9d600b3b6 100644 --- a/library/lang/Real.scala +++ b/library/lang/Real.scala @@ -10,8 +10,16 @@ class Real { def /(a: Real): Real = ??? def unary_- : Real = ??? + + def > (a: Real): Boolean = ??? + def >=(a: Real): Boolean = ??? + def < (a: Real): Boolean = ??? + def <=(a: Real): Boolean = ??? + } +@ignore object Real { def apply(n: BigInt, d: BigInt): Real = ??? + def apply(n: BigInt): Real = ??? } diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 348680e6e..7765eacd0 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -255,6 +255,14 @@ trait ASTExtractors { None } } + object ExRealIntLiteral { + def unapply(tree: Tree): Option[Tree] = tree match { + case Apply(ExSelected("leon", "lang", "Real", "apply"), n :: Nil) => + Some(n) + case _ => + None + } + } object ExIntToBigInt { diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index d7cab0c0a..8ede9e830 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -1213,8 +1213,23 @@ trait CodeExtraction extends ASTExtractors { } } - case ExRealLiteral(n: Literal, d: Literal) => - RealLiteral((BigInt(n.value.stringValue), BigInt(d.value.stringValue))) + case ExRealLiteral(n, d) => + val rn = extractTree(n) + val rd = extractTree(d) + (rn, rd) match { + case (InfiniteIntegerLiteral(n), InfiniteIntegerLiteral(d)) => + RealLiteral(BigDecimal(n) / BigDecimal(d)) + case _ => + outOfSubsetError(tr, "Real not build from literals") + } + case ExRealIntLiteral(n) => + val rn = extractTree(n) + rn match { + case InfiniteIntegerLiteral(n) => + RealLiteral(BigDecimal(n)) + case _ => + outOfSubsetError(tr, "Real not build from literals") + } case ExInt32Literal(v) => IntLiteral(v) @@ -1698,6 +1713,9 @@ trait CodeExtraction extends ASTExtractors { case TypeRef(_, sym, _) if isBigIntSym(sym) => IntegerType + case TypeRef(_, sym, _) if isRealSym(sym) => + RealType + case TypeRef(_, sym, btt :: Nil) if isScalaSetSym(sym) => outOfSubsetError(pos, "Scala's Set API is no longer extracted. Make sure you import leon.lang.Set that defines supported Set operations.") diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 05ed1e034..e520e049e 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -262,8 +262,11 @@ object Constructors { case (_, InfiniteIntegerLiteral(bi)) if bi == 0 => lhs case (IntLiteral(0), _) => rhs case (_, IntLiteral(0)) => lhs + case (RealLiteral(d), _) if d == 0 => rhs + case (_, RealLiteral(d)) if d == 0 => lhs case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Plus(lhs, rhs) case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVPlus(lhs, rhs) + case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealPlus(lhs, rhs) } def minus(lhs: Expr, rhs: Expr): Expr = (lhs, rhs) match { @@ -273,6 +276,7 @@ object Constructors { case (IntLiteral(0), _) => BVUMinus(rhs) case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Minus(lhs, rhs) case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVMinus(lhs, rhs) + case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealMinus(lhs, rhs) } def times(lhs: Expr, rhs: Expr): Expr = (lhs, rhs) match { @@ -286,6 +290,7 @@ object Constructors { case (_, IntLiteral(0)) => IntLiteral(0) case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Times(lhs, rhs) case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVTimes(lhs, rhs) + case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealTimes(lhs, rhs) } } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 12900e3ec..5c02df406 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -274,7 +274,7 @@ object Expressions { case class InfiniteIntegerLiteral(value: BigInt) extends Literal[BigInt] { val getType = IntegerType } - case class RealLiteral(value: (BigInt, BigInt)) extends Literal[(BigInt, BigInt)] { + case class RealLiteral(value: BigDecimal) extends Literal[BigDecimal] { val getType = RealType } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index c22a17575..f2c1d1804 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -23,6 +23,8 @@ object Extractors { Some((Seq(t), (es: Seq[Expr]) => UMinus(es.head))) case BVUMinus(t) => Some((Seq(t), (es: Seq[Expr]) => BVUMinus(es.head))) + case RealUMinus(t) => + Some((Seq(t), (es: Seq[Expr]) => RealUMinus(es.head))) case BVNot(t) => Some((Seq(t), (es: Seq[Expr]) => BVNot(es.head))) case SetCardinality(t) => @@ -92,6 +94,14 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => BVAShiftRight(es(0), es(1))) case BVLShiftRight(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => BVLShiftRight(es(0), es(1))) + case RealPlus(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => plus(es(0), es(1))) + case RealMinus(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => minus(es(0), es(1))) + case RealTimes(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => times(es(0), es(1))) + case RealDivision(t1, t2) => + Some(Seq(t1, t2), (es: Seq[Expr]) => RealDivision(es(0), es(1))) case ElementOfSet(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => ElementOfSet(es(0), es(1))) case SubsetOf(t1, t2) => diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index d179cc7eb..075402f13 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -161,9 +161,11 @@ class PrettyPrinter(opts: PrinterOptions, case Implies(l,r) => optP { p"$l ==> $r" } case UMinus(expr) => p"-$expr" case BVUMinus(expr) => p"-$expr" + case RealUMinus(expr) => p"-$expr" case Equals(l,r) => optP { p"$l == $r" } case IntLiteral(v) => p"$v" case InfiniteIntegerLiteral(v) => p"$v" + case RealLiteral(d) => p"$d" case CharLiteral(v) => p"$v" case BooleanLiteral(v) => p"$v" case UnitLiteral() => p"()" @@ -262,6 +264,10 @@ class PrettyPrinter(opts: PrinterOptions, case BVShiftLeft(l,r) => optP { p"$l << $r" } case BVAShiftRight(l,r) => optP { p"$l >> $r" } case BVLShiftRight(l,r) => optP { p"$l >>> $r" } + case RealPlus(l,r) => optP { p"$l + $r" } + case RealMinus(l,r) => optP { p"$l - $r" } + case RealTimes(l,r) => optP { p"$l * $r" } + case RealDivision(l,r) => optP { p"$l / $r" } case fs @ FiniteSet(rs, _) => p"{${rs.toSeq}}" case fm @ FiniteMap(rs, _, _) => p"{$rs}" case Not(ElementOfSet(e,s)) => p"$e \u2209 $s" @@ -411,6 +417,7 @@ class PrettyPrinter(opts: PrinterOptions, case UnitType => p"Unit" case Int32Type => p"Int" case IntegerType => p"BigInt" + case RealType => p"Real" case CharType => p"Char" case BooleanType => p"Boolean" case ArrayType(bt) => p"Array[$bt]" diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index 4fdf7b6e1..51df2bd3f 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -198,6 +198,7 @@ abstract class SMTLIBSolver(val context: LeonContext, tpe match { case BooleanType => Core.BoolSort() case IntegerType => Ints.IntSort() + case RealType => Reals.RealSort() case Int32Type => FixedSizeBitVectors.BitVectorSort(32) case CharType => FixedSizeBitVectors.BitVectorSort(32) @@ -310,6 +311,7 @@ abstract class SMTLIBSolver(val context: LeonContext, case InfiniteIntegerLiteral(i) => if (i >= 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i)) case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i)) + case RealLiteral(d) => if (d >= 0) Reals.DecimalLit(d) else Reals.Neg(Reals.DecimalLit(-d)) case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt)) case BooleanLiteral(v) => Core.BoolConst(v) case Let(b,d,e) => @@ -507,21 +509,25 @@ abstract class SMTLIBSolver(val context: LeonContext, case LessThan(a,b) => a.getType match { case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) case IntegerType => Ints.LessThan(toSMT(a), toSMT(b)) + case RealType => Reals.LessThan(toSMT(a), toSMT(b)) case CharType => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b)) } case LessEquals(a,b) => a.getType match { case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b)) + case RealType => Reals.LessEquals(toSMT(a), toSMT(b)) case CharType => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b)) } case GreaterThan(a,b) => a.getType match { case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b)) + case RealType => Reals.GreaterThan(toSMT(a), toSMT(b)) case CharType => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b)) } case GreaterEquals(a,b) => a.getType match { case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b)) + case RealType => Reals.GreaterEquals(toSMT(a), toSMT(b)) case CharType => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b)) } case BVPlus(a,b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b)) @@ -535,6 +541,12 @@ abstract class SMTLIBSolver(val context: LeonContext, case BVShiftLeft(a,b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b)) case BVAShiftRight(a,b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) case BVLShiftRight(a,b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) + + case RealPlus(a,b) => Reals.Add(toSMT(a), toSMT(b)) + case RealMinus(a,b) => Reals.Sub(toSMT(a), toSMT(b)) + case RealTimes(a,b) => Reals.Mul(toSMT(a), toSMT(b)) + case RealDivision(a,b) => Reals.Div(toSMT(a), toSMT(b)) + case And(sub) => Core.And(sub.map(toSMT): _*) case Or(sub) => Core.Or(sub.map(toSMT): _*) case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze)) @@ -570,6 +582,9 @@ abstract class SMTLIBSolver(val context: LeonContext, case (SNumeral(n), IntegerType) => InfiniteIntegerLiteral(n) + case (SDecimal(d), RealType) => + RealLiteral(d) + case (Core.True(), BooleanType) => BooleanLiteral(true) case (Core.False(), BooleanType) => BooleanLiteral(false) @@ -640,9 +655,15 @@ abstract class SMTLIBSolver(val context: LeonContext, case (FunctionApplication(SimpleSymbol(SSymbol(app)), args), tpe) => { app match { case "-" => - args match { - case List(a) => UMinus(fromSMT(a, IntegerType)) - case List(a, b) => Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + (args, tpe) match { + case (List(a), IntegerType) => UMinus(fromSMT(a, IntegerType)) + case (List(a, b), IntegerType) => Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) + case (List(a), RealType) => RealUMinus(fromSMT(a, RealType)) + case (List(a, b), RealType) => RealMinus(fromSMT(a, RealType), fromSMT(b, RealType)) + } + case "/" => + (args, tpe) match { + case (List(a, b), RealType) => RealDivision(fromSMT(a, RealType), fromSMT(b, RealType)) } case _ => diff --git a/testcases/verification/math/RealProps.scala b/testcases/verification/math/RealProps.scala new file mode 100644 index 000000000..0530ab97a --- /dev/null +++ b/testcases/verification/math/RealProps.scala @@ -0,0 +1,72 @@ +import leon.lang._ +import leon.collection._ +import leon._ + +import scala.language.postfixOps + +object RealProps { + + def plusIsCommutative(a: Real, b: Real): Boolean = { + a + b == b + a + } holds + + def plusIsAssociative(a: Real, b: Real, c: Real): Boolean = { + (a + b) + c == a + (b + c) + } holds + + def timesIsCommutative(a: Real, b: Real): Boolean = { + a * b == b * a + } holds + + def timesIsAssociative(a: Real, b: Real, c: Real): Boolean = { + (a * b) * c == a * (b * c) + } holds + + def distributivity(a: Real, b: Real, c: Real): Boolean = { + a*(b + c) == a*b + a*c + } holds + + def lessEqualsTransitive(a: Real, b: Real, c: Real): Boolean = { + require(a <= b && b <= c) + a <= c + } holds + + def lessThanTransitive(a: Real, b: Real, c: Real): Boolean = { + require(a < b && b < c) + a < c + } holds + + def greaterEqualsTransitive(a: Real, b: Real, c: Real): Boolean = { + require(a >= b && b >= c) + a >= c + } holds + + def greaterThanTransitive(a: Real, b: Real, c: Real): Boolean = { + require(a > b && b > c) + a > c + } holds + + /* between any two real, there is another real */ + def density(a: Real, b: Real): Boolean = { + require(a < b) + val mid = (a + b) / Real(2) + a < mid && mid < b + } holds + + def identity1(a: Real): Boolean = { + require(a != Real(0)) + a/a == Real(1) + } holds + + def test(r: Real): Real = { + r + } ensuring(res => res >= Real(0)) + + def findRoot(r: Real): Real = { + r * r + } ensuring(res => res != Real(4)) + + //def findSqrt2(r: Real): Real = { + // r * r + //} ensuring(res => res != Real(2)) +} -- GitLab