From e37d54cef8f9005b64297f5f10904b6dc493046a Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Sat, 15 Jun 2013 00:58:10 +0200 Subject: [PATCH] Fixed small but in ChainProcessor --- .../scala/leon/purescala/PrettyPrinter.scala | 2 +- src/main/scala/leon/purescala/Trees.scala | 5 +- src/main/scala/leon/purescala/TypeTrees.scala | 4 +- .../scala/leon/termination/ChainBuilder.scala | 17 ++--- .../leon/termination/ChainComparator.scala | 2 + .../leon/termination/ChainProcessor.scala | 69 ++++++++++++++----- .../leon/termination/LoopProcessor.scala | 44 +++++------- .../scala/leon/termination/Processor.scala | 65 +++++++++-------- .../leon/termination/RecursionProcessor.scala | 25 +++---- .../leon/termination/RelationBuilder.scala | 6 +- .../leon/termination/RelationComparator.scala | 2 + .../leon/termination/RelationProcessor.scala | 22 ++---- .../leon/termination/StructuralSize.scala | 36 ++++++---- 13 files changed, 170 insertions(+), 129 deletions(-) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index a5c9d9a89..f8b91d2b9 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -285,7 +285,7 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { case (expr: PrettyPrintable) => expr.printWith(lvl, this) - case _ => sb.append("Expr?") + case _ => sb.append("Expr? (" + tree.getClass + ")") } // TYPE TREES diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index a6be5fbbd..bbcfa9bb0 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -63,7 +63,10 @@ object Trees { case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType with ScalacPositional { val fixedType = funDef.returnType - // funDef.args.zip(args).foreach { case (a, c) => typeCheck(c, a.tpe) } + funDef.args.zip(args).foreach { + case (a, ResultVariable()) => true // assume its correct... don't know how to get context to really check + case (a, c) => typeCheck(c, a.tpe) + } } case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr with FixedType { val fixedType = leastUpperBound(thenn.getType, elze.getType).getOrElse(AnyType) diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index 74b35aeb6..ab7c03cae 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -101,7 +101,9 @@ object TypeTrees { Some(classDefToClassType(found.get)) } } - + case (TupleType(args1), TupleType(args2)) => + val args = (args1 zip args2).map(p => leastUpperBound(p._1, p._2)) + if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None case (o1, o2) if (o1 == o2) => Some(o1) case (o1,BottomType) => Some(o1) case (BottomType,o2) => Some(o2) diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index abd158530..065c54db8 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -10,9 +10,9 @@ final case class Chain(chain: List[Relation]) { def funDef : FunDef = chain.head.funDef def funDefs : Set[FunDef] = chain.map(_.funDef) toSet - def loop(initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = { - assert(initialSubst.nonEmpty || finalSubst.nonEmpty) + lazy val size: Int = chain.size + def loop(initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = { def rec(relations: List[Relation], subst: Map[Identifier, Expr]): Seq[Expr] = relations match { case Relation(_, path, FunctionInvocation(fd, args)) :: Nil => assert(fd == funDef) @@ -49,16 +49,6 @@ final case class Chain(chain: List[Relation]) { firstLoop ++ secondLoop } - def times(k: Int, initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = { - def rec(bindingSubst: Map[Identifier, Expr], count: Int) : Seq[Expr] = if (count == k) loop(initialSubst = bindingSubst, finalSubst = finalSubst) else { - val nextSubst : Map[Identifier, Expr] = funDef.args.map(arg => arg.id -> arg.id.freshen.toVariable).toMap - val currentLoop = loop(initialSubst = bindingSubst, finalSubst = nextSubst) - val rest = rec(nextSubst, count + 1) - currentLoop ++ rest - } - rec(initialSubst, 1) - } - def inlined: TraversableOnce[Expr] = { def rec(list: List[Relation], mapping: Map[Identifier, Expr]): List[Expr] = list match { case Relation(_, _, FunctionInvocation(fd, args)) :: xs => @@ -80,6 +70,9 @@ object ChainBuilder { import scala.collection.mutable.{Map => MutableMap} private val chainCache : MutableMap[FunDef, Set[Chain]] = MutableMap() + + def init : Unit = chainCache.clear + def run(funDef: FunDef): Set[Chain] = chainCache.get(funDef) match { case Some(chains) => chains case None => { diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index 901d72895..9a6352a2a 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -10,6 +10,8 @@ import purescala.Common._ object ChainComparator { import StructuralSize._ + def init : Unit = StructuralSize.init + def sizeDecreasing(e1: TypedExpr, e2s: Seq[(Seq[Expr], Expr)]) = _sizeDecreasing(e1, e2s map { case (path, e2) => (path, exprToTypedExpr(e2)) }) diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala index b40ccf166..e50ab2e2a 100644 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -12,25 +12,35 @@ class ChainProcessor(checker: TerminationChecker) extends Processor(checker) wit val name: String = "Chain Processor" - def run(problem: Problem) = { - val allChainMap : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> ChainBuilder.run(funDef)).toMap - val allChains : Set[Chain] = allChainMap.values.flatten.toSet + ChainBuilder.init + ChainComparator.init - // We check that loops can reenter themselves after a run. If not, then this is not a chain (since it will - // enter another chain and their conjunction is contained elsewhere in the chains set) - // Note: We are checking reentrance SAT, not looking for a counter example so we negate the formula! - val validChains : Set[Chain] = allChains.filter(chain => !solve(Not(And(chain reentrant chain))).isValid) - val chainMap : Map[FunDef, Set[Chain]] = allChainMap.mapValues(chains => chains intersect validChains) + def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = { + reporter.info("- Running ChainProcessor") + val allChainMap : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> ChainBuilder.run(funDef)).toMap + reporter.info("- Computing all possible Chains") + val possibleChainMap : Map[FunDef, Set[Chain]] = allChainMap.mapValues(chains => chains.filter(chain => isSAT(And(chain.loop())))) + reporter.info("- Collecting re-entrant Chains") + val reentrantChainMap : Map[FunDef, Set[Chain]] = possibleChainMap.mapValues(chains => chains.filter(chain => isSAT(And(chain reentrant chain)))) // We build a cross-chain map that determines which chains can reenter into another one after a loop. // Note: We are also checking reentrance SAT here, so again, we negate the formula! - val crossChains : Map[Chain, Set[Chain]] = chainMap.map({ case (funDef, chains) => - chains.map(chain => chain -> (chains - chain).filter(other => !solve(Not(And(chain reentrant other))).isValid)) + reporter.info("- Computing cross-chain map") + val crossChains : Map[Chain, Set[Chain]] = possibleChainMap.map({ case (funDef, chains) => + val reentrant = reentrantChainMap(funDef) + val reentrantPairs = reentrant.map(chain => chain -> Set(chain)) + val crosswise = (chains -- reentrant).map(chain => chain -> { + reentrant.filter(other => isSAT(And(chain reentrant other))) + }) + reentrantPairs ++ crosswise }).flatten.toMap + val validChainMap : Map[FunDef, Set[Chain]] = possibleChainMap.map({ case (funDef, chains) => funDef -> chains.filter(crossChains(_).nonEmpty) }) + // We use the cross-chains to build chain clusters. For each cluster, we must prove that the SAME argument // decreases in each of the chains in the cluster! - val clusters : Map[FunDef, Set[Set[Chain]]] = { + reporter.info("- Building initial cluster estimation by fix-point iteration") + val generalClusters : Map[FunDef, Set[Set[Chain]]] = { def cluster(set: Set[Chain]): Set[Chain] = { set ++ set.map(crossChains(_)).flatten } @@ -51,9 +61,19 @@ class ChainProcessor(checker: TerminationChecker) extends Processor(checker) wit filterClusters(allClusters.toList.sortBy(- _.size)).toSet } - chainMap.map({ case (funDef, chains) => funDef -> build(chains) }) + validChainMap.map({ case (funDef, chains) => funDef -> build(chains) }) } + reporter.info("- Trimming down to final clusters") + val clusters : Map[FunDef, Set[Set[Chain]]] = generalClusters.map({ case (funDef, clusters) => + funDef -> clusters.map(cluster => cluster.toSeq.sortBy(_.size).foldLeft(Set[Chain]())({ case (acc, chain) => + val chainElements : Set[Relation] = chain.chain.toSet + val seenElements : Set[Relation] = acc.map(_.chain).flatten.toSet + if (chainElements -- seenElements nonEmpty) acc + chain else acc + })).filter(_.nonEmpty) + }) + + reporter.info("- Strengthening postconditions") strengthenPostconditions(problem.funDefs) def buildLoops(fd: FunDef, cluster: Set[Chain]): (Expr, Seq[(Seq[Expr], Expr)]) = { @@ -71,23 +91,36 @@ class ChainProcessor(checker: TerminationChecker) extends Processor(checker) wit type ClusterMap = Map[FunDef, Set[Set[Chain]]] type FormulaGenerator = (FunDef, Set[Chain]) => Expr - def clear(clusters: ClusterMap, gen: FormulaGenerator): ClusterMap = clusters.map({ case (fd, clusters) => - val remaining = clusters.filter(cluster => !solve(gen(fd, cluster)).isValid) - fd -> remaining - }) + def clear(clusters: ClusterMap, gen: FormulaGenerator): ClusterMap = { + val formulas = clusters.map({ case (fd, clusters) => + (fd, clusters.map(cluster => cluster -> gen(fd, cluster))) + }) + initSolvers // add structural size functions to solver + formulas.map({ case (fd, clustersWithFormulas) => + fd -> clustersWithFormulas.filter({ case (cluster, formula) => !isAlwaysSAT(formula) }).map(_._1) + }) + } + + reporter.info("- Searching for structural size decrease") val sizeCleared : ClusterMap = clear(clusters, (fd, cluster) => { val (e1, e2s) = buildLoops(fd, cluster) ChainComparator.sizeDecreasing(e1, e2s) }) + reporter.info("- Searching for numeric convergence") val numericCleared : ClusterMap = clear(sizeCleared, (fd, cluster) => { val (e1, e2s) = buildLoops(fd, cluster) ChainComparator.numericConverging(e1, e2s, cluster, checker) }) val (okPairs, nokPairs) = numericCleared.partition(_._2.isEmpty) - val newProblems = if (nokPairs nonEmpty) List(Problem(nokPairs.map(_._1).toSet)) else Nil - (okPairs.map(p => Cleared(p._1)), newProblems) + val nok = nokPairs.map(_._1).toSet + val (ok, transitiveNok) = okPairs.map(_._1).partition({ fd => + checker.program.transitiveCallees(fd) intersect nok isEmpty + }) + val allNok = nok ++ transitiveNok + val newProblems = if (allNok nonEmpty) List(Problem(allNok)) else Nil + (ok.map(Cleared(_)), newProblems) } } diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala index f0791056f..5f0c6db3c 100644 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -9,37 +9,29 @@ class LoopProcessor(checker: TerminationChecker, k: Int = 10) extends Processor( val name: String = "Loop Processor" + ChainBuilder.init + def run(problem: Problem) = { val allChains : Set[Chain] = problem.funDefs.map(fd => ChainBuilder.run(fd)).flatten // Get reentrant loops (see ChainProcessor for more details) - val chains : Set[Chain] = allChains.filter(chain => !solve(Not(And(chain reentrant chain))).isValid) - - def findLoops(chains: Set[Chain]) = { - def rec(chains: Set[Chain], count: Int): Map[FunDef, Seq[Expr]] = if (count == k) Map() else { - val nonTerminating = chains.flatMap({ chain => - val freshArgs : Seq[Expr] = chain.funDef.args.map(arg => arg.id.freshen.toVariable) - val finalBindings = (chain.funDef.args.map(_.id) zip freshArgs).toMap - val path = chain.times(count, finalSubst = finalBindings) - val formula = And(path :+ Equals(Tuple(chain.funDef.args.map(_.toVariable)), Tuple(freshArgs))) - - val solvable = functionCallsOf(formula).forall({ - case FunctionInvocation(fd, args) => checker.terminates(fd).isGuaranteed - }) - - if (!solvable) None else solve(Not(formula)) match { - case Solution(false, model) => Some(chain.funDef, chain.funDef.args.map(arg => model(arg.id))) - case _ => None - } - }).toMap - - val remainingChains = chains.filter(chain => nonTerminating.contains(chain.funDef)) - nonTerminating ++ rec(remainingChains, count + 1) - } + val chains : Set[Chain] = allChains.filter(chain => isSAT(And(chain reentrant chain))) + + val nonTerminating = chains.flatMap({ chain => + val freshArgs : Seq[Expr] = chain.funDef.args.map(arg => arg.id.freshen.toVariable) + val finalBindings = (chain.funDef.args.map(_.id) zip freshArgs).toMap + val path = chain.loop(finalSubst = finalBindings) + val formula = And(path :+ Equals(Tuple(chain.funDef.args.map(_.toVariable)), Tuple(freshArgs))) - rec(chains, 1) - } + val solvable = functionCallsOf(formula).forall({ + case FunctionInvocation(fd, args) => checker.terminates(fd).isGuaranteed + }) + + if (!solvable) None else getModel(formula) match { + case Some(map) => Some(chain.funDef -> chain.funDef.args.map(arg => map(arg.id))) + case _ => None + } + }).toMap - val nonTerminating = findLoops(chains) val results = nonTerminating.map({ case (funDef, args) => Broken(funDef, args) }) val remaining = problem.funDefs -- nonTerminating.keys val newProblems = if (remaining.nonEmpty) List(Problem(remaining)) else Nil diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index 760a44e37..95ec075f5 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -21,20 +21,9 @@ abstract class Processor(val checker: TerminationChecker) { val name: String - def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) -} - -class Solution(solution: Option[Boolean], val model: Map[Identifier, Expr]) { - lazy val isValid : Boolean = solution getOrElse false -} + val reporter = checker.context.reporter -object NoSolution extends Solution(None, Map()) - -object Solution { - def unapply(s: Solution): Option[(Boolean, Map[Identifier, Expr])] = { - if (s == NoSolution) None - else Some(s.isValid, s.model) - } + def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) } object Solvable { @@ -57,7 +46,7 @@ object Solvable { val resFresh = FreshIdentifier("result", true).setType(body.getType) val formula = Implies(prec, Let(resFresh, body, replace(Map(ResultVariable() -> Variable(resFresh)), post))) - if (!solver.solve(formula).isValid) { + if (!solver.isAlwaysSAT(formula)) { funDef.postcondition = postcondition strengthened.add(funDef) false @@ -85,17 +74,25 @@ object Solvable { trait Solvable { self: Processor => + private var solvers: List[Solver] = null + def strengthenPostconditions(funDefs: Set[FunDef]) = Solvable.strengthenPostconditions(funDefs)(this) - def solve(problem: Expr): Solution = { + def initSolvers { val program : Program = self.checker.program val allDefs : Seq[Definition] = program.mainObject.defs ++ StructuralSize.defs val newProgram : Program = program.copy(mainObject = program.mainObject.copy(defs = allDefs)) + val context : LeonContext = self.checker.context.copy(reporter = new QuietReporter()) - val solvers0 = new TrivialSolver(self.checker.context) :: new FairZ3Solver(self.checker.context) :: Nil - val solvers = solvers0.map(new TimeoutSolver(_, 500)) + val solvers0 = new TrivialSolver(context) :: new FairZ3Solver(context) :: Nil + solvers = solvers0.map(new TimeoutSolver(_, 500)) solvers.foreach(_.setProgram(newProgram)) + } + + type Solution = (Option[Boolean], Map[Identifier, Expr]) + private def solve(problem: Expr): Solution = { + if (solvers == null) initSolvers // drop functions from constraints that might not terminate (and may therefore // make Leon unroll them forever...) val dangerousCallsMap : Map[Expr, Expr] = functionCallsOf(problem).collect({ @@ -115,16 +112,29 @@ trait Solvable { self: Processor => superseeded = superseeded ++ Set(se.superseeds: _*) se.init() - val (satResult, model) = se.solveSAT(Not(expr)) - val solverResult = satResult.map(!_) + val (satResult, model) = se.solveSAT(expr) - if (!solverResult.isDefined) None - else Some(new Solution(solverResult, model)) + if (!satResult.isDefined) None + else Some(satResult, model) } } } - solvers.collectFirst({ case Solved(result) => result }) getOrElse NoSolution + solvers.collectFirst({ case Solved(s, model) => (s, model) }) getOrElse (None, Map()) + } + + def isSAT(problem: Expr): Boolean = { + solve(problem)._1 getOrElse false + } + + def isAlwaysSAT(problem: Expr): Boolean = { + solve(Not(problem))._1.map(!_) getOrElse false + } + + def getModel(problem: Expr): Option[Map[Identifier, Expr]] = { + val solution = solve(problem) + if (solution._1 getOrElse false) Some(solution._2) + else None } } @@ -157,7 +167,7 @@ class ProcessingPipeline(program: Program, context: LeonContext, _processors: Pr private def printResult(results: List[Result]) { val sb = new StringBuilder() - sb.append("- Queue.head Processing Result:\n") + sb.append("- Processing Result:\n") for(result <- results) result match { case Cleared(fd) => sb.append(" %-10s %s\n".format(fd.id, "Cleared")) case Broken(fd, args) => sb.append(" %-10s %s\n".format(fd.id, "Broken for arguments: " + args.mkString("(", ",", ")"))) @@ -173,10 +183,11 @@ class ProcessingPipeline(program: Program, context: LeonContext, _processors: Pr } def run : Iterator[(String, List[Result])] = new Iterator[(String, List[Result])] { - // basic sanity check, funDefs can't call themselves in precondition! - assert(initialProblem.funDefs.forall(fd => !fd.precondition.map({ precondition => - functionCallsOf(precondition).map(fi => program.transitiveCallees(fi.funDef)).flatten - }).flatten.toSet(fd))) + // basic sanity check, funDefs shouldn't call themselves in precondition! + // XXX: it seems like some do... + // assert(initialProblem.funDefs.forall(fd => !fd.precondition.map({ precondition => + // functionCallsOf(precondition).map(fi => program.transitiveCallees(fi.funDef)).flatten + // }).flatten.toSet(fd))) def hasNext : Boolean = problems.nonEmpty def next : (String, List[Result]) = { diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala index 569e660c7..78f5ddbdb 100644 --- a/src/main/scala/leon/termination/RecursionProcessor.scala +++ b/src/main/scala/leon/termination/RecursionProcessor.scala @@ -11,6 +11,8 @@ class RecursionProcessor(checker: TerminationChecker) extends Processor(checker) val name: String = "Recursion Processor" + RelationBuilder.init + private def isSubtreeOf(expr: Expr, id: Identifier) : Boolean = { @tailrec def rec(e: Expr, fst: Boolean): Boolean = e match { @@ -23,20 +25,19 @@ class RecursionProcessor(checker: TerminationChecker) extends Processor(checker) def run(problem: Problem) = if (problem.funDefs.size > 1) (Nil, List(problem)) else { val funDef = problem.funDefs.head - - val selfRecursiveRelations = RelationBuilder.run(funDef).filter({ - case Relation(_, _, FunctionInvocation(fd, _)) => - fd == funDef || checker.terminates(fd).isGuaranteed - }) - - val decreases = funDef.args.zipWithIndex.exists({ case (arg, index) => - selfRecursiveRelations.forall({ case Relation(_, _, FunctionInvocation(_, args)) => - isSubtreeOf(args(index), arg.id) + val relations = RelationBuilder.run(funDef) + val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(fd, _)) => fd == funDef }) + + if (others.exists({ case Relation(_, _, FunctionInvocation(fd, _)) => !checker.terminates(fd).isGuaranteed })) (Nil, List(problem)) else { + val decreases = funDef.args.zipWithIndex.exists({ case (arg, index) => + recursive.forall({ case Relation(_, _, FunctionInvocation(_, args)) => + isSubtreeOf(args(index), arg.id) + }) }) - }) - if (!decreases) (Nil, List(problem)) - else (Cleared(funDef) :: Nil, Nil) + if (!decreases) (Nil, List(problem)) + else (Cleared(funDef) :: Nil, Nil) + } } } diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala index 3483197e8..8e50fcecb 100644 --- a/src/main/scala/leon/termination/RelationBuilder.scala +++ b/src/main/scala/leon/termination/RelationBuilder.scala @@ -13,7 +13,9 @@ final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocat object RelationBuilder { import scala.collection.mutable.{Map => MutableMap} - val relationCache : MutableMap[FunDef, Set[Relation]] = MutableMap() + private val relationCache : MutableMap[FunDef, Set[Relation]] = MutableMap() + + def init : Unit = relationCache.clear def run(funDef: FunDef): Set[Relation] = relationCache.get(funDef) match { case Some(relations) => relations @@ -62,8 +64,6 @@ object RelationBuilder { case _ => sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") } - // TODO: throw error if we see funDef in precondition or postcondition - val precondition = funDef.precondition getOrElse BooleanLiteral(true) val precRelations = funDef.precondition.map(e => visit(simplifyLets(matchToIfThenElse(e)), Nil)).flatten.toSet val bodyRelations = funDef.body.map(e => visit(simplifyLets(matchToIfThenElse(e)), List(precondition))).flatten.toSet diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala index 6266e1586..e6a6c9721 100644 --- a/src/main/scala/leon/termination/RelationComparator.scala +++ b/src/main/scala/leon/termination/RelationComparator.scala @@ -10,6 +10,8 @@ import purescala.Common._ object RelationComparator { import StructuralSize._ + def init : Unit = StructuralSize.init + def sizeDecreasing(e1: TypedExpr, e2: TypedExpr) = GreaterThan(size(e1), size(e2)) def sizeDecreasing(e1: Expr, e2: TypedExpr) = GreaterThan(size(e1), size(e2)) def sizeDecreasing(e1: TypedExpr, e2: Expr) = GreaterThan(size(e1), size(e2)) diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala index 973ca4105..2f094c838 100644 --- a/src/main/scala/leon/termination/RelationProcessor.scala +++ b/src/main/scala/leon/termination/RelationProcessor.scala @@ -12,6 +12,9 @@ class RelationProcessor(checker: TerminationChecker) extends Processor(checker) val name: String = "Relation Processor" + RelationBuilder.init + RelationComparator.init + def run(problem: Problem) = { strengthenPostconditions(problem.funDefs) @@ -32,10 +35,12 @@ class RelationProcessor(checker: TerminationChecker) extends Processor(checker) case class Dep(deps: Set[FunDef]) extends Result case object Failure extends Result + initSolvers + val decreasing = formulas.map({ case (fd, formulas) => val solved = formulas.map({ case (fid, (gt, ge)) => - if(solve(gt).isValid) Success - else if(solve(ge).isValid) Dep(Set(fid)) + if(isAlwaysSAT(gt)) Success + else if(isAlwaysSAT(ge)) Dep(Set(fid)) else Failure }) val result = if(solved.contains(Failure)) Failure else { @@ -69,18 +74,5 @@ class RelationProcessor(checker: TerminationChecker) extends Processor(checker) val results = terminating.map(Cleared(_)).toList val newProblems = if (problem.funDefs intersect nonTerminating nonEmpty) List(Problem(nonTerminating)) else Nil (results, newProblems) - - /* - val noIncrease = gtformulas.forall(solvers.solve(_._2)) - if(noIncrease) { - val isReducing = eqformulas.map(x => x._1 -> solvers.solve(x._2)) - if(isReducing.exists(!_._2)) { - val (ok,nok) = isReducing.partition(_._2) match { case (xs, ys) => (xs.map(_._1), ys.map(_._1)) } - ProcessingResult(Nil, ok.map(Conditional(_, nok)) toList, List(problem filter nok)) - } else if(noArgs.nonEmpty) { - ProcessingResult(Nil, functionsOfInterest.map(Conditional(_, noArgs)) toList, List(problem filter noArgs)) - } else ProcessingResult(problem.callers.map(Cleared(_, "size relation formula solved")) toList, Nil, Nil) - } else ProcessingResult(Nil, Nil, List(problem)) - */ } } diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index 60bc506b7..412c48660 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -19,21 +19,29 @@ object StructuralSize { private val sizeFunctionCache : MutableMap[TypeTree, FunDef] = MutableMap() def size(typedExpr: TypedExpr) : Expr = { - def funDef(tpe: TypeTree, cases: => Seq[MatchCase]) = sizeFunctionCache.get(tpe) match { - case Some(fd) => fd - case None => - val argument = VarDecl(FreshIdentifier("x"), tpe) - val fd = new FunDef(FreshIdentifier("size", true), Int32Type, Seq(argument)) - sizeFunctionCache(tpe) = fd + def funDef(tpe: TypeTree, cases: => Seq[MatchCase]) = { + // we want to reuse generic size functions for sub-types + val argumentType = tpe match { + case CaseClassType(cd) if cd.parent.isDefined => classDefToClassType(cd.parent.get) + case _ => tpe + } - val body = simplifyLets(matchToIfThenElse(MatchExpr(argument.toVariable, cases))) - val postSubcalls = functionCallsOf(body).map(GreaterThan(_, IntLiteral(0))).toSeq - val postRecursive = GreaterThan(ResultVariable(), IntLiteral(0)) - val postcondition = And(postSubcalls :+ postRecursive) + sizeFunctionCache.get(argumentType) match { + case Some(fd) => fd + case None => + val argument = VarDecl(FreshIdentifier("x"), argumentType) + val fd = new FunDef(FreshIdentifier("size", true), Int32Type, Seq(argument)) + sizeFunctionCache(argumentType) = fd - fd.body = Some(body) - fd.postcondition = Some(postcondition) - fd + val body = simplifyLets(matchToIfThenElse(MatchExpr(argument.toVariable, cases))) + val postSubcalls = functionCallsOf(body).map(GreaterThan(_, IntLiteral(0))).toSeq + val postRecursive = GreaterThan(ResultVariable(), IntLiteral(0)) + val postcondition = And(postSubcalls :+ postRecursive) + + fd.body = Some(body) + fd.postcondition = Some(postcondition) + fd + } } def caseClassType2MatchCase(_c: ClassTypeDef): MatchCase = { @@ -60,6 +68,8 @@ object StructuralSize { } def defs : Set[FunDef] = Set(sizeFunctionCache.values.toSeq : _*) + + def init : Unit = sizeFunctionCache.clear } // vim: set ts=4 sw=4 et: -- GitLab