From c8bd0066b060b38f0ffec8a686e01619a6edc5f6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ali=20Sinan=20K=C3=B6ksal?= <alisinan@gmail.com>
Date: Thu, 18 Nov 2010 15:27:31 +0000
Subject: [PATCH] Inductive tactic performs induction on first argument with
 abstract type. Corresponding modification to ListWithSize to prove
 associativity of append automatically.

---
 pldi2011-testcases/MergeSort.scala  | 51 +++++++++++++++--------------
 src/purescala/InductionTactic.scala | 17 ++++++++--
 testcases/ListWithSize.scala        |  8 +++--
 3 files changed, 47 insertions(+), 29 deletions(-)

diff --git a/pldi2011-testcases/MergeSort.scala b/pldi2011-testcases/MergeSort.scala
index da575c8bb..4e0121ac4 100644
--- a/pldi2011-testcases/MergeSort.scala
+++ b/pldi2011-testcases/MergeSort.scala
@@ -2,30 +2,31 @@ import scala.collection.immutable.Set
 
 object MergeSort {
   sealed abstract class List
-  case class Cons(head:Int,tail:List) extends List
+  case class Cons(head : Int, tail : List) extends List
   case class Nil() extends List
 
-  case class Pair(fst:List,snd:List)
+  sealed abstract class PairAbs
+  case class Pair(fst : List, snd : List) extends PairAbs
 
-  def contents(l: List): Set[Int] = l match {
+  def contents(l : List) : Set[Int] = l match {
     case Nil() => Set.empty
     case Cons(x,xs) => contents(xs) ++ Set(x)
   }
 
-  def is_sorted(l: List): Boolean = l match {
+  def isSorted(l : List) : Boolean = l match {
     case Nil() => true
     case Cons(x,xs) => xs match {
       case Nil() => true
-      case Cons(y, ys) => x <= y && is_sorted(Cons(y, ys))
+      case Cons(y, ys) => x <= y && isSorted(Cons(y, ys))
     }
   }    
 
-  def length(list:List): Int = list match {
+  def size(list : List) : Int = list match {
     case Nil() => 0
-    case Cons(x,xs) => 1 + length(xs)
+    case Cons(x,xs) => 1 + size(xs)
   }
 
-  def splithelper(aList:List,bList:List,n:Int): Pair = 
+  def splithelper(aList : List, bList : List, n : Int) : Pair = 
     if (n <= 0) Pair(aList,bList)
     else
 	bList match {
@@ -33,28 +34,30 @@ object MergeSort {
     	      case Cons(x,xs) => splithelper(Cons(x,aList),xs,n-1)
 	}
 
-  def split(list:List,n:Int): Pair = splithelper(Nil(),list,n)
+  def split(list : List, n : Int): Pair = splithelper(Nil(),list,n)
 
-  def merge(aList:List, bList:List):List = (bList match {       
-    case Nil() => aList
-    case Cons(x,xs) =>
-    	 aList match {
-   	       case Nil() => bList
-   	       case Cons(y,ys) =>
-    	        if (y < x)
-    		   Cons(y,merge(ys, bList))
-     		else
-		   Cons(x,merge(aList, xs))               
-   	 }   
-  }) ensuring(res => contents(res) == contents(aList) ++ contents(bList))
+  def merge(aList : List, bList : List) : List = {
+    bList match {       
+      case Nil() => aList
+      case Cons(x,xs) =>
+        aList match {
+              case Nil() => bList
+              case Cons(y,ys) =>
+               if (y < x)
+                  Cons(y,merge(ys, bList))
+               else
+                  Cons(x,merge(aList, xs))               
+        }   
+    }
+  } ensuring(res => contents(res) == contents(aList) ++ contents(bList))
 
-  def mergeSort(list:List):List = (list match {
+  def mergeSort(list : List) : List = (list match {
     case Nil() => list
     case Cons(x,Nil()) => list
     case _ =>
-    	 val p = split(list,length(list)/2)
+    	 val p = split(list,size(list)/2)
    	 merge(mergeSort(p.fst), mergeSort(p.snd))     
-  }) ensuring(res => contents(res) == contents(list) && is_sorted(res))
+  }) ensuring(res => contents(res) == contents(list) && isSorted(res))
  
 
   def main(args: Array[String]): Unit = {
diff --git a/src/purescala/InductionTactic.scala b/src/purescala/InductionTactic.scala
index d0567d4a4..11d329519 100644
--- a/src/purescala/InductionTactic.scala
+++ b/src/purescala/InductionTactic.scala
@@ -21,6 +21,18 @@ class InductionTactic(reporter: Reporter) extends DefaultTactic(reporter) {
     })
   } 
 
+  private def firstAbsClassDef(args: VarDecls) : Option[(AbstractClassDef, VarDecl)] = {
+    val filtered = args.filter(arg =>
+      arg.getType match {
+        case AbstractClassType(_) => true
+        case _ => false
+      })
+    if (filtered.size == 0) None else (filtered.head.getType match {
+      case AbstractClassType(classDef) => Some((classDef, filtered.head))
+      case _ => scala.Predef.error("This should not happen.")
+    })
+  } 
+
   private def selectorsOfParentType(parentType: ClassType, ccd: CaseClassDef, expr: Expr) : Seq[Expr] = {
     val childrenOfSameType = ccd.fields.filter(field => field.getType == parentType)
     for (field <- childrenOfSameType) yield {
@@ -31,12 +43,11 @@ class InductionTactic(reporter: Reporter) extends DefaultTactic(reporter) {
   override def generatePostconditions(funDef: FunDef) : Seq[VerificationCondition] = {
     assert(funDef.body.isDefined)
     val defaultPost = super.generatePostconditions(funDef)
-    singleAbsClassDef(funDef.args) match {
-      case Some(classDef) =>
+    firstAbsClassDef(funDef.args) match {
+      case Some((classDef, arg)) =>
         val prec = funDef.precondition
         val post = funDef.postcondition
         val body = matchToIfThenElse(funDef.body.get)
-        val arg = funDef.args.head
         val argAsVar = arg.toVariable
 
         if (post.isEmpty) {
diff --git a/testcases/ListWithSize.scala b/testcases/ListWithSize.scala
index 7e3523bb5..4e3be387f 100644
--- a/testcases/ListWithSize.scala
+++ b/testcases/ListWithSize.scala
@@ -59,12 +59,16 @@ object ListWithSize {
     } ensuring (res => res && Cons(x,append(xs, ys)) == append(Cons(x,xs), ys))
 
     def appendAssoc(xs : List, ys : List, zs : List) : Boolean = (xs match {
-      case Nil() => (nilAppend(append(ys,zs)) && nilAppend(ys))
+      case Nil() => (nilAppendInductive(append(ys,zs)) && nilAppendInductive(ys))
       case Cons(x,xs1) => appendAssoc(xs1, ys, zs)
     }) ensuring (res => res && append(xs, append(ys, zs)) == append(append(xs,ys), zs))
 
+    @induct
+    def appendAssocInductive(xs : List, ys : List, zs : List) : Boolean =
+      (append(append(xs, ys), zs) == append(xs, append(ys, zs))) holds
+
     def sizeAppend(l1 : List, l2 : List) : Boolean = (l1 match {
-      case Nil() => nilAppend(l2)
+      case Nil() => nilAppendInductive(l2)
       case Cons(x,xs) => sizeAppend(xs, l2)
     }) ensuring(res => res && size(append(l1,l2)) == size(l1) + size(l2))
 
-- 
GitLab