From 22e8927829dd9ffdb5a8934e0fa6f0d564c641c0 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Mon, 29 Aug 2016 16:04:37 +0200
Subject: [PATCH] Definition deconstruction for stainless VarDefs

---
 src/main/scala/inox/Reporter.scala        |  4 ++--
 src/main/scala/inox/ast/Definitions.scala | 21 +++++++++++++--------
 src/main/scala/inox/ast/Extractors.scala  |  6 ++++++
 src/main/scala/inox/ast/TreeOps.scala     | 15 ++++++++++++---
 4 files changed, 33 insertions(+), 13 deletions(-)

diff --git a/src/main/scala/inox/Reporter.scala b/src/main/scala/inox/Reporter.scala
index 597f65517..32347fab9 100644
--- a/src/main/scala/inox/Reporter.scala
+++ b/src/main/scala/inox/Reporter.scala
@@ -150,8 +150,8 @@ class DefaultReporter(debugSections: Set[DebugSection]) extends Reporter(debugSe
         lines
     }
 
-    if (lines.size > pos.line-1 && pos.line >= 0) {
-      Some(lines(pos.line-1))
+    if (lines.size > pos.line && pos.line >= 0) {
+      Some(lines(pos.line))
     } else {
       None
     }
diff --git a/src/main/scala/inox/ast/Definitions.scala b/src/main/scala/inox/ast/Definitions.scala
index 8e2417f3f..e54a768cc 100644
--- a/src/main/scala/inox/ast/Definitions.scala
+++ b/src/main/scala/inox/ast/Definitions.scala
@@ -173,17 +173,22 @@ trait Definitions { self: Trees =>
   }
  
   // Compiler annotations given in the source code as @annot
-  class Annotation(val annot: String, val args: Seq[Option[Any]]) {
-    override def equals(that: Any): Boolean = that match {
-      case o: Annotation => annot == o.annot && args == o.args
-      case _ => false
-    }
-
-    override def hashCode: Int = annot.hashCode + 31 * args.hashCode
+  case class Annotation(val annot: String, val args: Seq[Option[Any]]) extends Printable {
+    def asString(implicit opts: PrinterOptions): String = annot + (if (args.isEmpty) "" else {
+      args.map { case p: Printable => p.asString case arg => arg.toString }.mkString("(", ",", ")")
+    })
   }
 
   /** Denotes that this adt is refined by invariant ''id'' */
-  case class HasADTInvariant(id: Identifier) extends Annotation("invariant", Seq(Some(id)))
+  class HasADTInvariant(id: Identifier) extends Annotation("invariant", Seq(Some(id)))
+
+  object HasADTInvariant {
+    def apply(id: Identifier): HasADTInvariant = new HasADTInvariant(id)
+    def unapply(annot: Annotation): Option[Identifier] = annot match {
+      case Annotation("invariant", Seq(Some(id: Identifier))) => Some(id)
+      case _ => None
+    }
+  }
 
   /** Represents an ADT definition (either the ADT sort or a constructor). */
   sealed trait ADTDefinition extends Definition {
diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala
index 7e5a18d1c..baf93ab3e 100644
--- a/src/main/scala/inox/ast/Extractors.scala
+++ b/src/main/scala/inox/ast/Extractors.scala
@@ -8,6 +8,12 @@ trait TreeDeconstructor {
   protected val s: Trees
   protected val t: Trees
 
+  // Basically only provided for ValDefs, but could be extended to other
+  // definitions if the users wish to
+  def deconstruct(d: s.Definition): (Identifier, Seq[s.Expr], Seq[s.Type], (Identifier, Seq[t.Expr], Seq[t.Type]) => t.Definition) = d match {
+    case s.ValDef(id, tpe) => (id, Seq.empty, Seq(tpe), (id, es, tps) => t.ValDef(id, tps.head))
+  }
+
   def deconstruct(expr: s.Expr): (Seq[s.Expr], Seq[s.Type], (Seq[t.Expr], Seq[t.Type]) => t.Expr) = expr match {
     /* Unary operators */
     case s.Not(e) =>
diff --git a/src/main/scala/inox/ast/TreeOps.scala b/src/main/scala/inox/ast/TreeOps.scala
index a43920d69..861c946ff 100644
--- a/src/main/scala/inox/ast/TreeOps.scala
+++ b/src/main/scala/inox/ast/TreeOps.scala
@@ -17,9 +17,18 @@ trait TreeOps { self: Trees =>
     }
 
     def transform(vd: ValDef): ValDef = {
-      val (id, tpe) = transform(vd.id, vd.tpe)
-      if ((id ne vd.id) || (tpe ne vd.tpe)) {
-        ValDef(id, tpe).copiedFrom(vd)
+      val (id, es, Seq(tpe), builder) = deconstructor.deconstruct(vd)
+      val (newId, newTpe) = transform(id, tpe)
+
+      var changed = false
+      val newEs = for (e <- es) yield {
+        val newE = transform(e)
+        if (e ne newE) changed = true
+        newE
+      }
+
+      if ((id ne newId) || (tpe ne newTpe) || changed) {
+        builder(newId, newEs, Seq(newTpe)).copiedFrom(vd).asInstanceOf[ValDef]
       } else {
         vd
       }
-- 
GitLab