From f513ef074a85476999294e2bc43acb3e4c31e4cf Mon Sep 17 00:00:00 2001
From: manoskouk <emmanouil.koukoutos@epfl.ch>
Date: Mon, 16 Feb 2015 20:00:09 +0100
Subject: [PATCH] Richer collection API, including HOFs

---
 library/Option.scala          |  26 +++++++
 library/collection/List.scala | 140 +++++++++++++++++++++++++++-------
 2 files changed, 140 insertions(+), 26 deletions(-)

diff --git a/library/Option.scala b/library/Option.scala
index ccf945eb1..6a687490a 100644
--- a/library/Option.scala
+++ b/library/Option.scala
@@ -33,6 +33,32 @@ sealed abstract class Option[T] {
 
   def isDefined = !isEmpty
 
+
+  // Higher-order API
+  def map[R](f: T => R) = { this match {
+    case None() => None[R]()
+    case Some(x) => Some(f(x))
+  }} ensuring { _.isDefined == this.isDefined }
+
+  def flatMap[R](f: T => Option[R]) = this match {
+    case None() => None[R]()
+    case Some(x) => f(x)
+  }
+
+  def filter(p: T => Boolean) = this match {
+    case Some(x) if p(x) => Some(x)
+    case _ => None[T]()
+  }
+
+  def withFilter(p: T => Boolean) = filter(p)
+
+  def forall(p: T => Boolean) = this match {
+    case Some(x) if !p(x) => false 
+    case _ => true
+  }
+
+  def exists(p: T => Boolean) = !forall(!p(_))
+
 }
 
 case class Some[T](v: T) extends Option[T]
diff --git a/library/collection/List.scala b/library/collection/List.scala
index f4ab47a45..a35e027c4 100644
--- a/library/collection/List.scala
+++ b/library/collection/List.scala
@@ -1,5 +1,4 @@
 /* Copyright 2009-2014 EPFL, Lausanne */
-
 package leon.collection
 
 import leon._
