From 2a9d3d7a677dfd1953520eda78d8097f30f0fabf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch>
Date: Mon, 18 Apr 2016 10:51:29 +0200
Subject: [PATCH] Added support for abstract evaluation of MatchExpr
 expressions. Updated the test suite

---
 .../leon/evaluators/AbstractEvaluator.scala   | 77 ++++++++++++++++++-
 .../evaluators/AbstractEvaluatorSuite.scala   | 13 +++-
 2 files changed, 84 insertions(+), 6 deletions(-)

diff --git a/src/main/scala/leon/evaluators/AbstractEvaluator.scala b/src/main/scala/leon/evaluators/AbstractEvaluator.scala
index 0befea526..b950aff5f 100644
--- a/src/main/scala/leon/evaluators/AbstractEvaluator.scala
+++ b/src/main/scala/leon/evaluators/AbstractEvaluator.scala
@@ -88,9 +88,9 @@ class AbstractEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvalu
       
     case MatchExpr(scrut, cases) =>
       val (escrut, tscrut) = e(scrut)
-      cases.toStream.map(c => underlying.matchesCase(escrut, c)).find(_.nonEmpty) match {
+      cases.toStream.map(c => matchesCaseAbstract(escrut, c)).find(_.nonEmpty) match {
         case Some(Some((c, mappings))) =>
-          e(c.rhs)(rctx.withNewVars(mappings), gctx)
+          e(c.rhs)(rctx.withNewVars2(mappings), gctx)
         case _ =>
           throw RuntimeError("MatchError(Abstract evaluation): "+escrut.asString+" did not match any of the cases :\n" + cases.mkString("\n"))
       }
@@ -135,5 +135,78 @@ class AbstractEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvalu
     }
   }
 
+  def matchesCaseAbstract(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[Identifier, (Expr, Expr)])] = {
+    import purescala.TypeOps.isSubtypeOf
+    import purescala.Extractors._
 
+    def matchesPattern(pat: Pattern, expr: Expr, exprFromScrut: Expr): Option[Map[Identifier, (Expr, Expr)]] = (pat, expr) match {
+      case (InstanceOfPattern(ob, pct), e) =>
+        if (isSubtypeOf(e.getType, pct)) {
+          Some(obind(ob, e, exprFromScrut))
+        } else {
+          None
+        }
+      case (WildcardPattern(ob), e) =>
+        Some(obind(ob, e, exprFromScrut))
+
+      case (CaseClassPattern(ob, pct, subs), CaseClass(ct, args)) =>
+        if (pct == ct) {
+          val res = (subs zip args zip ct.classDef.fieldsIds).map{ case ((s, a), id) => matchesPattern(s, a, CaseClassSelector(ct, exprFromScrut, id)) }
+          if (res.forall(_.isDefined)) {
+            Some(obind(ob, expr, exprFromScrut) ++ res.flatten.flatten)
+          } else {
+            None
+          }
+        } else {
+          None
+        }
+      case (up @ UnapplyPattern(ob, _, subs), scrut) =>
+        e(functionInvocation(up.unapplyFun.fd, Seq(scrut))) match {
+          case (CaseClass(CaseClassType(cd, _), Seq()), eBuilt) if cd == program.library.None.get =>
+            None
+          case (CaseClass(CaseClassType(cd, _), Seq(arg)), eBuilt) if cd == program.library.Some.get =>
+            val res = (subs zip unwrapTuple(arg, subs.size)).zipWithIndex map {
+              case ((s, a), i) => matchesPattern(s, a, tupleSelect(eBuilt, i + 1, subs.size))
+            }
+            if (res.forall(_.isDefined)) {
+              Some(obind(ob, expr, eBuilt) ++ res.flatten.flatten)
+            } else {
+              None
+            }
+          case other =>
+            throw EvalError(typeErrorMsg(other._1, up.unapplyFun.returnType))
+        }
+      case (TuplePattern(ob, subs), Tuple(args)) =>
+        if (subs.size == args.size) {
+          val res = (subs zip args).zipWithIndex.map{ case ((s, a), i) => matchesPattern(s, a, TupleSelect(exprFromScrut, i + 1)) }
+          if (res.forall(_.isDefined)) {
+            Some(obind(ob, expr, exprFromScrut) ++ res.flatten.flatten)
+          } else {
+            None
+          }
+        } else {
+          None
+        }
+      case (LiteralPattern(ob, l1) , l2 : Literal[_]) if l1 == l2 =>
+        Some(obind(ob, l1, exprFromScrut))
+      case _ => None
+    }
+
+    def obind(ob: Option[Identifier], e: Expr, eBuilder: Expr): Map[Identifier, (Expr, Expr)] = {
+      Map[Identifier, (Expr, Expr)]() ++ ob.map(id => id -> ((e, eBuilder)))
+    }
+
+    caze match {
+      case SimpleCase(p, rhs) =>
+        matchesPattern(p, scrut, scrut).map(r =>
+          (caze, r)
+        )
+
+      case GuardedCase(p, g, rhs) =>
+        for {
+          r <- matchesPattern(p, scrut, scrut)
+          if e(g)(rctx.withNewVars2(r), gctx)._1 == BooleanLiteral(true)
+        } yield (caze, r)
+    }
+  }
 }
diff --git a/src/test/scala/leon/integration/evaluators/AbstractEvaluatorSuite.scala b/src/test/scala/leon/integration/evaluators/AbstractEvaluatorSuite.scala
index 99678f887..f7a47c0d7 100644
--- a/src/test/scala/leon/integration/evaluators/AbstractEvaluatorSuite.scala
+++ b/src/test/scala/leon/integration/evaluators/AbstractEvaluatorSuite.scala
@@ -69,21 +69,26 @@ object AbstractTests {
     val testFd = funDef("AbstractTests.test2")
     val Leaf = cc("AbstractTests.Leaf")()
     def Node(left: Expr, n: Expr, right: Expr) = cc("AbstractTests.Node")(left, n, right)
+    val NodeDef = classDef("AbstractTests.Node")
+    val NodeType = classType("AbstractTests.Node", Seq()).asInstanceOf[CaseClassType]
     
     val ae = new AbstractEvaluator(fix._1, fix._2)
     
-    val res = ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), Node(Leaf, IntLiteral(5), Leaf)))).result match {
+    val input = Node(Leaf, IntLiteral(5), Leaf)
+    
+    val res = ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), input))).result match {
       case Some((e, expr)) =>
         e should equal (IntLiteral(5))
-        expr should equal (IntLiteral(5))
+        expr should equal (CaseClassSelector(NodeType, input, NodeDef.fieldsIds(1)))
       case None =>
         fail("No result!")
     }
     val a = id("a", Int32Type)
-    ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), Node(Leaf, Variable(a), Leaf)))).result match {
+    val input2 = Node(Leaf, Variable(a), Leaf)
+    ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), input2))).result match {
       case Some((e, expr)) =>
         e should equal (Variable(a))
-        expr should equal (Variable(a))
+        expr should equal (CaseClassSelector(NodeType, input2, NodeDef.fieldsIds(1)))
       case None =>
         fail("No result!")
     }
-- 
GitLab