diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala index 85dbdf0030e3fa0b5a4cebb3ce78046e2676b56e..661a052872badf0cd6389beaad0093eec8891d54 100644 --- a/src/main/scala/leon/FunctionClosure.scala +++ b/src/main/scala/leon/FunctionClosure.scala @@ -114,9 +114,27 @@ object FunctionClosure extends Pass { val r = functionClosure(t, bindedVars, id2freshId, fd2FreshFd) recons(r).setType(u.getType) } - case m @ MatchExpr(scrut,cses) => { //TODO: will not work if there are actual nested function in cases - //val rScrut = functionClosure(scrut, bindedVars, id2freshId, fd2FreshFd) - m + case m @ MatchExpr(scrut,cses) => { //still needs to handle the new ids introduced by the patterns + val scrutRec = functionClosure(scrut, bindedVars, id2freshId, fd2FreshFd) + val csesRec = cses.map{ + case SimpleCase(pat, rhs) => { + val cond = conditionForPattern(scrut, pat) + pathConstraints ::= cond + val rRhs = functionClosure(rhs, bindedVars, id2freshId, fd2FreshFd) + pathConstraints = pathConstraints.tail + SimpleCase(pat, rRhs) + } + case GuardedCase(pat, guard, rhs) => { + val cond = conditionForPattern(scrut, pat) + pathConstraints ::= cond + val rRhs = functionClosure(rhs, bindedVars, id2freshId, fd2FreshFd) + val rGuard = functionClosure(guard, bindedVars, id2freshId, fd2FreshFd) + pathConstraints = pathConstraints.tail + GuardedCase(pat, rGuard, rRhs) + } + } + val tpe = csesRec.head.rhs.getType + MatchExpr(scrutRec, csesRec).setType(tpe).setPosInfo(m) } case v @ Variable(id) => id2freshId.get(id) match { case None => v