From f386da46ab5269745471b947ba0b813c155f5afd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch> Date: Fri, 27 Dec 2024 17:03:16 +0100 Subject: [PATCH] tests: Check that server survives abrupt disconnection This test currently shows exceptions while running; next commit will fix. --- .../test/scala/cs214/webapp/AppSuite.scala | 80 +++++++++++++++++-- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/jvm/src/test/scala/cs214/webapp/AppSuite.scala b/jvm/src/test/scala/cs214/webapp/AppSuite.scala index d36717e..9549259 100644 --- a/jvm/src/test/scala/cs214/webapp/AppSuite.scala +++ b/jvm/src/test/scala/cs214/webapp/AppSuite.scala @@ -2,6 +2,8 @@ package cs214.webapp package server package web +import java.net.http.{WebSocket => JWebSocket} + import scala.util.{Try, Success, Failure} import scala.concurrent.duration.Duration import scala.concurrent.{Await, Future} @@ -11,10 +13,10 @@ import ujson.Value import sttp.ws.WebSocket import sttp.client4.* import sttp.client4.ws.sync.* +import sttp.client4.ws.SyncWebSocket import cs214.webapp.utils.{*, given} import cs214.webapp.server.StateMachine -import sttp.client4.ws.SyncWebSocket object PingPongSuite: type Ping = String @@ -79,29 +81,91 @@ class WebSocketSuite extends PingPongSuite: app.wire.viewFormat.decode(viewJs).getOrElse: fail(f"Cannot decode view $viewJs") - def withWs[T](instanceId: InstanceId)(body: SyncWebSocket => T)(using server: WebServerInfo): Response[T] = + def withWs[T](instanceId: InstanceId, userId: String) + (body: SyncWebSocket => T) + (using server: WebServerInfo): Response[T] = val instInfo = instanceInfo(instanceId) - val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, UID0) + val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, userId) quickRequest .get(uri"$wsEndpoint") .header("Origin", f"$server") .response(asWebSocketAlways(body)) .send(backend) + extension (ws: SyncWebSocket) + def assertHello()(implicit loc: munit.Location) = + val initial = decodeActions(ws.receiveText()) + assertEquals(initial, Seq(Action.Render("hello"))) + test("ws: The ping server sends a welcome message over web sockets"): withServer: server ?=> val inst = createInstance(USER_IDS) - withWs(inst.instanceId): ws => - val initial = decodeActions(ws.receiveText()) - assertEquals(initial, Seq(Action.Render("hello"))) + withWs(inst.instanceId, UID0): ws => + ws.assertHello() test("ws: The ping server echoes events"): withServer: server ?=> val inst = createInstance(USER_IDS) - withWs(inst.instanceId): ws => - ws.receiveText() + withWs(inst.instanceId, UID0): ws => + ws.assertHello() time: for i <- 0 to 1000 do ws.sendText(app.wire.eventFormat.encode(i.toString).toString) val resp = decodeActions(ws.receiveText()) assertEquals(resp, Seq(Action.Render(i.toString))) + + test("ws: Two websockets see each other"): + withServer: server ?=> + val inst = createInstance(USER_IDS) + withWs(inst.instanceId, UID0): ws0 => + withWs(inst.instanceId, UID1): ws1 => + ws0.assertHello() + ws1.assertHello() + + val (ma, mb) = ("a", "b") + ws0.sendText(app.wire.eventFormat.encode("a").toString) + val r0a = decodeActions(ws0.receiveText()) + ws1.sendText(app.wire.eventFormat.encode("b").toString) + val r0b = decodeActions(ws0.receiveText()) + + val (r1a, r1b) = (decodeActions(ws1.receiveText()), decodeActions(ws1.receiveText())) + + assertEquals(r0a, r1a) + assertEquals(r0b, r1b) + + assertEquals(r0a, Seq(Action.Render(ma))) + assertEquals(r0b, Seq(Action.Render(mb))) + + def withJavaWs[T](instanceId: InstanceId, userId: String, listener: JWebSocket.Listener) + (body: JWebSocket => T) + (using server: WebServerInfo): T = + val instInfo = instanceInfo(instanceId) + val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, userId) + + val client = java.net.http.HttpClient.newHttpClient() + val ws = client.newWebSocketBuilder() + .header("Origin", f"$server") + .buildAsync(java.net.URI.create(wsEndpoint), listener) + .join() + body(ws) + + test("ws: Webserver survives abrupt disconnection"): + withServer: server ?=> + val inst = createInstance(USER_IDS) + + val listener = new JWebSocket.Listener {} + + withJavaWs(inst.instanceId, UID1, listener): wsJ => + withWs(inst.instanceId, UID0): ws => + ws.assertHello() + + ws.sendText(app.wire.eventFormat.encode("a").toString) + val r0a = decodeActions(ws.receiveText()) + assertEquals(r0a, Seq(Action.Render("a"))) + + // Unclean disconnect + wsJ.abort() + + ws.sendText(app.wire.eventFormat.encode("b").toString) + val r0b = decodeActions(ws.receiveText()) + assertEquals(r0b, Seq(Action.Render("b"))) -- GitLab