From fa9183567bda1a0b6ab32df50cc36516a8706ddf Mon Sep 17 00:00:00 2001
From: Samuel Gruetter <samuel.gruetter@epfl.ch>
Date: Tue, 2 Jun 2015 20:05:10 +0200
Subject: [PATCH] split the Termination Cake into two parts:

1) The termination checker (of which only 1 instance exists per run)
2) The modules which depend on a size function (of which several
   instances per run might exist, one per size function impl)
---
 .../scala/leon/termination/ChainBuilder.scala |  6 ++---
 .../leon/termination/ChainComparator.scala    |  7 ++---
 .../leon/termination/ChainProcessor.scala     | 15 ++++++-----
 .../ComplexTerminationChecker.scala           | 27 ++++++++++---------
 .../leon/termination/LoopProcessor.scala      |  6 ++---
 .../scala/leon/termination/Processor.scala    |  4 +--
 .../leon/termination/RecursionProcessor.scala |  4 +--
 .../leon/termination/RelationBuilder.scala    | 10 ++++---
 .../leon/termination/RelationProcessor.scala  | 13 ++++-----
 .../leon/termination/SelfCallsProcessor.scala |  2 +-
 .../scala/leon/termination/Strengthener.scala | 16 ++++++-----
 11 files changed, 61 insertions(+), 49 deletions(-)

diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala
index 3249656ef..678796e1b 100644
--- a/src/main/scala/leon/termination/ChainBuilder.scala
+++ b/src/main/scala/leon/termination/ChainBuilder.scala
@@ -113,12 +113,12 @@ final case class Chain(relations: List[Relation]) {
   lazy val inlined: Seq[Expr] = inlining.map(_._2)
 }
 
