From 18f41bdcc1ba4522cf78905679e916ea4ffc2f67 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Wed, 13 Feb 2013 15:05:03 +0100
Subject: [PATCH] Strengthen type invariants in trees

---
 src/main/scala/leon/purescala/Trees.scala     | 72 +++++++++++++------
 .../leon/solvers/z3/FunctionTemplate.scala    |  4 +-
 2 files changed, 54 insertions(+), 22 deletions(-)

diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index f1bd339c5..f6db8369e 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -23,11 +23,11 @@ object Trees {
    * the expected type. */
   case class Error(description: String) extends Expr with Terminal with ScalacPositional
 
-  case class Choose(vars: List[Identifier], pred: Expr) extends Expr with ScalacPositional with UnaryExtractable with FixedType {
+  case class Choose(vars: List[Identifier], pred: Expr) extends Expr with FixedType with ScalacPositional with UnaryExtractable {
 
     assert(!vars.isEmpty)
 
-    val fixedType = if (vars.size > 1) TupleType(vars.map(_.getType)) else  vars.head.getType
+    val fixedType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType
 
     def extract = {
       Some((pred, (e: Expr) => Choose(vars, e).setPosInfo(this)))
@@ -35,21 +35,18 @@ object Trees {
   }
 
   /* Like vals */
-  case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr {
+  case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr with FixedType {
     binder.markAsLetBinder
-    val et = body.getType
-    if(et != Untyped)
-      setType(et)
+
+    val fixedType = body.getType
   }
 
-  case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr {
+  case class LetTuple(binders: Seq[Identifier], value: Expr, body: Expr) extends Expr with FixedType {
     binders.foreach(_.markAsLetBinder)
     assert(value.getType.isInstanceOf[TupleType],
            "The definition value in LetTuple must be of some tuple type; yet we got [%s]. In expr: \n%s".format(value.getType, this))
 
-    val et = body.getType
-    if(et != Untyped)
-      setType(et)
+    val fixedType = body.getType
   }
 
   /* Control flow */
@@ -62,11 +59,8 @@ object Trees {
     val fixedType = leastUpperBound(then.getType, elze.getType).getOrElse(AnyType)
   }
 
-  case class Tuple(exprs: Seq[Expr]) extends Expr {
-    val subTpes = exprs.map(_.getType)
-    if(subTpes.forall(_ != Untyped)) {
-      setType(TupleType(subTpes))
-    }
+  case class Tuple(exprs: Seq[Expr]) extends Expr with FixedType {
+    val fixedType = TupleType(exprs.map(_.getType))
   }
 
   object TupleSelect {
@@ -81,16 +75,20 @@ object Trees {
       if (e eq null) None else Some((e.tuple, e.index))
     }
   }
+
   // This must be 1-indexed ! (So are methods of Scala Tuples)
   class TupleSelect(val tuple: Expr, val index: Int) extends Expr with FixedType {
     assert(index >= 1)
 
+    assert(tuple.getType.isInstanceOf[TupleType], "Applying TupleSelect on a non-tuple tree [%s] of type [%s].".format(tuple, tuple.getType))
+
     val fixedType : TypeTree = tuple.getType match {
-      case TupleType(ts) => 
+      case TupleType(ts) =>
         assert(index <= ts.size)
         ts(index - 1)
 
-      case _ => scala.sys.error("Applying TupleSelect on a non-tuple tree [%s] of type [%s].".format(tuple, tuple.getType))
+      case _ =>
+        AnyType
     }
 
     override def equals(that: Any): Boolean = (that != null) && (that match {
@@ -392,9 +390,11 @@ object Trees {
   case class IntLiteral(value: Int) extends Literal[Int] with FixedType {
     val fixedType = Int32Type
   }
+
   case class BooleanLiteral(value: Boolean) extends Literal[Boolean] with FixedType {
     val fixedType = BooleanType
   }
+
   case class StringLiteral(value: String) extends Literal[String]
   case object UnitLiteral extends Literal[Unit] with FixedType {
     val fixedType = UnitType
@@ -404,6 +404,7 @@ object Trees {
   case class CaseClass(classDef: CaseClassDef, args: Seq[Expr]) extends Expr with FixedType {
     val fixedType = CaseClassType(classDef)
   }
+
   case class CaseClassInstanceOf(classDef: CaseClassDef, expr: Expr) extends Expr with FixedType {
     val fixedType = BooleanType
   }
@@ -508,10 +509,39 @@ object Trees {
   }
 
   /* Array operations */
-  case class ArrayFill(length: Expr, defaultValue: Expr) extends Expr
-  case class ArrayMake(defaultValue: Expr) extends Expr
-  case class ArraySelect(array: Expr, index: Expr) extends Expr with ScalacPositional
-  case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr with ScalacPositional
+  case class ArrayFill(length: Expr, defaultValue: Expr) extends Expr with FixedType {
+    val fixedType = ArrayType(defaultValue.getType)
+  }
+
+  case class ArrayMake(defaultValue: Expr) extends Expr with FixedType {
+    val fixedType = ArrayType(defaultValue.getType)
+  }
+
+  case class ArraySelect(array: Expr, index: Expr) extends Expr with ScalacPositional with FixedType {
+    assert(array.getType.isInstanceOf[ArrayType],
+           "The array value in ArraySelect must of of array type; yet we got [%s]. In expr: \n%s".format(array.getType, array))
+
+    val fixedType = array.getType match {
+      case ArrayType(base) =>
+        base
+      case _ =>
+        AnyType
+    }
+
+  }
+
+  case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr with ScalacPositional with FixedType {
+    assert(array.getType.isInstanceOf[ArrayType],
+           "The array value in ArrayUpdated must of of array type; yet we got [%s]. In expr: \n%s".format(array.getType, array))
+
+    val fixedType = array.getType match {
+      case ArrayType(base) =>
+        leastUpperBound(base, newValue.getType).map(ArrayType(_)).getOrElse(AnyType)
+      case _ =>
+        AnyType
+    }
+  }
+
   case class ArrayLength(array: Expr) extends Expr with FixedType {
     val fixedType = Int32Type
   }
diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala
index 61454e38c..abb79d0ec 100644
--- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala
+++ b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala
@@ -58,7 +58,9 @@ class FunctionTemplate private(
     zippedFunDefArgs
   }
 
-  val asZ3Clauses: Seq[Z3AST] = asClauses.map(solver.toZ3Formula(_, idToZ3Ids).get)
+  val asZ3Clauses: Seq[Z3AST] = asClauses.map {
+    solver.toZ3Formula(_, idToZ3Ids).getOrElse(sys.error("Could not translate to z3. Did you forget --xlang?"))
+  }
 
   private val blockers : Map[Identifier,Set[FunctionInvocation]] = {
     val idCall = FunctionInvocation(funDef, funDef.args.map(_.toVariable))
-- 
GitLab