diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 9c81f0ce7f6aa94562855c3122eafc6af0d3d5e1..354a91659afba9cdf47b0d3d58305cfee84d1927 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -11,6 +11,7 @@ object Main { xlang.FunctionClosure, xlang.XlangAnalysisPhase, synthesis.SynthesisPhase, + termination.TerminationPhase, verification.AnalysisPhase ) } @@ -22,6 +23,7 @@ object Main { ) lazy val topLevelOptions : Set[LeonOptionDef] = Set( + LeonFlagOptionDef ("termination", "--termination", "Check program termination"), LeonFlagOptionDef ("synthesis", "--synthesis", "Partial synthesis of choose() constructs"), LeonFlagOptionDef ("xlang", "--xlang", "Support for extra program constructs (imperative,...)"), LeonFlagOptionDef ("parse", "--parse", "Checks only whether the program is valid PureScala"), @@ -37,7 +39,7 @@ object Main { lazy val allOptions = allComponents.flatMap(_.definedOptions) ++ topLevelOptions def displayHelp(reporter: Reporter) { - reporter.info("usage: leon [--xlang] [--synthesis] [--help] [--debug=<N>] [..] <files>") + reporter.info("usage: leon [--xlang] [--termination] [--synthesis] [--help] [--debug=<N>] [..] <files>") reporter.info("") for (opt <- topLevelOptions.toSeq.sortBy(_.name)) { reporter.info("%-20s %s".format(opt.usageOption, opt.usageDesc)) @@ -100,6 +102,8 @@ object Main { // Process options we understand: for(opt <- leonOptions) opt match { + case LeonFlagOption("termination") => + settings = settings.copy(termination = true, xlang = false, verify = false, synthesis = false) case LeonFlagOption("synthesis") => settings = settings.copy(synthesis = true, xlang = false, verify = false) case LeonFlagOption("xlang") => @@ -127,7 +131,10 @@ object Main { } val pipeVerify: Pipeline[Program, Any] = - if (settings.xlang) { + if (settings.termination) { + // OK, OK, that's not really verification... + termination.TerminationPhase + } else if (settings.xlang) { xlang.XlangAnalysisPhase } else if (settings.verify) { verification.AnalysisPhase @@ -152,7 +159,12 @@ object Main { try { // Run pipeline pipeline.run(ctx)(args.toList) match { - case (report: verification.VerificationReport) => reporter.info(report.summaryString) + case report: verification.VerificationReport => + reporter.info(report.summaryString) + + case report: termination.TerminationReport => + reporter.info(report.summaryString) + case _ => } } catch { diff --git a/src/main/scala/leon/Settings.scala b/src/main/scala/leon/Settings.scala index 754fbde7fff21aefc8398598e4fa2c685616e31f..c267ba4d42d76136e30cd635ba056af5da77729a 100644 --- a/src/main/scala/leon/Settings.scala +++ b/src/main/scala/leon/Settings.scala @@ -33,6 +33,7 @@ object Settings { } case class Settings( + val termination: Boolean = false, val synthesis: Boolean = false, val xlang: Boolean = false, val verify: Boolean = true, diff --git a/src/main/scala/leon/termination/SCC.scala b/src/main/scala/leon/termination/SCC.scala new file mode 100644 index 0000000000000000000000000000000000000000..c0c8f75887133be3897123b66bea9f9d7b9d6464 --- /dev/null +++ b/src/main/scala/leon/termination/SCC.scala @@ -0,0 +1,68 @@ +package leon +package termination + +/** This could be defined anywhere, it's just that the + termination checker is the only place where it is used. */ +object SCC { + def scc[T](graph : Map[T,Set[T]]) : (Array[Set[T]],Map[T,Int],Map[Int,Set[Int]]) = { + // The first part is a shameless adaptation from Wikipedia + val allVertices : Set[T] = graph.keySet ++ graph.values.flatten + + var index = 0 + var indices : Map[T,Int] = Map.empty + var lowLinks : Map[T,Int] = Map.empty + var components : List[Set[T]] = Nil + var s : List[T] = Nil + + def strongConnect(v : T) { + indices = indices.updated(v, index) + lowLinks = lowLinks.updated(v, index) + index += 1 + s = v :: s + + for(w <- graph.getOrElse(v, Set.empty)) { + if(!indices.isDefinedAt(w)) { + strongConnect(w) + lowLinks = lowLinks.updated(v, lowLinks(v) min lowLinks(w)) + } else if(s.contains(w)) { + lowLinks = lowLinks.updated(v, lowLinks(v) min indices(w)) + } + } + + if(lowLinks(v) == indices(v)) { + var c : Set[T] = Set.empty + var stop = false + do { + val x :: xs = s + c = c + x + s = xs + stop = (x == v) + } while(!stop); + components = c :: components + } + } + + for(v <- allVertices) { + if(!indices.isDefinedAt(v)) { + strongConnect(v) + } + } + + // At this point, we have our components. + // We finish by building a graph between them. + // In the graph, components are represented as arrays indices. + val asArray = components.toArray + val cSize = asArray.length + + val vertIDs : Map[T,Int] = allVertices.map(v => + v -> (0 until cSize).find(i => asArray(i)(v)).get + ).toMap + + val bigCallGraph : Map[Int,Set[Int]] = (0 until cSize).map({ i => + val dsts = asArray(i).flatMap(v => graph.getOrElse(v, Set.empty)).map(vertIDs(_)) + i -> dsts + }).toMap + + (asArray,vertIDs,bigCallGraph) + } +} diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala new file mode 100644 index 0000000000000000000000000000000000000000..840c16f2518de7df54b3f627d5ed7146dc4cf158 --- /dev/null +++ b/src/main/scala/leon/termination/SimpleTerminationChecker.scala @@ -0,0 +1,116 @@ +package leon +package termination + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees._ +import purescala.TreeOps._ + +import scala.collection.mutable.{Map=>MutableMap} + +import scala.annotation.tailrec + +class SimpleTerminationChecker(context : LeonContext, program : Program) extends TerminationChecker(context, program) { + + val name = "T1" + val description = "The simplest form of Terminator™" + + private lazy val callGraph : Map[FunDef,Set[FunDef]] = + program.callGraph.groupBy(_._1).mapValues(_.map(_._2)) // one liner from hell + + private lazy val sccTriple = SCC.scc(callGraph) + private lazy val sccArray : Array[Set[FunDef]] = sccTriple._1 + private lazy val funDefToSCCIndex : Map[FunDef,Int] = sccTriple._2 + private lazy val sccGraph : Map[Int,Set[Int]] = sccTriple._3 + + private def callees(funDef : FunDef) : Set[FunDef] = callGraph.getOrElse(funDef, Set.empty) + + private val answerCache = MutableMap.empty[FunDef,TerminationGuarantee] + + def terminates(funDef : FunDef) = answerCache.getOrElse(funDef, { + val g = forceCheckTermination(funDef) + answerCache(funDef) = g + g + }) + + private def forceCheckTermination(funDef : FunDef) : TerminationGuarantee = { + // We would have to clarify what it means to terminate. + // We would probably need something along the lines of: + // "Terminates for all values satisfying prec." + if(funDef.hasPrecondition) + return NoGuarantee + + // This is also too confusing for me to think about now. + if(!funDef.hasImplementation) + return NoGuarantee + + val sccIndex = funDefToSCCIndex(funDef) + val sccCallees = sccGraph(sccIndex) + + // We check all functions that are in a "lower" scc. These must + // terminate for all inputs in any case. + val sccLowerCallees = sccCallees.filterNot(_ == sccIndex) + val lowerDefs = sccLowerCallees.map(sccArray(_)).foldLeft(Set.empty[FunDef])(_ ++ _) + val lowerOK = lowerDefs.forall(terminates(_).isGuaranteed) + if(!lowerOK) + return NoGuarantee + + // Now all we need to do is check the functions in the same + // scc. But who knows, maybe none of these are called? + if(!sccCallees(sccIndex)) { + // (the distinction isn't exactly useful...) + if(sccCallees.isEmpty) + return TerminatesForAllInputs("no calls") + else + return TerminatesForAllInputs("by subcalls") + } + + // So now we know the function is recursive (or mutually + // recursive). Maybe it's just self-recursive? + if(sccArray(sccIndex).size == 1) { + assert(sccArray(sccIndex) == Set(funDef)) + // Yes it is ! + // Now we apply a simple recipe: we check that in each (self) + // call, at least one argument is of an ADT type and decreases. + // Yes, it's that restrictive. + val callsOfInterest = { (oe : Option[Expr]) => + oe.map { e => + functionCallsOf( + simplifyLets( + matchToIfThenElse(e) + ) + ).filter(_.funDef == funDef) + } getOrElse Set.empty[FunctionInvocation] + } + + val callsToAnalyze = callsOfInterest(funDef.body) ++ callsOfInterest(funDef.precondition) ++ callsOfInterest(funDef.postcondition) + + assert(!callsToAnalyze.isEmpty) + + val funDefArgsIDs = funDef.args.map(_.id).toSet + + if(callsToAnalyze.forall { fi => + fi.args.exists { arg => + isSubTreeOfArg(arg, funDefArgsIDs) + } + }) { + return TerminatesForAllInputs("decreasing") + } else { + return NoGuarantee + } + } + + // Handling mutually recursive functions is beyond my willpower. + NoGuarantee + } + + private def isSubTreeOfArg(expr : Expr, args : Set[Identifier]) : Boolean = { + @tailrec + def rec(e : Expr, fst : Boolean) : Boolean = e match { + case Variable(id) if(args(id)) => !fst + case CaseClassSelector(_, cc, _) => rec(cc, false) + case _ => false + } + rec(expr, true) + } +} diff --git a/src/main/scala/leon/termination/TerminationChecker.scala b/src/main/scala/leon/termination/TerminationChecker.scala new file mode 100644 index 0000000000000000000000000000000000000000..ed54acea4cdb6ebdd10041bedae9e104c606a7b7 --- /dev/null +++ b/src/main/scala/leon/termination/TerminationChecker.scala @@ -0,0 +1,18 @@ +package leon +package termination + +import purescala.Definitions._ + +abstract class TerminationChecker(val context : LeonContext, val program : Program) extends LeonComponent { + + def terminates(funDef : FunDef) : TerminationGuarantee +} + +sealed abstract class TerminationGuarantee { + val isGuaranteed : Boolean = false +} + +case class TerminatesForAllInputs(justification : String) extends TerminationGuarantee { + override val isGuaranteed : Boolean = true +} +case object NoGuarantee extends TerminationGuarantee diff --git a/src/main/scala/leon/termination/TerminationPhase.scala b/src/main/scala/leon/termination/TerminationPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..b59be31842d3ce4a49a7cff98a185c8f3f126b85 --- /dev/null +++ b/src/main/scala/leon/termination/TerminationPhase.scala @@ -0,0 +1,20 @@ +package leon +package termination + +import purescala.Definitions._ + +object TerminationPhase extends LeonPhase[Program,TerminationReport] { + val name = "Termination" + val description = "Check termination of PureScala functions" + + def run(ctx : LeonContext)(program : Program) : TerminationReport = { + val tc = new SimpleTerminationChecker(ctx, program) + + val startTime = System.currentTimeMillis + val results = program.definedFunctions.toList.sortWith(_ < _).map { funDef => + (funDef -> tc.terminates(funDef)) + } + val endTime = System.currentTimeMillis + new TerminationReport(results, (endTime - startTime).toDouble / 1000.0d) + } +} diff --git a/src/main/scala/leon/termination/TerminationReport.scala b/src/main/scala/leon/termination/TerminationReport.scala new file mode 100644 index 0000000000000000000000000000000000000000..4a4af2a6e95f091f65abf038e3bb2c6a8a120461 --- /dev/null +++ b/src/main/scala/leon/termination/TerminationReport.scala @@ -0,0 +1,18 @@ +package leon +package termination + +import purescala.Definitions._ + +case class TerminationReport(val results : Seq[(FunDef,TerminationGuarantee)], val time : Double) { + def summaryString : String = { + val sb = new StringBuilder + sb.append("─────────────────────\n") + sb.append(" Termination summary \n") + sb.append("─────────────────────\n\n") + for((fd,g) <- results) { + sb.append("- %-30s %-30s\n".format(fd.id.name, g.toString)) + } + sb.append("\n[Analysis time: %7.3f]\n".format(time)) + sb.toString + } +}