-
Regis Blanc authoredRegis Blanc authored
SubTreeOps.scala 9.62 KiB
/* Copyright 2009-2015 EPFL, Lausanne */
package leon
package purescala
import Common._
import utils._
object SubTreeOps {
trait Extractor[SubTree <: Tree] {
def unapply(e: SubTree): Option[(Seq[SubTree], (Seq[SubTree]) => SubTree)]
}
}
trait SubTreeOps[SubTree <: Tree] {
val Deconstructor: SubTreeOps.Extractor[SubTree]
/* ========
* Core API
* ========
*
* All these functions should be stable, tested, and used everywhere. Modify
* with care.
*/
/** Does a right tree fold
*
* A right tree fold applies the input function to the subnodes first (from left
* to right), and combine the results along with the current node value.
*
* @param f a function that takes the current node and the seq
* of results form the subtrees.
* @param e The value on which to apply the fold.
* @return The expression after applying `f` on all subtrees.
* @note the computation is lazy, hence you should not rely on side-effects of `f`
*/
def fold[T](f: (SubTree, Seq[T]) => T)(e: SubTree): T = {
val rec = fold(f) _
val Deconstructor(es, _) = e
//Usages of views makes the computation lazy. (which is useful for
//contains-like operations)
f(e, es.view.map(rec))
}
/** Pre-traversal of the tree.
*
* Invokes the input function on every node '''before''' visiting
* children. Traverse children from left to right subtrees.
*
* e.g.
* {{{
* Add(a, Minus(b, c))
* }}}
* will yield, in order:
* {{{
* f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c)
* }}}
*
* @param f a function to apply on each node of the expression
* @param e the expression to traverse
*/
def preTraversal(f: SubTree => Unit)(e: SubTree): Unit = {
val rec = preTraversal(f) _
val Deconstructor(es, _) = e
f(e)
es.foreach(rec)
}
/** Post-traversal of the tree.
*
* Invokes the input function on every node '''after''' visiting
* children.
*
* e.g.
* {{{
* Add(a, Minus(b, c))
* }}}
* will yield, in order:
* {{{
* f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c)))
* }}}
*
* @param f a function to apply on each node of the expression
* @param e the expression to traverse
*/
def postTraversal(f: SubTree => Unit)(e: SubTree): Unit = {
val rec = postTraversal(f) _
val Deconstructor(es, _) = e
es.foreach(rec)
f(e)
}
/** Pre-transformation of the tree.
*
* Takes a partial function of replacements and substitute
* '''before''' recursing down the trees.
*
* Supports two modes :
*
* - If applyRec is false (default), will only substitute once on each level.
*
* e.g.
* {{{
* Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f
* }}}
* will yield:
* {{{
* Add(a, d) // And not Add(a, f) because it only substitute once for each level.
* }}}
*
* - If applyRec is true, it will substitute multiple times on each level:
*
* e.g.
* {{{
* Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f
* }}}
* will yield:
* {{{
* Add(a, f)
* }}}
*
* @note The mode with applyRec true can diverge if f is not well formed
*/
def preMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = {
val rec = preMap(f, applyRec) _
val newV = if (applyRec) {
// Apply f as long as it returns Some()
fixpoint { e : SubTree => f(e) getOrElse e } (e)
} else {
f(e) getOrElse e
}
val Deconstructor(es, builder) = newV
val newEs = es.map(rec)
if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) {
builder(newEs).copiedFrom(newV)
} else {
newV
}
}
/** Post-transformation of the tree.
*
* Takes a partial function of replacements.
* Substitutes '''after''' recursing down the trees.
*
* Supports two modes :
*
* - If applyRec is false (default), will only substitute once on each level.
* e.g.
* {{{
* Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e
* }}}
* will yield:
* {{{
* Add(a, Minus(e, c))
* }}}
*
* - If applyRec is true, it will substitute multiple times on each level:
* e.g.
* {{{
* Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f
* }}}
* will yield:
* {{{
* Add(a, f)
* }}}
*
* @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent)
*/
def postMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = {
val rec = postMap(f, applyRec) _
val Deconstructor(es, builder) = e
val newEs = es.map(rec)
val newV = {
if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) {
builder(newEs).copiedFrom(e)
} else {
e
}
}
if (applyRec) {
// Apply f as long as it returns Some()
fixpoint { e : SubTree => f(e) getOrElse e } (newV)
} else {
f(newV) getOrElse newV
}
}
/** Applies functions and combines results in a generic way
*
* Start with an initial value, and apply functions to nodes before
* and after the recursion in the children. Combine the results of
* all children and apply a final function on the resulting node.
*
* @param pre a function applied on a node before doing a recursion in the children
* @param post a function applied to the node built from the recursive application to
all children
* @param combiner a function to combine the resulting values from all children with
the current node
* @param init the initial value
* @param expr the expression on which to apply the transform
*
* @see [[simpleTransform]]
* @see [[simplePreTransform]]
* @see [[simplePostTransform]]
*/
def genericTransform[C](pre: (SubTree, C) => (SubTree, C),
post: (SubTree, C) => (SubTree, C),
combiner: (SubTree, Seq[C]) => C)(init: C)(expr: SubTree) = {
def rec(eIn: SubTree, cIn: C): (SubTree, C) = {
val (expr, ctx) = pre(eIn, cIn)
val Deconstructor(es, builder) = expr
val (newExpr, newC) = {
val (nes, cs) = es.map{ rec(_, ctx)}.unzip
val newE = builder(nes).copiedFrom(expr)
(newE, combiner(newE, cs))
}
post(newExpr, newC)
}
rec(expr, init)
}
/** Pre-transformation of the tree, with a context value from "top-down".
*
* Takes a partial function of replacements.
* Substitutes '''before''' recursing down the trees. The function returns
* an option of the new value, as well as the new context to be used for
* the recursion in its children. The context is "lost" when going back up,
* so changes made by one node will not be see by its siblings.
*/
def preMapWithContext[C](f: (SubTree, C) => (Option[SubTree], C), applyRec: Boolean = false)
(e: SubTree, c: C): SubTree = {
def rec(expr: SubTree, context: C): SubTree = {
val (newV, newCtx) = {
if(applyRec) {
var ctx = context
val finalV = fixpoint{ e: SubTree => {
val res = f(e, ctx)
ctx = res._2
res._1.getOrElse(e)
}} (expr)
(finalV, ctx)
} else {
val res = f(expr, context)
(res._1.getOrElse(expr), res._2)
}
}
val Deconstructor(es, builder) = newV
val newEs = es.map(e => rec(e, newCtx))
if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) {
builder(newEs).copiedFrom(newV)
} else {
newV
}
}
rec(e, c)
}
def preFoldWithContext[C](f: (SubTree, C) => C, combiner: (SubTree, C, Seq[C]) => C)
(e: SubTree, c: C): C = {
def rec(eIn: SubTree, cIn: C): C = {
val ctx = f(eIn, cIn)
val Deconstructor(es, _) = eIn
val cs = es.map{ rec(_, ctx) }
combiner(eIn, ctx, cs)
}
rec(e, c)
}
/*
* =============
* Auxiliary API
* =============
*
* Convenient methods using the Core API.
*/
/** Checks if the predicate holds in some sub-expression */
def exists(matcher: SubTree => Boolean)(e: SubTree): Boolean = {
fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e)
}
/** Collects a set of objects from all sub-expressions */
def collect[T](matcher: SubTree => Set[T])(e: SubTree): Set[T] = {
fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e)
}
def collectPreorder[T](matcher: SubTree => Seq[T])(e: SubTree): Seq[T] = {
fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e)
}
/** Returns a set of all sub-expressions matching the predicate */
def filter(matcher: SubTree => Boolean)(e: SubTree): Set[SubTree] = {
collect[SubTree] { e => Set(e) filter matcher }(e)
}
/** Counts how many times the predicate holds in sub-expressions */
def count(matcher: SubTree => Int)(e: SubTree): Int = {
fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e)
}
/** Replaces bottom-up sub-expressions by looking up for them in a map */
def replace(substs: Map[SubTree,SubTree], expr: SubTree) : SubTree = {
postMap(substs.lift)(expr)
}
/** Replaces bottom-up sub-expressions by looking up for them in the provided order */
def replaceSeq(substs: Seq[(SubTree, SubTree)], expr: SubTree): SubTree = {
var res = expr
for (s <- substs) {
res = replace(Map(s), res)
}
res
}
}