From 47f8001ce8b8980c5eee5521cf45f3a67d7c3d90 Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Fri, 11 Jan 2013 20:58:49 +0100
Subject: [PATCH] Termination checker.

This commit introduces a termination checker. Needless to say, it is
rather primitive. The goal is rather to set up the interfaces, and to
have something that can immediately prove the most obvious cases. The
current `SimpleTerminationChecker` implementation computes
strongly-connected components, and proves that a function `f` terminates
for all inputs if and only if:

  1. `f` has a body
  2. `f` has no precondition
  3. `f` calls only functions that terminate for all inputs or itself
      and, whenever `f` calls itself, it decreases one of its algebraic
      data type arguments.

The astute reader will note that in particular,
`SimpleTerminationChecker` cannot prove anything about:

  1. functions with a precondition
  2. mutually recursive functions
  3. recursive functions that operate on integers only

I am confident that this simple termination checker will pave the way
for future implementations, though, and that we will end up re-inventing
the wheel so many times that we'll be able to equip many trains.
---
 src/main/scala/leon/Main.scala                |  18 ++-
 src/main/scala/leon/Settings.scala            |   1 +
 src/main/scala/leon/termination/SCC.scala     |  68 ++++++++++
 .../SimpleTerminationChecker.scala            | 116 ++++++++++++++++++
 .../leon/termination/TerminationChecker.scala |  18 +++
 .../leon/termination/TerminationPhase.scala   |  20 +++
 .../leon/termination/TerminationReport.scala  |  18 +++
 7 files changed, 256 insertions(+), 3 deletions(-)
 create mode 100644 src/main/scala/leon/termination/SCC.scala
 create mode 100644 src/main/scala/leon/termination/SimpleTerminationChecker.scala
 create mode 100644 src/main/scala/leon/termination/TerminationChecker.scala
 create mode 100644 src/main/scala/leon/termination/TerminationPhase.scala
 create mode 100644 src/main/scala/leon/termination/TerminationReport.scala

diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala
index 9c81f0ce7..354a91659 100644
--- a/src/main/scala/leon/Main.scala
+++ b/src/main/scala/leon/Main.scala
@@ -11,6 +11,7 @@ object Main {
       xlang.FunctionClosure,
       xlang.XlangAnalysisPhase,
       synthesis.SynthesisPhase,
+      termination.TerminationPhase,
       verification.AnalysisPhase
     )
   }
