From 42ffffa3949d298039bf1e5b5362ed935617c1b2 Mon Sep 17 00:00:00 2001
From: Regis Blanc <regwblanc@gmail.com>
Date: Tue, 14 Jul 2015 14:32:41 +0200
Subject: [PATCH] real complete enough to prove theorems

---
 build.sbt                                     |  2 +-
 library/lang/Real.scala                       |  8 +++
 .../leon/frontends/scalac/ASTExtractors.scala |  8 +++
 .../frontends/scalac/CodeExtraction.scala     | 22 +++++-
 .../scala/leon/purescala/Constructors.scala   |  5 ++
 .../scala/leon/purescala/Expressions.scala    |  2 +-
 .../scala/leon/purescala/Extractors.scala     | 10 +++
 .../scala/leon/purescala/PrettyPrinter.scala  |  7 ++
 .../leon/solvers/smtlib/SMTLIBSolver.scala    | 27 ++++++-
 testcases/verification/math/RealProps.scala   | 72 +++++++++++++++++++
 10 files changed, 156 insertions(+), 7 deletions(-)
 create mode 100644 testcases/verification/math/RealProps.scala

diff --git a/build.sbt b/build.sbt
index 58a1d325b..fae3a337f 100644
--- a/build.sbt
+++ b/build.sbt
@@ -118,7 +118,7 @@ def ghProject(repo: String, version: String) = RootProject(uri(s"${repo}#${versi
 
 lazy val bonsai      = ghProject("git://github.com/colder/bonsai.git",     "0fec9f97f4220fa94b1f3f305f2e8b76a3cd1539")
 
-lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "711e9a1ef994935482bc83ff3795a94f637f0a04")
+lazy val scalaSmtLib = ghProject("git://github.com/regb/scala-smtlib.git", "6b74fb332416d470e0be7bf6e94ebc18fd300c2f")
 
 lazy val root = (project in file(".")).
   configs(PerfTest).
diff --git a/library/lang/Real.scala b/library/lang/Real.scala
index 3e94b27c0..9d600b3b6 100644
--- a/library/lang/Real.scala
+++ b/library/lang/Real.scala
@@ -10,8 +10,16 @@ class Real {
    def /(a: Real): Real = ???
 
    def unary_- : Real = ???
+
+   def > (a: Real): Boolean = ???
+   def >=(a: Real): Boolean = ???
+   def < (a: Real): Boolean = ???
+   def <=(a: Real): Boolean = ???
+
 }
 
+@ignore
 object Real {
   def apply(n: BigInt, d: BigInt): Real = ???
+  def apply(n: BigInt): Real = ???
 }
diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
index 348680e6e..7765eacd0 100644
--- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
+++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala
@@ -255,6 +255,14 @@ trait ASTExtractors {
           None
       }
     }
+    object ExRealIntLiteral {
+      def unapply(tree: Tree): Option[Tree] = tree  match {
+        case Apply(ExSelected("leon", "lang", "Real", "apply"), n :: Nil) =>
+          Some(n)
+        case _ =>
+          None
+      }
+    }
 
 
     object ExIntToBigInt {
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index d7cab0c0a..8ede9e830 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -1213,8 +1213,23 @@ trait CodeExtraction extends ASTExtractors {
           }
         }
 
-        case ExRealLiteral(n: Literal, d: Literal) =>
-          RealLiteral((BigInt(n.value.stringValue), BigInt(d.value.stringValue)))
+        case ExRealLiteral(n, d) =>
+          val rn = extractTree(n)
+          val rd = extractTree(d)
+          (rn, rd) match {
+            case (InfiniteIntegerLiteral(n), InfiniteIntegerLiteral(d)) =>
+              RealLiteral(BigDecimal(n) / BigDecimal(d))
+            case _ =>
+              outOfSubsetError(tr, "Real not build from literals")
+          }
+        case ExRealIntLiteral(n) =>
+          val rn = extractTree(n)
+          rn match {
+            case InfiniteIntegerLiteral(n) =>
+              RealLiteral(BigDecimal(n))
+            case _ =>
+              outOfSubsetError(tr, "Real not build from literals")
+          }
 
         case ExInt32Literal(v) =>
           IntLiteral(v)
