diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 54bc87f1df0b9ba74292446b32266ab0fc674024..dc2352d4adeda982c03b73592265e51970435f0b 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1317,21 +1317,24 @@ object TreeOps { es.map(formulaSize).foldRight(0)(_ + _)+1 } - def collectChooses(e: Expr): List[Choose] = { - def post(e: Expr, cs: List[Choose]) = { - val newCs = e match { - case c: Choose => c :: cs - case _ => cs + def collect[C](f: PartialFunction[Expr, C])(e: Expr): List[C] = { + def post(e: Expr, cs: List[C]) = { + if (f.isDefinedAt(e)) { + (e, f(e) :: cs) + } else { + (e, cs) } - - (e, newCs) } - def combiner(cs: Seq[List[Choose]]) = { - cs.foldLeft(List[Choose]())(_ ::: _) + def combiner(cs: Seq[List[C]]) = { + cs.foldLeft(List[C]())(_ ::: _) } - genericTransform[List[Choose]]((_, _), post, combiner)(List())(e)._2 + genericTransform[List[C]]((_, _), post, combiner)(List())(e)._2 + } + + def collectChooses(e: Expr): List[Choose] = { + collect({ case c: Choose => c })(e) } def valuateWithModel(model: Map[Identifier, Expr])(id: Identifier): Expr = {