diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index c169ba5f913b100c4118baaed4cb6cc4b574345c..5faad8916234c7021a9d8a8c79069bf3882deeca 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -30,10 +30,15 @@ object Extractors { case ArrayMake(t) => Some((t, ArrayMake)) case Waypoint(i, t) => Some((t, (expr: Expr) => Waypoint(i, expr))) case e@Epsilon(t) => Some((t, (expr: Expr) => Epsilon(expr).setType(e.getType).setPosInfo(e))) + case ue: UnaryExtractable => ue.extract case _ => None } } + trait UnaryExtractable { + def extract: Option[(Expr, (Expr)=>Expr)]; + } + object BinaryOperator { def unapply(expr: Expr) : Option[(Expr,Expr,(Expr,Expr)=>Expr)] = expr match { case Equals(t1,t2) => Some((t1,t2,Equals.apply)) @@ -69,10 +74,15 @@ object Extractors { case Concat(t1,t2) => Some((t1,t2,Concat)) case ListAt(t1,t2) => Some((t1,t2,ListAt)) case wh@While(t1, t2) => Some((t1,t2, (t1, t2) => While(t1, t2).setInvariant(wh.invariant).setPosInfo(wh))) + case ex: BinaryExtractable => ex.extract case _ => None } } + trait BinaryExtractable { + def extract: Option[(Expr, Expr, (Expr, Expr)=>Expr)]; + } + object NAryOperator { def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match { case fi @ FunctionInvocation(fd, args) => Some((args, (as => FunctionInvocation(fd, as).setPosInfo(fi)))) @@ -89,10 +99,16 @@ object Extractors { case Distinct(args) => Some((args, Distinct)) case Block(args, rest) => Some((args :+ rest, exprs => Block(exprs.init, exprs.last))) case Tuple(args) => Some((args, Tuple)) + case IfExpr(cond, then, elze) => Some((Seq(cond, then, elze), (as: Seq[Expr]) => IfExpr(as(0), as(1), as(2)))) + case ex: NAryExtractable => ex.extract case _ => None } } + trait NAryExtractable { + def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)]; + } + object SimplePatternMatching { def isSimple(me: MatchExpr) : Boolean = unapply(me).isDefined diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index d99462fc669f974a11ee3260232aeaa27d7377e7..2938be18e9524f4a9684ffffdd2e51e978125b51 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -951,33 +951,56 @@ object TreeOps { fix(searchAndReplaceDFS(transform), expr) } - def genericTransform[C](down: PartialFunction[(Expr, C),(Expr, C)], up: PartialFunction[(Expr, C),(Expr, C)])(init: C)(expr: Expr) = { - val fDown = { x: (Expr, C) => if (down.isDefinedAt(x)) down(x) else x } - val fUp = { x: (Expr, C) => if (up.isDefinedAt(x)) up(x) else x } + def genericTransform[C](pre: (Expr, C) => (Expr, C), + post: (Expr, C) => (Expr, C), + combiner: (Expr, C, Seq[C]) => C)(init: C)(expr: Expr) = { - def rec(in: (Expr, C)): (Expr, C) = { + def rec(eIn: Expr, cIn: C): (Expr, C) = { - val (expr, ctx) = fDown(in) + val (expr, ctx) = pre(eIn, cIn) - val newExpr = expr match { - case UnaryOperator(e, builder) => builder(rec((e, ctx))._1) - case BinaryOperator(e1, e2, builder) => builder(rec((e1, ctx))._1, rec((e2, ctx))._1) - case NAryOperator(es, builder) => builder(es.map(e => rec((e, ctx))._1)) + val (newExpr, newC) = expr match { + case UnaryOperator(e, builder) => + val (e1, c) = rec(e, ctx) + + val newE = builder(e1) + (newE, combiner(newE, ctx, Seq(c))) + case BinaryOperator(e1, e2, builder) => + val (ne1, c1) = rec(e1, ctx) + val (ne2, c2) = rec(e2, ctx) + + val newE = builder(ne1, ne2) + (newE, combiner(newE, ctx, Seq(c1, c2))) + case NAryOperator(es, builder) => + val (nes, cs) = es.map(e => rec(e, ctx)).unzip + + val newE = builder(nes) + (newE, combiner(newE, ctx, cs)) + case e => + sys.error("Expression "+e+" ["+e.getClass+"] has no defined extractor") } - fUp((newExpr, in._2)) + post(newExpr, newC) } - rec((expr, init)) + rec(expr, init) } - def genericDFS[C](up: PartialFunction[(Expr, C), (Expr, C)])(init: C)(e: Expr) = - genericTransform[C](Map.empty, up)(init)(e) - - def genericBFS[C](down: PartialFunction[(Expr, C), (Expr, C)])(init: C)(e: Expr) = - genericTransform[C](down, Map.empty)(init)(e) + def noPre[C] (e: Expr, c: C) = (e, c) + def noPost[C](e: Expr, c: C) = (e, c) + def noCombiner[C](e: Expr, initC: C, subCs: Seq[C]) = initC def patternMatchReconstruction(e: Expr): Expr = { - e + case class Context() + + def pre(e: Expr, c: Context): (Expr, Context) = e match { + case IfExpr(cond, then, elze) => + println("Found IF: "+e) + (e, c) + case _ => + (e, c) + } + + genericTransform[Context](pre, noPost, noCombiner)(Context())(e)._1 } }