From 62fb2505cac10d14e1c75496f059c082bf7c8796 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch> Date: Thu, 28 Nov 2024 00:06:26 +0100 Subject: [PATCH] Run the HTTP and WebSocket servers on the same port Use Cask's built-in websocket implementation instead of a separate `java_websocket` server. * jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala: Remove the `WS_PORT` parameter. * jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala: (getAppInfo): Use `HTTP_PORT` instead of `WS_PORT` in `wsEndpoint` (websocket): New endpoint. * jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala: Remove `java_websocket` imports. Change from `WebSocket` to `cask.WsChannelActor`. Rename `appId` to `instanceId` where appropriate. (WebSocketsCollection): Remove the `port` parameter. (connect): New function, replacing the previous `WebSocketServer` instance. * shared/src/main/scala/cs214/webapp/Common.scala: (WS_PORT): Remove. (WebSocket): New endpoint. --- .../cs214/webapp/server/web/WebServer.scala | 2 +- .../webapp/server/web/WebServerRoutes.scala | 9 +- .../server/web/WebSocketsCollection.scala | 100 ++++++------------ .../src/main/scala/cs214/webapp/Common.scala | 9 +- 4 files changed, 41 insertions(+), 79 deletions(-) diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala index 9fc9753..81dcc93 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala @@ -33,7 +33,7 @@ object WebServer: /** Mapping from app instance ids to the attached running clocks */ private[web] val clocks: concurrent.Map[InstanceId, RunningClock] = concurrent.TrieMap() - private[web] lazy val webSocketServer: WebSocketsCollection = new WebSocketsCollection(Config.WS_PORT) + private[web] lazy val webSocketServer: WebSocketsCollection = WebSocketsCollection() /** Registers the given state-machine based app. */ def register[E, V, S](sm: StateMachine[E, V, S]): Unit = diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala index ae1ada9..e42a9ac 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala @@ -53,8 +53,8 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log val response: cask.Response[JsonData] = if WebServer.apps.contains(instanceId) then val app = WebServer.apps(instanceId).instance - val shareUrl = f"http://$hostAddress:${Config.HTTP_PORT}/app/${app.appInfo.id}/$instanceId/" - val wsEndpoint = f"ws://{{hostName}}:${Config.WS_PORT}/$instanceId/{{userId}}" + val shareUrl = f"http://$hostAddress:${Config.HTTP_PORT}${Endpoints.Api.root}/${app.appInfo.id}/$instanceId/" + val wsEndpoint = f"ws://{{hostName}}:${Config.HTTP_PORT}${Endpoints.WebSocket}/$instanceId/{{userId}}" val response = InstanceInfoResponse(instanceId, app.registeredUsers, wsEndpoint, shareUrl) InstanceInfoResponse.Wire.encode(response) else @@ -72,5 +72,8 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId)) response - WebServer.webSocketServer.run() + @cask.websocket(f"/${Endpoints.WebSocket}/:instanceId/:userId") + def websocket(instanceId: String, userId: String): cask.WebsocketResult = + WebServer.webSocketServer.connect(instanceId, userId) + initialize() diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala index c874a46..1c2b614 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala @@ -1,20 +1,13 @@ package cs214.webapp package server.web -import org.java_websocket.WebSocket -import org.java_websocket.handshake.ClientHandshake -import org.java_websocket.server.WebSocketServer -import java.net.{InetSocketAddress, URLDecoder} import scala.collection.mutable import scala.util.{Success, Failure, Try} /** A collection of websockets organized by app Id and client */ -private[web] final class WebSocketsCollection(val port: Int): - def onMessageReceive(appId: InstanceId, uid: UserId, msg: ujson.Value): Unit = - WebServer.handleMessage(appId, uid, msg) - +private[web] final class WebSocketsCollection(): // When the user joins the app, the projection of the current state is sent to them - def onClientConnect(instanceId: InstanceId, userId: UserId): Unit = + private def onClientConnect(instanceId: InstanceId, userId: UserId): Unit = val instance = WebServer.apps(instanceId) val newInstance = instance.copy(connectedUsersCount = instance.connectedUsersCount + 1) WebServer.apps(instanceId) = newInstance @@ -23,7 +16,7 @@ private[web] final class WebSocketsCollection(val port: Int): EventResponse.Wire.encode(Success(List(Action.Render(js)))) println(f"[${newInstance.instance.appInfo.id}][$instanceId] client \"$userId\" connected") - def onClientDisconnect(instanceId: InstanceId, userId: UserId): Unit = + private def onClientDisconnect(instanceId: InstanceId, userId: UserId): Unit = val instance = WebServer.apps(instanceId) val newInstance = instance.copy(connectedUsersCount = instance.connectedUsersCount - 1) WebServer.apps(instanceId) = newInstance @@ -32,9 +25,9 @@ private[web] final class WebSocketsCollection(val port: Int): println(f"[$appId][$instanceId] client \"$userId\" disconnected") /** All the sessions currently in use, mapping instance ids to an actual - * socket object that a client owns + * channel objects that a client owns */ - private val sessions: mutable.Map[InstanceId, Seq[(UserId, WebSocket)]] = + private val sessions: mutable.Map[InstanceId, Seq[(UserId, cask.WsChannelActor)]] = mutable.Map() /** Initialize an empty session list for the given app instance */ @@ -42,67 +35,34 @@ private[web] final class WebSocketsCollection(val port: Int): require(!sessions.contains(instanceId)) sessions(instanceId) = Seq.empty - /** Runs k with parameters parsed from the websocket connection path. [[k]] is - * run while synchronizing on [[sessions]] - * - * The connection should be on "ws://…/[app_instance_id]/[user_id]" for the - * parsing to function properly. - */ - private def withSessionParams[T](socket: WebSocket)(k: (InstanceId, UserId) => T) = - sessions.synchronized: - try - val components = socket.getResourceDescriptor.split("/").takeRight(2) - val decoded = components.map(s => URLDecoder.decode(s, "UTF-8")) - decoded match - case Array(appId, userId) => - if !sessions.contains(appId) then - throw IllegalArgumentException("Error: Invalid app ID") - k(appId, userId) - case _ => throw Exception("Error: Invalid path") - catch - case t => - socket.send(SocketResponseWire.encode(util.Failure(t)).toString) - socket.close() - - /** A single websocket server handling multiple apps and clients. */ - private val server: WebSocketServer = new WebSocketServer(InetSocketAddress("0.0.0.0", port)): - override def onOpen(socket: WebSocket, handshake: ClientHandshake): Unit = - withSessionParams(socket): (appId, clientId) => - sessions(appId) = sessions(appId) :+ (clientId, socket) - onClientConnect(appId, clientId) - - override def onClose(socket: WebSocket, code: Int, reason: String, remote: Boolean): Unit = - withSessionParams(socket): (appId, clientId) => - sessions(appId) = sessions(appId).filter(_._2 != socket) // Unregister the session - onClientDisconnect(appId, clientId) - - override def onMessage(socket: WebSocket, message: String): Unit = - withSessionParams(socket): (appId, clientId) => - onMessageReceive(appId, clientId, ujson.read(message)) - - override def onError(socket: WebSocket, ex: Exception): Unit = - // Only report the error, onClosed is called even when an error occurs - throw new RuntimeException(ex) - - override def onStart(): Unit = - val addr = server.getAddress - println(s"[WebSocket] server started on ${addr.getHostName}:${addr.getPort}.") - - /** Starts the server asynchronously. */ - def run(): Unit = - server.setReuseAddr(true) // Ignore leftover connections from pending processes - Thread(() => server.run(), "Socket Thread").start() - - /** Enumerates clients connected to [[appId]]. */ - def connectedClients(appId: InstanceId): Seq[UserId] = - sessions.get(appId).map(_.map(_._1).distinct).getOrElse(Seq()) + def connect(instanceId: String, userId: String) + (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult = + if WebServer.apps.contains(instanceId) then + cask.WsHandler: channel => + sessions(instanceId) = sessions(instanceId) :+ (userId, channel) + onClientConnect(instanceId, userId) + cask.WsActor { + case cask.Ws.Error(e) => + throw e + case cask.Ws.Close(code, reason) => + sessions(instanceId) = sessions(instanceId).filter(_._2 != channel) // Unregister the session + onClientDisconnect(instanceId, userId) + case cask.Ws.Text(data) => + WebServer.handleMessage(instanceId, userId, ujson.read(data)) + } + else + cask.Response(f"Unknown instance id $instanceId", 400) + + /** Enumerates clients connected to [[instanceId]]. */ + def connectedClients(instanceId: InstanceId): Seq[UserId] = + sessions.get(instanceId).map(_.map(_._1).distinct).getOrElse(Seq()) /** Sends a message to a specific client. */ - def send(appId: InstanceId, clientId: UserId)(message: ujson.Value) = + def send(instanceId: InstanceId, userId: UserId)(message: ujson.Value) = val wrapped = SocketResponseWire.encode(util.Success(message)).toString sessions.synchronized: for - (userId, socket) <- sessions.getOrElse(appId, Seq.empty) - if userId == clientId + (userId, channel) <- sessions.getOrElse(instanceId, Seq.empty) + if userId == userId do - socket.send(wrapped) + channel.send(cask.Ws.Text(wrapped)) diff --git a/shared/src/main/scala/cs214/webapp/Common.scala b/shared/src/main/scala/cs214/webapp/Common.scala index 18b5597..b868339 100644 --- a/shared/src/main/scala/cs214/webapp/Common.scala +++ b/shared/src/main/scala/cs214/webapp/Common.scala @@ -1,9 +1,6 @@ package cs214.webapp object Config: - /** Which port the websocket server uses */ - val WS_PORT = 9090 - /** Which port the HTTP server uses */ val HTTP_PORT = 8080 @@ -33,14 +30,16 @@ trait WireFormat[T] extends Encoder[T] with Decoder[T] /** HTTP endpoints */ object Endpoints: sealed abstract case class Api(path: String): - val root = "/api" - override def toString = f"$root/$path" + override def toString = f"${Api.root}/$path" object Api: + val root = "/api" object listApps extends Api("list-apps") object createInstance extends Api("create-instance") object instanceInfo extends Api("instance-info") + val WebSocket = "/ws" + trait RegistrationProvider: def register(): Unit -- GitLab