@@ -22,6 +23,7 @@ object Main {
   )
 
   lazy val topLevelOptions : Set[LeonOptionDef] = Set(
+      LeonFlagOptionDef ("termination",  "--termination", "Check program termination"),
       LeonFlagOptionDef ("synthesis",    "--synthesis",   "Partial synthesis of choose() constructs"),
       LeonFlagOptionDef ("xlang",        "--xlang",       "Support for extra program constructs (imperative,...)"),
       LeonFlagOptionDef ("parse",        "--parse",       "Checks only whether the program is valid PureScala"),
@@ -37,7 +39,7 @@ object Main {
   lazy val allOptions = allComponents.flatMap(_.definedOptions) ++ topLevelOptions
 
   def displayHelp(reporter: Reporter) {
-    reporter.info("usage: leon [--xlang] [--synthesis] [--help] [--debug=<N>] [..] <files>")
+    reporter.info("usage: leon [--xlang] [--termination] [--synthesis] [--help] [--debug=<N>] [..] <files>")
     reporter.info("")
     for (opt <- topLevelOptions.toSeq.sortBy(_.name)) {
       reporter.info("%-20s %s".format(opt.usageOption, opt.usageDesc))
@@ -100,6 +102,8 @@ object Main {
 
     // Process options we understand:
     for(opt <- leonOptions) opt match {
+      case LeonFlagOption("termination") =>
+        settings = settings.copy(termination = true, xlang = false, verify = false, synthesis = false)
       case LeonFlagOption("synthesis") =>
         settings = settings.copy(synthesis = true, xlang = false, verify = false)
       case LeonFlagOption("xlang") =>
@@ -127,7 +131,10 @@ object Main {
       }
 
     val pipeVerify: Pipeline[Program, Any] =
-      if (settings.xlang) {
+      if (settings.termination) {
+        // OK, OK, that's not really verification...
+        termination.TerminationPhase
+      } else if (settings.xlang) {
         xlang.XlangAnalysisPhase
       } else if (settings.verify) {
         verification.AnalysisPhase
@@ -152,7 +159,12 @@ object Main {
     try {
       // Run pipeline
       pipeline.run(ctx)(args.toList) match {
-        case (report: verification.VerificationReport) => reporter.info(report.summaryString)
+        case report: verification.VerificationReport =>
+          reporter.info(report.summaryString)
+
+        case report: termination.TerminationReport =>
+          reporter.info(report.summaryString)
+
         case _ =>
       }
     } catch {
diff --git a/src/main/scala/leon/Settings.scala b/src/main/scala/leon/Settings.scala
index 754fbde7f..c267ba4d4 100644
--- a/src/main/scala/leon/Settings.scala
+++ b/src/main/scala/leon/Settings.scala
@@ -33,6 +33,7 @@ object Settings {
 }
 
 case class Settings(
+  val termination: Boolean    = false,
   val synthesis: Boolean      = false,
   val xlang: Boolean          = false,
   val verify: Boolean         = true,
diff --git a/src/main/scala/leon/termination/SCC.scala b/src/main/scala/leon/termination/SCC.scala
new file mode 100644
index 000000000..c0c8f7588
--- /dev/null
+++ b/src/main/scala/leon/termination/SCC.scala
@@ -0,0 +1,68 @@
+package leon
+package termination
+
+/** This could be defined anywhere, it's just that the
+    termination checker is the only place where it is used. */
+object SCC {
+  def scc[T](graph : Map[T,Set[T]]) : (Array[Set[T]],Map[T,Int],Map[Int,Set[Int]]) = {
+    // The first part is a shameless adaptation from Wikipedia
+    val allVertices : Set[T] = graph.keySet ++ graph.values.flatten
+
+    var index = 0
+    var indices  : Map[T,Int] = Map.empty
+    var lowLinks : Map[T,Int] = Map.empty
+    var components : List[Set[T]] = Nil
+    var s : List[T] = Nil
+
+    def strongConnect(v : T) {
+      indices  = indices.updated(v, index)
+      lowLinks = lowLinks.updated(v, index)
+      index += 1
+      s = v :: s
+
+      for(w <- graph.getOrElse(v, Set.empty)) {
+        if(!indices.isDefinedAt(w)) {
+          strongConnect(w)
+          lowLinks = lowLinks.updated(v, lowLinks(v) min lowLinks(w))
+        } else if(s.contains(w)) {
+          lowLinks = lowLinks.updated(v, lowLinks(v) min indices(w))
+        }
+      }
+
+      if(lowLinks(v) == indices(v)) {
+        var c : Set[T] = Set.empty
+        var stop = false
+        do {
+          val x :: xs = s
+          c = c + x
+          s = xs
+          stop = (x == v)
+        } while(!stop);
+        components = c :: components
+      }
+    }
+
+    for(v <- allVertices) {
+      if(!indices.isDefinedAt(v)) {
+        strongConnect(v)
+      }  
+    }
+
+    // At this point, we have our components.
+    // We finish by building a graph between them.
+    // In the graph, components are represented as arrays indices.
+    val asArray = components.toArray
+    val cSize = asArray.length
+
+    val vertIDs : Map[T,Int] = allVertices.map(v =>
+      v -> (0 until cSize).find(i => asArray(i)(v)).get
+    ).toMap
+
+    val bigCallGraph : Map[Int,Set[Int]] = (0 until cSize).map({ i =>
+      val dsts = asArray(i).flatMap(v => graph.getOrElse(v, Set.empty)).map(vertIDs(_))
+      i -> dsts
+    }).toMap
+
+    (asArray,vertIDs,bigCallGraph)
+  }
+}
diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala
new file mode 100644
index 000000000..840c16f25
--- /dev/null
+++ b/src/main/scala/leon/termination/SimpleTerminationChecker.scala
@@ -0,0 +1,116 @@
+package leon
+package termination
+
+import purescala.Common._
+import purescala.Definitions._
+import purescala.Trees._
+import purescala.TreeOps._
+
+import scala.collection.mutable.{Map=>MutableMap}
+
+import scala.annotation.tailrec
+
+class SimpleTerminationChecker(context : LeonContext, program : Program) extends TerminationChecker(context, program) {
+
+  val name = "T1"
+  val description = "The simplest form of Terminatorâ„¢"
+
+  private lazy val callGraph : Map[FunDef,Set[FunDef]] =
+    program.callGraph.groupBy(_._1).mapValues(_.map(_._2)) // one liner from hell
+
+  private lazy val sccTriple = SCC.scc(callGraph)
+  private lazy val sccArray : Array[Set[FunDef]] = sccTriple._1
+  private lazy val funDefToSCCIndex : Map[FunDef,Int] = sccTriple._2
+  private lazy val sccGraph : Map[Int,Set[Int]] = sccTriple._3
+
+  private def callees(funDef : FunDef) : Set[FunDef] = callGraph.getOrElse(funDef, Set.empty)
+
+  private val answerCache = MutableMap.empty[FunDef,TerminationGuarantee]
+
+  def terminates(funDef : FunDef) = answerCache.getOrElse(funDef, {
+    val g = forceCheckTermination(funDef)
+    answerCache(funDef) = g
+    g
+  })
+
+  private def forceCheckTermination(funDef : FunDef) : TerminationGuarantee = {
+    // We would have to clarify what it means to terminate.
+    // We would probably need something along the lines of:
+    //   "Terminates for all values satisfying prec."
+    if(funDef.hasPrecondition)
+      return NoGuarantee
+
+    // This is also too confusing for me to think about now.
+    if(!funDef.hasImplementation)
+      return NoGuarantee
+
+    val sccIndex   = funDefToSCCIndex(funDef)
+    val sccCallees = sccGraph(sccIndex)
+
+    // We check all functions that are in a "lower" scc. These must
+    // terminate for all inputs in any case.
+    val sccLowerCallees = sccCallees.filterNot(_ == sccIndex)
+    val lowerDefs = sccLowerCallees.map(sccArray(_)).foldLeft(Set.empty[FunDef])(_ ++ _)
+    val lowerOK = lowerDefs.forall(terminates(_).isGuaranteed)
+    if(!lowerOK)
+      return NoGuarantee
+
+    // Now all we need to do is check the functions in the same
+    // scc. But who knows, maybe none of these are called?
+    if(!sccCallees(sccIndex)) {
+      // (the distinction isn't exactly useful...)
+      if(sccCallees.isEmpty)
+        return TerminatesForAllInputs("no calls")
+      else
+        return TerminatesForAllInputs("by subcalls")
+    }
+
+    // So now we know the function is recursive (or mutually
+    // recursive). Maybe it's just self-recursive?
+    if(sccArray(sccIndex).size == 1) {
+      assert(sccArray(sccIndex) == Set(funDef))
+      // Yes it is !
+      // Now we apply a simple recipe: we check that in each (self)
+      // call, at least one argument is of an ADT type and decreases.
+      // Yes, it's that restrictive.
+      val callsOfInterest = { (oe : Option[Expr]) => 
+        oe.map { e =>
+          functionCallsOf(
+            simplifyLets(
+              matchToIfThenElse(e)
+            )
+          ).filter(_.funDef == funDef)
+        } getOrElse Set.empty[FunctionInvocation]
+      }
+
+      val callsToAnalyze = callsOfInterest(funDef.body) ++ callsOfInterest(funDef.precondition) ++ callsOfInterest(funDef.postcondition)
+
+      assert(!callsToAnalyze.isEmpty)
+
+      val funDefArgsIDs = funDef.args.map(_.id).toSet
+
+      if(callsToAnalyze.forall { fi =>
+        fi.args.exists { arg =>
+          isSubTreeOfArg(arg, funDefArgsIDs)
+        }
+      }) {
+        return TerminatesForAllInputs("decreasing")
+      } else {
+        return NoGuarantee
+      }
+    }
+
+    // Handling mutually recursive functions is beyond my willpower.
+    NoGuarantee 
+  }
+
+  private def isSubTreeOfArg(expr : Expr, args : Set[Identifier]) : Boolean = {
+    @tailrec
+    def rec(e : Expr, fst : Boolean) : Boolean = e match {
+      case Variable(id) if(args(id)) => !fst
+      case CaseClassSelector(_, cc, _) => rec(cc, false)
+      case _ => false
+    }
+    rec(expr, true)
+  }
+}
diff --git a/src/main/scala/leon/termination/TerminationChecker.scala b/src/main/scala/leon/termination/TerminationChecker.scala
new file mode 100644
index 000000000..ed54acea4
--- /dev/null
+++ b/src/main/scala/leon/termination/TerminationChecker.scala
@@ -0,0 +1,18 @@
+package leon
+package termination
+
+import purescala.Definitions._
+
+abstract class TerminationChecker(val context : LeonContext, val program : Program) extends LeonComponent {
+  
+  def terminates(funDef : FunDef) : TerminationGuarantee
+}
+
+sealed abstract class TerminationGuarantee {
+  val isGuaranteed : Boolean = false
+}
+
+case class TerminatesForAllInputs(justification : String) extends TerminationGuarantee {
+  override val isGuaranteed : Boolean = true
+}
+case object NoGuarantee extends TerminationGuarantee
diff --git a/src/main/scala/leon/termination/TerminationPhase.scala b/src/main/scala/leon/termination/TerminationPhase.scala
new file mode 100644
index 000000000..b59be3184
--- /dev/null
+++ b/src/main/scala/leon/termination/TerminationPhase.scala
@@ -0,0 +1,20 @@
+package leon
+package termination
+
+import purescala.Definitions._
+
+object TerminationPhase extends LeonPhase[Program,TerminationReport] {
+  val name = "Termination"
+  val description = "Check termination of PureScala functions"
+
+  def run(ctx : LeonContext)(program : Program) : TerminationReport = {
+    val tc = new SimpleTerminationChecker(ctx, program)
+
+    val startTime = System.currentTimeMillis
+    val results = program.definedFunctions.toList.sortWith(_ < _).map { funDef =>
+      (funDef -> tc.terminates(funDef))
+    }
+    val endTime = System.currentTimeMillis
+    new TerminationReport(results, (endTime - startTime).toDouble / 1000.0d)
+  } 
+}
diff --git a/src/main/scala/leon/termination/TerminationReport.scala b/src/main/scala/leon/termination/TerminationReport.scala
new file mode 100644
index 000000000..4a4af2a6e
--- /dev/null
+++ b/src/main/scala/leon/termination/TerminationReport.scala
@@ -0,0 +1,18 @@
+package leon
+package termination
+
+import purescala.Definitions._
+
+case class TerminationReport(val results : Seq[(FunDef,TerminationGuarantee)], val time : Double) {
+  def summaryString : String = {
+    val sb = new StringBuilder
+    sb.append("─────────────────────\n")
+    sb.append(" Termination summary \n")
+    sb.append("─────────────────────\n\n")
+    for((fd,g) <- results) {
+      sb.append("- %-30s  %-30s\n".format(fd.id.name, g.toString))
+    }
+    sb.append("\n[Analysis time: %7.3f]\n".format(time))
+    sb.toString
+  }
+}
-- 
GitLab