From bb580e51ac2bd65375b096921d9a3a535db63fd0 Mon Sep 17 00:00:00 2001
From: Nicolas Voirol <voirol.nicolas@gmail.com>
Date: Mon, 25 Apr 2016 17:55:23 +0200
Subject: [PATCH] More structural size termination

---
 .../leon/termination/ChainComparator.scala    |  51 +------
 .../ComplexTerminationChecker.scala           |  36 ++---
 .../leon/termination/ProcessingPipeline.scala |  23 +++-
 .../leon/termination/RelationComparator.scala |  26 +++-
 .../leon/termination/RelationProcessor.scala  |   1 +
 .../leon/termination/StructuralSize.scala     | 128 ++++++++++++++++--
 src/main/scala/leon/utils/SCC.scala           |  19 +--
 7 files changed, 183 insertions(+), 101 deletions(-)

diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala
index 0a4df39fd..5e68311c0 100644
--- a/src/main/scala/leon/termination/ChainComparator.scala
+++ b/src/main/scala/leon/termination/ChainComparator.scala
@@ -14,58 +14,9 @@ import purescala.Common._
 trait ChainComparator { self : StructuralSize =>
   val checker: TerminationChecker
 
-  private object ContainerType {
-    def unapply(c: ClassType): Option[(CaseClassType, Seq[(Identifier, TypeTree)])] = c match {
-      case cct @ CaseClassType(ccd, _) =>
-        if (cct.fields.exists(arg => isSubtypeOf(arg.getType, cct.root))) None
-        else if (ccd.hasParent && ccd.parent.get.knownDescendants.size > 1) None
-        else Some((cct, cct.fields.map(arg => arg.id -> arg.getType)))
-      case _ => None
-    }
-  }
-
-  private def flatTypesPowerset(tpe: TypeTree): Set[Expr => Expr] = {
-    def powerSetToFunSet(l: TraversableOnce[Expr => Expr]): Set[Expr => Expr] = {
-      l.toSet.subsets.filter(_.nonEmpty).map{
-        (reconss : Set[Expr => Expr]) => (e : Expr) => 
-          tupleWrap(reconss.toSeq map { f => f(e) })
-      }.toSet
-    }
-
-    def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match  {
-      case ContainerType(cct, fields) =>
-        powerSetToFunSet(fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) =>
-          rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId)))
-        })
-      case TupleType(tpes) =>
-        powerSetToFunSet(tpes.indices.flatMap { case index =>
-          rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1, true)))
-        })
-      case _ => Set((e: Expr) => e)
-    }
-
-    rec(tpe)
-  }
-
-  private def flatType(tpe: TypeTree): Set[Expr => Expr] = {
-    def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match {
-      case ContainerType(cct, fields) =>
-        fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) =>
-          rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId)))
-        }.toSet
-      case TupleType(tpes) =>
-        tpes.indices.flatMap { case index =>
-          rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1, true)))
-        }.toSet
-      case _ => Set((e: Expr) => e)
-    }
-
-    rec(tpe)
-  }
-
   def structuralDecreasing(e1: Expr, e2s: Seq[(Path, Expr)]): Seq[Expr] = flatTypesPowerset(e1.getType).toSeq.map {
     recons => andJoin(e2s.map { case (path, e2) =>
-      path implies GreaterThan(self.size(recons(e1)), self.size(recons(e2)))
+      path implies GreaterThan(self.fullSize(recons(e1)), self.fullSize(recons(e2)))
     })
   }
 
diff --git a/src/main/scala/leon/termination/ComplexTerminationChecker.scala b/src/main/scala/leon/termination/ComplexTerminationChecker.scala
index 6d634f088..059ebdafd 100644
--- a/src/main/scala/leon/termination/ComplexTerminationChecker.scala
+++ b/src/main/scala/leon/termination/ComplexTerminationChecker.scala
@@ -18,39 +18,41 @@ class ComplexTerminationChecker(context: LeonContext, initProgram: Program) exte
        with ChainComparator
        with Strengthener
        with RelationBuilder
