Skip to content
Snippets Groups Projects
Commit f386da46 authored by Clément Pit-Claudel's avatar Clément Pit-Claudel
Browse files

tests: Check that server survives abrupt disconnection

This test currently shows exceptions while running; next commit will fix.
parent 67b813cd
Branches
Tags
1 merge request!40Properly handle sudden websocket disconnections
...@@ -2,6 +2,8 @@ package cs214.webapp ...@@ -2,6 +2,8 @@ package cs214.webapp
package server package server
package web package web
import java.net.http.{WebSocket => JWebSocket}
import scala.util.{Try, Success, Failure} import scala.util.{Try, Success, Failure}
import scala.concurrent.duration.Duration import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future} import scala.concurrent.{Await, Future}
...@@ -11,10 +13,10 @@ import ujson.Value ...@@ -11,10 +13,10 @@ import ujson.Value
import sttp.ws.WebSocket import sttp.ws.WebSocket
import sttp.client4.* import sttp.client4.*
import sttp.client4.ws.sync.* import sttp.client4.ws.sync.*
import sttp.client4.ws.SyncWebSocket
import cs214.webapp.utils.{*, given} import cs214.webapp.utils.{*, given}
import cs214.webapp.server.StateMachine import cs214.webapp.server.StateMachine
import sttp.client4.ws.SyncWebSocket
object PingPongSuite: object PingPongSuite:
type Ping = String type Ping = String
...@@ -79,29 +81,91 @@ class WebSocketSuite extends PingPongSuite: ...@@ -79,29 +81,91 @@ class WebSocketSuite extends PingPongSuite:
app.wire.viewFormat.decode(viewJs).getOrElse: app.wire.viewFormat.decode(viewJs).getOrElse:
fail(f"Cannot decode view $viewJs") fail(f"Cannot decode view $viewJs")
def withWs[T](instanceId: InstanceId)(body: SyncWebSocket => T)(using server: WebServerInfo): Response[T] = def withWs[T](instanceId: InstanceId, userId: String)
(body: SyncWebSocket => T)
(using server: WebServerInfo): Response[T] =
val instInfo = instanceInfo(instanceId) val instInfo = instanceInfo(instanceId)
val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, UID0) val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, userId)
quickRequest quickRequest
.get(uri"$wsEndpoint") .get(uri"$wsEndpoint")
.header("Origin", f"$server") .header("Origin", f"$server")
.response(asWebSocketAlways(body)) .response(asWebSocketAlways(body))
.send(backend) .send(backend)
extension (ws: SyncWebSocket)
def assertHello()(implicit loc: munit.Location) =
val initial = decodeActions(ws.receiveText())
assertEquals(initial, Seq(Action.Render("hello")))
test("ws: The ping server sends a welcome message over web sockets"): test("ws: The ping server sends a welcome message over web sockets"):
withServer: server ?=> withServer: server ?=>
val inst = createInstance(USER_IDS) val inst = createInstance(USER_IDS)
withWs(inst.instanceId): ws => withWs(inst.instanceId, UID0): ws =>
val initial = decodeActions(ws.receiveText()) ws.assertHello()
assertEquals(initial, Seq(Action.Render("hello")))
test("ws: The ping server echoes events"): test("ws: The ping server echoes events"):
withServer: server ?=> withServer: server ?=>
val inst = createInstance(USER_IDS) val inst = createInstance(USER_IDS)
withWs(inst.instanceId): ws => withWs(inst.instanceId, UID0): ws =>
ws.receiveText() ws.assertHello()
time: time:
for i <- 0 to 1000 do for i <- 0 to 1000 do
ws.sendText(app.wire.eventFormat.encode(i.toString).toString) ws.sendText(app.wire.eventFormat.encode(i.toString).toString)
val resp = decodeActions(ws.receiveText()) val resp = decodeActions(ws.receiveText())
assertEquals(resp, Seq(Action.Render(i.toString))) assertEquals(resp, Seq(Action.Render(i.toString)))
test("ws: Two websockets see each other"):
withServer: server ?=>
val inst = createInstance(USER_IDS)
withWs(inst.instanceId, UID0): ws0 =>
withWs(inst.instanceId, UID1): ws1 =>
ws0.assertHello()
ws1.assertHello()
val (ma, mb) = ("a", "b")
ws0.sendText(app.wire.eventFormat.encode("a").toString)
val r0a = decodeActions(ws0.receiveText())
ws1.sendText(app.wire.eventFormat.encode("b").toString)
val r0b = decodeActions(ws0.receiveText())
val (r1a, r1b) = (decodeActions(ws1.receiveText()), decodeActions(ws1.receiveText()))
assertEquals(r0a, r1a)
assertEquals(r0b, r1b)
assertEquals(r0a, Seq(Action.Render(ma)))
assertEquals(r0b, Seq(Action.Render(mb)))
def withJavaWs[T](instanceId: InstanceId, userId: String, listener: JWebSocket.Listener)
(body: JWebSocket => T)
(using server: WebServerInfo): T =
val instInfo = instanceInfo(instanceId)
val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, userId)
val client = java.net.http.HttpClient.newHttpClient()
val ws = client.newWebSocketBuilder()
.header("Origin", f"$server")
.buildAsync(java.net.URI.create(wsEndpoint), listener)
.join()
body(ws)
test("ws: Webserver survives abrupt disconnection"):
withServer: server ?=>
val inst = createInstance(USER_IDS)
val listener = new JWebSocket.Listener {}
withJavaWs(inst.instanceId, UID1, listener): wsJ =>
withWs(inst.instanceId, UID0): ws =>
ws.assertHello()
ws.sendText(app.wire.eventFormat.encode("a").toString)
val r0a = decodeActions(ws.receiveText())
assertEquals(r0a, Seq(Action.Render("a")))
// Unclean disconnect
wsJ.abort()
ws.sendText(app.wire.eventFormat.encode("b").toString)
val r0b = decodeActions(ws.receiveText())
assertEquals(r0b, Seq(Action.Render("b")))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment