diff --git a/src/test/scala/leon/unit/purescala/DefinitionTransformerSuite.scala b/src/test/scala/leon/unit/purescala/DefinitionTransformerSuite.scala
index 88e9aea53c295baa1190eda284a323c471a89c2b..cf7a1dbed2bc5c12c63144d41231dd7c61c7d04c 100644
--- a/src/test/scala/leon/unit/purescala/DefinitionTransformerSuite.scala
+++ b/src/test/scala/leon/unit/purescala/DefinitionTransformerSuite.scala
@@ -17,10 +17,10 @@ class DefinitionTransformerSuite extends FunSuite with ExpressionsDSL {
   private val fd1 = new FunDef(FreshIdentifier("f1"), Seq(), Seq(ValDef(x.id)), IntegerType)
   fd1.body = Some(x)
 
-  private val fd2 = new FunDef(FreshIdentifier("f1"), Seq(), Seq(ValDef(x.id)), IntegerType)
+  private val fd2 = new FunDef(FreshIdentifier("f2"), Seq(), Seq(ValDef(x.id)), IntegerType)
   fd2.body = Some(Plus(x, bi(1)))
 
-  private val fd3 = new FunDef(FreshIdentifier("f1"), Seq(), Seq(ValDef(x.id)), IntegerType)
+  private val fd3 = new FunDef(FreshIdentifier("f3"), Seq(), Seq(ValDef(x.id)), IntegerType)
   fd3.body = Some(Times(x, bi(1)))
 
   test("transformation with no rewriting should not change FunDef") {
@@ -30,4 +30,33 @@ class DefinitionTransformerSuite extends FunSuite with ExpressionsDSL {
     assert(tr1.transform(fd3) === fd3)
   }
 
+
+  private val classA = new CaseClassDef(FreshIdentifier("A"), Seq(), None, false)
+  classA.setFields(Seq(ValDef(FreshIdentifier("x", IntegerType))))
+  private val classB = new CaseClassDef(FreshIdentifier("B"), Seq(), None, false)
+  classB.setFields(Seq(ValDef(FreshIdentifier("a", classA.typed))))
+
+  test("transformating type of a nested case class change all related case classes") {
+    val tr1 = new DefinitionTransformer {
+      override def transformType(t: TypeTree): Option[TypeTree] = t match {
+        case IntegerType => Some(BooleanType)
+        case _ => None
+      }
+    }
+    val classA2 = tr1.transform(classA)
+    assert(classA.id !== classA2.id)
+    assert(classA.fields.head.id !== classA2.fields.head.id)
+    assert(classA.fields.head.getType === IntegerType)
+    assert(classA2.fields.head.getType === BooleanType)
+    assert(tr1.transform(classA) === classA2)
+
+    val classB2 = tr1.transform(classB)
+    assert(tr1.transform(classA) === classA2)
+    assert(classB.id !== classB2.id)
+    assert(classB.fields.head.id !== classB2.fields.head.id)
+    assert(classB.fields.head.getType === classA.typed)
+    assert(classB2.fields.head.getType === classA2.typed)
+
+  }
+
 }