From 62fb2505cac10d14e1c75496f059c082bf7c8796 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Thu, 28 Nov 2024 00:06:26 +0100
Subject: [PATCH] Run the HTTP and WebSocket servers on the same port

Use Cask's built-in websocket implementation instead of a separate
`java_websocket` server.

* jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala:
  Remove the `WS_PORT` parameter.
* jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala:
  (getAppInfo): Use `HTTP_PORT` instead of `WS_PORT` in `wsEndpoint`
  (websocket): New endpoint.
* jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala:
  Remove `java_websocket` imports.
  Change from `WebSocket` to `cask.WsChannelActor`.
  Rename `appId` to `instanceId` where appropriate.
  (WebSocketsCollection): Remove the `port` parameter.
  (connect): New function, replacing the previous `WebSocketServer` instance.
* shared/src/main/scala/cs214/webapp/Common.scala:
  (WS_PORT): Remove.
  (WebSocket): New endpoint.
---
 .../cs214/webapp/server/web/WebServer.scala   |   2 +-
 .../webapp/server/web/WebServerRoutes.scala   |   9 +-
 .../server/web/WebSocketsCollection.scala     | 100 ++++++------------
 .../src/main/scala/cs214/webapp/Common.scala  |   9 +-
 4 files changed, 41 insertions(+), 79 deletions(-)

diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
index 9fc9753..81dcc93 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
@@ -33,7 +33,7 @@ object WebServer:
   /** Mapping from app instance ids to the attached running clocks */
   private[web] val clocks: concurrent.Map[InstanceId, RunningClock] = concurrent.TrieMap()
 
-  private[web] lazy val webSocketServer: WebSocketsCollection = new WebSocketsCollection(Config.WS_PORT)
+  private[web] lazy val webSocketServer: WebSocketsCollection = WebSocketsCollection()
 
   /** Registers the given state-machine based app. */
   def register[E, V, S](sm: StateMachine[E, V, S]): Unit =
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
index ae1ada9..e42a9ac 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
@@ -53,8 +53,8 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
     val response: cask.Response[JsonData] =
       if WebServer.apps.contains(instanceId) then
         val app = WebServer.apps(instanceId).instance
-        val shareUrl = f"http://$hostAddress:${Config.HTTP_PORT}/app/${app.appInfo.id}/$instanceId/"
-        val wsEndpoint = f"ws://{{hostName}}:${Config.WS_PORT}/$instanceId/{{userId}}"
+        val shareUrl = f"http://$hostAddress:${Config.HTTP_PORT}${Endpoints.Api.root}/${app.appInfo.id}/$instanceId/"
+        val wsEndpoint = f"ws://{{hostName}}:${Config.HTTP_PORT}${Endpoints.WebSocket}/$instanceId/{{userId}}"
         val response = InstanceInfoResponse(instanceId, app.registeredUsers, wsEndpoint, shareUrl)
         InstanceInfoResponse.Wire.encode(response)
       else
@@ -72,5 +72,8 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
         CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId))
     response
 
-  WebServer.webSocketServer.run()
+  @cask.websocket(f"/${Endpoints.WebSocket}/:instanceId/:userId")
+  def websocket(instanceId: String, userId: String): cask.WebsocketResult =
+    WebServer.webSocketServer.connect(instanceId, userId)
+
   initialize()
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala
index c874a46..1c2b614 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala
@@ -1,20 +1,13 @@
 package cs214.webapp
 package server.web
 
-import org.java_websocket.WebSocket
-import org.java_websocket.handshake.ClientHandshake
-import org.java_websocket.server.WebSocketServer
-import java.net.{InetSocketAddress, URLDecoder}
 import scala.collection.mutable
 import scala.util.{Success, Failure, Try}
 
 /** A collection of websockets organized by app Id and client */
-private[web] final class WebSocketsCollection(val port: Int):
-  def onMessageReceive(appId: InstanceId, uid: UserId, msg: ujson.Value): Unit =
-    WebServer.handleMessage(appId, uid, msg)
-
+private[web] final class WebSocketsCollection():
   // When the user joins the app, the projection of the current state is sent to them
-  def onClientConnect(instanceId: InstanceId, userId: UserId): Unit =
+  private def onClientConnect(instanceId: InstanceId, userId: UserId): Unit =
     val instance = WebServer.apps(instanceId)
     val newInstance = instance.copy(connectedUsersCount = instance.connectedUsersCount + 1)
     WebServer.apps(instanceId) = newInstance
@@ -23,7 +16,7 @@ private[web] final class WebSocketsCollection(val port: Int):
       EventResponse.Wire.encode(Success(List(Action.Render(js))))
     println(f"[${newInstance.instance.appInfo.id}][$instanceId] client \"$userId\" connected")
 
-  def onClientDisconnect(instanceId: InstanceId, userId: UserId): Unit =
+  private def onClientDisconnect(instanceId: InstanceId, userId: UserId): Unit =
     val instance = WebServer.apps(instanceId)
     val newInstance = instance.copy(connectedUsersCount = instance.connectedUsersCount - 1)
     WebServer.apps(instanceId) = newInstance
@@ -32,9 +25,9 @@ private[web] final class WebSocketsCollection(val port: Int):
     println(f"[$appId][$instanceId] client \"$userId\" disconnected")
 
   /** All the sessions currently in use, mapping instance ids to an actual
-    * socket object that a client owns
+    * channel objects that a client owns
     */
-  private val sessions: mutable.Map[InstanceId, Seq[(UserId, WebSocket)]] =
+  private val sessions: mutable.Map[InstanceId, Seq[(UserId, cask.WsChannelActor)]] =
     mutable.Map()
 
   /** Initialize an empty session list for the given app instance */