-trait ChainBuilder extends RelationBuilder { self: TerminationChecker with Strengthener with RelationComparator =>
+trait ChainBuilder extends RelationBuilder { self: Strengthener with RelationComparator =>
 
   protected type ChainSignature = (FunDef, Set[RelationSignature])
 
   protected def funDefChainSignature(funDef: FunDef): ChainSignature = {
-    funDef -> (self.program.callGraph.transitiveCallees(funDef) + funDef).map(funDefRelationSignature)
+    funDef -> (checker.program.callGraph.transitiveCallees(funDef) + funDef).map(funDefRelationSignature)
   }
 
   private val chainCache : MutableMap[FunDef, (Set[FunDef], Set[Chain], ChainSignature)] = MutableMap.empty
@@ -153,7 +153,7 @@ trait ChainBuilder extends RelationBuilder { self: TerminationChecker with Stren
         val Relation(_, _, FunctionInvocation(tfd, _), _) :: _ = chain
         val fd = tfd.fd
 
-        if (!self.program.callGraph.transitivelyCalls(fd, funDef)) {
+        if (!checker.program.callGraph.transitivelyCalls(fd, funDef)) {
           Set.empty[FunDef] -> Set.empty[Chain]
         } else if (fd == funDef) {
           Set.empty[FunDef] -> Set(Chain(chain.reverse))
diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala
index 4e327ce84..bbee27665 100644
--- a/src/main/scala/leon/termination/ChainComparator.scala
+++ b/src/main/scala/leon/termination/ChainComparator.scala
@@ -10,7 +10,8 @@ import purescala.TypeOps._
 import purescala.Constructors._
 import purescala.Common._
 
-trait ChainComparator { self : StructuralSize with TerminationChecker =>
+trait ChainComparator { self : StructuralSize =>
+  val checker: TerminationChecker
 
   private object ContainerType {
     def unapply(c: ClassType): Option[(CaseClassType, Seq[(Identifier, TypeTree)])] = c match {
@@ -175,8 +176,8 @@ trait ChainComparator { self : StructuralSize with TerminationChecker =>
         case NoEndpoint =>
           endpoint(thenn) min endpoint(elze)
         case ep =>
-          val terminatingThen = functionCallsOf(thenn).forall(fi => self.terminates(fi.tfd.fd).isGuaranteed)
-          val terminatingElze = functionCallsOf(elze).forall(fi => self.terminates(fi.tfd.fd).isGuaranteed)
+          val terminatingThen = functionCallsOf(thenn).forall(fi => checker.terminates(fi.tfd.fd).isGuaranteed)
+          val terminatingElze = functionCallsOf(elze).forall(fi => checker.terminates(fi.tfd.fd).isGuaranteed)
           val thenEndpoint = if (terminatingThen) ep max endpoint(thenn) else endpoint(thenn)
           val elzeEndpoint = if (terminatingElze) ep.inverse max endpoint(elze) else endpoint(elze)
           thenEndpoint max elzeEndpoint
diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala
index 5537d41e2..c5a8ccb69 100644
--- a/src/main/scala/leon/termination/ChainProcessor.scala
+++ b/src/main/scala/leon/termination/ChainProcessor.scala
@@ -8,20 +8,23 @@ import purescala.Common._
 import purescala.Definitions._
 import purescala.Constructors.tupleWrap
 
-class ChainProcessor(val checker: TerminationChecker with ChainBuilder with ChainComparator with Strengthener with StructuralSize) extends Processor with Solvable {
+class ChainProcessor(
+    val checker: TerminationChecker,
+    val modules: ChainBuilder with ChainComparator with Strengthener with StructuralSize
+) extends Processor with Solvable {
 
   val name: String = "Chain Processor"
 
   def run(problem: Problem) = {
     reporter.debug("- Strengthening postconditions")
-    checker.strengthenPostconditions(problem.funSet)(this)
+    modules.strengthenPostconditions(problem.funSet)(this)
 
     reporter.debug("- Strengthening applications")
-    checker.strengthenApplications(problem.funSet)(this)
+    modules.strengthenApplications(problem.funSet)(this)
 
     reporter.debug("- Running ChainBuilder")
     val chainsMap : Map[FunDef, (Set[FunDef], Set[Chain])] = problem.funSet.map { funDef =>
-      funDef -> checker.getChains(funDef)(this)
+      funDef -> modules.getChains(funDef)(this)
     }.toMap
 
     val loopPoints = chainsMap.foldLeft(Set.empty[FunDef]) { case (set, (fd, (fds, chains))) => set ++ fds }
@@ -48,7 +51,7 @@ class ChainProcessor(val checker: TerminationChecker with ChainBuilder with Chai
       reporter.debug("-+> Searching for structural size decrease")
 
       val (se1, se2s, _) = exprs(funDefs.head)
-      val structuralFormulas = checker.structuralDecreasing(se1, se2s)
+      val structuralFormulas = modules.structuralDecreasing(se1, se2s)
       val structuralDecreasing = structuralFormulas.exists(formula => definitiveALL(formula))
 
       reporter.debug("-+> Searching for numerical converging")
@@ -56,7 +59,7 @@ class ChainProcessor(val checker: TerminationChecker with ChainBuilder with Chai
       // worth checking multiple funDefs as the endpoint discovery can be context sensitive
       val numericDecreasing = funDefs.exists { fd =>
         val (ne1, ne2s, fdChains) = exprs(fd)
-        val numericFormulas = checker.numericConverging(ne1, ne2s, fdChains)
+        val numericFormulas = modules.numericConverging(ne1, ne2s, fdChains)
         numericFormulas.exists(formula => definitiveALL(formula))
       }
 
diff --git a/src/main/scala/leon/termination/ComplexTerminationChecker.scala b/src/main/scala/leon/termination/ComplexTerminationChecker.scala
index c889d6165..062f6af12 100644
--- a/src/main/scala/leon/termination/ComplexTerminationChecker.scala
+++ b/src/main/scala/leon/termination/ComplexTerminationChecker.scala
@@ -8,24 +8,27 @@ import purescala.Expressions._
 
 import scala.collection.mutable.{Map => MutableMap}
 
-class ComplexTerminationChecker(context: LeonContext, program: Program)
-       extends ProcessingPipeline(context, program)
-          with StructuralSize
-          with RelationComparator
-          with ChainComparator
-          with Strengthener
-          with RelationBuilder
-          with ChainBuilder {
+class ComplexTerminationChecker(context: LeonContext, program: Program) extends ProcessingPipeline(context, program) {
 
   val name = "Complex Termination Checker"
   val description = "A modular termination checker with a few basic modules™"
+  
+  val modules = new StructuralSize
+               with RelationComparator
+               with ChainComparator
+               with Strengthener
+               with RelationBuilder
+               with ChainBuilder 
+  {
+    val checker = ComplexTerminationChecker.this
+  }
 
   def processors = List(
-    new RecursionProcessor(this),
-    new RelationProcessor(this),
-    new ChainProcessor(this),
+    new RecursionProcessor(this, modules),
+    new RelationProcessor(this, modules),
+    new ChainProcessor(this, modules),
     new SelfCallsProcessor(this),
-    new LoopProcessor(this)
+    new LoopProcessor(this, modules)
   )
 
 }
diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala
index 9f419e8ab..930771443 100644
--- a/src/main/scala/leon/termination/LoopProcessor.scala
+++ b/src/main/scala/leon/termination/LoopProcessor.scala
@@ -10,16 +10,16 @@ import purescala.Constructors._
 
 import scala.collection.mutable.{Map => MutableMap}
 
-class LoopProcessor(val checker: TerminationChecker with ChainBuilder with Strengthener with StructuralSize, k: Int = 10) extends Processor with Solvable {
+class LoopProcessor(val checker: TerminationChecker, val modules: ChainBuilder with Strengthener with StructuralSize, k: Int = 10) extends Processor with Solvable {
 
   val name: String = "Loop Processor"
 
   def run(problem: Problem) = {
     reporter.debug("- Strengthening applications")
-    checker.strengthenApplications(problem.funSet)(this)
+    modules.strengthenApplications(problem.funSet)(this)
 
     reporter.debug("- Running ChainBuilder")
-    val chains : Set[Chain] = problem.funSet.flatMap(fd => checker.getChains(fd)(this)._2)
+    val chains : Set[Chain] = problem.funSet.flatMap(fd => modules.getChains(fd)(this)._2)
 
     reporter.debug("- Searching for loops")
     val nonTerminating: MutableMap[FunDef, Result] = MutableMap.empty
diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala
index 34ef279ff..db91c8602 100644
--- a/src/main/scala/leon/termination/Processor.scala
+++ b/src/main/scala/leon/termination/Processor.scala
@@ -26,12 +26,12 @@ trait Processor {
 
 trait Solvable extends Processor {
 
-  val checker : TerminationChecker with Strengthener with StructuralSize
+  val modules: Strengthener with StructuralSize
 
   private val solver: SolverFactory[Solver] = {
     val program     : Program     = checker.program
     val context     : LeonContext = checker.context
-    val sizeModule  : ModuleDef   = ModuleDef(FreshIdentifier("$size"), checker.defs.toSeq, false)
+    val sizeModule  : ModuleDef   = ModuleDef(FreshIdentifier("$size"), modules.defs.toSeq, false)
     val sizeUnit    : UnitDef     = UnitDef(FreshIdentifier("$size"),Seq(sizeModule)) 
     val newProgram  : Program     = program.copy( units = sizeUnit :: program.units)
 
diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala
index d0b29d62f..2a93aa853 100644
--- a/src/main/scala/leon/termination/RecursionProcessor.scala
+++ b/src/main/scala/leon/termination/RecursionProcessor.scala
@@ -8,7 +8,7 @@ import purescala.Common._
 
 import scala.annotation.tailrec
 
-class RecursionProcessor(val checker: TerminationChecker with RelationBuilder) extends Processor {
+class RecursionProcessor(val checker: TerminationChecker, val rb: RelationBuilder) extends Processor {
 
   val name: String = "Recursion Processor"
 
@@ -24,7 +24,7 @@ class RecursionProcessor(val checker: TerminationChecker with RelationBuilder) e
 
   def run(problem: Problem) = if (problem.funDefs.size > 1) None else {
     val funDef = problem.funDefs.head
-    val relations = checker.getRelations(funDef)
+    val relations = rb.getRelations(funDef)
     val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(tfd, _), _) => tfd.fd == funDef })
 
     if (others.exists({ case Relation(_, _, FunctionInvocation(tfd, _), _) => !checker.terminates(tfd.fd).isGuaranteed })) {
diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala
index 30c5c01f6..d2edc6228 100644
--- a/src/main/scala/leon/termination/RelationBuilder.scala
+++ b/src/main/scala/leon/termination/RelationBuilder.scala
@@ -13,13 +13,15 @@ final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocat
   override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.tfd.id + call.args.mkString("(",",",")") + "," + inLambda + ")"
 }
 
-trait RelationBuilder { self: TerminationChecker with Strengthener =>
+trait RelationBuilder { self: Strengthener =>
+
+  val checker: TerminationChecker
 
   protected type RelationSignature = (FunDef, Option[Expr], Option[Expr], Option[Expr], Boolean, Set[(FunDef, Boolean)])
 
   protected def funDefRelationSignature(fd: FunDef): RelationSignature = {
-    val strengthenedCallees = self.program.callGraph.callees(fd).map(fd => fd -> strengthened(fd))
-    (fd, fd.precondition, fd.body, fd.postcondition, self.terminates(fd).isGuaranteed, strengthenedCallees)
+    val strengthenedCallees = checker.program.callGraph.callees(fd).map(fd => fd -> strengthened(fd))
+    (fd, fd.precondition, fd.body, fd.postcondition, checker.terminates(fd).isGuaranteed, strengthenedCallees)
   }
 
   private val relationCache : MutableMap[FunDef, (Set[Relation], RelationSignature)] = MutableMap.empty
@@ -42,7 +44,7 @@ trait RelationBuilder { self: TerminationChecker with Strengthener =>
         }
 
         def collect(e: Expr, path: Seq[Expr]): Option[Relation] = e match {
-          case fi @ FunctionInvocation(f, args) if self.functions(f.fd) =>
+          case fi @ FunctionInvocation(f, args) if checker.functions(f.fd) =>
             val flatPath = path flatMap {
               case And(es) => es
               case expr => Seq(expr)
diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala
index 91fdce46e..0fd6dd380 100644
--- a/src/main/scala/leon/termination/RelationProcessor.scala
+++ b/src/main/scala/leon/termination/RelationProcessor.scala
@@ -12,25 +12,26 @@ import leon.purescala.Constructors._
 import leon.purescala.Definitions._
 
 class RelationProcessor(
-    val checker: TerminationChecker with RelationBuilder with RelationComparator with Strengthener with StructuralSize
+    val checker: TerminationChecker,
+    val modules: RelationBuilder with RelationComparator with Strengthener with StructuralSize
   ) extends Processor with Solvable {
 
   val name: String = "Relation Processor"
 
   def run(problem: Problem) = {
     reporter.debug("- Strengthening postconditions")
-    checker.strengthenPostconditions(problem.funSet)(this)
+    modules.strengthenPostconditions(problem.funSet)(this)
 
     reporter.debug("- Strengthening applications")
-    checker.strengthenApplications(problem.funSet)(this)
+    modules.strengthenApplications(problem.funSet)(this)
 
     val formulas = problem.funDefs.map({ funDef =>
-      funDef -> checker.getRelations(funDef).collect({
+      funDef -> modules.getRelations(funDef).collect({
         case Relation(_, path, FunctionInvocation(tfd, args), _) if problem.funSet(tfd.fd) =>
           val args0 = funDef.params.map(_.toVariable)
           def constraint(expr: Expr) = implies(andJoin(path.toSeq), expr)
-          val greaterThan = checker.sizeDecreasing(args0, args)
-          val greaterEquals = checker.softDecreasing(args0, args)
+          val greaterThan = modules.sizeDecreasing(args0, args)
+          val greaterEquals = modules.softDecreasing(args0, args)
           (tfd.fd, (constraint(greaterThan), constraint(greaterEquals)))
       })
     })
diff --git a/src/main/scala/leon/termination/SelfCallsProcessor.scala b/src/main/scala/leon/termination/SelfCallsProcessor.scala
index 67524e106..a18d292cd 100644
--- a/src/main/scala/leon/termination/SelfCallsProcessor.scala
+++ b/src/main/scala/leon/termination/SelfCallsProcessor.scala
@@ -6,7 +6,7 @@ import purescala.Common._
 import purescala.Expressions._
 import purescala.Constructors._
 
-class SelfCallsProcessor(val checker: TerminationChecker with ChainBuilder with Strengthener with StructuralSize) extends Processor with Solvable {
+class SelfCallsProcessor(val checker: TerminationChecker) extends Processor {
 
   val name: String = "Self Calls Processor"
 
diff --git a/src/main/scala/leon/termination/Strengthener.scala b/src/main/scala/leon/termination/Strengthener.scala
index 31ed7f002..ed77e8bef 100644
--- a/src/main/scala/leon/termination/Strengthener.scala
+++ b/src/main/scala/leon/termination/Strengthener.scala
@@ -12,16 +12,18 @@ import purescala.Constructors._
 
 import scala.collection.mutable.{Set => MutableSet, Map => MutableMap}
 
-trait Strengthener { self : TerminationChecker with RelationComparator with RelationBuilder =>
+trait Strengthener { self : RelationComparator =>
+
+  val checker: TerminationChecker
 
   private val strengthenedPost : MutableSet[FunDef] = MutableSet.empty
 
   def strengthenPostconditions(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) {
     // Strengthen postconditions on all accessible functions by adding size constraints
-    val callees : Set[FunDef] = funDefs.map(fd => self.program.callGraph.transitiveCallees(fd)).flatten
-    val sortedCallees : Seq[FunDef] = callees.toSeq.sortWith((fd1, fd2) => self.program.callGraph.transitivelyCalls(fd2, fd1))
+    val callees : Set[FunDef] = funDefs.map(fd => checker.program.callGraph.transitiveCallees(fd)).flatten
+    val sortedCallees : Seq[FunDef] = callees.toSeq.sortWith((fd1, fd2) => checker.program.callGraph.transitivelyCalls(fd2, fd1))
 
-    for (funDef <- sortedCallees if !strengthenedPost(funDef) && funDef.hasBody && self.terminates(funDef).isGuaranteed) {
+    for (funDef <- sortedCallees if !strengthenedPost(funDef) && funDef.hasBody && checker.terminates(funDef).isGuaranteed) {
       def strengthen(cmp: (Seq[Expr], Seq[Expr]) => Expr): Boolean = {
         val old = funDef.postcondition
         val postcondition = {
@@ -79,10 +81,10 @@ trait Strengthener { self : TerminationChecker with RelationComparator with Rela
   }
 
   def strengthenApplications(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) {
-    val transitiveFunDefs = funDefs ++ funDefs.flatMap(fd => self.program.callGraph.transitiveCallees(fd))
-    val sortedFunDefs = transitiveFunDefs.toSeq.sortWith((fd1, fd2) => self.program.callGraph.transitivelyCalls(fd2, fd1))
+    val transitiveFunDefs = funDefs ++ funDefs.flatMap(fd => checker.program.callGraph.transitiveCallees(fd))
+    val sortedFunDefs = transitiveFunDefs.toSeq.sortWith((fd1, fd2) => checker.program.callGraph.transitivelyCalls(fd2, fd1))
 
-    for (funDef <- sortedFunDefs if !strengthenedApp(funDef) && funDef.hasBody && self.terminates(funDef).isGuaranteed) {
+    for (funDef <- sortedFunDefs if !strengthenedApp(funDef) && funDef.hasBody && checker.terminates(funDef).isGuaranteed) {
 
       val appCollector = new CollectorWithPaths[(Identifier,Expr,Seq[Expr])] {
         def collect(e: Expr, path: Seq[Expr]): Option[(Identifier, Expr, Seq[Expr])] = e match {
-- 
GitLab