Skip to content
Snippets Groups Projects
Commit 62fb2505 authored by Clément Pit-Claudel's avatar Clément Pit-Claudel
Browse files

Run the HTTP and WebSocket servers on the same port

Use Cask's built-in websocket implementation instead of a separate
`java_websocket` server.

* jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala:
  Remove the `WS_PORT` parameter.
* jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala:
  (getAppInfo): Use `HTTP_PORT` instead of `WS_PORT` in `wsEndpoint`
  (websocket): New endpoint.
* jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala:
  Remove `java_websocket` imports.
  Change from `WebSocket` to `cask.WsChannelActor`.
  Rename `appId` to `instanceId` where appropriate.
  (WebSocketsCollection): Remove the `port` parameter.
  (connect): New function, replacing the previous `WebSocketServer` instance.
* shared/src/main/scala/cs214/webapp/Common.scala:
  (WS_PORT): Remove.
  (WebSocket): New endpoint.
parent 8970451f
No related branches found
No related tags found
1 merge request!19Add support for https and wss when using domain names
......@@ -33,7 +33,7 @@ object WebServer:
/** Mapping from app instance ids to the attached running clocks */
private[web] val clocks: concurrent.Map[InstanceId, RunningClock] = concurrent.TrieMap()
private[web] lazy val webSocketServer: WebSocketsCollection = new WebSocketsCollection(Config.WS_PORT)
private[web] lazy val webSocketServer: WebSocketsCollection = WebSocketsCollection()
/** Registers the given state-machine based app. */
def register[E, V, S](sm: StateMachine[E, V, S]): Unit =
......
......@@ -53,8 +53,8 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
val response: cask.Response[JsonData] =
if WebServer.apps.contains(instanceId) then
val app = WebServer.apps(instanceId).instance
val shareUrl = f"http://$hostAddress:${Config.HTTP_PORT}/app/${app.appInfo.id}/$instanceId/"
val wsEndpoint = f"ws://{{hostName}}:${Config.WS_PORT}/$instanceId/{{userId}}"
val shareUrl = f"http://$hostAddress:${Config.HTTP_PORT}${Endpoints.Api.root}/${app.appInfo.id}/$instanceId/"
val wsEndpoint = f"ws://{{hostName}}:${Config.HTTP_PORT}${Endpoints.WebSocket}/$instanceId/{{userId}}"
val response = InstanceInfoResponse(instanceId, app.registeredUsers, wsEndpoint, shareUrl)
InstanceInfoResponse.Wire.encode(response)
else
......@@ -72,5 +72,8 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId))
response
WebServer.webSocketServer.run()
@cask.websocket(f"/${Endpoints.WebSocket}/:instanceId/:userId")
def websocket(instanceId: String, userId: String): cask.WebsocketResult =
WebServer.webSocketServer.connect(instanceId, userId)
initialize()
package cs214.webapp
package server.web
import org.java_websocket.WebSocket
import org.java_websocket.handshake.ClientHandshake
import org.java_websocket.server.WebSocketServer
import java.net.{InetSocketAddress, URLDecoder}
import scala.collection.mutable
import scala.util.{Success, Failure, Try}
/** A collection of websockets organized by app Id and client */
private[web] final class WebSocketsCollection(val port: Int):
def onMessageReceive(appId: InstanceId, uid: UserId, msg: ujson.Value): Unit =
WebServer.handleMessage(appId, uid, msg)
private[web] final class WebSocketsCollection():
// When the user joins the app, the projection of the current state is sent to them
def onClientConnect(instanceId: InstanceId, userId: UserId): Unit =
private def onClientConnect(instanceId: InstanceId, userId: UserId): Unit =
val instance = WebServer.apps(instanceId)
val newInstance = instance.copy(connectedUsersCount = instance.connectedUsersCount + 1)
WebServer.apps(instanceId) = newInstance
......@@ -23,7 +16,7 @@ private[web] final class WebSocketsCollection(val port: Int):
EventResponse.Wire.encode(Success(List(Action.Render(js))))
println(f"[${newInstance.instance.appInfo.id}][$instanceId] client \"$userId\" connected")
def onClientDisconnect(instanceId: InstanceId, userId: UserId): Unit =
private def onClientDisconnect(instanceId: InstanceId, userId: UserId): Unit =
val instance = WebServer.apps(instanceId)
val newInstance = instance.copy(connectedUsersCount = instance.connectedUsersCount - 1)
WebServer.apps(instanceId) = newInstance
......@@ -32,9 +25,9 @@ private[web] final class WebSocketsCollection(val port: Int):
println(f"[$appId][$instanceId] client \"$userId\" disconnected")
/** All the sessions currently in use, mapping instance ids to an actual
* socket object that a client owns
* channel objects that a client owns
*/
private val sessions: mutable.Map[InstanceId, Seq[(UserId, WebSocket)]] =
private val sessions: mutable.Map[InstanceId, Seq[(UserId, cask.WsChannelActor)]] =
mutable.Map()
/** Initialize an empty session list for the given app instance */
......@@ -42,67 +35,34 @@ private[web] final class WebSocketsCollection(val port: Int):
require(!sessions.contains(instanceId))
sessions(instanceId) = Seq.empty
/** Runs k with parameters parsed from the websocket connection path. [[k]] is
* run while synchronizing on [[sessions]]
*
* The connection should be on "ws://…/[app_instance_id]/[user_id]" for the
* parsing to function properly.
*/
private def withSessionParams[T](socket: WebSocket)(k: (InstanceId, UserId) => T) =
sessions.synchronized:
try
val components = socket.getResourceDescriptor.split("/").takeRight(2)
val decoded = components.map(s => URLDecoder.decode(s, "UTF-8"))
decoded match
case Array(appId, userId) =>
if !sessions.contains(appId) then
throw IllegalArgumentException("Error: Invalid app ID")
k(appId, userId)
case _ => throw Exception("Error: Invalid path")
catch
case t =>
socket.send(SocketResponseWire.encode(util.Failure(t)).toString)
socket.close()
/** A single websocket server handling multiple apps and clients. */
private val server: WebSocketServer = new WebSocketServer(InetSocketAddress("0.0.0.0", port)):
override def onOpen(socket: WebSocket, handshake: ClientHandshake): Unit =
withSessionParams(socket): (appId, clientId) =>
sessions(appId) = sessions(appId) :+ (clientId, socket)
onClientConnect(appId, clientId)
override def onClose(socket: WebSocket, code: Int, reason: String, remote: Boolean): Unit =
withSessionParams(socket): (appId, clientId) =>
sessions(appId) = sessions(appId).filter(_._2 != socket) // Unregister the session
onClientDisconnect(appId, clientId)
override def onMessage(socket: WebSocket, message: String): Unit =
withSessionParams(socket): (appId, clientId) =>
onMessageReceive(appId, clientId, ujson.read(message))
override def onError(socket: WebSocket, ex: Exception): Unit =
// Only report the error, onClosed is called even when an error occurs
throw new RuntimeException(ex)
override def onStart(): Unit =
val addr = server.getAddress
println(s"[WebSocket] server started on ${addr.getHostName}:${addr.getPort}.")
/** Starts the server asynchronously. */
def run(): Unit =
server.setReuseAddr(true) // Ignore leftover connections from pending processes
Thread(() => server.run(), "Socket Thread").start()
/** Enumerates clients connected to [[appId]]. */
def connectedClients(appId: InstanceId): Seq[UserId] =
sessions.get(appId).map(_.map(_._1).distinct).getOrElse(Seq())
def connect(instanceId: String, userId: String)
(implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult =
if WebServer.apps.contains(instanceId) then
cask.WsHandler: channel =>
sessions(instanceId) = sessions(instanceId) :+ (userId, channel)
onClientConnect(instanceId, userId)
cask.WsActor {
case cask.Ws.Error(e) =>
throw e
case cask.Ws.Close(code, reason) =>
sessions(instanceId) = sessions(instanceId).filter(_._2 != channel) // Unregister the session
onClientDisconnect(instanceId, userId)
case cask.Ws.Text(data) =>
WebServer.handleMessage(instanceId, userId, ujson.read(data))
}
else
cask.Response(f"Unknown instance id $instanceId", 400)
/** Enumerates clients connected to [[instanceId]]. */
def connectedClients(instanceId: InstanceId): Seq[UserId] =
sessions.get(instanceId).map(_.map(_._1).distinct).getOrElse(Seq())
/** Sends a message to a specific client. */
def send(appId: InstanceId, clientId: UserId)(message: ujson.Value) =
def send(instanceId: InstanceId, userId: UserId)(message: ujson.Value) =
val wrapped = SocketResponseWire.encode(util.Success(message)).toString
sessions.synchronized:
for
(userId, socket) <- sessions.getOrElse(appId, Seq.empty)
if userId == clientId
(userId, channel) <- sessions.getOrElse(instanceId, Seq.empty)
if userId == userId
do
socket.send(wrapped)
channel.send(cask.Ws.Text(wrapped))
package cs214.webapp
object Config:
/** Which port the websocket server uses */
val WS_PORT = 9090
/** Which port the HTTP server uses */
val HTTP_PORT = 8080
......@@ -33,14 +30,16 @@ trait WireFormat[T] extends Encoder[T] with Decoder[T]
/** HTTP endpoints */
object Endpoints:
sealed abstract case class Api(path: String):
val root = "/api"
override def toString = f"$root/$path"
override def toString = f"${Api.root}/$path"
object Api:
val root = "/api"
object listApps extends Api("list-apps")
object createInstance extends Api("create-instance")
object instanceInfo extends Api("instance-info")
val WebSocket = "/ws"
trait RegistrationProvider:
def register(): Unit
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment