From 80de017453dd9f7a2e6ebe33ec6345cc07639af3 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Thu, 31 Mar 2016 01:00:13 +0200
Subject: [PATCH] Fix for ADT invariant extraction and new testcase

---
 .../frontends/scalac/CodeExtraction.scala     | 13 ++++---
 .../scala/leon/purescala/Definitions.scala    |  2 +-
 .../scala/leon/purescala/MethodLifting.scala  |  2 ++
 .../valid/BinarySearchTreeQuant2.scala        | 35 +++++++++++++++++++
 .../valid/BinarySearchTreeQuant2.scala        | 35 +++++++++++++++++++
 5 files changed, 82 insertions(+), 5 deletions(-)
 create mode 100644 src/test/resources/regression/verification/purescala/valid/BinarySearchTreeQuant2.scala
 create mode 100644 testcases/verification/quantification/valid/BinarySearchTreeQuant2.scala

diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 9dc3514a4..993bc4fd0 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -334,11 +334,16 @@ trait CodeExtraction extends ASTExtractors {
 
         classToInvariants.get(sym).foreach { bodies =>
           val cd = classesToClasses(sym)
-          val fd = new FunDef(invId, Seq.empty, Seq.empty, BooleanType)
-          fd.addFlag(IsADTInvariant)
-          fd.addFlags(cd.flags.collect { case annot : purescala.Definitions.Annotation => annot })
 
-          cd.registerMethod(fd)
+          for (c <- (cd.ancestors.toSet ++ cd.root.knownDescendants + cd) if !c.methods.exists(_.isInvariant)) {
+            val fd = new FunDef(invId, Seq.empty, Seq.empty, BooleanType)
+            fd.addFlag(IsADTInvariant)
+            fd.addFlags(c.flags.collect { case annot : purescala.Definitions.Annotation => annot })
+            fd.fullBody = BooleanLiteral(true)
+            c.registerMethod(fd)
+          }
+
+          val fd = cd.methods.find(_.isInvariant).get
           val ctparams = sym.tpe match {
             case TypeRef(_, _, tps) =>
               extractTypeParams(tps).map(_._1)
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index 0103bd950..3f00b8e69 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -286,7 +286,7 @@ object Definitions {
       case None => _invariant = Some(fd)
     }
 
-    def hasInvariant: Boolean = invariant.isDefined || (root.knownChildren.exists(cd => cd.methods.exists(_.isInvariant)))
+    def hasInvariant: Boolean = invariant.isDefined || (this +: root.knownDescendants).exists(cd => cd.methods.exists(_.isInvariant))
 
     def annotations: Set[String] = extAnnotations.keySet
     def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap
diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala
index 13810f83d..5c297bc8b 100644
--- a/src/main/scala/leon/purescala/MethodLifting.scala
+++ b/src/main/scala/leon/purescala/MethodLifting.scala
@@ -223,6 +223,7 @@ object MethodLifting extends TransformationPhase {
 
             Some(and(classPre(fd), compositePre))
           }
+
           val postSimple = {
             val trivial = post.forall {
               case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true
@@ -239,6 +240,7 @@ object MethodLifting extends TransformationPhase {
               ).setPos(fd))
             }
           }
+
           val bodySimple = {
             val trivial = body forall {
               case SimpleCase(_, NoTree(_)) => true
diff --git a/src/test/resources/regression/verification/purescala/valid/BinarySearchTreeQuant2.scala b/src/test/resources/regression/verification/purescala/valid/BinarySearchTreeQuant2.scala
new file mode 100644
index 000000000..46fb2bbeb
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/BinarySearchTreeQuant2.scala
@@ -0,0 +1,35 @@
+import leon.lang._
+import leon.collection._
+
+object BSTSimpler2 {
+
+  sealed abstract class Tree {
+    def content: Set[BigInt] = this match {
+      case Leaf() => Set.empty[BigInt]
+      case Node(l, v, r) => l.content ++ Set(v) ++ r.content
+    }
+  }
+
+  case class Node(left: Tree, value: BigInt, right: Tree) extends Tree {
+    require(forall((x:BigInt) => (left.content.contains(x) ==> x < value)) &&
+      forall((x:BigInt) => (right.content.contains(x) ==> value < x)))
+  }
+
+  case class Leaf() extends Tree
+
+  def emptySet(): Tree = Leaf()
+
+  def insert(tree: Tree, value: BigInt): Node = {
+    tree match {
+      case Leaf() => Node(Leaf(), value, Leaf())
+      case Node(l, v, r) => (if (v < value) {
+        Node(l, v, insert(r, value))
+      } else if (v > value) {
+        Node(insert(l, value), v, r)
+      } else {
+        Node(l, v, r)
+      })
+    }
+  } ensuring(res => res.content == tree.content ++ Set(value))
+
+}
diff --git a/testcases/verification/quantification/valid/BinarySearchTreeQuant2.scala b/testcases/verification/quantification/valid/BinarySearchTreeQuant2.scala
new file mode 100644
index 000000000..0c72b1018
--- /dev/null
+++ b/testcases/verification/quantification/valid/BinarySearchTreeQuant2.scala
@@ -0,0 +1,35 @@
+import leon.lang._
+import leon.collection._
+
+object BSTSimpler {
+
+  sealed abstract class Tree {
+    def content: Set[BigInt] = this match {
+      case Leaf() => Set.empty[BigInt]
+      case Node(l, v, r) => l.content ++ Set(v) ++ r.content
+    }
+  }
+
+  case class Node(left: Tree, value: BigInt, right: Tree) extends Tree {
+    require(forall((x:BigInt) => (left.content.contains(x) ==> x < value)) &&
+      forall((x:BigInt) => (right.content.contains(x) ==> value < x)))
+  }
+
+  case class Leaf() extends Tree
+
+  def emptySet(): Tree = Leaf()
+
+  def insert(tree: Tree, value: BigInt): Node = {
+    tree match {
+      case Leaf() => Node(Leaf(), value, Leaf())
+      case Node(l, v, r) => (if (v < value) {
+        Node(l, v, insert(r, value))
+      } else if (v > value) {
+        Node(insert(l, value), v, r)
+      } else {
+        Node(l, v, r)
+      })
+    }
+  } ensuring(res => res.content == tree.content ++ Set(value))
+
+}
-- 
GitLab