From d9a82908388f1fe3dfc61c273c060f99d44e5dfe Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Fri, 12 Sep 2014 14:58:47 +0200 Subject: [PATCH] Evaluate Pattern matching, matchToITE no longer necessary More precise tracing, especially for in-IDE exploration of execution --- .../leon/evaluators/RecursiveEvaluator.scala | 93 +++++++++++++++---- .../leon/evaluators/TracingEvaluator.scala | 23 ++++- .../test/evaluators/EvaluatorsTests.scala | 50 ++++++++++ 3 files changed, 145 insertions(+), 21 deletions(-) diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 0a825de5f..da44836d2 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -71,11 +71,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Variable(id) => rctx.mappings.get(id) match { case Some(v) => - if(!isGround(v)) { - throw EvalError("Substitution for identifier " + id.name + " is not ground.") - } else { - v - } + e(v) case None => throw EvalError("No value for identifier " + id.name + " in mapping.") } @@ -121,11 +117,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val frame = rctx.withVars((tfd.params.map(_.id) zip evArgs).toMap) if(tfd.hasPrecondition) { - e(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match { + e(tfd.precondition.get)(frame, gctx) match { case BooleanLiteral(true) => case BooleanLiteral(false) => throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get) - case other => throw RuntimeError(typeErrorMsg(other, BooleanType)) + case other => + throw RuntimeError(typeErrorMsg(other, BooleanType)) } } @@ -134,15 +131,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) - val callResult = e(matchToIfThenElse(body))(frame, gctx) + val callResult = e(body)(frame, gctx) if(tfd.hasPostcondition) { val (id, post) = tfd.postcondition.get - val freshResID = FreshIdentifier("result").setType(tfd.returnType) - val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post)) - - e(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match { + e(post)(frame.withNewVar(id, callResult), gctx) match { case BooleanLiteral(true) => case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") case other => throw EvalError(typeErrorMsg(other, BooleanType)) @@ -430,15 +424,82 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int solver.free() } + case MatchExpr(scrut, cases) => + val rscrut = e(scrut) + + cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match { + case Some(Some((c, mappings))) => + e(c.rhs)(rctx.withNewVars(mappings), gctx) + case _ => + throw RuntimeError("MatchError: "+rscrut+" did not match any of the cases") + } + case other => context.reporter.error(other.getPos, "Error: don't know how to handle " + other + " in Evaluator.") throw EvalError("Unhandled case in Evaluator : " + other) } - def typeErrorMsg(tree : Expr, expected : TypeTree) : String = "Type error : expected %s, found %s.".format(expected, tree) + def matchesCase(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[Identifier, Expr])] = { + import purescala.TypeTreeOps.isSubtypeOf - // quick and dirty.. don't overuse. - private def isGround(expr: Expr) : Boolean = { - variablesOf(expr) == Set.empty + def matchesPattern(pat: Pattern, e: Expr): Option[Map[Identifier, Expr]] = (pat, e) match { + case (InstanceOfPattern(ob, pct), CaseClass(ct, _)) => + if (isSubtypeOf(ct, pct)) { + Some(obind(ob, e)) + } else { + None + } + case (WildcardPattern(ob), e) => + Some(obind(ob, e)) + + case (CaseClassPattern(ob, pct, subs), CaseClass(ct, args)) => + if (pct == ct) { + val res = (subs zip args).map{ case (s, a) => matchesPattern(s, a) } + if (res.forall(_.isDefined)) { + Some(obind(ob, e) ++ res.flatten.flatten) + } else { + None + } + } else { + None + } + case (TuplePattern(ob, subs), Tuple(args)) => + if (subs.size == args.size) { + val res = (subs zip args).map{ case (s, a) => matchesPattern(s, a) } + if (res.forall(_.isDefined)) { + Some(obind(ob, e) ++ res.flatten.flatten) + } else { + None + } + } else { + None + } + case _ => None + } + + def obind(ob: Option[Identifier], e: Expr): Map[Identifier, Expr] = { + Map[Identifier, Expr]() ++ ob.map(id => id -> e) + } + + caze match { + case SimpleCase(p, rhs) => + matchesPattern(p, scrut).map( r => + (caze, r) + ) + + + case GuardedCase(p, g, rhs) => + matchesPattern(p, scrut).flatMap( r => + e(g)(rctx.withNewVars(r), gctx) match { + case BooleanLiteral(true) => + Some((caze, r)) + case _ => + None + } + ) + } } + + def typeErrorMsg(tree : Expr, expected : TypeTree) : String = "Type error : expected %s, found %s.".format(expected, tree) + } diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index 945a3a121..273270b46 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -32,6 +32,20 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex val res = e(b)(rctx.withNewVar(i, first), gctx) (res, first) + case MatchExpr(scrut, cases) => + val rscrut = e(scrut) + + val r = cases.toStream.map(c => matchesCase(rscrut, c)).find(_.nonEmpty) match { + case Some(Some((c, mappings))) => + gctx.values ++= mappings.map { case (id, v) => id.toVariable.setPos(id) -> v } + + e(c.rhs)(rctx.withNewVars(mappings), gctx) + case _ => + throw RuntimeError("MatchError: "+rscrut+" did not match any of the cases") + } + + (r, r) + case fi @ FunctionInvocation(tfd, args) => if (gctx.stepsLeft < 0) { throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")") @@ -44,7 +58,7 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex val frame = new TracingRecContext((tfd.params.map(_.id) zip evArgs).toMap, rctx.tracingFrames-1) if(tfd.hasPrecondition) { - e(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match { + e(tfd.precondition.get)(frame, gctx) match { case BooleanLiteral(true) => case BooleanLiteral(false) => throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get) @@ -57,15 +71,14 @@ class TracingEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int = 1000) ex } val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) - val callResult = e(matchToIfThenElse(body))(frame, gctx) + val callResult = e(body)(frame, gctx) if(tfd.hasPostcondition) { val (id, post) = tfd.postcondition.get - val freshResID = FreshIdentifier("result").setType(tfd.returnType) - val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post)) + gctx.values ::= id.toVariable.setPos(id) -> callResult - e(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match { + e(post)(frame.withNewVar(id, callResult), gctx) match { case BooleanLiteral(true) => case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") case other => throw EvalError(typeErrorMsg(other, BooleanType)) diff --git a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala index 402cf0225..84c546c05 100644 --- a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala +++ b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala @@ -477,4 +477,54 @@ class EvaluatorsTests extends leon.test.LeonTestSuite { val e = new CodeGenEvaluator(leonContext, prog, CodeGenParams(checkContracts = true)) checkError(e, mkCall("c", IL(-42))) } + + test("Pattern Matching") { + val p = """|object Program { + | abstract class List; + | case class Cons(h: Int, t: List) extends List; + | case object Nil extends List; + | + | def f1: Int = (Cons(1, Nil): List) match { + | case Cons(h, t) => h + | case Nil => 0 + | } + | + | def f2: Int = (Cons(1, Nil): List) match { + | case Cons(h, _) => h + | case Nil => 0 + | } + | + | def f3: Int = (Nil: List) match { + | case _ => 1 + | } + | + | def f4: Int = (Cons(1, Cons(2, Nil)): List) match { + | case a: Cons => 1 + | case _ => 0 + | } + | + | def f5: Int = ((Cons(1, Nil), Nil): (List, List)) match { + | case (a: Cons, _) => 1 + | case _ => 0 + | } + | + | def f6: Int = (Cons(2, Nil): List) match { + | case Cons(h, t) if h > 0 => 1 + | case _ => 0 + | } + |}""".stripMargin + + implicit val prog = parseString(p) + val evaluators = prepareEvaluators + + for(e <- evaluators) { + // Some simple math. + checkComp(e, mkCall("f1"), IL(1)) + checkComp(e, mkCall("f2"), IL(1)) + checkComp(e, mkCall("f3"), IL(1)) + checkComp(e, mkCall("f4"), IL(1)) + checkComp(e, mkCall("f5"), IL(1)) + checkComp(e, mkCall("f6"), IL(1)) + } + } } -- GitLab