From c8be608f4ac9d609193d5b3b9fe246b186973be5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Thu, 5 Dec 2024 15:02:26 +0100
Subject: [PATCH] server: Synchronize accesses to the `sessions` map

This is mostly a stopgap measure until we merge !25.
---
 .../cs214/webapp/server/web/WebSocketsCollection.scala | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala
index 1c2b614..23aab01 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebSocketsCollection.scala
@@ -7,7 +7,7 @@ import scala.util.{Success, Failure, Try}
 /** A collection of websockets organized by app Id and client */
 private[web] final class WebSocketsCollection():
   // When the user joins the app, the projection of the current state is sent to them
-  private def onClientConnect(instanceId: InstanceId, userId: UserId): Unit =
+  private def onClientConnect(instanceId: InstanceId, userId: UserId): Unit = sessions.synchronized:
     val instance = WebServer.apps(instanceId)
     val newInstance = instance.copy(connectedUsersCount = instance.connectedUsersCount + 1)
     WebServer.apps(instanceId) = newInstance
@@ -16,7 +16,7 @@ private[web] final class WebSocketsCollection():
       EventResponse.Wire.encode(Success(List(Action.Render(js))))
     println(f"[${newInstance.instance.appInfo.id}][$instanceId] client \"$userId\" connected")
 
-  private def onClientDisconnect(instanceId: InstanceId, userId: UserId): Unit =
+  private def onClientDisconnect(instanceId: InstanceId, userId: UserId): Unit = sessions.synchronized:
     val instance = WebServer.apps(instanceId)
     val newInstance = instance.copy(connectedUsersCount = instance.connectedUsersCount - 1)
     WebServer.apps(instanceId) = newInstance
@@ -31,12 +31,12 @@ private[web] final class WebSocketsCollection():
     mutable.Map()
 
   /** Initialize an empty session list for the given app instance */
-  def initializeApp(instanceId: InstanceId) =
+  def initializeApp(instanceId: InstanceId) = sessions.synchronized:
     require(!sessions.contains(instanceId))
     sessions(instanceId) = Seq.empty
 
   def connect(instanceId: String, userId: String)
-             (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult =
+             (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult = sessions.synchronized:
     if WebServer.apps.contains(instanceId) then
       cask.WsHandler: channel =>
         sessions(instanceId) = sessions(instanceId) :+ (userId, channel)
@@ -54,7 +54,7 @@ private[web] final class WebSocketsCollection():
       cask.Response(f"Unknown instance id $instanceId", 400)
 
   /** Enumerates clients connected to [[instanceId]]. */
-  def connectedClients(instanceId: InstanceId): Seq[UserId] =
+  def connectedClients(instanceId: InstanceId): Seq[UserId] = sessions.synchronized:
     sessions.get(instanceId).map(_.map(_._1).distinct).getOrElse(Seq())
 
   /** Sends a message to a specific client. */
-- 
GitLab