From 060e8e3b0f5d2145fcdf9b7dcb935be72a98ff23 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Sun, 22 Dec 2024 21:59:20 +0100
Subject: [PATCH] client: Automatically reconnect if websocket connection drops

---
 .../webapp/client/StateMachineClientApp.scala | 54 ++++++++++++++++---
 .../cs214/webapp/server/web/ServerApp.scala   |  2 +-
 2 files changed, 47 insertions(+), 9 deletions(-)

diff --git a/js/src/main/scala/cs214/webapp/client/StateMachineClientApp.scala b/js/src/main/scala/cs214/webapp/client/StateMachineClientApp.scala
index 5c59937..3155739 100644
--- a/js/src/main/scala/cs214/webapp/client/StateMachineClientApp.scala
+++ b/js/src/main/scala/cs214/webapp/client/StateMachineClientApp.scala
@@ -9,23 +9,61 @@ import org.scalajs.dom
 
 type Target = dom.Element
 
+private class WebSocket(endpoint: String):
+  var socket: Option[dom.WebSocket] = None
+
+  val MIN_DELAY_MS = 10d
+  val MAX_DELAY_MS = 5000d
+  var reconnect_delay_ms = MIN_DELAY_MS
+
+  private var listeners: List[String => Unit] = Nil
+  def addListener(listener: String => Unit) =
+    listeners = listener :: listeners
+
+  def send(str: String) = socket match
+    case Some(s) => s.send(str)
+    case None => dom.window.alert("Disconnected!")
+
+  private def onmessage(event: dom.MessageEvent) =
+    listeners.foreach(_(event.data.toString))
+
+  private def reconnect(): Unit =
+    println("[ws] Attempting to connect")
+    val _socket = dom.WebSocket(endpoint)
+
+    def onopen(event: dom.Event) =
+      reconnect_delay_ms = MIN_DELAY_MS
+      println("[ws] WebSocket connection opened")
+    def onclose(event: dom.CloseEvent) =
+      if socket == Some(_socket) then socket = None
+      println(s"[ws] WebSocket connection closed (${event.code}): ${event.reason}")
+      if event.code != 1000 then // Normal closure
+        dom.window.setTimeout(() => reconnect(), reconnect_delay_ms)
+        reconnect_delay_ms = math.min(1.5 * reconnect_delay_ms, MAX_DELAY_MS)
+    def onerror(event: dom.Event) =
+      println(s"[ws] WebSocket error")
+      _socket.close(3000) // Lowest custom error code
+
+    _socket.onopen = evt => onopen(evt)
+    _socket.onclose = evt => onclose(evt)
+    _socket.onerror = evt => onerror(evt)
+    _socket.onmessage = msg => onmessage(msg)
+    socket = Some(_socket)
+
+  reconnect()
+
 abstract class WSClientApp extends ClientApp:
   WebClient.register(this)
 
   protected def init(userId: UserId, sendMessage: ujson.Value => Unit, target: Target): ClientAppInstance
 
   def init(instanceId: InstanceId, userId: UserId, endpoint: String, target: Target): ClientAppInstance =
-    val socket = new dom.WebSocket(endpoint)
-    socket.onopen = (event: dom.Event) => println("WebSocket connection opened")
-    socket.onclose = (event: dom.CloseEvent) => println(s"WebSocket connection closed: ${event.reason}")
-    socket.onerror = (event: dom.Event) => println(s"WebSocket error: ${event.`type`}")
-
+    val socket = WebSocket(endpoint)
     val sendMessage = (js: ujson.Value) => socket.send(js.toString)
     val client = init(userId, sendMessage, target)
-    socket.onmessage = msg =>
-      val js = ujson.read(msg.data.toString)
+    socket.addListener: msg =>
+      val js = ujson.read(msg)
       client.onMessage(SocketResponseWire.decode(js).flatten)
-
     client
 
 /** Instance of a client-side state machine application.
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
index c468c3d..78d9022 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
@@ -100,7 +100,7 @@ private[web] abstract class ServerApp:
     for channels <- channels.values
         channel <- channels
     do
-      channel.send(cask.Ws.Close())
+      channel.send(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown"))
 
   /** Sends a message to a specific client. */
   private def send(userId: UserId)(message: ujson.Value): Unit = instanceLock.synchronized:
-- 
GitLab