diff --git a/jvm/src/test/scala/cs214/webapp/AppSuite.scala b/jvm/src/test/scala/cs214/webapp/AppSuite.scala
index 7e16f45eb23bc15abbb13f453da5cc450530ee0c..635e934a61f414378425b0ac3825036951afbf96 100644
--- a/jvm/src/test/scala/cs214/webapp/AppSuite.scala
+++ b/jvm/src/test/scala/cs214/webapp/AppSuite.scala
@@ -60,3 +60,37 @@ class PagesSuite extends PingPongSuite:
         for _ <- 1 to 10 do
           val inst = createInstance(USER_IDS)
           instanceInfo(inst.instanceId)
+
+class WebSocketSuite extends PingPongSuite:
+  import PingPongSuite.*
+  import ClientSuite.*
+  import castor.Context.Simple.global
+  import cask.Logger.Console.globalLogger
+
+  def decodeActions(wsMessage: String) =
+    val js = ujson.read(wsMessage)
+    val evt = TryWire(EventResponse.Wire).decode(js).flatten
+    evt match
+      case Failure(msg)              => fail(msg.getMessage)
+      case Success(Failure(msg))     => fail(msg.getMessage)
+      case Success(Success(actions)) =>
+        actions.map: action =>
+          action.map: viewJs =>
+            app.wire.viewFormat.decode(viewJs).getOrElse:
+              fail(f"Cannot decode view $viewJs")
+
+  def withWs[T](instanceId: InstanceId)(body: SyncWebSocket => T)(using server: WebServerInfo): Response[T] =
+    val instInfo = instanceInfo(instanceId)
+    val wsEndpoint = substituteInstanceInfo(instInfo.wsEndpoint, ProtocolInfo.WebSocket, UID0)
+    quickRequest
+      .get(uri"$wsEndpoint")
+      .header("Origin", f"$server")
+      .response(asWebSocketAlways(body))
+      .send(backend)
+
+  test("ws: The server sends a welcome message over web sockets"):
+    withServer: server ?=>
+      val inst = createInstance(USER_IDS)
+      withWs(inst.instanceId): ws =>
+          val initial = decodeActions(ws.receiveText())
+          assertEquals(initial, Seq(Action.Render("")))