diff --git a/src/main/scala/leon/xlang/ArrayTransformation.scala b/src/main/scala/leon/xlang/ArrayTransformation.scala index 32d64e95f0a6e219e24b9b130becabd9e114a2b6..5c4b314ea7426211edb23de946372a07c959c682 100644 --- a/src/main/scala/leon/xlang/ArrayTransformation.scala +++ b/src/main/scala/leon/xlang/ArrayTransformation.scala @@ -46,36 +46,6 @@ object ArrayTransformation extends TransformationPhase { Variable(env.getOrElse(i, i)) } - case LetVar(id, e, b) => { - val er = transform(e) - val br = transform(b) - LetVar(id, er, br) - } - case wh@While(c, e) => { - val newWh = While(transform(c), transform(e)) - newWh.invariant = wh.invariant.map(i => transform(i)) - newWh.setPos(wh) - newWh - } - - case ite@IfExpr(c, t, e) => { - val rc = transform(c) - val rt = transform(t) - val re = transform(e) - IfExpr(rc, rt, re) - } - - case m @ MatchExpr(scrut, cses) => { - val scrutRec = transform(scrut) - val csesRec = cses.map{ cse => MatchCase(cse.pattern, cse.optGuard map transform, transform(cse.rhs)) } - val tpe = csesRec.head.rhs.getType - matchExpr(scrutRec, csesRec).setPos(m) - } - case LetDef(fd, b) => { - fd.fullBody = transform(fd.fullBody) - val rb = transform(b) - LetDef(fd, rb) - } case n @ NAryOperator(args, recons) => recons(args.map(transform)) case b @ BinaryOperator(a1, a2, recons) => recons(transform(a1), transform(a2)) case u @ UnaryOperator(a, recons) => recons(transform(a)) diff --git a/src/main/scala/leon/xlang/NoXLangFeaturesChecking.scala b/src/main/scala/leon/xlang/NoXLangFeaturesChecking.scala index 08425aada0b910f6e26ab11134884bf35a08455e..05b9c255387365172e85804ab7ea9c8ee528f55d 100644 --- a/src/main/scala/leon/xlang/NoXLangFeaturesChecking.scala +++ b/src/main/scala/leon/xlang/NoXLangFeaturesChecking.scala @@ -9,39 +9,17 @@ import purescala.Trees._ import purescala.Definitions._ import purescala.Constructors._ -import purescala.TreeOps.exists import xlang.Trees._ +import xlang.TreeOps.isXLang object NoXLangFeaturesChecking extends UnitPhase[Program] { val name = "no-xlang" val description = "Ensure and enforce that no xlang features are used" - private def isXlangExpr(expr: Expr): Boolean = expr match { - - case Block(_, _) => true - - case Assignment(_, _) => true - - case While(_, _) => true - - case Epsilon(_, _) => true - - case EpsilonVariable(_, _) => true - - case LetVar(_, _, _) => true - - case Waypoint(_, _, _) => true - - case ArrayUpdate(_, _, _) => true - - case _ => false - - } - override def apply(ctx: LeonContext, pgm: Program): Unit = { pgm.definedFunctions.foreach(fd => { - if(exists(isXlangExpr)(fd.fullBody)) { + if(isXLang(fd.fullBody)) { ctx.reporter.fatalError(fd.fullBody.getPos, "Expr is using xlang features") } }) diff --git a/src/main/scala/leon/xlang/TreeOps.scala b/src/main/scala/leon/xlang/TreeOps.scala index d8d283b88d41b3e5bc9348eb3c4ae5c7284e5e0b..ce2944a8f0a575e1a285e39409e0834ac7bab0b0 100644 --- a/src/main/scala/leon/xlang/TreeOps.scala +++ b/src/main/scala/leon/xlang/TreeOps.scala @@ -12,26 +12,21 @@ import purescala.TreeOps._ import purescala.Extractors._ object TreeOps { + + def isXLang(expr: Expr): Boolean = exists { + case Block(_, _) | Assignment(_, _) | + While(_, _) | Epsilon(_, _) | + EpsilonVariable(_, _) | + LetVar(_, _, _) | Waypoint(_, _, _) | + ArrayUpdate(_, _, _) + => true + case _ => false + }(expr) - //checking whether the expr is not pure, that is it contains any non-pure construct: - // assign, while, blocks, array, ... - def isXLang(expr: Expr): Boolean = { - exists { _ match { - case Block(_, _) => true - case Assignment(_, _) => true - case While(_, _) => true - case LetVar(_, _, _) => true - case LetDef(_, _) => true - case ArrayUpdate(_, _, _) => true - case Epsilon(_, _) => true - case _ => false - }}(expr) - } - - def containsEpsilon(e: Expr) = exists{ _ match { - case (l: Epsilon) => true - case _ => false - }}(e) + def containsEpsilon(e: Expr) = exists{ + case (l: Epsilon) => true + case _ => false + }(e) def flattenBlocks(expr: Expr): Expr = { postMap({ @@ -51,3 +46,4 @@ object TreeOps { })(expr) } } + diff --git a/src/main/scala/leon/xlang/Trees.scala b/src/main/scala/leon/xlang/Trees.scala index 50e87706f39ab134c5a385f19211e01199bc03cd..2740b3d5c8e4de56583f75c2d9cfda7e540e62bb 100644 --- a/src/main/scala/leon/xlang/Trees.scala +++ b/src/main/scala/leon/xlang/Trees.scala @@ -45,7 +45,7 @@ object Trees { } } - case class While(cond: Expr, body: Expr) extends Expr with BinaryExtractable with PrettyPrintable { + case class While(cond: Expr, body: Expr) extends Expr with NAryExtractable with PrettyPrintable { val getType = UnitType var invariant: Option[Expr] = None @@ -53,8 +53,11 @@ object Trees { def setInvariant(inv: Expr) = { invariant = Some(inv); this } def setInvariant(inv: Option[Expr]) = { invariant = inv; this } - def extract: Option[(Expr, Expr, (Expr, Expr)=>Expr)] = { - Some((cond, body, (t1, t2) => While(t1, t2).setInvariant(this.invariant).setPos(this))) + def extract: Option[(Seq[Expr], Seq[Expr]=>Expr)] = { + Some((Seq(cond, body) ++ invariant, { + case Seq(e1, e2) => While(e1, e2).setPos(this) + case Seq(e1, e2, e3) => While(e1, e2).setInvariant(e3).setPos(this) + })) } def printWith(implicit pctx: PrinterContext) { @@ -97,8 +100,7 @@ object Trees { val getType = body.getType def extract: Option[(Expr, Expr, (Expr, Expr)=>Expr)] = { - val LetVar(binders, expr, body) = this - Some((expr, body, (e: Expr, b: Expr) => LetVar(binders, e, b))) + Some((value, body, (e: Expr, b: Expr) => LetVar(binder, e, b))) } def printWith(implicit pctx: PrinterContext) {