Skip to content
Snippets Groups Projects
ImperativeCodeElimination.scala 12.74 KiB
package leon

import purescala.Common._
import purescala.Definitions._
import purescala.Trees._
import purescala.TypeTrees._

object ImperativeCodeElimination extends Pass {

  val description = "Transform imperative constructs into purely functional code"

  private var varInScope = Set[Identifier]()
  private var parent: FunDef = null //the enclosing fundef

  def apply(pgm: Program): Program = {
    val allFuns = pgm.definedFunctions
    allFuns.foreach(fd => {
      parent = fd
      val (res, scope, _) = toFunction(fd.getBody)
      fd.body = Some(scope(res))
    })
    pgm
  }

  //return a "scope" consisting of purely functional code that defines potentially needed 
  //new variables (val, not var) and a mapping for each modified variable (var, not val :) )
  //to their new name defined in the scope. The first returned valued is the value of the expression
  //that should be introduced as such in the returned scope (the val already refers to the new names)
  private def toFunction(expr: Expr): (Expr, Expr => Expr, Map[Identifier, Identifier]) = {
    val res = expr match {
      case LetVar(id, e, b) => {
        val newId = FreshIdentifier(id.name).setType(id.getType)
        val (rhsVal, rhsScope, rhsFun) = toFunction(e)
        varInScope += id
        val (bodyRes, bodyScope, bodyFun) = toFunction(b)
        varInScope -= id
        val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, replaceNames(rhsFun + (id -> newId), bodyScope(body))))
        (bodyRes, scope, (rhsFun + (id -> newId)) ++ bodyFun)
      }
      case Assignment(id, e) => {
        assert(varInScope.contains(id))
        val newId = FreshIdentifier(id.name).setType(id.getType)
        val (rhsVal, rhsScope, rhsFun) = toFunction(e)
        val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, body))
        (UnitLiteral, scope, rhsFun + (id -> newId))
      }

      case ite@IfExpr(cond, tExpr, eExpr) => {
        val (cRes, cScope, cFun) = toFunction(cond)
        val (tRes, tScope, tFun) = toFunction(tExpr)
        val (eRes, eScope, eFun) = toFunction(eExpr)

        val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varInScope).toSeq
        val resId = FreshIdentifier("res").setType(ite.getType)
        val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType))
        val iteType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType))

        val thenVal = if(modifiedVars.isEmpty) tRes else Tuple(tRes +: modifiedVars.map(vId => tFun.get(vId) match {
          case Some(newId) => newId.toVariable
          case None => vId.toVariable
        }))
        thenVal.setType(iteType)

        val elseVal = if(modifiedVars.isEmpty) eRes else Tuple(eRes +: modifiedVars.map(vId => eFun.get(vId) match {
          case Some(newId) => newId.toVariable
          case None => vId.toVariable
        }))
        elseVal.setType(iteType)

        val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).setType(iteType)

        val scope = ((body: Expr) => {
          val tupleId = FreshIdentifier("t").setType(iteType)
          cScope(
            Let(tupleId, iteExpr, 
              if(freshIds.isEmpty)
                Let(resId, tupleId.toVariable, body)
              else
                Let(resId, TupleSelect(tupleId.toVariable, 1),
                  freshIds.zipWithIndex.foldLeft(body)((b, id) => 
                    Let(id._1, 
                      TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), 
                      b)))))
        })

        (resId.toVariable, scope, cFun ++ modifiedVars.zip(freshIds).toMap)
      }

      case m @ MatchExpr(scrut, cses) => {
        val csesRhs = cses.map(_.rhs) //we can ignore pattern, and the guard is required to be pure
        val (csesRes, csesScope, csesFun) = csesRhs.map(toFunction).unzip3
        val (scrutRes, scrutScope, scrutFun) = toFunction(scrut)

        val modifiedVars: Seq[Identifier] = csesFun.toSet.flatMap((m: Map[Identifier, Identifier]) => m.keys).intersect(varInScope).toSeq
        val resId = FreshIdentifier("res").setType(m.getType)
        val freshIds = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType))
        val matchType = if(modifiedVars.isEmpty) resId.getType else TupleType(resId.getType +: freshIds.map(_.getType))

        val csesVals = csesRes.zip(csesFun).map{ 
          case (cRes, cFun) => (if(modifiedVars.isEmpty) cRes else Tuple(cRes +: modifiedVars.map(vId => cFun.get(vId) match {
            case Some(newId) => newId.toVariable
            case None => vId.toVariable
          }))).setType(matchType)
        }

        val newRhs = csesVals.zip(csesScope).map{ 
          case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)).setType(matchType)
        }
        val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{
          case (SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs)
          case (GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs)
        }).setType(matchType).setPosInfo(m)

        val scope = ((body: Expr) => {
          val tupleId = FreshIdentifier("t").setType(matchType)
          scrutScope(
            Let(tupleId, matchExpr, 
              if(freshIds.isEmpty)
                Let(resId, tupleId.toVariable, body)
              else
                Let(resId, TupleSelect(tupleId.toVariable, 1),
                  freshIds.zipWithIndex.foldLeft(body)((b, id) => 
                    Let(id._1, 
                      TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), 
                      b)))))
        })

        (resId.toVariable, scope, scrutFun ++ modifiedVars.zip(freshIds).toMap)
      }
      case wh@While(cond, body) => {
        val (condRes, condScope, condFun) = toFunction(cond)
        val (_, bodyScope, bodyFun) = toFunction(body)
        val condBodyFun = condFun ++ bodyFun

        val modifiedVars: Seq[Identifier] = condBodyFun.keys.toSet.intersect(varInScope).toSeq

        if(modifiedVars.isEmpty)
          (UnitLiteral, (b: Expr) => b, Map())
        else {
          val whileFunVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType))
          val modifiedVars2WhileFunVars = modifiedVars.zip(whileFunVars).toMap
          val whileFunVarDecls = whileFunVars.map(id => VarDecl(id, id.getType))
          val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else TupleType(whileFunVars.map(_.getType))
          val whileFunDef = new FunDef(FreshIdentifier("while"), whileFunReturnType, whileFunVarDecls).setPosInfo(wh)
          whileFunDef.fromLoop = true
          whileFunDef.parent = Some(parent)
          
          val whileFunCond = condRes
          val whileFunRecursiveCall = replaceNames(condFun,
            bodyScope(FunctionInvocation(whileFunDef, modifiedVars.map(id => condBodyFun(id).toVariable)).setPosInfo(wh)))
          val whileFunBaseCase =
            (if(whileFunVars.size == 1) 
                condFun.get(modifiedVars.head).getOrElse(whileFunVars.head).toVariable
             else 
                Tuple(modifiedVars.map(id => condFun.get(id).getOrElse(modifiedVars2WhileFunVars(id)).toVariable))
            ).setType(whileFunReturnType)
          val whileFunBody = replaceNames(modifiedVars2WhileFunVars, 
            condScope(IfExpr(whileFunCond, whileFunRecursiveCall, whileFunBaseCase).setType(whileFunReturnType)))
          whileFunDef.body = Some(whileFunBody)

          val resVar = ResultVariable().setType(whileFunReturnType)
          val whileFunVars2ResultVars: Map[Expr, Expr] = 
            if(whileFunVars.size == 1) 
              Map(whileFunVars.head.toVariable -> resVar)
            else
              whileFunVars.zipWithIndex.map{ case (v, i) => (v.toVariable, TupleSelect(resVar, i+1).setType(v.getType)) }.toMap
          val modifiedVars2ResultVars: Map[Expr, Expr]  = modifiedVars.map(id => 
            (id.toVariable, whileFunVars2ResultVars(modifiedVars2WhileFunVars(id).toVariable))).toMap

          //the mapping of the trivial post condition variables depends on whether the condition has had some side effect
          val trivialPostcondition: Option[Expr] = Some(Not(replace(
            modifiedVars.map(id => (condFun.get(id).getOrElse(id).toVariable, modifiedVars2ResultVars(id.toVariable))).toMap,
            whileFunCond)))
          val invariantPrecondition: Option[Expr] = wh.invariant.map(expr => replaceNames(modifiedVars2WhileFunVars, expr))
          val invariantPostcondition: Option[Expr] = wh.invariant.map(expr => replace(modifiedVars2ResultVars, expr))
          whileFunDef.precondition = invariantPrecondition
          whileFunDef.postcondition = trivialPostcondition.map(expr => 
              And(expr, invariantPostcondition match { 
                case Some(e) => e
                case None => BooleanLiteral(true)
              }))

          val finalVars = modifiedVars.map(id => FreshIdentifier(id.name).setType(id.getType))
          val finalScope = ((body: Expr) => {
            val tupleId = FreshIdentifier("t").setType(whileFunReturnType)
            LetDef(
              whileFunDef,
              Let(tupleId, 
                  FunctionInvocation(whileFunDef, modifiedVars.map(_.toVariable)).setPosInfo(wh), 
                  if(finalVars.size == 1)
                    Let(finalVars.head, tupleId.toVariable, body)
                  else
                    finalVars.zipWithIndex.foldLeft(body)((b, id) => 
                      Let(id._1, 
                        TupleSelect(tupleId.toVariable, id._2 + 1).setType(id._1.getType), 
                        b))))
          })

          (UnitLiteral, finalScope, modifiedVars.zip(finalVars).toMap)
        }
      }

      case Block(Seq(), expr) => toFunction(expr)
      case Block(exprs, expr) => {
        val (scope, fun) = exprs.foldRight((body: Expr) => body, Map[Identifier, Identifier]())((e, acc) => {
          val (accScope, accFun) = acc
          val (_, rScope, rFun) = toFunction(e)
          val scope = (body: Expr) => rScope(replaceNames(rFun, accScope(body)))
          (scope, rFun ++ accFun)
        })
        val (lastRes, lastScope, lastFun) = toFunction(expr)
        val finalFun = fun ++ lastFun
        (replaceNames(finalFun, lastRes),
         (body: Expr) => scope(replaceNames(fun, lastScope(body))),
         finalFun)
      }

      //pure expression (that could still contain side effects as a subexpression) (evaluation order is from left to right)
      case Let(id, e, b) => {
        val (bindRes, bindScope, bindFun) = toFunction(e)
        val (bodyRes, bodyScope, bodyFun) = toFunction(b)
        (bodyRes, 
         (b2: Expr) => bindScope(Let(id, bindRes, replaceNames(bindFun, bodyScope(b2)))), 
         bindFun ++ bodyFun)
      }
      case LetDef(fd, b) => {
        //Recall that here the nested function should not access mutable variables from an outside scope
        val (bodyRes, bodyScope, bodyFun) = toFunction(b)
        (bodyRes, (b2: Expr) => LetDef(fd, bodyScope(b2)), bodyFun)
      }
      case n @ NAryOperator(Seq(), recons) => (n, (body: Expr) => body, Map())
      case n @ NAryOperator(args, recons) => {
        val (recArgs, scope, fun) = args.foldRight((Seq[Expr](), (body: Expr) => body, Map[Identifier, Identifier]()))((arg, acc) => {
          val (accArgs, accScope, accFun) = acc
          val (argVal, argScope, argFun) = toFunction(arg)
          val newScope = (body: Expr) => argScope(replaceNames(argFun, accScope(body)))
          (argVal +: accArgs, newScope, argFun ++ accFun)
        })
        (recons(recArgs).setType(n.getType), scope, fun)
      }
      case b @ BinaryOperator(a1, a2, recons) => {
        val (argVal1, argScope1, argFun1) = toFunction(a1)
        val (argVal2, argScope2, argFun2) = toFunction(a2)
        val scope = (body: Expr) => {
          val rhs = argScope2(replaceNames(argFun2, body))
          val lhs = argScope1(replaceNames(argFun1, rhs))
          lhs
        }
        (recons(argVal1, argVal2).setType(b.getType), scope, argFun1 ++ argFun2)
      }
      case u @ UnaryOperator(a, recons) => {
        val (argVal, argScope, argFun) = toFunction(a)
        (recons(argVal).setType(u.getType), argScope, argFun)
      }
      case (t: Terminal) => (t, (body: Expr) => body, Map())


      case _ => sys.error("not supported: " + expr)
    }
    //val codeRepresentation = res._2(Block(res._3.map{ case (id1, id2) => Assignment(id1, id2.toVariable)}.toSeq, res._1))
    //println("res of toFunction on: " + expr + " IS: " + codeRepresentation)
    res.asInstanceOf[(Expr, (Expr) => Expr, Map[Identifier, Identifier])] //need cast because it seems that res first map type is _ <: Identifier instead of Identifier
  }

  def replaceNames(fun: Map[Identifier, Identifier], expr: Expr) = replace(fun.map(ids => (ids._1.toVariable, ids._2.toVariable)), expr)

}