diff --git a/src/main/scala/inox/ast/TreeOps.scala b/src/main/scala/inox/ast/TreeOps.scala index 973c25e22a566d0f5625bf2c1b9c352577be820e..2b3d76eb1b6b40f11ecbc8c701c1f442c34afc9f 100644 --- a/src/main/scala/inox/ast/TreeOps.scala +++ b/src/main/scala/inox/ast/TreeOps.scala @@ -230,31 +230,34 @@ trait SymbolTransformer { else tparams.map(tdef => t.TypeParameterDef(t.TypeParameter(tdef.id))) } - def transform(syms: s.Symbols): t.Symbols = t.NoSymbols.withFunctions { - syms.functions.values.toSeq.map(fd => new t.FunDef( - fd.id, - transformTypeParams(fd.tparams), - fd.params.map(vd => transformer.transform(vd)), - transformer.transform(fd.returnType), - transformer.transform(fd.fullBody), - fd.flags.map(f => transformer.transform(f)))) - }.withADTs { - syms.adts.values.toSeq.map { - case sort: s.ADTSort if (s eq t) => sort.asInstanceOf[t.ADTSort] - case sort: s.ADTSort => new t.ADTSort( - sort.id, - transformTypeParams(sort.tparams), - sort.cons, - sort.flags.map(f => transformer.transform(f))) - case cons: s.ADTConstructor => new t.ADTConstructor( - cons.id, - transformTypeParams(cons.tparams), - cons.sort, - cons.fields.map(vd => transformer.transform(vd)), - cons.flags.map(f => transformer.transform(f))) - } + protected def transformFunction(fd: s.FunDef): t.FunDef = new t.FunDef( + fd.id, + transformTypeParams(fd.tparams), + fd.params.map(vd => transformer.transform(vd)), + transformer.transform(fd.returnType), + transformer.transform(fd.fullBody), + fd.flags.map(f => transformer.transform(f)) + ) + + protected def transformADT(adt: s.ADTDefinition): t.ADTDefinition = adt match { + case sort: s.ADTSort if (s eq t) => sort.asInstanceOf[t.ADTSort] + case sort: s.ADTSort => new t.ADTSort( + sort.id, + transformTypeParams(sort.tparams), + sort.cons, + sort.flags.map(f => transformer.transform(f))) + case cons: s.ADTConstructor => new t.ADTConstructor( + cons.id, + transformTypeParams(cons.tparams), + cons.sort, + cons.fields.map(vd => transformer.transform(vd)), + cons.flags.map(f => transformer.transform(f))) } + def transform(syms: s.Symbols): t.Symbols = t.NoSymbols + .withFunctions(syms.functions.values.toSeq.map(transformFunction)) + .withADTs(syms.adts.values.toSeq.map(transformADT)) + def compose(that: SymbolTransformer { val transformer: TreeTransformer { val t: SymbolTransformer.this.s.type } }): SymbolTransformer {