From f0cbe2ceae408915471cda7f7cfaef6b7e3f5103 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Wed, 13 Mar 2013 13:59:33 +0100 Subject: [PATCH] Implement termination checker --- .../scala/leon/purescala/Definitions.scala | 2 +- src/main/scala/leon/purescala/Trees.scala | 2 +- .../scala/leon/termination/ChainBuilder.scala | 114 ++++++++++ .../leon/termination/ChainComparator.scala | 203 +++++++++++++++++ .../leon/termination/ChainProcessor.scala | 93 ++++++++ .../ComplexTerminationChecker.scala | 53 +++++ .../leon/termination/ComponentBuilder.scala | 53 +++++ .../leon/termination/ComponentProcessor.scala | 33 +++ .../leon/termination/LoopProcessor.scala | 50 +++++ .../scala/leon/termination/Processor.scala | 210 ++++++++++++++++++ .../leon/termination/RecursionProcessor.scala | 43 ++++ .../leon/termination/RelationBuilder.scala | 76 +++++++ .../leon/termination/RelationComparator.scala | 24 ++ .../leon/termination/RelationProcessor.scala | 86 +++++++ src/main/scala/leon/termination/SCC.scala | 19 +- .../SimpleTerminationChecker.scala | 68 +++--- .../leon/termination/StructuralSize.scala | 65 ++++++ .../leon/termination/TerminationChecker.scala | 23 +- .../leon/termination/TerminationPhase.scala | 11 +- .../leon/termination/TerminationReport.scala | 23 +- 20 files changed, 1194 insertions(+), 57 deletions(-) create mode 100644 src/main/scala/leon/termination/ChainBuilder.scala create mode 100644 src/main/scala/leon/termination/ChainComparator.scala create mode 100644 src/main/scala/leon/termination/ChainProcessor.scala create mode 100644 src/main/scala/leon/termination/ComplexTerminationChecker.scala create mode 100644 src/main/scala/leon/termination/ComponentBuilder.scala create mode 100644 src/main/scala/leon/termination/ComponentProcessor.scala create mode 100644 src/main/scala/leon/termination/LoopProcessor.scala create mode 100644 src/main/scala/leon/termination/Processor.scala create mode 100644 src/main/scala/leon/termination/RecursionProcessor.scala create mode 100644 src/main/scala/leon/termination/RelationBuilder.scala create mode 100644 src/main/scala/leon/termination/RelationComparator.scala create mode 100644 src/main/scala/leon/termination/RelationProcessor.scala create mode 100644 src/main/scala/leon/termination/StructuralSize.scala diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index b62992618..cc8003c5a 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -113,7 +113,7 @@ object Definitions { funDef.precondition.map(treeCatamorphism[CallGraph](convert, combine, compute(funDef)_, _)).getOrElse(Set.empty) ++ funDef.body.map(treeCatamorphism[CallGraph](convert, combine, compute(funDef)_, _)).getOrElse(Set.empty) ++ funDef.postcondition.map( pc => treeCatamorphism[CallGraph](convert, combine, compute(funDef)_, pc._2)).getOrElse(Set.empty) - }).reduceLeft(_ ++ _) + }).foldLeft(Set[(FunDef, FunDef)]())(_ ++ _) var callers: Map[FunDef,Set[FunDef]] = new scala.collection.immutable.HashMap[FunDef,Set[FunDef]] diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 65b25f25e..a6be5fbbd 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -63,7 +63,7 @@ object Trees { case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType with ScalacPositional { val fixedType = funDef.returnType - funDef.args.zip(args).foreach { case (a, c) => typeCheck(c, a.tpe) } + // funDef.args.zip(args).foreach { case (a, c) => typeCheck(c, a.tpe) } } case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr with FixedType { val fixedType = leastUpperBound(thenn.getType, elze.getType).getOrElse(AnyType) diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala new file mode 100644 index 000000000..abd158530 --- /dev/null +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -0,0 +1,114 @@ +package leon +package termination + +import leon.purescala.Definitions._ +import leon.purescala.Trees._ +import leon.purescala.TreeOps._ +import leon.purescala.Common._ + +final case class Chain(chain: List[Relation]) { + def funDef : FunDef = chain.head.funDef + def funDefs : Set[FunDef] = chain.map(_.funDef) toSet + + def loop(initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = { + assert(initialSubst.nonEmpty || finalSubst.nonEmpty) + + def rec(relations: List[Relation], subst: Map[Identifier, Expr]): Seq[Expr] = relations match { + case Relation(_, path, FunctionInvocation(fd, args)) :: Nil => + assert(fd == funDef) + val newPath = path.map(replaceFromIDs(subst, _)) + val equalityConstraints = if (finalSubst.isEmpty) Seq() else { + val newArgs = args.map(replaceFromIDs(subst, _)) + (fd.args.map(arg => finalSubst(arg.id)) zip newArgs).map(p => Equals(p._1, p._2)) + } + newPath ++ equalityConstraints + case Relation(_, path, FunctionInvocation(fd, args)) :: xs => + val formalArgs = fd.args.map(_.id) + val freshFormalArgVars = formalArgs.map(_.freshen.toVariable) + val formalArgsMap: Map[Identifier, Expr] = formalArgs zip freshFormalArgVars toMap + val (newPath, newArgs) = (path.map(replaceFromIDs(subst, _)), args.map(replaceFromIDs(subst, _))) + val constraints = newPath ++ (freshFormalArgVars zip newArgs).map(p => Equals(p._1, p._2)) + constraints ++ rec(xs, formalArgsMap) + case Nil => sys.error("Empty chain shouldn't exist by construction") + } + + val subst : Map[Identifier, Expr] = if (initialSubst.nonEmpty) initialSubst else { + funDef.args.map(arg => arg.id -> arg.toVariable).toMap + } + val Chain(relations) = this + rec(relations, subst) + } + + def reentrant(other: Chain) : Seq[Expr] = { + assert(funDef == other.funDef) + val bindingSubst : Map[Identifier, Expr] = funDef.args.map({ + arg => arg.id -> arg.id.freshen.toVariable + }).toMap + val firstLoop = loop(finalSubst = bindingSubst) + val secondLoop = other.loop(initialSubst = bindingSubst) + firstLoop ++ secondLoop + } + + def times(k: Int, initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = { + def rec(bindingSubst: Map[Identifier, Expr], count: Int) : Seq[Expr] = if (count == k) loop(initialSubst = bindingSubst, finalSubst = finalSubst) else { + val nextSubst : Map[Identifier, Expr] = funDef.args.map(arg => arg.id -> arg.id.freshen.toVariable).toMap + val currentLoop = loop(initialSubst = bindingSubst, finalSubst = nextSubst) + val rest = rec(nextSubst, count + 1) + currentLoop ++ rest + } + rec(initialSubst, 1) + } + + def inlined: TraversableOnce[Expr] = { + def rec(list: List[Relation], mapping: Map[Identifier, Expr]): List[Expr] = list match { + case Relation(_, _, FunctionInvocation(fd, args)) :: xs => + val mappedArgs = args.map(replaceFromIDs(mapping, _)) + val newMapping = fd.args.map(_.id).zip(mappedArgs).toMap + // We assume we have a body at this point. It would be weird to have gotten here without one... + val expr = hoistIte(expandLets(matchToIfThenElse(fd.getBody))) + val inlinedExpr = replaceFromIDs(newMapping, expr) + inlinedExpr:: rec(xs, newMapping) + case Nil => Nil + } + + val body = hoistIte(expandLets(matchToIfThenElse(funDef.getBody))) + body :: rec(chain, funDef.args.map(arg => arg.id -> arg.toVariable) toMap) + } +} + +object ChainBuilder { + import scala.collection.mutable.{Map => MutableMap} + + private val chainCache : MutableMap[FunDef, Set[Chain]] = MutableMap() + def run(funDef: FunDef): Set[Chain] = chainCache.get(funDef) match { + case Some(chains) => chains + case None => { + // Note that this method will generate duplicate cycles (in fact, it will generate all list representations of a cycle) + def chains(partials: List[(Relation, List[Relation])]): List[List[Relation]] = if (partials.isEmpty) Nil else { + // Note that chains in partials are reversed to profit from O(1) insertion + val (results, newPartials) = partials.foldLeft(List[List[Relation]](),List[(Relation, List[Relation])]())({ + case ((results, partials), (first, chain @ Relation(_, _, FunctionInvocation(fd, _)) :: xs)) => + val cycle = RelationBuilder.run(fd).contains(first) + // reverse the chain when "returning" it since we're working on reversed chains + val newResults = if (cycle) chain.reverse :: results else results + + // Partial chains can fall back onto a transition that was already taken (thus creating a cycle + // inside the chain). Since this cycle will be discovered elsewhere, such partial chains should be + // dropped from the partial chain list + val transitions = RelationBuilder.run(fd) -- chain.toSet + val newPartials = transitions.map(transition => (first, transition :: chain)).toList + + (newResults, partials ++ newPartials) + case (_, (_, Nil)) => scala.sys.error("Empty partial chain shouldn't exist by construction") + }) + + results ++ chains(newPartials) + } + + val initialPartials = RelationBuilder.run(funDef).map(r => (r, r :: Nil)).toList + val result = chains(initialPartials).map(Chain(_)).toSet + chainCache(funDef) = result + result + } + } +} diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala new file mode 100644 index 000000000..901d72895 --- /dev/null +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -0,0 +1,203 @@ +package leon +package termination + +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ +import purescala.Definitions._ +import purescala.Common._ + +object ChainComparator { + import StructuralSize._ + + def sizeDecreasing(e1: TypedExpr, e2s: Seq[(Seq[Expr], Expr)]) = _sizeDecreasing(e1, e2s map { + case (path, e2) => (path, exprToTypedExpr(e2)) + }) + def sizeDecreasing(e1: Expr, e2s: Seq[(Seq[Expr], Expr)]) = _sizeDecreasing(e1, e2s map { + case (path, e2) => (path, exprToTypedExpr(e2)) + }) + + private object ContainerType { + def unapply(c: ClassType): Option[(CaseClassDef, Seq[(Identifier, TypeTree)])] = c match { + case CaseClassType(classDef) => + if (classDef.fields.exists(arg => isSubtypeOf(arg.tpe, classDef.parent.map(AbstractClassType(_)).getOrElse(c)))) None + else if (classDef.hasParent && classDef.parent.get.knownChildren.size > 1) None + else Some((classDef, classDef.fields.map(arg => arg.id -> arg.tpe))) + case _ => None + } + } + + private def _sizeDecreasing(te1: TypedExpr, te2s: Seq[(Seq[Expr], TypedExpr)]) : Expr = te1 match { + case TypedExpr(e1, ContainerType(def1, types1)) => Or(types1.zipWithIndex map { case ((id1, type1), index) => + val newTe1 = TypedExpr(CaseClassSelector(def1, e1, id1), type1) + val newTe2s = te2s.map({ + case (path, TypedExpr(e2, ContainerType(def2, types2))) => + val (id2, type2) = types2(index) + (path, TypedExpr(CaseClassSelector(def2, e2, id2), type2)) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + }) + _sizeDecreasing(newTe1, newTe2s) + }) + case TypedExpr(e1, TupleType(types1)) => Or(types1.zipWithIndex map { case (type1, index) => + val newTe1 = TypedExpr(TupleSelect(e1, index + 1), type1) + val newTe2s = te2s.map({ + case (path, TypedExpr(e2, TupleType(types2))) => (path, TypedExpr(TupleSelect(e2, index + 1), types2(index))) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + }) + _sizeDecreasing(newTe1, newTe2s) + }) + case TypedExpr(_, _: ClassType) => And(te2s map { + case (path, te2 @ TypedExpr(_, _: ClassType)) => Implies(And(path), GreaterThan(size(te1), size(te2))) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + }) + case TypedExpr(_, BooleanType) => BooleanLiteral(false) + case TypedExpr(_, Int32Type) => BooleanLiteral(false) + case _ => scala.sys.error("Unexpected type " + te1.tpe) + } + + private sealed abstract class NumericEndpoint { + def inverse: NumericEndpoint = this match { + case UpperBoundEndpoint => LowerBoundEndpoint + case LowerBoundEndpoint => UpperBoundEndpoint + case InnerEndpoint => AnyEndpoint + case AnyEndpoint => InnerEndpoint + case NoEndpoint => NoEndpoint + } + def <(that: NumericEndpoint) : Boolean = (this, that) match { + case (UpperBoundEndpoint, AnyEndpoint) => true + case (LowerBoundEndpoint, AnyEndpoint) => true + case (InnerEndpoint, AnyEndpoint) => true + case (NoEndpoint, AnyEndpoint) => true + case (InnerEndpoint, UpperBoundEndpoint) => true + case (InnerEndpoint, LowerBoundEndpoint) => true + case (NoEndpoint, UpperBoundEndpoint) => true + case (NoEndpoint, LowerBoundEndpoint) => true + case (NoEndpoint, InnerEndpoint) => true + case _ => false + } + def <=(that: NumericEndpoint) : Boolean = (this, that) match { + case (t1, t2) if t1 < t2 => true + case (t1, t2) if t1 == t2 => true + case _ => false + } + def min(that: NumericEndpoint) : NumericEndpoint = { + if (this <= that) this else if (that <= this) that else InnerEndpoint + } + def max(that: NumericEndpoint) : NumericEndpoint = { + if (this <= that) that else if (that <= this) this else AnyEndpoint + } + } + + private case object UpperBoundEndpoint extends NumericEndpoint + private case object LowerBoundEndpoint extends NumericEndpoint + private case object InnerEndpoint extends NumericEndpoint + private case object AnyEndpoint extends NumericEndpoint + private case object NoEndpoint extends NumericEndpoint + + private def numericEndpoint(value: Expr, cluster: Set[Chain], checker: TerminationChecker) = { + + object Value { + val vars = variablesOf(value) + assert(vars.size == 1) + + def simplifyBinaryArithmetic(e1: Expr, e2: Expr) : Boolean = { + val (inE1, inE2) = (variablesOf(e1) == vars, variablesOf(e2) == vars) + if (inE1 && inE2) false else if (inE1) unapply(e1) else if (inE2) unapply(e2) else { + scala.sys.error("How the heck did we get here?!?") + } + } + + def unapply(expr: Expr): Boolean = if (variablesOf(expr) != vars) false else expr match { + case Plus(e1, e2) => simplifyBinaryArithmetic(e1, e2) + case Minus(e1, e2) => simplifyBinaryArithmetic(e1, e2) + // case Times(e1, e2) => ... Need to make sure multiplier is not negative! + case e => e == value + } + } + + def matches(expr: Expr) : NumericEndpoint = expr match { + case And(es) => es.map(matches(_)).foldLeft[NumericEndpoint](AnyEndpoint)(_ min _) + case Or(es) => es.map(matches(_)).foldLeft[NumericEndpoint](NoEndpoint)(_ max _) + case Not(e) => matches(e).inverse + case GreaterThan(Value(), e) if variablesOf(e).isEmpty => LowerBoundEndpoint + case GreaterThan(e, Value()) if variablesOf(e).isEmpty => UpperBoundEndpoint + case GreaterEquals(Value(), e) if variablesOf(e).isEmpty => LowerBoundEndpoint + case GreaterEquals(e, Value()) if variablesOf(e).isEmpty => UpperBoundEndpoint + case Equals(Value(), e) if variablesOf(e).isEmpty => InnerEndpoint + case Equals(e, Value()) if variablesOf(e).isEmpty => InnerEndpoint + case LessThan(e1, e2) => matches(GreaterThan(e2, e1)) + case LessEquals(e1, e2) => matches(GreaterEquals(e2, e1)) + case _ => NoEndpoint + } + + def endpoint(expr: Expr) : NumericEndpoint = expr match { + case IfExpr(cond, then, elze) => matches(cond) match { + case NoEndpoint => + endpoint(then) min endpoint(elze) + case ep => + val terminatingThen = functionCallsOf(then).forall(fi => checker.terminates(fi.funDef).isGuaranteed) + val terminatingElze = functionCallsOf(elze).forall(fi => checker.terminates(fi.funDef).isGuaranteed) + val thenEndpoint = if (terminatingThen) ep max endpoint(then) else endpoint(then) + val elzeEndpoint = if (terminatingElze) ep.inverse max endpoint(elze) else endpoint(elze) + thenEndpoint max elzeEndpoint + } + case _ => NoEndpoint + } + + cluster.foldLeft[NumericEndpoint](AnyEndpoint)((acc, chain) => { + acc min chain.inlined.foldLeft[NumericEndpoint](NoEndpoint)((acc, expr) => acc max endpoint(expr)) + }) + } + + def numericConverging(e1: TypedExpr, e2s: Seq[(Seq[Expr], Expr)], cluster: Set[Chain], checker: TerminationChecker) = _numericConverging(e1, e2s map { + case (path, e2) => (path, exprToTypedExpr(e2)) + }, cluster, checker) + def numericConverging(e1: Expr, e2s: Seq[(Seq[Expr], Expr)], cluster: Set[Chain], checker: TerminationChecker) = _numericConverging(e1, e2s map { + case (path, e2) => (path, exprToTypedExpr(e2)) + }, cluster, checker) + + private def _numericConverging(te1: TypedExpr, te2s: Seq[(Seq[Expr], TypedExpr)], cluster: Set[Chain], checker: TerminationChecker) : Expr = te1 match { + case TypedExpr(e1, ContainerType(def1, types1)) => Or(types1.zipWithIndex map { case ((id1, type1), index) => + val newTe1 = TypedExpr(CaseClassSelector(def1, e1, id1), type1) + val newTe2s = te2s.map({ + case (path, TypedExpr(e2, ContainerType(def2, types2))) => + val (id2, type2) = types2(index) + (path, TypedExpr(CaseClassSelector(def2, e2, id2), type2)) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + }) + _numericConverging(newTe1, newTe2s, cluster, checker) + }) + case TypedExpr(e1, TupleType(types1)) => Or(types1.zipWithIndex map { case (type1, index) => + val newTe1 = TypedExpr(TupleSelect(e1, index + 1), type1) + val newTe2s = te2s.map({ + case (path, TypedExpr(e2, TupleType(types2))) => (path, TypedExpr(TupleSelect(e2, index + 1), types2(index))) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + }) + _numericConverging(newTe1, newTe2s, cluster, checker) + }) + case TypedExpr(e1, Int32Type) => numericEndpoint(e1, cluster, checker) match { + case UpperBoundEndpoint => And(te2s map { + case (path, TypedExpr(e2, Int32Type)) => Implies(And(path), GreaterThan(e1, e2)) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + }) + case LowerBoundEndpoint => And(te2s map { + case (path, TypedExpr(e2, Int32Type)) => Implies(And(path), LessThan(e1, e2)) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + }) + case AnyEndpoint => Or(And(te2s map { + case (path, TypedExpr(e2, Int32Type)) => Implies(And(path), GreaterThan(e1, e2)) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + }), And(te2s map { + case (path, TypedExpr(e2, Int32Type)) => Implies(And(path), LessThan(e1, e2)) + case (_, te2) => scala.sys.error("Unexpected input combinations: " + te1 + " " + te2) + })) + case InnerEndpoint => BooleanLiteral(false) + case NoEndpoint => BooleanLiteral(false) + } + case TypedExpr(_, _: ClassType) => BooleanLiteral(false) + case TypedExpr(_, BooleanType) => BooleanLiteral(false) + case _ => scala.sys.error("Unexpected type " + te1.tpe) + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala new file mode 100644 index 000000000..b40ccf166 --- /dev/null +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -0,0 +1,93 @@ +package leon +package termination + +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ +import purescala.Common._ +import purescala.Extractors._ +import purescala.Definitions._ + +class ChainProcessor(checker: TerminationChecker) extends Processor(checker) with Solvable { + + val name: String = "Chain Processor" + + def run(problem: Problem) = { + val allChainMap : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> ChainBuilder.run(funDef)).toMap + val allChains : Set[Chain] = allChainMap.values.flatten.toSet + + // We check that loops can reenter themselves after a run. If not, then this is not a chain (since it will + // enter another chain and their conjunction is contained elsewhere in the chains set) + // Note: We are checking reentrance SAT, not looking for a counter example so we negate the formula! + val validChains : Set[Chain] = allChains.filter(chain => !solve(Not(And(chain reentrant chain))).isValid) + val chainMap : Map[FunDef, Set[Chain]] = allChainMap.mapValues(chains => chains intersect validChains) + + // We build a cross-chain map that determines which chains can reenter into another one after a loop. + // Note: We are also checking reentrance SAT here, so again, we negate the formula! + val crossChains : Map[Chain, Set[Chain]] = chainMap.map({ case (funDef, chains) => + chains.map(chain => chain -> (chains - chain).filter(other => !solve(Not(And(chain reentrant other))).isValid)) + }).flatten.toMap + + // We use the cross-chains to build chain clusters. For each cluster, we must prove that the SAME argument + // decreases in each of the chains in the cluster! + val clusters : Map[FunDef, Set[Set[Chain]]] = { + def cluster(set: Set[Chain]): Set[Chain] = { + set ++ set.map(crossChains(_)).flatten + } + + def fix[A](f: A => A, a: A): A = { + val na = f(a) + if (a == na) a else fix(f, na) + } + + def filterClusters(all: List[Set[Chain]]): List[Set[Chain]] = if (all.isEmpty) Nil else { + val newCluster = all.head + val rest = all.tail.filter(set => !set.subsetOf(newCluster)) + newCluster :: filterClusters(rest) + } + + def build(chains: Set[Chain]): Set[Set[Chain]] = { + val allClusters = chains.map(chain => fix(cluster, Set(chain))) + filterClusters(allClusters.toList.sortBy(- _.size)).toSet + } + + chainMap.map({ case (funDef, chains) => funDef -> build(chains) }) + } + + strengthenPostconditions(problem.funDefs) + + def buildLoops(fd: FunDef, cluster: Set[Chain]): (Expr, Seq[(Seq[Expr], Expr)]) = { + val e1 = Tuple(fd.args.map(_.toVariable)) + val e2s = cluster.toSeq.map({ chain => + val freshArgs : Seq[Expr] = fd.args.map(arg => arg.id.freshen.toVariable) + val finalBindings = (fd.args.map(_.id) zip freshArgs).toMap + val path = chain.loop(finalSubst = finalBindings) + path -> Tuple(freshArgs) + }) + + (e1, e2s) + } + + type ClusterMap = Map[FunDef, Set[Set[Chain]]] + type FormulaGenerator = (FunDef, Set[Chain]) => Expr + + def clear(clusters: ClusterMap, gen: FormulaGenerator): ClusterMap = clusters.map({ case (fd, clusters) => + val remaining = clusters.filter(cluster => !solve(gen(fd, cluster)).isValid) + fd -> remaining + }) + + val sizeCleared : ClusterMap = clear(clusters, (fd, cluster) => { + val (e1, e2s) = buildLoops(fd, cluster) + ChainComparator.sizeDecreasing(e1, e2s) + }) + + val numericCleared : ClusterMap = clear(sizeCleared, (fd, cluster) => { + val (e1, e2s) = buildLoops(fd, cluster) + ChainComparator.numericConverging(e1, e2s, cluster, checker) + }) + + val (okPairs, nokPairs) = numericCleared.partition(_._2.isEmpty) + val newProblems = if (nokPairs nonEmpty) List(Problem(nokPairs.map(_._1).toSet)) else Nil + (okPairs.map(p => Cleared(p._1)), newProblems) + } +} diff --git a/src/main/scala/leon/termination/ComplexTerminationChecker.scala b/src/main/scala/leon/termination/ComplexTerminationChecker.scala new file mode 100644 index 000000000..de9d00cc2 --- /dev/null +++ b/src/main/scala/leon/termination/ComplexTerminationChecker.scala @@ -0,0 +1,53 @@ +package leon +package termination + +import purescala.Definitions._ +import purescala.Trees._ + +class ComplexTerminationChecker(context: LeonContext, program: Program) extends TerminationChecker(context, program) { + import scala.collection.mutable.{Map => MutableMap} + + val name = "Complex Termination Checker" + val description = "A modular termination checker with a few basic modules™" + + private val pipeline = new ProcessingPipeline( + program, context, // required for solvers and reporting + new ComponentProcessor(this), + new RecursionProcessor(this), + new RelationProcessor(this), + new ChainProcessor(this), + new LoopProcessor(this) + ) + + private val clearedMap : MutableMap[FunDef, String] = MutableMap() + private val brokenMap : MutableMap[FunDef, (String, Seq[Expr])] = MutableMap() + def initialize() { + for ((reason, results) <- pipeline.run; result <- results) result match { + case Cleared(fd) => clearedMap(fd) = reason + case Broken(fd, args) => brokenMap(fd) = (reason, args) + } + } + + private val terminationMap : MutableMap[FunDef, TerminationGuarantee] = MutableMap() + def terminates(funDef: FunDef): TerminationGuarantee = terminationMap.get(funDef) match { + case Some(guarantee) => guarantee + case None => { + val guarantee = brokenMap.get(funDef) match { + case Some((reason, args)) => LoopsGivenInputs(reason, args) + case None => program.transitiveCallees(funDef) intersect brokenMap.keys.toSet match { + case set if set.nonEmpty => CallsNonTerminating(set) + case _ => if (pipeline.clear(funDef)) clearedMap.get(funDef) match { + case Some(reason) => Terminates(reason) + case None => scala.sys.error(funDef.id + " -> not problem, but not cleared or broken ??") + } else NoGuarantee + } + } + + if (guarantee != NoGuarantee) { + terminationMap(funDef) = guarantee + } + + guarantee + } + } +} diff --git a/src/main/scala/leon/termination/ComponentBuilder.scala b/src/main/scala/leon/termination/ComponentBuilder.scala new file mode 100644 index 000000000..b70ea455d --- /dev/null +++ b/src/main/scala/leon/termination/ComponentBuilder.scala @@ -0,0 +1,53 @@ +package leon +package termination + +/** This could be defined anywhere, it's just that the + termination checker is the only place where it is used. */ +object ComponentBuilder { + def run[T](graph : Map[T,Set[T]]) : List[Set[T]] = { + // The first part is a shameless adaptation from Wikipedia + val allVertices : Set[T] = graph.keySet ++ graph.values.flatten + + var index = 0 + var indices : Map[T,Int] = Map.empty + var lowLinks : Map[T,Int] = Map.empty + var components : List[Set[T]] = Nil + var s : List[T] = Nil + + def strongConnect(v : T) { + indices = indices.updated(v, index) + lowLinks = lowLinks.updated(v, index) + index += 1 + s = v :: s + + for(w <- graph.getOrElse(v, Set.empty)) { + if(!indices.isDefinedAt(w)) { + strongConnect(w) + lowLinks = lowLinks.updated(v, lowLinks(v) min lowLinks(w)) + } else if(s.contains(w)) { + lowLinks = lowLinks.updated(v, lowLinks(v) min indices(w)) + } + } + + if(lowLinks(v) == indices(v)) { + var c : Set[T] = Set.empty + var stop = false + do { + val x :: xs = s + c = c + x + s = xs + stop = (x == v) + } while(!stop); + components = c :: components + } + } + + for(v <- allVertices) { + if(!indices.isDefinedAt(v)) { + strongConnect(v) + } + } + + components + } +} diff --git a/src/main/scala/leon/termination/ComponentProcessor.scala b/src/main/scala/leon/termination/ComponentProcessor.scala new file mode 100644 index 000000000..58d1da4a5 --- /dev/null +++ b/src/main/scala/leon/termination/ComponentProcessor.scala @@ -0,0 +1,33 @@ +package leon +package termination + +import purescala.TreeOps._ +import purescala.Definitions._ + +class ComponentProcessor(checker: TerminationChecker) extends Processor(checker) { + + val name: String = "Component Processor" + + def run(problem: Problem) = { + val pairs : Set[(FunDef, FunDef)] = checker.program.callGraph.filter({ + case (fd1, fd2) => problem.funDefs(fd1) && problem.funDefs(fd2) + }) + val callGraph : Map[FunDef,Set[FunDef]] = pairs.groupBy(_._1).mapValues(_.map(_._2)) + val components : List[Set[FunDef]] = ComponentBuilder.run(callGraph) + val fdToSCC : Map[FunDef, Set[FunDef]] = components.map(set => set.map(fd => fd -> set)).flatten.toMap + + import scala.collection.mutable.{Map => MutableMap} + val terminationCache : MutableMap[FunDef, Boolean] = MutableMap() + def terminates(fd: FunDef) : Boolean = terminationCache.getOrElse(fd, { + val scc = fdToSCC.getOrElse(fd, Set()) // functions that aren't called and don't call belong to no SCC + val result = if (scc(fd)) false else scc.forall(terminates(_)) + terminationCache(fd) = result + result + }) + + val terminating = problem.funDefs.filter(terminates(_)) + assert(components.forall(scc => (scc subsetOf terminating) || (scc intersect terminating isEmpty))) + val newProblems = components.filter(scc => scc intersect terminating isEmpty).map(Problem(_)) + (terminating.map(Cleared(_)), newProblems) + } +} diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala new file mode 100644 index 000000000..f0791056f --- /dev/null +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -0,0 +1,50 @@ +package leon +package termination + +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ + +class LoopProcessor(checker: TerminationChecker, k: Int = 10) extends Processor(checker) with Solvable { + + val name: String = "Loop Processor" + + def run(problem: Problem) = { + val allChains : Set[Chain] = problem.funDefs.map(fd => ChainBuilder.run(fd)).flatten + // Get reentrant loops (see ChainProcessor for more details) + val chains : Set[Chain] = allChains.filter(chain => !solve(Not(And(chain reentrant chain))).isValid) + + def findLoops(chains: Set[Chain]) = { + def rec(chains: Set[Chain], count: Int): Map[FunDef, Seq[Expr]] = if (count == k) Map() else { + val nonTerminating = chains.flatMap({ chain => + val freshArgs : Seq[Expr] = chain.funDef.args.map(arg => arg.id.freshen.toVariable) + val finalBindings = (chain.funDef.args.map(_.id) zip freshArgs).toMap + val path = chain.times(count, finalSubst = finalBindings) + val formula = And(path :+ Equals(Tuple(chain.funDef.args.map(_.toVariable)), Tuple(freshArgs))) + + val solvable = functionCallsOf(formula).forall({ + case FunctionInvocation(fd, args) => checker.terminates(fd).isGuaranteed + }) + + if (!solvable) None else solve(Not(formula)) match { + case Solution(false, model) => Some(chain.funDef, chain.funDef.args.map(arg => model(arg.id))) + case _ => None + } + }).toMap + + val remainingChains = chains.filter(chain => nonTerminating.contains(chain.funDef)) + nonTerminating ++ rec(remainingChains, count + 1) + } + + rec(chains, 1) + } + + val nonTerminating = findLoops(chains) + val results = nonTerminating.map({ case (funDef, args) => Broken(funDef, args) }) + val remaining = problem.funDefs -- nonTerminating.keys + val newProblems = if (remaining.nonEmpty) List(Problem(remaining)) else Nil + (results, newProblems) + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala new file mode 100644 index 000000000..760a44e37 --- /dev/null +++ b/src/main/scala/leon/termination/Processor.scala @@ -0,0 +1,210 @@ +package leon +package termination + +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.Common._ +import purescala.Definitions._ + +import leon.solvers._ +import leon.solvers.z3._ + +case class Problem(funDefs: Set[FunDef]) { + override def toString : String = funDefs.map(_.id).mkString("Problem(", ",", ")") +} + +sealed abstract class Result(funDef: FunDef) +case class Cleared(funDef: FunDef) extends Result(funDef) +case class Broken(funDef: FunDef, args: Seq[Expr]) extends Result(funDef) + +abstract class Processor(val checker: TerminationChecker) { + + val name: String + + def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) +} + +class Solution(solution: Option[Boolean], val model: Map[Identifier, Expr]) { + lazy val isValid : Boolean = solution getOrElse false +} + +object NoSolution extends Solution(None, Map()) + +object Solution { + def unapply(s: Solution): Option[(Boolean, Map[Identifier, Expr])] = { + if (s == NoSolution) None + else Some(s.isValid, s.model) + } +} + +object Solvable { + import scala.collection.mutable.{Set => MutableSet} + + private val strengthened : MutableSet[FunDef] = MutableSet() + private def strengthenPostcondition(funDef: FunDef, cmp: (Expr, TypedExpr) => Expr) + (implicit solver: Processor with Solvable) : Boolean = if (!funDef.hasBody) false else { + assert(solver.checker.terminates(funDef).isGuaranteed) + + val postcondition = funDef.postcondition + val args = funDef.args.map(_.toVariable) + val typedResult = TypedExpr(ResultVariable(), funDef.returnType) + val sizePost = cmp(Tuple(args), typedResult) + funDef.postcondition = Some(And(postcondition.toSeq :+ sizePost)) + + val prec = matchToIfThenElse(funDef.precondition.getOrElse(BooleanLiteral(true))) + val post = matchToIfThenElse(funDef.postcondition.get) + val body = matchToIfThenElse(funDef.body.get) + val resFresh = FreshIdentifier("result", true).setType(body.getType) + val formula = Implies(prec, Let(resFresh, body, replace(Map(ResultVariable() -> Variable(resFresh)), post))) + + if (!solver.solve(formula).isValid) { + funDef.postcondition = postcondition + strengthened.add(funDef) + false + } else { + strengthened.add(funDef) + true + } + } + + def strengthenPostconditions(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) { + // Strengthen postconditions on all accessible functions by adding size constraints + val callees : Set[FunDef] = funDefs.map(fd => solver.checker.program.transitiveCallees(fd)).flatten + val sortedCallees : Seq[FunDef] = callees.toSeq.sortWith((fd1, fd2) => solver.checker.program.transitivelyCalls(fd2, fd1)) + for (funDef <- sortedCallees if !strengthened(funDef) && funDef.hasBody && solver.checker.terminates(funDef).isGuaranteed) { + // test if size is smaller or equal to input + val weekConstraintHolds = strengthenPostcondition(funDef, RelationComparator.softDecreasing) + + if (weekConstraintHolds) { + // try to improve postcondition with strictly smaller + strengthenPostcondition(funDef, RelationComparator.sizeDecreasing) + } + } + } +} + +trait Solvable { self: Processor => + + def strengthenPostconditions(funDefs: Set[FunDef]) = Solvable.strengthenPostconditions(funDefs)(this) + + def solve(problem: Expr): Solution = { + val program : Program = self.checker.program + val allDefs : Seq[Definition] = program.mainObject.defs ++ StructuralSize.defs + val newProgram : Program = program.copy(mainObject = program.mainObject.copy(defs = allDefs)) + + val solvers0 = new TrivialSolver(self.checker.context) :: new FairZ3Solver(self.checker.context) :: Nil + val solvers = solvers0.map(new TimeoutSolver(_, 500)) + solvers.foreach(_.setProgram(newProgram)) + + // drop functions from constraints that might not terminate (and may therefore + // make Leon unroll them forever...) + val dangerousCallsMap : Map[Expr, Expr] = functionCallsOf(problem).collect({ + // extra definitions (namely size functions) are quaranteed to terminate because structures are non-looping + case fi @ FunctionInvocation(fd, args) if !StructuralSize.defs(fd) && !self.checker.terminates(fd).isGuaranteed => + fi -> FreshIdentifier("noRun", true).setType(fi.getType).toVariable + }).toMap + + val expr = searchAndReplace(dangerousCallsMap.get, recursive=false)(problem) + + object Solved { + var superseeded : Set[String] = Set.empty[String] + def unapply(se: Solver): Option[Solution] = { + if(superseeded(se.name) || superseeded(se.description)) { + None + } else { + superseeded = superseeded ++ Set(se.superseeds: _*) + + se.init() + val (satResult, model) = se.solveSAT(Not(expr)) + val solverResult = satResult.map(!_) + + if (!solverResult.isDefined) None + else Some(new Solution(solverResult, model)) + } + } + } + + solvers.collectFirst({ case Solved(result) => result }) getOrElse NoSolution + } +} + +class ProcessingPipeline(program: Program, context: LeonContext, _processors: Processor*) { + import scala.collection.mutable.{Queue => MutableQueue} + + assert(_processors.nonEmpty) + private val processors: Array[Processor] = _processors.toArray + private val reporter: Reporter = context.reporter + + private val initialProblem : Problem = Problem(program.definedFunctions.toSet) + private val problems : MutableQueue[(Problem,Int)] = MutableQueue((initialProblem, 0)) + private var unsolved : Set[Problem] = Set() + + private def printQueue { + val sb = new StringBuilder() + sb.append("- Problems in Queue:\n") + for((problem, index) <- problems) { + sb.append(" -> Problem awaiting processor #") + sb.append(index + 1) + sb.append(" (") + sb.append(processors(index).name) + sb.append(")\n") + for(funDef <- problem.funDefs) { + sb.append(" " + funDef.id + "\n") + } + } + reporter.info(sb.toString) + } + + private def printResult(results: List[Result]) { + val sb = new StringBuilder() + sb.append("- Queue.head Processing Result:\n") + for(result <- results) result match { + case Cleared(fd) => sb.append(" %-10s %s\n".format(fd.id, "Cleared")) + case Broken(fd, args) => sb.append(" %-10s %s\n".format(fd.id, "Broken for arguments: " + args.mkString("(", ",", ")"))) + } + reporter.info(sb.toString) + } + + def clear(fd: FunDef) : Boolean = { + lazy val unsolvedDefs = unsolved.map(_.funDefs).flatten.toSet + lazy val problemDefs = problems.map({ case (problem, _) => problem.funDefs }).flatten.toSet + def issue(defs: Set[FunDef]) : Boolean = defs(fd) || (defs intersect program.transitiveCallees(fd) nonEmpty) + ! (issue(unsolvedDefs) || issue(problemDefs)) + } + + def run : Iterator[(String, List[Result])] = new Iterator[(String, List[Result])] { + // basic sanity check, funDefs can't call themselves in precondition! + assert(initialProblem.funDefs.forall(fd => !fd.precondition.map({ precondition => + functionCallsOf(precondition).map(fi => program.transitiveCallees(fi.funDef)).flatten + }).flatten.toSet(fd))) + + def hasNext : Boolean = problems.nonEmpty + def next : (String, List[Result]) = { + printQueue + val (problem, index) = problems.head + val processor : Processor = processors(index) + val (_results, nextProblems) = processor.run(problem) + val results = _results.toList + printResult(results) + + // dequeue and enqueue atomically to make sure the queue always + // makes sense (necessary for calls to clear(fd)) + problems.dequeue + nextProblems match { + case x :: xs if x == problem => + assert(xs.isEmpty) + if (index == processors.size - 1) unsolved += x + else problems.enqueue(x -> (index + 1)) + case list @ x :: xs => + problems.enqueue(list.map(p => (p -> 0)) : _*) + problems.enqueue(unsolved.map(p => (p -> 0)).toSeq : _*) + unsolved = Set() + case Nil => // no problem => do nothing! + } + + processor.name -> results.toList + } + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala new file mode 100644 index 000000000..569e660c7 --- /dev/null +++ b/src/main/scala/leon/termination/RecursionProcessor.scala @@ -0,0 +1,43 @@ +package leon +package termination + +import purescala.Trees._ +import purescala.Common._ +import purescala.Definitions._ + +import scala.annotation.tailrec + +class RecursionProcessor(checker: TerminationChecker) extends Processor(checker) { + + val name: String = "Recursion Processor" + + private def isSubtreeOf(expr: Expr, id: Identifier) : Boolean = { + @tailrec + def rec(e: Expr, fst: Boolean): Boolean = e match { + case Variable(aid) if aid == id => !fst + case CaseClassSelector(_, cc, _) => rec(cc, false) + case _ => false + } + rec(expr, true) + } + + def run(problem: Problem) = if (problem.funDefs.size > 1) (Nil, List(problem)) else { + val funDef = problem.funDefs.head + + val selfRecursiveRelations = RelationBuilder.run(funDef).filter({ + case Relation(_, _, FunctionInvocation(fd, _)) => + fd == funDef || checker.terminates(fd).isGuaranteed + }) + + val decreases = funDef.args.zipWithIndex.exists({ case (arg, index) => + selfRecursiveRelations.forall({ case Relation(_, _, FunctionInvocation(_, args)) => + isSubtreeOf(args(index), arg.id) + }) + }) + + if (!decreases) (Nil, List(problem)) + else (Cleared(funDef) :: Nil, Nil) + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala new file mode 100644 index 000000000..3483197e8 --- /dev/null +++ b/src/main/scala/leon/termination/RelationBuilder.scala @@ -0,0 +1,76 @@ +package leon +package termination + +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.Extractors._ +import purescala.Common._ + +final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocation) { + override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.funDef.id + call.args.mkString("(",",",")") + ")" +} + +object RelationBuilder { + import scala.collection.mutable.{Map => MutableMap} + val relationCache : MutableMap[FunDef, Set[Relation]] = MutableMap() + + def run(funDef: FunDef): Set[Relation] = relationCache.get(funDef) match { + case Some(relations) => relations + case None => { + def visit(e: Expr, path: List[Expr]): Set[Relation] = e match { + + // we skip functions that aren't in this SCC since the call relations + // associated to them don't interest us. + case fi @ FunctionInvocation(f, args) => + val argRelations = args.flatMap(visit(_, path)).toSet + argRelations + Relation(funDef, path.reverse.toSeq, fi) + + case Let(i, e, b) => + val ve = visit(e, path) + val vb = visit(b, Equals(Variable(i), e) :: path) + ve ++ vb + + case IfExpr(cond, then, elze) => + val vc = visit(cond, path) + val vt = visit(then, cond :: path) + val ve = visit(elze, Not(cond) :: path) + vc ++ vt ++ ve + + case And(es) => + def resolveAnds(ands: List[Expr], p: List[Expr]): Set[Relation] = ands match { + case x :: xs => visit(x, p ++ path) ++ resolveAnds(xs, x :: p) + case Nil => Set() + } + resolveAnds(es toList, Nil) + + case Or(es) => + def resolveOrs(ors: List[Expr], p: List[Expr]): Set[Relation] = ors match { + case x :: xs => visit(x, p ++ path) ++ resolveOrs(xs, Not(x) :: p) + case Nil => Set() + } + resolveOrs(es toList, Nil) + + case UnaryOperator(e, _) => visit(e, path) + + case BinaryOperator(e1, e2, _) => visit(e1, path) ++ visit(e2, path) + + case NAryOperator(es, _) => es.map(visit(_, path)).flatten.toSet + + case t : Terminal => Set() + + case _ => sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") + } + + // TODO: throw error if we see funDef in precondition or postcondition + + val precondition = funDef.precondition getOrElse BooleanLiteral(true) + val precRelations = funDef.precondition.map(e => visit(simplifyLets(matchToIfThenElse(e)), Nil)).flatten.toSet + val bodyRelations = funDef.body.map(e => visit(simplifyLets(matchToIfThenElse(e)), List(precondition))).flatten.toSet + val postRelations = funDef.postcondition.map(e => visit(simplifyLets(matchToIfThenElse(e)), Nil)).flatten.toSet + val relations = precRelations ++ bodyRelations ++ postRelations + relationCache(funDef) = relations + relations + } + } +} diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala new file mode 100644 index 000000000..6266e1586 --- /dev/null +++ b/src/main/scala/leon/termination/RelationComparator.scala @@ -0,0 +1,24 @@ +package leon +package termination + +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ +import purescala.Definitions._ +import purescala.Common._ + +object RelationComparator { + import StructuralSize._ + + def sizeDecreasing(e1: TypedExpr, e2: TypedExpr) = GreaterThan(size(e1), size(e2)) + def sizeDecreasing(e1: Expr, e2: TypedExpr) = GreaterThan(size(e1), size(e2)) + def sizeDecreasing(e1: TypedExpr, e2: Expr) = GreaterThan(size(e1), size(e2)) + def sizeDecreasing(e1: Expr, e2: Expr) = GreaterThan(size(e1), size(e2)) + + def softDecreasing(e1: TypedExpr, e2: TypedExpr) = GreaterEquals(size(e1), size(e2)) + def softDecreasing(e1: Expr, e2: TypedExpr) = GreaterEquals(size(e1), size(e2)) + def softDecreasing(e1: TypedExpr, e2: Expr) = GreaterEquals(size(e1), size(e2)) + def softDecreasing(e1: Expr, e2: Expr) = GreaterEquals(size(e1), size(e2)) +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala new file mode 100644 index 000000000..973ca4105 --- /dev/null +++ b/src/main/scala/leon/termination/RelationProcessor.scala @@ -0,0 +1,86 @@ +package leon +package termination + +import leon.purescala.Trees._ +import leon.purescala.TreeOps._ +import leon.purescala.TypeTrees._ +import leon.purescala.Common._ +import leon.purescala.Extractors._ +import leon.purescala.Definitions._ + +class RelationProcessor(checker: TerminationChecker) extends Processor(checker) with Solvable { + + val name: String = "Relation Processor" + + def run(problem: Problem) = { + + strengthenPostconditions(problem.funDefs) + + val formulas = problem.funDefs.map({ funDef => + funDef -> RelationBuilder.run(funDef).collect({ + case Relation(_, path, FunctionInvocation(fd, args)) if problem.funDefs(fd) => + val (e1, e2) = (Tuple(funDef.args.map(_.toVariable)), Tuple(args)) + def constraint(expr: Expr) = Implies(And(path.toSeq), expr) + val greaterThan = RelationComparator.sizeDecreasing(e1, e2) + val greaterEquals = RelationComparator.softDecreasing(e1, e2) + (fd, (constraint(greaterThan), constraint(greaterEquals))) + }) + }) + + sealed abstract class Result + case object Success extends Result + case class Dep(deps: Set[FunDef]) extends Result + case object Failure extends Result + + val decreasing = formulas.map({ case (fd, formulas) => + val solved = formulas.map({ case (fid, (gt, ge)) => + if(solve(gt).isValid) Success + else if(solve(ge).isValid) Dep(Set(fid)) + else Failure + }) + val result = if(solved.contains(Failure)) Failure else { + val deps = solved.collect({ case Dep(fds) => fds }).flatten + if(deps.isEmpty) Success + else Dep(deps) + } + fd -> result + }) + + val (terminating, nonTerminating) = { + def currentReducing(fds: Set[FunDef], deps: List[(FunDef, Set[FunDef])]): (Set[FunDef], List[(FunDef, Set[FunDef])]) = { + val (okDeps, nokDeps) = deps.partition({ case (fd, deps) => deps.subsetOf(fds) }) + val newFds = fds ++ okDeps.map(_._1) + (newFds, nokDeps) + } + + def fix[A,B](f: (A,B) => (A,B), a: A, b: B): (A,B) = { + val (na, nb) = f(a, b) + if(na == a && nb == b) (a,b) else fix(f, na, nb) + } + + val ok = decreasing.collect({ case (fd, Success) => fd }).toSet + val nok = decreasing.collect({ case (fd, Dep(fds)) => fd -> fds }).toList + val (allOk, allNok) = fix(currentReducing, ok, nok) + (allOk, allNok.map(_._1).toSet ++ decreasing.collect({ case (fd, Failure) => fd })) + } + + assert(terminating ++ nonTerminating == problem.funDefs) + + val results = terminating.map(Cleared(_)).toList + val newProblems = if (problem.funDefs intersect nonTerminating nonEmpty) List(Problem(nonTerminating)) else Nil + (results, newProblems) + + /* + val noIncrease = gtformulas.forall(solvers.solve(_._2)) + if(noIncrease) { + val isReducing = eqformulas.map(x => x._1 -> solvers.solve(x._2)) + if(isReducing.exists(!_._2)) { + val (ok,nok) = isReducing.partition(_._2) match { case (xs, ys) => (xs.map(_._1), ys.map(_._1)) } + ProcessingResult(Nil, ok.map(Conditional(_, nok)) toList, List(problem filter nok)) + } else if(noArgs.nonEmpty) { + ProcessingResult(Nil, functionsOfInterest.map(Conditional(_, noArgs)) toList, List(problem filter noArgs)) + } else ProcessingResult(problem.callers.map(Cleared(_, "size relation formula solved")) toList, Nil, Nil) + } else ProcessingResult(Nil, Nil, List(problem)) + */ + } +} diff --git a/src/main/scala/leon/termination/SCC.scala b/src/main/scala/leon/termination/SCC.scala index 54a13ec41..201d55d96 100644 --- a/src/main/scala/leon/termination/SCC.scala +++ b/src/main/scala/leon/termination/SCC.scala @@ -6,7 +6,7 @@ package termination /** This could be defined anywhere, it's just that the termination checker is the only place where it is used. */ object SCC { - def scc[T](graph : Map[T,Set[T]]) : (Array[Set[T]],Map[T,Int],Map[Int,Set[Int]]) = { + def scc[T](graph : Map[T,Set[T]]) : List[Set[T]] = { // The first part is a shameless adaptation from Wikipedia val allVertices : Set[T] = graph.keySet ++ graph.values.flatten @@ -50,21 +50,6 @@ object SCC { } } - // At this point, we have our components. - // We finish by building a graph between them. - // In the graph, components are represented as arrays indices. - val asArray = components.toArray - val cSize = asArray.length - - val vertIDs : Map[T,Int] = allVertices.map(v => - v -> (0 until cSize).find(i => asArray(i)(v)).get - ).toMap - - val bigCallGraph : Map[Int,Set[Int]] = (0 until cSize).map({ i => - val dsts = asArray(i).flatMap(v => graph.getOrElse(v, Set.empty)).map(vertIDs(_)) - i -> dsts - }).toMap - - (asArray,vertIDs,bigCallGraph) + components } } diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala index a72d12d2c..d8c5fce24 100644 --- a/src/main/scala/leon/termination/SimpleTerminationChecker.scala +++ b/src/main/scala/leon/termination/SimpleTerminationChecker.scala @@ -8,45 +8,55 @@ import purescala.Definitions._ import purescala.Trees._ import purescala.TreeOps._ -import scala.collection.mutable.{Map=>MutableMap} +import scala.collection.mutable.{ Map => MutableMap } import scala.annotation.tailrec -class SimpleTerminationChecker(context : LeonContext, program : Program) extends TerminationChecker(context, program) { +class SimpleTerminationChecker(context: LeonContext, program: Program) extends TerminationChecker(context, program) { val name = "T1" val description = "The simplest form of Terminator™" - private lazy val callGraph : Map[FunDef,Set[FunDef]] = + private lazy val callGraph: Map[FunDef, Set[FunDef]] = program.callGraph.groupBy(_._1).mapValues(_.map(_._2)) // one liner from hell - private lazy val sccTriple = SCC.scc(callGraph) - private lazy val sccArray : Array[Set[FunDef]] = sccTriple._1 - private lazy val funDefToSCCIndex : Map[FunDef,Int] = sccTriple._2 - private lazy val sccGraph : Map[Int,Set[Int]] = sccTriple._3 + private lazy val components = ComponentBuilder.run(callGraph) + val allVertices = callGraph.keySet ++ callGraph.values.flatten - private def callees(funDef : FunDef) : Set[FunDef] = callGraph.getOrElse(funDef, Set.empty) + val sccArray = components.toArray + val cSize = sccArray.length - private val answerCache = MutableMap.empty[FunDef,TerminationGuarantee] + val funDefToSCCIndex = (callGraph.keySet ++ callGraph.values.flatten).map(v => + v -> (0 until cSize).find(i => sccArray(i)(v)).get).toMap - def terminates(funDef : FunDef) = answerCache.getOrElse(funDef, { + val sccGraph = (0 until cSize).map({ i => + val dsts = sccArray(i).flatMap(v => callGraph.getOrElse(v, Set.empty)).map(funDefToSCCIndex(_)) + i -> dsts + }).toMap + + private def callees(funDef: FunDef): Set[FunDef] = callGraph.getOrElse(funDef, Set.empty) + + private val answerCache = MutableMap.empty[FunDef, TerminationGuarantee] + + def initialize() {} + def terminates(funDef: FunDef) = answerCache.getOrElse(funDef, { val g = forceCheckTermination(funDef) answerCache(funDef) = g g }) - private def forceCheckTermination(funDef : FunDef) : TerminationGuarantee = { + private def forceCheckTermination(funDef: FunDef): TerminationGuarantee = { // We would have to clarify what it means to terminate. // We would probably need something along the lines of: // "Terminates for all values satisfying prec." - if(funDef.hasPrecondition) + if (funDef.hasPrecondition) return NoGuarantee // This is also too confusing for me to think about now. - if(!funDef.hasImplementation) + if (!funDef.hasImplementation) return NoGuarantee - val sccIndex = funDefToSCCIndex.getOrElse(funDef, { + val sccIndex = funDefToSCCIndex.getOrElse(funDef, { return NoGuarantee }) val sccCallees = sccGraph(sccIndex) @@ -56,34 +66,32 @@ class SimpleTerminationChecker(context : LeonContext, program : Program) extends val sccLowerCallees = sccCallees.filterNot(_ == sccIndex) val lowerDefs = sccLowerCallees.map(sccArray(_)).foldLeft(Set.empty[FunDef])(_ ++ _) val lowerOK = lowerDefs.forall(terminates(_).isGuaranteed) - if(!lowerOK) + if (!lowerOK) return NoGuarantee // Now all we need to do is check the functions in the same // scc. But who knows, maybe none of these are called? - if(!sccCallees(sccIndex)) { + if (!sccCallees(sccIndex)) { // (the distinction isn't exactly useful...) - if(sccCallees.isEmpty) - return TerminatesForAllInputs("no calls") + if (sccCallees.isEmpty) + return Terminates("no calls") else - return TerminatesForAllInputs("by subcalls") + return Terminates("by subcalls") } // So now we know the function is recursive (or mutually // recursive). Maybe it's just self-recursive? - if(sccArray(sccIndex).size == 1) { + if (sccArray(sccIndex).size == 1) { assert(sccArray(sccIndex) == Set(funDef)) // Yes it is ! // Now we apply a simple recipe: we check that in each (self) // call, at least one argument is of an ADT type and decreases. // Yes, it's that restrictive. - val callsOfInterest = { (oe : Option[Expr]) => + val callsOfInterest = { (oe: Option[Expr]) => oe.map { e => functionCallsOf( simplifyLets( - matchToIfThenElse(e) - ) - ).filter(_.funDef == funDef) + matchToIfThenElse(e))).filter(_.funDef == funDef) } getOrElse Set.empty[FunctionInvocation] } @@ -91,25 +99,25 @@ class SimpleTerminationChecker(context : LeonContext, program : Program) extends val funDefArgsIDs = funDef.args.map(_.id).toSet - if(callsToAnalyze.forall { fi => + if (callsToAnalyze.forall { fi => fi.args.exists { arg => isSubTreeOfArg(arg, funDefArgsIDs) } }) { - return TerminatesForAllInputs("decreasing") + return Terminates("decreasing") } else { return NoGuarantee } } // Handling mutually recursive functions is beyond my willpower. - NoGuarantee + NoGuarantee } - private def isSubTreeOfArg(expr : Expr, args : Set[Identifier]) : Boolean = { + private def isSubTreeOfArg(expr: Expr, args: Set[Identifier]): Boolean = { @tailrec - def rec(e : Expr, fst : Boolean) : Boolean = e match { - case Variable(id) if(args(id)) => !fst + def rec(e: Expr, fst: Boolean): Boolean = e match { + case Variable(id) if (args(id)) => !fst case CaseClassSelector(_, cc, _) => rec(cc, false) case _ => false } diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala new file mode 100644 index 000000000..60bc506b7 --- /dev/null +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -0,0 +1,65 @@ +package leon +package termination + +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.TypeTrees._ +import purescala.Definitions._ +import purescala.Common._ + +case class TypedExpr(expr: Expr, tpe: TypeTree) + +object StructuralSize { + import scala.collection.mutable.{Map => MutableMap} + + implicit def exprToTypedExpr(expr: Expr): TypedExpr = { + assert(expr.getType != Untyped) + TypedExpr(expr, expr.getType) + } + + private val sizeFunctionCache : MutableMap[TypeTree, FunDef] = MutableMap() + def size(typedExpr: TypedExpr) : Expr = { + def funDef(tpe: TypeTree, cases: => Seq[MatchCase]) = sizeFunctionCache.get(tpe) match { + case Some(fd) => fd + case None => + val argument = VarDecl(FreshIdentifier("x"), tpe) + val fd = new FunDef(FreshIdentifier("size", true), Int32Type, Seq(argument)) + sizeFunctionCache(tpe) = fd + + val body = simplifyLets(matchToIfThenElse(MatchExpr(argument.toVariable, cases))) + val postSubcalls = functionCallsOf(body).map(GreaterThan(_, IntLiteral(0))).toSeq + val postRecursive = GreaterThan(ResultVariable(), IntLiteral(0)) + val postcondition = And(postSubcalls :+ postRecursive) + + fd.body = Some(body) + fd.postcondition = Some(postcondition) + fd + } + + def caseClassType2MatchCase(_c: ClassTypeDef): MatchCase = { + val c = _c.asInstanceOf[CaseClassDef] // required by leon framework + val arguments = c.fields.map(f => f -> f.id.freshen) + val argumentPatterns = arguments.map(p => WildcardPattern(Some(p._2))) + val sizes = arguments.map(p => size(TypedExpr(Variable(p._2), p._1.tpe))) + val result = sizes.foldLeft[Expr](IntLiteral(1))(Plus(_,_)) + SimpleCase(CaseClassPattern(None, c, argumentPatterns), result) + } + + typedExpr match { + case TypedExpr(expr, a: AbstractClassType) => + val sizeFd = funDef(a, a.classDef.knownChildren map caseClassType2MatchCase) + FunctionInvocation(sizeFd, Seq(expr)) + case TypedExpr(expr, c: CaseClassType) => + val sizeFd = funDef(c, Seq(caseClassType2MatchCase(c.classDef))) + FunctionInvocation(sizeFd, Seq(expr)) + case TypedExpr(expr, TupleType(argTypes)) => argTypes.zipWithIndex.map({ + case (tpe, index) => size(TypedExpr(TupleSelect(expr, index + 1), tpe)) + }).foldLeft[Expr](IntLiteral(0))(Plus(_,_)) + case _ => IntLiteral(0) + } + } + + def defs : Set[FunDef] = Set(sizeFunctionCache.values.toSeq : _*) +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/TerminationChecker.scala b/src/main/scala/leon/termination/TerminationChecker.scala index 442f59ab8..0b4d7878b 100644 --- a/src/main/scala/leon/termination/TerminationChecker.scala +++ b/src/main/scala/leon/termination/TerminationChecker.scala @@ -4,17 +4,32 @@ package leon package termination import purescala.Definitions._ +import purescala.Trees._ abstract class TerminationChecker(val context : LeonContext, val program : Program) extends LeonComponent { + def initialize() : Unit def terminates(funDef : FunDef) : TerminationGuarantee } sealed abstract class TerminationGuarantee { - val isGuaranteed : Boolean = false + def isGuaranteed: Boolean } -case class TerminatesForAllInputs(justification : String) extends TerminationGuarantee { - override val isGuaranteed : Boolean = true +abstract class Terminating(justification: String) extends TerminationGuarantee { + override def isGuaranteed: Boolean = true +} + +case class Terminates(justification: String) extends Terminating(justification) + +abstract class NonTerminating extends TerminationGuarantee { + override def isGuaranteed: Boolean = false +} + +case class LoopsGivenInputs(justification: String, args: Seq[Expr]) extends NonTerminating + +case class CallsNonTerminating(calls: Set[FunDef]) extends NonTerminating + +case object NoGuarantee extends TerminationGuarantee { + override def isGuaranteed: Boolean = false } -case object NoGuarantee extends TerminationGuarantee diff --git a/src/main/scala/leon/termination/TerminationPhase.scala b/src/main/scala/leon/termination/TerminationPhase.scala index ca44e6433..7f755e97e 100644 --- a/src/main/scala/leon/termination/TerminationPhase.scala +++ b/src/main/scala/leon/termination/TerminationPhase.scala @@ -10,13 +10,18 @@ object TerminationPhase extends LeonPhase[Program,TerminationReport] { val description = "Check termination of PureScala functions" def run(ctx : LeonContext)(program : Program) : TerminationReport = { - val tc = new SimpleTerminationChecker(ctx, program) - val startTime = System.currentTimeMillis + +// val tc = new SimpleTerminationChecker(ctx, program) + val tc = new ComplexTerminationChecker(ctx, program) + + tc.initialize() + val results = program.definedFunctions.toList.sortWith(_ < _).map { funDef => (funDef -> tc.terminates(funDef)) } val endTime = System.currentTimeMillis + new TerminationReport(results, (endTime - startTime).toDouble / 1000.0d) - } + } } diff --git a/src/main/scala/leon/termination/TerminationReport.scala b/src/main/scala/leon/termination/TerminationReport.scala index 9004d3f10..e38e8fa70 100644 --- a/src/main/scala/leon/termination/TerminationReport.scala +++ b/src/main/scala/leon/termination/TerminationReport.scala @@ -12,9 +12,30 @@ case class TerminationReport(val results : Seq[(FunDef,TerminationGuarantee)], v sb.append(" Termination summary \n") sb.append("─────────────────────\n\n") for((fd,g) <- results) { - sb.append("- %-30s %-30s\n".format(fd.id.name, g.toString)) + val result = if (g.isGuaranteed) "\u2713" else "\u2717" + val toPrint = g match { + case LoopsGivenInputs(reason, args) => + "Non-terminating for call: " + args.mkString(fd.id+"(", ",", ")") + case CallsNonTerminating(funDefs) => + "Calls non-terminating functions " + funDefs.map(_.id).mkString(",") + case Terminates(reason) => + "Terminates (" + reason + ")" + case _ => g.toString + } + sb.append("- %-30s %s %-30s\n".format(fd.id.name, result, toPrint)) } sb.append("\n[Analysis time: %7.3f]\n".format(time)) sb.toString } + + def evaluationString : String = { + val sb = new StringBuilder + for((fd,g) <- results) { + sb.append("- %-30s %s\n".format(fd.id.name, g match { + case NoGuarantee => "u" + case t => if (t.isGuaranteed) "t" else "n" + })) + } + sb.toString + } } -- GitLab