From dd985716b0aa3eab5befe75eae9e57f2b8412a3c Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Mon, 18 Nov 2013 15:27:56 +0100
Subject: [PATCH] Remove global state in Termination

---
 .../scala/leon/termination/ChainBuilder.scala     | 13 ++++++-------
 .../scala/leon/termination/ChainComparator.scala  |  6 ++----
 .../scala/leon/termination/ChainProcessor.scala   | 15 ++++++++-------
 .../termination/ComplexTerminationChecker.scala   | 14 ++++++++++----
 .../leon/termination/ComponentProcessor.scala     |  2 +-
 .../scala/leon/termination/LoopProcessor.scala    | 10 ++++++----
 src/main/scala/leon/termination/Processor.scala   | 15 +++++++++------
 .../leon/termination/RecursionProcessor.scala     |  6 ++----
 .../scala/leon/termination/RelationBuilder.scala  |  7 +++----
 .../leon/termination/RelationComparator.scala     |  6 ++----
 .../leon/termination/RelationProcessor.scala      | 15 ++++++++-------
 .../scala/leon/termination/StructuralSize.scala   |  4 +---
 12 files changed, 58 insertions(+), 55 deletions(-)

diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala
index f6ec38673..936e6a9a0 100644
--- a/src/main/scala/leon/termination/ChainBuilder.scala
+++ b/src/main/scala/leon/termination/ChainBuilder.scala
@@ -6,6 +6,8 @@ import leon.purescala.Trees._
 import leon.purescala.TreeOps._
 import leon.purescala.Common._
 
