From 09af25f3746cbb906d86dd1c43276a1ee6eb3d7b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Sun, 22 Dec 2024 00:57:25 +0100
Subject: [PATCH] server: Limit total number of instances

---
 .../scala/cs214/webapp/server/web/WebServer.scala   | 13 +++++++++----
 .../cs214/webapp/server/web/WebServerRoutes.scala   |  7 +++++--
 2 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
index 69e96b5..27dffdd 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServer.scala
@@ -2,7 +2,7 @@ package cs214.webapp
 package server
 package web
 
-import scala.collection.{concurrent, mutable, immutable}
+import scala.collection.{concurrent, mutable}
 import scala.util.{Failure, Success, Try}
 import cask.endpoints.WsChannelActor
 
@@ -21,8 +21,12 @@ private case class Wordlist(fname: String):
     import scala.util.Random
     words(Random.nextInt(words.size))
 
+class TooManyInstances extends Exception("Too many instances.")
+
 /** Contains the web server state and functionalities */
 object WebServer:
+  val MAX_INSTANCES = 64000
+
   private val instanceIdsFileName = "eff_short_wordlist_1.txt" // Where to find the list of possible instance ids
   private val instanceIdLength = 4 // How many words do we concatenate for the instance id
   private val instanceIdWords = Wordlist(instanceIdsFileName)
@@ -30,11 +34,10 @@ object WebServer:
   private[web] val debug = false
 
   /** Mapping from app ids to their app class */
-  private[web] val appDirectory: mutable.Map[AppId, ServerAppFactory] = mutable.Map()
+  private[web] val appDirectory = mutable.Map[AppId, ServerAppFactory]()
 
   /** Mapping from app instance ids to their running instances */
-  private[web] val instances: concurrent.TrieMap[InstanceId, ServerApp] =
-    concurrent.TrieMap()
+  private[web] val instances = concurrent.TrieMap[InstanceId, ServerApp]()
 
   /** Registers the given app */
   private def register(app: ServerAppFactory): Unit =
@@ -58,6 +61,8 @@ object WebServer:
 
   /** Creates a new instance of `appId`. */
   private[web] def createInstance(appId: AppId, registeredUserIds: Seq[UserId]): InstanceId =
+    if instances.size >= MAX_INSTANCES then
+      throw TooManyInstances()
     withFreshInstanceId: instanceId =>
       instances(instanceId) = appDirectory(appId).init(instanceId, registeredUserIds)
       println(f"[$appId/$instanceId] instance created")
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
index 24ab947..631e620 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
@@ -69,8 +69,11 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
         cask.Response(f"Unable to decode data: ${request.text()}", 400)
       case Success(CreateInstanceRequest(appName, userIds)) =>
         if WebServer.appDirectory.contains(appName) then
-          val appId = WebServer.createInstance(appName, userIds)
-          CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId))
+          try
+            val appId = WebServer.createInstance(appName, userIds)
+            CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId))
+          catch case tmi: TooManyInstances =>
+            cask.Response(tmi.getMessage(), 429)
         else
           cask.Response(f"Unknown app '$appName'", 400)
 
-- 
GitLab