From 5d8d8ea1f0800c351c87117353a495208f91bc39 Mon Sep 17 00:00:00 2001 From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch> Date: Thu, 27 Aug 2015 12:03:58 +0200 Subject: [PATCH] Handle case-class field implementing method --- .../leon/frontends/scalac/ASTExtractors.scala | 19 +++++- .../frontends/scalac/CodeExtraction.scala | 54 +++++++---------- .../scala/leon/purescala/MethodLifting.scala | 60 ++++++++++++------- .../purescala/invalid/FieldInheritance.scala | 22 +++++++ .../purescala/valid/FieldInheritance.scala | 25 ++++++++ 5 files changed, 127 insertions(+), 53 deletions(-) create mode 100644 src/test/resources/regression/verification/purescala/invalid/FieldInheritance.scala create mode 100644 src/test/resources/regression/verification/purescala/valid/FieldInheritance.scala diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index d8910012f..066955731 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -350,7 +350,24 @@ trait ASTExtractors { case _ => false }.get.asInstanceOf[DefDef] - val args = constructor.vparamss.flatten.map(vd => ( vd.symbol, vd)) + val valDefs = constructor.vparamss.flatten + + //impl.children foreach println + + val symbols = impl.children.collect { + case df: DefDef if df.symbol.isStable && df.symbol.isAccessor && + df.symbol.isParamAccessor => + df.symbol + } + + //if (symbols.size != valDefs.size) { + // println(" >>>>> " + cd.name) + // symbols foreach println + // valDefs foreach println + //} + + val args = symbols zip valDefs + Some((name.toString, cd.symbol, args, impl)) } case _ => None diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 62627f8c8..9f9db555e 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -518,10 +518,10 @@ trait CodeExtraction extends ASTExtractors { classesToClasses += sym -> ccd parent.foreach(_.classDef.registerChild(ccd)) - val fields = args.map { case (symbol, t) => - val tpt = t.tpt - val tpe = leonType(tpt.tpe)(defCtx, sym.pos) - LeonValDef(FreshIdentifier(symbol.name.toString, tpe).setPos(t.pos)).setPos(t.pos) + val fields = args.map { case (fsym, t) => + val tpe = leonType(t.tpt.tpe)(defCtx, fsym.pos) + val id = overridenOrFresh(fsym, Some(ccd), tpe) + LeonValDef(id.setPos(t.pos), Some(tpe)).setPos(t.pos) } ccd.setFields(fields) @@ -607,6 +607,20 @@ trait CodeExtraction extends ASTExtractors { cd } + // Returns the parent's method Identifier if sym overrides a symbol, otherwise a fresh Identifier + private def overridenOrFresh(sym: Symbol, within: Option[LeonClassDef], tpe: LeonType = Untyped) = { + val name = sym.name.toString + if (sym.overrideChain.length > 1) { + (for { + cd <- within + p <- cd.parent + m <- p.classDef.methods.find(_.id.name == name) + } yield m.id).getOrElse(FreshIdentifier(name, tpe)) + } else { + FreshIdentifier(name, tpe) + } + } + private var defsToDefs = Map[Symbol, FunDef]() private def defineFunDef(sym: Symbol, within: Option[LeonClassDef] = None)(implicit dctx: DefContext): FunDef = { @@ -626,19 +640,7 @@ trait CodeExtraction extends ASTExtractors { val returnType = leonType(sym.info.finalResultType)(nctx, sym.pos) - val name = sym.name.toString - - val id = { - if (sym.overrideChain.length > 1) { - (for { - cd <- within - p <- cd.parent - m <- p.classDef.methods.find(_.id.name == name) - } yield m.id).getOrElse(FreshIdentifier(name)) - } else { - FreshIdentifier(name) - } - } + val id = overridenOrFresh(sym, within) val fd = new FunDef(id.setPos(sym.pos), tparamsDef, returnType, newParams) @@ -661,19 +663,7 @@ trait CodeExtraction extends ASTExtractors { val returnType = leonType(sym.info.finalResultType)(nctx, sym.pos) - val name = sym.name.toString - - val id = - if (sym.overrideChain.length == 1) { - FreshIdentifier(name) - } else { - ( for { - cd <- within - p <- cd.parent - m <- p.classDef.methods.find(_.id.name == name) - } yield m.id).getOrElse(FreshIdentifier(name)) - } - + val id = overridenOrFresh(sym, within) val fd = new FunDef(id.setPos(sym.pos), Seq(), returnType, Seq()) fd.setPos(sym.pos) @@ -861,7 +851,7 @@ trait CodeExtraction extends ASTExtractors { // case Obj => extractType(s) match { case ct: CaseClassType => - assert(ct.classDef.fields.size == 0) + assert(ct.classDef.fields.isEmpty) (CaseClassPattern(binder, ct, Seq()).setPos(p.pos), dctx) case _ => outOfSubsetError(s, "Invalid type "+s.tpe+" for .isInstanceOf") @@ -1494,7 +1484,7 @@ trait CodeExtraction extends ASTExtractors { //println(s"symbol $sym with id ${sym.id}") //println(s"isMethod($sym) == ${isMethod(sym)}") - + (rrec, sym.name.decoded, rargs) match { case (null, _, args) => val fd = getFunDef(sym, c.pos) diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 5f41bc811..38319eb22 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -24,25 +24,45 @@ object MethodLifting extends TransformationPhase { // A Seq of MatchCases is returned, along with a boolean that signifies if the matching is complete. private def makeCases(cd: ClassDef, fdId: Identifier, breakDown: Expr => Expr): (Seq[MatchCase], Boolean) = cd match { case ccd: CaseClassDef => - ccd.methods.find( _.id == fdId) match { - case None => - (List(), false) - case Some(m) => - val ct = ccd.typed - val binder = FreshIdentifier(ccd.id.name.toLowerCase, ct, true) - val fBinders = ct.fields.map{ f => f.id -> f.id.freshen }.toMap - def subst(e: Expr): Expr = e match { - case CaseClassSelector(`ct`, This(`ct`), i) => - Variable(fBinders(i)).setPos(e) - case This(`ct`) => - Variable(binder).setPos(e) - case e => - e - } - val newE = simplePreTransform(subst)(breakDown(m.fullBody)) - val subPatts = ct.fields map (f => WildcardPattern(Some(fBinders(f.id)))) - val cse = SimpleCase(CaseClassPattern(Some(binder), ct, subPatts), newE).setPos(newE) - (List(cse), true) + + // Common for both cases + val ct = ccd.typed + val binder = FreshIdentifier(ccd.id.name.toLowerCase, ct, true) + val fBinders = ct.fields.map{ f => f.id -> f.id.freshen }.toMap + def subst(e: Expr): Expr = e match { + case CaseClassSelector(`ct`, This(`ct`), i) => + Variable(fBinders(i)).setPos(e) + case This(`ct`) => + Variable(binder).setPos(e) + case e => + e + } + + ccd.methods.find( _.id == fdId).map { m => + + // Ancestor's method is a method in the case class + val subPatts = ct.fields map (f => WildcardPattern(Some(fBinders(f.id)))) + val patt = CaseClassPattern(Some(binder), ct, subPatts) + val newE = simplePreTransform(subst)(breakDown(m.fullBody)) + val cse = SimpleCase(patt, newE).setPos(newE) + (List(cse), true) + + } orElse ccd.fields.find( _.id == fdId).map { f => + + // Ancestor's method is a case class argument in the case class + val subPatts = ct.fields map (fld => + if (fld.id == f.id) + WildcardPattern(Some(fBinders(f.id))) + else + WildcardPattern(None) + ) + val patt = CaseClassPattern(Some(binder), ct, subPatts) + val newE = breakDown(Variable(fBinders(f.id))) + val cse = SimpleCase(patt, newE).setPos(newE) + (List(cse), true) + + } getOrElse { + (List(), false) } case acd: AbstractClassDef => val (r, c) = acd.knownChildren.map(makeCases(_, fdId, breakDown)).unzip @@ -139,7 +159,7 @@ object MethodLifting extends TransformationPhase { nfd.setPos(fd) nfd.addFlag(IsMethod(cd)) - if (cd.knownDescendants.forall( _.methods.forall(_.id != fd.id))) { + if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) { val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap // Don't need to compose methods nfd.fullBody = postMap { diff --git a/src/test/resources/regression/verification/purescala/invalid/FieldInheritance.scala b/src/test/resources/regression/verification/purescala/invalid/FieldInheritance.scala new file mode 100644 index 000000000..7ef9d9522 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/FieldInheritance.scala @@ -0,0 +1,22 @@ +import leon.lang._ + +object FieldInheritance { + + abstract class Foo[B] { + val thisIsIt: BigInt = 1 + val y: BigInt + val weird: B + } + + case class Bar[X](override val thisIsIt: BigInt, weird: X) extends Foo[X] { + val y = thisIsIt + } + + case class Baz[X](weird: X) extends Foo[X] { + val y = thisIsIt + 1 + } + + + def foo[A](f: Foo[A]) = (f.thisIsIt == 1).holds + +} diff --git a/src/test/resources/regression/verification/purescala/valid/FieldInheritance.scala b/src/test/resources/regression/verification/purescala/valid/FieldInheritance.scala new file mode 100644 index 000000000..9b99fc84c --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/FieldInheritance.scala @@ -0,0 +1,25 @@ +import leon.lang._ + +object FieldInheritance { + + abstract class Foo[B] { + val thisIsIt: BigInt = 1 + val y: BigInt + val weird: B + } + + case class Bar[X](override val thisIsIt: BigInt, weird: X) extends Foo[X] { + val y = thisIsIt + } + + case class Baz[X](weird: X) extends Foo[X] { + val y = thisIsIt + 1 + } + + + def foo[A](f: Foo[A]) = { f match { + case Bar(t, _) => f.thisIsIt == t + case _ => true + }}.holds + +} -- GitLab