@@ -8,6 +7,7 @@ import leon.annotation._
 
 @library
 sealed abstract class List[T] {
+
   def size: BigInt = (this match {
     case Nil() => BigInt(0)
     case Cons(h, t) => 1 + t.size
@@ -27,20 +27,21 @@ sealed abstract class List[T] {
   def ++(that: List[T]): List[T] = (this match {
     case Nil() => that
     case Cons(x, xs) => Cons(x, xs ++ that)
-  }) ensuring { res => (res.content == this.content ++ that.content) && (res.size == this.size + that.size)}
+  }) ensuring { res => 
+    (res.content == this.content ++ that.content) && 
+    (res.size == this.size + that.size)
+  }
 
   def head: T = {
     require(this != Nil[T]())
-    this match {
-      case Cons(h, t) => h
-    }
+    val Cons(h, _) = this
+    h
   }
 
   def tail: List[T] = {
     require(this != Nil[T]())
-    this match {
-      case Cons(h, t) => t
-    }
+    val Cons(_, t) = this
+    t
   }
 
   def apply(index: BigInt): T = {
@@ -68,32 +69,40 @@ sealed abstract class List[T] {
     }
   } ensuring (res => (res.size == size) && (res.content == content))
 
-  def take(i: BigInt): List[T] = (this, i) match {
+  def take(i: BigInt): List[T] = { (this, i) match {
     case (Nil(), _) => Nil()
     case (Cons(h, t), i) =>
-      if (i == BigInt(0)) {
+      if (i <= BigInt(0)) {
         Nil()
       } else {
         Cons(h, t.take(i-1))
       }
-  }
+  }} ensuring { _.size == (
+    if      (i <= 0)         BigInt(0)
+    else if (i >= this.size) this.size 
+    else                     i
+  )}
 
-  def drop(i: BigInt): List[T] = (this, i) match {
+  def drop(i: BigInt): List[T] = { (this, i) match {
     case (Nil(), _) => Nil()
     case (Cons(h, t), i) =>
-      if (i == BigInt(0)) {
+      if (i <= BigInt(0)) {
         Cons(h, t)
       } else {
         t.drop(i-1)
       }
-  }
+  }} ensuring { _.size == (
+    if      (i <= 0)         this.size
+    else if (i >= this.size) BigInt(0)
+    else                     this.size - i
+  )}
 
   def slice(from: BigInt, to: BigInt): List[T] = {
     require(from < to && to < size && from >= 0)
     drop(from).take(to-from)
   }
 
-  def replace(from: T, to: T): List[T] = this match {
+  def replace(from: T, to: T): List[T] = { this match {
     case Nil() => Nil()
     case Cons(h, t) =>
       val r = t.replace(from, to)
@@ -102,6 +111,12 @@ sealed abstract class List[T] {
       } else {
         Cons(h, r)
       }
+  }} ensuring { res =>
+    res.size == this.size && 
+    res.content == (
+      (this.content -- Set(from)) ++
+      (if (this.content contains from) Set(to) else Set[T]())
+    )
   }
 
   private def chunk0(s: BigInt, l: List[T], acc: List[T], res: List[List[T]], s0: BigInt): List[List[T]] = l match {
@@ -125,14 +140,16 @@ sealed abstract class List[T] {
     chunk0(s, this, Nil(), Nil(), s)
   }
 
-  def zip[B](that: List[B]): List[(T, B)] = (this, that) match {
+  def zip[B](that: List[B]): List[(T, B)] = { (this, that) match {
     case (Cons(h1, t1), Cons(h2, t2)) =>
       Cons((h1, h2), t1.zip(t2))
     case (_) =>
       Nil()
-  }
+  }} ensuring { _.size == (
+    if (this.size <= that.size) this.size else that.size
+  )}
 
-  def -(e: T): List[T] = this match {
+  def -(e: T): List[T] = { this match {
     case Cons(h, t) =>
       if (e == h) {
         t - e
@@ -141,9 +158,9 @@ sealed abstract class List[T] {
       }
     case Nil() =>
       Nil()
-  }
+  }} ensuring { _.content == this.content -- Set(e) }
 
-  def --(that: List[T]): List[T] = this match {
+  def --(that: List[T]): List[T] = { this match {
     case Cons(h, t) =>
       if (that.contains(h)) {
         t -- that
@@ -152,9 +169,9 @@ sealed abstract class List[T] {
       }
     case Nil() =>
       Nil()
-  }
+  }} ensuring { _.content == this.content -- that.content }
 
-  def &(that: List[T]): List[T] = this match {
+  def &(that: List[T]): List[T] = { this match {
     case Cons(h, t) =>
       if (that.contains(h)) {
         Cons(h, t & that)
@@ -163,7 +180,7 @@ sealed abstract class List[T] {
       }
     case Nil() =>
       Nil()
-  }
+  }} ensuring { _.content == (this.content & that.content) }
 
   def pad(s: BigInt, e: T): List[T] = (this, s) match {
     case (_, s) if s <= 0 =>
@@ -174,7 +191,7 @@ sealed abstract class List[T] {
       Cons(h, t.pad(s-1, e))
   }
 
-  def find(e: T): Option[BigInt] = this match {
+  def find(e: T): Option[BigInt] = { this match {
     case Nil() => None()
     case Cons(h, t) =>
       if (h == e) {
@@ -185,7 +202,7 @@ sealed abstract class List[T] {
           case Some(i) => Some(i+1)
         }
       }
-  }
+  }} ensuring { _.isDefined == this.contains(e) }
 
   def init: List[T] = (this match {
     case Cons(h, Nil()) =>
@@ -196,6 +213,14 @@ sealed abstract class List[T] {
       Nil[T]()
   }) ensuring ( (r: List[T]) => ((r.size < this.size) || (this.size == BigInt(0))) )
 
+  def last: T = {
+    require(!isEmpty)
+    this match {
+      case Cons(h, Nil()) => h
+      case Cons(_, t) => t.last
+    }
+  }
+
   def lastOption: Option[T] = this match {
     case Cons(h, t) =>
       t.lastOption.orElse(Some(h))
@@ -290,6 +315,61 @@ sealed abstract class List[T] {
     case _ => false 
   }
 
+  // Higher-order API
+  def map[R](f: T => R): List[R] = { this match {
+    case Nil() => Nil()
+    case Cons(h, t) => f(h) :: t.map(f)
+  }} ensuring { _.size == this.size}
+
+  def foldLeft[R](z: R)(f: (R,T) => R): R = this match {
+    case Nil() => z
+    case Cons(h,t) => t.foldLeft(f(z,h))(f)
+  }
+
+  def foldRight[R](f: (T,R) => R)(z: R): R = this match {
+    case Nil() => z
+    case Cons(h, t) => f(h, t.foldRight(f)(z))
+  }
+
+  def flatMap[R](f: T => List[R]): List[R] = 
+    ListOps.flatten(this map f)
+
+  def filter(p: T => Boolean): List[T] = { this match {
+    case Nil() => Nil()
+    case Cons(h, t) if p(h) => Cons(h, t.filter(p))
+    case Cons(_, t) => t.filter(p)
+  }} ensuring { res => res.size <= this.size && res.forall(p) }
+
+  // In case we implement for-comprehensions
+  def withFilter(p: T => Boolean) = filter(p)
+
+  def forall(p: T => Boolean): Boolean = this match {
+    case Nil() => true
+    case Cons(h, t) => p(h) && t.forall(p)
+  }
+
+  def exists(p: T => Boolean) = !forall(!p(_))
+
+  def find(p: T => Boolean): Option[T] = { this match {
+    case Nil() => None()
+    case Cons(h, t) if p(h) => Some(h)
+    case Cons(_, t) => t.find(p)
+  }} ensuring { _.isDefined == exists(p) }
+
+  // FIXME: I keep getting these weird type errors
+  //def groupBy[R](f: T => R): Map[R, List[T]] = this match {
+  //  case Nil() => Map.empty[R, List[T]]
+  //  case Cons(h, t) =>
+  //    val key: R = f(h)
+  //    val rest: Map[R, List[T]] = t.groupBy(f)
+  //    val prev: List[T] = if (rest isDefinedAt key) rest(key) else Nil[T]()
+  //    (rest ++ Map((key, h :: prev))) : Map[R, List[T]]
+  //}
+
+  def takeWhile(p: T => Boolean): List[T] = { this match {
+    case Cons(h,t) if p(h) => Cons(h, t.takeWhile(p))
+    case _ => Nil[T]()
+  }} ensuring { _ forall p }
 }
 
 @ignore
@@ -327,7 +407,6 @@ object ListOps {
   }
 }
 
-
 case class Cons[T](h: T, t: List[T]) extends List[T]
 case class Nil[T]() extends List[T]
 
@@ -414,4 +493,13 @@ object ListSpecs {
   //  }) &&
   //  ((l1 ++ l2).reverse == (l2.reverse ++ l1.reverse))
   //}.holds
+  
+  //@induct
+  //def folds[T,R](l : List[T], z : R, f : (R,T) => R) = {
+  //  { l match {
+  //    case Nil() => true
+  //    case Cons(h,t) => snocReverse[T](t, h)
+  //  }} &&
+  //  l.foldLeft(z)(f) == l.reverse.foldRight((x:T,y:R) => f(y,x))(z)
+  //}.holds
 }
-- 
GitLab