diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 5c297bc8b0e962d99c22e0471ea2b68dd2a8f54a..8aaf2eafe8596626eac04bdbe15c6a7a5d56279f 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -64,6 +64,7 @@ object MethodLifting extends TransformationPhase { } getOrElse { (List(), false) } + case acd: AbstractClassDef => val (r, c) = acd.knownChildren.map(makeCases(_, fdId, breakDown)).unzip val recs = r.flatten @@ -71,26 +72,78 @@ object MethodLifting extends TransformationPhase { if (complete) { // Children define all cases completely, we don't need to add anything (recs, true) - } else if (!acd.methods.exists( m => m.id == fdId && m.body.nonEmpty)) { + } else if (!acd.methods.exists(m => m.id == fdId && m.body.nonEmpty)) { // We don't have anything to add (recs, false) } else { // We have something to add - val m = acd.methods.find( m => m.id == fdId ).get + val m = acd.methods.find(m => m.id == fdId).get val at = acd.typed val binder = FreshIdentifier(acd.id.name.toLowerCase, at, true) - def subst(e: Expr): Expr = e match { - case This(ct) => - asInstOf(Variable(binder), ct) - case e => - e - } - val newE = simplePreTransform(subst)(breakDown(m.fullBody)) + val newE = simplePreTransform { + case This(ct) => asInstOf(Variable(binder), ct) + case e => e + } (breakDown(m.fullBody)) + val cse = SimpleCase(InstanceOfPattern(Some(binder), at), newE).setPos(newE) (recs :+ cse, true) } } + def makeInvCases(cd: ClassDef): (Seq[MatchCase], Boolean) = { + val ct = cd.typed + val binder = FreshIdentifier(cd.id.name.toLowerCase, ct, true) + val fd = cd.methods.find(_.isInvariant).get + + cd match { + case ccd: CaseClassDef => + val fBinders = (ccd.fieldsIds zip ct.fields).map(p => p._1 -> p._2.id.freshen).toMap + + // Ancestor's method is a method in the case class + val subPatts = ccd.fields map (f => WildcardPattern(Some(fBinders(f.id)))) + val patt = CaseClassPattern(Some(binder), ct.asInstanceOf[CaseClassType], subPatts) + val newE = simplePreTransform { + case e @ CaseClassSelector(`ct`, This(`ct`), i) => + Variable(fBinders(i)).setPos(e) + case e @ This(`ct`) => + Variable(binder).setPos(e) + case e => + e + } (fd.fullBody) + + if (newE == BooleanLiteral(true)) { + (Nil, false) + } else { + val cse = SimpleCase(patt, newE).setPos(newE) + (List(cse), true) + } + + case acd: AbstractClassDef => + val (r, c) = acd.knownChildren.map(makeInvCases).unzip + val recs = r.flatten + val complete = !(c contains false) + + val newE = simplePreTransform { + case This(ct) => asInstOf(Variable(binder), ct) + case e => e + } (fd.fullBody) + + if (newE == BooleanLiteral(true)) { + (recs, false) + } else { + val rhs = if (recs.isEmpty) { + newE + } else { + val allCases = if (complete) recs else { + recs :+ SimpleCase(WildcardPattern(None), BooleanLiteral(true)) + } + and(newE, MatchExpr(binder.toVariable, allCases)).setPos(newE) + } + val cse = SimpleCase(InstanceOfPattern(Some(binder), ct), rhs).setPos(newE) + (Seq(cse), true) + } + } + } def apply(ctx: LeonContext, program: Program): Program = { @@ -156,7 +209,7 @@ object MethodLifting extends TransformationPhase { isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp })) } - if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) { + if (cd.knownDescendants.forall(cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) { // Don't need to compose methods val paramsMap = fd.params.zip(fdParams).map { case (x,y) => (x.id, y.id) }.toMap def thisToReceiver(e: Expr): Option[Expr] = e match { @@ -170,12 +223,9 @@ object MethodLifting extends TransformationPhase { nfd.fullBody = postMap(thisToReceiver)(insTp(nfd.fullBody)) // Add precondition if the method was defined in a subclass - val pre = and( - classPre(fd), - nfd.precOrTrue - ) + val pre = and(classPre(fd), nfd.precOrTrue) nfd.fullBody = withPrecondition(nfd.fullBody, Some(pre)) - + } else { // We need to compose methods of subclasses @@ -203,65 +253,72 @@ object MethodLifting extends TransformationPhase { paramsMap + (receiver -> receiver) ) - /* Separately handle pre, post, body */ - val (pre, _) = makeCases(cd, fd.id, preconditionOf(_).getOrElse(BooleanLiteral(true))) - val (post, _) = makeCases(cd, fd.id, postconditionOf(_).getOrElse( - Lambda(Seq(ValDef(FreshIdentifier("res", retType, true))), BooleanLiteral(true)) - )) - val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType))) - - // Some simplifications - val preSimple = { - val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) } - - val compositePre = - if (nonTrivial == 0) { - BooleanLiteral(true) - } else { - inst(pre).setPos(fd.getPos) - } - - Some(and(classPre(fd), compositePre)) - } - - val postSimple = { - val trivial = post.forall { - case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true - case _ => false + if (fd.isInvariant) { + val (cases, complete) = makeInvCases(cd) + val allCases = if (complete) cases else { + cases :+ SimpleCase(WildcardPattern(None), BooleanLiteral(true)) } - if (trivial) None - else { - val resVal = FreshIdentifier("res", retType, true) - Some(Lambda( - Seq(ValDef(resVal)), - inst(post map { cs => cs.copy( rhs = - application(cs.rhs, Seq(Variable(resVal))) - )}) - ).setPos(fd)) + + nfd.fullBody = inst(allCases).setPos(fd.getPos) + } else { + /* Separately handle pre, post, body */ + val (pre, _) = makeCases(cd, fd.id, preconditionOf(_).getOrElse(BooleanLiteral(true))) + val (post, _) = makeCases(cd, fd.id, postconditionOf(_).getOrElse( + Lambda(Seq(ValDef(FreshIdentifier("res", retType, true))), BooleanLiteral(true)) + )) + val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType))) + + // Some simplifications + val preSimple = { + val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) } + + val compositePre = + if (nonTrivial == 0) { + BooleanLiteral(true) + } else { + inst(pre).setPos(fd.getPos) + } + + Some(and(classPre(fd), compositePre)) } - } - val bodySimple = { - val trivial = body forall { - case SimpleCase(_, NoTree(_)) => true - case _ => false + val postSimple = { + val trivial = post.forall { + case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true + case _ => false + } + if (trivial) None + else { + val resVal = FreshIdentifier("res", retType, true) + Some(Lambda( + Seq(ValDef(resVal)), + inst(post map { cs => cs.copy( rhs = + application(cs.rhs, Seq(Variable(resVal))) + )}) + ).setPos(fd)) + } } - if (trivial) NoTree(retType) else inst(body) - } - /* Construct full body */ - nfd.fullBody = withPostcondition( - withPrecondition(bodySimple, preSimple), - postSimple - ) - } + val bodySimple = { + val trivial = body forall { + case SimpleCase(_, NoTree(_)) => true + case _ => false + } + if (trivial) NoTree(retType) else inst(body) + } - if (cd.methods.exists(md => md.id == fd.id && md.isInvariant)) { - cd.setInvariant(nfd) + /* Construct full body */ + nfd.fullBody = withPostcondition( + withPrecondition(bodySimple, preSimple), + postSimple + ) + } } mdToFds += fd -> nfd fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd) + + if (fd.isInvariant) cd.setInvariant(nfd) } // 2) Place functions in existing companions: diff --git a/src/test/resources/regression/verification/purescala/valid/ADTInvariants3.scala b/src/test/resources/regression/verification/purescala/valid/ADTInvariants3.scala new file mode 100644 index 0000000000000000000000000000000000000000..317f0230c506f03e7b4668d7fba2bab2caa89d5e --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/ADTInvariants3.scala @@ -0,0 +1,27 @@ +import leon.lang._ + +object ADTInvariants3 { + + sealed abstract class A + sealed abstract class B extends A { + require(this match { + case Cons(h, t) => h == size + case Nil(_) => true + }) + + def size: BigInt = (this match { + case Cons(h, t) => 1 + t.size + case Nil(_) => BigInt(0) + }) ensuring ((i: BigInt) => i >= 0) + } + + case class Cons(head: BigInt, tail: B) extends B + case class Nil(i: BigInt) extends B { + require(i >= 0) + } + + def sum(a: A): BigInt = (a match { + case Cons(head, tail) => head + sum(tail) + case Nil(i) => i + }) ensuring ((i: BigInt) => i >= 0) +}