diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index d410278d92782dc1a93c7e381a64706e81d7d207..5f3025c3bee35eff0922fb4f2e934096c4f4f325 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -607,7 +607,23 @@ trait CodeExtraction extends Extractors { val ccRec = rec(cc) val checkType = scalaType2PureScala(unit, silent)(tt.tpe) checkType match { - case CaseClassType(cd) => CaseClassInstanceOf(cd, ccRec) + case CaseClassType(ccd) => { + val rootType: ClassTypeDef = if(ccd.parent != None) ccd.parent.get else ccd + if(!ccRec.getType.isInstanceOf[ClassType]) { + unit.error(tr.pos, "isInstanceOf can only be used with a case class") + throw ImpureCodeEncounteredException(tr) + } else { + val testedExprType = ccRec.getType.asInstanceOf[ClassType].classDef + val testedExprRootType: ClassTypeDef = if(testedExprType.parent != None) testedExprType.parent.get else testedExprType + + if(rootType != testedExprRootType) { + unit.error(tr.pos, "isInstanceOf can only be used with compatible case classes") + throw ImpureCodeEncounteredException(tr) + } else { + CaseClassInstanceOf(ccd, ccRec) + } + } + } case _ => { unit.error(tr.pos, "isInstanceOf can only be used with a case class") throw ImpureCodeEncounteredException(tr)