@@ -1698,6 +1713,9 @@ trait CodeExtraction extends ASTExtractors {
       case TypeRef(_, sym, _) if isBigIntSym(sym) =>
         IntegerType
 
+      case TypeRef(_, sym, _) if isRealSym(sym) =>
+        RealType
+
       case TypeRef(_, sym, btt :: Nil) if isScalaSetSym(sym) =>
         outOfSubsetError(pos, "Scala's Set API is no longer extracted. Make sure you import leon.lang.Set that defines supported Set operations.")
 
diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala
index 05ed1e034..e520e049e 100644
--- a/src/main/scala/leon/purescala/Constructors.scala
+++ b/src/main/scala/leon/purescala/Constructors.scala
@@ -262,8 +262,11 @@ object Constructors {
     case (_, InfiniteIntegerLiteral(bi)) if bi == 0 => lhs
     case (IntLiteral(0), _) => rhs
     case (_, IntLiteral(0)) => lhs
+    case (RealLiteral(d), _) if d == 0 => rhs
+    case (_, RealLiteral(d)) if d == 0 => lhs
     case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Plus(lhs, rhs)
     case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVPlus(lhs, rhs)
+    case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealPlus(lhs, rhs)
   }
 
   def minus(lhs: Expr, rhs: Expr): Expr = (lhs, rhs) match {
@@ -273,6 +276,7 @@ object Constructors {
     case (IntLiteral(0), _) => BVUMinus(rhs)
     case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Minus(lhs, rhs)
     case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVMinus(lhs, rhs)
+    case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealMinus(lhs, rhs)
   }
 
   def times(lhs: Expr, rhs: Expr): Expr = (lhs, rhs) match {
@@ -286,6 +290,7 @@ object Constructors {
     case (_, IntLiteral(0)) => IntLiteral(0)
     case (IsTyped(_, IntegerType), IsTyped(_, IntegerType)) => Times(lhs, rhs)
     case (IsTyped(_, Int32Type), IsTyped(_, Int32Type)) => BVTimes(lhs, rhs)
+    case (IsTyped(_, RealType), IsTyped(_, RealType)) => RealTimes(lhs, rhs)
   }
 
 }
diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala
index 12900e3ec..5c02df406 100644
--- a/src/main/scala/leon/purescala/Expressions.scala
+++ b/src/main/scala/leon/purescala/Expressions.scala
@@ -274,7 +274,7 @@ object Expressions {
   case class InfiniteIntegerLiteral(value: BigInt) extends Literal[BigInt] {
     val getType = IntegerType
   }
-  case class RealLiteral(value: (BigInt, BigInt)) extends Literal[(BigInt, BigInt)] {
+  case class RealLiteral(value: BigDecimal) extends Literal[BigDecimal] {
     val getType = RealType
   }
 
diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala
index c22a17575..f2c1d1804 100644
--- a/src/main/scala/leon/purescala/Extractors.scala
+++ b/src/main/scala/leon/purescala/Extractors.scala
@@ -23,6 +23,8 @@ object Extractors {
         Some((Seq(t), (es: Seq[Expr]) => UMinus(es.head)))
       case BVUMinus(t) =>
         Some((Seq(t), (es: Seq[Expr]) => BVUMinus(es.head)))
+      case RealUMinus(t) =>
+        Some((Seq(t), (es: Seq[Expr]) => RealUMinus(es.head)))
       case BVNot(t) =>
         Some((Seq(t), (es: Seq[Expr]) => BVNot(es.head)))
       case SetCardinality(t) =>
@@ -92,6 +94,14 @@ object Extractors {
         Some(Seq(t1, t2), (es: Seq[Expr]) => BVAShiftRight(es(0), es(1)))
       case BVLShiftRight(t1, t2) =>
         Some(Seq(t1, t2), (es: Seq[Expr]) => BVLShiftRight(es(0), es(1)))
+      case RealPlus(t1, t2) =>
+        Some(Seq(t1, t2), (es: Seq[Expr]) => plus(es(0), es(1)))
+      case RealMinus(t1, t2) =>
+        Some(Seq(t1, t2), (es: Seq[Expr]) => minus(es(0), es(1)))
+      case RealTimes(t1, t2) =>
+        Some(Seq(t1, t2), (es: Seq[Expr]) => times(es(0), es(1)))
+      case RealDivision(t1, t2) =>
+        Some(Seq(t1, t2), (es: Seq[Expr]) => RealDivision(es(0), es(1)))
       case ElementOfSet(t1, t2) =>
         Some(Seq(t1, t2), (es: Seq[Expr]) => ElementOfSet(es(0), es(1)))
       case SubsetOf(t1, t2) =>
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index d179cc7eb..075402f13 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -161,9 +161,11 @@ class PrettyPrinter(opts: PrinterOptions,
       case Implies(l,r)         => optP { p"$l ==> $r" }
       case UMinus(expr)         => p"-$expr"
       case BVUMinus(expr)       => p"-$expr"
+      case RealUMinus(expr)     => p"-$expr"
       case Equals(l,r)          => optP { p"$l == $r" }
       case IntLiteral(v)        => p"$v"
       case InfiniteIntegerLiteral(v) => p"$v"
+      case RealLiteral(d)       => p"$d"
       case CharLiteral(v)       => p"$v"
       case BooleanLiteral(v)    => p"$v"
       case UnitLiteral()        => p"()"
@@ -262,6 +264,10 @@ class PrettyPrinter(opts: PrinterOptions,
       case BVShiftLeft(l,r)          => optP { p"$l << $r" }
       case BVAShiftRight(l,r)        => optP { p"$l >> $r" }
       case BVLShiftRight(l,r)        => optP { p"$l >>> $r" }
+      case RealPlus(l,r)             => optP { p"$l + $r" }
+      case RealMinus(l,r)            => optP { p"$l - $r" }
+      case RealTimes(l,r)            => optP { p"$l * $r" }
+      case RealDivision(l,r)         => optP { p"$l / $r" }
       case fs @ FiniteSet(rs, _)     => p"{${rs.toSeq}}"
       case fm @ FiniteMap(rs, _, _)  => p"{$rs}"
       case Not(ElementOfSet(e,s))    => p"$e \u2209 $s"
@@ -411,6 +417,7 @@ class PrettyPrinter(opts: PrinterOptions,
       case UnitType              => p"Unit"
       case Int32Type             => p"Int"
       case IntegerType           => p"BigInt"
+      case RealType              => p"Real"
       case CharType              => p"Char"
       case BooleanType           => p"Boolean"
       case ArrayType(bt)         => p"Array[$bt]"
diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
index 4fdf7b6e1..51df2bd3f 100644
--- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
+++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala
@@ -198,6 +198,7 @@ abstract class SMTLIBSolver(val context: LeonContext,
       tpe match {
         case BooleanType => Core.BoolSort()
         case IntegerType => Ints.IntSort()
+        case RealType => Reals.RealSort()
         case Int32Type => FixedSizeBitVectors.BitVectorSort(32)
         case CharType => FixedSizeBitVectors.BitVectorSort(32)
 
@@ -310,6 +311,7 @@ abstract class SMTLIBSolver(val context: LeonContext,
 
       case InfiniteIntegerLiteral(i) => if (i >= 0) Ints.NumeralLit(i) else Ints.Neg(Ints.NumeralLit(-i))
       case IntLiteral(i) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(i))
+      case RealLiteral(d) => if (d >= 0) Reals.DecimalLit(d) else Reals.Neg(Reals.DecimalLit(-d))
       case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromInt(c.toInt))
       case BooleanLiteral(v) => Core.BoolConst(v)
       case Let(b,d,e) =>
@@ -507,21 +509,25 @@ abstract class SMTLIBSolver(val context: LeonContext,
       case LessThan(a,b) => a.getType match {
         case Int32Type => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b))
         case IntegerType => Ints.LessThan(toSMT(a), toSMT(b))
+        case RealType => Reals.LessThan(toSMT(a), toSMT(b))
         case CharType => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b))
       }
       case LessEquals(a,b) => a.getType match {
         case Int32Type => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b))
         case IntegerType => Ints.LessEquals(toSMT(a), toSMT(b))
+        case RealType => Reals.LessEquals(toSMT(a), toSMT(b))
         case CharType => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b))
       }
       case GreaterThan(a,b) => a.getType match {
         case Int32Type => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b))
         case IntegerType => Ints.GreaterThan(toSMT(a), toSMT(b))
+        case RealType => Reals.GreaterThan(toSMT(a), toSMT(b))
         case CharType => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b))
       }
       case GreaterEquals(a,b) => a.getType match {
         case Int32Type => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b))
         case IntegerType => Ints.GreaterEquals(toSMT(a), toSMT(b))
