diff --git a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
index d57b4fe9d61b166a438aa79e2bbf208fac87a572..f9c65341ba85252f22e2178bce420f1cddd657e5 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala
@@ -10,6 +10,8 @@ import scala.collection.{concurrent, mutable}
 import scala.util.{Success, Failure, Try}
 import scala.concurrent.duration.Duration
 
+import io.undertow.websockets.core.WebSocketChannel
+
 case class BroadcastException(ex: Throwable) extends Throwable
 
 /** Server-side apps definition and abstractions */
@@ -64,7 +66,7 @@ private[web] abstract class ServerApp:
   def lastActivity = _lastActivity.get()
 
   private val channels =
-    concurrent.TrieMap[UserId, mutable.Set[cask.WsChannelActor]]().withDefault(_ => mutable.Set())
+    concurrent.TrieMap[UserId, mutable.Set[WebSocketChannel]]().withDefault(_ => mutable.Set())
 
   def instanceInfo = // No lock needed: read-only or thread-safe members
     AdminStatusInstanceInfo(
@@ -76,13 +78,13 @@ private[web] abstract class ServerApp:
   def connect(userId: UserId)
              (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult = instanceLock.synchronized:
     recordActivity()
-    cask.WsHandler: channel =>
-      channels.getOrElseUpdate(userId, mutable.Set()).add(channel)
+    WebSocketConnectionAdapter: channel =>
       println(f"[${appInfo.id}/$instanceId/$userId] client connected")
+      channels.getOrElseUpdate(userId, mutable.Set()).add(channel)
       send(userId):
         val view = respondToNewClient(userId)
         EventResponse.Wire.encode(Success(List(Action.Render(view))))
-      cask.WsActor {
+      WebSocketEventListener { evt => evt match
         case cask.Ws.Error(e) =>
           println(f"[${appInfo.id}/$instanceId/$userId] error: ${e.getMessage()}")
           disconnect(userId, channel)
@@ -92,14 +94,20 @@ private[web] abstract class ServerApp:
         case cask.Ws.ChannelClosed() =>
           println(f"[${appInfo.id}/$instanceId/$userId] channel closed")
           disconnect(userId, channel)
+        case cask.Ws.Binary(data) =>
+          println(f"[${appInfo.id}/$instanceId/$userId] unsupported: binary data")
+          disconnect(userId, channel)
         case cask.Ws.Text(data) =>
           handleMessage(userId, ujson.read(data))
+        case cask.Ws.Ping(_) | cask.Ws.Pong(_) =>
+          ()
       }
 
-  def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor): Unit = instanceLock.synchronized:
-    channels(userId).remove(channel)
-    if channels(userId).isEmpty then channels.remove(userId)
-    println(f"[${appInfo.id}/$instanceId/$userId] client disconnected")
+  def disconnect(userId: UserId, channel: WebSocketChannel): Unit = instanceLock.synchronized:
+    if channels(userId).remove(channel) then
+      println(f"[${appInfo.id}/$instanceId/$userId] client disconnected")
+    if channels(userId).isEmpty then
+      channels.remove(userId)
 
   /** Enumerates clients connected to this instance. */
   def connectedClients: Seq[UserId] = instanceLock.synchronized:
@@ -109,16 +117,25 @@ private[web] abstract class ServerApp:
     channels.nonEmpty
 
   def shutdown() = instanceLock.synchronized:
-    for channels <- channels.values
-        channel <- channels
+    for (userId, userChannels) <- channels
+        channel <- userChannels
     do
-      channel.send(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown"))
+      send(userId, channel)(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown"))
 
   /** Sends a message to a specific client. */
   private def send(userId: UserId)(message: ujson.Value): Unit = instanceLock.synchronized:
     val wrapped = SocketResponseWire.encode(util.Success(message)).toString
     for channel <- channels(userId) do
-      channel.send(cask.Ws.Text(wrapped))
+      send(userId, channel)(cask.Ws.Text(wrapped))
+
+  private def send(userId: UserId, channel: WebSocketChannel)
+                  (event: WebSocketEvent): Unit = instanceLock.synchronized:
+    channel.send(event) {
+      case util.Success(()) =>
+      case util.Failure(ex) =>
+        println(f"[${appInfo.id}/$instanceId/$userId] send error: ${ex.getMessage()}")
+        disconnect(userId, channel)
+    }
 
   def handleMessage(userId: UserId, msg: ujson.Value): Unit = instanceLock.synchronized:
     recordActivity()
@@ -205,7 +222,7 @@ private[web] final class ClockDrivenStateMachineServerApp[E, S, V](
     if !hasClients then clock.start()
     super.connect(userId)
 
-  override def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor)
+  override def disconnect(userId: UserId, channel: WebSocketChannel)
     : Unit = instanceLock.synchronized:
     super.disconnect(userId, channel)
     if !hasClients then clock.shutdown()
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala
new file mode 100644
index 0000000000000000000000000000000000000000..1cb1bcad5dcd444a482d3f8470042f5664e92706
--- /dev/null
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala
@@ -0,0 +1,71 @@
+package cs214.webapp.server.web
+
+import scala.util.{Try, Success, Failure}
+
+import io.undertow.websockets.*
+import io.undertow.websockets.core.*
+import io.undertow.websockets.spi.WebSocketHttpExchange
+
+// Cask offers a higher-level websocket API, but it doesn't propagate send
+// errors back to individual channels (instead, the default error handler of the
+// current actor context catches them).  The implementation below uses the async
+// API with a callback to make sure no error gets lost.
+
+type WebSocketEvent = cask.Ws.Event
+
+case class WebSocketEventListener(f: WebSocketEvent => Unit):
+  def receive(we: WebSocketEvent): Unit = f(we)
+
+case class WebSocketAdapter(listener: WebSocketEventListener) extends AbstractReceiveListener:
+  extension (message: BufferedBinaryMessage)
+    def getDataArray: Array[Byte] = WebSockets.mergeBuffers(message.getData.getResource*).array()
+
+  // Missing from Cask: calls to default handlers (including deallocation of
+  // buffers and responding to pings) and `onError` callback.
+
+  override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) =
+    listener.receive(cask.Ws.Text(message.getData))
+    super.onFullTextMessage(channel, message)
+  override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) =
+    listener.receive(cask.Ws.Binary(message.getDataArray))
+    super.onFullBinaryMessage(channel, message)
+  override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) =
+    listener.receive(cask.Ws.Ping(message.getDataArray))
+    super.onFullPingMessage(channel, message)
+  override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) =
+    listener.receive(cask.Ws.Pong(message.getDataArray))
+    super.onFullPongMessage(channel, message)
+  override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) =
+    listener.receive(cask.Ws.Close(cm.getCode, cm.getReason))
+    super.onCloseMessage(cm, channel)
+  override def onError(channel: WebSocketChannel, error: Throwable) =
+    listener.receive(cask.Ws.Error(error))
+    super.onError(channel, error)
+
+case class WebSocketConnectionAdapter(f: WebSocketChannel => WebSocketEventListener) extends WebSocketConnectionCallback:
+  def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit =
+    channel.suspendReceives()
+    val listener = f(channel)
+    channel.addCloseTask(_ => listener.receive(cask.Ws.ChannelClosed()))
+    channel.getReceiveSetter.set(WebSocketAdapter(listener))
+    channel.resumeReceives()
+
+// We don't use a Future to avoid having to pass an ExecutionContext everywhere
+case class TryCallback(cb: Try[Unit] => Unit) extends WebSocketCallback[Void]:
+  override def complete(channel: WebSocketChannel, context: Void): Unit =
+    cb(Success(()))
+  override def onError(channel: WebSocketChannel, context: Void, throwable: Throwable): Unit =
+    cb(Failure(throwable))
+    org.xnio.IoUtils.safeClose(channel) // Default behavior when no callback is given
+
+extension (channel: WebSocketChannel)
+  def send(event: WebSocketEvent)(callback: Try[Unit] => Unit) =
+    import java.nio.ByteBuffer
+    val cb = TryCallback(callback)
+    event match
+      case cask.Ws.Text(data) => WebSockets.sendText(data, channel, cb)
+      case cask.Ws.Binary(data) => WebSockets.sendBinary(ByteBuffer.wrap(data), channel, cb)
+      case cask.Ws.Ping(value) => WebSockets.sendPing(ByteBuffer.wrap(value), channel, cb)
+      case cask.Ws.Pong(value) => WebSockets.sendPong(ByteBuffer.wrap(value), channel, cb)
+      case cask.Ws.Close(code, reason) => WebSockets.sendClose(code, reason, channel, cb)
+      case _ => throw UnsupportedOperationException()