From bac2daf1e5f2beacc07f46e720f7f1cc5a12d2af Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Wed, 9 May 2012 19:43:29 +0200
Subject: [PATCH] correct bug when writing if(..) a(i) = .. else 0

---
 src/main/scala/leon/ArrayTransformation.scala |  6 +++---
 src/main/scala/leon/FairZ3Solver.scala        |  2 +-
 .../leon/ImperativeCodeElimination.scala      |  5 +++++
 .../scala/leon/plugin/CodeExtraction.scala    | 12 ++++++++---
 .../scala/leon/purescala/Definitions.scala    |  2 +-
 .../scala/leon/purescala/PrettyPrinter.scala  |  8 ++++----
 src/main/scala/leon/purescala/Trees.scala     |  8 +++++++-
 src/main/scala/leon/purescala/TypeTrees.scala | 20 +++++++++----------
 testcases/regression/error/Array3.scala       | 11 ++++++++++
 9 files changed, 51 insertions(+), 23 deletions(-)
 create mode 100644 testcases/regression/error/Array3.scala

diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala
index db6a7fe2f..6c4382ff7 100644
--- a/src/main/scala/leon/ArrayTransformation.scala
+++ b/src/main/scala/leon/ArrayTransformation.scala
@@ -42,10 +42,9 @@ object ArrayTransformation extends Pass {
       val rv = transform(v)
       val Variable(id) = ra
       val length = ArrayLength(ra)
-      val array = TupleSelect(ra, 1).setType(ArrayType(v.getType))
       val res = IfExpr(
         And(LessEquals(IntLiteral(0), ri), LessThan(ri, length)),
-        Assignment(id, ArrayUpdated(ra, ri, rv).setType(a.getType).setPosInfo(up)),
+        Assignment(id, ArrayUpdated(ra, ri, rv).setType(ra.getType).setPosInfo(up)),
         Error("Index out of bound").setType(UnitType).setPosInfo(up)
       ).setType(UnitType)
       res
@@ -53,7 +52,7 @@ object ArrayTransformation extends Pass {
     case Let(i, v, b) => {
       v.getType match {
         case ArrayType(_) => {
-          val freshIdentifier = FreshIdentifier("t").setType(v.getType)
+          val freshIdentifier = FreshIdentifier("t").setType(i.getType)
           id2FreshId += (i -> freshIdentifier)
           LetVar(freshIdentifier, transform(v), transform(b))
         }
@@ -74,6 +73,7 @@ object ArrayTransformation extends Pass {
       val newWh = While(transform(c), transform(e))
       newWh.invariant = wh.invariant.map(i => transform(i))
       newWh.setPosInfo(wh)
+      newWh
     }
 
     case ite@IfExpr(c, t, e) => {
diff --git a/src/main/scala/leon/FairZ3Solver.scala b/src/main/scala/leon/FairZ3Solver.scala
index 58cd49ff4..90f95153e 100644
--- a/src/main/scala/leon/FairZ3Solver.scala
+++ b/src/main/scala/leon/FairZ3Solver.scala
@@ -1237,7 +1237,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S
                   val r1 = rargs(1)
                   val r2 = rargs(2)
                   try {
-                    IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType))
+                    IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType).get)
                   } catch {
                     case e => {
                       println("I was asking for lub because of this.")
diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala
index f6937463e..0bffa9459 100644
--- a/src/main/scala/leon/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/ImperativeCodeElimination.scala
@@ -49,6 +49,11 @@ object ImperativeCodeElimination extends Pass {
         val (cRes, cScope, cFun) = toFunction(cond)
         val (tRes, tScope, tFun) = toFunction(tExpr)
         val (eRes, eScope, eFun) = toFunction(eExpr)
+        if(tRes.getType != eRes.getType) {
+          println("PROBLEM: tres: " + tRes + " has type: " + tRes.getType + " then was: " + tExpr)
+          println("PROBLEM: eres: " + eRes + " has type: " + eRes.getType + " else was: " + eExpr)
+          assert(false)
+        }
 
         val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varInScope).toSeq
         val resId = FreshIdentifier("res").setType(ite.getType)
diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala
index 6162db6b1..0d05298b0 100644
--- a/src/main/scala/leon/plugin/CodeExtraction.scala
+++ b/src/main/scala/leon/plugin/CodeExtraction.scala
@@ -816,7 +816,7 @@ trait CodeExtraction extends Extractors {
           }
           val indexRec = rec(index)
           val newValueRec = rec(newValue)
-          ArrayUpdate(lhsRec, indexRec, newValueRec).setType(newValueRec.getType).setPosInfo(update.pos.line, update.pos.column)
+          ArrayUpdate(lhsRec, indexRec, newValueRec).setPosInfo(update.pos.line, update.pos.column)
         }
         case ExArrayLength(t) => {
           val rt = rec(t)
@@ -836,7 +836,13 @@ trait CodeExtraction extends Extractors {
           }
           val r2 = rec(t2)
           val r3 = rec(t3)
-          IfExpr(r1, r2, r3).setType(leastUpperBound(r2.getType, r3.getType))
+          val lub = leastUpperBound(r2.getType, r3.getType)
+          lub match {
+            case Some(lub) => IfExpr(r1, r2, r3).setType(lub)
+            case None =>
+              unit.error(tree.pos, "Both branches of ifthenelse have incompatible types")
+              throw ImpureCodeEncounteredException(t1)
+          }
         }
 
         case ExIsInstanceOf(tt, cc) => {
@@ -879,7 +885,7 @@ trait CodeExtraction extends Extractors {
         case pm @ ExPatternMatching(sel, cses) => {
           val rs = rec(sel)
           val rc = cses.map(rewriteCaseDef(_))
-          val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_))
+          val rt: purescala.TypeTrees.TypeTree = rc.map(_.rhs.getType).reduceLeft(leastUpperBound(_,_).get)
           MatchExpr(rs, rc).setType(rt).setPosInfo(pm.pos.line,pm.pos.column)
         }
 
diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala
index f73530c40..bc497dcdd 100644
--- a/src/main/scala/leon/purescala/Definitions.scala
+++ b/src/main/scala/leon/purescala/Definitions.scala
@@ -30,7 +30,7 @@ object Definitions {
       case _ => false
     }
 
-    def toVariable : Variable = Variable(id)//.setType(tpe)
+    def toVariable : Variable = Variable(id).setType(tpe)
   }
 
   type VarDecls = Seq[VarDecl]
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index c4b6b4fc9..ae42766f9 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -66,7 +66,7 @@ object PrettyPrinter {
   }
 
   private def pp(tree: Expr, sb: StringBuffer, lvl: Int): StringBuffer = tree match {
-    case Variable(id) => sb.append(id)
+    case Variable(id) => sb.append(id + "#" + id.getType)
     case DeBruijnIndex(idx) => sb.append("_" + idx)
     case Let(b,d,e) => {
         //pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")")
@@ -139,10 +139,10 @@ object PrettyPrinter {
       sb.append("\n")
     }
 
-    case Tuple(exprs) => ppNary(sb, exprs, "(", ", ", ")", lvl)
-    case TupleSelect(t, i) => {
+    case t@Tuple(exprs) => ppNary(sb, exprs, "(", ", ", ")#" + t.getType, lvl)
+    case s@TupleSelect(t, i) => {
       pp(t, sb, lvl)
-      sb.append("._" + i)
+      sb.append("._" + i + "#" + s.getType)
       sb
     }
 
diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index 33fb76780..8938048fb 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -76,7 +76,13 @@ object Trees {
   }
   case class IfExpr(cond: Expr, then: Expr, elze: Expr) extends Expr 
 
-  case class Tuple(exprs: Seq[Expr]) extends Expr
+  case class Tuple(exprs: Seq[Expr]) extends Expr {
+    val subTpes = exprs.map(_.getType)
+    if(!subTpes.exists(_ == Untyped)) {
+      setType(TupleType(subTpes))
+    }
+
+  }
   case class TupleSelect(tuple: Expr, index: Int) extends Expr
 
   object MatchExpr {
diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala
index 0410fd11a..01aea89f9 100644
--- a/src/main/scala/leon/purescala/TypeTrees.scala
+++ b/src/main/scala/leon/purescala/TypeTrees.scala
@@ -18,7 +18,7 @@ object TypeTrees {
 
     def setType(tt: TypeTree): self.type = _type match {
       case None => _type = Some(tt); this
-      case Some(o) if o != tt => scala.sys.error("Resetting type information.")
+      case Some(o) if o != tt => scala.sys.error("Resetting type information! Type [" + o + "] is modified to [" + tt)
       case _ => this
     }
   }
@@ -47,7 +47,7 @@ object TypeTrees {
     case other => other
   }
 
-  def leastUpperBound(t1: TypeTree, t2: TypeTree): TypeTree = (t1,t2) match {
+  def leastUpperBound(t1: TypeTree, t2: TypeTree): Option[TypeTree] = (t1,t2) match {
     case (c1: ClassType, c2: ClassType) => {
       import scala.collection.immutable.Set
       var c: ClassTypeDef = c1.classDef
@@ -72,19 +72,19 @@ object TypeTrees {
       }
 
       if(found.isEmpty) {
-        scala.sys.error("Asking for lub of unrelated class types : " + t1 + " and " + t2)
+        None
       } else {
-        classDefToClassType(found.get)
+        Some(classDefToClassType(found.get))
       }
     }
 
-    case (o1, o2) if (o1 == o2) => o1
-    case (o1,BottomType) => o1
-    case (BottomType,o2) => o2
-    case (o1,AnyType) => AnyType
-    case (AnyType,o2) => AnyType
+    case (o1, o2) if (o1 == o2) => Some(o1)
+    case (o1,BottomType) => Some(o1)
+    case (BottomType,o2) => Some(o2)
+    case (o1,AnyType) => Some(AnyType)
+    case (AnyType,o2) => Some(AnyType)
 
-    case _ => scala.sys.error("Asking for lub of unrelated types: " + t1 + " and " + t2)
+    case _ => None
   }
 
   // returns the number of distinct values that inhabit a type
diff --git a/testcases/regression/error/Array3.scala b/testcases/regression/error/Array3.scala
new file mode 100644
index 000000000..451dd9db0
--- /dev/null
+++ b/testcases/regression/error/Array3.scala
@@ -0,0 +1,11 @@
+object Array3 {
+
+  def foo(): Int = {
+    val a = Array.fill(5)(5)
+    if(a.length > 2)
+      a(1) = 2
+    else 0
+    0
+  }
+
+}
-- 
GitLab