From ae577b37bcd724c90c9bc97401f123121cdcaab4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Tue, 12 Apr 2016 23:02:28 +0200
Subject: [PATCH] fix flatten blocks

---
 src/main/scala/leon/xlang/ExprOps.scala       |  7 ++-
 src/main/scala/leon/xlang/Expressions.scala   |  5 +-
 .../verification/xlang/valid/Blocks1.scala    | 10 ++++
 .../scala/leon/unit/xlang/ExprOpsSuite.scala  | 52 +++++++++++++++++++
 4 files changed, 70 insertions(+), 4 deletions(-)
 create mode 100644 src/test/resources/regression/verification/xlang/valid/Blocks1.scala
 create mode 100644 src/test/scala/leon/unit/xlang/ExprOpsSuite.scala

diff --git a/src/main/scala/leon/xlang/ExprOps.scala b/src/main/scala/leon/xlang/ExprOps.scala
index a5013e5f8..e4680286f 100644
--- a/src/main/scala/leon/xlang/ExprOps.scala
+++ b/src/main/scala/leon/xlang/ExprOps.scala
@@ -21,9 +21,12 @@ object ExprOps {
   def flattenBlocks(expr: Expr): Expr = {
     postMap({
       case Block(exprs, last) =>
-        val nexprs = (exprs :+ last).flatMap{
+        val filtered = exprs.filter{
+          case UnitLiteral() => false
+          case _ => true
+        }
+        val nexprs = (filtered :+ last).flatMap{
           case Block(es2, el) => es2 :+ el
-          case UnitLiteral() => Seq()
           case e2 => Seq(e2)
         }
         Some(nexprs match {
diff --git a/src/main/scala/leon/xlang/Expressions.scala b/src/main/scala/leon/xlang/Expressions.scala
index 642655062..c0e2d4deb 100644
--- a/src/main/scala/leon/xlang/Expressions.scala
+++ b/src/main/scala/leon/xlang/Expressions.scala
@@ -28,8 +28,9 @@ object Expressions {
       Some((exprs :+ last, exprs => Block(exprs.init, exprs.last)))
     }
 
-    override def getPos = {
-      Position.between(exprs.head.getPos, last.getPos)
+    override def getPos = exprs.headOption match {
+      case Some(head) => Position.between(head.getPos, last.getPos)
+      case None => last.getPos
     }
 
     def printWith(implicit pctx: PrinterContext) {
diff --git a/src/test/resources/regression/verification/xlang/valid/Blocks1.scala b/src/test/resources/regression/verification/xlang/valid/Blocks1.scala
new file mode 100644
index 000000000..a2266bcc6
--- /dev/null
+++ b/src/test/resources/regression/verification/xlang/valid/Blocks1.scala
@@ -0,0 +1,10 @@
+object Blocks1 {
+
+  //this used to crash as we would simplify away the final Unit, and get a typing
+  //error during the solving part
+  def test(a: BigInt): Unit = {
+    42
+    ()
+  } ensuring(_ => a == (a + a - a))
+
+}
diff --git a/src/test/scala/leon/unit/xlang/ExprOpsSuite.scala b/src/test/scala/leon/unit/xlang/ExprOpsSuite.scala
new file mode 100644
index 000000000..df9b80225
--- /dev/null
+++ b/src/test/scala/leon/unit/xlang/ExprOpsSuite.scala
@@ -0,0 +1,52 @@
+/* Copyright 2009-2016 EPFL, Lausanne */
+
+package leon.unit.xlang
+
+import org.scalatest._
+
+import leon.test._
+import leon.purescala.Common._
+import leon.purescala.Expressions._
+import leon.purescala.Types._
+import leon.purescala.TypeOps.isSubtypeOf
+import leon.purescala.Definitions._
+import leon.xlang.Expressions._
+import leon.xlang.ExprOps._
+
+class ExprOpsSuite extends FunSuite with helpers.ExpressionsDSL {
+
+  test("flattenBlocks does not change a simple block") {
+    assert(flattenBlocks(Block(Seq(y), x)) === Block(Seq(y), x))
+    assert(flattenBlocks(Block(Seq(y, z), x)) === Block(Seq(y, z), x))
+  }
+
+  test("flattenBlocks of a single statement removes the block") {
+    assert(flattenBlocks(Block(Seq(), x)) === x)
+    assert(flattenBlocks(Block(Seq(), y)) === y)
+  }
+
+  test("flattenBlocks of a one nested block flatten everything") {
+    assert(flattenBlocks(Block(Seq(Block(Seq(y), z)), x)) === Block(Seq(y, z), x))
+    assert(flattenBlocks(Block(Seq(y, Block(Seq(), z)), x)) === Block(Seq(y, z), x))
+  }
+
+  test("flattenBlocks of a several nested blocks flatten everything") {
+    assert(flattenBlocks(Block(Seq(Block(Seq(), y), Block(Seq(), z)), x)) === Block(Seq(y, z), x))
+  }
+
+  test("flattenBlocks of a nested block in last expr should flatten") {
+    assert(flattenBlocks(Block(Seq(Block(Seq(), y)), Block(Seq(z), x))) === Block(Seq(y, z), x))
+  }
+
+  test("flattenBlocks should eliminate intermediate UnitLiteral") {
+    assert(flattenBlocks(Block(Seq(UnitLiteral(), y, z), x)) === Block(Seq(y, z), x))
+    assert(flattenBlocks(Block(Seq(y, UnitLiteral(), z), x)) === Block(Seq(y, z), x))
+    assert(flattenBlocks(Block(Seq(UnitLiteral(), UnitLiteral(), z), x)) === Block(Seq(z), x))
+    assert(flattenBlocks(Block(Seq(UnitLiteral()), x)) === x)
+  }
+
+  test("flattenBlocks should not eliminate trailing UnitLiteral") {
+    assert(flattenBlocks(Block(Seq(x), UnitLiteral())) === Block(Seq(x), UnitLiteral()))
+  }
+
+}
-- 
GitLab