diff --git a/jvm/src/test/scala/cs214/webapp/AppSuite.scala b/jvm/src/test/scala/cs214/webapp/AppSuite.scala index 7e16f45eb23bc15abbb13f453da5cc450530ee0c..635e934a61f414378425b0ac3825036951afbf96 100644 --- a/jvm/src/test/scala/cs214/webapp/AppSuite.scala +++ b/jvm/src/test/scala/cs214/webapp/AppSuite.scala @@ -60,3 +60,37 @@ class PagesSuite extends PingPongSuite: for _ <- 1 to 10 do val inst = createInstance(USER_IDS) instanceInfo(inst.instanceId) + +class WebSocketSuite extends PingPongSuite: + import PingPongSuite.* + import ClientSuite.* + import castor.Context.Simple.global + import cask.Logger.Console.globalLogger + + def decodeActions(wsMessage: String) = + val js = ujson.read(wsMessage) + val evt = TryWire(EventResponse.Wire).decode(js).flatten + evt match + case Failure(msg) => fail(msg.getMessage) + case Success(Failure(msg)) => fail(msg.getMessage) + case Success(Success(actions)) => + actions.map: action => + action.map: viewJs => + app.wire.viewFormat.decode(viewJs).getOrElse: + fail(f"Cannot decode view $viewJs") + + def withWs[T](instanceId: InstanceId)(body: SyncWebSocket => T)(using server: WebServerInfo): Response[T] = + val instInfo = instanceInfo(instanceId) + val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, UID0) + quickRequest + .get(uri"$wsEndpoint") + .header("Origin", f"$server") + .response(asWebSocketAlways(body)) + .send(backend) + + test("ws: The server sends a welcome message over web sockets"): + withServer: server ?=> + val inst = createInstance(USER_IDS) + withWs(inst.instanceId): ws => + val initial = decodeActions(ws.receiveText()) + assertEquals(initial, Seq(Action.Render("")))