From 75fc27a3abae105ee503692d0fc84fc13eb2dcb1 Mon Sep 17 00:00:00 2001
From: Manos Koukoutos <emmanouil.koukoutos@epfl.ch>
Date: Fri, 11 Sep 2015 16:31:55 +0200
Subject: [PATCH] State monad

---
 library/monads/state/State.scala | 421 +++++++++++++++++++++++++++++++
 1 file changed, 421 insertions(+)
 create mode 100644 library/monads/state/State.scala

diff --git a/library/monads/state/State.scala b/library/monads/state/State.scala
new file mode 100644
index 000000000..8bdb38320
--- /dev/null
+++ b/library/monads/state/State.scala
@@ -0,0 +1,421 @@
+package leon.monads.state
+
+import leon.collection._
+import leon.lang._
+import leon.annotation._
+
+case class State[S, A](runState: S => (A, S)) {
+
+  /** Basic monadic methods */
+
+  // Monadic bind
+  @inline
+  def flatMap[B](f: A => State[S, B]): State[S, B] = {
+    State { s =>
+      val (a, newS) = runState(s)
+      f(a).runState(newS)
+    }
+  }
+
+  // All monads are also functors, and they define the map function
+  @inline
+  def map[B](f: A => B): State[S, B] = {
+    State { s =>
+      val (a, newS) = runState(s)
+      (f(a), newS)
+    }
+  }
+
+  @inline
+  def >>=[B](f: A => State[S, B]): State[S, B] = flatMap(f)
+
+  @inline
+  def >>[B](that: State[S, B]) = >>= ( _ => that )
+
+  @inline
+  // This is unfortunate, but implementing properly would just lead to an error
+  def withFilter(f: A => Boolean): State[S, A] = this
+
+  @inline
+  def eval(s: S) = runState(s)._1
+  @inline
+  def exec(s: S) = runState(s)._2
+
+  // eval with arguments flipped
+  @inline
+  def >:: (s: S) = eval(s)
+
+  /** Helpers */
+  def forever[B]: State[S, B] = this >> forever
+
+  def apply(s: S) = runState(s)
+
+}
+
+object State {
+
+  @inline
+  def get[A]: State[A, A] = State( a => (a, a) )
+
+  @inline
+  def gets[S, A](f: S => A): State[S, A] = for {
+    s <- get[S]
+  } yield f(s)
+
+  @inline
+  def put[S](s: S): State[S, Unit] = State { _ => ((), s) }
+
+  @inline
+  def modify[S](f: S => S): State[S, Unit] = for {
+    s <- get[S]
+    s2 <- put(f(s))
+  } yield s2
+
+  // Monadic unit
+  @inline
+  def unit[S, A](a: A): State[S, A] = State { s => (a, s) }
+
+  @inline
+  def state[S]: State[S, S] = State(s => (s, s))
+
+  @inline
+  // Left-to-right Kleisli composition of monads
+  def >=>[S, A, B, C](f: A => State[S, B], g: B => State[S, C], a: A): State[S, C] = {
+    for {
+      b <- f(a)
+      c <- g(b)
+    } yield c
+  }
+
+  /* Monadic-List functions */
+
+  def mapM[A, B, S](f: A => State[S, B], l: List[A]): State[S, List[B]] = {
+    l match {
+      case Nil() => unit(Nil[B]())
+      case x :: xs => for {
+        mx <- f(x)
+        mxs <- mapM(f,xs)
+      } yield mx :: mxs
+    }
+  }
+
+  def mapM_ [A, B, S] (f: A => State[S, B], l: List[A]): State[S, Unit] = {
+    l match {
+      case Nil() => unit(())
+      case x :: xs => for {
+        mx <- f(x)
+        mxs <- mapM_ (f, xs)
+      } yield ()
+    }
+  }
+
+  def sequence[S, A](l: List[State[S, A]]): State[S, List[A]] = {
+    l.foldRight[State[S, List[A]]](unit(Nil[A]())){
+      case (x, xs) => for {
+        v <- x
+        vs <- xs
+      } yield v :: vs
+    }
+  }
+
+  def filterM[S, A](f: A => State[S, Boolean], l: List[A]): State[S, List[A]] = {
+    l match {
+      case Nil() => unit(Nil[A]())
+      case x :: xs => for {
+        flg <- f(x)
+        rest <- filterM(f, xs)
+      } yield (if (flg) x :: rest else rest)
+    }
+  }
+
+  //(b -> a -> m b) -> b -> t a -> m b
+  def foldLeftM[S, A, B](f: (B, A) => State[S, B], z: B, l: List[A]): State[S, B] = {
+    l match {
+      case Nil() => unit(z)
+      case x :: xs =>
+        f(z, x) >>= ( z0 =>  foldLeftM(f, z0,xs) )
+    }
+  }
+
+  //(b -> a -> m b) -> b -> t a -> m ()
+  def foldLeftM_ [S, A, B](f: A => State[S, Unit], l: List[A]): State[S, Unit] = {
+    l match {
+      case Nil() => unit(())
+      case x :: xs =>
+        f(x) >> foldLeftM_ (f, xs)
+    }
+  }
+
+}
+
+
+object MonadStateLaws {
+  import State._
+
+  /* Monadic laws:
+   *
+   * return a >>= k  =  k a
+   * m >>= return  =  m
+   * m >>= (x -> k x >>= h)  =  (m >>= k) >>= h
+   *
+   * Note: We cannot compare State's directly because State contains an anonymous function.
+   * So we have to run them on an initial state.
+   */
+
+  def unitLeftId[A, B, S](a: A, k: A => State[S, B])(s: S): Boolean = {
+    (unit(a) >>= (k))(s) == k(a)(s)
+  }.holds
+
+  def unitRightId[S, A](m: State[S,A])(s:S): Boolean = {
+    (m >>= unit).runState(s) == m.runState(s)
+  }.holds
+
+  def assoc[S, A, B, C](m: State[S,A], k: A => State[S, B], h: B => State[S, C])(s: S) = {
+    (m >>= (x => k(x) >>= h)).runState(s) == (m >>= k >>= h).runState(s)
+  }.holds
+
+  /* This is the same as assoc, but proves in 42sec instead of 50ms!
+  def assoc2[S, A, B, C](m: State[S,A], k: A => State[S, B], h: B => State[S, C])(s: S) = {
+    val lhs = for {
+      x <- m
+      y <- for {
+        z <- k(x)
+        w <- h(z)
+      } yield w
+    } yield y
+
+    val rhs = for {
+      x <- m
+      y <- k(x)
+      z <- h(y)
+    } yield z
+
+    lhs.runState(s) == rhs.runState(s)
+  }.holds
+  */
+
+
+  /* The law connecting monad and functor instances:
+   *
+   * m >>= (return . f)  =  m map f
+   */
+
+  def mapToFlatMap[S, A, B](m: State[S, A], f: A => B)(s: S): Boolean = {
+    (m >>= (x => unit(f(x)))).runState(s) == (m map f).runState(s)
+  }.holds
+
+
+  /* Different theorems */
+
+  //@induct
+  //def mapMVsSequenceM[S, A, B](f: A => State[S, B], l: List[A])(s: S) = {
+  //  mapM(f, l).runState(s) == sequence(l map f).runState(s)
+  //}.holds
+
+
+}
+
+object ExampleCounter {
+
+  import State._
+
+  def counter(init: BigInt) = {
+
+    @inline
+    def tick = modify[BigInt](_ + 1)
+
+    init >:: (for {
+      _ <- tick
+      _ <- tick
+      _ <- tick
+      r <- get
+    } yield r)
+
+  } ensuring( _ == init + 3 )
+
+
+}
+
+object ExampleTicTacToe {
+
+  import State._
+
+  abstract class Square
+  case object UL extends Square
+  case object U  extends Square
+  case object UR extends Square
+  case object L  extends Square
+  case object C  extends Square
+  case object R  extends Square
+  case object DL extends Square
+  case object D  extends Square
+  case object DR extends Square
+
+  @inline
+  val winningSets: List[Set[Square]] = List(
+    Set(UL, U, UR),
+    Set( L, C,  R),
+    Set(DL, D, DR),
+    Set(UL, L, DL),
+    Set( U, C,  D),
+    Set(UR, R, DR),
+    Set(UR, C, DL),
+    Set(UL, C, DR)
+  )
+
+  @inline
+  def wins(sqs: Set[Square]) = winningSets exists { _ subsetOf sqs}
+
+  case class Board(xs: Set[Square], os: Set[Square])
+  case class TState(b: Board, xHasMove: Boolean) // true = X, false = O
+
+  @inline
+  val init = TState(Board(Set.empty[Square], Set.empty[Square]), true)
+
+  @inline
+  def legalMove(b: Board, sq: Square) = !(b.xs ++ b.os).contains(sq)
+
+  @inline
+  def move(sq: Square): State[TState, Unit] = {
+    modify[TState] { case TState(board, xHasMove) =>
+      TState(
+        if (xHasMove)
+          Board(board.xs ++ Set(sq), board.os)
+        else
+          Board(board.xs, board.os ++ Set(sq)),
+        if (legalMove(board, sq))
+          !xHasMove
+        else
+          xHasMove
+      )
+    }
+  }
+
+  @inline
+  def play(moves: List[Square]): Option[Boolean] = {
+    val b = foldLeftM_ (move, moves).exec(init).b
+    if (wins(b.xs)) Some(true)
+    else if (wins(b.os)) Some(false)
+    else None[Boolean]()
+  }
+
+  //def ex = {
+  //  play(List(UL, UR, C, D, DR)) == Some(true)
+  //}.holds
+
+  //def gameEnds(moves: List[Square], oneMore: Square) = {
+  //  val res1 = play(moves)
+  //  val res2 = play(moves :+ oneMore)
+  //  res1.isDefined ==> (res1 == res2)
+  //}.holds
+}
+
+object ExampleFreshen {
+
+  import State._
+  import leon.lang.string._
+
+  case class Id(name: String, id: BigInt)
+
+  abstract class Expr
+  case class EVar(id: Id) extends Expr
+  case class EAbs(v: Id, body: Expr) extends Expr
+  case class EApp(f: Expr, arg: Expr) extends Expr
+
+  case class EState(counters: Map[String, BigInt])
+
+  def addVar(name: String): State[EState, Id] = for {
+    s <- get[EState]
+    newId = if (s.counters.contains(name)) s.counters(name) + 1 else BigInt(0)
+    _ <- put( EState(s.counters + (name -> newId) ) )
+  } yield Id(name, newId)
+
+  def freshen(e: Expr): State[EState, Expr] = e match {
+    case EVar(Id(name, id)) =>
+      for {
+        s <- get
+      } yield {
+        val newId = if (s.counters.contains(name)) s.counters(name) else id
+        EVar(Id(name, newId))
+      }
+    case EAbs(v, body) =>
+      for {
+        v2    <- addVar(v.name)
+        body2 <- freshen(body)
+      } yield EAbs(v2, body2)
+    case EApp(f, arg) =>
+      for {
+        // We need to freshen both f and arg with the initial state,
+        // so we save it here
+        init <- get
+        f2   <- freshen(f)
+        _    <- put(init)
+        arg2 <- freshen(arg)
+      } yield EApp(f2, arg2)
+  }
+
+  val empty = EState(Map[String, BigInt]())
+
+  def freshenNames(e: Expr): Expr = empty >:: freshen(e)
+
+}
+
+
+object ExampleStackEval {
+
+  import State._
+
+  abstract class Expr
+  case class Plus(e1: Expr, e2: Expr) extends Expr
+  case class Minus(e1: Expr, e2: Expr) extends Expr
+  case class UMinus(e: Expr) extends Expr
+  case class Lit(i: BigInt) extends Expr
+
+  def push(i: BigInt) = State[List[BigInt], Unit]( s => ( (), i :: s) )
+  def pop: State[List[BigInt], BigInt] = for {
+    l <- get
+    _ <- put(l.tailOption.getOrElse(Nil[BigInt]()))
+  } yield l.headOption.getOrElse(0)
+
+
+  def evalM(e: Expr): State[List[BigInt], Unit] = e match {
+    case Plus(e1, e2) =>
+      for {
+        _ <- evalM(e1)
+        _ <- evalM(e2)
+        i2 <- pop
+        i1 <- pop
+        _ <- push(i1 + i2)
+      } yield(())
+    case Minus(e1, e2) =>
+      for {
+        _ <- evalM(e1)
+        _ <- evalM(e2)
+        i2 <- pop
+        i1 <- pop
+        _ <- push(i1 - i2)
+      } yield(())
+    case UMinus(e) =>
+      evalM(e) >> pop >>= (i => push(-i))
+    case Lit(i) =>
+      push(i)
+  }
+
+  def eval(e: Expr): BigInt = e match {
+    case Plus(e1, e2) =>
+      eval(e1) + eval(e2)
+    case Minus(e1, e2) =>
+      eval(e1) - eval(e2)
+    case UMinus(e) =>
+      -eval(e)
+    case Lit(i) =>
+      i
+  }
+
+  def empty = List[BigInt]()
+
+  //def evalVsEval(e: Expr) = {
+  //  evalM(e).exec(empty).head == eval(e)
+  //}.holds
+
+}
\ No newline at end of file
-- 
GitLab