diff --git a/mytest/Match.scala b/mytest/Match.scala new file mode 100644 index 0000000000000000000000000000000000000000..5376cc61d1df009da2c8ba09ea2c57c3bf248cd7 --- /dev/null +++ b/mytest/Match.scala @@ -0,0 +1,18 @@ +object Match { + + sealed abstract class A + case class B(b: Int) extends A + case class C(c: Int) extends A + + def foo(a: A): Int = ({ + + var i = 0 + var j = 0 + + {i = i + 1; a} match { + case B(b) => {i = i + 1; b} + case C(c) => {j = j + 1; i = i + 1; c} + } + i + }) ensuring(_ == 2) +} diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala index 489568a889f3bf9b56f1a45643286e44ed7e8383..29ff5341da5b97ae45c6940c1a4575103a506cd6 100644 --- a/src/main/scala/leon/FunctionClosure.scala +++ b/src/main/scala/leon/FunctionClosure.scala @@ -112,7 +112,10 @@ object FunctionClosure extends Pass { pathConstraints = pathConstraints.tail IfExpr(rCond, rThen, rElze).setType(i.getType) } - case m @ MatchExpr(scrut,cses) => sys.error("Will see")//MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m) + case m @ MatchExpr(scrut,cses) => { + //val rScrut = functionClosure(scrut, bindedVars) + m + } case t if t.isInstanceOf[Terminal] => t case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) } diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala index 88881e2ceebcc0957f77dc928e15db448bd72f43..ec02d43360c0180925252ab6ab8f32590dac038b 100644 --- a/src/main/scala/leon/FunctionHoisting.scala +++ b/src/main/scala/leon/FunctionHoisting.scala @@ -53,7 +53,20 @@ object FunctionHoisting extends Pass { val (r3, s3) = hoist(t3) (IfExpr(r1, r2, r3).setType(i.getType), s1 ++ s2 ++ s3) } - case m @ MatchExpr(scrut,cses) => sys.error("We'll see")//MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m) + case m @ MatchExpr(scrut,cses) => { + val (scrutRes, scrutSet) = hoist(scrut) + val (csesRes, csesSets) = cses.map{ + case SimpleCase(pat, rhs) => { + val (r, s) = hoist(rhs) + (SimpleCase(pat, r), s) + } + case GuardedCase(pat, guard, rhs) => { + val (r, s) = hoist(rhs) + (GuardedCase(pat, guard, r), s) + } + }.unzip + (MatchExpr(scrutRes, csesRes).setType(m.getType), csesSets.toSet.flatten ++ scrutSet) + } case t if t.isInstanceOf[Terminal] => (t, Set()) case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) } diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index 744c0b383771f47bfc70b7361d3b2716b093fb42..1ec71eea823b2b459a702302feb65e25aac3f806 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -197,13 +197,53 @@ object ImperativeCodeElimination extends Pass { } case (t: Terminal) => (t, (body: Expr) => body, Map()) - case m @ MatchExpr(scrut, cses) => sys.error("not supported: " + expr) + case m @ MatchExpr(scrut, cses) => { + val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure + val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3 + val (scrutRes, scrutScope, scrutFun) = toFunction(scrut) + + val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).toSeq + val resId = FreshIdentifier("res").setType(m.getType) + val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val matchType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) + + val csesVals = csesRes.zip(csesFun).map{ + case (cRes, cFun) => (if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + }))).setType(matchType) + } + + val newRhs = csesVals.zip(csesScope).map{ + case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)).setType(matchType) + } + val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{ + case (SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs) + case (GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs) + }).setType(matchType) + + val scope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(matchType) + scrutScope( + Let(tupleId, matchExpr, + if(freshIds.isEmpty) + Let(resId, tupleId.toVariable, body) + else + Let(resId, TupleSelect(tupleId.toVariable, 1), + freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), + b))))) + }) + + (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) + } case _ => sys.error("not supported: " + expr) } //val codeRepresentation = res._2(Block(res._3.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq, res._1)) //println("res of toFunction on: " + expr + " IS: " + codeRepresentation) - res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])] + res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])] //need cast because it seems that res first map type is _ <: Identifier instead of Identifier } def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replace(fun.map(ids => (ids._1.toVariable, ids._2.toVariable)), expr) diff --git a/src/main/scala/leon/UnitElimination.scala b/src/main/scala/leon/UnitElimination.scala index 7200a26ecb6f9d913971fbe29897efb4daa33529..9417f7bfb97e69c43da0b9ca940af2f069d21159 100644 --- a/src/main/scala/leon/UnitElimination.scala +++ b/src/main/scala/leon/UnitElimination.scala @@ -119,7 +119,15 @@ object UnitElimination extends Pass { } case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v case (t: Terminal) => t - case m @ MatchExpr(scrut, cses) => sys.error("not supported: " + expr) + case m @ MatchExpr(scrut, cses) => { + val scrutRec = removeUnit(scrut) + val csesRec = cses.map{ + case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs)) + } + val tpe = csesRec.head.rhs.getType + MatchExpr(scrutRec, csesRec).setType(tpe) + } case _ => sys.error("not supported: " + expr) } }