Skip to content
Snippets Groups Projects
Commit 1bcd4533 authored by Manos Koukoutos's avatar Manos Koukoutos
Browse files

Finish pattern/guard support, improvements

parent c583cc94
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ trait DSL { ...@@ -20,7 +20,7 @@ trait DSL {
case SafeSimplify => e2 case SafeSimplify => e2
} }
implicit class ExprDSL(e: Expr)(implicit simpLvl: SimplificationLevel) { implicit class ExprOps(e: Expr)(implicit simpLvl: SimplificationLevel) {
private def binOp( private def binOp(
e1: (Expr, Expr) => Expr, e1: (Expr, Expr) => Expr,
...@@ -86,8 +86,6 @@ trait DSL { ...@@ -86,8 +86,6 @@ trait DSL {
def asInstOf(tp: ClassType) = unOp(AsInstanceOf(_, tp), symbols.asInstOf(_, tp)) def asInstOf(tp: ClassType) = unOp(AsInstanceOf(_, tp), symbols.asInstOf(_, tp))
} }
private def tpl(es: Expr*) = Tuple(es.toSeq)
// Literals // Literals
def L(i: Int): Expr = IntLiteral(i) def L(i: Int): Expr = IntLiteral(i)
def L(b: BigInt): Expr = IntegerLiteral(b) def L(b: BigInt): Expr = IntegerLiteral(b)
...@@ -96,9 +94,7 @@ trait DSL { ...@@ -96,9 +94,7 @@ trait DSL {
def L(): Expr = UnitLiteral() def L(): Expr = UnitLiteral()
def L(n: BigInt, d: BigInt) = FractionLiteral(n, d) def L(n: BigInt, d: BigInt) = FractionLiteral(n, d)
def L(s: String): Expr = StringLiteral(s) def L(s: String): Expr = StringLiteral(s)
def L(e1: Expr, e2: Expr): Expr = tpl(e1, e2) def L(e1: Expr, e2: Expr, es: Expr*): Expr = Tuple(e1 :: e2 :: es.toList)
def L(e1: Expr, e2: Expr, e3: Expr): Expr = tpl(e1, e2, e3)
def L(e1: Expr, e2: Expr, e3: Expr, e4: Expr): Expr = tpl(e1, e2, e3, e4)
def L(s: Set[Expr]) = { def L(s: Set[Expr]) = {
require(s.nonEmpty) require(s.nonEmpty)
FiniteSet(s.toSeq, leastUpperBound(s.toSeq map (_.getType)).get) FiniteSet(s.toSeq, leastUpperBound(s.toSeq map (_.getType)).get)
...@@ -167,80 +163,105 @@ trait DSL { ...@@ -167,80 +163,105 @@ trait DSL {
def assertion(e: Expr) = new BlockSuspension(Assert(e, None, _)) def assertion(e: Expr) = new BlockSuspension(Assert(e, None, _))
def assertion(e: Expr, msg: String) = new BlockSuspension(Assert(e, Some(msg), _)) def assertion(e: Expr, msg: String) = new BlockSuspension(Assert(e, Some(msg), _))
// Pattern-matching
implicit class PatternMatch(scrut: Expr) { implicit class PatternMatch(scrut: Expr) {
def matchOn(cases: MatchCase* ) = { def matchOn(cases: MatchCase* ) = {
MatchExpr(scrut, cases.toList) MatchExpr(scrut, cases.toList)
} }
} }
implicit class PatternSuspension(pat: Pattern) { //Patterns
// This introduces the rhs of a case given a pattern
implicit class PatternOps(pat: Pattern) {
val guard: Option[Expr] = None
def ~> (rhs: => Expr) = { def ==>(rhs: => Expr) = {
val Seq() = pat.binders val Seq() = pat.binders
MatchCase(pat, None, rhs) MatchCase(pat, guard, rhs)
} }
def ~> (rhs: Expr => Expr) = { def ==>(rhs: Expr => Expr) = {
val Seq(b1) = pat.binders val Seq(b1) = pat.binders
MatchCase(pat, None, rhs(b1.toVariable)) MatchCase(pat, guard, rhs(b1.toVariable))
} }
def ~> (rhs: (Expr, Expr) => Expr) = { def ==>(rhs: (Expr, Expr) => Expr) = {
val Seq(b1, b2) = pat.binders val Seq(b1, b2) = pat.binders
MatchCase(pat, None, rhs(b1.toVariable, b2.toVariable)) MatchCase(pat, guard, rhs(b1.toVariable, b2.toVariable))
} }
def ~> (rhs: (Expr, Expr, Expr) => Expr) = { def ==>(rhs: (Expr, Expr, Expr) => Expr) = {
val Seq(b1, b2, b3) = pat.binders val Seq(b1, b2, b3) = pat.binders
MatchCase(pat, None, rhs(b1.toVariable, b2.toVariable, b3.toVariable)) MatchCase(pat, guard, rhs(b1.toVariable, b2.toVariable, b3.toVariable))
} }
def ~> (rhs: (Expr, Expr, Expr, Expr) => Expr) = { def ==>(rhs: (Expr, Expr, Expr, Expr) => Expr) = {
val Seq(b1, b2, b3, b4) = pat.binders val Seq(b1, b2, b3, b4) = pat.binders
MatchCase(pat, None, MatchCase(pat, guard,
rhs(b1.toVariable, b2.toVariable, b3.toVariable, b4.toVariable)) rhs(b1.toVariable, b2.toVariable, b3.toVariable, b4.toVariable))
} }
def ~|~(g: Expr) = new PatternOpsWithGuard(pat, g)
}
class PatternOpsWithGuard(pat: Pattern, g: Expr) extends PatternOps(pat) {
override val guard = Some(g)
override def ~|~(g: Expr) = sys.error("Redefining guard!")
} }
private def l2p[T](l: Literal[T]) = LiteralPattern(None, l) private def l2p[T](l: Literal[T]) = LiteralPattern(None, l)
// Literal patterns
def P(i: Int) = l2p(IntLiteral(i)) def P(i: Int) = l2p(IntLiteral(i))
def P(b: BigInt) = l2p(IntegerLiteral(b)) def P(b: BigInt) = l2p(IntegerLiteral(b))
def P(b: Boolean) = l2p(BooleanLiteral(b)) def P(b: Boolean) = l2p(BooleanLiteral(b))
def P(c: Char) = l2p(CharLiteral(c)) def P(c: Char) = l2p(CharLiteral(c))
def P() = l2p(UnitLiteral()) def P() = l2p(UnitLiteral())
def P(ps: Pattern*) = TuplePattern(None, ps.toSeq) // Binder-only patterns
def P(vd: ValDef) = WildcardPattern(Some(vd))
class CaseClassToPattern(ct: ClassType) {
def apply(ps: Pattern*) = CaseClassPattern(None, ct, ps.toSeq)
}
// case class patterns
def P(ct: ClassType) = new CaseClassToPattern(ct)
// Tuple patterns
def P(p1:Pattern, p2: Pattern, ps: Pattern*) = TuplePattern(None, p1 :: p2 :: ps.toList)
// Wildcard pattern
def __ = WildcardPattern(None) def __ = WildcardPattern(None)
// Attach binder to pattern
implicit class BinderToPattern(b: ValDef) { implicit class BinderToPattern(b: ValDef) {
def @@ (p: Pattern) = p.withBinder(b) def @@ (p: Pattern) = p.withBinder(b)
} }
implicit class CaseClassToPattern(ct: ClassType) { // Instance-of patterns
def pat(ps: Pattern*) = CaseClassPattern(None, ct, ps.toSeq)
}
implicit class TypeToInstanceOfPattern(ct: ClassType) { implicit class TypeToInstanceOfPattern(ct: ClassType) {
def :: (vd: Option[ValDef]) = InstanceOfPattern(vd, ct)
def :: (vd: ValDef) = InstanceOfPattern(Some(vd), ct) def :: (vd: ValDef) = InstanceOfPattern(Some(vd), ct)
def :: (wp: WildcardPattern) = {
if (wp.binder.nonEmpty) sys.error("Instance of pattern with named wildcardpattern?")
else InstanceOfPattern(None, ct)
} // TODO Kinda dodgy...
} }
// TODO: Remove this at some point // TODO: Remove this at some point
private def test(e1: Expr, e2: Expr, ct: ClassType)(implicit simpl: SimplificationLevel) = { private def test(e1: Expr, e2: Expr, ct: ClassType)(implicit simpl: SimplificationLevel) = {
prec(e1) in prec(e1) in
let("i" :: Untyped, e1) { i => let("i" :: Untyped, e1) { i =>
if_ (\("j" :: Untyped)(j => e1(j))) { if_ (\("j" :: Untyped)(j => e1(j))) {
e1 + e2 + i + L(42) e1 + e2 + i + L(42)
} else_ { } else_ {
assertion(L(true), "Weird things") in assertion(L(true), "Weird things") in
ct(e1, e2) matchOn ( ct(e1, e2) matchOn (
ct.pat( P(ct)(
("i" :: Untyped) :: ct, P(42), ("i" :: Untyped) :: ct,
P(P(__), ( "j" :: Untyped) @@ P(42)) P(42),
) ~> { __ :: ct,
(i, j) => e1 P("k" :: Untyped),
}, P(__, ( "j" :: Untyped) @@ P(42))
__ ~> e2 ) ==> {
) (i, j, k) => e1
} },
__ ~|~ e1 ==> e2
)
} }
}
} ensures e2 } ensures e2
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment