Skip to content
Snippets Groups Projects
Commit ddf8d882 authored by Mikaël Mayer's avatar Mikaël Mayer
Browse files

Added support for abstract evaluation of MatchExpr expressions.

Updated the test suite
parent 8ea34c35
Branches
Tags
No related merge requests found
......@@ -88,9 +88,9 @@ class AbstractEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvalu
case MatchExpr(scrut, cases) =>
val (escrut, tscrut) = e(scrut)
cases.toStream.map(c => underlying.matchesCase(escrut, c)).find(_.nonEmpty) match {
cases.toStream.map(c => matchesCaseAbstract(escrut, c)).find(_.nonEmpty) match {
case Some(Some((c, mappings))) =>
e(c.rhs)(rctx.withNewVars(mappings), gctx)
e(c.rhs)(rctx.withNewVars2(mappings), gctx)
case _ =>
throw RuntimeError("MatchError(Abstract evaluation): "+escrut.asString+" did not match any of the cases :\n" + cases.mkString("\n"))
}
......@@ -135,5 +135,78 @@ class AbstractEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvalu
}
}
def matchesCaseAbstract(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[Identifier, (Expr, Expr)])] = {
import purescala.TypeOps.isSubtypeOf
import purescala.Extractors._
def matchesPattern(pat: Pattern, expr: Expr, exprFromScrut: Expr): Option[Map[Identifier, (Expr, Expr)]] = (pat, expr) match {
case (InstanceOfPattern(ob, pct), e) =>
if (isSubtypeOf(e.getType, pct)) {
Some(obind(ob, e, exprFromScrut))
} else {
None
}
case (WildcardPattern(ob), e) =>
Some(obind(ob, e, exprFromScrut))
case (CaseClassPattern(ob, pct, subs), CaseClass(ct, args)) =>
if (pct == ct) {
val res = (subs zip args zip ct.classDef.fieldsIds).map{ case ((s, a), id) => matchesPattern(s, a, CaseClassSelector(ct, exprFromScrut, id)) }
if (res.forall(_.isDefined)) {
Some(obind(ob, expr, exprFromScrut) ++ res.flatten.flatten)
} else {
None
}
} else {
None
}
case (up @ UnapplyPattern(ob, _, subs), scrut) =>
e(functionInvocation(up.unapplyFun.fd, Seq(scrut))) match {
case (CaseClass(CaseClassType(cd, _), Seq()), eBuilt) if cd == program.library.None.get =>
None
case (CaseClass(CaseClassType(cd, _), Seq(arg)), eBuilt) if cd == program.library.Some.get =>
val res = (subs zip unwrapTuple(arg, subs.size)).zipWithIndex map {
case ((s, a), i) => matchesPattern(s, a, tupleSelect(eBuilt, i + 1, subs.size))
}
if (res.forall(_.isDefined)) {
Some(obind(ob, expr, eBuilt) ++ res.flatten.flatten)
} else {
None
}
case other =>
throw EvalError(typeErrorMsg(other._1, up.unapplyFun.returnType))
}
case (TuplePattern(ob, subs), Tuple(args)) =>
if (subs.size == args.size) {
val res = (subs zip args).zipWithIndex.map{ case ((s, a), i) => matchesPattern(s, a, TupleSelect(exprFromScrut, i + 1)) }
if (res.forall(_.isDefined)) {
Some(obind(ob, expr, exprFromScrut) ++ res.flatten.flatten)
} else {
None
}
} else {
None
}
case (LiteralPattern(ob, l1) , l2 : Literal[_]) if l1 == l2 =>
Some(obind(ob, l1, exprFromScrut))
case _ => None
}
def obind(ob: Option[Identifier], e: Expr, eBuilder: Expr): Map[Identifier, (Expr, Expr)] = {
Map[Identifier, (Expr, Expr)]() ++ ob.map(id => id -> ((e, eBuilder)))
}
caze match {
case SimpleCase(p, rhs) =>
matchesPattern(p, scrut, scrut).map(r =>
(caze, r)
)
case GuardedCase(p, g, rhs) =>
for {
r <- matchesPattern(p, scrut, scrut)
if e(g)(rctx.withNewVars2(r), gctx)._1 == BooleanLiteral(true)
} yield (caze, r)
}
}
}
......@@ -69,21 +69,26 @@ object AbstractTests {
val testFd = funDef("AbstractTests.test2")
val Leaf = cc("AbstractTests.Leaf")()
def Node(left: Expr, n: Expr, right: Expr) = cc("AbstractTests.Node")(left, n, right)
val NodeDef = classDef("AbstractTests.Node")
val NodeType = classType("AbstractTests.Node", Seq()).asInstanceOf[CaseClassType]
val ae = new AbstractEvaluator(fix._1, fix._2)
val res = ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), Node(Leaf, IntLiteral(5), Leaf)))).result match {
val input = Node(Leaf, IntLiteral(5), Leaf)
val res = ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), input))).result match {
case Some((e, expr)) =>
e should equal (IntLiteral(5))
expr should equal (IntLiteral(5))
expr should equal (CaseClassSelector(NodeType, input, NodeDef.fieldsIds(1)))
case None =>
fail("No result!")
}
val a = id("a", Int32Type)
ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), Node(Leaf, Variable(a), Leaf)))).result match {
val input2 = Node(Leaf, Variable(a), Leaf)
ae.eval(FunctionInvocation(testFd.typed, Seq(BooleanLiteral(true), input2))).result match {
case Some((e, expr)) =>
e should equal (Variable(a))
expr should equal (Variable(a))
expr should equal (CaseClassSelector(NodeType, input2, NodeDef.fieldsIds(1)))
case None =>
fail("No result!")
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment