diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala
index 4848bf10effc93b0b8db481a26d2c53db297e5f0..f94272e9b0923aa493681ed320f07b8a1a2005ee 100644
--- a/src/main/scala/leon/purescala/MethodLifting.scala
+++ b/src/main/scala/leon/purescala/MethodLifting.scala
@@ -95,49 +95,43 @@ object MethodLifting extends TransformationPhase {
 
     // First we create the appropriate functions from methods:
     var mdToFds = Map[FunDef, FunDef]()
+    var mdToCls = Map[FunDef, ClassDef]()
+
+    // Lift methods to the root class
+    for {
+      u <- program.units
+      ch <- u.classHierarchies
+      c  <- ch
+      if c.parent.isDefined
+      fd <- c.methods
+      if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id))
+    } {
+      val root = c.ancestors.last
+      val tMap = c.tparams.zip(root.tparams.map{_.tp}).toMap
+      val tSubst: TypeTree => TypeTree  = instantiateType(_, tMap)
+
+      val fdParams = fd.params map { vd =>
+        val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType))
+        ValDef(newId).setPos(vd.getPos)
+      }
+      val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap
+      val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap)
 
-    val newUnits = for (u <- program.units) yield {
-      var fdsOf = Map[String, Set[FunDef]]()
+      val newFd = new FunDef(fd.id, fd.tparams, tSubst(fd.returnType), fdParams).copiedFrom(fd)
+      newFd.copyContentFrom(fd)
 
-      // Lift methods to the root class
-      for {
-        ch <- u.classHierarchies
-        c  <- ch
-        if c.parent.isDefined
-        fd <- c.methods
-        if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id))
-      } {
-        val root = c.ancestors.last
-        val tMap = c.tparams.zip(root.tparams.map{_.tp}).toMap
-        val tSubst: TypeTree => TypeTree  = instantiateType(_, tMap)
+      mdToCls += newFd -> c
 
-        val fdParams = fd.params map { vd =>
-          val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType))
-          ValDef(newId).setPos(vd.getPos)
-        }
-        val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap
-        val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap)
-
-        val newFd = new FunDef(fd.id, fd.tparams, tSubst(fd.returnType), fdParams).copiedFrom(fd)
-        newFd.copyContentFrom(fd)
-        val prec = fd.precondition.getOrElse(BooleanLiteral(true))
-        newFd.fullBody = eSubst(withPrecondition(
-          newFd.fullBody,
-          Some(and(
-            prec,
-            isInstOf(
-              This(root.typed),
-              c.typed(root.tparams.map{ _.tp })
-            )
-          ))
-        ))
+      newFd.fullBody = eSubst(newFd.fullBody)
 
-        c.unregisterMethod(fd.id)
-        root.registerMethod(newFd)
-      }
+      c.unregisterMethod(fd.id)
+      root.registerMethod(newFd)
+    }
 
+    val newUnits = for (u <- program.units) yield {
+      var fdsOf = Map[String, Set[FunDef]]()
       // 1) Create one function for each method
-      for { cd <- u.classHierarchyRoots if cd.methods.nonEmpty; fd <- cd.methods } {
+      for { cd <- u.classHierarchyRoots; fd <- cd.methods } {
         // We import class type params and freshen them
         val ctParams = cd.tparams map { _.freshen }
         val tparamsMap = cd.tparams.zip(ctParams map { _.tp }).toMap
@@ -157,6 +151,14 @@ object MethodLifting extends TransformationPhase {
         nfd.setPos(fd)
         nfd.addFlag(IsMethod(cd))
 
+        def classPre(fd: FunDef) = mdToCls.get(fd) match {
+          case None =>
+            BooleanLiteral(true)
+          case Some(cl) =>
+            isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp }))
+        }
+
+
         if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) {
           // Don't need to compose methods
           val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap
@@ -168,8 +170,15 @@ object MethodLifting extends TransformationPhase {
           }
 
           val insTp: Expr => Expr = instantiateType(_, tparamsMap, paramsMap)
-
           nfd.fullBody = insTp( postMap(thisToReceiver)(insTp(nfd.fullBody)) )
+
+          // Add precondition if the method was defined in a subclass
+          val pre = and(
+            classPre(fd),
+            nfd.precondition.getOrElse(BooleanLiteral(true))
+          )
+          nfd.fullBody = withPrecondition(nfd.fullBody, Some(pre))
+        
         } else {
           // We need to compose methods of subclasses
 
@@ -189,7 +198,7 @@ object MethodLifting extends TransformationPhase {
             (from,to) <- m.tparams zip fd.tparams
           } yield (from, to.tp)).toMap
           def inst(cs: Seq[MatchCase]) = instantiateType(
-            MatchExpr(Variable(receiver), cs).setPos(fd),
+            matchExpr(Variable(receiver), cs).setPos(fd),
             classParamsMap ++ methodParamsMap,
             paramsMap
           )
@@ -201,10 +210,18 @@ object MethodLifting extends TransformationPhase {
           ))
           val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType)))
 
-          /* Some obvious simplifications */
+          // Some simplifications
           val preSimple = {
-            val trivial = pre.forall { _.rhs == BooleanLiteral(true) }
-            if (trivial) None else Some(inst(pre).setPos(fd.getPos))
+            val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) }
+
+            val compositePre =
+              if (nonTrivial == 0) {
+                BooleanLiteral(true)
+              } else {
+                inst(pre).setPos(fd.getPos)
+              }
+
+            Some(and(classPre(fd), compositePre))
           }
           val postSimple = {
             val trivial = post.forall {
@@ -235,8 +252,8 @@ object MethodLifting extends TransformationPhase {
             withPrecondition(bodySimple, preSimple),
             postSimple
           )
-
         }
+
         mdToFds += fd -> nfd
         fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd)
       }