Skip to content
Snippets Groups Projects
UnitElimination.scala 6.01 KiB
package leon

import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.Extractors._
import purescala.TypeTrees._

object UnitElimination extends TransformationPhase {

  val name = "Unit Elimination"
  val description = "Remove all usage of the Unit type and value"

  private var fun2FreshFun: Map[FunDef, FunDef] = Map()
  private var id2FreshId: Map[Identifier, Identifier] = Map()

  def apply(ctx: LeonContext, pgm: Program): Program = {
    fun2FreshFun = Map()
    val allFuns = pgm.definedFunctions

    //first introduce new signatures without Unit parameters
    allFuns.foreach(fd => {
      if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) {
        val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPosInfo(fd)
        freshFunDef.fromLoop = fd.fromLoop
        freshFunDef.parent = fd.parent
        freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well..
        freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well..
        freshFunDef.addAnnotation(fd.annotations.toSeq:_*)
        fun2FreshFun += (fd -> freshFunDef)
      } else {
        fun2FreshFun += (fd -> fd) //this will make the next step simpler
      }
    })

    //then apply recursively to the bodies
    val newFuns = allFuns.flatMap(fd => if(fd.returnType == UnitType) Seq() else {
      val newBody = fd.body.map(body => removeUnit(body))
      val newFd = fun2FreshFun(fd)
      newFd.body = newBody
      Seq(newFd)
    })

    val Program(id, ObjectDef(objId, _, invariants)) = pgm
    val allClasses = pgm.definedClasses
    Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants))
  }

  private def simplifyType(tpe: TypeTree): TypeTree = tpe match {
    case TupleType(tpes) => tpes.map(simplifyType).filterNot{ case UnitType => true case _ => false } match {
      case Seq() => UnitType
      case Seq(tpe) => tpe
      case tpes => TupleType(tpes)
    }
    case t => t
  }

  //remove unit value as soon as possible, so expr should never be equal to a unit
  private def removeUnit(expr: Expr): Expr = {
    assert(expr.getType != UnitType)
    expr match {
      case fi@FunctionInvocation(fd, args) => {
        val newArgs = args.filterNot(arg => arg.getType == UnitType)
        FunctionInvocation(fun2FreshFun(fd), newArgs).setPosInfo(fi)
      }
      case t@Tuple(args) => {
        val TupleType(tpes) = t.getType
        val (newTpes, newArgs) = tpes.zip(args).filterNot{ case (UnitType, _) => true case _ => false }.unzip
        Tuple(newArgs.map(removeUnit)).setType(TupleType(newTpes))
      }