From 06ad831e3e50e3fadd43efbb9210858f47d2063c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Fri, 20 Dec 2024 00:37:28 +0100
Subject: [PATCH] tests: Test websocket initialization

---
 .../test/scala/cs214/webapp/AppSuite.scala    | 34 +++++++++++++++++++
 1 file changed, 34 insertions(+)

diff --git a/jvm/src/test/scala/cs214/webapp/AppSuite.scala b/jvm/src/test/scala/cs214/webapp/AppSuite.scala
index 7e16f45..635e934 100644
--- a/jvm/src/test/scala/cs214/webapp/AppSuite.scala
+++ b/jvm/src/test/scala/cs214/webapp/AppSuite.scala
@@ -60,3 +60,37 @@ class PagesSuite extends PingPongSuite:
         for _ <- 1 to 10 do
           val inst = createInstance(USER_IDS)
           instanceInfo(inst.instanceId)
+
+class WebSocketSuite extends PingPongSuite:
+  import PingPongSuite.*
+  import ClientSuite.*
+  import castor.Context.Simple.global
+  import cask.Logger.Console.globalLogger
+
+  def decodeActions(wsMessage: String) =
+    val js = ujson.read(wsMessage)
+    val evt = TryWire(EventResponse.Wire).decode(js).flatten
+    evt match
+      case Failure(msg)              => fail(msg.getMessage)
+      case Success(Failure(msg))     => fail(msg.getMessage)
+      case Success(Success(actions)) =>
+        actions.map: action =>
+          action.map: viewJs =>
+            app.wire.viewFormat.decode(viewJs).getOrElse:
+              fail(f"Cannot decode view $viewJs")
+
+  def withWs[T](instanceId: InstanceId)(body: SyncWebSocket => T)(using server: WebServerInfo): Response[T] =
+    val instInfo = instanceInfo(instanceId)
+    val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, UID0)
+    quickRequest
+      .get(uri"$wsEndpoint")
+      .header("Origin", f"$server")
+      .response(asWebSocketAlways(body))
+      .send(backend)
+
+  test("ws: The server sends a welcome message over web sockets"):
+    withServer: server ?=>
+      val inst = createInstance(USER_IDS)
+      withWs(inst.instanceId): ws =>
+          val initial = decodeActions(ws.receiveText())
+          assertEquals(initial, Seq(Action.Render("")))
-- 
GitLab