From 90156b1f287db448be86b4a86da7438ebdf6e536 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9gis=20Blanc?= <regwblanc@gmail.com> Date: Mon, 2 Apr 2012 13:44:38 +0200 Subject: [PATCH] Can use pattern matching with side effects --- mytest/Match.scala | 18 ++++++++ src/main/scala/leon/FunctionClosure.scala | 5 ++- src/main/scala/leon/FunctionHoisting.scala | 15 ++++++- .../leon/ImperativeCodeElimination.scala | 44 ++++++++++++++++++- src/main/scala/leon/UnitElimination.scala | 10 ++++- 5 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 mytest/Match.scala diff --git a/mytest/Match.scala b/mytest/Match.scala new file mode 100644 index 000000000..5376cc61d --- /dev/null +++ b/mytest/Match.scala @@ -0,0 +1,18 @@ +object Match { + + sealed abstract class A + case class B(b: Int) extends A + case class C(c: Int) extends A + + def foo(a: A): Int = ({ + + var i = 0 + var j = 0 + + {i = i + 1; a} match { + case B(b) => {i = i + 1; b} + case C(c) => {j = j + 1; i = i + 1; c} + } + i + }) ensuring(_ == 2) +} diff --git a/src/main/scala/leon/FunctionClosure.scala b/src/main/scala/leon/FunctionClosure.scala index 489568a88..29ff5341d 100644 --- a/src/main/scala/leon/FunctionClosure.scala +++ b/src/main/scala/leon/FunctionClosure.scala @@ -112,7 +112,10 @@ object FunctionClosure extends Pass { pathConstraints = pathConstraints.tail IfExpr(rCond, rThen, rElze).setType(i.getType) } - case m @ MatchExpr(scrut,cses) => sys.error("Will see")//MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m) + case m @ MatchExpr(scrut,cses) => { + //val rScrut = functionClosure(scrut, bindedVars) + m + } case t if t.isInstanceOf[Terminal] => t case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) } diff --git a/src/main/scala/leon/FunctionHoisting.scala b/src/main/scala/leon/FunctionHoisting.scala index 88881e2ce..ec02d4336 100644 --- a/src/main/scala/leon/FunctionHoisting.scala +++ b/src/main/scala/leon/FunctionHoisting.scala @@ -53,7 +53,20 @@ object FunctionHoisting extends Pass { val (r3, s3) = hoist(t3) (IfExpr(r1, r2, r3).setType(i.getType), s1 ++ s2 ++ s3) } - case m @ MatchExpr(scrut,cses) => sys.error("We'll see")//MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m) + case m @ MatchExpr(scrut,cses) => { + val (scrutRes, scrutSet) = hoist(scrut) + val (csesRes, csesSets) = cses.map{ + case SimpleCase(pat, rhs) => { + val (r, s) = hoist(rhs) + (SimpleCase(pat, r), s) + } + case GuardedCase(pat, guard, rhs) => { + val (r, s) = hoist(rhs) + (GuardedCase(pat, guard, r), s) + } + }.unzip + (MatchExpr(scrutRes, csesRes).setType(m.getType), csesSets.toSet.flatten ++ scrutSet) + } case t if t.isInstanceOf[Terminal] => (t, Set()) case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) } diff --git a/src/main/scala/leon/ImperativeCodeElimination.scala b/src/main/scala/leon/ImperativeCodeElimination.scala index 744c0b383..1ec71eea8 100644 --- a/src/main/scala/leon/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/ImperativeCodeElimination.scala @@ -197,13 +197,53 @@ object ImperativeCodeElimination extends Pass { } case (t: Terminal) => (t, (body: Expr) => body, Map()) - case m @ MatchExpr(scrut, cses) => sys.error("not supported: " + expr) + case m @ MatchExpr(scrut, cses) => { + val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure + val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3 + val (scrutRes, scrutScope, scrutFun) = toFunction(scrut) + + val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).toSeq + val resId = FreshIdentifier("res").setType(m.getType) + val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType)) + val matchType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType)) + + val csesVals = csesRes.zip(csesFun).map{ + case (cRes, cFun) => (if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match { + case Some(newId) => newId.toVariable + case None => vId.toVariable + }))).setType(matchType) + } + + val newRhs = csesVals.zip(csesScope).map{ + case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)).setType(matchType) + } + val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{ + case (SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs) + case (GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs) + }).setType(matchType) + + val scope = ((body: Expr) => { + val tupleId = FreshIdentifier("t").setType(matchType) + scrutScope( + Let(tupleId, matchExpr, + if(freshIds.isEmpty) + Let(resId, tupleId.toVariable, body) + else + Let(resId, TupleSelect(tupleId.toVariable, 1), + freshIds.zipWithIndex.foldLeft(body)((b, id) => + Let(id._1, + TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), + b))))) + }) + + (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap) + } case _ => sys.error("not supported: " + expr) } //val codeRepresentation = res._2(Block(res._3.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq, res._1)) //println("res of toFunction on: " + expr + " IS: " + codeRepresentation) - res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])] + res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])] //need cast because it seems that res first map type is _ <: Identifier instead of Identifier } def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replace(fun.map(ids => (ids._1.toVariable, ids._2.toVariable)), expr) diff --git a/src/main/scala/leon/UnitElimination.scala b/src/main/scala/leon/UnitElimination.scala index 7200a26ec..9417f7bfb 100644 --- a/src/main/scala/leon/UnitElimination.scala +++ b/src/main/scala/leon/UnitElimination.scala @@ -119,7 +119,15 @@ object UnitElimination extends Pass { } case v @ Variable(id) => if(id2FreshId.isDefinedAt(id)) Variable(id2FreshId(id)) else v case (t: Terminal) => t - case m @ MatchExpr(scrut, cses) => sys.error("not supported: " + expr) + case m @ MatchExpr(scrut, cses) => { + val scrutRec = removeUnit(scrut) + val csesRec = cses.map{ + case SimpleCase(pat, rhs) => SimpleCase(pat, removeUnit(rhs)) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, removeUnit(guard), removeUnit(rhs)) + } + val tpe = csesRec.head.rhs.getType + MatchExpr(scrutRec, csesRec).setType(tpe) + } case _ => sys.error("not supported: " + expr) } } -- GitLab