diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala
index 5c297bc8b0e962d99c22e0471ea2b68dd2a8f54a..8aaf2eafe8596626eac04bdbe15c6a7a5d56279f 100644
--- a/src/main/scala/leon/purescala/MethodLifting.scala
+++ b/src/main/scala/leon/purescala/MethodLifting.scala
@@ -64,6 +64,7 @@ object MethodLifting extends TransformationPhase {
       } getOrElse {
         (List(), false)
       }
+
     case acd: AbstractClassDef =>
       val (r, c) = acd.knownChildren.map(makeCases(_, fdId, breakDown)).unzip
       val recs = r.flatten
@@ -71,26 +72,78 @@ object MethodLifting extends TransformationPhase {
       if (complete) {
         // Children define all cases completely, we don't need to add anything
         (recs, true)
-      } else if (!acd.methods.exists( m => m.id == fdId && m.body.nonEmpty)) {
+      } else if (!acd.methods.exists(m => m.id == fdId && m.body.nonEmpty)) {
         // We don't have anything to add
         (recs, false)
       } else {
         // We have something to add
-        val m = acd.methods.find( m => m.id == fdId ).get
+        val m = acd.methods.find(m => m.id == fdId).get
         val at = acd.typed
         val binder = FreshIdentifier(acd.id.name.toLowerCase, at, true)
-        def subst(e: Expr): Expr = e match {
-          case This(ct) =>
-            asInstOf(Variable(binder), ct)
-          case e =>
-            e
-        }
-        val newE = simplePreTransform(subst)(breakDown(m.fullBody))
+        val newE = simplePreTransform {
+          case This(ct) => asInstOf(Variable(binder), ct)
+          case e => e
+        } (breakDown(m.fullBody))
+
         val cse = SimpleCase(InstanceOfPattern(Some(binder), at), newE).setPos(newE)
         (recs :+ cse, true)
       }
   }
 
+  def makeInvCases(cd: ClassDef): (Seq[MatchCase], Boolean) = {
+    val ct = cd.typed
+    val binder = FreshIdentifier(cd.id.name.toLowerCase, ct, true)
+    val fd = cd.methods.find(_.isInvariant).get
+
+    cd match {
+      case ccd: CaseClassDef =>
+        val fBinders = (ccd.fieldsIds zip ct.fields).map(p => p._1 -> p._2.id.freshen).toMap
+
+        // Ancestor's method is a method in the case class
+        val subPatts = ccd.fields map (f => WildcardPattern(Some(fBinders(f.id))))
+        val patt = CaseClassPattern(Some(binder), ct.asInstanceOf[CaseClassType], subPatts)
+        val newE = simplePreTransform {
+          case e @ CaseClassSelector(`ct`, This(`ct`), i) =>
+            Variable(fBinders(i)).setPos(e)
+          case e @ This(`ct`) =>
+            Variable(binder).setPos(e)
+          case e =>
+            e
+        } (fd.fullBody)
+
+        if (newE == BooleanLiteral(true)) {
+          (Nil, false)
+        } else {
+          val cse = SimpleCase(patt, newE).setPos(newE)
+          (List(cse), true)
+        }
+
+      case acd: AbstractClassDef =>
+        val (r, c) = acd.knownChildren.map(makeInvCases).unzip
+        val recs = r.flatten
+        val complete = !(c contains false)
+
+        val newE = simplePreTransform {
+          case This(ct) => asInstOf(Variable(binder), ct)
+          case e => e
+        } (fd.fullBody)
+
+        if (newE == BooleanLiteral(true)) {
+          (recs, false)
+        } else {
+          val rhs = if (recs.isEmpty) {
+            newE
+          } else {
+            val allCases = if (complete) recs else {
+              recs :+ SimpleCase(WildcardPattern(None), BooleanLiteral(true))
+            }
+            and(newE, MatchExpr(binder.toVariable, allCases)).setPos(newE)
+          }
+          val cse = SimpleCase(InstanceOfPattern(Some(binder), ct), rhs).setPos(newE)
+          (Seq(cse), true)
+        }
+    }
+  }
 
   def apply(ctx: LeonContext, program: Program): Program = {
 
@@ -156,7 +209,7 @@ object MethodLifting extends TransformationPhase {
             isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp }))
         }
 