+        case RealType => Reals.GreaterEquals(toSMT(a), toSMT(b))
         case CharType => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b))
       }
       case BVPlus(a,b) => FixedSizeBitVectors.Add(toSMT(a), toSMT(b))
@@ -535,6 +541,12 @@ abstract class SMTLIBSolver(val context: LeonContext,
       case BVShiftLeft(a,b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b))
       case BVAShiftRight(a,b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b))
       case BVLShiftRight(a,b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b))
+
+      case RealPlus(a,b) => Reals.Add(toSMT(a), toSMT(b))
+      case RealMinus(a,b) => Reals.Sub(toSMT(a), toSMT(b))
+      case RealTimes(a,b) => Reals.Mul(toSMT(a), toSMT(b))
+      case RealDivision(a,b) => Reals.Div(toSMT(a), toSMT(b))
+
       case And(sub) => Core.And(sub.map(toSMT): _*)
       case Or(sub) => Core.Or(sub.map(toSMT): _*)
       case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze))
@@ -570,6 +582,9 @@ abstract class SMTLIBSolver(val context: LeonContext,
     case (SNumeral(n), IntegerType) =>
       InfiniteIntegerLiteral(n)
 
+    case (SDecimal(d), RealType) =>
+      RealLiteral(d)
+
     case (Core.True(), BooleanType)  => BooleanLiteral(true)
     case (Core.False(), BooleanType)  => BooleanLiteral(false)
 
@@ -640,9 +655,15 @@ abstract class SMTLIBSolver(val context: LeonContext,
     case (FunctionApplication(SimpleSymbol(SSymbol(app)), args), tpe) => {
       app match {
         case "-" =>
-          args match {
-            case List(a) => UMinus(fromSMT(a, IntegerType))
-            case List(a, b) => Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType))
+          (args, tpe) match {
+            case (List(a), IntegerType) => UMinus(fromSMT(a, IntegerType))
+            case (List(a, b), IntegerType) => Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType))
+            case (List(a), RealType) => RealUMinus(fromSMT(a, RealType))
+            case (List(a, b), RealType) => RealMinus(fromSMT(a, RealType), fromSMT(b, RealType))
+          }
+        case "/" =>
+          (args, tpe) match {
+            case (List(a, b), RealType) => RealDivision(fromSMT(a, RealType), fromSMT(b, RealType))
           }
 
         case _ =>
diff --git a/testcases/verification/math/RealProps.scala b/testcases/verification/math/RealProps.scala
new file mode 100644
index 000000000..0530ab97a
--- /dev/null
+++ b/testcases/verification/math/RealProps.scala
@@ -0,0 +1,72 @@
+import leon.lang._
+import leon.collection._
+import leon._
+
+import scala.language.postfixOps
+
+object RealProps {
+
+  def plusIsCommutative(a: Real, b: Real): Boolean = {
+    a + b == b + a
+  } holds
+
+  def plusIsAssociative(a: Real, b: Real, c: Real): Boolean = {
+    (a + b) + c == a + (b + c)
+  } holds
+
+  def timesIsCommutative(a: Real, b: Real): Boolean = {
+    a * b == b * a
+  } holds
+
+  def timesIsAssociative(a: Real, b: Real, c: Real): Boolean = {
+    (a * b) * c == a * (b * c)
+  } holds
+
+  def distributivity(a: Real, b: Real, c: Real): Boolean = {
+    a*(b + c) == a*b + a*c
+  } holds
+
+  def lessEqualsTransitive(a: Real, b: Real, c: Real): Boolean = {
+    require(a <= b && b <= c)
+    a <= c
+  } holds
+
+  def lessThanTransitive(a: Real, b: Real, c: Real): Boolean = {
+    require(a < b && b < c)
+    a < c
+  } holds
+
+  def greaterEqualsTransitive(a: Real, b: Real, c: Real): Boolean = {
+    require(a >= b && b >= c)
+    a >= c
+  } holds
+
+  def greaterThanTransitive(a: Real, b: Real, c: Real): Boolean = {
+    require(a > b && b > c)
+    a > c
+  } holds
+
+  /* between any two real, there is another real */
+  def density(a: Real, b: Real): Boolean = {
+    require(a < b)
+    val mid = (a + b) / Real(2)
+    a < mid && mid < b
+  } holds
+
+  def identity1(a: Real): Boolean = {
+    require(a != Real(0))
+    a/a == Real(1)
+  } holds
+
+  def test(r: Real): Real = {
+    r
+  } ensuring(res => res >= Real(0))
+
+  def findRoot(r: Real): Real = {
+    r * r
+  } ensuring(res => res != Real(4))
+
+  //def findSqrt2(r: Real): Real = {
+  //  r * r
+  //} ensuring(res => res != Real(2))
+}
-- 
GitLab