From dd985716b0aa3eab5befe75eae9e57f2b8412a3c Mon Sep 17 00:00:00 2001 From: Etienne Kneuss <colder@php.net> Date: Mon, 18 Nov 2013 15:27:56 +0100 Subject: [PATCH] Remove global state in Termination --- .../scala/leon/termination/ChainBuilder.scala | 13 ++++++------- .../scala/leon/termination/ChainComparator.scala | 6 ++---- .../scala/leon/termination/ChainProcessor.scala | 15 ++++++++------- .../termination/ComplexTerminationChecker.scala | 14 ++++++++++---- .../leon/termination/ComponentProcessor.scala | 2 +- .../scala/leon/termination/LoopProcessor.scala | 10 ++++++---- src/main/scala/leon/termination/Processor.scala | 15 +++++++++------ .../leon/termination/RecursionProcessor.scala | 6 ++---- .../scala/leon/termination/RelationBuilder.scala | 7 +++---- .../leon/termination/RelationComparator.scala | 6 ++---- .../leon/termination/RelationProcessor.scala | 15 ++++++++------- .../scala/leon/termination/StructuralSize.scala | 4 +--- 12 files changed, 58 insertions(+), 55 deletions(-) diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index f6ec38673..936e6a9a0 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -6,6 +6,8 @@ import leon.purescala.Trees._ import leon.purescala.TreeOps._ import leon.purescala.Common._ +import scala.collection.mutable.{Map => MutableMap} + object ChainID { private var counter: Int = 0 def get: Int = { @@ -79,13 +81,10 @@ final case class Chain(chain: List[Relation]) { } } -object ChainBuilder { - import scala.collection.mutable.{Map => MutableMap} +class ChainBuilder(relationBuilder: RelationBuilder) { private val chainCache : MutableMap[FunDef, Set[Chain]] = MutableMap() - def init : Unit = chainCache.clear - def run(funDef: FunDef): Set[Chain] = chainCache.get(funDef) match { case Some(chains) => chains case None => { @@ -94,14 +93,14 @@ object ChainBuilder { // Note that chains in partials are reversed to profit from O(1) insertion val (results, newPartials) = partials.foldLeft(List[List[Relation]](),List[(Relation, List[Relation])]())({ case ((results, partials), (first, chain @ Relation(_, _, FunctionInvocation(fd, _)) :: xs)) => - val cycle = RelationBuilder.run(fd).contains(first) + val cycle = relationBuilder.run(fd).contains(first) // reverse the chain when "returning" it since we're working on reversed chains val newResults = if (cycle) chain.reverse :: results else results // Partial chains can fall back onto a transition that was already taken (thus creating a cycle // inside the chain). Since this cycle will be discovered elsewhere, such partial chains should be // dropped from the partial chain list - val transitions = RelationBuilder.run(fd) -- chain.toSet + val transitions = relationBuilder.run(fd) -- chain.toSet val newPartials = transitions.map(transition => (first, transition :: chain)).toList (newResults, partials ++ newPartials) @@ -111,7 +110,7 @@ object ChainBuilder { results ++ chains(newPartials) } - val initialPartials = RelationBuilder.run(funDef).map(r => (r, r :: Nil)).toList + val initialPartials = relationBuilder.run(funDef).map(r => (r, r :: Nil)).toList val result = chains(initialPartials).map(Chain(_)).toSet chainCache(funDef) = result result diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index 278b4bc50..dc61531d3 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -7,10 +7,8 @@ import purescala.TypeTrees._ import purescala.Definitions._ import purescala.Common._ -object ChainComparator { - import StructuralSize._ - - def init : Unit = StructuralSize.init +class ChainComparator(structuralSize: StructuralSize) { + import structuralSize.size private object ContainerType { def unapply(c: ClassType): Option[(CaseClassDef, Seq[(Identifier, TypeTree)])] = c match { diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala index 83e20354b..6ac33d59b 100644 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -8,19 +8,20 @@ import purescala.Common._ import purescala.Extractors._ import purescala.Definitions._ -class ChainProcessor(checker: TerminationChecker) extends Processor(checker) with Solvable { +class ChainProcessor(checker: TerminationChecker, + chainBuilder: ChainBuilder, + val structuralSize: StructuralSize, + val strengthener: Strengthener) extends Processor(checker) with Solvable { val name: String = "Chain Processor" - ChainBuilder.init - ChainComparator.init + val chainComparator = new ChainComparator(structuralSize) def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = { - import scala.collection.mutable.{Map => MutableMap} implicit val debugSection = DebugSectionTermination reporter.debug("- Running ChainProcessor") - val allChainMap : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> ChainBuilder.run(funDef)).toMap + val allChainMap : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> chainBuilder.run(funDef)).toMap reporter.debug("- Computing all possible Chains") var counter = 0 val possibleChainMap : Map[FunDef, Set[Chain]] = allChainMap.mapValues(chains => chains.filter(chain => isWeakSAT(And(chain.loop())))) @@ -109,13 +110,13 @@ class ChainProcessor(checker: TerminationChecker) extends Processor(checker) wit reporter.debug("- Searching for structural size decrease") val sizeCleared : ClusterMap = clear(clusters, (fd, cluster) => { val (e1, e2s) = buildLoops(fd, cluster) - ChainComparator.sizeDecreasing(e1, e2s) + chainComparator.sizeDecreasing(e1, e2s) }) reporter.debug("- Searching for numeric convergence") val numericCleared : ClusterMap = clear(sizeCleared, (fd, cluster) => { val (e1, e2s) = buildLoops(fd, cluster) - ChainComparator.numericConverging(e1, e2s, cluster, checker) + chainComparator.numericConverging(e1, e2s, cluster, checker) }) val (okPairs, nokPairs) = numericCleared.partition(_._2.isEmpty) diff --git a/src/main/scala/leon/termination/ComplexTerminationChecker.scala b/src/main/scala/leon/termination/ComplexTerminationChecker.scala index 8dcd43fb2..b808d8c8d 100644 --- a/src/main/scala/leon/termination/ComplexTerminationChecker.scala +++ b/src/main/scala/leon/termination/ComplexTerminationChecker.scala @@ -11,13 +11,19 @@ class ComplexTerminationChecker(context: LeonContext, _program: Program) extends val name = "Complex Termination Checker" val description = "A modular termination checker with a few basic modules™" + val structuralSize = new StructuralSize() + val relationBuilder = new RelationBuilder() + val chainBuilder = new ChainBuilder(relationBuilder) + val relationComparator = new RelationComparator(structuralSize) + val strengthener = new Strengthener(relationComparator) + private val pipeline = new ProcessingPipeline( program, context, // required for solvers and reporting new ComponentProcessor(this), - new RecursionProcessor(this), - new RelationProcessor(this), - new ChainProcessor(this), - new LoopProcessor(this) + new RecursionProcessor(this, relationBuilder), + new RelationProcessor(this, relationBuilder, structuralSize, relationComparator, strengthener), + new ChainProcessor(this, chainBuilder, structuralSize, strengthener), + new LoopProcessor(this, chainBuilder, structuralSize, strengthener) ) private val clearedMap : MutableMap[FunDef, String] = MutableMap() diff --git a/src/main/scala/leon/termination/ComponentProcessor.scala b/src/main/scala/leon/termination/ComponentProcessor.scala index 1e980370b..2cf37cd78 100644 --- a/src/main/scala/leon/termination/ComponentProcessor.scala +++ b/src/main/scala/leon/termination/ComponentProcessor.scala @@ -3,6 +3,7 @@ package termination import purescala.TreeOps._ import purescala.Definitions._ +import scala.collection.mutable.{Map => MutableMap} class ComponentProcessor(checker: TerminationChecker) extends Processor(checker) { @@ -16,7 +17,6 @@ class ComponentProcessor(checker: TerminationChecker) extends Processor(checker) val components : List[Set[FunDef]] = ComponentBuilder.run(callGraph) val fdToSCC : Map[FunDef, Set[FunDef]] = components.map(set => set.map(fd => fd -> set)).flatten.toMap - import scala.collection.mutable.{Map => MutableMap} val terminationCache : MutableMap[FunDef, Boolean] = MutableMap() def terminates(fd: FunDef) : Boolean = terminationCache.getOrElse(fd, { val scc = fdToSCC.getOrElse(fd, Set()) // functions that aren't called and don't call belong to no SCC diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala index a45cc3c7d..e33f1c055 100644 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -5,14 +5,16 @@ import purescala.Definitions._ import purescala.Trees._ import purescala.TreeOps._ -class LoopProcessor(checker: TerminationChecker, k: Int = 10) extends Processor(checker) with Solvable { +class LoopProcessor(checker: TerminationChecker, + chainBuilder: ChainBuilder, + val structuralSize: StructuralSize, + val strengthener: Strengthener, + k: Int = 10) extends Processor(checker) with Solvable { val name: String = "Loop Processor" - ChainBuilder.init - def run(problem: Problem) = { - val allChains : Set[Chain] = problem.funDefs.map(fd => ChainBuilder.run(fd)).flatten + val allChains : Set[Chain] = problem.funDefs.map(fd => chainBuilder.run(fd)).flatten // Get reentrant loops (see ChainProcessor for more details) val chains : Set[Chain] = allChains.filter(chain => isWeakSAT(And(chain reentrant chain))) diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index dd826a56a..3b3311e22 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -28,7 +28,7 @@ abstract class Processor(val checker: TerminationChecker) { def process(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = run(problem) } -object Solvable { +class Strengthener(relationComparator: RelationComparator) { import scala.collection.mutable.{Set => MutableSet} private val strengthened : MutableSet[FunDef] = MutableSet() @@ -67,11 +67,11 @@ object Solvable { val sortedCallees : Seq[FunDef] = callees.toSeq.sortWith((fd1, fd2) => solver.checker.program.transitivelyCalls(fd2, fd1)) for (funDef <- sortedCallees if !strengthened(funDef) && funDef.hasBody && solver.checker.terminates(funDef).isGuaranteed) { // test if size is smaller or equal to input - val weekConstraintHolds = strengthenPostcondition(funDef, RelationComparator.softDecreasing) + val weekConstraintHolds = strengthenPostcondition(funDef, relationComparator.softDecreasing) if (weekConstraintHolds) { // try to improve postcondition with strictly smaller - strengthenPostcondition(funDef, RelationComparator.sizeDecreasing) + strengthenPostcondition(funDef, relationComparator.sizeDecreasing) } } } @@ -79,6 +79,9 @@ object Solvable { trait Solvable { self: Processor => + val structuralSize: StructuralSize + val strengthener: Strengthener + override def process(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = { self.run(problem) } @@ -86,10 +89,10 @@ trait Solvable { self: Processor => private var solvers: List[SolverFactory[Solver]] = null private var lastDefs: Set[FunDef] = Set() - def strengthenPostconditions(funDefs: Set[FunDef]) = Solvable.strengthenPostconditions(funDefs)(this) + def strengthenPostconditions(funDefs: Set[FunDef]) = strengthener.strengthenPostconditions(funDefs)(this) private def initSolvers { - val structDefs = StructuralSize.defs + val structDefs = structuralSize.defs if (structDefs != lastDefs || solvers == null) { val program : Program = self.checker.program val allDefs : Seq[Definition] = program.mainObject.defs ++ structDefs @@ -108,7 +111,7 @@ trait Solvable { self: Processor => // make Leon unroll them forever...) val dangerousCallsMap : Map[Expr, Expr] = functionCallsOf(problem).collect({ // extra definitions (namely size functions) are quaranteed to terminate because structures are non-looping - case fi @ FunctionInvocation(fd, args) if !StructuralSize.defs(fd) && !self.checker.terminates(fd).isGuaranteed => + case fi @ FunctionInvocation(fd, args) if !structuralSize.defs(fd) && !self.checker.terminates(fd).isGuaranteed => fi -> FreshIdentifier("noRun", true).setType(fi.getType).toVariable }).toMap diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala index 78f5ddbdb..1bbd88fd4 100644 --- a/src/main/scala/leon/termination/RecursionProcessor.scala +++ b/src/main/scala/leon/termination/RecursionProcessor.scala @@ -7,12 +7,10 @@ import purescala.Definitions._ import scala.annotation.tailrec -class RecursionProcessor(checker: TerminationChecker) extends Processor(checker) { +class RecursionProcessor(checker: TerminationChecker, relationBuilder: RelationBuilder) extends Processor(checker) { val name: String = "Recursion Processor" - RelationBuilder.init - private def isSubtreeOf(expr: Expr, id: Identifier) : Boolean = { @tailrec def rec(e: Expr, fst: Boolean): Boolean = e match { @@ -25,7 +23,7 @@ class RecursionProcessor(checker: TerminationChecker) extends Processor(checker) def run(problem: Problem) = if (problem.funDefs.size > 1) (Nil, List(problem)) else { val funDef = problem.funDefs.head - val relations = RelationBuilder.run(funDef) + val relations = relationBuilder.run(funDef) val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(fd, _)) => fd == funDef }) if (others.exists({ case Relation(_, _, FunctionInvocation(fd, _)) => !checker.terminates(fd).isGuaranteed })) (Nil, List(problem)) else { diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala index 11b411b0a..94e10516c 100644 --- a/src/main/scala/leon/termination/RelationBuilder.scala +++ b/src/main/scala/leon/termination/RelationBuilder.scala @@ -7,16 +7,15 @@ import purescala.TreeOps._ import purescala.Extractors._ import purescala.Common._ +import scala.collection.mutable.{Map => MutableMap} + final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocation) { override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.funDef.id + call.args.mkString("(",",",")") + ")" } -object RelationBuilder { - import scala.collection.mutable.{Map => MutableMap} +class RelationBuilder { private val relationCache : MutableMap[FunDef, Set[Relation]] = MutableMap() - def init : Unit = relationCache.clear - def run(funDef: FunDef): Set[Relation] = relationCache.get(funDef) match { case Some(relations) => relations case None => { diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala index 9f3985b2c..213301275 100644 --- a/src/main/scala/leon/termination/RelationComparator.scala +++ b/src/main/scala/leon/termination/RelationComparator.scala @@ -7,10 +7,8 @@ import purescala.TypeTrees._ import purescala.Definitions._ import purescala.Common._ -object RelationComparator { - import StructuralSize._ - - def init : Unit = StructuralSize.init +class RelationComparator(structuralSize: StructuralSize) { + import structuralSize.size def sizeDecreasing(e1: Expr, e2: Expr) = GreaterThan(size(e1), size(e2)) diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala index 24527eb92..7cb0b9111 100644 --- a/src/main/scala/leon/termination/RelationProcessor.scala +++ b/src/main/scala/leon/termination/RelationProcessor.scala @@ -8,24 +8,25 @@ import leon.purescala.Common._ import leon.purescala.Extractors._ import leon.purescala.Definitions._ -class RelationProcessor(checker: TerminationChecker) extends Processor(checker) with Solvable { +class RelationProcessor(checker: TerminationChecker, + relationBuilder: RelationBuilder, + val structuralSize: StructuralSize, + relationComparator: RelationComparator, + val strengthener: Strengthener) extends Processor(checker) with Solvable { val name: String = "Relation Processor" - RelationBuilder.init - RelationComparator.init - def run(problem: Problem) = { strengthenPostconditions(problem.funDefs) val formulas = problem.funDefs.map({ funDef => - funDef -> RelationBuilder.run(funDef).collect({ + funDef -> relationBuilder.run(funDef).collect({ case Relation(_, path, FunctionInvocation(fd, args)) if problem.funDefs(fd) => val (e1, e2) = (Tuple(funDef.args.map(_.toVariable)), Tuple(args)) def constraint(expr: Expr) = Implies(And(path.toSeq), expr) - val greaterThan = RelationComparator.sizeDecreasing(e1, e2) - val greaterEquals = RelationComparator.softDecreasing(e1, e2) + val greaterThan = relationComparator.sizeDecreasing(e1, e2) + val greaterEquals = relationComparator.softDecreasing(e1, e2) (fd, (constraint(greaterThan), constraint(greaterEquals))) }) }) diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index c85f24085..e10a30c28 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -7,7 +7,7 @@ import purescala.TypeTrees._ import purescala.Definitions._ import purescala.Common._ -object StructuralSize { +class StructuralSize { import scala.collection.mutable.{Map => MutableMap} private val sizeFunctionCache : MutableMap[TypeTree, FunDef] = MutableMap() @@ -62,8 +62,6 @@ object StructuralSize { } def defs : Set[FunDef] = Set(sizeFunctionCache.values.toSeq : _*) - - def init : Unit = sizeFunctionCache.clear } // vim: set ts=4 sw=4 et: -- GitLab