diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 11e57f36e328139b633a2dd66668c5e6f13099b8..c495ccb5d54e19e65494b36387a467a6f6a9536a 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -465,18 +465,21 @@ object CodeGeneration { } def compileCaseClassDef(p : Program, ccd : CaseClassDef)(implicit env : CompilationEnvironment) : ClassFile = { - assert(ccd.hasParent) val cName = defToJVMName(p, ccd) - val pName = defToJVMName(p, ccd.parent.get) + val pName = ccd.parent.map(parent => defToJVMName(p, parent)) - val cf = new ClassFile(cName, Some(pName)) + val cf = new ClassFile(cName, pName) cf.setFlags(( CLASS_ACC_SUPER | CLASS_ACC_PUBLIC | CLASS_ACC_FINAL ).asInstanceOf[U2]) + if(ccd.parent.isEmpty) { + cf.addInterface(CaseClassClass) + } + // definition of the constructor if(ccd.fields.isEmpty) { cf.addDefaultConstructor @@ -493,7 +496,8 @@ object CodeGeneration { val cch = cf.addConstructor(namesTypes.map(_._2).toList).codeHandler - cch << ALoad(0) << InvokeSpecial(pName, constructorName, "()V") + cch << ALoad(0) + cch << InvokeSpecial(pName.getOrElse("java/lang/Object"), constructorName, "()V") var c = 1 for((nme, jvmt) <- namesTypes) { diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 4cad603f7f6a5b44fd28e3721c89461fa869f6d5..548e2ab000cccc0459245f76a79809669bcc5ffe 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -179,6 +179,10 @@ object CompilationUnit { } } + for(single <- p.singleCaseClasses) { + classes += single -> compileCaseClassDef(p, single) + } + val mainClassName = defToJVMName(p, p.mainObject) val cf = new ClassFile(mainClassName, None) diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 80f1df9843c89bb258483d177bb159238e2fe69b..6e3c0b9267e99af169f1238c062a7669a8c09349 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -44,6 +44,7 @@ object Definitions { def definedClasses = mainObject.definedClasses def classHierarchyRoots = mainObject.classHierarchyRoots def algebraicDataTypes = mainObject.algebraicDataTypes + def singleCaseClasses = mainObject.singleCaseClasses def callGraph = mainObject.callGraph def calls(f1: FunDef, f2: FunDef) = mainObject.calls(f1, f2) def callers(f1: FunDef) = mainObject.callers(f1) @@ -100,6 +101,10 @@ object Definitions { case c @ CaseClassDef(_, Some(_), _) => c }).groupBy(_.parent.get) + lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { + case c @ CaseClassDef(_, None, _) => c + } + lazy val (callGraph, callers, callees) = { type CallGraph = Set[(FunDef,FunDef)] diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index d8c673fcc627b6c57d158161c5f9d706724f0788..31906fb14b9316cf406167af096d64962c60ef62 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -60,7 +60,7 @@ object TypeTrees { def bestRealType(t: TypeTree) : TypeTree = t match { case c: ClassType if c.classDef.isInstanceOf[CaseClassDef] => { c.classDef.parent match { - case None => scala.sys.error("Asking for real type of a case class without abstract parent") + case None => CaseClassType(c.classDef.asInstanceOf[CaseClassDef]) case Some(p) => AbstractClassType(p) } } diff --git a/src/test/resources/regression/verification/purescala/valid/BestRealTypes.scala b/src/test/resources/regression/verification/purescala/valid/BestRealTypes.scala new file mode 100644 index 0000000000000000000000000000000000000000..e78cb0b0d1bce0cb54a13a839f9e4032a3eef31c --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/BestRealTypes.scala @@ -0,0 +1,25 @@ +import leon.Utils._ + +/** This benchmarks tests some potential issues with the legacy "bestRealType" function, which was original introduced to work around + * Scala's well-too-precise-for-Leon type inference. */ +object BestRealTypes { + sealed abstract class Num + case class Zero() extends Num + case class Succ(pred : Num) extends Num + + case class Wrapper(num : Num) + + def boolToNum(b : Boolean) : Num = if(b) { + Zero() + } else { + Succ(Zero()) + } + + // This requires computing the "bestRealTypes" of w1 and w2. + def zipWrap(w1 : Wrapper, w2 : Wrapper) : (Wrapper,Wrapper) = (w1, w2) + + def somethingToProve(b : Boolean) : Boolean = { + val (z1,z2) = zipWrap(Wrapper(boolToNum(b)), Wrapper(boolToNum(!b))) + z1.num == Zero() || z2.num == Zero() + } holds +} diff --git a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala index d832950cc131c84d874c299aa00e0df94719aefb..36edcf9f0bc4b221e6c696adb46b5d47ba2fed63 100644 --- a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala +++ b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala @@ -257,6 +257,8 @@ class EvaluatorsTests extends FunSuite { | case class Nil() extends List | case class Cons(head : Int, tail : List) extends List | + | case class MySingleton(i : Int) + | | def size(l : List) : Int = l match { | case Nil() => 0 | case Cons(_, xs) => 1 + size(xs) @@ -267,6 +269,8 @@ class EvaluatorsTests extends FunSuite { | def head(l : List) : Int = l match { | case Cons(h, _) => h | } + | + | def wrap(i : Int) : MySingleton = MySingleton(i) |}""".stripMargin implicit val prog = parseString(p) @@ -276,6 +280,8 @@ class EvaluatorsTests extends FunSuite { val nil = mkCaseClass("Nil") val cons12a = mkCaseClass("Cons", IL(1), mkCaseClass("Cons", IL(2), mkCaseClass("Nil"))) val cons12b = mkCaseClass("Cons", IL(1), mkCaseClass("Cons", IL(2), mkCaseClass("Nil"))) + val sing1 = mkCaseClass("MySingleton", IL(1)) + val sing2 = mkCaseClass("MySingleton", IL(2)) for(e <- evaluators) { checkComp(e, mkCall("size", nil), IL(0)) @@ -284,6 +290,8 @@ class EvaluatorsTests extends FunSuite { checkComp(e, mkCall("compare", cons12a, cons12b), T) checkComp(e, mkCall("head", cons12a), IL(1)) + checkComp(e, Equals(mkCall("wrap", IL(1)), sing1), T) + // Match error checkError(e, mkCall("head", nil)) }