From 49d74675297ae71732db27d45c3158e49bbb1f88 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Thu, 3 May 2012 14:00:05 +0000
Subject: [PATCH] working progress in a pass to process array

---
 src/main/scala/leon/ArrayTransformation.scala | 95 +++++++++++++++++++
 src/main/scala/leon/Main.scala                |  2 +-
 .../scala/leon/plugin/CodeExtraction.scala    |  2 +-
 .../scala/leon/purescala/PrettyPrinter.scala  |  8 ++
 src/main/scala/leon/purescala/Trees.scala     |  4 +
 testcases/regression/valid/Array1.scala       |  9 ++
 6 files changed, 118 insertions(+), 2 deletions(-)
 create mode 100644 src/main/scala/leon/ArrayTransformation.scala
 create mode 100644 testcases/regression/valid/Array1.scala

diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala
new file mode 100644
index 000000000..ffd27587d
--- /dev/null
+++ b/src/main/scala/leon/ArrayTransformation.scala
@@ -0,0 +1,95 @@
+package leon
+
+import purescala.Common._
+import purescala.Definitions._
+import purescala.Trees._
+import purescala.TypeTrees._
+
+object ArrayTransformation extends Pass {
+
+  val description = "Add bound checking for array access and remove side effect array update operations"
+
+  def apply(pgm: Program): Program = {
+
+    val allFuns = pgm.definedFunctions
+    allFuns.foreach(fd => fd.body.map(body => {
+      val newBody = searchAndReplaceDFS{
+        case Let(i, v, b) => {println("let i: " + v.getType); v.getType match {
+          case ArrayType(_) => {
+            println("this is array type")
+            Some(LetVar(i, v, b))
+          }
+          case _ => None
+        }}
+        case sel@ArraySelect(ar, i) => {
+          Some(IfExpr(
+            And(GreaterEquals(i, i), LessThan(i, i)),
+            sel, 
+            Error("Array out of bound access").setType(sel.getType)
+          ).setType(sel.getType))
+        }
+        case ArrayUpdate(ar, i, v) => {
+          Some(IfExpr(
+            And(GreaterEquals(i, i), LessThan(i, i)),
+            Assignment(ar.asInstanceOf[Variable].id, ArrayUpdated(ar, i, v).setType(ar.getType)),
+            Error("Array out of bound access").setType(UnitType)
+          ).setType(UnitType))
+        }
+        case _ => None
+      }(body)
+      fd.body = Some(newBody)
+    }))
+    pgm
+  }
+
+  private def transform(expr: Expr): Expr = expr match {
+    case fill@ArrayFill(length, default) => {
+      var rLength = transform(length)
+      val rDefault = transform(default)
+      val rFill = ArrayFill(length, default).setType(fill.getType)
+      Tuple(Seq(rFill, length)).setType(TupleType(Seq(fill.getType, Int32Type)))
+    }
+    case sel@ArraySelect(a, i) => {
+      val ar = transform(a)
+      val ir = transform(i)
+      val length = TupleSelect(ar, 2)
+      IfExpr(
+        And(GreaterEquals(i, IntLiteral(0)), LessThan(i, length)),
+        ArraySelect(TupleSelect(ar, 1), ir).setType(sel.getType),
+        Error("Array out of bound access").setType(sel.getType)
+      ).setType(sel.getType)
+    }
+    case ArrayUpdate(a, i, v) => {
+      val ar = transform(a)
+      val ir = transform(i)
+      val vr = transform(v)
+      val length = TupleSelect(ar, 2)
+      val  = Tuple
+      Some(IfExpr(
+        And(GreaterEquals(i, i), LessThan(i, i)),
+        Assignment(ar.asInstanceOf[Variable].id, ArrayUpdated(ar, i, v).setType(ar.getType)),
+        Error("Array out of bound access").setType(UnitType)
+      ).setType(UnitType))
+    }
+
+    case Let(i, v, b) => v.getType match {
+      case ArrayType(_) => Some(LetVar(i, v, b))
+      case _ => None
+    }
+    case LetVar(id, e, b) =>
+
+    case ite@IfExpr(cond, tExpr, eExpr) => 
+
+    case m @ MatchExpr(scrut, cses) => 
+    case LetDef(fd, b) => 
+
+    case n @ NAryOperator(args, recons) => {
+    case b @ BinaryOperator(a1, a2, recons) => 
+    case u @ UnaryOperator(a, recons) => 
+
+    case (t: Terminal) => 
+    case unhandled => scala.sys.error("Non-terminal case should be handled in ArrayTransformation: " + unhandled)
+
+  }
+
+}
diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala
index e9818f16d..b3b3c17b4 100644
--- a/src/main/scala/leon/Main.scala
+++ b/src/main/scala/leon/Main.scala
@@ -32,7 +32,7 @@ object Main {
 
   private def defaultAction(program: Program, reporter: Reporter) : Unit = {
     Logger.debug("Default action on program: " + program, 3, "main")
-    val passManager = new PassManager(Seq(EpsilonElimination, ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator))
+    val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator))
     val program2 = passManager.run(program)
     val analysis = new Analysis(program2, reporter)
     analysis.analyse
diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala
index a83e819b8..1a760f3fe 100644
--- a/src/main/scala/leon/plugin/CodeExtraction.scala
+++ b/src/main/scala/leon/plugin/CodeExtraction.scala
@@ -809,7 +809,7 @@ trait CodeExtraction extends Extractors {
           val underlying = scalaType2PureScala(unit, silent)(baseType.tpe)
           val lengthRec = rec(length)
           val defaultValueRec = rec(defaultValue)
-          ArrayFill(lengthRec, defaultValueRec).setType(underlying)
+          ArrayFill(lengthRec, defaultValueRec).setType(ArrayType(underlying))
         }
         case ExIfThenElse(t1,t2,t3) => {
           val r1 = rec(t1)
diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index 7b3636015..3b2a0718f 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -269,6 +269,14 @@ object PrettyPrinter {
       sb.append(") = ")
       pp(v, sb, lvl)
     }
+    case ArrayUpdated(ar, i, v) => {
+      pp(ar, sb, lvl)
+      sb.append(".updated(")
+      pp(i, sb, lvl)
+      sb.append(", ")
+      pp(v, sb, lvl)
+      sb.append(")")
+    }
 
     case Distinct(exprs) => {
       var nsb = sb
diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index b160e495e..6f103ad80 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -382,7 +382,10 @@ object Trees {
   /* Array operations */
   case class ArrayFill(length: Expr, defaultValue: Expr) extends Expr
   case class ArraySelect(array: Expr, index: Expr) extends Expr
+  //the difference between ArrayUpdate and ArrayUpdated is that the former has a side effect while the latter is the function variant
+  //ArrayUpdate should be eliminated soon in the analysis while ArrayUpdated is keep all the way to the backend
   case class ArrayUpdate(array: Expr, index: Expr, newValue: Expr) extends Expr
+  case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr 
 
   /* List operations */
   case class NilList(baseType: TypeTree) extends Expr with Terminal
@@ -471,6 +474,7 @@ object Trees {
       case FiniteMap(args) => Some((args, (as : Seq[Expr]) => FiniteMap(as.asInstanceOf[Seq[SingletonMap]])))
       case FiniteMultiset(args) => Some((args, FiniteMultiset))
       case ArrayUpdate(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdate(as(0), as(1), as(2))))
+      case ArrayUpdated(t1, t2, t3) => Some((Seq(t1,t2,t3), (as: Seq[Expr]) => ArrayUpdated(as(0), as(1), as(2))))
       case Distinct(args) => Some((args, Distinct))
       case Block(args, rest) => Some((args :+ rest, exprs => Block(exprs.init, exprs.last)))
       case Tuple(args) => Some((args, Tuple))
diff --git a/testcases/regression/valid/Array1.scala b/testcases/regression/valid/Array1.scala
new file mode 100644
index 000000000..3815f0344
--- /dev/null
+++ b/testcases/regression/valid/Array1.scala
@@ -0,0 +1,9 @@
+object Array1 {
+
+  def foo(): Int = {
+    val a = Array.fill(5)(0)
+    a(2) = 3
+    a(2)
+  } ensuring(_ == 3)
+
+}
-- 
GitLab