From f386da46ab5269745471b947ba0b813c155f5afd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Fri, 27 Dec 2024 17:03:16 +0100
Subject: [PATCH] tests: Check that server survives abrupt disconnection

This test currently shows exceptions while running; next commit will fix.
---
 .../test/scala/cs214/webapp/AppSuite.scala    | 80 +++++++++++++++++--
 1 file changed, 72 insertions(+), 8 deletions(-)

diff --git a/jvm/src/test/scala/cs214/webapp/AppSuite.scala b/jvm/src/test/scala/cs214/webapp/AppSuite.scala
index d36717e..9549259 100644
--- a/jvm/src/test/scala/cs214/webapp/AppSuite.scala
+++ b/jvm/src/test/scala/cs214/webapp/AppSuite.scala
@@ -2,6 +2,8 @@ package cs214.webapp
 package server
 package web
 
+import java.net.http.{WebSocket => JWebSocket}
+
 import scala.util.{Try, Success, Failure}
 import scala.concurrent.duration.Duration
 import scala.concurrent.{Await, Future}
@@ -11,10 +13,10 @@ import ujson.Value
 import sttp.ws.WebSocket
 import sttp.client4.*
 import sttp.client4.ws.sync.*
+import sttp.client4.ws.SyncWebSocket
 
 import cs214.webapp.utils.{*, given}
 import cs214.webapp.server.StateMachine
-import sttp.client4.ws.SyncWebSocket
 
 object PingPongSuite:
   type Ping = String
@@ -79,29 +81,91 @@ class WebSocketSuite extends PingPongSuite:
             app.wire.viewFormat.decode(viewJs).getOrElse:
               fail(f"Cannot decode view $viewJs")
 
-  def withWs[T](instanceId: InstanceId)(body: SyncWebSocket => T)(using server: WebServerInfo): Response[T] =
+  def withWs[T](instanceId: InstanceId, userId: String)
+               (body: SyncWebSocket => T)
+               (using server: WebServerInfo): Response[T] =
     val instInfo = instanceInfo(instanceId)
-    val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, UID0)
+    val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, userId)
     quickRequest
       .get(uri"$wsEndpoint")
       .header("Origin", f"$server")
       .response(asWebSocketAlways(body))
       .send(backend)
 
+  extension (ws: SyncWebSocket)
+    def assertHello()(implicit loc: munit.Location) =
+      val initial = decodeActions(ws.receiveText())
+      assertEquals(initial, Seq(Action.Render("hello")))
+
   test("ws: The ping server sends a welcome message over web sockets"):
     withServer: server ?=>
       val inst = createInstance(USER_IDS)
-      withWs(inst.instanceId): ws =>
-        val initial = decodeActions(ws.receiveText())
-        assertEquals(initial, Seq(Action.Render("hello")))
+      withWs(inst.instanceId, UID0): ws =>
+        ws.assertHello()
 
   test("ws: The ping server echoes events"):
     withServer: server ?=>
       val inst = createInstance(USER_IDS)
-      withWs(inst.instanceId): ws =>
-        ws.receiveText()
+      withWs(inst.instanceId, UID0): ws =>
+        ws.assertHello()
         time:
           for i <- 0 to 1000 do
             ws.sendText(app.wire.eventFormat.encode(i.toString).toString)
             val resp = decodeActions(ws.receiveText())
             assertEquals(resp, Seq(Action.Render(i.toString)))
+
+  test("ws: Two websockets see each other"):
+    withServer: server ?=>
+      val inst = createInstance(USER_IDS)
+      withWs(inst.instanceId, UID0): ws0 =>
+        withWs(inst.instanceId, UID1): ws1 =>
+          ws0.assertHello()
+          ws1.assertHello()
+
+          val (ma, mb) = ("a", "b")
+          ws0.sendText(app.wire.eventFormat.encode("a").toString)
+          val r0a = decodeActions(ws0.receiveText())
+          ws1.sendText(app.wire.eventFormat.encode("b").toString)
+          val r0b = decodeActions(ws0.receiveText())
+
+          val (r1a, r1b) = (decodeActions(ws1.receiveText()), decodeActions(ws1.receiveText()))
+
+          assertEquals(r0a, r1a)
+          assertEquals(r0b, r1b)
+
+          assertEquals(r0a, Seq(Action.Render(ma)))
+          assertEquals(r0b, Seq(Action.Render(mb)))
+
+  def withJavaWs[T](instanceId: InstanceId, userId: String, listener: JWebSocket.Listener)
+                   (body: JWebSocket => T)
+                   (using server: WebServerInfo): T =
+    val instInfo = instanceInfo(instanceId)
+    val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, userId)
+
+    val client = java.net.http.HttpClient.newHttpClient()
+    val ws = client.newWebSocketBuilder()
+      .header("Origin", f"$server")
+      .buildAsync(java.net.URI.create(wsEndpoint), listener)
+      .join()
+    body(ws)
+
+  test("ws: Webserver survives abrupt disconnection"):
+    withServer: server ?=>
+      val inst = createInstance(USER_IDS)
+
+      val listener = new JWebSocket.Listener {}
+
+      withJavaWs(inst.instanceId, UID1, listener): wsJ =>
+        withWs(inst.instanceId, UID0): ws =>
+          ws.assertHello()
+
+          ws.sendText(app.wire.eventFormat.encode("a").toString)
+          val r0a = decodeActions(ws.receiveText())
+          assertEquals(r0a, Seq(Action.Render("a")))
+
+          // Unclean disconnect
+          wsJ.abort()
+
+          ws.sendText(app.wire.eventFormat.encode("b").toString)
+          val r0b = decodeActions(ws.receiveText())
+          assertEquals(r0b, Seq(Action.Render("b")))
-- 
GitLab