Skip to content
Snippets Groups Projects
Commit f3aadd28 authored by Nicolas Voirol's avatar Nicolas Voirol
Browse files

ADT invariants from abstract classes (no override for invariants)

parent aec3b935
No related branches found
No related tags found
No related merge requests found
......@@ -64,6 +64,7 @@ object MethodLifting extends TransformationPhase {
} getOrElse {
(List(), false)
}
case acd: AbstractClassDef =>
val (r, c) = acd.knownChildren.map(makeCases(_, fdId, breakDown)).unzip
val recs = r.flatten
......@@ -71,26 +72,78 @@ object MethodLifting extends TransformationPhase {
if (complete) {
// Children define all cases completely, we don't need to add anything
(recs, true)
} else if (!acd.methods.exists( m => m.id == fdId && m.body.nonEmpty)) {
} else if (!acd.methods.exists(m => m.id == fdId && m.body.nonEmpty)) {
// We don't have anything to add
(recs, false)
} else {
// We have something to add
val m = acd.methods.find( m => m.id == fdId ).get
val m = acd.methods.find(m => m.id == fdId).get
val at = acd.typed
val binder = FreshIdentifier(acd.id.name.toLowerCase, at, true)
def subst(e: Expr): Expr = e match {
case This(ct) =>
asInstOf(Variable(binder), ct)
case e =>
e
}
val newE = simplePreTransform(subst)(breakDown(m.fullBody))
val newE = simplePreTransform {
case This(ct) => asInstOf(Variable(binder), ct)
case e => e
} (breakDown(m.fullBody))
val cse = SimpleCase(InstanceOfPattern(Some(binder), at), newE).setPos(newE)
(recs :+ cse, true)
}
}
def makeInvCases(cd: ClassDef): (Seq[MatchCase], Boolean) = {
val ct = cd.typed
val binder = FreshIdentifier(cd.id.name.toLowerCase, ct, true)
val fd = cd.methods.find(_.isInvariant).get
cd match {
case ccd: CaseClassDef =>
val fBinders = (ccd.fieldsIds zip ct.fields).map(p => p._1 -> p._2.id.freshen).toMap
// Ancestor's method is a method in the case class
val subPatts = ccd.fields map (f => WildcardPattern(Some(fBinders(f.id))))
val patt = CaseClassPattern(Some(binder), ct.asInstanceOf[CaseClassType], subPatts)
val newE = simplePreTransform {
case e @ CaseClassSelector(`ct`, This(`ct`), i) =>
Variable(fBinders(i)).setPos(e)
case e @ This(`ct`) =>
Variable(binder).setPos(e)
case e =>
e
} (fd.fullBody)
if (newE == BooleanLiteral(true)) {
(Nil, false)
} else {
val cse = SimpleCase(patt, newE).setPos(newE)
(List(cse), true)
}
case acd: AbstractClassDef =>
val (r, c) = acd.knownChildren.map(makeInvCases).unzip
val recs = r.flatten
val complete = !(c contains false)
val newE = simplePreTransform {
case This(ct) => asInstOf(Variable(binder), ct)
case e => e
} (fd.fullBody)
if (newE == BooleanLiteral(true)) {
(recs, false)
} else {
val rhs = if (recs.isEmpty) {
newE
} else {
val allCases = if (complete) recs else {
recs :+ SimpleCase(WildcardPattern(None), BooleanLiteral(true))
}
and(newE, MatchExpr(binder.toVariable, allCases)).setPos(newE)
}
val cse = SimpleCase(InstanceOfPattern(Some(binder), ct), rhs).setPos(newE)
(Seq(cse), true)
}
}
}
def apply(ctx: LeonContext, program: Program): Program = {
......@@ -156,7 +209,7 @@ object MethodLifting extends TransformationPhase {
isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp }))
}
if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) {
if (cd.knownDescendants.forall(cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) {
// Don't need to compose methods
val paramsMap = fd.params.zip(fdParams).map { case (x,y) => (x.id, y.id) }.toMap
def thisToReceiver(e: Expr): Option[Expr] = e match {
......@@ -170,12 +223,9 @@ object MethodLifting extends TransformationPhase {
nfd.fullBody = postMap(thisToReceiver)(insTp(nfd.fullBody))
// Add precondition if the method was defined in a subclass
val pre = and(
classPre(fd),
nfd.precOrTrue
)
val pre = and(classPre(fd), nfd.precOrTrue)
nfd.fullBody = withPrecondition(nfd.fullBody, Some(pre))
} else {
// We need to compose methods of subclasses
......@@ -203,65 +253,72 @@ object MethodLifting extends TransformationPhase {
paramsMap + (receiver -> receiver)
)
/* Separately handle pre, post, body */
val (pre, _) = makeCases(cd, fd.id, preconditionOf(_).getOrElse(BooleanLiteral(true)))
val (post, _) = makeCases(cd, fd.id, postconditionOf(_).getOrElse(
Lambda(Seq(ValDef(FreshIdentifier("res", retType, true))), BooleanLiteral(true))
))
val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType)))
// Some simplifications
val preSimple = {
val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) }
val compositePre =
if (nonTrivial == 0) {
BooleanLiteral(true)
} else {
inst(pre).setPos(fd.getPos)
}
Some(and(classPre(fd), compositePre))
}
val postSimple = {
val trivial = post.forall {
case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true
case _ => false
if (fd.isInvariant) {
val (cases, complete) = makeInvCases(cd)
val allCases = if (complete) cases else {
cases :+ SimpleCase(WildcardPattern(None), BooleanLiteral(true))
}
if (trivial) None
else {
val resVal = FreshIdentifier("res", retType, true)
Some(Lambda(
Seq(ValDef(resVal)),
inst(post map { cs => cs.copy( rhs =
application(cs.rhs, Seq(Variable(resVal)))
)})
).setPos(fd))
nfd.fullBody = inst(allCases).setPos(fd.getPos)
} else {
/* Separately handle pre, post, body */
val (pre, _) = makeCases(cd, fd.id, preconditionOf(_).getOrElse(BooleanLiteral(true)))
val (post, _) = makeCases(cd, fd.id, postconditionOf(_).getOrElse(
Lambda(Seq(ValDef(FreshIdentifier("res", retType, true))), BooleanLiteral(true))
))
val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType)))
// Some simplifications
val preSimple = {
val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) }
val compositePre =
if (nonTrivial == 0) {
BooleanLiteral(true)
} else {
inst(pre).setPos(fd.getPos)
}
Some(and(classPre(fd), compositePre))
}
}
val bodySimple = {
val trivial = body forall {
case SimpleCase(_, NoTree(_)) => true
case _ => false
val postSimple = {
val trivial = post.forall {
case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true
case _ => false
}
if (trivial) None
else {
val resVal = FreshIdentifier("res", retType, true)
Some(Lambda(
Seq(ValDef(resVal)),
inst(post map { cs => cs.copy( rhs =
application(cs.rhs, Seq(Variable(resVal)))
)})
).setPos(fd))
}
}
if (trivial) NoTree(retType) else inst(body)
}
/* Construct full body */
nfd.fullBody = withPostcondition(
withPrecondition(bodySimple, preSimple),
postSimple
)
}
val bodySimple = {
val trivial = body forall {
case SimpleCase(_, NoTree(_)) => true
case _ => false
}
if (trivial) NoTree(retType) else inst(body)
}
if (cd.methods.exists(md => md.id == fd.id && md.isInvariant)) {
cd.setInvariant(nfd)
/* Construct full body */
nfd.fullBody = withPostcondition(
withPrecondition(bodySimple, preSimple),
postSimple
)
}
}
mdToFds += fd -> nfd
fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd)
if (fd.isInvariant) cd.setInvariant(nfd)
}
// 2) Place functions in existing companions:
......
import leon.lang._
object ADTInvariants3 {
sealed abstract class A
sealed abstract class B extends A {
require(this match {
case Cons(h, t) => h == size
case Nil(_) => true
})
def size: BigInt = (this match {
case Cons(h, t) => 1 + t.size
case Nil(_) => BigInt(0)
}) ensuring ((i: BigInt) => i >= 0)
}
case class Cons(head: BigInt, tail: B) extends B
case class Nil(i: BigInt) extends B {
require(i >= 0)
}
def sum(a: A): BigInt = (a match {
case Cons(head, tail) => head + sum(tail)
case Nil(i) => i
}) ensuring ((i: BigInt) => i >= 0)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment