diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 4848bf10effc93b0b8db481a26d2c53db297e5f0..f94272e9b0923aa493681ed320f07b8a1a2005ee 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -95,49 +95,43 @@ object MethodLifting extends TransformationPhase { // First we create the appropriate functions from methods: var mdToFds = Map[FunDef, FunDef]() + var mdToCls = Map[FunDef, ClassDef]() + + // Lift methods to the root class + for { + u <- program.units + ch <- u.classHierarchies + c <- ch + if c.parent.isDefined + fd <- c.methods + if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id)) + } { + val root = c.ancestors.last + val tMap = c.tparams.zip(root.tparams.map{_.tp}).toMap + val tSubst: TypeTree => TypeTree = instantiateType(_, tMap) + + val fdParams = fd.params map { vd => + val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType)) + ValDef(newId).setPos(vd.getPos) + } + val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap + val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap) - val newUnits = for (u <- program.units) yield { - var fdsOf = Map[String, Set[FunDef]]() + val newFd = new FunDef(fd.id, fd.tparams, tSubst(fd.returnType), fdParams).copiedFrom(fd) + newFd.copyContentFrom(fd) - // Lift methods to the root class - for { - ch <- u.classHierarchies - c <- ch - if c.parent.isDefined - fd <- c.methods - if c.ancestors.forall(!_.methods.map{_.id}.contains(fd.id)) - } { - val root = c.ancestors.last - val tMap = c.tparams.zip(root.tparams.map{_.tp}).toMap - val tSubst: TypeTree => TypeTree = instantiateType(_, tMap) + mdToCls += newFd -> c - val fdParams = fd.params map { vd => - val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType)) - ValDef(newId).setPos(vd.getPos) - } - val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap - val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap) - - val newFd = new FunDef(fd.id, fd.tparams, tSubst(fd.returnType), fdParams).copiedFrom(fd) - newFd.copyContentFrom(fd) - val prec = fd.precondition.getOrElse(BooleanLiteral(true)) - newFd.fullBody = eSubst(withPrecondition( - newFd.fullBody, - Some(and( - prec, - isInstOf( - This(root.typed), - c.typed(root.tparams.map{ _.tp }) - ) - )) - )) + newFd.fullBody = eSubst(newFd.fullBody) - c.unregisterMethod(fd.id) - root.registerMethod(newFd) - } + c.unregisterMethod(fd.id) + root.registerMethod(newFd) + } + val newUnits = for (u <- program.units) yield { + var fdsOf = Map[String, Set[FunDef]]() // 1) Create one function for each method - for { cd <- u.classHierarchyRoots if cd.methods.nonEmpty; fd <- cd.methods } { + for { cd <- u.classHierarchyRoots; fd <- cd.methods } { // We import class type params and freshen them val ctParams = cd.tparams map { _.freshen } val tparamsMap = cd.tparams.zip(ctParams map { _.tp }).toMap @@ -157,6 +151,14 @@ object MethodLifting extends TransformationPhase { nfd.setPos(fd) nfd.addFlag(IsMethod(cd)) + def classPre(fd: FunDef) = mdToCls.get(fd) match { + case None => + BooleanLiteral(true) + case Some(cl) => + isInstOf(Variable(receiver), cl.typed(ctParams map { _.tp })) + } + + if (cd.knownDescendants.forall( cd => (cd.methods ++ cd.fields).forall(_.id != fd.id))) { // Don't need to compose methods val paramsMap = fd.params.zip(fdParams).map{case (x,y) => (x.id, y.id)}.toMap @@ -168,8 +170,15 @@ object MethodLifting extends TransformationPhase { } val insTp: Expr => Expr = instantiateType(_, tparamsMap, paramsMap) - nfd.fullBody = insTp( postMap(thisToReceiver)(insTp(nfd.fullBody)) ) + + // Add precondition if the method was defined in a subclass + val pre = and( + classPre(fd), + nfd.precondition.getOrElse(BooleanLiteral(true)) + ) + nfd.fullBody = withPrecondition(nfd.fullBody, Some(pre)) + } else { // We need to compose methods of subclasses @@ -189,7 +198,7 @@ object MethodLifting extends TransformationPhase { (from,to) <- m.tparams zip fd.tparams } yield (from, to.tp)).toMap def inst(cs: Seq[MatchCase]) = instantiateType( - MatchExpr(Variable(receiver), cs).setPos(fd), + matchExpr(Variable(receiver), cs).setPos(fd), classParamsMap ++ methodParamsMap, paramsMap ) @@ -201,10 +210,18 @@ object MethodLifting extends TransformationPhase { )) val (body, _) = makeCases(cd, fd.id, withoutSpec(_).getOrElse(NoTree(retType))) - /* Some obvious simplifications */ + // Some simplifications val preSimple = { - val trivial = pre.forall { _.rhs == BooleanLiteral(true) } - if (trivial) None else Some(inst(pre).setPos(fd.getPos)) + val nonTrivial = pre.count{ _.rhs != BooleanLiteral(true) } + + val compositePre = + if (nonTrivial == 0) { + BooleanLiteral(true) + } else { + inst(pre).setPos(fd.getPos) + } + + Some(and(classPre(fd), compositePre)) } val postSimple = { val trivial = post.forall { @@ -235,8 +252,8 @@ object MethodLifting extends TransformationPhase { withPrecondition(bodySimple, preSimple), postSimple ) - } + mdToFds += fd -> nfd fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd) }