diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index dcc8c493f067cb9a3018abb75784dec1b04519c4..68ce7c5c80f98b93db1281c488d701c1c808927c 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -9,6 +9,7 @@ import purescala.TreeOps._ import purescala.Trees._ import purescala.TypeTrees._ import purescala.Constructors._ +import purescala.Extractors._ import solvers.TimeoutSolver @@ -523,7 +524,6 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (caze, r) ) - case GuardedCase(p, g, rhs) => matchesPattern(p, scrut).flatMap( r => e(g)(rctx.withNewVars(r), gctx) match { diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 0b5bf20a7316bf9597a98fd32eb3d1f4bf653243..19c2daf00af172190286c5f914c32d4009a5e7d7 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -20,7 +20,7 @@ import purescala.Definitions.{ import purescala.Trees.{Expr => LeonExpr, This => LeonThis, _} import purescala.TypeTrees.{TypeTree => LeonType, _} import purescala.Common._ -import purescala.Extractors.{IsTyped,UnwrapTuple} +import purescala.Extractors._ import purescala.Constructors._ import purescala.TreeOps._ import purescala.TypeTreeOps._ diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index bb3b8156f2105940dce7a4914dee25f235594cea..6d4d1575ec98e0c7adb4669624065fe3fa8988c4 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -122,4 +122,5 @@ object Constructors { case (l1, Implies(l2, r2)) => implies(and(l1, l2), r2) case _ => Implies(lhs, rhs) } + } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index bfbcf02676071f998fba1856555630f2801392c9..5ee3980484a788a6c32db86511eac260e5f79d33 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -241,6 +241,22 @@ object Extractors { } } + object SimpleCase { + def apply(p : Pattern, rhs : Expr) = MatchCase(p, None, rhs) + def unapply(c : MatchCase) = c match { + case MatchCase(p, None, rhs) => Some((p, rhs)) + case _ => None + } + } + + object GuardedCase { + def apply(p : Pattern, g: Expr, rhs : Expr) = MatchCase(p, Some(g), rhs) + def unapply(c : MatchCase) = c match { + case MatchCase(p, Some(g), rhs) => Some((p, g, rhs)) + case _ => None + } + } + object Pattern { def unapply(p : Pattern) : Option[( Option[Identifier], diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 61fca12d3b573884101cc0d82a15cd242e5c8daf..7cdfce833fac2e957c04bb16b51426442234bbd1 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -143,24 +143,15 @@ object FunctionClosure extends TransformationPhase { } case m @ MatchExpr(scrut,cses) => { val scrutRec = functionClosure(scrut, bindedVars, id2freshId, fd2FreshFd) - val csesRec = cses.map{ - case SimpleCase(pat, rhs) => { - val binders = pat.binders - val cond = conditionForPattern(scrut, pat) - pathConstraints ::= cond - val rRhs = functionClosure(rhs, bindedVars ++ binders, id2freshId, fd2FreshFd) - pathConstraints = pathConstraints.tail - SimpleCase(pat, rRhs) - } - case GuardedCase(pat, guard, rhs) => { - val binders = pat.binders - val cond = conditionForPattern(scrut, pat) - pathConstraints ::= cond - val rRhs = functionClosure(rhs, bindedVars ++ binders, id2freshId, fd2FreshFd) - val rGuard = functionClosure(guard, bindedVars ++ binders, id2freshId, fd2FreshFd) - pathConstraints = pathConstraints.tail - GuardedCase(pat, rGuard, rRhs) - } + val csesRec = cses.map{ cse => + import cse._ + val binders = pattern.binders + val cond = conditionForPattern(scrut, pattern) + pathConstraints ::= cond + val rRhs = functionClosure(rhs, bindedVars ++ binders, id2freshId, fd2FreshFd) + val rGuard = optGuard map { functionClosure(_, bindedVars ++ binders, id2freshId, fd2FreshFd) } + pathConstraints = pathConstraints.tail + MatchCase(pattern, rGuard, rRhs) } val tpe = csesRec.head.rhs.getType matchExpr(scrutRec, csesRec).copiedFrom(m) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 08dee1c712eed27f09037d393c9c043b38b341fb..b31c3db9fa7a21618865296a6998b6816b2fc4c4 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -407,13 +407,9 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe } // Cases - case SimpleCase(pat, rhs) => - p"""|case $pat => - | $rhs""" - - case GuardedCase(pat, guard, rhs) => - p"""|case $pat if $guard => - | $rhs""" + case MatchCase(pat, optG, rhs) => + p"|case $pat "; optG foreach { g => p"if $g "}; p"""=> + | $rhs""" // Patterns case WildcardPattern(None) => p"_" diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index 010c708125ddb9a3360f0c03a15f0c2fe605f256..97ba8a1fe48171c19414f678efca5bc44c0deaa4 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -114,13 +114,7 @@ class ScopeSimplifier extends Transformer { MatchExpr(rs, cases.map { c => val (newP, newScope) = trPattern(c.pattern, scope) - - c match { - case SimpleCase(p, rhs) => - SimpleCase(newP, rec(rhs, newScope)) - case GuardedCase(p, g, rhs) => - GuardedCase(newP, rec(g, newScope), rec(rhs, newScope)) - } + MatchCase(newP, c.optGuard map {rec(_, newScope)}, rec(c.rhs, newScope)) }) case Variable(id) => diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/leon/purescala/TransformerWithPC.scala index 858f02a5f5eda3f2c4c34ec39bd2e80336d04520..41bf10a03a50b3cc3b182918617901e29dcc027a 100644 --- a/src/main/scala/leon/purescala/TransformerWithPC.scala +++ b/src/main/scala/leon/purescala/TransformerWithPC.scala @@ -35,13 +35,9 @@ abstract class TransformerWithPC extends Transformer { val subPath = register(patternExprPos, soFar) soFar = register(Not(patternExprNeg), soFar) + + MatchCase(c.pattern, c.optGuard, rec(c.rhs,subPath)).copiedFrom(c) - c match { - case SimpleCase(p, rhs) => - SimpleCase(p, rec(rhs, subPath)).copiedFrom(c) - case GuardedCase(p, g, rhs) => - GuardedCase(p, g, rec(rhs, subPath)).copiedFrom(c) - } }).copiedFrom(e) case LetTuple(is, e, b) => diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 4eddf7cdb5937030611a9a3ed1517f4a0db93ffd..33d2b564ceaf7c3dc386a65b853484225555efb5 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -418,11 +418,12 @@ object TreeOps { val allBinders: Set[Identifier] = cse.pattern.binders val subMap: Map[Identifier,Identifier] = Map(allBinders.map(i => (i, FreshIdentifier(i.name, true).setType(i.getType))).toSeq : _*) val subVarMap: Map[Expr,Expr] = subMap.map(kv => (Variable(kv._1) -> Variable(kv._2))) - - cse match { - case SimpleCase(pattern, rhs) => SimpleCase(rewritePattern(pattern, subMap), replace(subVarMap, rhs)) - case GuardedCase(pattern, guard, rhs) => GuardedCase(rewritePattern(pattern, subMap), replace(subVarMap, guard), replace(subVarMap, rhs)) - } + + MatchCase( + rewritePattern(cse.pattern, subMap), + cse.optGuard map { replace(subVarMap, _)}, + replace(subVarMap,cse.rhs) + ) } @@ -644,9 +645,9 @@ object TreeOps { case unhandled => scala.sys.error("Unhandled case in expandLets: " + unhandled) } - def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = cse match { - case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs, s)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard, s), rec(rhs, s)) + def inCase(cse: MatchCase, s: Map[Identifier,Expr]) : MatchCase = { + import cse._ + MatchCase(pattern, optGuard map { rec(_, s) }, rec(rhs,s)) } rec(expr, Map.empty) @@ -1575,18 +1576,16 @@ object TreeOps { } (cs1 zip cs2).forall { - case (SimpleCase(p1, e1), SimpleCase(p2, e2)) => + case (MatchCase(p1, g1, e1), MatchCase(p2, g2, e2)) => val (h, nm) = patternHomo(p1, p2) + val g = (g1, g2) match { + case (Some(g1), Some(g2)) => isHomo(g1,g2)(map ++ nm) + case (None, None) => true + case _ => false + } + val e = isHomo(e1, e2)(map ++ nm) - h && isHomo(e1, e2)(map ++ nm) - - case (GuardedCase(p1, g1, e1), GuardedCase(p2, g2, e2)) => - val (h, nm) = patternHomo(p1, p2) - - h && isHomo(g1, g2)(map ++ nm) && isHomo(e1, e2)(map ++ nm) - - case _ => - false + g && e && h } } diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 70b9c2116c68aa0b1717a6cc3cc5f873389982f3..88fd34b3e63ebb43ef16ae3cb7726483d44ef582 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -212,21 +212,10 @@ object Trees { } - sealed abstract class MatchCase extends Tree { - val pattern: Pattern - val rhs: Expr - val optGuard: Option[Expr] + case class MatchCase(pattern : Pattern, optGuard : Option[Expr], rhs: Expr) extends Tree { def expressions: Seq[Expr] = List(rhs) ++ optGuard } - case class SimpleCase(pattern: Pattern, rhs: Expr) extends MatchCase { - val optGuard = None - } - - case class GuardedCase(pattern: Pattern, guard: Expr, rhs: Expr) extends MatchCase { - val optGuard = Some(guard) - } - sealed abstract class Pattern extends Tree { val subPatterns: Seq[Pattern] val binder: Option[Identifier] diff --git a/src/main/scala/leon/refactor/Repairman.scala b/src/main/scala/leon/refactor/Repairman.scala index e34e20fdfbcd192c83b3a9052e5fb88eff5315f7..37bf364613d2c08d9bcb267d1b6da367fb196aa4 100644 --- a/src/main/scala/leon/refactor/Repairman.scala +++ b/src/main/scala/leon/refactor/Repairman.scala @@ -46,10 +46,10 @@ class Repairman(ctx: LeonContext, program: Program, fd: FunDef) { val testsCases = inouts.collect { case InOutExample(ins, outs) => val (patt, optGuard) = expressionToPattern(tupleWrap(ins)) - optGuard match { - case BooleanLiteral(true) => SimpleCase(patt, tupleWrap(outs)) - case guard => GuardedCase(WildcardPattern(None), guard, tupleWrap(outs)) - } + MatchCase(patt, optGuard match { + case BooleanLiteral(true) => None + case guard => Some(guard) + }, tupleWrap(outs)) }.toList val passes = if (testsCases.nonEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/GuidedDecomp.scala b/src/main/scala/leon/synthesis/rules/GuidedDecomp.scala index 03550325e33851669d9035b87a24c62d16ce021b..af2ad3b6d5e5d84ad4eaf037c300a77bd01c6a98 100644 --- a/src/main/scala/leon/synthesis/rules/GuidedDecomp.scala +++ b/src/main/scala/leon/synthesis/rules/GuidedDecomp.scala @@ -80,10 +80,7 @@ case object GuidedDecomp extends Rule("Guided Decomp") { val onSuccess: List[Solution] => Option[Solution] = { subs => val cases = for ((c, s) <- cs zip subs) yield { - c match { - case SimpleCase(c, rhs) => SimpleCase(c, s.term) - case GuardedCase(c, g, rhs) => GuardedCase(c, g, s.term) - } + c.copy(rhs = s.term) } Some(Solution( diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index ff9c79fbc0ccb9d4c9e5345cc3e10794f9326af4..25965c4ad6c71a9fa09c41a3ca5df674d0885f69 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -53,7 +53,7 @@ trait StructuralSize { val argumentPatterns = arguments.map(id => WildcardPattern(Some(id))) val sizes = arguments.map(id => size(Variable(id))) val result = sizes.foldLeft[Expr](IntLiteral(1))(Plus(_,_)) - SimpleCase(CaseClassPattern(None, c, argumentPatterns), result) + purescala.Extractors.SimpleCase(CaseClassPattern(None, c, argumentPatterns), result) } expr.getType match { diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index 9761a7b6007c0a885ef61491f5393bdc65115e40..82bb880e393e7fc48402e8897a475b9fbf8940e0 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -142,9 +142,8 @@ object UnitElimination extends TransformationPhase { case (t: Terminal) => t case m @ MatchExpr(scrut, cses) => { val scrutRec = removeUnit(scrut) - val csesRec = cses.map{ - case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs)) + val csesRec = cses.map{ cse => + MatchCase(cse.pattern, cse.optGuard map removeUnit, removeUnit(cse.rhs)) } val tpe = csesRec.head.rhs.getType matchExpr(scrutRec, csesRec).setPos(m) diff --git a/src/main/scala/leon/xlang/ArrayTransformation.scala b/src/main/scala/leon/xlang/ArrayTransformation.scala index a66541178b7e0bd962d842d005454d3a5ee1c786..c8688cc0f1f2a66a6b9adb631e43e5716c648ff5 100644 --- a/src/main/scala/leon/xlang/ArrayTransformation.scala +++ b/src/main/scala/leon/xlang/ArrayTransformation.scala @@ -77,10 +77,7 @@ object ArrayTransformation extends TransformationPhase { case m @ MatchExpr(scrut, cses) => { val scrutRec = transform(scrut) - val csesRec = cses.map{ - case SimpleCase(pat, rhs) => SimpleCase(pat, transform(rhs)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, transform(guard), transform(rhs)) - } + val csesRec = cses.map{ cse => MatchCase(cse.pattern, cse.optGuard map transform, transform(cse.rhs)) } val tpe = csesRec.head.rhs.getType matchExpr(scrutRec, csesRec).setPos(m) } diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index cb829bece8aea13d3704974f74434d409cf569f3..33056d81bd70247c1f1cc906b81f7e65eae11288 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -123,8 +123,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)) } val matchE = matchExpr(scrutRes, cses.zip(newRhs).map{ - case (sc @ SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs).setPos(sc) - case (gc @ GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs).setPos(gc) + case (mc @ MatchCase(pat, guard, _), newRhs) => MatchCase(pat, guard map { replaceNames(scrutFun, _)}, newRhs).setPos(mc) }).setPos(m) val scope = ((body: Expr) => {