From 89c64e8ba610e737db95f73667c00e24b3111f99 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Fri, 27 Dec 2024 01:32:09 +0100
Subject: [PATCH] server: Add a new endpoint (/admin/status)

---
 .../cs214/webapp/server/web/ServerApp.scala   | 15 +++++++---
 .../cs214/webapp/server/web/WebServer.scala   |  4 +++
 .../webapp/server/web/WebServerRoutes.scala   |  7 +++++
 .../src/main/scala/cs214/webapp/Common.scala  |  7 +++++
 .../main/scala/cs214/webapp/Messages.scala    | 29 +++++++++++++++++++
 5 files changed, 58 insertions(+), 4 deletions(-)

diff --git a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
index 78d9022..e97b6d0 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
@@ -6,7 +6,7 @@ import java.util.concurrent.ScheduledExecutorService
 import java.util.concurrent.atomic.AtomicReference
 import java.time.Instant
 
-import scala.collection.mutable
+import scala.collection.{concurrent, mutable}
 import scala.util.{Success, Failure, Try}
 import scala.concurrent.duration.Duration
 
@@ -64,14 +64,21 @@ private[web] abstract class ServerApp:
   def lastActivity = _lastActivity.get()
 
   private val channels =
-    mutable.Map[UserId, mutable.Set[cask.WsChannelActor]]().withDefault(_ => mutable.Set())
+    concurrent.TrieMap[UserId, mutable.Set[cask.WsChannelActor]]().withDefault(_ => mutable.Set())
+
+  def instanceInfo = // No lock needed: read-only or thread-safe members
+    AdminStatusInstanceInfo(
+      id = instanceId, info = appInfo,
+      creationTime = creationTime, lastActivity = lastActivity,
+      registeredUserIds = registeredUserIds, connectedClients = channels.keySet.toList.sorted
+    )
 
   def connect(userId: UserId)
              (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult = instanceLock.synchronized:
     recordActivity()
     cask.WsHandler: channel =>
       channels.getOrElseUpdate(userId, mutable.Set()).add(channel)
-      println(f"[${appInfo.id}/$instanceId] client \"$userId\" connected")
+      println(f"[${appInfo.id}/$instanceId/$userId] client connected")
       send(userId):
         val view = respondToNewClient(userId)
         EventResponse.Wire.encode(Success(List(Action.Render(view))))
@@ -87,7 +94,7 @@ private[web] abstract class ServerApp:
   def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor): Unit = instanceLock.synchronized:
     channels(userId).remove(channel)
     if channels(userId).isEmpty then channels.remove(userId)
-    println(f"[${appInfo.id}/$instanceId] client \"$userId\" disconnected")
+    println(f"[${appInfo.id}/$instanceId/$userId] client disconnected")
 
   /** Enumerates clients connected to this instance. */
   def connectedClients: Seq[UserId] = instanceLock.synchronized:
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
index 5906c37..edc86aa 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
@@ -55,6 +55,10 @@ object WebServer:
     case _ =>
       register(StateMachineServerAppFactory(sm))
 
+  /** Get the collection of all running instances. */
+  def runningInstances: concurrent.TrieMap[AppId, ServerApp] =
+    instances.snapshot()
+
   /** Computes an unused instance ID and calls `body` with it. */
   private def withFreshInstanceId[T](body: InstanceId => T): T =
     instances.synchronized:
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
index 631e620..9739aab 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
@@ -77,6 +77,13 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
         else
           cask.Response(f"Unknown app '$appName'", 400)
 
+  @cask.getJson(f"${Endpoints.Admin.status}")
+  def adminStatus(): cask.Response[JsonData] =
+    AdminStatusResponseEncoder.encode:
+      AdminStatusResponse(
+        apps = WebServer.runningInstances.values.toSeq.map(_.instanceInfo)
+      )
+
   @checkOriginHeader
   @cask.websocket(f"${Endpoints.WebSocket}/:instanceId/:userId")
   def websocket(instanceId: String, userId: String, request: cask.Request): cask.WebsocketResult =
diff --git a/shared/src/main/scala/cs214/webapp/Common.scala b/shared/src/main/scala/cs214/webapp/Common.scala
index abf60db..ef8461a 100644
--- a/shared/src/main/scala/cs214/webapp/Common.scala
+++ b/shared/src/main/scala/cs214/webapp/Common.scala
@@ -42,6 +42,13 @@ object Endpoints:
     object createInstance extends Api("create-instance")
     object instanceInfo extends Api("instance-info")
 
+  sealed abstract case class Admin(path: String):
+    override def toString = f"${Admin.root}/$path"
+
+  object Admin:
+    val root = "/admin"
+    object status extends Admin("status")
+
   val App = "/app"
   val WebSocket = "/ws"
 
diff --git a/shared/src/main/scala/cs214/webapp/Messages.scala b/shared/src/main/scala/cs214/webapp/Messages.scala
index 46954f3..94b3b89 100644
--- a/shared/src/main/scala/cs214/webapp/Messages.scala
+++ b/shared/src/main/scala/cs214/webapp/Messages.scala
@@ -108,3 +108,32 @@ object InstanceInfoResponse:
         js("wsEndpoint").str,
         js("shareUrl").str
       )
+
+/** A response to the admin/status query */
+
+case class AdminStatusInstanceInfo(
+  id: AppId, info: AppInfo,
+  creationTime: java.time.Instant, lastActivity: java.time.Instant,
+  registeredUserIds: Seq[UserId], connectedClients: Seq[UserId]
+)
+
+object AdminStatusInstanceInfoEncoder extends Encoder[AdminStatusInstanceInfo]:
+  def encode(t: AdminStatusInstanceInfo): ujson.Value = t match
+    case AdminStatusInstanceInfo(id, info, creationTime, lastActivity, registeredUserIds, connectedClients) =>
+      Obj(
+        "id" -> id,
+        "info" -> AppInfoWire.encode(info),
+        "creationTime" -> f"$creationTime",
+        "lastActivity" -> f"$lastActivity",
+        "registeredUserIds" -> ujson.Arr(registeredUserIds.map(ujson.Str(_))*),
+        "connectedClients" -> ujson.Arr(connectedClients.map(ujson.Str(_))*),
+      )
+
+case class AdminStatusResponse(apps: Seq[AdminStatusInstanceInfo])
+
+object AdminStatusResponseEncoder extends Encoder[AdminStatusResponse]:
+  def encode(t: AdminStatusResponse): ujson.Value = t match
+    case AdminStatusResponse(apps) =>
+      Obj("apps" -> ujson.Arr(
+        apps.sortBy(_.id).map(AdminStatusInstanceInfoEncoder.encode)*
+      ))
-- 
GitLab