diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index d31c7a3929c82072e410cbdb766fdf493aee1654..09f6a01c89d6251140aab0e3b9c19f73fadb0f5e 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1040,7 +1040,7 @@ object TreeOps { // This transformation assumes IfExpr of the form generated by decomposeIfs def patternMatchReconstruction(e: Expr): Expr = { - def pre(e: Expr): Expr = e match { + def post(e: Expr): Expr = e match { case IfExpr(cond, thenn, elze) => val TopLevelAnds(cases) = cond @@ -1133,7 +1133,20 @@ object TreeOps { p } - MatchExpr(scrutinee, Seq(SimpleCase(simplifyPattern(pattern), newThen), SimpleCase(WildcardPattern(None), elze))).setType(e.getType) + elze match { + case MatchExpr(scrut, cases) if scrut == scrutinee => + MatchExpr(scrutinee, + SimpleCase(simplifyPattern(pattern), newThen) +: + cases + ).setType(e.getType) + + case _ => + MatchExpr(scrutinee, + Seq(SimpleCase(simplifyPattern(pattern), newThen), + SimpleCase(WildcardPattern(None), elze) + ) + ).setType(e.getType) + } } else { e } @@ -1141,7 +1154,7 @@ object TreeOps { e } - simplePreTransform(pre)(e) + simplePostTransform(post)(e) } def simplifyTautologies(sf: SolverFactory[Solver])(expr : Expr) : Expr = {