From d9a82908388f1fe3dfc61c273c060f99d44e5dfe Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Fri, 12 Sep 2014 14:58:47 +0200
Subject: [PATCH] Evaluate Pattern matching, matchToITE no longer necessary

More precise tracing, especially for in-IDE exploration of execution
---
 .../leon/evaluators/RecursiveEvaluator.scala  | 93 +++++++++++++++----
 .../leon/evaluators/TracingEvaluator.scala    | 23 ++++-
 .../test/evaluators/EvaluatorsTests.scala     | 50 ++++++++++
 3 files changed, 145 insertions(+), 21 deletions(-)

diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index 0a825de5f..da44836d2 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -71,11 +71,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
     case Variable(id) =>
       rctx.mappings.get(id) match {
         case Some(v) =>
-          if(!isGround(v)) {
-            throw EvalError("Substitution for identifier " + id.name + " is not ground.")
-          } else {
-            v
-          }
+          e(v)
         case None =>
           throw EvalError("No value for identifier " + id.name + " in mapping.")
       }
@@ -121,11 +117,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
       val frame = rctx.withVars((tfd.params.map(_.id) zip evArgs).toMap)
       
       if(tfd.hasPrecondition) {
-        e(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match {
+        e(tfd.precondition.get)(frame, gctx) match {
           case BooleanLiteral(true) =>
           case BooleanLiteral(false) =>
             throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get)
-          case other => throw RuntimeError(typeErrorMsg(other, BooleanType))
+          case other =>
+            throw RuntimeError(typeErrorMsg(other, BooleanType))
         }
       }
 
@@ -134,15 +131,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
       }
 
       val body = tfd.body.getOrElse(rctx.mappings(tfd.id))
-      val callResult = e(matchToIfThenElse(body))(frame, gctx)
+      val callResult = e(body)(frame, gctx)
 
       if(tfd.hasPostcondition) {
         val (id, post) = tfd.postcondition.get
 
-        val freshResID = FreshIdentifier("result").setType(tfd.returnType)
-        val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post))
-
-        e(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match {
+        e(post)(frame.withNewVar(id, callResult), gctx) match {
           case BooleanLiteral(true) =>
           case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.")
           case other => throw EvalError(typeErrorMsg(other, BooleanType))
@@ -430,15 +424,82 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
         solver.free()
       }
 
+    case MatchExpr(scrut, cases) =>
+      val rscrut = e(scrut)
+
+      cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match {
+        case Some(Some((c, mappings))) =>
+          e(c.rhs)(rctx.withNewVars(mappings), gctx)
+        case _ =>
+          throw RuntimeError("MatchError: "+rscrut+" did not match any of the cases")
+      }
+
     case other =>
       context.reporter.error(other.getPos, "Error: don't know how to handle " + other + " in Evaluator.")
       throw EvalError("Unhandled case in Evaluator : " + other) 
   }
 
-  def typeErrorMsg(tree : Expr, expected : TypeTree) : String = "Type error : expected %s, found %s.".format(expected, tree)
+  def matchesCase(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[Identifier, Expr])] = {
+    import purescala.TypeTreeOps.isSubtypeOf
 
-  // quick and dirty.. don't overuse.
-  private def isGround(expr: Expr) : Boolean = {
-    variablesOf(expr) == Set.empty
+    def matchesPattern(pat: Pattern, e: Expr): Option[Map[Identifier, Expr]] = (pat, e) match {
+      case (InstanceOfPattern(ob, pct), CaseClass(ct, _)) =>
+        if (isSubtypeOf(ct, pct)) {
+          Some(obind(ob, e))
+        } else {
+          None
+        }
+      case (WildcardPattern(ob), e) =>
+        Some(obind(ob, e))
+
+      case (CaseClassPattern(ob, pct, subs), CaseClass(ct, args)) =>
+        if (pct == ct) {
+          val res = (subs zip args).map{ case (s, a) => matchesPattern(s, a) }
+          if (res.forall(_.isDefined)) {
+            Some(obind(ob, e) ++ res.flatten.flatten)
+          } else {
+            None
+          }
+        } else {
+          None
+        }
+      case (TuplePattern(ob, subs), Tuple(args)) =>
+        if (subs.size == args.size) {
+          val res = (subs zip args).map{ case (s, a) => matchesPattern(s, a) }
+          if (res.forall(_.isDefined)) {
+            Some(obind(ob, e) ++ res.flatten.flatten)
+          } else {
+            None
+          }
+        } else {
+          None
+        }
+      case _ => None
+    }
+
+    def obind(ob: Option[Identifier], e: Expr): Map[Identifier, Expr] = {
+      Map[Identifier, Expr]() ++ ob.map(id => id -> e)
+    }
+
+    caze match {
+      case SimpleCase(p, rhs) =>
+        matchesPattern(p, scrut).map( r =>
+          (caze, r)
+        )
+
+
+      case GuardedCase(p, g, rhs) =>
+        matchesPattern(p, scrut).flatMap( r =>
+          e(g)(rctx.withNewVars(r), gctx) match {
+            case BooleanLiteral(true) =>
+              Some((caze, r))
+            case _ =>
+              None
+          }
+        )
+    }
   }
+
+  def typeErrorMsg(tree : Expr, expected : TypeTree) : String = "Type error : expected %s, found %s.".format(expected, tree)
+
 }
diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala
index 945a3a121..273270b46 100644
--- a/src/main/scala/leon/evaluators/TracingEvaluator.scala
+++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala
@@ -32,6 +32,20 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex
           val res = e(b)(rctx.withNewVar(i, first), gctx)
           (res, first)
 
+        case MatchExpr(scrut, cases) =>
+          val rscrut = e(scrut)
+
+          val r = cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match {
+            case Some(Some((c, mappings))) =>
+              gctx.values ++= mappings.map { case (id, v) => id.toVariable.setPos(id) -> v }
+
+              e(c.rhs)(rctx.withNewVars(mappings), gctx)
+            case _ =>
+              throw RuntimeError("MatchError: "+rscrut+" did not match any of the cases")
+          }
+
+          (r, r)
+
         case fi @ FunctionInvocation(tfd, args) =>
           if (gctx.stepsLeft < 0) {
             throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")")
@@ -44,7 +58,7 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex
           val frame = new TracingRecContext((tfd.params.map(_.id) zip evArgs).toMap, rctx.tracingFrames-1)
 
           if(tfd.hasPrecondition) {
-            e(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match {
+            e(tfd.precondition.get)(frame, gctx) match {
               case BooleanLiteral(true) =>
               case BooleanLiteral(false) =>
                 throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get)
@@ -57,15 +71,14 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex
           }
 
           val body = tfd.body.getOrElse(rctx.mappings(tfd.id))
-          val callResult = e(matchToIfThenElse(body))(frame, gctx)
+          val callResult = e(body)(frame, gctx)
 
           if(tfd.hasPostcondition) {
             val (id, post) = tfd.postcondition.get
 
-            val freshResID = FreshIdentifier("result").setType(tfd.returnType)
-            val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post))
+            gctx.values ::= id.toVariable.setPos(id) -> callResult
 
-            e(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match {
+            e(post)(frame.withNewVar(id, callResult), gctx) match {
               case BooleanLiteral(true) =>
               case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.")
               case other => throw EvalError(typeErrorMsg(other, BooleanType))
diff --git a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala
index 402cf0225..84c546c05 100644
--- a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala
+++ b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala
@@ -477,4 +477,54 @@ class EvaluatorsTests extends leon.test.LeonTestSuite {
     val e = new CodeGenEvaluator(leonContext, prog, CodeGenParams(checkContracts = true))
     checkError(e, mkCall("c", IL(-42)))
   }
+
+  test("Pattern Matching") {
+    val p = """|object Program {
+               |  abstract class List;
+               |  case class Cons(h: Int, t: List) extends List;
+               |  case object Nil extends List;
+               |
+               |  def f1: Int = (Cons(1, Nil): List) match {
+               |    case Cons(h, t) => h
+               |    case Nil => 0
+               |  }
+               |
+               |  def f2: Int = (Cons(1, Nil): List) match {
+               |    case Cons(h, _) => h
+               |    case Nil => 0
+               |  }
+               |
+               |  def f3: Int = (Nil: List) match {
+               |    case _ => 1
+               |  }
+               |
+               |  def f4: Int = (Cons(1, Cons(2, Nil)): List) match {
+               |    case a: Cons => 1
+               |    case _ => 0
+               |  }
+               |
+               |  def f5: Int = ((Cons(1, Nil), Nil): (List, List)) match {
+               |    case (a: Cons, _) => 1
+               |    case _ => 0
+               |  }
+               |
+               |  def f6: Int = (Cons(2, Nil): List) match {
+               |    case Cons(h, t) if h > 0 => 1
+               |    case _ => 0
+               |  }
+               |}""".stripMargin
+
+    implicit val prog = parseString(p)
+    val evaluators = prepareEvaluators
+
+    for(e <- evaluators) {
+      // Some simple math.
+      checkComp(e, mkCall("f1"), IL(1))
+      checkComp(e, mkCall("f2"), IL(1))
+      checkComp(e, mkCall("f3"), IL(1))
+      checkComp(e, mkCall("f4"), IL(1))
+      checkComp(e, mkCall("f5"), IL(1))
+      checkComp(e, mkCall("f6"), IL(1))
+    }
+  }
 }
-- 
GitLab