-       with ChainBuilder 
-  {
-    val checker = ComplexTerminationChecker.this
-  }
+       with ChainBuilder {
+         val checker = ComplexTerminationChecker.this
+       }
 
   val modulesLexicographic =
         new StructuralSize
        with LexicographicRelationComparator
-       with ChainComparator
        with Strengthener
-       with RelationBuilder
-       with ChainBuilder 
-  {
-    val checker = ComplexTerminationChecker.this
-  }
+       with RelationBuilder {
+         val checker = ComplexTerminationChecker.this
+       }
+
+  val modulesOuter =
+        new StructuralSize
+       with ArgsOuterSizeRelationComparator
+       with Strengthener
+       with RelationBuilder {
+         val checker = ComplexTerminationChecker.this
+       }
 
   val modulesBV =
         new StructuralSize
        with BVRelationComparator
-       with ChainComparator
        with Strengthener
-       with RelationBuilder
-       with ChainBuilder 
-  {
-    val checker = ComplexTerminationChecker.this
-  }
+       with RelationBuilder {
+         val checker = ComplexTerminationChecker.this
+       }
 
   def processors = List(
     new RecursionProcessor(this, modules),
     // RelationProcessor is the only Processor which benefits from trying a different RelationComparator
     new RelationProcessor(this, modulesBV),
-    new RelationProcessor(this, modules),
+    new RelationProcessor(this, modulesOuter),
     new RelationProcessor(this, modulesLexicographic),
+    new RelationProcessor(this, modules),
     new ChainProcessor(this, modules),
     new SelfCallsProcessor(this),
     new LoopProcessor(this, modules)
diff --git a/src/main/scala/leon/termination/ProcessingPipeline.scala b/src/main/scala/leon/termination/ProcessingPipeline.scala
index 4c979cadc..94d5566d3 100644
--- a/src/main/scala/leon/termination/ProcessingPipeline.scala
+++ b/src/main/scala/leon/termination/ProcessingPipeline.scala
@@ -141,7 +141,8 @@ abstract class ProcessingPipeline(context: LeonContext, initProgram: Program) ex
         } else {
           clearedMap.get(funDef).map(Terminates).getOrElse(
             if (!running) {
-              verifyTermination(funDef)
+              val verified = verifyTermination(funDef)
+              for (fd <- verified) terminates(fd) // fill in terminationMap
               terminates(funDef)
             } else {
               if (!dependencies.exists(_.contains(funDef))) {
@@ -162,15 +163,21 @@ abstract class ProcessingPipeline(context: LeonContext, initProgram: Program) ex
     val funDefs = program.callGraph.transitiveCallees(funDef) + funDef
     val pairs = program.callGraph.allCalls.filter { case (fd1, fd2) => funDefs(fd1) && funDefs(fd2) }
     val callGraph = pairs.groupBy(_._1).mapValues(_.map(_._2))
-    val components = SCC.scc(callGraph)
+    val allComponents = SCC.scc(callGraph)
 
-    for (fd <- funDefs -- components.toSet.flatten) clearedMap(fd) = "Non-recursive"
-    val newProblems = components.filter(fds => fds.forall { fd => !terminationMap.isDefinedAt(fd) })
+    val (problemComponents, nonRec) = allComponents.partition { fds =>
+      fds.flatMap(fd => program.callGraph.transitiveCallees(fd)) exists fds
+    }
+
+    for (fd <- funDefs -- problemComponents.toSet.flatten) clearedMap(fd) = "Non-recursive"
+    val newProblems = problemComponents.filter(fds => fds.forall { fd => !terminationMap.isDefinedAt(fd) })
     newProblems.map(fds => Problem(fds.toSeq))
   }
 
-  def verifyTermination(funDef: FunDef): Unit = {
-    problems ++= generateProblems(funDef).map(_ -> 0)
+  def verifyTermination(funDef: FunDef): Set[FunDef] = {
+    reporter.debug("Verifying termination of " + funDef.id)
+    val terminationProblems = generateProblems(funDef)
+    problems ++= terminationProblems.map(_ -> 0)
 
     val it = new Iterator[(String, List[Result])] {
       def hasNext : Boolean      = running
@@ -180,7 +187,7 @@ abstract class ProcessingPipeline(context: LeonContext, initProgram: Program) ex
         val processor : Processor = processorArray(index)
         reporter.debug("Running " + processor.name)
         val result = processor.run(problem)
-        reporter.debug(" +-> " + (if (result.isDefined) "Success" else "Failure"))
+        reporter.debug(" +-> " + (if (result.isDefined) "Success" else "Failure")+ "\n")
 
         // dequeue and enqueue atomically to make sure the queue always
         // makes sense (necessary for calls to terminates(fd))
@@ -211,5 +218,7 @@ abstract class ProcessingPipeline(context: LeonContext, initProgram: Program) ex
       case Broken(fd, args) => brokenMap(fd) = (reason, args)
       case MaybeBroken(fd, args) => maybeBrokenMap(fd) = (reason, args)
     }
+
+    terminationProblems.flatMap(_.funDefs).toSet
   }
 }
diff --git a/src/main/scala/leon/termination/RelationComparator.scala b/src/main/scala/leon/termination/RelationComparator.scala
index 150759232..a7cba806d 100644
--- a/src/main/scala/leon/termination/RelationComparator.scala
+++ b/src/main/scala/leon/termination/RelationComparator.scala
@@ -18,7 +18,6 @@ trait RelationComparator { self : StructuralSize =>
 
   /** weakly decreasing: args1 >= args2 */
   def softDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr
-
 }
 
 
@@ -29,13 +28,28 @@ trait ArgsSizeSumRelationComparator extends RelationComparator { self : Structur
   def isApplicableFor(p: Problem): Boolean = true
 
   def sizeDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = {
-    GreaterThan(self.size(tupleWrap(args1)), self.size(tupleWrap(args2)))
+    GreaterThan(self.fullSize(tupleWrap(args1)), self.fullSize(tupleWrap(args2)))
   }
 
   def softDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = {
-    GreaterEquals(self.size(tupleWrap(args1)), self.size(tupleWrap(args2)))
+    GreaterEquals(self.fullSize(tupleWrap(args1)), self.fullSize(tupleWrap(args2)))
+  }
+}
+
+
+trait ArgsOuterSizeRelationComparator extends RelationComparator { self : StructuralSize =>
+
+  val comparisonMethod = "comparing outer structural sizes of argument types"
+
+  def isApplicableFor(p: Problem): Boolean = true
+
+  def sizeDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = {
+    GreaterThan(self.outerSize(tupleWrap(args1)), self.outerSize(tupleWrap(args2)))
   }
 
+  def softDecreasing(args1: Seq[Expr], args2: Seq[Expr]): Expr = {
+    GreaterEquals(self.outerSize(tupleWrap(args1)), self.outerSize(tupleWrap(args2)))
+  }
 }
 
 
@@ -46,13 +60,12 @@ trait LexicographicRelationComparator extends RelationComparator { self : Struct
   def isApplicableFor(p: Problem): Boolean = true
 
   def sizeDecreasing(s1: Seq[Expr], s2: Seq[Expr]): Expr = {
-    lexicographicDecreasing(s1, s2, strict = true, sizeOfOneExpr = e => self.size(e))
+    lexicographicDecreasing(s1, s2, strict = true, sizeOfOneExpr = e => self.fullSize(e))
   }
 
   def softDecreasing(s1: Seq[Expr], s2: Seq[Expr]): Expr = {
-    lexicographicDecreasing(s1, s2, strict = false, sizeOfOneExpr = e => self.size(e))
+    lexicographicDecreasing(s1, s2, strict = false, sizeOfOneExpr = e => self.fullSize(e))
   }
-
 }
 
 // for bitvector Ints
@@ -93,7 +106,6 @@ trait BVRelationComparator extends RelationComparator { self : StructuralSize =>
     val s2 = s20.filter(_.getType == Int32Type)
     lexicographicDecreasing(s2, s1, strict = false, sizeOfOneExpr = bvSize)
   }
-
 }
 
 
diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala
index c27a9a6e1..980c015f1 100644
--- a/src/main/scala/leon/termination/RelationProcessor.scala
+++ b/src/main/scala/leon/termination/RelationProcessor.scala
@@ -46,6 +46,7 @@ class RelationProcessor(
         else if (definitiveALL(ge)) Dep(Set(fid))
         else Failure
       })
+
       val result = if(solved.contains(Failure)) Failure else {
         val deps = solved.collect({ case Dep(fds) => fds }).flatten
         if (deps.isEmpty) Success
diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala
index 9a1995365..f7554e7aa 100644
--- a/src/main/scala/leon/termination/StructuralSize.scala
+++ b/src/main/scala/leon/termination/StructuralSize.scala
@@ -14,8 +14,6 @@ import scala.collection.mutable.{Map => MutableMap}
 
 trait StructuralSize {
 
-  private val sizeCache: MutableMap[TypeTree, FunDef] = MutableMap.empty
-
   /* Absolute value for BigInt type
    *
    * def absBigInt(x: BigInt): BigInt = if (x >= 0) x else -x
@@ -79,20 +77,27 @@ trait StructuralSize {
     absFun.typed
   }
 
-  def size(expr: Expr) : Expr = {
+  private val fullCache: MutableMap[TypeTree, FunDef] = MutableMap.empty
+
+  /* Fully recursive size computation
+   *
+   * Computes (positive) size of leon types by summing up all sub-elements
+   * accessible in the type definition.
+   */
+  def fullSize(expr: Expr) : Expr = {
     def funDef(ct: ClassType, cases: ClassType => Seq[MatchCase]): FunDef = {
       // we want to reuse generic size functions for sub-types
       val classDef = ct.root.classDef
       val argumentType = classDef.typed
       val typeParams = classDef.tparams.map(_.tp)
 
-      sizeCache.get(argumentType) match {
+      fullCache.get(argumentType) match {
         case Some(fd) => fd
         case None =>
           val argument = ValDef(FreshIdentifier("x", argumentType, true))
           val formalTParams = typeParams.map(TypeParameterDef)
-          val fd = new FunDef(FreshIdentifier("size", alwaysShowUniqueID = true), formalTParams, Seq(argument), IntegerType)
-          sizeCache(argumentType) = fd
+          val fd = new FunDef(FreshIdentifier("fullSize", alwaysShowUniqueID = true), formalTParams, Seq(argument), IntegerType)
+          fullCache(argumentType) = fd
 
           val body = simplifyLets(matchToIfThenElse(matchExpr(argument.toVariable, cases(argumentType))))
           val postId = FreshIdentifier("res", IntegerType)
@@ -107,29 +112,130 @@ trait StructuralSize {
     def caseClassType2MatchCase(c: CaseClassType): MatchCase = {
       val arguments = c.fields.map(vd => FreshIdentifier(vd.id.name, vd.getType))
       val argumentPatterns = arguments.map(id => WildcardPattern(Some(id)))
-      val sizes = arguments.map(id => size(Variable(id)))
+      val sizes = arguments.map(id => fullSize(Variable(id)))
       val result = sizes.foldLeft[Expr](InfiniteIntegerLiteral(1))(Plus)
       purescala.Extractors.SimpleCase(CaseClassPattern(None, c, argumentPatterns), result)
     }
 
     expr.getType match {
       case (ct: ClassType) =>
-        val fd = funDef(ct, {
+        val fd = funDef(ct.root, {
           case (act: AbstractClassType) => act.knownCCDescendants map caseClassType2MatchCase
           case (cct: CaseClassType) => Seq(caseClassType2MatchCase(cct))
         })
         FunctionInvocation(TypedFunDef(fd, ct.tps), Seq(expr))
       case TupleType(argTypes) => argTypes.zipWithIndex.map({
-        case (_, index) => size(tupleSelect(expr, index + 1, true))
-      }).foldLeft[Expr](InfiniteIntegerLiteral(0))(Plus)
+        case (_, index) => fullSize(tupleSelect(expr, index + 1, true))
+      }).foldLeft[Expr](InfiniteIntegerLiteral(0))(plus)
       case IntegerType =>
-        FunctionInvocation(typedAbsBigIntFun, Seq(expr)) 
+        FunctionInvocation(typedAbsBigIntFun, Seq(expr))
       case Int32Type =>
         FunctionInvocation(typedAbsInt2IntegerFun, Seq(expr))
       case _ => InfiniteIntegerLiteral(0)
     }
   }
 
+  private val outerCache: MutableMap[TypeTree, FunDef] = MutableMap.empty
+
+  object ContainerType {
+    def unapply(c: ClassType): Option[(CaseClassType, Seq[(Identifier, TypeTree)])] = c match {
+      case cct @ CaseClassType(ccd, _) =>
+        if (cct.fields.exists(arg => purescala.TypeOps.isSubtypeOf(arg.getType, cct.root))) None
+        else if (ccd.hasParent && ccd.parent.get.knownDescendants.size > 1) None
+        else Some((cct, cct.fields.map(arg => arg.id -> arg.getType)))
+      case _ => None
+    }
+  }
+
+  def flatTypesPowerset(tpe: TypeTree): Set[Expr => Expr] = {
+    def powerSetToFunSet(l: TraversableOnce[Expr => Expr]): Set[Expr => Expr] = {
+      l.toSet.subsets.filter(_.nonEmpty).map{
+        (reconss : Set[Expr => Expr]) => (e : Expr) => 
+          tupleWrap(reconss.toSeq map { f => f(e) })
+      }.toSet
+    }
+
+    def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match  {
+      case ContainerType(cct, fields) =>
+        powerSetToFunSet(fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) =>
+          rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId)))
+        })
+      case TupleType(tpes) =>
+        powerSetToFunSet(tpes.indices.flatMap { case index =>
+          rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1, true)))
+        })
+      case _ => Set((e: Expr) => e)
+    }
+
+    rec(tpe)
+  }
+
+  def flatType(tpe: TypeTree): Set[Expr => Expr] = {
+    def rec(tpe: TypeTree): Set[Expr => Expr] = tpe match {
+      case ContainerType(cct, fields) =>
+        fields.zipWithIndex.flatMap { case ((fieldId, fieldTpe), index) =>
+          rec(fieldTpe).map(recons => (e: Expr) => recons(caseClassSelector(cct, e, fieldId)))
+        }.toSet
+      case TupleType(tpes) =>
+        tpes.indices.flatMap { case index =>
+          rec(tpes(index)).map(recons => (e: Expr) => recons(tupleSelect(e, index + 1, true)))
+        }.toSet
+      case _ => Set((e: Expr) => e)
+    }
+
+    rec(tpe)
+  }
+
+  /* Recursively computes outer datastructure size
+   *
+   * Computes the structural size of a datastructure but only considers
+   * the outer-most datatype definition.
+   *
+   * eg. for List[List[T]], the inner list size is not considered
+   */
+  def outerSize(expr: Expr) : Expr = {
+    def dependencies(ct: ClassType): Set[ClassType] = {
+      def deps(ct: ClassType): Set[ClassType] = ct.fieldsTypes.collect { case ct: ClassType => ct.root }.toSet
+      utils.fixpoint((cts: Set[ClassType]) => cts ++ cts.flatMap(deps))(Set(ct))
+    }
+
+    flatType(expr.getType).foldLeft[Expr](InfiniteIntegerLiteral(0)) { case (i, f) =>
+      val e = f(expr)
+      plus(i, e.getType match {
+        case ct: ClassType =>
+          val root = ct.root
+          val fd = outerCache.getOrElse(root.classDef.typed, {
+            val id = FreshIdentifier("x", root.classDef.typed, true)
+            val fd = new FunDef(FreshIdentifier("outerSize", alwaysShowUniqueID = true),
+              root.classDef.tparams,
+              Seq(ValDef(id)),
+              IntegerType)
+            outerCache(root.classDef.typed) = fd
+
+            fd.body = Some(MatchExpr(Variable(id), root.knownCCDescendants map { cct =>
+              val args = cct.fields.map(_.id.freshen)
+              purescala.Extractors.SimpleCase(
+                CaseClassPattern(None, cct, args.map(id => WildcardPattern(Some(id)))),
+                args.foldLeft[Expr](InfiniteIntegerLiteral(1)) { case (e, id) =>
+                  plus(e, id.getType match {
+                    case ct: ClassType if dependencies(root)(ct.root) => outerSize(Variable(id))
+                    case _ => InfiniteIntegerLiteral(0)
+                  })
+                })
+            }))
+
+            val res = FreshIdentifier("res", IntegerType, true)
+            fd.postcondition = Some(Lambda(Seq(ValDef(res)), GreaterEquals(Variable(res), InfiniteIntegerLiteral(0))))
+            fd
+          })
+
+          FunctionInvocation(fd.typed(ct.tps), Seq(e))
+
+        case _ => InfiniteIntegerLiteral(0)
+      })
+    }
+  }
+
   def lexicographicDecreasing(s1: Seq[Expr], s2: Seq[Expr], strict: Boolean, sizeOfOneExpr: Expr => Expr): Expr = {
     // Note: The Equal and GreaterThan ASTs work for both BigInt and Bitvector
 
diff --git a/src/main/scala/leon/utils/SCC.scala b/src/main/scala/leon/utils/SCC.scala
index 7f5212d54..31dd7b4a5 100644
--- a/src/main/scala/leon/utils/SCC.scala
+++ b/src/main/scala/leon/utils/SCC.scala
@@ -9,7 +9,7 @@ package utils
   * This could be defined anywhere, it's just that the
   * termination checker is the only place where it is used. */
 object SCC {
-  def scc[T](graph : Map[T,Set[T]]) : List[Set[T]] = {
+  def scc[T](graph: Map[T,Set[T]]) : List[Set[T]] = {
     // The first part is a shameless adaptation from Wikipedia
     val allVertices : Set[T] = graph.keySet ++ graph.values.flatten
 
@@ -19,22 +19,22 @@ object SCC {
     var components : List[Set[T]] = Nil
     var s : List[T] = Nil
 
-    def strongConnect(v : T) {
+    def strongConnect(v: T) {
       indices  = indices.updated(v, index)
       lowLinks = lowLinks.updated(v, index)
       index += 1
       s = v :: s
 
-      for(w <- graph.getOrElse(v, Set.empty)) {
-        if(!indices.isDefinedAt(w)) {
+      for (w <- graph.getOrElse(v, Set.empty)) {
+        if (!indices.isDefinedAt(w)) {
           strongConnect(w)
           lowLinks = lowLinks.updated(v, lowLinks(v) min lowLinks(w))
-        } else if(s.contains(w)) {
+        } else if (s.contains(w)) {
           lowLinks = lowLinks.updated(v, lowLinks(v) min indices(w))
         }
       }
 
-      if(lowLinks(v) == indices(v)) {
+      if (lowLinks(v) == indices(v)) {
         var c : Set[T] = Set.empty
         var stop = false
         do {
@@ -42,13 +42,14 @@ object SCC {
           c = c + x
           s = xs
           stop = x == v
-        } while(!stop)
+        } while (!stop)
+
         components = c :: components
       }
     }
 
-    for(v <- allVertices) {
-      if(!indices.isDefinedAt(v)) {
+    for (v <- allVertices) {
+      if (!indices.isDefinedAt(v)) {
         strongConnect(v)
       }
     }
-- 
GitLab