diff --git a/src/main/scala/inox/ast/TreeOps.scala b/src/main/scala/inox/ast/TreeOps.scala
index 973c25e22a566d0f5625bf2c1b9c352577be820e..2b3d76eb1b6b40f11ecbc8c701c1f442c34afc9f 100644
--- a/src/main/scala/inox/ast/TreeOps.scala
+++ b/src/main/scala/inox/ast/TreeOps.scala
@@ -230,31 +230,34 @@ trait SymbolTransformer {
     else tparams.map(tdef => t.TypeParameterDef(t.TypeParameter(tdef.id)))
   }
 
-  def transform(syms: s.Symbols): t.Symbols = t.NoSymbols.withFunctions {
-    syms.functions.values.toSeq.map(fd => new t.FunDef(
-      fd.id,
-      transformTypeParams(fd.tparams),
-      fd.params.map(vd => transformer.transform(vd)),
-      transformer.transform(fd.returnType),
-      transformer.transform(fd.fullBody),
-      fd.flags.map(f => transformer.transform(f))))
-  }.withADTs {
-    syms.adts.values.toSeq.map {
-      case sort: s.ADTSort if (s eq t) => sort.asInstanceOf[t.ADTSort]
-      case sort: s.ADTSort => new t.ADTSort(
-        sort.id,
-        transformTypeParams(sort.tparams),
-        sort.cons,
-        sort.flags.map(f => transformer.transform(f)))
-      case cons: s.ADTConstructor => new t.ADTConstructor(
-        cons.id,
-        transformTypeParams(cons.tparams),
-        cons.sort,
-        cons.fields.map(vd => transformer.transform(vd)),
-        cons.flags.map(f => transformer.transform(f)))
-    }
+  protected def transformFunction(fd: s.FunDef): t.FunDef = new t.FunDef(
+    fd.id,
+    transformTypeParams(fd.tparams),
+    fd.params.map(vd => transformer.transform(vd)),
+    transformer.transform(fd.returnType),
+    transformer.transform(fd.fullBody),
+    fd.flags.map(f => transformer.transform(f))
+  )
+
+  protected def transformADT(adt: s.ADTDefinition): t.ADTDefinition = adt match {
+    case sort: s.ADTSort if (s eq t) => sort.asInstanceOf[t.ADTSort]
+    case sort: s.ADTSort => new t.ADTSort(
+      sort.id,
+      transformTypeParams(sort.tparams),
+      sort.cons,
+      sort.flags.map(f => transformer.transform(f)))
+    case cons: s.ADTConstructor => new t.ADTConstructor(
+      cons.id,
+      transformTypeParams(cons.tparams),
+      cons.sort,
+      cons.fields.map(vd => transformer.transform(vd)),
+      cons.flags.map(f => transformer.transform(f)))
   }
 
+  def transform(syms: s.Symbols): t.Symbols = t.NoSymbols
+    .withFunctions(syms.functions.values.toSeq.map(transformFunction))
+    .withADTs(syms.adts.values.toSeq.map(transformADT))
+
   def compose(that: SymbolTransformer {
     val transformer: TreeTransformer { val t: SymbolTransformer.this.s.type }
   }): SymbolTransformer {