diff --git a/src/test/scala/leon/unit/purescala/DefinitionTransformerSuite.scala b/src/test/scala/leon/unit/purescala/DefinitionTransformerSuite.scala index 88e9aea53c295baa1190eda284a323c471a89c2b..cf7a1dbed2bc5c12c63144d41231dd7c61c7d04c 100644 --- a/src/test/scala/leon/unit/purescala/DefinitionTransformerSuite.scala +++ b/src/test/scala/leon/unit/purescala/DefinitionTransformerSuite.scala @@ -17,10 +17,10 @@ class DefinitionTransformerSuite extends FunSuite with ExpressionsDSL { private val fd1 = new FunDef(FreshIdentifier("f1"), Seq(), Seq(ValDef(x.id)), IntegerType) fd1.body = Some(x) - private val fd2 = new FunDef(FreshIdentifier("f1"), Seq(), Seq(ValDef(x.id)), IntegerType) + private val fd2 = new FunDef(FreshIdentifier("f2"), Seq(), Seq(ValDef(x.id)), IntegerType) fd2.body = Some(Plus(x, bi(1))) - private val fd3 = new FunDef(FreshIdentifier("f1"), Seq(), Seq(ValDef(x.id)), IntegerType) + private val fd3 = new FunDef(FreshIdentifier("f3"), Seq(), Seq(ValDef(x.id)), IntegerType) fd3.body = Some(Times(x, bi(1))) test("transformation with no rewriting should not change FunDef") { @@ -30,4 +30,33 @@ class DefinitionTransformerSuite extends FunSuite with ExpressionsDSL { assert(tr1.transform(fd3) === fd3) } + + private val classA = new CaseClassDef(FreshIdentifier("A"), Seq(), None, false) + classA.setFields(Seq(ValDef(FreshIdentifier("x", IntegerType)))) + private val classB = new CaseClassDef(FreshIdentifier("B"), Seq(), None, false) + classB.setFields(Seq(ValDef(FreshIdentifier("a", classA.typed)))) + + test("transformating type of a nested case class change all related case classes") { + val tr1 = new DefinitionTransformer { + override def transformType(t: TypeTree): Option[TypeTree] = t match { + case IntegerType => Some(BooleanType) + case _ => None + } + } + val classA2 = tr1.transform(classA) + assert(classA.id !== classA2.id) + assert(classA.fields.head.id !== classA2.fields.head.id) + assert(classA.fields.head.getType === IntegerType) + assert(classA2.fields.head.getType === BooleanType) + assert(tr1.transform(classA) === classA2) + + val classB2 = tr1.transform(classB) + assert(tr1.transform(classA) === classA2) + assert(classB.id !== classB2.id) + assert(classB.fields.head.id !== classB2.fields.head.id) + assert(classB.fields.head.getType === classA.typed) + assert(classB2.fields.head.getType === classA2.typed) + + } + }