From e37d54cef8f9005b64297f5f10904b6dc493046a Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Sat, 15 Jun 2013 00:58:10 +0200
Subject: [PATCH] Fixed small but in ChainProcessor

---
 .../scala/leon/purescala/PrettyPrinter.scala  |  2 +-
 src/main/scala/leon/purescala/Trees.scala     |  5 +-
 src/main/scala/leon/purescala/TypeTrees.scala |  4 +-
 .../scala/leon/termination/ChainBuilder.scala | 17 ++---
 .../leon/termination/ChainComparator.scala    |  2 +
 .../leon/termination/ChainProcessor.scala     | 69 ++++++++++++++-----
 .../leon/termination/LoopProcessor.scala      | 44 +++++-------
 .../scala/leon/termination/Processor.scala    | 65 +++++++++--------
 .../leon/termination/RecursionProcessor.scala | 25 +++----
 .../leon/termination/RelationBuilder.scala    |  6 +-
 .../leon/termination/RelationComparator.scala |  2 +
 .../leon/termination/RelationProcessor.scala  | 22 ++----
 .../leon/termination/StructuralSize.scala     | 36 ++++++----
 13 files changed, 170 insertions(+), 129 deletions(-)

diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala
index a5c9d9a89..f8b91d2b9 100644
--- a/src/main/scala/leon/purescala/PrettyPrinter.scala
+++ b/src/main/scala/leon/purescala/PrettyPrinter.scala
@@ -285,7 +285,7 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) {
 
     case (expr: PrettyPrintable) => expr.printWith(lvl, this)
 
-    case _ => sb.append("Expr?")
+    case _ => sb.append("Expr? (" + tree.getClass + ")")
   }
 
   // TYPE TREES
diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala
index a6be5fbbd..bbcfa9bb0 100644
--- a/src/main/scala/leon/purescala/Trees.scala
+++ b/src/main/scala/leon/purescala/Trees.scala
@@ -63,7 +63,10 @@ object Trees {
   case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType with ScalacPositional {
     val fixedType = funDef.returnType
 
-    // funDef.args.zip(args).foreach { case (a, c) => typeCheck(c, a.tpe) }
+    funDef.args.zip(args).foreach {
+      case (a, ResultVariable()) => true // assume its correct... don't know how to get context to really check
+      case (a, c) => typeCheck(c, a.tpe)
+    }
   }
   case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr with FixedType {
     val fixedType = leastUpperBound(thenn.getType, elze.getType).getOrElse(AnyType)
diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala
index 74b35aeb6..ab7c03cae 100644
--- a/src/main/scala/leon/purescala/TypeTrees.scala
+++ b/src/main/scala/leon/purescala/TypeTrees.scala
@@ -101,7 +101,9 @@ object TypeTrees {
         Some(classDefToClassType(found.get))
       }
     }
-
+    case (TupleType(args1), TupleType(args2)) =>
+      val args = (args1 zip args2).map(p => leastUpperBound(p._1, p._2))
+      if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None
     case (o1, o2) if (o1 == o2) => Some(o1)
     case (o1,BottomType) => Some(o1)
     case (BottomType,o2) => Some(o2)
diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala
index abd158530..065c54db8 100644
--- a/src/main/scala/leon/termination/ChainBuilder.scala
+++ b/src/main/scala/leon/termination/ChainBuilder.scala
@@ -10,9 +10,9 @@ final case class Chain(chain: List[Relation]) {
   def funDef  : FunDef                    = chain.head.funDef
   def funDefs : Set[FunDef]               = chain.map(_.funDef) toSet
 
-  def loop(initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = {
-    assert(initialSubst.nonEmpty || finalSubst.nonEmpty)
+  lazy val size: Int = chain.size
 
+  def loop(initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = {
     def rec(relations: List[Relation], subst: Map[Identifier, Expr]): Seq[Expr] = relations match {
       case Relation(_, path, FunctionInvocation(fd, args)) :: Nil =>
         assert(fd == funDef)
@@ -49,16 +49,6 @@ final case class Chain(chain: List[Relation]) {
     firstLoop ++ secondLoop
   }
 
-  def times(k: Int, initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = {
-    def rec(bindingSubst: Map[Identifier, Expr], count: Int) : Seq[Expr] = if (count == k) loop(initialSubst = bindingSubst, finalSubst = finalSubst) else {
-      val nextSubst : Map[Identifier, Expr] = funDef.args.map(arg => arg.id -> arg.id.freshen.toVariable).toMap
-      val currentLoop = loop(initialSubst = bindingSubst, finalSubst = nextSubst)
-      val rest = rec(nextSubst, count + 1)
-      currentLoop ++ rest
-    }
-    rec(initialSubst, 1)
-  }
-
   def inlined: TraversableOnce[Expr] = {
     def rec(list: List[Relation], mapping: Map[Identifier, Expr]): List[Expr] = list match {
       case Relation(_, _, FunctionInvocation(fd, args)) :: xs =>
@@ -80,6 +70,9 @@ object ChainBuilder {
   import scala.collection.mutable.{Map => MutableMap}
 
   private val chainCache : MutableMap[FunDef, Set[Chain]] = MutableMap()
+
+  def init : Unit = chainCache.clear
+
   def run(funDef: FunDef): Set[Chain] = chainCache.get(funDef) match {
     case Some(chains) => chains
     case None => {
diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala
index 901d72895..9a6352a2a 100644
--- a/src/main/scala/leon/termination/ChainComparator.scala
+++ b/src/main/scala/leon/termination/ChainComparator.scala
@@ -10,6 +10,8 @@ import purescala.Common._
 object ChainComparator {
   import StructuralSize._
 
+  def init : Unit = StructuralSize.init
+
   def sizeDecreasing(e1: TypedExpr, e2s: Seq[(Seq[Expr], Expr)]) = _sizeDecreasing(e1, e2s map {
     case (path, e2) => (path, exprToTypedExpr(e2))
   })
diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala
index b40ccf166..e50ab2e2a 100644
--- a/src/main/scala/leon/termination/ChainProcessor.scala
+++ b/src/main/scala/leon/termination/ChainProcessor.scala
@@ -12,25 +12,35 @@ class ChainProcessor(checker: TerminationChecker) extends Processor(checker) wit
 
   val name: String = "Chain Processor"
 
-  def run(problem: Problem) = {
-    val allChainMap : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> ChainBuilder.run(funDef)).toMap
-    val allChains   : Set[Chain]              = allChainMap.values.flatten.toSet
+  ChainBuilder.init
+  ChainComparator.init
 
-    // We check that loops can reenter themselves after a run. If not, then this is not a chain (since it will
-    // enter another chain and their conjunction is contained elsewhere in the chains set)
-    // Note: We are checking reentrance SAT, not looking for a counter example so we negate the formula!
-    val validChains : Set[Chain]              = allChains.filter(chain => !solve(Not(And(chain reentrant chain))).isValid)
-    val chainMap    : Map[FunDef, Set[Chain]] = allChainMap.mapValues(chains => chains intersect validChains)
+  def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = {
+    reporter.info("- Running ChainProcessor")
+    val allChainMap       : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> ChainBuilder.run(funDef)).toMap
+    reporter.info("- Computing all possible Chains")
+    val possibleChainMap  : Map[FunDef, Set[Chain]] = allChainMap.mapValues(chains => chains.filter(chain => isSAT(And(chain.loop()))))
+    reporter.info("- Collecting re-entrant Chains")
+    val reentrantChainMap : Map[FunDef, Set[Chain]] = possibleChainMap.mapValues(chains => chains.filter(chain => isSAT(And(chain reentrant chain))))
 
     // We build a cross-chain map that determines which chains can reenter into another one after a loop.
     // Note: We are also checking reentrance SAT here, so again, we negate the formula!
-    val crossChains : Map[Chain, Set[Chain]] = chainMap.map({ case (funDef, chains) =>
-      chains.map(chain => chain -> (chains - chain).filter(other => !solve(Not(And(chain reentrant other))).isValid))
+    reporter.info("- Computing cross-chain map")
+    val crossChains       : Map[Chain, Set[Chain]]  = possibleChainMap.map({ case (funDef, chains) =>
+      val reentrant = reentrantChainMap(funDef)
+      val reentrantPairs = reentrant.map(chain => chain -> Set(chain))
+      val crosswise = (chains -- reentrant).map(chain => chain -> {
+        reentrant.filter(other => isSAT(And(chain reentrant other)))
+      })
+      reentrantPairs ++ crosswise
     }).flatten.toMap
 
+    val validChainMap     : Map[FunDef, Set[Chain]] = possibleChainMap.map({ case (funDef, chains) => funDef -> chains.filter(crossChains(_).nonEmpty) })
+
     // We use the cross-chains to build chain clusters. For each cluster, we must prove that the SAME argument
     // decreases in each of the chains in the cluster!
-    val clusters : Map[FunDef, Set[Set[Chain]]] = {
+    reporter.info("- Building initial cluster estimation by fix-point iteration")
+    val generalClusters : Map[FunDef, Set[Set[Chain]]] = {
       def cluster(set: Set[Chain]): Set[Chain] = {
         set ++ set.map(crossChains(_)).flatten
       }
@@ -51,9 +61,19 @@ class ChainProcessor(checker: TerminationChecker) extends Processor(checker) wit
         filterClusters(allClusters.toList.sortBy(- _.size)).toSet
       }
 
-      chainMap.map({ case (funDef, chains) => funDef -> build(chains) })
+      validChainMap.map({ case (funDef, chains) => funDef -> build(chains) })
     }
 
+    reporter.info("- Trimming down to final clusters")
+    val clusters : Map[FunDef, Set[Set[Chain]]] = generalClusters.map({ case (funDef, clusters) =>
+      funDef -> clusters.map(cluster => cluster.toSeq.sortBy(_.size).foldLeft(Set[Chain]())({ case (acc, chain) =>
+        val chainElements : Set[Relation] = chain.chain.toSet
+        val seenElements  : Set[Relation] = acc.map(_.chain).flatten.toSet
+        if (chainElements -- seenElements nonEmpty) acc + chain else acc
+      })).filter(_.nonEmpty)
+    })
+
+    reporter.info("- Strengthening postconditions")
     strengthenPostconditions(problem.funDefs)
 
     def buildLoops(fd: FunDef, cluster: Set[Chain]): (Expr, Seq[(Seq[Expr], Expr)]) = {
@@ -71,23 +91,36 @@ class ChainProcessor(checker: TerminationChecker) extends Processor(checker) wit
     type ClusterMap = Map[FunDef, Set[Set[Chain]]]
     type FormulaGenerator = (FunDef, Set[Chain]) => Expr
 
-    def clear(clusters: ClusterMap, gen: FormulaGenerator): ClusterMap = clusters.map({ case (fd, clusters) =>
-      val remaining = clusters.filter(cluster => !solve(gen(fd, cluster)).isValid)
-      fd -> remaining
-    })
+    def clear(clusters: ClusterMap, gen: FormulaGenerator): ClusterMap = {
+      val formulas = clusters.map({ case (fd, clusters) =>
+        (fd, clusters.map(cluster => cluster -> gen(fd, cluster)))
+      })
 
+      initSolvers // add structural size functions to solver
+      formulas.map({ case (fd, clustersWithFormulas) =>
+        fd -> clustersWithFormulas.filter({ case (cluster, formula) => !isAlwaysSAT(formula) }).map(_._1)
+      })
+    }
+
+    reporter.info("- Searching for structural size decrease")
     val sizeCleared : ClusterMap = clear(clusters, (fd, cluster) => {
       val (e1, e2s) = buildLoops(fd, cluster)
       ChainComparator.sizeDecreasing(e1, e2s)
     })
 
+    reporter.info("- Searching for numeric convergence")
     val numericCleared : ClusterMap = clear(sizeCleared, (fd, cluster) => {
       val (e1, e2s) = buildLoops(fd, cluster)
       ChainComparator.numericConverging(e1, e2s, cluster, checker)
     })
 
     val (okPairs, nokPairs) = numericCleared.partition(_._2.isEmpty)
-    val newProblems = if (nokPairs nonEmpty) List(Problem(nokPairs.map(_._1).toSet)) else Nil
-    (okPairs.map(p => Cleared(p._1)), newProblems)
+    val nok = nokPairs.map(_._1).toSet
+    val (ok, transitiveNok) = okPairs.map(_._1).partition({ fd =>
+      checker.program.transitiveCallees(fd) intersect nok isEmpty
+    })
+    val allNok = nok ++ transitiveNok
+    val newProblems = if (allNok nonEmpty) List(Problem(allNok)) else Nil
+    (ok.map(Cleared(_)), newProblems)
   }
 }
diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala
index f0791056f..5f0c6db3c 100644
--- a/src/main/scala/leon/termination/LoopProcessor.scala
+++ b/src/main/scala/leon/termination/LoopProcessor.scala
@@ -9,37 +9,29 @@ class LoopProcessor(checker: TerminationChecker, k: Int = 10) extends Processor(
 
   val name: String = "Loop Processor"
 
+  ChainBuilder.init
+
   def run(problem: Problem) = {
     val allChains : Set[Chain] = problem.funDefs.map(fd => ChainBuilder.run(fd)).flatten
     // Get reentrant loops (see ChainProcessor for more details)
-    val chains    : Set[Chain] = allChains.filter(chain => !solve(Not(And(chain reentrant chain))).isValid)
-
-    def findLoops(chains: Set[Chain]) = {
-      def rec(chains: Set[Chain], count: Int): Map[FunDef, Seq[Expr]] = if (count == k) Map() else {
-        val nonTerminating = chains.flatMap({ chain =>
-          val freshArgs : Seq[Expr] = chain.funDef.args.map(arg => arg.id.freshen.toVariable)
-          val finalBindings = (chain.funDef.args.map(_.id) zip freshArgs).toMap
-          val path = chain.times(count, finalSubst = finalBindings)
-          val formula = And(path :+ Equals(Tuple(chain.funDef.args.map(_.toVariable)), Tuple(freshArgs)))
-
-          val solvable = functionCallsOf(formula).forall({
-            case FunctionInvocation(fd, args) => checker.terminates(fd).isGuaranteed
-          })
-
-          if (!solvable) None else solve(Not(formula)) match {
-            case Solution(false, model) => Some(chain.funDef, chain.funDef.args.map(arg => model(arg.id)))
-            case _ => None
-          }
-        }).toMap
-
-        val remainingChains = chains.filter(chain => nonTerminating.contains(chain.funDef))
-        nonTerminating ++ rec(remainingChains, count + 1)
-      }
+    val chains    : Set[Chain] = allChains.filter(chain => isSAT(And(chain reentrant chain)))
+
+    val nonTerminating = chains.flatMap({ chain =>
+      val freshArgs : Seq[Expr] = chain.funDef.args.map(arg => arg.id.freshen.toVariable)
+      val finalBindings = (chain.funDef.args.map(_.id) zip freshArgs).toMap
+      val path = chain.loop(finalSubst = finalBindings)
+      val formula = And(path :+ Equals(Tuple(chain.funDef.args.map(_.toVariable)), Tuple(freshArgs)))
 
-      rec(chains, 1)
-    }
+      val solvable = functionCallsOf(formula).forall({
+        case FunctionInvocation(fd, args) => checker.terminates(fd).isGuaranteed
+      })
+
+      if (!solvable) None else getModel(formula) match {
+        case Some(map) => Some(chain.funDef -> chain.funDef.args.map(arg => map(arg.id)))
+        case _ => None
+      }
+    }).toMap
 
-    val nonTerminating = findLoops(chains)
     val results = nonTerminating.map({ case (funDef, args) => Broken(funDef, args) })
     val remaining = problem.funDefs -- nonTerminating.keys
     val newProblems = if (remaining.nonEmpty) List(Problem(remaining)) else Nil
diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala
index 760a44e37..95ec075f5 100644
--- a/src/main/scala/leon/termination/Processor.scala
+++ b/src/main/scala/leon/termination/Processor.scala
@@ -21,20 +21,9 @@ abstract class Processor(val checker: TerminationChecker) {
 
   val name: String
 
-  def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem])
-}
-
-class Solution(solution: Option[Boolean], val model: Map[Identifier, Expr]) {
-  lazy val isValid : Boolean = solution getOrElse false
-}
+  val reporter = checker.context.reporter
 
-object NoSolution extends Solution(None, Map())
-
-object Solution {
-  def unapply(s: Solution): Option[(Boolean, Map[Identifier, Expr])] = {
-    if (s == NoSolution) None
-    else Some(s.isValid, s.model)
-  }
+  def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem])
 }
 
 object Solvable {
@@ -57,7 +46,7 @@ object Solvable {
     val resFresh = FreshIdentifier("result", true).setType(body.getType)
     val formula = Implies(prec, Let(resFresh, body, replace(Map(ResultVariable() -> Variable(resFresh)), post)))
 
-    if (!solver.solve(formula).isValid) {
+    if (!solver.isAlwaysSAT(formula)) {
       funDef.postcondition = postcondition
       strengthened.add(funDef)
       false
@@ -85,17 +74,25 @@ object Solvable {
 
 trait Solvable { self: Processor =>
 
+  private var solvers: List[Solver] = null
+
   def strengthenPostconditions(funDefs: Set[FunDef]) = Solvable.strengthenPostconditions(funDefs)(this)
 
-  def solve(problem: Expr): Solution = {
+  def initSolvers {
     val program     : Program         = self.checker.program
     val allDefs     : Seq[Definition] = program.mainObject.defs ++ StructuralSize.defs
     val newProgram  : Program         = program.copy(mainObject = program.mainObject.copy(defs = allDefs))
+    val context     : LeonContext     = self.checker.context.copy(reporter = new QuietReporter())
 
-    val solvers0 = new TrivialSolver(self.checker.context) :: new FairZ3Solver(self.checker.context) :: Nil
-    val solvers = solvers0.map(new TimeoutSolver(_, 500))
+    val solvers0 = new TrivialSolver(context) :: new FairZ3Solver(context) :: Nil
+    solvers = solvers0.map(new TimeoutSolver(_, 500))
     solvers.foreach(_.setProgram(newProgram))
+  }
+
+  type Solution = (Option[Boolean], Map[Identifier, Expr])
 
+  private def solve(problem: Expr): Solution = {
+    if (solvers == null) initSolvers
     // drop functions from constraints that might not terminate (and may therefore
     // make Leon unroll them forever...)
     val dangerousCallsMap : Map[Expr, Expr] = functionCallsOf(problem).collect({
@@ -115,16 +112,29 @@ trait Solvable { self: Processor =>
           superseeded = superseeded ++ Set(se.superseeds: _*)
 
           se.init()
-          val (satResult, model) = se.solveSAT(Not(expr))
-          val solverResult = satResult.map(!_)
+          val (satResult, model) = se.solveSAT(expr)
 
-          if (!solverResult.isDefined) None
-          else Some(new Solution(solverResult, model))
+          if (!satResult.isDefined) None
+          else Some(satResult, model)
         }
       }
     }
 
-    solvers.collectFirst({ case Solved(result) => result }) getOrElse NoSolution
+    solvers.collectFirst({ case Solved(s, model) => (s, model) }) getOrElse (None, Map())
+  }
+
+  def isSAT(problem: Expr): Boolean = {
+    solve(problem)._1 getOrElse false
+  }
+
+  def isAlwaysSAT(problem: Expr): Boolean = {
+    solve(Not(problem))._1.map(!_) getOrElse false
+  }
+
+  def getModel(problem: Expr): Option[Map[Identifier, Expr]] = {
+    val solution = solve(problem)
+    if (solution._1 getOrElse false) Some(solution._2)
+    else None
   }
 }
 
@@ -157,7 +167,7 @@ class ProcessingPipeline(program: Program, context: LeonContext, _processors: Pr
 
   private def printResult(results: List[Result]) {
     val sb = new StringBuilder()
-    sb.append("- Queue.head Processing Result:\n")
+    sb.append("- Processing Result:\n")
     for(result <- results) result match {
       case Cleared(fd) => sb.append("    %-10s %s\n".format(fd.id, "Cleared"))
       case Broken(fd, args) => sb.append("    %-10s %s\n".format(fd.id, "Broken for arguments: " + args.mkString("(", ",", ")")))
@@ -173,10 +183,11 @@ class ProcessingPipeline(program: Program, context: LeonContext, _processors: Pr
   }
 
   def run : Iterator[(String, List[Result])] = new Iterator[(String, List[Result])] {
-    // basic sanity check, funDefs can't call themselves in precondition!
-    assert(initialProblem.funDefs.forall(fd => !fd.precondition.map({ precondition =>
-      functionCallsOf(precondition).map(fi => program.transitiveCallees(fi.funDef)).flatten
-    }).flatten.toSet(fd)))
+    // basic sanity check, funDefs shouldn't call themselves in precondition!
+    // XXX: it seems like some do...
+    // assert(initialProblem.funDefs.forall(fd => !fd.precondition.map({ precondition =>
+    //   functionCallsOf(precondition).map(fi => program.transitiveCallees(fi.funDef)).flatten
+    // }).flatten.toSet(fd)))
 
     def hasNext : Boolean      = problems.nonEmpty
     def next    : (String, List[Result]) = {
diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala
index 569e660c7..78f5ddbdb 100644
--- a/src/main/scala/leon/termination/RecursionProcessor.scala
+++ b/src/main/scala/leon/termination/RecursionProcessor.scala
@@ -11,6 +11,8 @@ class RecursionProcessor(checker: TerminationChecker) extends Processor(checker)
 
   val name: String = "Recursion Processor"
 
+  RelationBuilder.init
+
   private def isSubtreeOf(expr: Expr, id: Identifier) : Boolean = {
     @tailrec
     def rec(e: Expr, fst: Boolean): Boolean = e match {
@@ -23,20 +25,19 @@ class RecursionProcessor(checker: TerminationChecker) extends Processor(checker)
 
   def run(problem: Problem) = if (problem.funDefs.size > 1) (Nil, List(problem)) else {
     val funDef = problem.funDefs.head
-
-    val selfRecursiveRelations = RelationBuilder.run(funDef).filter({
-      case Relation(_, _, FunctionInvocation(fd, _)) =>
-        fd == funDef || checker.terminates(fd).isGuaranteed
-    })
-
-    val decreases = funDef.args.zipWithIndex.exists({ case (arg, index) =>
-      selfRecursiveRelations.forall({ case Relation(_, _, FunctionInvocation(_, args)) =>
-        isSubtreeOf(args(index), arg.id)
+    val relations = RelationBuilder.run(funDef)
+    val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(fd, _)) => fd == funDef })
+
+    if (others.exists({ case Relation(_, _, FunctionInvocation(fd, _)) => !checker.terminates(fd).isGuaranteed })) (Nil, List(problem)) else {
+      val decreases = funDef.args.zipWithIndex.exists({ case (arg, index) =>
+        recursive.forall({ case Relation(_, _, FunctionInvocation(_, args)) =>
+          isSubtreeOf(args(index), arg.id)
+        })
       })
-    })
 
-    if (!decreases) (Nil, List(problem))
-    else (Cleared(funDef) :: Nil, Nil)
+      if (!decreases) (Nil, List(problem))
+      else (Cleared(funDef) :: Nil, Nil)
+    }
   }
 }
 
diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala
index 3483197e8..8e50fcecb 100644
--- a/src/main/scala/leon/termination/RelationBuilder.scala
+++ b/src/main/scala/leon/termination/RelationBuilder.scala
@@ -13,7 +13,9 @@ final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocat
 
 object RelationBuilder {
   import scala.collection.mutable.{Map => MutableMap}
-  val relationCache : MutableMap[FunDef, Set[Relation]] = MutableMap()
+  private val relationCache : MutableMap[FunDef, Set[Relation]] = MutableMap()
+
+  def init : Unit = relationCache.clear
 
   def run(funDef: FunDef): Set[Relation] = relationCache.get(funDef) match {
     case Some(relations) => relations
@@ -62,8 +64,6 @@ object RelationBuilder {
         case _ => sys.error("Expression "+e+" ["+e.getClass+"] is not extractable")
       }
 
-      // TODO: throw error if we see funDef in precondition or postcondition
-
       val precondition = funDef.precondition getOrElse BooleanLiteral(true)
       val precRelations = funDef.precondition.map(e => visit(simplifyLets(matchToIfThenElse(e)), Nil)).flatten.toSet
       val bodyRelations = funDef.body.map(e => visit(simplifyLets(matchToIfThenElse(e)), List(precondition))).flatten.toSet
diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala
index 6266e1586..e6a6c9721 100644
--- a/src/main/scala/leon/termination/RelationComparator.scala
+++ b/src/main/scala/leon/termination/RelationComparator.scala
@@ -10,6 +10,8 @@ import purescala.Common._
 object RelationComparator {
   import StructuralSize._
 
+  def init : Unit = StructuralSize.init
+
   def sizeDecreasing(e1: TypedExpr, e2: TypedExpr) = GreaterThan(size(e1), size(e2))
   def sizeDecreasing(e1:      Expr, e2: TypedExpr) = GreaterThan(size(e1), size(e2))
   def sizeDecreasing(e1: TypedExpr, e2:      Expr) = GreaterThan(size(e1), size(e2))
diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala
index 973ca4105..2f094c838 100644
--- a/src/main/scala/leon/termination/RelationProcessor.scala
+++ b/src/main/scala/leon/termination/RelationProcessor.scala
@@ -12,6 +12,9 @@ class RelationProcessor(checker: TerminationChecker) extends Processor(checker)
 
   val name: String = "Relation Processor"
 
+  RelationBuilder.init
+  RelationComparator.init
+
   def run(problem: Problem) = {
 
     strengthenPostconditions(problem.funDefs)
@@ -32,10 +35,12 @@ class RelationProcessor(checker: TerminationChecker) extends Processor(checker)
     case class Dep(deps: Set[FunDef]) extends Result
     case object Failure extends Result
 
+    initSolvers
+
     val decreasing = formulas.map({ case (fd, formulas) =>
       val solved = formulas.map({ case (fid, (gt, ge)) =>
-        if(solve(gt).isValid) Success
-        else if(solve(ge).isValid) Dep(Set(fid))
+        if(isAlwaysSAT(gt)) Success
+        else if(isAlwaysSAT(ge)) Dep(Set(fid))
         else Failure
       })
       val result = if(solved.contains(Failure)) Failure else {
@@ -69,18 +74,5 @@ class RelationProcessor(checker: TerminationChecker) extends Processor(checker)
     val results = terminating.map(Cleared(_)).toList
     val newProblems = if (problem.funDefs intersect nonTerminating nonEmpty) List(Problem(nonTerminating)) else Nil
     (results, newProblems)
-
-    /*
-    val noIncrease = gtformulas.forall(solvers.solve(_._2))
-    if(noIncrease) {
-      val isReducing = eqformulas.map(x => x._1 -> solvers.solve(x._2))
-      if(isReducing.exists(!_._2)) {
-        val (ok,nok) = isReducing.partition(_._2) match { case (xs, ys) => (xs.map(_._1), ys.map(_._1)) }
-        ProcessingResult(Nil, ok.map(Conditional(_, nok)) toList, List(problem filter nok))
-      } else if(noArgs.nonEmpty) {
-        ProcessingResult(Nil, functionsOfInterest.map(Conditional(_, noArgs)) toList, List(problem filter noArgs))
-      } else ProcessingResult(problem.callers.map(Cleared(_, "size relation formula solved")) toList, Nil, Nil)
-    } else ProcessingResult(Nil, Nil, List(problem))
-    */
   }
 }
diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala
index 60bc506b7..412c48660 100644
--- a/src/main/scala/leon/termination/StructuralSize.scala
+++ b/src/main/scala/leon/termination/StructuralSize.scala
@@ -19,21 +19,29 @@ object StructuralSize {
 
   private val sizeFunctionCache : MutableMap[TypeTree, FunDef] = MutableMap()
   def size(typedExpr: TypedExpr) : Expr = {
-    def funDef(tpe: TypeTree, cases: => Seq[MatchCase]) = sizeFunctionCache.get(tpe) match {
-      case Some(fd) => fd
-      case None =>
-        val argument = VarDecl(FreshIdentifier("x"), tpe)
-        val fd = new FunDef(FreshIdentifier("size", true), Int32Type, Seq(argument))
-        sizeFunctionCache(tpe) = fd
+    def funDef(tpe: TypeTree, cases: => Seq[MatchCase]) = {
+      // we want to reuse generic size functions for sub-types
+      val argumentType = tpe match {
+        case CaseClassType(cd) if cd.parent.isDefined => classDefToClassType(cd.parent.get)
+        case _ => tpe
+      }
 
-        val body = simplifyLets(matchToIfThenElse(MatchExpr(argument.toVariable, cases)))
-        val postSubcalls = functionCallsOf(body).map(GreaterThan(_, IntLiteral(0))).toSeq
-        val postRecursive = GreaterThan(ResultVariable(), IntLiteral(0))
-        val postcondition = And(postSubcalls :+ postRecursive)
+      sizeFunctionCache.get(argumentType) match {
+        case Some(fd) => fd
+        case None =>
+          val argument = VarDecl(FreshIdentifier("x"), argumentType)
+          val fd = new FunDef(FreshIdentifier("size", true), Int32Type, Seq(argument))
+          sizeFunctionCache(argumentType) = fd
 
-        fd.body = Some(body)
-        fd.postcondition = Some(postcondition)
-        fd
+          val body = simplifyLets(matchToIfThenElse(MatchExpr(argument.toVariable, cases)))
+          val postSubcalls = functionCallsOf(body).map(GreaterThan(_, IntLiteral(0))).toSeq
+          val postRecursive = GreaterThan(ResultVariable(), IntLiteral(0))
+          val postcondition = And(postSubcalls :+ postRecursive)
+
+          fd.body = Some(body)
+          fd.postcondition = Some(postcondition)
+          fd
+      }
     }
 
     def caseClassType2MatchCase(_c: ClassTypeDef): MatchCase = {
@@ -60,6 +68,8 @@ object StructuralSize {
   }
 
   def defs : Set[FunDef] = Set(sizeFunctionCache.values.toSeq : _*)
+
+  def init : Unit = sizeFunctionCache.clear
 }
 
 // vim: set ts=4 sw=4 et:
-- 
GitLab