-        if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) {
+        if (cd.knownDescendants.forall(cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) {
           // Don't need to compose methods
           val paramsMap = fd.params.zip(fdParams).map { case (x,y) => (x.id, y.id) }.toMap
           def thisToReceiver(e: Expr): Option[Expr] = e match {
@@ -170,12 +223,9 @@ object MethodLifting extends TransformationPhase {
           nfd.fullBody = postMap(thisToReceiver)(insTp(nfd.fullBody))
 
           // Add precondition if the method was defined in a subclass
-          val pre = and(
-            classPre(fd),
-            nfd.precOrTrue
-          )
+          val pre = and(classPre(fd), nfd.precOrTrue)
           nfd.fullBody = withPrecondition(nfd.fullBody, Some(pre))
-        
+
         } else {
           // We need to compose methods of subclasses
 
@@ -203,65 +253,72 @@ object MethodLifting extends TransformationPhase {
             paramsMap + (receiver -> receiver)
           )
 
-          /* Separately handle pre, post, body */
-          val (pre, _)  = makeCases(cd, fd.id, preconditionOf(_).getOrElse(BooleanLiteral(true)))
-          val (post, _) = makeCases(cd, fd.id, postconditionOf(_).getOrElse(
-            Lambda(Seq(ValDef(FreshIdentifier("res", retType, true))), BooleanLiteral(true))
-          ))
-          val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType)))
-
-          // Some simplifications
-          val preSimple = {
-            val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) }
-
-            val compositePre =
-              if (nonTrivial == 0) {
-                BooleanLiteral(true)
-              } else {
-                inst(pre).setPos(fd.getPos)
-              }
-
-            Some(and(classPre(fd), compositePre))
-          }
-
-          val postSimple = {
-            val trivial = post.forall {
-              case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true
-              case _ => false
+          if (fd.isInvariant) {
+            val (cases, complete) = makeInvCases(cd)
+            val allCases = if (complete) cases else {
+              cases :+ SimpleCase(WildcardPattern(None), BooleanLiteral(true))
             }
-            if (trivial) None
-            else {
-              val resVal = FreshIdentifier("res", retType, true)
-              Some(Lambda(
-                Seq(ValDef(resVal)),
-                inst(post map { cs => cs.copy( rhs =
-                  application(cs.rhs, Seq(Variable(resVal)))
-                )})
-              ).setPos(fd))
+
+            nfd.fullBody = inst(allCases).setPos(fd.getPos)
+          } else {
+            /* Separately handle pre, post, body */
+            val (pre, _)  = makeCases(cd, fd.id, preconditionOf(_).getOrElse(BooleanLiteral(true)))
+            val (post, _) = makeCases(cd, fd.id, postconditionOf(_).getOrElse(
+              Lambda(Seq(ValDef(FreshIdentifier("res", retType, true))), BooleanLiteral(true))
+            ))
+            val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType)))
+
+            // Some simplifications
+            val preSimple = {
+              val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) }
+
+              val compositePre =
+                if (nonTrivial == 0) {
+                  BooleanLiteral(true)
+                } else {
+                  inst(pre).setPos(fd.getPos)
+                }
+
+              Some(and(classPre(fd), compositePre))
             }
-          }
 
-          val bodySimple = {
-            val trivial = body forall {
-              case SimpleCase(_, NoTree(_)) => true
-              case _ => false
+            val postSimple = {
+              val trivial = post.forall {
+                case SimpleCase(_, Lambda(_, BooleanLiteral(true))) => true
+                case _ => false
+              }
+              if (trivial) None
+              else {
+                val resVal = FreshIdentifier("res", retType, true)
+                Some(Lambda(
+                  Seq(ValDef(resVal)),
+                  inst(post map { cs => cs.copy( rhs =
+                    application(cs.rhs, Seq(Variable(resVal)))
+                  )})
+                ).setPos(fd))
+              }
             }
-            if (trivial) NoTree(retType) else inst(body)
-          }
 
-          /* Construct full body */
-          nfd.fullBody = withPostcondition(
-            withPrecondition(bodySimple, preSimple),
-            postSimple
-          )
-        }
+            val bodySimple = {
+              val trivial = body forall {
+                case SimpleCase(_, NoTree(_)) => true
+                case _ => false
+              }
+              if (trivial) NoTree(retType) else inst(body)
+            }
 
-        if (cd.methods.exists(md => md.id == fd.id && md.isInvariant)) {
-          cd.setInvariant(nfd)
+            /* Construct full body */
+            nfd.fullBody = withPostcondition(
+              withPrecondition(bodySimple, preSimple),
+              postSimple
+            )
+          }
         }
 
         mdToFds += fd -> nfd
         fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd)
+
+        if (fd.isInvariant) cd.setInvariant(nfd)
       }
 
       // 2) Place functions in existing companions:
diff --git a/src/test/resources/regression/verification/purescala/valid/ADTInvariants3.scala b/src/test/resources/regression/verification/purescala/valid/ADTInvariants3.scala
new file mode 100644
index 0000000000000000000000000000000000000000..317f0230c506f03e7b4668d7fba2bab2caa89d5e
--- /dev/null
+++ b/src/test/resources/regression/verification/purescala/valid/ADTInvariants3.scala
@@ -0,0 +1,27 @@
+import leon.lang._
+
+object ADTInvariants3 {
+
+  sealed abstract class A
+  sealed abstract class B extends A {
+    require(this match {
+      case Cons(h, t) => h == size
+      case Nil(_) => true
+    })
+
+    def size: BigInt = (this match {
+      case Cons(h, t) => 1 + t.size
+      case Nil(_) => BigInt(0)
+    }) ensuring ((i: BigInt) => i >= 0)
+  }
+
+  case class Cons(head: BigInt, tail: B) extends B
+  case class Nil(i: BigInt) extends B {
+    require(i >= 0)
+  }
+
+  def sum(a: A): BigInt = (a match {
+    case Cons(head, tail) => head + sum(tail)
+    case Nil(i) => i
+  }) ensuring ((i: BigInt) => i >= 0)
+}