diff --git a/src/main/scala/leon/ArrayTransformation.scala b/src/main/scala/leon/ArrayTransformation.scala index 7d0d94458291b71381d8719a25e13751ca3cdd9f..38f1fce44dd79089166c566a11b17e706b0f3215 100644 --- a/src/main/scala/leon/ArrayTransformation.scala +++ b/src/main/scala/leon/ArrayTransformation.scala @@ -98,6 +98,11 @@ object ArrayTransformation extends Pass { IfExpr(rc, rt, re).setType(rt.getType) } + case c @ Choose(args, body) => + val body2 = transform(body) + + Choose(args, body2).setType(c.getType).setPosInfo(c) + case m @ MatchExpr(scrut, cses) => { val scrutRec = transform(scrut) val csesRec = cses.map{ diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index 1ff95e4c922d08cf0b400c2049dc0d8139da6380..24f49dfab090b3b7735671faaa98a4b1b252b3ed 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -681,6 +681,23 @@ trait CodeExtraction extends Extractors { } Epsilon(c1).setType(pstpe).setPosInfo(epsi.pos.line, epsi.pos.column) } + + case chs @ ExChooseExpression(args, tpe, body) => { + val cTpe = scalaType2PureScala(unit, silent)(tpe) + + val vars = args map { case (tpe, sym) => + val aTpe = scalaType2PureScala(unit, silent)(tpe) + val newID = FreshIdentifier(sym.name.toString).setType(aTpe) + owners += (Variable(newID) -> None) + varSubsts(sym) = (() => Variable(newID)) + newID + } + + val cBody = rec(body) + + Choose(vars, cBody).setType(cTpe).setPosInfo(chs.pos.line, chs.pos.column) + } + case ExWaypointExpression(tpe, i, tree) => { val pstpe = scalaType2PureScala(unit, silent)(tpe) val IntLiteral(ri) = rec(i) diff --git a/src/main/scala/leon/plugin/Extractors.scala b/src/main/scala/leon/plugin/Extractors.scala index eae46482780a0362946b5ddacce1eae632ffc1e9..d3879226b4ade267f7cbbfe98d5f313bccce63ed 100644 --- a/src/main/scala/leon/plugin/Extractors.scala +++ b/src/main/scala/leon/plugin/Extractors.scala @@ -174,6 +174,21 @@ trait Extractors { case _ => None } } + + object ExChooseExpression { + def unapply(tree: Apply) : Option[(List[(Type, Symbol)], Type, Tree)] = tree match { + case a @ Apply( + TypeApply(Select(Select(funcheckIdent, utilsName), chooseName), types), + Function(vds, predicateBody) :: Nil) => { + if (utilsName.toString == "Utils" && chooseName.toString == "choose") + Some(((types.map(_.tpe) zip vds.map(_.symbol)).toList, a.tpe, predicateBody)) + else + None + } + case _ => None + } + } + object ExWaypointExpression { def unapply(tree: Apply) : Option[(Type, Tree, Tree)] = tree match { case Apply( diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index b7b5bd343c2adc3bf973440776951a124a0ddd08..72d87cec11eade2afab25312c5438725a461073c 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -154,6 +154,14 @@ object PrettyPrinter { nsb } + case c@Choose(vars, pred) => { + var nsb = sb + nsb.append("choose("+vars.mkString(", ")+" => ") + nsb = pp(pred, nsb, lvl) + nsb.append(")") + nsb + } + case Waypoint(i, expr) => { sb.append("waypoint_" + i + "(") pp(expr, sb, lvl) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index ea5b7ca0f9d6a4dea13d7687b9cb6c6ff953278b..4778c9b2dbc01d0291a408adb27ab520d8b93778 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -612,6 +612,16 @@ object Trees { i } case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPosInfo(m) + + case c @ Choose(args, body) => + val body2 = rec(body) + + if (body != body2) { + Choose(args, body2).setType(c.getType) + } else { + c + } + case t if t.isInstanceOf[Terminal] => t case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplace: " + unhandled) } @@ -724,6 +734,16 @@ object Trees { m }) } + + case c @ Choose(args, body) => + val body2 = rec(body) + + applySubst(if (body != body2) { + Choose(args, body2).setType(c.getType).setPosInfo(c) + } else { + c + }) + case t if t.isInstanceOf[Terminal] => applySubst(t) case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplaceDFS: " + unhandled) } @@ -850,6 +870,7 @@ object Trees { case ArrayMake(_) => false case ArrayClone(_) => false case Epsilon(_) => false + case Choose(_, _) => false case _ => true } def combine(b1: Boolean, b2: Boolean) = b1 && b2 @@ -863,6 +884,7 @@ object Trees { case ArrayMake(_) => false case ArrayClone(_) => false case Epsilon(_) => false + case Choose(_, _) => false case _ => b } treeCatamorphism(convert, combine, compute, expr) @@ -880,6 +902,7 @@ object Trees { } treeCatamorphism(convert, combine, compute, expr) } + def containsLetDef(expr: Expr): Boolean = { def convert(t : Expr) : Boolean = t match { case (l : LetDef) => true