@@ -42,67 +35,34 @@ private[web] final class WebSocketsCollection(val port: Int):
     require(!sessions.contains(instanceId))
     sessions(instanceId) = Seq.empty
 
-  /** Runs k with parameters parsed from the websocket connection path. [[k]] is
-    * run while synchronizing on [[sessions]]
-    *
-    * The connection should be on "ws://…/[app_instance_id]/[user_id]" for the
-    * parsing to function properly.
-    */
-  private def withSessionParams[T](socket: WebSocket)(k: (InstanceId, UserId) => T) =
-    sessions.synchronized:
-      try
-        val components = socket.getResourceDescriptor.split("/").takeRight(2)
-        val decoded = components.map(s => URLDecoder.decode(s, "UTF-8"))
-        decoded match
-          case Array(appId, userId) =>
-            if !sessions.contains(appId) then
-              throw IllegalArgumentException("Error: Invalid app ID")
-            k(appId, userId)
-          case _ => throw Exception("Error: Invalid path")
-      catch
-        case t =>
-          socket.send(SocketResponseWire.encode(util.Failure(t)).toString)
-          socket.close()
-
-  /** A single websocket server handling multiple apps and clients. */
-  private val server: WebSocketServer = new WebSocketServer(InetSocketAddress("0.0.0.0", port)):
-    override def onOpen(socket: WebSocket, handshake: ClientHandshake): Unit =
-      withSessionParams(socket): (appId, clientId) =>
-        sessions(appId) = sessions(appId) :+ (clientId, socket)
-        onClientConnect(appId, clientId)
-
-    override def onClose(socket: WebSocket, code: Int, reason: String, remote: Boolean): Unit =
-      withSessionParams(socket): (appId, clientId) =>
-        sessions(appId) = sessions(appId).filter(_._2 != socket) // Unregister the session
-        onClientDisconnect(appId, clientId)
-
-    override def onMessage(socket: WebSocket, message: String): Unit =
-      withSessionParams(socket): (appId, clientId) =>
-        onMessageReceive(appId, clientId, ujson.read(message))
-
-    override def onError(socket: WebSocket, ex: Exception): Unit =
-      // Only report the error, onClosed is called even when an error occurs
-      throw new RuntimeException(ex)
-
-    override def onStart(): Unit =
-      val addr = server.getAddress
-      println(s"[WebSocket] server started on ${addr.getHostName}:${addr.getPort}.")
-
-  /** Starts the server asynchronously. */
-  def run(): Unit =
-    server.setReuseAddr(true) // Ignore leftover connections from pending processes
-    Thread(() => server.run(), "Socket Thread").start()
-
-  /** Enumerates clients connected to [[appId]]. */
-  def connectedClients(appId: InstanceId): Seq[UserId] =
-    sessions.get(appId).map(_.map(_._1).distinct).getOrElse(Seq())
+  def connect(instanceId: String, userId: String)
+             (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult =
+    if WebServer.apps.contains(instanceId) then
+      cask.WsHandler: channel =>
+        sessions(instanceId) = sessions(instanceId) :+ (userId, channel)
+        onClientConnect(instanceId, userId)
+        cask.WsActor {
+          case cask.Ws.Error(e) =>
+            throw e
+          case cask.Ws.Close(code, reason) =>
+            sessions(instanceId) = sessions(instanceId).filter(_._2 != channel) // Unregister the session
+            onClientDisconnect(instanceId, userId)
+          case cask.Ws.Text(data) =>
+            WebServer.handleMessage(instanceId, userId, ujson.read(data))
+        }
+    else
+      cask.Response(f"Unknown instance id $instanceId", 400)
+
+  /** Enumerates clients connected to [[instanceId]]. */
+  def connectedClients(instanceId: InstanceId): Seq[UserId] =
+    sessions.get(instanceId).map(_.map(_._1).distinct).getOrElse(Seq())
 
   /** Sends a message to a specific client. */
-  def send(appId: InstanceId, clientId: UserId)(message: ujson.Value) =
+  def send(instanceId: InstanceId, userId: UserId)(message: ujson.Value) =
     val wrapped = SocketResponseWire.encode(util.Success(message)).toString
     sessions.synchronized:
       for
-        (userId, socket) <- sessions.getOrElse(appId, Seq.empty)
-        if userId == clientId
+        (userId, channel) <- sessions.getOrElse(instanceId, Seq.empty)
+        if userId == userId
       do
-        socket.send(wrapped)
+        channel.send(cask.Ws.Text(wrapped))
diff --git a/shared/src/main/scala/cs214/webapp/Common.scala b/shared/src/main/scala/cs214/webapp/Common.scala
index 18b5597..b868339 100644
--- a/shared/src/main/scala/cs214/webapp/Common.scala
+++ b/shared/src/main/scala/cs214/webapp/Common.scala
@@ -1,9 +1,6 @@
 package cs214.webapp
 
 object Config:
-  /** Which port the websocket server uses */
-  val WS_PORT = 9090
-
   /** Which port the HTTP server uses */
   val HTTP_PORT = 8080
 
@@ -33,14 +30,16 @@ trait WireFormat[T] extends Encoder[T] with Decoder[T]
 /** HTTP endpoints */
 object Endpoints:
   sealed abstract case class Api(path: String):
-    val root = "/api"
-    override def toString = f"$root/$path"
+    override def toString = f"${Api.root}/$path"
 
   object Api:
+    val root = "/api"
     object listApps extends Api("list-apps")
     object createInstance extends Api("create-instance")
     object instanceInfo extends Api("instance-info")
 
+  val WebSocket = "/ws"
+
 trait RegistrationProvider:
   def register(): Unit
 
-- 
GitLab