From 90156b1f287db448be86b4a86da7438ebdf6e536 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com>
Date: Mon, 2 Apr 2012 13:44:38 +0200
Subject: [PATCH] Can use pattern matching with side effects

---
 mytest/Match.scala                            | 18 ++++++++
 src/main/scala/leon/FunctionClosure.scala     |  5 ++-
 src/main/scala/leon/FunctionHoisting.scala    | 15 ++++++-
 .../leon/ImperativeCodeElimination.scala      | 44 ++++++++++++++++++-
 src/main/scala/leon/UnitElimination.scala     | 10 ++++-
 5 files changed, 87 insertions(+), 5 deletions(-)
 create mode 100644 mytest/Match.scala

diff --git a/mytest/Match.scala b/mytest/Match.scala
new file mode 100644
index 000000000..5376cc61d
--- /dev/null
+++ b/mytest/Match.scala
@@ -0,0 +1,18 @@
+object Match {
+
+  sealed abstract class A
+  case class B(b: Int) extends A
+  case class C(c: Int) extends A
+
+  def foo(a: A): Int = ({
+
+    var i = 0
+    var j = 0
+
+    {i = i + 1; a} match {
+      case B(b) => {i = i + 1; b}
+      case C(c) => {j = j + 1; i = i + 1; c}
+    }
+    i
+  }) ensuring(_ == 2)
+}
diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala
index 489568a88..29ff5341d 100644
--- a/src/main/scala/leon/FunctionClosure.scala
+++ b/src/main/scala/leon/FunctionClosure.scala
@@ -112,7 +112,10 @@ object FunctionClosure extends Pass {
       pathConstraints = pathConstraints.tail
       IfExpr(rCond, rThen, rElze).setType(i.getType)
     }
-    case m @ MatchExpr(scrut,cses) => sys.error("Will see")//MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m)
+    case m @ MatchExpr(scrut,cses) => {
+      //val rScrut = functionClosure(scrut, bindedVars)
+      m
+    }
     case t if t.isInstanceOf[Terminal] => t
     case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled)
   }
diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala
index 88881e2ce..ec02d4336 100644
--- a/src/main/scala/leon/FunctionHoisting.scala
+++ b/src/main/scala/leon/FunctionHoisting.scala
@@ -53,7 +53,20 @@ object FunctionHoisting extends Pass {
       val (r3, s3) = hoist(t3)
       (IfExpr(r1, r2, r3).setType(i.getType), s1 ++ s2 ++ s3)
     }
-    case m @ MatchExpr(scrut,cses) => sys.error("We'll see")//MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m)
+    case m @ MatchExpr(scrut,cses) => {
+      val (scrutRes, scrutSet) = hoist(scrut)
+      val (csesRes, csesSets) = cses.map{
+        case SimpleCase(pat, rhs) => {
+          val (r, s) = hoist(rhs)
+          (SimpleCase(pat, r), s)
+        }
+        case GuardedCase(pat, guard, rhs) => {
+          val (r, s) = hoist(rhs)
+          (GuardedCase(pat, guard, r), s)
+        }
+      }.unzip
+      (MatchExpr(scrutRes, csesRes).setType(m.getType), csesSets.toSet.flatten ++ scrutSet)
+    }
     case t if t.isInstanceOf[Terminal] => (t, Set())
     case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled)
   }
diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala
index 744c0b383..1ec71eea8 100644
--- a/src/main/scala/leon/ImperativeCodeElimination.scala
+++ b/src/main/scala/leon/ImperativeCodeElimination.scala
@@ -197,13 +197,53 @@ object ImperativeCodeElimination extends Pass {
       }
       case (t: Terminal) => (t, (body: Expr) => body, Map())
 
-      case m @ MatchExpr(scrut, cses) => sys.error("not supported: " + expr)
+      case m @ MatchExpr(scrut, cses) => {
+        val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure
+        val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3
+        val (scrutRes, scrutScope, scrutFun) = toFunction(scrut)
+
+        val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).toSeq
+        val resId = FreshIdentifier("res").setType(m.getType)
+        val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType))
+        val matchType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType))
+
+        val csesVals = csesRes.zip(csesFun).map{ 
+          case (cRes, cFun) => (if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match {
+            case Some(newId) => newId.toVariable
+            case None => vId.toVariable
+          }))).setType(matchType)
+        }
+
+        val newRhs = csesVals.zip(csesScope).map{ 
+          case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)).setType(matchType)
+        }
+        val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{
+          case (SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs)
+          case (GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs)
+        }).setType(matchType)
+
+        val scope = ((body: Expr) => {
+          val tupleId = FreshIdentifier("t").setType(matchType)
+          scrutScope(
+            Let(tupleId, matchExpr, 
+              if(freshIds.isEmpty)
+                Let(resId, tupleId.toVariable, body)
+              else
+                Let(resId, TupleSelect(tupleId.toVariable, 1),
+                  freshIds.zipWithIndex.foldLeft(body)((b, id) => 
+                    Let(id._1, 
+                      TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), 
+                      b)))))
+        })
+
+        (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap)
+      }
 
       case _ => sys.error("not supported: " + expr)
     }
     //val codeRepresentation = res._2(Block(res._3.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq, res._1))
     //println("res of toFunction on: " + expr + " IS: " + codeRepresentation)
-    res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])]
+    res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])] //need cast because it seems that res first map type is _ <: Identifier instead of Identifier
   }
 
   def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replace(fun.map(ids => (ids._1.toVariable, ids._2.toVariable)), expr)
diff --git a/src/main/scala/leon/UnitElimination.scala b/src/main/scala/leon/UnitElimination.scala
index 7200a26ec..9417f7bfb 100644
--- a/src/main/scala/leon/UnitElimination.scala
+++ b/src/main/scala/leon/UnitElimination.scala
@@ -119,7 +119,15 @@ object UnitElimination extends Pass {
       }
       case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v
       case (t: Terminal) => t
-      case m @ MatchExpr(scrut, cses) => sys.error("not supported: " + expr)
+      case m @ MatchExpr(scrut, cses) => {
+        val scrutRec = removeUnit(scrut)
+        val csesRec = cses.map{
+          case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs))
+          case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs))
+        }
+        val tpe = csesRec.head.rhs.getType
+        MatchExpr(scrutRec, csesRec).setType(tpe)
+      }
       case _ => sys.error("not supported: " + expr)
     }
   }
-- 
GitLab