Skip to content
Snippets Groups Projects
ChainBuilder.scala 5.09 KiB
package leon
package termination

import leon.purescala.Definitions._
import leon.purescala.Trees._
import leon.purescala.TreeOps._
import leon.purescala.Common._

import scala.collection.mutable.{Map => MutableMap}

object ChainID {
  private var counter: Int = 0
  def get: Int = {
    counter = counter + 1
    counter
  }
}

final case class Chain(chain: List[Relation]) {
  val id = ChainID.get

  override def equals(obj: Any): Boolean = obj.isInstanceOf[Chain] && obj.asInstanceOf[Chain].id == id
  override def hashCode(): Int = id

  def funDef  : FunDef                    = chain.head.funDef
  def funDefs : Set[FunDef]               = chain.map(_.funDef).toSet

  lazy val size: Int = chain.size

  def loop(initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = {
    def rec(relations: List[Relation], subst: Map[Identifier, Expr]): Seq[Expr] = relations match {
      case Relation(_, path, FunctionInvocation(fd, args)) :: Nil =>
        assert(fd == funDef)
        val newPath = path.map(replaceFromIDs(subst, _))
        val equalityConstraints = if (finalSubst.isEmpty) Seq() else {
          val newArgs = args.map(replaceFromIDs(subst, _))
          (fd.args.map(arg => finalSubst(arg.id)) zip newArgs).map(p => Equals(p._1, p._2))
        }
        newPath ++ equalityConstraints
      case Relation(_, path, FunctionInvocation(fd, args)) :: xs =>
        val formalArgs = fd.args.map(_.id)
        val freshFormalArgVars = formalArgs.map(_.freshen.toVariable)
        val formalArgsMap: Map[Identifier, Expr] = (formalArgs zip freshFormalArgVars).toMap
        val (newPath, newArgs) = (path.map(replaceFromIDs(subst, _)), args.map(replaceFromIDs(subst, _)))
        val constraints = newPath ++ (freshFormalArgVars zip newArgs).map(p => Equals(p._1, p._2))
        constraints ++ rec(xs, formalArgsMap)
      case Nil => sys.error("Empty chain shouldn't exist by construction")
    }

    val subst : Map[Identifier, Expr] = if (initialSubst.nonEmpty) initialSubst else {
      funDef.args.map(arg => arg.id -> arg.toVariable).toMap
    }
    val Chain(relations) = this
    rec(relations, subst)
  }

  def reentrant(other: Chain) : Seq[Expr] = {
    assert(funDef == other.funDef)
    val bindingSubst : Map[Identifier, Expr] = funDef.args.map({
      arg => arg.id -> arg.id.freshen.toVariable
    }).toMap
    val firstLoop = loop(finalSubst = bindingSubst)
    val secondLoop = other.loop(initialSubst = bindingSubst)
    firstLoop ++ secondLoop
  }

  def inlined: TraversableOnce[Expr] = {
    def rec(list: List[Relation], mapping: Map[Identifier, Expr]): List[Expr] = list match {
      case Relation(_, _, FunctionInvocation(fd, args)) :: xs =>
        val mappedArgs = args.map(replaceFromIDs(mapping, _))
        val newMapping = fd.args.map(_.id).zip(mappedArgs).toMap
        // We assume we have a body at this point. It would be weird to have gotten here without one...
        val expr = hoistIte(expandLets(matchToIfThenElse(fd.body.get)))
        val inlinedExpr = replaceFromIDs(newMapping, expr)
        inlinedExpr:: rec(xs, newMapping)
      case Nil => Nil
    }

    val body = hoistIte(expandLets(matchToIfThenElse(funDef.body.get)))
    body :: rec(chain, funDef.args.map(arg => arg.id -> arg.toVariable).toMap)
  }
}

class ChainBuilder(relationBuilder: RelationBuilder) {

  private val chainCache : MutableMap[FunDef, Set[Chain]] = MutableMap()

  def run(funDef: FunDef): Set[Chain] = chainCache.get(funDef) match {
    case Some(chains) => chains
    case None => {
      // Note that this method will generate duplicate cycles (in fact, it will generate all list representations of a cycle)
      def chains(partials: List[(Relation, List[Relation])]): List[List[Relation]] = if (partials.isEmpty) Nil else {
        // Note that chains in partials are reversed to profit from O(1) insertion
        val (results, newPartials) = partials.foldLeft(List[List[Relation]](),List[(Relation, List[Relation])]())({
          case ((results, partials), (first, chain @ Relation(_, _, FunctionInvocation(fd, _)) :: xs)) =>
            val cycle = relationBuilder.run(fd).contains(first)
            // reverse the chain when "returning" it since we're working on reversed chains
            val newResults = if (cycle) chain.reverse :: results else results

            // Partial chains can fall back onto a transition that was already taken (thus creating a cycle
            // inside the chain). Since this cycle will be discovered elsewhere, such partial chains should be
            // dropped from the partial chain list
            val transitions = relationBuilder.run(fd) -- chain.toSet
            val newPartials = transitions.map(transition => (first, transition :: chain)).toList

            (newResults, partials ++ newPartials)
          case (_, (_, Nil)) => scala.sys.error("Empty partial chain shouldn't exist by construction")
        })

        results ++ chains(newPartials)
      }

      val initialPartials = relationBuilder.run(funDef).map(r => (r, r :: Nil)).toList
      val result = chains(initialPartials).map(Chain(_)).toSet
      chainCache(funDef) = result
      result
    }
  }
}