diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index b5d15e4637ea84ed90b3dbbd644ac1cad761ab0f..369ac9e718c5785b1efbade3e58fffe4e1a36c84 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -65,7 +65,7 @@ object ExprOps { /** pre-traversal of the tree. * * invokes the input function on every node *before* visiting - * children. + * children. Traverse children from left to right subtrees. * * e.g. * {{{ @@ -73,8 +73,11 @@ object ExprOps { * }}} * will yield, in order: * {{{ - * f(Add(a, Minus(b, c))), f(a), f(Minus(b, c)), f(b), f(c) + * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse */ def preTraversal(f: Expr => Unit)(e: Expr): Unit = { val rec = preTraversal(f) _ @@ -96,6 +99,9 @@ object ExprOps { * {{{ * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse */ def postTraversal(f: Expr => Unit)(e: Expr): Unit = { val rec = postTraversal(f) _ diff --git a/src/unit-test/scala/leon/purescala/ExprOpsSuite.scala b/src/unit-test/scala/leon/purescala/ExprOpsSuite.scala index 40d442ef61c19a3fe85181c15f8e2e829294fdd8..91e14f1d58fd4dd48623ef66e7b286522f79f793 100644 --- a/src/unit-test/scala/leon/purescala/ExprOpsSuite.scala +++ b/src/unit-test/scala/leon/purescala/ExprOpsSuite.scala @@ -48,15 +48,110 @@ class ExprOpsSuite extends LeonTestSuite with WithLikelyEq with ExpressionsBuild assert(foldRight(foldConcatNames)(And(p, Or(q, r))) === (p.id.name + q.id.name + r.id.name)) } + private class LocalCounter { + private var c = 0 + def inc() = c += 1 + def get = c + } + + test("preTraversal works on a single node") { + val c = new LocalCounter + preTraversal(e => c.inc())(x) + assert(c.get === 1) + preTraversal(e => c.inc())(y) + assert(c.get === 2) + + var names: List[String] = List() + preTraversal({ + case Variable(id) => names ::= id.name + case _ => () + })(x) + assert(names === List(x.id.name)) + } + test("preTraversal correctly applies on every nodes on a simple expression") { + val c1 = new LocalCounter + preTraversal(e => c1.inc())(And(Seq(p, q, r))) + assert(c1.get === 4) + val c2 = new LocalCounter + preTraversal(e => c2.inc())(Or(p, q)) + assert(c2.get === 3) + preTraversal(e => c2.inc())(Plus(x, y)) + assert(c2.get === 6) + } - test("Path-aware simplifications") { - // TODO actually testing something here would be better, sorry - // PS + test("preTraversal visits children from left to right") { + var names: List[String] = List() + preTraversal({ + case Variable(id) => names ::= id.name + case _ => () + })(And(List(p, q, r))) + assert(names === List(r.id.name, q.id.name, p.id.name)) + } - assert(true) + test("preTraversal works on nexted expressions") { + val c = new LocalCounter + preTraversal(e => c.inc())(And(p, And(q, r))) + assert(c.get === 5) } + test("preTraversal traverses in pre-order") { + var nodes: List[Expr] = List() + val node = And(List(p, q, r)) + preTraversal(e => nodes ::= e)(node) + assert(nodes === List(r, q, p, node)) + } + + + test("postTraversal works on a single node") { + val c = new LocalCounter + postTraversal(e => c.inc())(x) + assert(c.get === 1) + postTraversal(e => c.inc())(y) + assert(c.get === 2) + + var names: List[String] = List() + postTraversal({ + case Variable(id) => names ::= id.name + case _ => () + })(x) + assert(names === List(x.id.name)) + } + + test("postTraversal correctly applies on every nodes on a simple expression") { + val c1 = new LocalCounter + postTraversal(e => c1.inc())(And(Seq(p, q, r))) + assert(c1.get === 4) + val c2 = new LocalCounter + postTraversal(e => c2.inc())(Or(p, q)) + assert(c2.get === 3) + postTraversal(e => c2.inc())(Plus(x, y)) + assert(c2.get === 6) + } + + test("postTraversal visits children from left to right") { + var names: List[String] = List() + postTraversal({ + case Variable(id) => names ::= id.name + case _ => () + })(And(List(p, q, r))) + assert(names === List(r.id.name, q.id.name, p.id.name)) + } + + test("postTraversal works on nexted expressions") { + val c = new LocalCounter + postTraversal(e => c.inc())(And(p, And(q, r))) + assert(c.get === 5) + } + + test("postTraversal traverses in pre-order") { + var nodes: List[Expr] = List() + val node = And(List(p, q, r)) + postTraversal(e => nodes ::= e)(node) + assert(nodes === List(node, r, q, p)) + } + + /** * If the formula consist of some top level AND, find a top level * Equals and extract it, return the remaining formula as well