From 006758d00b6c370c3c588ecfc4329a303ccb1908 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Tue, 11 Mar 2014 08:47:34 +0100 Subject: [PATCH] Corrected some issues with termination proving. --- src/main/scala/leon/purescala/TreeOps.scala | 52 ++++- .../scala/leon/termination/ChainBuilder.scala | 219 +++++++++++------- .../leon/termination/ChainComparator.scala | 140 ++++++----- .../leon/termination/ChainProcessor.scala | 154 ++++-------- .../ComplexTerminationChecker.scala | 44 ++-- .../leon/termination/ComponentBuilder.scala | 51 +--- .../leon/termination/ComponentProcessor.scala | 5 +- .../leon/termination/LoopProcessor.scala | 59 +++-- .../scala/leon/termination/Processor.scala | 146 ++++-------- .../leon/termination/RecursionProcessor.scala | 10 +- .../leon/termination/RelationBuilder.scala | 88 +++---- .../leon/termination/RelationComparator.scala | 7 +- .../leon/termination/RelationProcessor.scala | 28 +-- .../SimpleTerminationChecker.scala | 4 +- .../scala/leon/termination/Strengthener.scala | 174 ++++++++++++++ .../leon/termination/StructuralSize.scala | 69 +++--- .../leon/termination/TerminationChecker.scala | 7 +- .../leon/{termination => utils}/SCC.scala | 4 +- .../leon/verification/DefaultTactic.scala | 6 +- .../termination/looping/Numeric3.scala | 8 + .../termination/unknown/Numeric3.scala | 10 - .../termination/valid/ComplexChains.scala | 28 --- .../valid/Termination_passing2.scala | 22 -- .../termination/TerminationRegression.scala | 3 +- 24 files changed, 702 insertions(+), 636 deletions(-) create mode 100644 src/main/scala/leon/termination/Strengthener.scala rename src/main/scala/leon/{termination => utils}/SCC.scala (97%) create mode 100644 src/test/resources/regression/termination/looping/Numeric3.scala delete mode 100644 src/test/resources/regression/termination/unknown/Numeric3.scala delete mode 100644 src/test/resources/regression/termination/valid/ComplexChains.scala delete mode 100644 src/test/resources/regression/termination/valid/Termination_passing2.scala diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index e68f47a48..b04ad6cef 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -767,6 +767,8 @@ object TreeOps { */ def hoistIte(expr: Expr): Expr = { def transform(expr: Expr): Option[Expr] = expr match { + case IfExpr(c, t, e) => None + case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).copiedFrom(uop), op(e).copiedFrom(uop)).copiedFrom(uop)) @@ -1142,24 +1144,48 @@ object TreeOps { def traverse(e: Expr): T } - class CollectorWithPaths[T](matcher: PartialFunction[Expr, T]) extends TransformerWithPC with Traverser[Seq[(T, Expr)]]{ + object CollectorWithPaths { + def apply[T](p: PartialFunction[Expr,T]): CollectorWithPaths[(T, Expr)] = new CollectorWithPaths[(T, Expr)] { + def collect(e: Expr, path: Seq[Expr]): Option[(T, Expr)] = if (!p.isDefinedAt(e)) None else { + Some(p(e) -> And(path)) + } + } + } + + trait CollectorWithPaths[T] extends TransformerWithPC with Traverser[Seq[T]] { type C = Seq[Expr] - val initC = Nil + val initC : C = Nil def register(e: Expr, path: C) = path :+ e - var results: Seq[(T, Expr)] = Nil + private var results: Seq[T] = Nil + + def collect(e: Expr, path: Seq[Expr]): Option[T] + + def walk(e: Expr, path: Seq[Expr]): Option[Expr] = None - override def rec(e: Expr, path: C) = { - if(matcher.isDefinedAt(e)) { - val res = matcher(e) - results = results :+ (res, And(path)) + override final def rec(e: Expr, path: Seq[Expr]) = { + collect(e, path).foreach { results :+= _ } + walk(e, path) match { + case Some(r) => r + case _ => super.rec(e, path) } - super.rec(e, path) } - def traverse(e: Expr) = { + def traverse(funDef: FunDef): Seq[T] = { + val precondition = funDef.precondition.map(e => matchToIfThenElse(e)).toSeq + val precTs = funDef.precondition.map(e => traverse(e)).toSeq.flatten + val bodyTs = funDef.body.map(e => traverse(e, precondition)).toSeq.flatten + val postTs = funDef.postcondition.map(p => traverse(p._2)).toSeq.flatten + precTs ++ bodyTs ++ postTs + } + + def traverse(e: Expr): Seq[T] = traverse(e, initC) + + def traverse(e: Expr, init: Expr): Seq[T] = traverse(e, Seq(init)) + + def traverse(e: Expr, init: Seq[Expr]): Seq[T] = { results = Nil - rec(e, initC) + rec(e, init) results } } @@ -1173,7 +1199,11 @@ object TreeOps { case _ => false } } - class ChooseCollectorWithPaths extends CollectorWithPaths[Choose](ChooseMatch) + + class ChooseCollectorWithPaths extends CollectorWithPaths[(Choose,Expr)] { + val matcher = ChooseMatch.lift + def collect(e: Expr, path: Seq[Expr]) = matcher(e).map(_ -> And(path)) + } /** * Eliminates tuples of arity 0 and 1. diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index b6806f45c..26c557608 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -6,116 +6,171 @@ package termination import leon.purescala.Definitions._ import leon.purescala.Trees._ import leon.purescala.TreeOps._ +import leon.purescala.TypeTrees._ +import leon.purescala.TypeTreeOps._ import leon.purescala.Common._ import scala.collection.mutable.{Map => MutableMap} -object ChainID { - private var counter: Int = 0 - def get: Int = { - counter = counter + 1 - counter +final case class Chain(relations: List[Relation]) { + + private def identifier: Map[(Relation, Relation), Int] = { + (relations zip (relations.tail :+ relations.head)).groupBy(p => p).mapValues(_.size) } -} -final case class Chain(chain: List[Relation]) { - val id = ChainID.get + override def equals(obj: Any): Boolean = obj match { + case (chain : Chain) => chain.identifier == identifier + case _ => false + } - override def equals(obj: Any): Boolean = obj.isInstanceOf[Chain] && obj.asInstanceOf[Chain].id == id - override def hashCode(): Int = id + override def hashCode(): Int = identifier.hashCode - def funDef : FunDef = chain.head.funDef - def funDefs : Set[FunDef] = chain.map(_.funDef).toSet + lazy val funDef : FunDef = relations.head.funDef + lazy val funDefs : Set[FunDef] = relations.map(_.funDef).toSet - lazy val size: Int = chain.size + lazy val size: Int = relations.size - def loop(initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = { - def rec(relations: List[Relation], subst: Map[Identifier, Expr]): Seq[Expr] = relations match { - case Relation(_, path, FunctionInvocation(tfd, args)) :: Nil => - assert(tfd.fd == funDef) - val newPath = path.map(replaceFromIDs(subst, _)) - val equalityConstraints = if (finalSubst.isEmpty) Seq() else { - val newArgs = args.map(replaceFromIDs(subst, _)) - (tfd.params.map(arg => finalSubst(arg.id)) zip newArgs).map(p => Equals(p._1, p._2)) - } - newPath ++ equalityConstraints - case Relation(_, path, FunctionInvocation(tfd, args)) :: xs => - val formalArgs = tfd.params.map(_.id) - val freshFormalArgVars = formalArgs.map(_.freshen.toVariable) - val formalArgsMap: Map[Identifier, Expr] = (formalArgs zip freshFormalArgVars).toMap - val (newPath, newArgs) = (path.map(replaceFromIDs(subst, _)), args.map(replaceFromIDs(subst, _))) - val constraints = newPath ++ (freshFormalArgVars zip newArgs).map(p => Equals(p._1, p._2)) - constraints ++ rec(xs, formalArgsMap) - case Nil => sys.error("Empty chain shouldn't exist by construction") + private lazy val inlining : Seq[(Seq[ValDef], Expr)] = { + def rec(list: List[Relation], funDef: TypedFunDef, subst: Map[Identifier, Expr]): Seq[(Seq[ValDef], Expr)] = list match { + case Relation(_, _, fi @ FunctionInvocation(fitfd, args), _) :: xs => + val tfd = TypedFunDef(fitfd.fd, fitfd.tps.map(funDef.translated(_))) + val expr = replaceFromIDs(subst, hoistIte(expandLets(matchToIfThenElse(tfd.body.get)))) + + val mappedArgs = args.map(e => replaceFromIDs(subst, tfd.translated(e))) + val newSubst = (tfd.params.map(_.id) zip mappedArgs).toMap + (tfd.params, expr) +: rec(xs, tfd, newSubst) + case Nil => Seq.empty } - val subst : Map[Identifier, Expr] = if (initialSubst.nonEmpty) initialSubst else { - funDef.params.map(arg => arg.id -> arg.toVariable).toMap + val body = hoistIte(expandLets(matchToIfThenElse(funDef.body.get))) + val tfd = funDef.typed(funDef.tparams.map(_.tp)) + (tfd.params, body) +: rec(relations, tfd, funDef.params.map(arg => arg.id -> arg.toVariable).toMap) + } + + lazy val finalParams : Seq[ValDef] = inlining.last._1 + + def loop(initialSubst: Map[Identifier, Identifier] = Map(), finalSubst: Map[Identifier, Identifier] = Map()) : Seq[Expr] = { + def rec(relations: List[Relation], funDef: TypedFunDef, subst: Map[Identifier, Identifier]): Seq[Expr] = { + val translate : Expr => Expr = { + val map : Map[Expr, Expr] = subst.map(p => p._1.toVariable -> p._2.toVariable) + (e: Expr) => replace(map, funDef.translated(e)) + } + + val Relation(_, path, fi @ FunctionInvocation(fitfd, args), _) = relations.head + val tfd = TypedFunDef(fitfd.fd, fitfd.tps.map(funDef.translated(_))) + + lazy val newArgs = args.map(translate(_)) + + path.map(translate(_)) ++ (relations.tail match { + case Nil => if (finalSubst.isEmpty) Seq.empty else { + (tfd.params.map(vd => finalSubst(vd.id).toVariable) zip newArgs).map(p => Equals(p._1, p._2)) + } + case xs => + val params = tfd.params.map(_.id) + val freshParams = tfd.params.map(arg => FreshIdentifier(arg.id.name, true).setType(arg.tpe)) + val bindings = (freshParams.map(_.toVariable) zip newArgs).map(p => Equals(p._1, p._2)) + bindings ++ rec(xs, tfd, (params zip freshParams).toMap) + }) } - val Chain(relations) = this - rec(relations, subst) + + rec(relations, funDef.typed(funDef.tparams.map(_.tp)), initialSubst) } + /* def reentrant(other: Chain) : Seq[Expr] = { assert(funDef == other.funDef) - val bindingSubst : Map[Identifier, Expr] = funDef.params.map({ - arg => arg.id -> arg.id.freshen.toVariable - }).toMap + val bindingSubst = funDef.params.map(vd => vd.id -> vd.id.freshen).toMap val firstLoop = loop(finalSubst = bindingSubst) val secondLoop = other.loop(initialSubst = bindingSubst) firstLoop ++ secondLoop } + */ - def inlined: TraversableOnce[Expr] = { - def rec(list: List[Relation], mapping: Map[Identifier, Expr]): List[Expr] = list match { - case Relation(_, _, FunctionInvocation(fd, args)) :: xs => - val mappedArgs = args.map(replaceFromIDs(mapping, _)) - val newMapping = fd.params.map(_.id).zip(mappedArgs).toMap - // We assume we have a body at this point. It would be weird to have gotten here without one... - val expr = hoistIte(expandLets(matchToIfThenElse(fd.body.get))) - val inlinedExpr = replaceFromIDs(newMapping, expr) - inlinedExpr:: rec(xs, newMapping) - case Nil => Nil - } + lazy val cycles : Seq[List[Relation]] = (0 to relations.size - 1).map { index => + val (start, end) = relations.splitAt(index) + end ++ start + } - val body = hoistIte(expandLets(matchToIfThenElse(funDef.body.get))) - body :: rec(chain, funDef.params.map(arg => arg.id -> arg.toVariable).toMap) + def compose(that: Chain) : Set[Chain] = { + val map = relations.zipWithIndex.map(p => p._1.call.tfd.fd -> ((p._2 + 1) % relations.size)).groupBy(_._1).mapValues(_.map(_._2)) + val tmap = that.relations.zipWithIndex.map(p => p._1.funDef -> p._2).groupBy(_._1).mapValues(_.map(_._2)) + val keys = map.keys.toSet & tmap.keys.toSet + + keys.flatMap(fd => map(fd).flatMap { i1 => + val (start1, end1) = relations.splitAt(i1) + val called = if (start1.isEmpty) relations.head.funDef else start1.last.call.tfd.fd + tmap(called).map { i2 => + val (start2, end2) = that.relations.splitAt(i2) + Chain(start1 ++ end2 ++ start2 ++ end1) + } + }).toSet } + + lazy val inlined: Seq[Expr] = inlining.map(_._2) } -class ChainBuilder(relationBuilder: RelationBuilder) { - - private val chainCache : MutableMap[FunDef, Set[Chain]] = MutableMap() - - def run(funDef: FunDef): Set[Chain] = chainCache.get(funDef) match { - case Some(chains) => chains - case None => { - // Note that this method will generate duplicate cycles (in fact, it will generate all list representations of a cycle) - def chains(partials: List[(Relation, List[Relation])]): List[List[Relation]] = if (partials.isEmpty) Nil else { - // Note that chains in partials are reversed to profit from O(1) insertion - val (results, newPartials) = partials.foldLeft(List[List[Relation]](),List[(Relation, List[Relation])]())({ - case ((results, partials), (first, chain @ Relation(_, _, FunctionInvocation(tfd, _)) :: xs)) => - val cycle = relationBuilder.run(tfd.fd).contains(first) - // reverse the chain when "returning" it since we're working on reversed chains - val newResults = if (cycle) chain.reverse :: results else results - - // Partial chains can fall back onto a transition that was already taken (thus creating a cycle - // inside the chain). Since this cycle will be discovered elsewhere, such partial chains should be - // dropped from the partial chain list - val transitions = relationBuilder.run(tfd.fd) -- chain.toSet - val newPartials = transitions.map(transition => (first, transition :: chain)).toList - - (newResults, partials ++ newPartials) - case (_, (_, Nil)) => scala.sys.error("Empty partial chain shouldn't exist by construction") - }) - - results ++ chains(newPartials) +trait ChainBuilder extends RelationBuilder { self: TerminationChecker with Strengthener with RelationComparator => + + protected type ChainSignature = (FunDef, Set[RelationSignature]) + + protected def funDefChainSignature(funDef: FunDef): ChainSignature = { + funDef -> (self.program.callGraph.transitiveCallees(funDef) + funDef).map(funDefRelationSignature(_)) + } + + private val chainCache : MutableMap[FunDef, (Set[FunDef], Set[Chain], ChainSignature)] = MutableMap.empty + + def getChains(funDef: FunDef)(implicit solver: Processor with Solvable): (Set[FunDef], Set[Chain]) = chainCache.get(funDef) match { + case Some((subloop, chains, signature)) if signature == funDefChainSignature(funDef) => subloop -> chains + case _ => { + val relationConstraints : MutableMap[Relation, SizeConstraint] = MutableMap.empty + + def decreasing(relations: List[Relation]): Boolean = { + val constraints = relations.map(relation => relationConstraints.get(relation).getOrElse { + val Relation(funDef, path, FunctionInvocation(fd, args), _) = relation + val (e1, e2) = (Tuple(funDef.params.map(_.toVariable)), Tuple(args)) + val constraint = if (solver.definitiveALL(Implies(And(path), self.softDecreasing(e1, e2)))) { + if (solver.definitiveALL(Implies(And(path), self.sizeDecreasing(e1, e2)))) { + StrongDecreasing + } else { + WeakDecreasing + } + } else { + NoConstraint + } + + relationConstraints(relation) = constraint + constraint + }).toSet + + !constraints(NoConstraint) && constraints(StrongDecreasing) + } + + def chains(seen: Set[FunDef], chain: List[Relation]) : (Set[FunDef], Set[Chain]) = { + val Relation(_, _, FunctionInvocation(tfd, _), _) :: xs = chain + val fd = tfd.fd + + if (!self.program.callGraph.transitivelyCalls(fd, funDef)) { + Set.empty[FunDef] -> Set.empty[Chain] + } else if (fd == funDef) { + Set.empty[FunDef] -> Set(Chain(chain.reverse)) + } else if (seen(fd)) { + Set(fd) -> Set.empty[Chain] + } else { + val results = getRelations(fd).map(r => chains(seen + fd, r :: chain)) + val (funDefs, allChains) = results.unzip + (funDefs.flatten, allChains.flatten) + } } - val initialPartials = relationBuilder.run(funDef).map(r => (r, r :: Nil)).toList - val result = chains(initialPartials).map(Chain(_)).toSet - chainCache(funDef) = result - result + val results = getRelations(funDef).map(r => chains(Set.empty, r :: Nil)) + val (funDefs, allChains) = results.unzip + + val loops = funDefs.flatten + val filteredChains = allChains.flatten.filter(chain => !decreasing(chain.relations)) + + chainCache(funDef) = (loops, filteredChains, funDefChainSignature(funDef)) + + loops -> filteredChains } } } diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index 376943e7c..743db6335 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -10,37 +10,75 @@ import purescala.TypeTreeOps._ import purescala.Definitions._ import purescala.Common._ -class ChainComparator(structuralSize: StructuralSize) { - import structuralSize.size +trait ChainComparator { self : StructuralSize with TerminationChecker => private object ContainerType { def unapply(c: ClassType): Option[(CaseClassType, Seq[(Identifier, TypeTree)])] = c match { - case act @ CaseClassType(classDef, tpes) => - val ftps = act.fields - val parentType = classDef.parent.getOrElse(c) - - if (ftps.exists(ad => isSubtypeOf(ad.tpe, parentType))) { - None - } else if (classDef.parent.map(_.classDef.knownChildren.size > 1).getOrElse(false)) { - None - } else { - Some((act, ftps.map{ ad => ad.id -> ad.tpe })) - } + case cct @ CaseClassType(ccd, _) => + if (cct.fields.exists(arg => isSubtypeOf(arg.tpe, cct.parent.getOrElse(c)))) None + else if (ccd.hasParent && ccd.parent.get.knownDescendents.size > 1) None + else Some((cct, cct.fields.map(arg => arg.id -> arg.tpe))) case _ => None } } - def sizeDecreasing(e1: Expr, e2s: Seq[(Seq[Expr], Expr)]) : Expr = e1.getType match { - case ContainerType(ct1, fields1) => Or(fields1.zipWithIndex map { case ((id1, type1), index) => - sizeDecreasing(CaseClassSelector(ct1, e1, id1), e2s.map { case (path, e2) => + private def flatTypesPowerset(tpe: TypeTree): Set[Expr => Expr] = { + def powerSetToFunSet(l: TraversableOnce[Expr => Expr]): Set[Expr => Expr] = { + l.toSet.subsets.filter(_.nonEmpty).map((reconss : Set[Expr => Expr]) => reconss.toSeq match { + case Seq(x) => x + case seq => (e: Expr) => Tuple(seq.map(r => r(e))) + }).toSet + } + + def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match { + case ContainerType(cct, fields) => + powerSetToFunSet(fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) => + rec(fieldTpe).map(recons => (e: Expr) => recons(CaseClassSelector(cct, e, fieldId))) + }) + case TupleType(tpes) => + powerSetToFunSet((0 until tpes.length).flatMap { case index => + rec(tpes(index)).map(recons => (e: Expr) => recons(TupleSelect(e, index + 1))) + }) + case _ => Set((e: Expr) => e) + } + + rec(tpe) + } + + private def flatType(tpe: TypeTree): Set[Expr => Expr] = { + def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match { + case ContainerType(cct, fields) => + fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) => + rec(fieldTpe).map(recons => (e: Expr) => recons(CaseClassSelector(cct, e, fieldId))) + }.toSet + case TupleType(tpes) => + (0 until tpes.length).flatMap { case index => + rec(tpes(index)).map(recons => (e: Expr) => recons(TupleSelect(e, index + 1))) + }.toSet + case _ => Set((e: Expr) => e) + } + + rec(tpe) + } + + def structuralDecreasing(e1: Expr, e2s: Seq[(Seq[Expr], Expr)]) : Seq[Expr] = flatTypesPowerset(e1.getType).toSeq.map { + recons => And(e2s.map { case (path, e2) => + Implies(And(path), GreaterThan(self.size(recons(e1)), self.size(recons(e2)))) + }) + } + + /* + def structuralDecreasing(e1: Expr, e2s: Seq[(Seq[Expr], Expr)]) : Expr = e1.getType match { + case ContainerType(def1, fields1) => Or(fields1.zipWithIndex map { case ((id1, type1), index) => + structuralDecreasing(CaseClassSelector(def1, e1, id1), e2s.map { case (path, e2) => e2.getType match { - case ContainerType(ct2, fields2) => (path, CaseClassSelector(ct2, e2, fields2(index)._1)) + case ContainerType(def2, fields2) => (path, CaseClassSelector(def2, e2, fields2(index)._1)) case _ => scala.sys.error("Unexpected input combinations: " + e1 + " " + e2) } }) }) case TupleType(types1) => Or((0 until types1.length) map { case index => - sizeDecreasing(TupleSelect(e1, index + 1), e2s.map { case (path, e2) => + structuralDecreasing(TupleSelect(e1, index + 1), e2s.map { case (path, e2) => e2.getType match { case TupleType(_) => (path, TupleSelect(e2, index + 1)) case _ => scala.sys.error("Unexpected input combination: " + e1 + " " + e2) @@ -49,14 +87,13 @@ class ChainComparator(structuralSize: StructuralSize) { }) case c: ClassType => And(e2s map { case (path, e2) => e2.getType match { - case c2: ClassType => Implies(And(path), GreaterThan(size(e1), size(e2))) + case c2: ClassType => Implies(And(path), GreaterThan(self.size(e1), self.size(e2))) case _ => scala.sys.error("Unexpected input combination: " + e1 + " " + e2) } }) - case BooleanType => BooleanLiteral(false) - case Int32Type => BooleanLiteral(false) - case tpe => scala.sys.error("Unexpected type " + tpe) + case _ => BooleanLiteral(false) } + */ private sealed abstract class NumericEndpoint { def inverse: NumericEndpoint = this match { @@ -97,7 +134,7 @@ class ChainComparator(structuralSize: StructuralSize) { private case object AnyEndpoint extends NumericEndpoint private case object NoEndpoint extends NumericEndpoint - private def numericEndpoint(value: Expr, cluster: Set[Chain], checker: TerminationChecker) = { + private def numericEndpoint(value: Expr, cluster: Set[Chain]) = { object Value { val vars = variablesOf(value) @@ -138,8 +175,8 @@ class ChainComparator(structuralSize: StructuralSize) { case NoEndpoint => endpoint(thenn) min endpoint(elze) case ep => - val terminatingThen = functionCallsOf(thenn).forall(fi => checker.terminates(fi.tfd.fd).isGuaranteed) - val terminatingElze = functionCallsOf(elze).forall(fi => checker.terminates(fi.tfd.fd).isGuaranteed) + val terminatingThen = functionCallsOf(thenn).forall(fi => self.terminates(fi.tfd.fd).isGuaranteed) + val terminatingElze = functionCallsOf(elze).forall(fi => self.terminates(fi.tfd.fd).isGuaranteed) val thenEndpoint = if (terminatingThen) ep max endpoint(thenn) else endpoint(thenn) val elzeEndpoint = if (terminatingElze) ep.inverse max endpoint(elze) else endpoint(elze) thenEndpoint max elzeEndpoint @@ -152,45 +189,26 @@ class ChainComparator(structuralSize: StructuralSize) { }) } - def numericConverging(e1: Expr, e2s: Seq[(Seq[Expr], Expr)], cluster: Set[Chain], checker: TerminationChecker) : Expr = e1.getType match { - case ContainerType(def1, fields1) => Or(fields1.zipWithIndex map { case ((id1, type1), index) => - numericConverging(CaseClassSelector(def1, e1, id1), e2s.map { case (path, e2) => - e2.getType match { - case ContainerType(def2, fields2) => (path, CaseClassSelector(def2, e2, fields2(index)._1)) - case _ => scala.sys.error("Unexpected input combination: " + e1 + " " + e2) + def numericConverging(e1: Expr, e2s: Seq[(Seq[Expr], Expr)], cluster: Set[Chain]) : Seq[Expr] = flatType(e1.getType).toSeq.flatMap { + recons => recons(e1) match { + case e if e.getType == Int32Type => + val endpoint = numericEndpoint(e, cluster) + + val uppers = if (endpoint == UpperBoundEndpoint || endpoint == AnyEndpoint) { + Some(And(e2s map { case (path, e2) => Implies(And(path), GreaterThan(e, recons(e2))) })) + } else { + None } - }, cluster, checker) - }) - case TupleType(types) => Or((0 until types.length) map { case index => - numericConverging(TupleSelect(e1, index + 1), e2s.map { case (path, e2) => - e2.getType match { - case TupleType(_) => (path, TupleSelect(e2, index + 1)) - case _ => scala.sys.error("Unexpected input combination: " + e1 + " " + e2) + + val lowers = if (endpoint == LowerBoundEndpoint || endpoint == AnyEndpoint) { + Some(And(e2s map { case (path, e2) => Implies(And(path), LessThan(e, recons(e2))) })) + } else { + None } - }, cluster, checker) - }) - case Int32Type => numericEndpoint(e1, cluster, checker) match { - case UpperBoundEndpoint => And(e2s map { - case (path, e2) if e2.getType == Int32Type => Implies(And(path), GreaterThan(e1, e2)) - case (_, e2) => scala.sys.error("Unexpected input combinations: " + e1 + " " + e2) - }) - case LowerBoundEndpoint => And(e2s map { - case (path, e2) if e2.getType == Int32Type => Implies(And(path), LessThan(e1, e2)) - case (_, e2) => scala.sys.error("Unexpected input combinations: " + e1 + " " + e2) - }) - case AnyEndpoint => Or(And(e2s map { - case (path, e2) if e2.getType == Int32Type => Implies(And(path), GreaterThan(e1, e2)) - case (_, e2) => scala.sys.error("Unexpected input combinations: " + e1 + " " + e2) - }), And(e2s map { - case (path, e2) if e2.getType == Int32Type => Implies(And(path), LessThan(e1, e2)) - case (_, e2) => scala.sys.error("Unexpected input combinations: " + e1 + " " + e2) - })) - case InnerEndpoint => BooleanLiteral(false) - case NoEndpoint => BooleanLiteral(false) + + uppers ++ lowers + case _ => Seq.empty } - case _: ClassType => BooleanLiteral(false) - case BooleanType => BooleanLiteral(false) - case tpe => scala.sys.error("Unexpected type " + tpe) } } diff --git a/src/main/scala/leon/termination/ChainProcessor.scala b/src/main/scala/leon/termination/ChainProcessor.scala index 3b640b1de..0c4e1608a 100644 --- a/src/main/scala/leon/termination/ChainProcessor.scala +++ b/src/main/scala/leon/termination/ChainProcessor.scala @@ -10,124 +10,68 @@ import purescala.Common._ import purescala.Extractors._ import purescala.Definitions._ -class ChainProcessor(checker: TerminationChecker, - chainBuilder: ChainBuilder, - val structuralSize: StructuralSize, - val strengthener: Strengthener) extends Processor(checker) with Solvable { +import scala.collection.mutable.{Map => MutableMap} + +class ChainProcessor(val checker: TerminationChecker with ChainBuilder with ChainComparator with Strengthener with StructuralSize) extends Processor with Solvable { val name: String = "Chain Processor" - val chainComparator = new ChainComparator(structuralSize) - - def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = { - implicit val debugSection = utils.DebugSectionTermination - - reporter.debug("- Running ChainProcessor") - val allChainMap : Map[FunDef, Set[Chain]] = problem.funDefs.map(funDef => funDef -> chainBuilder.run(funDef)).toMap - reporter.debug("- Computing all possible Chains") - var counter = 0 - val possibleChainMap : Map[FunDef, Set[Chain]] = allChainMap.mapValues(chains => chains.filter(chain => isWeakSAT(And(chain.loop())))) - reporter.debug("- Collecting re-entrant Chains") - val reentrantChainMap : Map[FunDef, Set[Chain]] = possibleChainMap.mapValues(chains => chains.filter(chain => isWeakSAT(And(chain reentrant chain)))) - - // We build a cross-chain map that determines which chains can reenter into another one after a loop. - // Note: We are also checking reentrance SAT here, so again, we negate the formula! - reporter.debug("- Computing cross-chain map") - val crossChains : Map[Chain, Set[Chain]] = possibleChainMap.toSeq.map({ case (funDef, chains) => - val reentrant = reentrantChainMap(funDef) - chains.map(chain => chain -> { - val cross = (reentrant - chain).filter(other => isWeakSAT(And(chain reentrant other))) - val self = if (reentrant(chain)) Set(chain) else Set() - cross ++ self - }) - }).flatten.toMap - - val validChainMap : Map[FunDef, Set[Chain]] = possibleChainMap.map({ case (funDef, chains) => funDef -> chains.filter(crossChains(_).nonEmpty) }).toMap - - // We use the cross-chains to build chain clusters. For each cluster, we must prove that the SAME argument - // decreases in each of the chains in the cluster! - reporter.debug("- Building cluster estimation by fix-point iteration") - val clusters : Map[FunDef, Set[Set[Chain]]] = { - def cluster(set: Set[Chain]): Set[Chain] = { - set ++ set.map(crossChains(_)).flatten - } + def run(problem: Problem) = { + reporter.debug("- Strengthening postconditions") + checker.strengthenPostconditions(problem.funDefs)(this) - def fix[A](f: A => A, a: A): A = { - val na = f(a) - if (a == na) a else fix(f, na) - } +// reporter.debug("- Strengthening applications") +// checker.strengthenApplications(problem.funDefs)(this) - def reduceClusters(all: List[Set[Chain]]): List[Set[Chain]] = { - all.map(cluster => cluster.toSeq.sortBy(_.size).foldLeft(Set[Chain]())({ case (acc, chain) => - val chainElements : Set[Relation] = chain.chain.toSet - val seenElements : Set[Relation] = acc.map(_.chain).flatten.toSet - if ((chainElements -- seenElements).nonEmpty) acc + chain else acc - })).filter(_.nonEmpty) - } + reporter.debug("- Running ChainBuilder") + val chainsMap : Map[FunDef, (Set[FunDef], Set[Chain])] = problem.funDefs.map { funDef => + funDef -> checker.getChains(funDef)(this) + }.toMap - def filterClusters(all: List[Set[Chain]]): List[Set[Chain]] = if (all.isEmpty) Nil else { - val newCluster = all.head - val rest = all.tail.filter(set => !set.subsetOf(newCluster)) - newCluster :: filterClusters(rest) - } + val loopPoints = chainsMap.foldLeft(Set.empty[FunDef]) { case (set, (fd, (fds, chains))) => set ++ fds } + + if (loopPoints.size > 1) { + reporter.debug("-+> Multiple looping points, can't build chain proof") + (Nil, List(problem)) + } else { + + def exprs(fd: FunDef): (Expr, Seq[(Seq[Expr], Expr)], Set[Chain]) = { + val fdChains = chainsMap(fd)._2 + val nfdChains = chainsMap.filter(_._1 != fd).values.foldLeft(Set.empty[Chain])((set, p) => set ++ p._2) + assert(nfdChains.subsetOf(fdChains)) - def build(chains: Set[Chain]): Set[Set[Chain]] = { - val allClusters = chains.map(chain => fix(cluster, Set(chain))) - val reducedClusters = reduceClusters(allClusters.toList) - val filteredClusters = filterClusters(reducedClusters.sortBy(- _.size)) - filteredClusters.toSet + val e1 = Tuple(fd.params.map(_.toVariable)) + val e2s = fdChains.toSeq.map { chain => + val freshParams = chain.finalParams.map(arg => FreshIdentifier(arg.id.name, true).setType(arg.id.getType)) + val finalBindings = (chain.finalParams.map(_.id) zip freshParams).toMap + (chain.loop(finalSubst = finalBindings), Tuple(freshParams.map(_.toVariable))) + } + + (e1, e2s, fdChains) } - validChainMap.map({ case (funDef, chains) => funDef -> build(chains) }) - } + val funDefs = if (loopPoints.size == 1) Set(loopPoints.head) else problem.funDefs - reporter.debug("- Strengthening postconditions") - strengthenPostconditions(problem.funDefs) - - def buildLoops(fd: FunDef, cluster: Set[Chain]): (Expr, Seq[(Seq[Expr], Expr)]) = { - val e1 = Tuple(fd.params.map(_.toVariable)) - val e2s = cluster.toSeq.map({ chain => - val freshArgs : Seq[Expr] = fd.params.map(arg => arg.id.freshen.toVariable) - val finalBindings = (fd.params.map(_.id) zip freshArgs).toMap - val path = chain.loop(finalSubst = finalBindings) - path -> Tuple(freshArgs) - }) - - (e1, e2s) - } + reporter.debug("-+> Searching for structural size decrease") - type ClusterMap = Map[FunDef, Set[Set[Chain]]] - type FormulaGenerator = (FunDef, Set[Chain]) => Expr + val (se1, se2s, _) = exprs(funDefs.head) + val structuralFormulas = checker.structuralDecreasing(se1, se2s) + val structuralDecreasing = structuralFormulas.exists(formula => definitiveALL(formula)) - def clear(clusters: ClusterMap, gen: FormulaGenerator): ClusterMap = { - val formulas = clusters.map({ case (fd, clusters) => - (fd, clusters.map(cluster => cluster -> gen(fd, cluster))) - }) + reporter.debug("-+> Searching for numerical converging") - formulas.map({ case (fd, clustersWithFormulas) => - fd -> clustersWithFormulas.filter({ case (cluster, formula) => !isAlwaysSAT(formula) }).map(_._1) - }) - } + // worth checking multiple funDefs as the endpoint discovery can be context sensitive + val numericDecreasing = funDefs.exists { fd => + val (ne1, ne2s, fdChains) = exprs(fd) + val numericFormulas = checker.numericConverging(ne1, ne2s, fdChains) + numericFormulas.exists(formula => definitiveALL(formula)) + } - reporter.debug("- Searching for structural size decrease") - val sizeCleared : ClusterMap = clear(clusters, (fd, cluster) => { - val (e1, e2s) = buildLoops(fd, cluster) - chainComparator.sizeDecreasing(e1, e2s) - }) - - reporter.debug("- Searching for numeric convergence") - val numericCleared : ClusterMap = clear(sizeCleared, (fd, cluster) => { - val (e1, e2s) = buildLoops(fd, cluster) - chainComparator.numericConverging(e1, e2s, cluster, checker) - }) - - val (okPairs, nokPairs) = numericCleared.partition(_._2.isEmpty) - val nok = nokPairs.map(_._1).toSet - val (ok, transitiveNok) = okPairs.map(_._1).partition({ fd => - (checker.program.callGraph.transitiveCallees(fd) intersect nok).isEmpty - }) - val allNok = nok ++ transitiveNok - val newProblems = if (allNok.nonEmpty) List(Problem(allNok)) else Nil - (ok.map(Cleared(_)), newProblems) + if (structuralDecreasing || numericDecreasing) { + (problem.funDefs.map(Cleared(_)), Nil) + } else { + (Nil, List(problem)) + } + } } } diff --git a/src/main/scala/leon/termination/ComplexTerminationChecker.scala b/src/main/scala/leon/termination/ComplexTerminationChecker.scala index c9815cb5a..127ff6681 100644 --- a/src/main/scala/leon/termination/ComplexTerminationChecker.scala +++ b/src/main/scala/leon/termination/ComplexTerminationChecker.scala @@ -6,34 +6,39 @@ package termination import purescala.Definitions._ import purescala.Trees._ -class ComplexTerminationChecker(context: LeonContext, _program: Program) extends TerminationChecker(context, _program.duplicate) { +import scala.collection.mutable.{Map => MutableMap} - import scala.collection.mutable.{Map => MutableMap} +class ComplexTerminationChecker(context: LeonContext, program: Program) + extends TerminationChecker(context, program) + with StructuralSize + with RelationComparator + with ChainComparator + with Strengthener + with ComponentBuilder + with RelationBuilder + with ChainBuilder { val name = "Complex Termination Checker" val description = "A modular termination checker with a few basic modules™" - val structuralSize = new StructuralSize() - val relationBuilder = new RelationBuilder() - val chainBuilder = new ChainBuilder(relationBuilder) - val relationComparator = new RelationComparator(structuralSize) - val strengthener = new Strengthener(relationComparator) - private val pipeline = new ProcessingPipeline( program, context, // required for solvers and reporting new ComponentProcessor(this), - new RecursionProcessor(this, relationBuilder), - new RelationProcessor(this, relationBuilder, structuralSize, relationComparator, strengthener), - new ChainProcessor(this, chainBuilder, structuralSize, strengthener), - new LoopProcessor(this, chainBuilder, structuralSize, strengthener) + new RecursionProcessor(this), + new RelationProcessor(this), + new ChainProcessor(this), + new LoopProcessor(this) ) private val clearedMap : MutableMap[FunDef, String] = MutableMap() private val brokenMap : MutableMap[FunDef, (String, Seq[Expr])] = MutableMap() + private val maybeBrokenMap : MutableMap[FunDef, (String, Seq[Expr])] = MutableMap() + def initialize() { for ((reason, results) <- pipeline.run; result <- results) result match { case Cleared(fd) => clearedMap(fd) = reason case Broken(fd, args) => brokenMap(fd) = (reason, args) + case MaybeBroken(fd, args) => maybeBrokenMap(fd) = (reason, args) } } @@ -43,12 +48,15 @@ class ComplexTerminationChecker(context: LeonContext, _program: Program) extends case None => { val guarantee = brokenMap.get(funDef) match { case Some((reason, args)) => LoopsGivenInputs(reason, args) - case None => program.callGraph.transitiveCallees(funDef) intersect brokenMap.keys.toSet match { - case set if set.nonEmpty => CallsNonTerminating(set) - case _ => if (pipeline.clear(funDef)) clearedMap.get(funDef) match { - case Some(reason) => Terminates(reason) - case None => scala.sys.error(funDef.id + " -> not problem, but not cleared or broken ??") - } else NoGuarantee + case None => maybeBrokenMap.get(funDef) match { + case Some((reason, args)) => MaybeLoopsGivenInputs(reason, args) + case None => program.callGraph.transitiveCallees(funDef) intersect (brokenMap.keys.toSet ++ maybeBrokenMap.keys) match { + case set if set.nonEmpty => CallsNonTerminating(set) + case _ => if (pipeline.clear(funDef)) clearedMap.get(funDef) match { + case Some(reason) => Terminates(reason) + case None => scala.sys.error(funDef.id + " -> not problem, but not cleared or broken ??") + } else NoGuarantee + } } } diff --git a/src/main/scala/leon/termination/ComponentBuilder.scala b/src/main/scala/leon/termination/ComponentBuilder.scala index ea65f5d7b..bfb6a239b 100644 --- a/src/main/scala/leon/termination/ComponentBuilder.scala +++ b/src/main/scala/leon/termination/ComponentBuilder.scala @@ -3,53 +3,8 @@ package leon package termination -/** This could be defined anywhere, it's just that the - termination checker is the only place where it is used. */ -object ComponentBuilder { - def run[T](graph : Map[T,Set[T]]) : List[Set[T]] = { - // The first part is a shameless adaptation from Wikipedia - val allVertices : Set[T] = graph.keySet ++ graph.values.flatten +import utils._ - var index = 0 - var indices : Map[T,Int] = Map.empty - var lowLinks : Map[T,Int] = Map.empty - var components : List[Set[T]] = Nil - var s : List[T] = Nil - - def strongConnect(v : T) { - indices = indices.updated(v, index) - lowLinks = lowLinks.updated(v, index) - index += 1 - s = v :: s - - for(w <- graph.getOrElse(v, Set.empty)) { - if(!indices.isDefinedAt(w)) { - strongConnect(w) - lowLinks = lowLinks.updated(v, lowLinks(v) min lowLinks(w)) - } else if(s.contains(w)) { - lowLinks = lowLinks.updated(v, lowLinks(v) min indices(w)) - } - } - - if(lowLinks(v) == indices(v)) { - var c : Set[T] = Set.empty - var stop = false - do { - val x :: xs = s - c = c + x - s = xs - stop = (x == v) - } while(!stop); - components = c :: components - } - } - - for(v <- allVertices) { - if(!indices.isDefinedAt(v)) { - strongConnect(v) - } - } - - components - } +trait ComponentBuilder { + def getComponents[T](graph : Map[T,Set[T]]) : List[Set[T]] = SCC.scc(graph) } diff --git a/src/main/scala/leon/termination/ComponentProcessor.scala b/src/main/scala/leon/termination/ComponentProcessor.scala index 232368ea2..b541a7477 100644 --- a/src/main/scala/leon/termination/ComponentProcessor.scala +++ b/src/main/scala/leon/termination/ComponentProcessor.scala @@ -7,7 +7,7 @@ import purescala.TreeOps._ import purescala.Definitions._ import scala.collection.mutable.{Map => MutableMap} -class ComponentProcessor(checker: TerminationChecker) extends Processor(checker) { +class ComponentProcessor(val checker: TerminationChecker with ComponentBuilder) extends Processor { val name: String = "Component Processor" @@ -15,8 +15,9 @@ class ComponentProcessor(checker: TerminationChecker) extends Processor(checker) val pairs : Set[(FunDef, FunDef)] = checker.program.callGraph.allCalls.filter({ case (fd1, fd2) => problem.funDefs(fd1) && problem.funDefs(fd2) }) + val callGraph : Map[FunDef,Set[FunDef]] = pairs.groupBy(_._1).mapValues(_.map(_._2)) - val components : List[Set[FunDef]] = ComponentBuilder.run(callGraph) + val components : List[Set[FunDef]] = checker.getComponents(callGraph) val fdToSCC : Map[FunDef, Set[FunDef]] = components.map(set => set.map(fd => fd -> set)).flatten.toMap val terminationCache : MutableMap[FunDef, Boolean] = MutableMap() diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala index 1d30ceca8..7ff9039af 100644 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -4,39 +4,50 @@ package leon package termination import purescala.Definitions._ +import purescala.Common._ import purescala.Trees._ import purescala.TreeOps._ -class LoopProcessor(checker: TerminationChecker, - chainBuilder: ChainBuilder, - val structuralSize: StructuralSize, - val strengthener: Strengthener, - k: Int = 10) extends Processor(checker) with Solvable { +import scala.collection.mutable.{Map => MutableMap} + +class LoopProcessor(val checker: TerminationChecker with ChainBuilder with Strengthener with StructuralSize, k: Int = 10) extends Processor with Solvable { val name: String = "Loop Processor" def run(problem: Problem) = { - val allChains : Set[Chain] = problem.funDefs.map(fd => chainBuilder.run(fd)).flatten - // Get reentrant loops (see ChainProcessor for more details) - val chains : Set[Chain] = allChains.filter(chain => isWeakSAT(And(chain reentrant chain))) - - val nonTerminating = chains.flatMap({ chain => - val freshArgs : Seq[Expr] = chain.funDef.params.map(arg => arg.id.freshen.toVariable) - val finalBindings = (chain.funDef.params.map(_.id) zip freshArgs).toMap - val path = chain.loop(finalSubst = finalBindings) - val formula = And(path :+ Equals(Tuple(chain.funDef.params.map(_.toVariable)), Tuple(freshArgs))) - - val solvable = functionCallsOf(formula).forall({ - case FunctionInvocation(tfd, args) => checker.terminates(tfd.fd).isGuaranteed - }) - - if (!solvable) None else getModel(formula) match { - case Some(map) => Some(chain.funDef -> chain.funDef.params.map(arg => map(arg.id))) - case _ => None +// reporter.debug("- Strengthening applications") +// checker.strengthenApplications(problem.funDefs)(this) + + reporter.debug("- Running ChainBuilder") + val chains : Set[Chain] = problem.funDefs.flatMap(fd => checker.getChains(fd)(this)._2) + + reporter.debug("- Searching for loops") + val nonTerminating: MutableMap[FunDef, Result] = MutableMap.empty + + (0 to k).foldLeft(chains) { (cs, index) => + reporter.debug("-+> Iteration #" + index) + for (chain <- cs if !nonTerminating.isDefinedAt(chain.funDef) && + (chain.funDef.params zip chain.finalParams).forall(p => p._1.getType == p._2.getType)) { + val freshParams = chain.funDef.params.map(arg => FreshIdentifier(arg.id.name, true).setType(arg.tpe)) + val finalBindings = (chain.funDef.params.map(_.id) zip freshParams).toMap + val path = chain.loop(finalSubst = finalBindings) + + val srcTuple = Tuple(chain.funDef.params.map(_.toVariable)) + val resTuple = Tuple(freshParams.map(_.toVariable)) + + definitiveSATwithModel(And(path :+ Equals(srcTuple, resTuple))) match { + case Some(map) => + val args = chain.funDef.params.map(arg => map(arg.id)) + val res = if (chain.relations.exists(_.inAnon)) MaybeBroken(chain.funDef, args) else Broken(chain.funDef, args) + nonTerminating(chain.funDef) = res + case None => + } } - }).toMap - val results = nonTerminating.map({ case (funDef, args) => Broken(funDef, args) }) + cs.flatMap(c1 => chains.flatMap(c2 => c1.compose(c2))) + } + + val results = nonTerminating.values.toSet val remaining = problem.funDefs -- nonTerminating.keys val newProblems = if (remaining.nonEmpty) List(Problem(remaining)) else Nil (results, newProblems) diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index dab5cf346..fb85214a7 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -3,8 +3,6 @@ package leon package termination -import utils._ - import purescala.Trees._ import purescala.TreeOps._ import purescala.Common._ @@ -20,135 +18,71 @@ case class Problem(funDefs: Set[FunDef]) { sealed abstract class Result(funDef: FunDef) case class Cleared(funDef: FunDef) extends Result(funDef) case class Broken(funDef: FunDef, args: Seq[Expr]) extends Result(funDef) +case class MaybeBroken(funDef: FunDef, args: Seq[Expr]) extends Result(funDef) -abstract class Processor(val checker: TerminationChecker) { +trait Processor { val name: String - val reporter = checker.context.reporter + val checker : TerminationChecker - protected def run(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) + implicit val debugSection = utils.DebugSectionTermination + val reporter = checker.context.reporter - def process(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = run(problem) + def run(problem: Problem): (Traversable[Result], Traversable[Problem]) } -class Strengthener(relationComparator: RelationComparator) { - import scala.collection.mutable.{Set => MutableSet} +trait Solvable extends Processor { - private val strengthened : MutableSet[FunDef] = MutableSet() - private def strengthenPostcondition(funDef: FunDef, cmp: (Expr, Expr) => Expr) - (implicit solver: Processor with Solvable) : Boolean = if (!funDef.hasBody) false else { - assert(solver.checker.terminates(funDef).isGuaranteed) + val checker : TerminationChecker with Strengthener with StructuralSize - val old = funDef.postcondition - val (res, postcondition) = { - val (res, post) = old.getOrElse(FreshIdentifier("res").setType(funDef.returnType) -> BooleanLiteral(true)) - val args = funDef.params.map(_.toVariable) - val sizePost = cmp(Tuple(funDef.params.map(_.toVariable)), res.toVariable) - (res, And(post, sizePost)) - } - - funDef.postcondition = Some(res -> postcondition) + private val solver: SolverFactory[Solver] = SolverFactory(() => { + val structDefs = checker.defs + val program : Program = checker.program + val context : LeonContext = checker.context + val sizeModule : ModuleDef = ModuleDef(FreshIdentifier("$size", false), checker.defs.toSeq) + val newProgram : Program = program.copy(modules = sizeModule :: program.modules) - val prec = matchToIfThenElse(funDef.precondition.getOrElse(BooleanLiteral(true))) - val body = matchToIfThenElse(funDef.body.get) - val post = matchToIfThenElse(postcondition) - val formula = Implies(prec, Let(res, body, post)) + (new FairZ3Solver(context, newProgram) with TimeoutAssumptionSolver).setTimeout(500L) + }) - if (!solver.isAlwaysSAT(formula)) { - funDef.postcondition = old - strengthened.add(funDef) - false - } else { - strengthened.add(funDef) - true - } - } + type Solution = (Option[Boolean], Map[Identifier, Expr]) - def strengthenPostconditions(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) { - // Strengthen postconditions on all accessible functions by adding size constraints - val callees : Set[FunDef] = funDefs.map(fd => solver.checker.program.callGraph.transitiveCallees(fd)).flatten - val sortedCallees : Seq[FunDef] = callees.toSeq.sortWith((fd1, fd2) => solver.checker.program.callGraph.transitivelyCalls(fd2, fd1)) - for (funDef <- sortedCallees if !strengthened(funDef) && funDef.hasBody && solver.checker.terminates(funDef).isGuaranteed) { - // test if size is smaller or equal to input - val weekConstraintHolds = strengthenPostcondition(funDef, relationComparator.softDecreasing) - - if (weekConstraintHolds) { - // try to improve postcondition with strictly smaller - strengthenPostcondition(funDef, relationComparator.sizeDecreasing) - } + private def withoutPosts[T](block: => T): T = { + val dangerousFunDefs = checker.functions.filter(fd => !checker.terminates(fd).isGuaranteed) + val backups = dangerousFunDefs.toList map { fd => + val p = fd.postcondition + fd.postcondition = None + () => fd.postcondition = p } - } -} - -trait Solvable { self: Processor => - val structuralSize: StructuralSize - val strengthener: Strengthener - - override def process(problem: Problem): (TraversableOnce[Result], TraversableOnce[Problem]) = { - self.run(problem) + val res : T = block // force evaluation now + backups.foreach(_()) + res } - private var solvers: List[SolverFactory[Solver]] = null - private var lastDefs: Set[FunDef] = Set() - - def strengthenPostconditions(funDefs: Set[FunDef]) = strengthener.strengthenPostconditions(funDefs)(this) - - private def initSolvers { - val structDefs = structuralSize.defs - if (structDefs != lastDefs || solvers == null) { - val program : Program = self.checker.program - val newProgram : Program = program.copy(modules = ModuleDef(FreshIdentifier("structDefs"), structDefs.toSeq) :: program.modules) - val context : LeonContext = self.checker.context - - solvers = new TimeoutSolverFactory(SolverFactory(() => new FairZ3Solver(context, newProgram) with TimeoutSolver), 500) :: Nil + def maybeSAT(problem: Expr): Boolean = { + withoutPosts { + SimpleSolverAPI(solver).solveSAT(problem)._1 getOrElse true } } - type Solution = (Option[Boolean], Map[Identifier, Expr]) - - private def solve(problem: Expr): Solution = { - initSolvers - // drop functions from constraints that might not terminate (and may therefore - // make Leon unroll them forever...) - val dangerousCallsMap : Map[Expr, Expr] = functionCallsOf(problem).collect({ - // extra definitions (namely size functions) are quaranteed to terminate because structures are non-looping - case fi @ FunctionInvocation(tfd, args) if !structuralSize.defs(tfd.fd) && !self.checker.terminates(tfd.fd).isGuaranteed => - fi -> FreshIdentifier("noRun", true).setType(fi.getType).toVariable - }).toMap - - // TODO: Fix&check without the recursive=false - //val expr = searchAndReplace(dangerousCallsMap.get, recursive=false)(problem) - val expr = postMap(dangerousCallsMap.lift)(problem) - - object Solved { - def unapply(se: SolverFactory[Solver]): Option[Solution] = { - val (satResult, model) = SimpleSolverAPI(se).solveSAT(expr) - - if (!satResult.isDefined) None - else Some(satResult -> model) - } + def definitiveALL(problem: Expr): Boolean = { + withoutPosts { + SimpleSolverAPI(solver).solveSAT(Not(problem))._1.map(!_) getOrElse false } - - solvers.collectFirst({ case Solved(s, model) => (s, model) }) getOrElse (None, Map()) } - def isStrongSAT(problem: Expr): Boolean = solve(problem)._1 getOrElse false - - def isWeakSAT(problem: Expr): Boolean = solve(problem)._1 getOrElse true - - def isAlwaysSAT(problem: Expr): Boolean = solve(Not(problem))._1.map(!_) getOrElse false - - def getModel(problem: Expr): Option[Map[Identifier, Expr]] = { - val solution = solve(problem) - if (solution._1 getOrElse false) Some(solution._2) - else None + def definitiveSATwithModel(problem: Expr): Option[Map[Identifier, Expr]] = { + withoutPosts { + val (sat, model) = SimpleSolverAPI(solver).solveSAT(problem) + if (sat.isDefined && sat.get) Some(model) else None + } } } class ProcessingPipeline(program: Program, context: LeonContext, _processors: Processor*) { - implicit val debugSection = DebugSectionTermination + implicit val debugSection = utils.DebugSectionTermination import scala.collection.mutable.{Queue => MutableQueue} @@ -182,6 +116,7 @@ class ProcessingPipeline(program: Program, context: LeonContext, _processors: Pr for(result <- results) result match { case Cleared(fd) => sb.append(" %-10s %s\n".format(fd.id, "Cleared")) case Broken(fd, args) => sb.append(" %-10s %s\n".format(fd.id, "Broken for arguments: " + args.mkString("(", ",", ")"))) + case MaybeBroken(fd, args) => sb.append(" %-10s %s\n".format(fd.id, "HO construct application breaks for arguments: " + args.mkString("(", ",", ")"))) } reporter.debug(sb.toString) } @@ -205,7 +140,8 @@ class ProcessingPipeline(program: Program, context: LeonContext, _processors: Pr printQueue val (problem, index) = problems.head val processor : Processor = processors(index) - val (_results, nextProblems) = processor.process(problem) + reporter.debug("Running " + processor.name) + val (_results, nextProblems) = processor.run(problem) val results = _results.toList printResult(results) diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala index f7bf287b1..0238612e4 100644 --- a/src/main/scala/leon/termination/RecursionProcessor.scala +++ b/src/main/scala/leon/termination/RecursionProcessor.scala @@ -9,7 +9,7 @@ import purescala.Definitions._ import scala.annotation.tailrec -class RecursionProcessor(checker: TerminationChecker, relationBuilder: RelationBuilder) extends Processor(checker) { +class RecursionProcessor(val checker: TerminationChecker with RelationBuilder) extends Processor { val name: String = "Recursion Processor" @@ -25,12 +25,12 @@ class RecursionProcessor(checker: TerminationChecker, relationBuilder: RelationB def run(problem: Problem) = if (problem.funDefs.size > 1) (Nil, List(problem)) else { val funDef = problem.funDefs.head - val relations = relationBuilder.run(funDef) - val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(fd, _)) => fd == funDef }) + val relations = checker.getRelations(funDef) + val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(tfd, _), _) => tfd.fd == funDef }) - if (others.exists({ case Relation(_, _, FunctionInvocation(tfd, _)) => !checker.terminates(tfd.fd).isGuaranteed })) (Nil, List(problem)) else { + if (others.exists({ case Relation(_, _, FunctionInvocation(tfd, _), _) => !checker.terminates(tfd.fd).isGuaranteed })) (Nil, List(problem)) else { val decreases = funDef.params.zipWithIndex.exists({ case (arg, index) => - recursive.forall({ case Relation(_, _, FunctionInvocation(_, args)) => + recursive.forall({ case Relation(_, _, FunctionInvocation(_, args), _) => isSubtreeOf(args(index), arg.id) }) }) diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala index 1eab5be8c..a597ebde8 100644 --- a/src/main/scala/leon/termination/RelationBuilder.scala +++ b/src/main/scala/leon/termination/RelationBuilder.scala @@ -3,74 +3,54 @@ package leon package termination -import purescala.Definitions._ import purescala.Trees._ import purescala.TreeOps._ import purescala.Extractors._ import purescala.Common._ +import purescala.Definitions._ import scala.collection.mutable.{Map => MutableMap} -final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocation) { - override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.tfd.fd.id + call.args.mkString("(",",",")") + ")" +final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocation, inAnon: Boolean) { + override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.tfd.id + call.args.mkString("(",",",")") + "," + inAnon + ")" } -class RelationBuilder { - private val relationCache : MutableMap[FunDef, Set[Relation]] = MutableMap() - - def run(funDef: FunDef): Set[Relation] = relationCache.get(funDef) match { - case Some(relations) => relations - case None => { - def visit(e: Expr, path: List[Expr]): Set[Relation] = e match { - - // we skip functions that aren't in this SCC since the call relations - // associated to them don't interest us. - case fi @ FunctionInvocation(f, args) => - val argRelations = args.flatMap(visit(_, path)).toSet - argRelations + Relation(funDef, path.reverse.toSeq, fi) - - case Let(i, e, b) => - val ve = visit(e, path) - val vb = visit(b, Equals(Variable(i), e) :: path) - ve ++ vb - - case IfExpr(cond, thenn, elze) => - val vc = visit(cond, path) - val vt = visit(thenn, cond :: path) - val ve = visit(elze, Not(cond) :: path) - vc ++ vt ++ ve +trait RelationBuilder { self: TerminationChecker with Strengthener => - case And(es) => - def resolveAnds(ands: List[Expr], p: List[Expr]): Set[Relation] = ands match { - case x :: xs => visit(x, p ++ path) ++ resolveAnds(xs, x :: p) - case Nil => Set() - } - resolveAnds(es.toList, Nil) + protected type RelationSignature = (FunDef, Option[Expr], Option[Expr], Option[Expr], Boolean, Set[(FunDef, Boolean)]) - case Or(es) => - def resolveOrs(ors: List[Expr], p: List[Expr]): Set[Relation] = ors match { - case x :: xs => visit(x, p ++ path) ++ resolveOrs(xs, Not(x) :: p) - case Nil => Set() - } - resolveOrs(es.toList, Nil) - - case UnaryOperator(e, _) => visit(e, path) - - case BinaryOperator(e1, e2, _) => visit(e1, path) ++ visit(e2, path) - - case NAryOperator(es, _) => es.map(visit(_, path)).flatten.toSet - - case t : Terminal => Set() + protected def funDefRelationSignature(fd: FunDef): RelationSignature = { + val strengthenedCallees = Set.empty[(FunDef, Boolean)] // self.program.callGraph.callees(fd).map(fd => fd -> strengthened(fd)) + (fd, fd.precondition, fd.body, fd.postcondition.map(_._2), self.terminates(fd).isGuaranteed, strengthenedCallees) + } - case _ => sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") + private val relationCache : MutableMap[FunDef, (Set[Relation], RelationSignature)] = MutableMap.empty + + def getRelations(funDef: FunDef): Set[Relation] = relationCache.get(funDef) match { + case Some((relations, signature)) if signature == funDefRelationSignature(funDef) => relations + case _ => { + val collector = new CollectorWithPaths[Relation] { + var inAnon: Boolean = false + def collect(e: Expr, path: Seq[Expr]): Option[Relation] = e match { + case fi @ FunctionInvocation(f, args) if self.functions(f.fd) => + Some(Relation(funDef, path, fi, inAnon)) +// case af @ AnonymousFunction(args, body) => +// inAnon = true +// None + case _ => None + } + + override def walk(e: Expr, path: Seq[Expr]) = e match { + case FunctionInvocation(fd, args) => + Some(FunctionInvocation(fd, (fd.params.map(_.id) zip args) map { case (id, arg) => + rec(arg, /* register(self.applicationConstraint(fd, id, arg, args), path) */ path) + })) + case _ => None + } } - val precondition = funDef.precondition getOrElse BooleanLiteral(true) - val precRelations = funDef.precondition.map(e => visit(simplifyLets(matchToIfThenElse(e)), Nil)) getOrElse Set() - val bodyRelations = funDef.body.map(e => visit(simplifyLets(matchToIfThenElse(e)), List(precondition))) getOrElse Set() - val postRelations = funDef.postcondition.map(e => visit(simplifyLets(matchToIfThenElse(e._2)), Nil)) getOrElse Set() - val relations = precRelations ++ bodyRelations ++ postRelations - relationCache(funDef) = relations + val relations = collector.traverse(funDef).toSet + relationCache(funDef) = (relations, funDefRelationSignature(funDef)) relations } } diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala index c54ed7d05..eb886d5ea 100644 --- a/src/main/scala/leon/termination/RelationComparator.scala +++ b/src/main/scala/leon/termination/RelationComparator.scala @@ -9,12 +9,11 @@ import purescala.TypeTrees._ import purescala.Definitions._ import purescala.Common._ -class RelationComparator(structuralSize: StructuralSize) { - import structuralSize.size +trait RelationComparator { self : StructuralSize => - def sizeDecreasing(e1: Expr, e2: Expr) = GreaterThan(size(e1), size(e2)) + def sizeDecreasing(e1: Expr, e2: Expr) = GreaterThan(self.size(e1), self.size(e2)) - def softDecreasing(e1: Expr, e2: Expr) = GreaterEquals(size(e1), size(e2)) + def softDecreasing(e1: Expr, e2: Expr) = GreaterEquals(self.size(e1), self.size(e2)) } // vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala index 44fb0c8f5..4982968e2 100644 --- a/src/main/scala/leon/termination/RelationProcessor.scala +++ b/src/main/scala/leon/termination/RelationProcessor.scala @@ -10,25 +10,26 @@ import leon.purescala.Common._ import leon.purescala.Extractors._ import leon.purescala.Definitions._ -class RelationProcessor(checker: TerminationChecker, - relationBuilder: RelationBuilder, - val structuralSize: StructuralSize, - relationComparator: RelationComparator, - val strengthener: Strengthener) extends Processor(checker) with Solvable { +class RelationProcessor( + val checker: TerminationChecker with RelationBuilder with RelationComparator with Strengthener with StructuralSize + ) extends Processor with Solvable { val name: String = "Relation Processor" def run(problem: Problem) = { + reporter.debug("- Strengthening postconditions") + checker.strengthenPostconditions(problem.funDefs)(this) - strengthenPostconditions(problem.funDefs) +// reporter.debug("- Strengthening applications") +// checker.strengthenApplications(problem.funDefs)(this) val formulas = problem.funDefs.map({ funDef => - funDef -> relationBuilder.run(funDef).collect({ - case Relation(_, path, FunctionInvocation(tfd, args)) if problem.funDefs(tfd.fd) => + funDef -> checker.getRelations(funDef).collect({ + case Relation(_, path, FunctionInvocation(tfd, args), _) if problem.funDefs(tfd.fd) => val (e1, e2) = (Tuple(funDef.params.map(_.toVariable)), Tuple(args)) def constraint(expr: Expr) = Implies(And(path.toSeq), expr) - val greaterThan = relationComparator.sizeDecreasing(e1, e2) - val greaterEquals = relationComparator.softDecreasing(e1, e2) + val greaterThan = checker.sizeDecreasing(e1, e2) + val greaterEquals = checker.softDecreasing(e1, e2) (tfd.fd, (constraint(greaterThan), constraint(greaterEquals))) }) }) @@ -38,15 +39,16 @@ class RelationProcessor(checker: TerminationChecker, case class Dep(deps: Set[FunDef]) extends Result case object Failure extends Result + reporter.debug("- Searching for structural size decrease") val decreasing = formulas.map({ case (fd, formulas) => val solved = formulas.map({ case (fid, (gt, ge)) => - if(isAlwaysSAT(gt)) Success - else if(isAlwaysSAT(ge)) Dep(Set(fid)) + if (definitiveALL(gt)) Success + else if (definitiveALL(ge)) Dep(Set(fid)) else Failure }) val result = if(solved.contains(Failure)) Failure else { val deps = solved.collect({ case Dep(fds) => fds }).flatten - if(deps.isEmpty) Success + if (deps.isEmpty) Success else Dep(deps) } fd -> result diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala index e87e3adf6..1a8f6aaef 100644 --- a/src/main/scala/leon/termination/SimpleTerminationChecker.scala +++ b/src/main/scala/leon/termination/SimpleTerminationChecker.scala @@ -12,7 +12,7 @@ import scala.collection.mutable.{ Map => MutableMap } import scala.annotation.tailrec -class SimpleTerminationChecker(context: LeonContext, program: Program) extends TerminationChecker(context, program) { +class SimpleTerminationChecker(context: LeonContext, program: Program) extends TerminationChecker(context, program) with ComponentBuilder { val name = "T1" val description = "The simplest form of Terminator™" @@ -20,7 +20,7 @@ class SimpleTerminationChecker(context: LeonContext, program: Program) extends T private lazy val callGraph: Map[FunDef, Set[FunDef]] = program.callGraph.allCalls.groupBy(_._1).mapValues(_.map(_._2)) // one liner from hell - private lazy val components = ComponentBuilder.run(callGraph) + private lazy val components = getComponents(callGraph) val allVertices = callGraph.keySet ++ callGraph.values.flatten val sccArray = components.toArray diff --git a/src/main/scala/leon/termination/Strengthener.scala b/src/main/scala/leon/termination/Strengthener.scala new file mode 100644 index 000000000..31f076498 --- /dev/null +++ b/src/main/scala/leon/termination/Strengthener.scala @@ -0,0 +1,174 @@ +package leon +package termination + +import purescala.Trees._ +import purescala.TypeTrees._ +import purescala.TreeOps._ +import purescala.Common._ +import purescala.Definitions._ + +import scala.collection.mutable.{Set => MutableSet, Map => MutableMap} + +trait Strengthener { self : TerminationChecker with RelationComparator with RelationBuilder => + + private val strengthenedPost : MutableSet[FunDef] = MutableSet.empty + + def strengthenPostconditions(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) { + // Strengthen postconditions on all accessible functions by adding size constraints + val callees : Set[FunDef] = funDefs.map(fd => self.program.callGraph.transitiveCallees(fd)).flatten + val sortedCallees : Seq[FunDef] = callees.toSeq.sortWith((fd1, fd2) => self.program.callGraph.transitivelyCalls(fd2, fd1)) + + for (funDef <- sortedCallees if !strengthenedPost(funDef) && funDef.hasBody && self.terminates(funDef).isGuaranteed) { + def strengthen(cmp: (Expr, Expr) => Expr): Boolean = { + val old = funDef.postcondition + val (res, postcondition) = { + val (res, post) = old.getOrElse(FreshIdentifier("res").setType(funDef.returnType) -> BooleanLiteral(true)) + val args = funDef.params.map(_.toVariable) + val sizePost = cmp(Tuple(funDef.params.map(_.toVariable)), res.toVariable) + (res, And(post, sizePost)) + } + + funDef.postcondition = Some(res -> postcondition) + + val prec = matchToIfThenElse(funDef.precondition.getOrElse(BooleanLiteral(true))) + val body = matchToIfThenElse(funDef.body.get) + val post = matchToIfThenElse(postcondition) + val formula = Implies(prec, Let(res, body, post)) + + if (!solver.definitiveALL(formula)) { + funDef.postcondition = old + false + } else { + true + } + } + + // test if size is smaller or equal to input + val weekConstraintHolds = strengthen(self.softDecreasing) + + if (weekConstraintHolds) { + // try to improve postcondition with strictly smaller + strengthen(self.sizeDecreasing) + } + + strengthenedPost += funDef + } + } + + sealed abstract class SizeConstraint + case object StrongDecreasing extends SizeConstraint + case object WeakDecreasing extends SizeConstraint + case object NoConstraint extends SizeConstraint + + /* + private val strengthenedApp : MutableSet[FunDef] = MutableSet.empty + + protected def strengthened(fd: FunDef): Boolean = strengthenedApp(fd) + + private val appConstraint : MutableMap[(FunDef, Identifier), SizeConstraint] = MutableMap.empty + + def applicationConstraint(fd: FunDef, id: Identifier, arg: Expr, args: Seq[Expr]): Expr = arg match { + case AnonymousFunction(fargs, body) => appConstraint.get(fd -> id) match { + case Some(StrongDecreasing) => self.sizeDecreasing(Tuple(args), Tuple(fargs.map(_.toVariable))) + case Some(WeakDecreasing) => self.softDecreasing(Tuple(args), Tuple(fargs.map(_.toVariable))) + case _ => BooleanLiteral(true) + } + case _ => BooleanLiteral(true) + } + + def strengthenApplications(funDefs: Set[FunDef])(implicit solver: Processor with Solvable) { + val transitiveFunDefs = funDefs ++ funDefs.flatMap(fd => self.program.callGraph.transitiveCallees(fd)) + val sortedFunDefs = transitiveFunDefs.toSeq.sortWith((fd1, fd2) => self.program.callGraph.transitivelyCalls(fd2, fd1)) + + for (funDef <- sortedFunDefs if !strengthenedApp(funDef) && funDef.hasBody && self.terminates(funDef).isGuaranteed) { + + val appCollector = new CollectorWithPaths[(Identifier,Expr,Expr)] { + def collect(e: Expr, path: Seq[Expr]): Option[(Identifier, Expr, Expr)] = e match { + case FunctionApplication(Variable(id), args) => Some((id, And(path), Tuple(args))) + case _ => None + } + } + + val applications = appCollector.traverse(funDef).distinct + + val funDefArgTuple = Tuple(funDef.args.map(_.toVariable)) + + val allFormulas = for ((id, path, appArgs) <- applications) yield { + val soft = Implies(path, self.softDecreasing(funDefArgTuple, appArgs)) + val hard = Implies(path, self.sizeDecreasing(funDefArgTuple, appArgs)) + id -> ((soft, hard)) + } + + val formulaMap = allFormulas.groupBy(_._1).mapValues(_.map(_._2).unzip) + + val constraints = for ((id, (weakFormulas, strongFormulas)) <- formulaMap) yield id -> { + if (solver.definitiveALL(And(weakFormulas.toSeq), funDef.precondition)) { + if (solver.definitiveALL(And(strongFormulas.toSeq), funDef.precondition)) { + StrongDecreasing + } else { + WeakDecreasing + } + } else { + NoConstraint + } + } + + val funDefHOArgs = funDef.args.map(_.id).filter(_.getType.isInstanceOf[FunctionType]).toSet + + val fiCollector = new CollectorWithPaths[(Expr, Expr, Seq[(Identifier,(FunDef, Identifier))])] { + def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Expr, Seq[(Identifier,(FunDef, Identifier))])] = e match { + case FunctionInvocation(fd, args) if (funDefHOArgs intersect args.collect({ case Variable(id) => id }).toSet).nonEmpty => + Some((And(path), Tuple(args), (args zip fd.args).collect { + case (Variable(id), vd) if funDefHOArgs(id) => id -> ((fd, vd.id)) + })) + case _ => None + } + } + + val invocations = fiCollector.traverse(funDef) + val id2invocations : Seq[(Identifier, ((FunDef, Identifier), Expr, Expr))] = + invocations.flatMap(p => p._3.map(c => c._1 -> ((c._2, p._1, p._2)))) + val invocationMap : Map[Identifier, Seq[((FunDef, Identifier), Expr, Expr)]] = + id2invocations.groupBy(_._1).mapValues(_.map(_._2)) + + def constraint(id: Identifier, passings: Seq[((FunDef, Identifier), Expr, Expr)]): SizeConstraint = { + if (constraints.get(id) == Some(NoConstraint)) NoConstraint + else if (passings.exists(p => appConstraint.get(p._1) == Some(NoConstraint))) NoConstraint + else passings.foldLeft[SizeConstraint](constraints.getOrElse(id, StrongDecreasing)) { + case (constraint, (key, path, args)) => + + lazy val strongFormula = Implies(path, self.sizeDecreasing(funDefArgTuple, args)) + lazy val weakFormula = Implies(path, self.softDecreasing(funDefArgTuple, args)) + + (constraint, appConstraint.get(key)) match { + case (_, Some(NoConstraint)) => scala.sys.error("Whaaaat!?!? This shouldn't happen...") + case (_, None) => NoConstraint + case (NoConstraint, _) => NoConstraint + case (StrongDecreasing | WeakDecreasing, Some(StrongDecreasing)) => + if (solver.definitiveALL(weakFormula, funDef.precondition)) StrongDecreasing + else NoConstraint + case (StrongDecreasing, Some(WeakDecreasing)) => + if (solver.definitiveALL(strongFormula, funDef.precondition)) StrongDecreasing + else if (solver.definitiveALL(weakFormula, funDef.precondition)) WeakDecreasing + else NoConstraint + case (WeakDecreasing, Some(WeakDecreasing)) => + if (solver.definitiveALL(weakFormula, funDef.precondition)) WeakDecreasing + else NoConstraint + } + } + } + + val outers = invocationMap.mapValues(_.filter(_._1._1 != funDef)) + funDefHOArgs.foreach { id => appConstraint(funDef -> id) = constraint(id, outers.getOrElse(id, Seq.empty)) } + + val selfs = invocationMap.mapValues(_.filter(_._1._1 == funDef)) + funDefHOArgs.foreach { id => appConstraint(funDef -> id) = constraint(id, selfs.getOrElse(id, Seq.empty)) } + + strengthenedApp += funDef + } + } + */ +} + + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index 9c0b058a7..a6f221500 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -9,58 +9,59 @@ import purescala.TypeTrees._ import purescala.Definitions._ import purescala.Common._ -class StructuralSize() { - import scala.collection.mutable.{Map => MutableMap} +import scala.collection.mutable.{Map => MutableMap} - private val sizeFunctionCache : MutableMap[TypeTree, TypedFunDef] = MutableMap() +trait StructuralSize { + + private val sizeCache : MutableMap[TypeTree, FunDef] = MutableMap.empty def size(expr: Expr) : Expr = { - def funDef(tpe: TypeTree, cases: => Seq[MatchCase]) = { + def funDef(ct: ClassType, cases: ClassType => Seq[MatchCase]): FunDef = { // we want to reuse generic size functions for sub-types - val argumentType = tpe match { - case CaseClassType(cd, tpes) if cd.parent.isDefined => classDefToClassType(cd.parent.get.classDef, tpes) - case _ => tpe + val (argumentType, typeParams) = ct match { + case (cct : CaseClassType) if cct.parent.isDefined => + val classDef = cct.parent.get.classDef + val tparams = classDef.tparams.map(_.tp) + (classDefToClassType(classDef, tparams), tparams) + case (ct : ClassType) => + val tparams = ct.classDef.tparams.map(_.tp) + (classDefToClassType(ct.classDef, tparams), tparams) } - sizeFunctionCache.get(argumentType) match { + sizeCache.get(argumentType) match { case Some(fd) => fd case None => - val argument = ValDef(FreshIdentifier("x"), argumentType) - val fd = new FunDef(FreshIdentifier("size", true), Nil, Int32Type, Seq(argument)) - val tfd = fd.typed(Nil) - sizeFunctionCache(argumentType) = tfd + val argument = ValDef(FreshIdentifier("x").setType(argumentType), argumentType) + val formalTParams = typeParams.map(TypeParameterDef(_)) + val fd = new FunDef(FreshIdentifier("size", true), formalTParams, Int32Type, Seq(argument)) + sizeCache(argumentType) = fd - val body = simplifyLets(matchToIfThenElse(MatchExpr(argument.toVariable, cases))) + val body = simplifyLets(matchToIfThenElse(MatchExpr(argument.toVariable, cases(argumentType)))) val postId = FreshIdentifier("res", false).setType(Int32Type) - val postSubcalls = functionCallsOf(body).map(GreaterThan(_, IntLiteral(0))).toSeq - val postRecursive = GreaterThan(Variable(postId), IntLiteral(0)) - val postcondition = And(postSubcalls :+ postRecursive) + val postcondition = GreaterThan(Variable(postId), IntLiteral(0)) fd.body = Some(body) fd.postcondition = Some(postId, postcondition) - - tfd + fd } } - def caseClassType2MatchCase(ct: ClassType): MatchCase = ct match { - case cct: CaseClassType => - val arguments = cct.fields.map(f => f -> f.id.freshen) - val argumentPatterns = arguments.map(p => WildcardPattern(Some(p._2))) - val sizes = arguments.map(p => size(Variable(p._2))) - val result = sizes.foldLeft[Expr](IntLiteral(1))(Plus(_,_)) - SimpleCase(CaseClassPattern(None, cct, argumentPatterns), result) - case _ => - sys.error("woot?") + def caseClassType2MatchCase(_c: ClassType): MatchCase = { + val c = _c.asInstanceOf[CaseClassType] // required by leon framework + val arguments = c.fields.map(f => f -> f.id.freshen) + val argumentPatterns = arguments.map(p => WildcardPattern(Some(p._2))) + val sizes = arguments.map(p => size(Variable(p._2))) + val result = sizes.foldLeft[Expr](IntLiteral(1))(Plus(_,_)) + SimpleCase(CaseClassPattern(None, c, argumentPatterns), result) } expr.getType match { - case a: AbstractClassType => - val sizeFd = funDef(a, a.knownCCDescendents.map(caseClassType2MatchCase)) - FunctionInvocation(sizeFd, Seq(expr)) - case c: CaseClassType => - val sizeFd = funDef(c, Seq(caseClassType2MatchCase(c))) - FunctionInvocation(sizeFd, Seq(expr)) + case (ct: ClassType) => + val fd = funDef(ct, _ match { + case (act: AbstractClassType) => act.knownCCDescendents map caseClassType2MatchCase + case (cct: CaseClassType) => Seq(caseClassType2MatchCase(cct)) + }) + FunctionInvocation(TypedFunDef(fd, ct.tps), Seq(expr)) case TupleType(argTypes) => argTypes.zipWithIndex.map({ case (_, index) => size(TupleSelect(expr, index + 1)) }).foldLeft[Expr](IntLiteral(0))(Plus(_,_)) @@ -68,7 +69,7 @@ class StructuralSize() { } } - def defs : Set[FunDef] = sizeFunctionCache.values.map(_.fd).toSet + def defs : Set[FunDef] = Set(sizeCache.values.toSeq : _*) } // vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/termination/TerminationChecker.scala b/src/main/scala/leon/termination/TerminationChecker.scala index 41c46e07b..d90333484 100644 --- a/src/main/scala/leon/termination/TerminationChecker.scala +++ b/src/main/scala/leon/termination/TerminationChecker.scala @@ -6,8 +6,10 @@ package termination import purescala.Definitions._ import purescala.Trees._ -abstract class TerminationChecker(val context : LeonContext, val program : Program) extends LeonComponent { - +abstract class TerminationChecker(val context: LeonContext, val program: Program) extends LeonComponent { + + val functions = program.definedFunctions.toSet + def initialize() : Unit def terminates(funDef : FunDef) : TerminationGuarantee } @@ -27,6 +29,7 @@ abstract class NonTerminating extends TerminationGuarantee { } case class LoopsGivenInputs(justification: String, args: Seq[Expr]) extends NonTerminating +case class MaybeLoopsGivenInputs(justification: String, args: Seq[Expr]) extends NonTerminating case class CallsNonTerminating(calls: Set[FunDef]) extends NonTerminating diff --git a/src/main/scala/leon/termination/SCC.scala b/src/main/scala/leon/utils/SCC.scala similarity index 97% rename from src/main/scala/leon/termination/SCC.scala rename to src/main/scala/leon/utils/SCC.scala index 0188a4cab..f82b6262e 100644 --- a/src/main/scala/leon/termination/SCC.scala +++ b/src/main/scala/leon/utils/SCC.scala @@ -1,7 +1,7 @@ /* Copyright 2009-2014 EPFL, Lausanne */ package leon -package termination +package utils /** This could be defined anywhere, it's just that the termination checker is the only place where it is used. */ @@ -47,7 +47,7 @@ object SCC { for(v <- allVertices) { if(!indices.isDefinedAt(v)) { strongConnect(v) - } + } } components diff --git a/src/main/scala/leon/verification/DefaultTactic.scala b/src/main/scala/leon/verification/DefaultTactic.scala index cd7cd2d68..f947286b8 100644 --- a/src/main/scala/leon/verification/DefaultTactic.scala +++ b/src/main/scala/leon/verification/DefaultTactic.scala @@ -159,10 +159,10 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { val toRet = if (function.hasBody) { val cleanBody = matchToIfThenElse(function.body.get) - val allPathConds = new CollectorWithPaths({ + val allPathConds = CollectorWithPaths { case expr@ArraySelect(a, i) => (expr, a, i) case expr@ArrayUpdated(a, i, _) => (expr, a, i) - }).traverse(cleanBody) + }.traverse(cleanBody) val arrayAccessConditions = allPathConds.map{ case ((expr, array, index), pathCond) => { @@ -199,7 +199,7 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { } def collectWithPathCondition(matcher: Expr=>Boolean, expression: Expr) : Set[(Seq[Expr],Expr)] = { - new CollectorWithPaths({ + CollectorWithPaths({ case e if matcher(e) => e }).traverse(expression).map{ case (e, And(es)) => (es, e) diff --git a/src/test/resources/regression/termination/looping/Numeric3.scala b/src/test/resources/regression/termination/looping/Numeric3.scala new file mode 100644 index 000000000..e0616cc13 --- /dev/null +++ b/src/test/resources/regression/termination/looping/Numeric3.scala @@ -0,0 +1,8 @@ +import leon.Utils._ + +object Numeric3 { + def looping(x: Int) : Int = if (x > 0) looping(x - 1) else looping(5) +} + + +// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/termination/unknown/Numeric3.scala b/src/test/resources/regression/termination/unknown/Numeric3.scala deleted file mode 100644 index a375765d5..000000000 --- a/src/test/resources/regression/termination/unknown/Numeric3.scala +++ /dev/null @@ -1,10 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -import leon.lang._ - -object Numeric3 { - def unknown(x: Int) : Int = if (x > 0) unknown(x - 1) else unknown(5) -} - - -// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/termination/valid/ComplexChains.scala b/src/test/resources/regression/termination/valid/ComplexChains.scala deleted file mode 100644 index f6316acc0..000000000 --- a/src/test/resources/regression/termination/valid/ComplexChains.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -import leon.lang._ - -object ComplexChains { - - abstract class List - case class Cons(head: Int, tail: List) extends List - case class Nil() extends List - - def f1(list: List): List = list match { - case Cons(head, tail) if head > 0 => f2(Cons(1, list)) - case Cons(head, tail) if head < 0 => f3(Cons(-1, list)) - case Cons(head, tail) => f1(tail) - case Nil() => Nil() - } - - def f2(list: List): List = f3(Cons(0, list)) - - def f3(list: List): List = f1(list match { - case Cons(head, Cons(head2, Cons(head3, tail))) => tail - case Cons(head, Cons(head2, tail)) => tail - case Cons(head, tail) => tail - case Nil() => Nil() - }) -} - -// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/termination/valid/Termination_passing2.scala b/src/test/resources/regression/termination/valid/Termination_passing2.scala deleted file mode 100644 index 296104dca..000000000 --- a/src/test/resources/regression/termination/valid/Termination_passing2.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -import leon.lang._ - -object Termination { - def f1(x: Int, b: Boolean) : Int = { - if (b) f2(x) - else f3(x) - } - - def f2(x: Int) : Int = { - if (x < 0) 0 - else f1(x - 1, true) - } - - def f3(x: Int) : Int = { - if (x > 0) 0 - else f1(x + 1, false) - } -} - -// vim: set ts=4 sw=4 et: diff --git a/src/test/scala/leon/test/termination/TerminationRegression.scala b/src/test/scala/leon/test/termination/TerminationRegression.scala index 65f20d036..964202970 100644 --- a/src/test/scala/leon/test/termination/TerminationRegression.scala +++ b/src/test/scala/leon/test/termination/TerminationRegression.scala @@ -81,7 +81,8 @@ class TerminationRegression extends LeonTestSuite { forEachFileIn("looping") { output => val Output(report, reporter) = output val looping = report.results.filter { case (fd, guarantee) => fd.id.name.startsWith("looping") } - assert(looping.forall(_._2.isInstanceOf[LoopsGivenInputs]), "Functions " + looping.filter(!_._2.isInstanceOf[LoopsGivenInputs]).map(_._1.id) + " should loop") + assert(looping.forall(p => p._2.isInstanceOf[LoopsGivenInputs] || p._2.isInstanceOf[CallsNonTerminating]), + "Functions " + looping.filter(p => !p._2.isInstanceOf[LoopsGivenInputs] && !p._2.isInstanceOf[CallsNonTerminating]).map(_._1.id) + " should loop") val calling = report.results.filter { case (fd, guarantee) => fd.id.name.startsWith("calling") } assert(calling.forall(_._2.isInstanceOf[CallsNonTerminating]), "Functions " + calling.filter(!_._2.isInstanceOf[CallsNonTerminating]).map(_._1.id) + " should call non-terminating") val ok = report.results.filter { case (fd, guarantee) => fd.id.name.startsWith("ok") } -- GitLab