From 2a9d3d7a677dfd1953520eda78d8097f30f0fabf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mika=C3=ABl=20Mayer?= <mikael.mayer@epfl.ch> Date: Mon, 18 Apr 2016 10:51:29 +0200 Subject: [PATCH] Added support for abstract evaluation of MatchExpr expressions. Updated the test suite --- .../leon/evaluators/AbstractEvaluator.scala | 77 ++++++++++++++++++- .../evaluators/AbstractEvaluatorSuite.scala | 13 +++- 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/src/main/scala/leon/evaluators/AbstractEvaluator.scala b/src/main/scala/leon/evaluators/AbstractEvaluator.scala index 0befea526..b950aff5f 100644 --- a/src/main/scala/leon/evaluators/AbstractEvaluator.scala +++ b/src/main/scala/leon/evaluators/AbstractEvaluator.scala @@ -88,9 +88,9 @@ class AbstractEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvalu case MatchExpr(scrut, cases) => val (escrut, tscrut) = e(scrut) - cases.toStream.map(c => underlying.matchesCase(escrut, c)).find(_.nonEmpty) match { + cases.toStream.map(c => matchesCaseAbstract(escrut, c)).find(_.nonEmpty) match { case Some(Some((c, mappings))) => - e(c.rhs)(rctx.withNewVars(mappings), gctx) + e(c.rhs)(rctx.withNewVars2(mappings), gctx) case _ => throw RuntimeError("MatchError(Abstract evaluation): "+escrut.asString+" did not match any of the cases :\n" + cases.mkString("\n")) } @@ -135,5 +135,78 @@ class AbstractEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvalu } } + def matchesCaseAbstract(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[Identifier, (Expr, Expr)])] = { + import purescala.TypeOps.isSubtypeOf + import purescala.Extractors._ + def matchesPattern(pat: Pattern, expr: Expr, exprFromScrut: Expr): Option[Map[Identifier, (Expr, Expr)]] = (pat, expr) match { + case (InstanceOfPattern(ob, pct), e) => + if (isSubtypeOf(e.getType, pct)) { + Some(obind(ob, e, exprFromScrut)) + } else { + None + } + case (WildcardPattern(ob), e) => + Some(obind(ob, e, exprFromScrut)) + + case (CaseClassPattern(ob, pct, subs), CaseClass(ct, args)) => + if (pct == ct) { + val res = (subs zip args zip ct.classDef.fieldsIds).map{ case ((s, a), id) => matchesPattern(s, a, CaseClassSelector(ct, exprFromScrut, id)) } + if (res.forall(_.isDefined)) { + Some(obind(ob, expr, exprFromScrut) ++ res.flatten.flatten) + } else { + None + } + } else { + None + } + case (up @ UnapplyPattern(ob, _, subs), scrut) => + e(functionInvocation(up.unapplyFun.fd, Seq(scrut))) match { + case (CaseClass(CaseClassType(cd, _), Seq()), eBuilt) if cd == program.library.None.get => + None + case (CaseClass(CaseClassType(cd, _), Seq(arg)), eBuilt) if cd == program.library.Some.get => + val res = (subs zip unwrapTuple(arg, subs.size)).zipWithIndex map { + case ((s, a), i) => matchesPattern(s, a, tupleSelect(eBuilt, i + 1, subs.size)) + } + if (res.forall(_.isDefined)) { + Some(obind(ob, expr, eBuilt) ++ res.flatten.flatten) + } else { + None + } + case other => + throw EvalError(typeErrorMsg(other._1, up.unapplyFun.returnType)) + } + case (TuplePattern(ob, subs), Tuple(args)) => + if (subs.size == args.size) { + val res = (subs zip args).zipWithIndex.map{ case ((s, a), i) => matchesPattern(s, a, TupleSelect(exprFromScrut, i + 1)) } + if (res.forall(_.isDefined)) { + Some(obind(ob, expr, exprFromScrut) ++ res.flatten.flatten) + } else { + None + } + } else { + None + } + case (LiteralPattern(ob, l1) , l2 : Literal[_]) if l1 == l2 => + Some(obind(ob, l1, exprFromScrut)) + case _ => None + } + + def obind(ob: Option[Identifier], e: Expr, eBuilder: Expr): Map[Identifier, (Expr, Expr)] = { + Map[Identifier, (Expr, Expr)]() ++ ob.map(id => id -> ((e, eBuilder))) + } + + caze match { + case SimpleCase(p, rhs) => + matchesPattern(p, scrut, scrut).map(r => + (caze, r) + ) + + case GuardedCase(p, g, rhs) => + for { + r <- matchesPattern(p, scrut, scrut) + if e(g)(rctx.withNewVars2(r), gctx)._1 == BooleanLiteral(true) + } yield (caze, r) + } + } } diff --git a/src/test/scala/leon/integration/evaluators/AbstractEvaluatorSuite.scala b/src/test/scala/leon/integration/evaluators/AbstractEvaluatorSuite.scala index 99678f887..f7a47c0d7 100644 --- a/src/test/scala/leon/integration/evaluators/AbstractEvaluatorSuite.scala +++ b/src/test/scala/leon/integration/evaluators/AbstractEvaluatorSuite.scala @@ -69,21 +69,26 @@ object AbstractTests { val testFd = funDef("AbstractTests.test2") val Leaf = cc("AbstractTests.Leaf")() def Node(left: Expr, n: Expr, right: Expr) = cc("AbstractTests.Node")(left, n, right) + val NodeDef = classDef("AbstractTests.Node") + val NodeType = classType("AbstractTests.Node", Seq()).asInstanceOf[CaseClassType] val ae = new AbstractEvaluator(fix._1, fix._2) - val res = ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), Node(Leaf, IntLiteral(5), Leaf)))).result match { + val input = Node(Leaf, IntLiteral(5), Leaf) + + val res = ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), input))).result match { case Some((e, expr)) => e should equal (IntLiteral(5)) - expr should equal (IntLiteral(5)) + expr should equal (CaseClassSelector(NodeType, input, NodeDef.fieldsIds(1))) case None => fail("No result!") } val a = id("a", Int32Type) - ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), Node(Leaf, Variable(a), Leaf)))).result match { + val input2 = Node(Leaf, Variable(a), Leaf) + ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), input2))).result match { case Some((e, expr)) => e should equal (Variable(a)) - expr should equal (Variable(a)) + expr should equal (CaseClassSelector(NodeType, input2, NodeDef.fieldsIds(1))) case None => fail("No result!") } -- GitLab