+import scala.collection.mutable.{Map => MutableMap}
+
 object ChainID {
   private var counter: Int = 0
   def get: Int = {
@@ -79,13 +81,10 @@ final case class Chain(chain: List[Relation]) {
   }
 }
 
-object ChainBuilder {
-  import scala.collection.mutable.{Map => MutableMap}
+class ChainBuilder(relationBuilder: RelationBuilder) {
 
   private val chainCache : MutableMap[FunDef, Set[Chain]] = MutableMap()
 
-  def init : Unit = chainCache.clear
-
   def run(funDef: FunDef): Set[Chain] = chainCache.get(funDef) match {
     case Some(chains) => chains
     case None => {
@@ -94,14 +93,14 @@ object ChainBuilder {
         // Note that chains in partials are reversed to profit from O(1) insertion
         val (results, newPartials) = partials.foldLeft(List[List[Relation]](),List[(Relation, List[Relation])]())({
           case ((results, partials), (first, chain @ Relation(_, _, FunctionInvocation(fd, _)) :: xs)) =>
-            val cycle = RelationBuilder.run(fd).contains(first)
+            val cycle = relationBuilder.run(fd).contains(first)
             // reverse the chain when "returning" it since we're working on reversed chains
             val newResults = if (cycle) chain.reverse :: results else results
 
             // Partial chains can fall back onto a transition that was already taken (thus creating a cycle
             // inside the chain). Since this cycle will be discovered elsewhere, such partial chains should be
             // dropped from the partial chain list
-            val transitions = RelationBuilder.run(fd) -- chain.toSet
+            val transitions = relationBuilder.run(fd) -- chain.toSet
             val newPartials = transitions.map(transition => (first, transition :: chain)).toList
 
             (newResults, partials ++ newPartials)
@@ -111,7 +110,7 @@ object ChainBuilder {
         results ++ chains(newPartials)
       }
 
-      val initialPartials = RelationBuilder.run(funDef).map(r => (r, r :: Nil)).toList
+      val initialPartials = relationBuilder.run(funDef).map(r => (r, r :: Nil)).toList
       val result = chains(initialPartials).map(Chain(_)).toSet
       chainCache(funDef) = result
       result
diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala
index 278b4bc50..dc61531d3 100644
--- a/src/main/scala/leon/termination/ChainComparator.scala
+++ b/src/main/scala/leon/termination/ChainComparator.scala
@@ -7,10 +7,8 @@ import purescala.TypeTrees._
 import purescala.Definitions._
 import purescala.Common._
 
-object ChainComparator {
-  import StructuralSize._
-
-  def init : Unit = StructuralSize.init
+class ChainComparator(structuralSize: StructuralSize) {
+  import structuralSize.size
 
   private object ContainerType {
     def unapply(c: ClassType): Option[(CaseClassDef, Seq[(Identifier, TypeTree)])] = c match {
diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala
index 83e20354b..6ac33d59b 100644
--- a/src/main/scala/leon/termination/ChainProcessor.scala
+++ b/src/main/scala/leon/termination/ChainProcessor.scala
@@ -8,19 +8,20 @@ import purescala.Common._
 import purescala.Extractors._
 import purescala.Definitions._
 
-class ChainProcessor(checker: TerminationChecker) extends Processor(checker) with Solvable {
+class ChainProcessor(checker: TerminationChecker,
+                     chainBuilder: ChainBuilder,
+                     val structuralSize: StructuralSize,
+                     val strengthener: Strengthener) extends Processor(checker) with Solvable {
 
   val name: String = "Chain Processor"
 
-  ChainBuilder.init
-  ChainComparator.init
+  val chainComparator = new ChainComparator(structuralSize)
 
   def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = {
-    import scala.collection.mutable.{Map => MutableMap}
     implicit val debugSection = DebugSectionTermination
 
     reporter.debug("- Running ChainProcessor")
-    val allChainMap       : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> ChainBuilder.run(funDef)).toMap
+    val allChainMap       : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> chainBuilder.run(funDef)).toMap
     reporter.debug("- Computing all possible Chains")
     var counter = 0
     val possibleChainMap  : Map[FunDef, Set[Chain]] = allChainMap.mapValues(chains => chains.filter(chain => isWeakSAT(And(chain.loop()))))
@@ -109,13 +110,13 @@ class ChainProcessor(checker: TerminationChecker) extends Processor(checker) wit
     reporter.debug("- Searching for structural size decrease")
     val sizeCleared : ClusterMap = clear(clusters, (fd, cluster) => {
       val (e1, e2s) = buildLoops(fd, cluster)
-      ChainComparator.sizeDecreasing(e1, e2s)
+      chainComparator.sizeDecreasing(e1, e2s)
     })
 
     reporter.debug("- Searching for numeric convergence")
     val numericCleared : ClusterMap = clear(sizeCleared, (fd, cluster) => {
       val (e1, e2s) = buildLoops(fd, cluster)
-      ChainComparator.numericConverging(e1, e2s, cluster, checker)
+      chainComparator.numericConverging(e1, e2s, cluster, checker)
     })
 
     val (okPairs, nokPairs) = numericCleared.partition(_._2.isEmpty)
diff --git a/src/main/scala/leon/termination/ComplexTerminationChecker.scala b/src/main/scala/leon/termination/ComplexTerminationChecker.scala
index 8dcd43fb2..b808d8c8d 100644
--- a/src/main/scala/leon/termination/ComplexTerminationChecker.scala
+++ b/src/main/scala/leon/termination/ComplexTerminationChecker.scala
@@ -11,13 +11,19 @@ class ComplexTerminationChecker(context: LeonContext, _program: Program) extends
   val name = "Complex Termination Checker"
   val description = "A modular termination checker with a few basic modules™"
 
+  val structuralSize     = new StructuralSize()
+  val relationBuilder    = new RelationBuilder()
+  val chainBuilder       = new ChainBuilder(relationBuilder)
+  val relationComparator = new RelationComparator(structuralSize)
+  val strengthener       = new Strengthener(relationComparator)
+
   private val pipeline = new ProcessingPipeline(
     program, context, // required for solvers and reporting
     new ComponentProcessor(this),
-    new RecursionProcessor(this),
-    new RelationProcessor(this),
-    new ChainProcessor(this),
-    new LoopProcessor(this)
+    new RecursionProcessor(this, relationBuilder),
+    new RelationProcessor(this, relationBuilder, structuralSize, relationComparator, strengthener),
+    new ChainProcessor(this, chainBuilder, structuralSize, strengthener),
+    new LoopProcessor(this, chainBuilder, structuralSize, strengthener)
   )
 
   private val clearedMap     : MutableMap[FunDef, String]               = MutableMap()
diff --git a/src/main/scala/leon/termination/ComponentProcessor.scala b/src/main/scala/leon/termination/ComponentProcessor.scala
index 1e980370b..2cf37cd78 100644
--- a/src/main/scala/leon/termination/ComponentProcessor.scala
+++ b/src/main/scala/leon/termination/ComponentProcessor.scala
@@ -3,6 +3,7 @@ package termination
 
 import purescala.TreeOps._
 import purescala.Definitions._
+import scala.collection.mutable.{Map => MutableMap}
 
 class ComponentProcessor(checker: TerminationChecker) extends Processor(checker) {
 
@@ -16,7 +17,6 @@ class ComponentProcessor(checker: TerminationChecker) extends Processor(checker)
     val components : List[Set[FunDef]]        = ComponentBuilder.run(callGraph)
     val fdToSCC    : Map[FunDef, Set[FunDef]] = components.map(set => set.map(fd => fd -> set)).flatten.toMap
 
-    import scala.collection.mutable.{Map => MutableMap}
     val terminationCache : MutableMap[FunDef, Boolean] = MutableMap()
     def terminates(fd: FunDef) : Boolean = terminationCache.getOrElse(fd, {
       val scc = fdToSCC.getOrElse(fd, Set()) // functions that aren't called and don't call belong to no SCC
diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala
index a45cc3c7d..e33f1c055 100644
--- a/src/main/scala/leon/termination/LoopProcessor.scala
+++ b/src/main/scala/leon/termination/LoopProcessor.scala
@@ -5,14 +5,16 @@ import purescala.Definitions._
 import purescala.Trees._
 import purescala.TreeOps._
 
-class LoopProcessor(checker: TerminationChecker, k: Int = 10) extends Processor(checker) with Solvable {
+class LoopProcessor(checker: TerminationChecker,
+                    chainBuilder: ChainBuilder,
+                    val structuralSize: StructuralSize,
+                    val strengthener: Strengthener,
+                    k: Int = 10) extends Processor(checker) with Solvable {
 
   val name: String = "Loop Processor"
 
-  ChainBuilder.init
-
   def run(problem: Problem) = {
-    val allChains : Set[Chain] = problem.funDefs.map(fd => ChainBuilder.run(fd)).flatten
+    val allChains : Set[Chain] = problem.funDefs.map(fd => chainBuilder.run(fd)).flatten
     // Get reentrant loops (see ChainProcessor for more details)
     val chains    : Set[Chain] = allChains.filter(chain => isWeakSAT(And(chain reentrant chain)))
 
diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala
index dd826a56a..3b3311e22 100644
--- a/src/main/scala/leon/termination/Processor.scala
+++ b/src/main/scala/leon/termination/Processor.scala
@@ -28,7 +28,7 @@ abstract class Processor(val checker: TerminationChecker) {
   def process(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = run(problem)
 }
 
-object Solvable {
+class Strengthener(relationComparator: RelationComparator) {
   import scala.collection.mutable.{Set => MutableSet}
 
   private val strengthened : MutableSet[FunDef] = MutableSet()
@@ -67,11 +67,11 @@ object Solvable {
     val sortedCallees : Seq[FunDef] = callees.toSeq.sortWith((fd1, fd2) => solver.checker.program.transitivelyCalls(fd2, fd1))
     for (funDef <- sortedCallees if !strengthened(funDef) && funDef.hasBody && solver.checker.terminates(funDef).isGuaranteed) {
       // test if size is smaller or equal to input
-      val weekConstraintHolds = strengthenPostcondition(funDef, RelationComparator.softDecreasing)
+      val weekConstraintHolds = strengthenPostcondition(funDef, relationComparator.softDecreasing)
 
       if (weekConstraintHolds) {
         // try to improve postcondition with strictly smaller
-        strengthenPostcondition(funDef, RelationComparator.sizeDecreasing)
+        strengthenPostcondition(funDef, relationComparator.sizeDecreasing)
       }
     }
   }
@@ -79,6 +79,9 @@ object Solvable {
 
 trait Solvable { self: Processor =>
 
+  val structuralSize: StructuralSize
+  val strengthener: Strengthener
+
   override def process(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = {
     self.run(problem)
   }
@@ -86,10 +89,10 @@ trait Solvable { self: Processor =>
   private var solvers: List[SolverFactory[Solver]] = null
   private var lastDefs: Set[FunDef] = Set()
 
-  def strengthenPostconditions(funDefs: Set[FunDef]) = Solvable.strengthenPostconditions(funDefs)(this)
+  def strengthenPostconditions(funDefs: Set[FunDef]) = strengthener.strengthenPostconditions(funDefs)(this)
 
   private def initSolvers {
-    val structDefs = StructuralSize.defs
+    val structDefs = structuralSize.defs
     if (structDefs != lastDefs || solvers == null) {
       val program     : Program         = self.checker.program
       val allDefs     : Seq[Definition] = program.mainObject.defs ++ structDefs
@@ -108,7 +111,7 @@ trait Solvable { self: Processor =>
     // make Leon unroll them forever...)
     val dangerousCallsMap : Map[Expr, Expr] = functionCallsOf(problem).collect({
       // extra definitions (namely size functions) are quaranteed to terminate because structures are non-looping
-      case fi @ FunctionInvocation(fd, args) if !StructuralSize.defs(fd) && !self.checker.terminates(fd).isGuaranteed =>
+      case fi @ FunctionInvocation(fd, args) if !structuralSize.defs(fd) && !self.checker.terminates(fd).isGuaranteed =>
         fi -> FreshIdentifier("noRun", true).setType(fi.getType).toVariable
     }).toMap
 
diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala
index 78f5ddbdb..1bbd88fd4 100644
--- a/src/main/scala/leon/termination/RecursionProcessor.scala
+++ b/src/main/scala/leon/termination/RecursionProcessor.scala
@@ -7,12 +7,10 @@ import purescala.Definitions._
 
 import scala.annotation.tailrec
 
-class RecursionProcessor(checker: TerminationChecker) extends Processor(checker) {
+class RecursionProcessor(checker: TerminationChecker, relationBuilder: RelationBuilder) extends Processor(checker) {
 
   val name: String = "Recursion Processor"
 
-  RelationBuilder.init
-
   private def isSubtreeOf(expr: Expr, id: Identifier) : Boolean = {
     @tailrec
     def rec(e: Expr, fst: Boolean): Boolean = e match {
@@ -25,7 +23,7 @@ class RecursionProcessor(checker: TerminationChecker) extends Processor(checker)
 
   def run(problem: Problem) = if (problem.funDefs.size > 1) (Nil, List(problem)) else {
     val funDef = problem.funDefs.head
-    val relations = RelationBuilder.run(funDef)
+    val relations = relationBuilder.run(funDef)
     val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(fd, _)) => fd == funDef })
 
     if (others.exists({ case Relation(_, _, FunctionInvocation(fd, _)) => !checker.terminates(fd).isGuaranteed })) (Nil, List(problem)) else {
diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala
index 11b411b0a..94e10516c 100644
--- a/src/main/scala/leon/termination/RelationBuilder.scala
+++ b/src/main/scala/leon/termination/RelationBuilder.scala
@@ -7,16 +7,15 @@ import purescala.TreeOps._
 import purescala.Extractors._
 import purescala.Common._
 
+import scala.collection.mutable.{Map => MutableMap}
+
 final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocation) {
   override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.funDef.id + call.args.mkString("(",",",")") + ")"
 }
 
-object RelationBuilder {
-  import scala.collection.mutable.{Map => MutableMap}
+class RelationBuilder {
   private val relationCache : MutableMap[FunDef, Set[Relation]] = MutableMap()
 
-  def init : Unit = relationCache.clear
-
   def run(funDef: FunDef): Set[Relation] = relationCache.get(funDef) match {
     case Some(relations) => relations
     case None => {
diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala
index 9f3985b2c..213301275 100644
--- a/src/main/scala/leon/termination/RelationComparator.scala
+++ b/src/main/scala/leon/termination/RelationComparator.scala
@@ -7,10 +7,8 @@ import purescala.TypeTrees._
 import purescala.Definitions._
 import purescala.Common._
 
-object RelationComparator {
-  import StructuralSize._
-
-  def init : Unit = StructuralSize.init
+class RelationComparator(structuralSize: StructuralSize) {
+  import structuralSize.size
 
   def sizeDecreasing(e1: Expr, e2: Expr) = GreaterThan(size(e1), size(e2))
 
diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala
index 24527eb92..7cb0b9111 100644
--- a/src/main/scala/leon/termination/RelationProcessor.scala
+++ b/src/main/scala/leon/termination/RelationProcessor.scala
@@ -8,24 +8,25 @@ import leon.purescala.Common._
 import leon.purescala.Extractors._
 import leon.purescala.Definitions._
 
-class RelationProcessor(checker: TerminationChecker) extends Processor(checker) with Solvable {
+class RelationProcessor(checker: TerminationChecker,
+                        relationBuilder: RelationBuilder,
+                        val structuralSize: StructuralSize,
+                        relationComparator: RelationComparator,
+                        val strengthener: Strengthener) extends Processor(checker) with Solvable {
 
   val name: String = "Relation Processor"
 
-  RelationBuilder.init
-  RelationComparator.init
-
   def run(problem: Problem) = {
 
     strengthenPostconditions(problem.funDefs)
 
     val formulas = problem.funDefs.map({ funDef =>
-      funDef -> RelationBuilder.run(funDef).collect({
+      funDef -> relationBuilder.run(funDef).collect({
         case Relation(_, path, FunctionInvocation(fd, args)) if problem.funDefs(fd) =>
           val (e1, e2) = (Tuple(funDef.args.map(_.toVariable)), Tuple(args))
           def constraint(expr: Expr) = Implies(And(path.toSeq), expr)
-          val greaterThan = RelationComparator.sizeDecreasing(e1, e2)
-          val greaterEquals = RelationComparator.softDecreasing(e1, e2)
+          val greaterThan = relationComparator.sizeDecreasing(e1, e2)
+          val greaterEquals = relationComparator.softDecreasing(e1, e2)
           (fd, (constraint(greaterThan), constraint(greaterEquals)))
       })
     })
diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala
index c85f24085..e10a30c28 100644
--- a/src/main/scala/leon/termination/StructuralSize.scala
+++ b/src/main/scala/leon/termination/StructuralSize.scala
@@ -7,7 +7,7 @@ import purescala.TypeTrees._
 import purescala.Definitions._
 import purescala.Common._
 
-object StructuralSize {
+class StructuralSize {
   import scala.collection.mutable.{Map => MutableMap}
 
   private val sizeFunctionCache : MutableMap[TypeTree, FunDef] = MutableMap()
@@ -62,8 +62,6 @@ object StructuralSize {
   }
 
   def defs : Set[FunDef] = Set(sizeFunctionCache.values.toSeq : _*)
-
-  def init : Unit = sizeFunctionCache.clear
 }
 
 // vim: set ts=4 sw=4 et:
-- 
GitLab