From 496c4fa729b1388a070e393d4807e97aeedb7cdb Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Thu, 14 Feb 2013 16:07:14 +0100 Subject: [PATCH] Implement support for "case object" --- .../scala/leon/plugin/CodeExtraction.scala | 77 +++++++++++-------- src/main/scala/leon/plugin/Extractors.scala | 10 +++ .../scala/leon/purescala/Definitions.scala | 1 + .../scala/leon/purescala/PrettyPrinter.scala | 6 +- .../scala/leon/purescala/ScalaPrinter.scala | 6 +- .../purescala/valid/CaseObject1.scala | 21 +++++ 6 files changed, 85 insertions(+), 36 deletions(-) create mode 100644 src/test/resources/regression/verification/purescala/valid/CaseObject1.scala diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index ad3dc998a..06a6ac624 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -19,6 +19,7 @@ trait CodeExtraction extends Extractors { import global.definitions._ import StructuralExtractors._ import ExpressionExtractors._ + import ExtractorHelpers._ private lazy val setTraitSym = definitions.getClass("scala.collection.immutable.Set") @@ -156,7 +157,9 @@ trait CodeExtraction extends Extractors { if(p._1.isAbstractClass) { classesToClasses += (p._1 -> new AbstractClassDef(p._2)) } else if(p._1.isCase) { - classesToClasses += (p._1 -> new CaseClassDef(p._2)) + val ccd = new CaseClassDef(p._2) + ccd.isCaseObject = p._1.isModuleClass + classesToClasses += (p._1 -> ccd) } }) @@ -367,51 +370,48 @@ trait CodeExtraction extends Extractors { * errors. */ private def scala2PureScala(unit: CompilationUnit, silent: Boolean)(tree: Tree): Expr = { def rewriteCaseDef(cd: CaseDef): MatchCase = { - def pat2pat(p: Tree): Pattern = p match { - case Ident(nme.WILDCARD) => WildcardPattern(None) - case b @ Bind(name, Ident(nme.WILDCARD)) => { - val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(b.symbol.tpe)) - varSubsts(b.symbol) = (() => Variable(newID)) - WildcardPattern(Some(newID)) - } - case b @ Bind(name, Typed(Ident(nme.WILDCARD), tpe)) => { + + def pat2pat(p: Tree, binder: Option[Identifier] = None): Pattern = p match { + case b @ Bind(name, Typed(pat, tpe)) => val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(tpe.tpe)) varSubsts(b.symbol) = (() => Variable(newID)) - WildcardPattern(Some(newID)) - } - case a @ Apply(fn, args) if fn.isType && a.tpe.typeSymbol.isCase && classesToClasses.keySet.contains(a.tpe.typeSymbol) => { - val cd = classesToClasses(a.tpe.typeSymbol).asInstanceOf[CaseClassDef] - assert(args.size == cd.fields.size) - CaseClassPattern(None, cd, args.map(pat2pat(_))) - } - case b @ Bind(name, a @ Apply(fn, args)) if fn.isType && a.tpe.typeSymbol.isCase && classesToClasses.keySet.contains(a.tpe.typeSymbol) => { + pat2pat(pat, Some(newID)) + + case b @ Bind(name, pat) => val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(b.symbol.tpe)) varSubsts(b.symbol) = (() => Variable(newID)) + pat2pat(pat, Some(newID)) + + case Ident(nme.WILDCARD) => + WildcardPattern(binder) + + case s @ Select(This(_), b) if s.tpe.typeSymbol.isCase && + classesToClasses.keySet.contains(s.tpe.typeSymbol) => + // case Obj => + val cd = classesToClasses(s.tpe.typeSymbol).asInstanceOf[CaseClassDef] + assert(cd.fields.size == 0) + CaseClassPattern(binder, cd, Seq()) + + case a @ Apply(fn, args) if fn.isType && + a.tpe.typeSymbol.isCase && + classesToClasses.keySet.contains(a.tpe.typeSymbol) => + val cd = classesToClasses(a.tpe.typeSymbol).asInstanceOf[CaseClassDef] assert(args.size == cd.fields.size) - CaseClassPattern(Some(newID), cd, args.map(pat2pat(_))) - } - case a@Apply(fn, args) => { - val pst = scalaType2PureScala(unit, silent)(a.tpe) - pst match { - case TupleType(argsTpes) => TuplePattern(None, args.map(pat2pat)) - case _ => throw ImpureCodeEncounteredException(p) - } - } - case b @ Bind(name, a @ Apply(fn, args)) => { - val newID = FreshIdentifier(name.toString).setType(scalaType2PureScala(unit,silent)(b.symbol.tpe)) - varSubsts(b.symbol) = (() => Variable(newID)) + CaseClassPattern(binder, cd, args.map(pat2pat(_))) + + case a @ Apply(fn, args) => val pst = scalaType2PureScala(unit, silent)(a.tpe) pst match { - case TupleType(argsTpes) => TuplePattern(Some(newID), args.map(pat2pat)) + case TupleType(argsTpes) => TuplePattern(binder, args.map(pat2pat(_))) case _ => throw ImpureCodeEncounteredException(p) } - } - case _ => { - if(!silent) + + case _ => + if (!silent) { unit.error(p.pos, "Unsupported pattern.") + } throw ImpureCodeEncounteredException(p) - } } if(cd.guard == EmptyTree) { @@ -508,6 +508,14 @@ trait CodeExtraction extends Extractors { } val e2: Option[Expr] = nextExpr match { + case ExCaseObject(sym) => + classesToClasses.get(sym) match { + case Some(ccd: CaseClassDef) => + Some(CaseClass(ccd, Seq())) + case _ => + None + } + case ExParameterlessMethodCall(t,n) => { val selector = rec(t) val selType = selector.getType @@ -726,6 +734,7 @@ trait CodeExtraction extends Extractors { val cct = cctype.asInstanceOf[CaseClassType] CaseClass(cct.classDef, nargs).setType(cct) } + case ExAnd(l, r) => And(rec(l), rec(r)).setType(BooleanType) case ExOr(l, r) => Or(rec(l), rec(r)).setType(BooleanType) case ExNot(e) => Not(rec(e)).setType(BooleanType) diff --git a/src/main/scala/leon/plugin/Extractors.scala b/src/main/scala/leon/plugin/Extractors.scala index b6b15a162..c333df79f 100644 --- a/src/main/scala/leon/plugin/Extractors.scala +++ b/src/main/scala/leon/plugin/Extractors.scala @@ -123,6 +123,16 @@ trait Extractors { } } + object ExCaseObject { + def unapply(s: Select): Option[Symbol] = { + if (s.tpe.typeSymbol.isModuleClass) { + Some(s.tpe.typeSymbol) + } else { + None + } + } + } + object ExCaseClassSyntheticJunk { def unapply(cd: ClassDef): Boolean = cd match { case ClassDef(_, _, _, _) if (cd.symbol.isSynthetic && cd.symbol.isFinal) => true diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 6e3c0b926..fe9ae283a 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -253,6 +253,7 @@ object Definitions { class CaseClassDef(val id: Identifier, prnt: Option[AbstractClassDef] = None) extends ClassTypeDef with ExtractorTypeDef { private var parent_ = prnt var fields: VarDecls = Nil + var isCaseObject = false val isAbstract = false def setParent(newParent: AbstractClassDef) = { diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 8a244a5e8..fd8fa2ed0 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -117,7 +117,11 @@ object PrettyPrinter { case CaseClass(cd, args) => { var nsb = sb nsb.append(cd.id) - nsb = ppNary(nsb, args, "(", ", ", ")", lvl) + if (cd.isCaseObject) { + nsb = ppNary(nsb, args, "", "", "", lvl) + } else { + nsb = ppNary(nsb, args, "(", ", ", ")", lvl) + } nsb } case CaseClassInstanceOf(cd, e) => { diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index d9a82aa8c..9e2579f6c 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -124,7 +124,11 @@ object ScalaPrinter { case CaseClass(cd, args) => { sb.append(cd.id) - ppNary(sb, args, "(", ", ", ")", lvl) + if (cd.isCaseObject) { + ppNary(sb, args, "", "", "", lvl) + } else { + ppNary(sb, args, "(", ", ", ")", lvl) + } } case CaseClassInstanceOf(cd, e) => { pp(e, sb, lvl) diff --git a/src/test/resources/regression/verification/purescala/valid/CaseObject1.scala b/src/test/resources/regression/verification/purescala/valid/CaseObject1.scala new file mode 100644 index 000000000..90dcdf064 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/CaseObject1.scala @@ -0,0 +1,21 @@ +object CaseObject1 { + + abstract sealed class A + case class B(size: Int) extends A + case object C extends A + + def foo(): A = { + C + } + + def foo1(a: A): A = a match { + case C => a + case B(s) => a + } + + def foo2(a: A): A = a match { + case c @ C => c + case B(s) => a + } + +} -- GitLab