From 387af5363eee4b86e8605b0f8b1ad0ea376e855d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch> Date: Fri, 27 Dec 2024 12:32:45 +0100 Subject: [PATCH] server: Rewrite the websocket server using undertow instead of cask While the previous commit solves the issue of improperly closed channels staying around forever, a dropped connection still logs one error message with its stack trace (this is the default behavior of cask actors). With this commits, we: - Forward send and receive errors to the user-supplied callback. - Correctly handle the error that appeared as a stack trace in 8a6361a. - Answer ping messages with a pong. Together with the previous commit, this allows us to avoid infinitely growing traces full of this message: ``` [error] java.io.IOException: UT002027: Could not send data, as the underlying web socket connection has been broken [error] at io.undertow.websockets.core.WebSocketChannel.send(WebSocketChannel.java:357) [error] at io.undertow.websockets.core.WebSockets.sendBlockingInternal(WebSockets.java:992) [error] at io.undertow.websockets.core.WebSockets.sendBlockingInternal(WebSockets.java:986) [error] at io.undertow.websockets.core.WebSockets.sendTextBlocking(WebSockets.java:200) [error] at cask.endpoints.WsChannelActor.run(WebSocketEndpoint.scala:81) [error] at cask.endpoints.WsChannelActor.run(WebSocketEndpoint.scala:80) [error] at castor.SimpleActor.runBatch0$$anonfun$5(Actors.scala:71) [error] at scala.runtime.function.JProcedure1.apply(JProcedure1.java:15) [error] at scala.runtime.function.JProcedure1.apply(JProcedure1.java:10) [error] at scala.collection.IterableOnceOps.foreach(IterableOnce.scala:619) [error] at scala.collection.IterableOnceOps.foreach$(IterableOnce.scala:617) [error] at scala.collection.AbstractIterable.foreach(Iterable.scala:935) [error] at scala.collection.IterableOps$WithFilter.foreach(Iterable.scala:905) [error] at castor.SimpleActor.runBatch0(Actors.scala:75) [error] at castor.BaseActor.castor$BaseActor$$runWithItems(Actors.scala:36) [error] at castor.BaseActor$$anon$1.run(Actors.scala:18) [error] at castor.Context$Impl$$anon$1.run(Context.scala:139) [error] at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144) [error] at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642) [error] at java.base/java.lang.Thread.run(Thread.java:1583) ``` --- .../cs214/webapp/server/web/ServerApp.scala | 43 +++++++---- .../cs214/webapp/server/web/WebSocket.scala | 71 +++++++++++++++++++ 2 files changed, 101 insertions(+), 13 deletions(-) create mode 100644 jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala diff --git a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala index d57b4fe..f9c6534 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala @@ -10,6 +10,8 @@ import scala.collection.{concurrent, mutable} import scala.util.{Success, Failure, Try} import scala.concurrent.duration.Duration +import io.undertow.websockets.core.WebSocketChannel + case class BroadcastException(ex: Throwable) extends Throwable /** Server-side apps definition and abstractions */ @@ -64,7 +66,7 @@ private[web] abstract class ServerApp: def lastActivity = _lastActivity.get() private val channels = - concurrent.TrieMap[UserId, mutable.Set[cask.WsChannelActor]]().withDefault(_ => mutable.Set()) + concurrent.TrieMap[UserId, mutable.Set[WebSocketChannel]]().withDefault(_ => mutable.Set()) def instanceInfo = // No lock needed: read-only or thread-safe members AdminStatusInstanceInfo( @@ -76,13 +78,13 @@ private[web] abstract class ServerApp: def connect(userId: UserId) (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult = instanceLock.synchronized: recordActivity() - cask.WsHandler: channel => - channels.getOrElseUpdate(userId, mutable.Set()).add(channel) + WebSocketConnectionAdapter: channel => println(f"[${appInfo.id}/$instanceId/$userId] client connected") + channels.getOrElseUpdate(userId, mutable.Set()).add(channel) send(userId): val view = respondToNewClient(userId) EventResponse.Wire.encode(Success(List(Action.Render(view)))) - cask.WsActor { + WebSocketEventListener { evt => evt match case cask.Ws.Error(e) => println(f"[${appInfo.id}/$instanceId/$userId] error: ${e.getMessage()}") disconnect(userId, channel) @@ -92,14 +94,20 @@ private[web] abstract class ServerApp: case cask.Ws.ChannelClosed() => println(f"[${appInfo.id}/$instanceId/$userId] channel closed") disconnect(userId, channel) + case cask.Ws.Binary(data) => + println(f"[${appInfo.id}/$instanceId/$userId] unsupported: binary data") + disconnect(userId, channel) case cask.Ws.Text(data) => handleMessage(userId, ujson.read(data)) + case cask.Ws.Ping(_) | cask.Ws.Pong(_) => + () } - def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor): Unit = instanceLock.synchronized: - channels(userId).remove(channel) - if channels(userId).isEmpty then channels.remove(userId) - println(f"[${appInfo.id}/$instanceId/$userId] client disconnected") + def disconnect(userId: UserId, channel: WebSocketChannel): Unit = instanceLock.synchronized: + if channels(userId).remove(channel) then + println(f"[${appInfo.id}/$instanceId/$userId] client disconnected") + if channels(userId).isEmpty then + channels.remove(userId) /** Enumerates clients connected to this instance. */ def connectedClients: Seq[UserId] = instanceLock.synchronized: @@ -109,16 +117,25 @@ private[web] abstract class ServerApp: channels.nonEmpty def shutdown() = instanceLock.synchronized: - for channels <- channels.values - channel <- channels + for (userId, userChannels) <- channels + channel <- userChannels do - channel.send(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown")) + send(userId, channel)(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown")) /** Sends a message to a specific client. */ private def send(userId: UserId)(message: ujson.Value): Unit = instanceLock.synchronized: val wrapped = SocketResponseWire.encode(util.Success(message)).toString for channel <- channels(userId) do - channel.send(cask.Ws.Text(wrapped)) + send(userId, channel)(cask.Ws.Text(wrapped)) + + private def send(userId: UserId, channel: WebSocketChannel) + (event: WebSocketEvent): Unit = instanceLock.synchronized: + channel.send(event) { + case util.Success(()) => + case util.Failure(ex) => + println(f"[${appInfo.id}/$instanceId/$userId] send error: ${ex.getMessage()}") + disconnect(userId, channel) + } def handleMessage(userId: UserId, msg: ujson.Value): Unit = instanceLock.synchronized: recordActivity() @@ -205,7 +222,7 @@ private[web] final class ClockDrivenStateMachineServerApp[E, S, V]( if !hasClients then clock.start() super.connect(userId) - override def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor) + override def disconnect(userId: UserId, channel: WebSocketChannel) : Unit = instanceLock.synchronized: super.disconnect(userId, channel) if !hasClients then clock.shutdown() diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala new file mode 100644 index 0000000..1cb1bca --- /dev/null +++ b/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala @@ -0,0 +1,71 @@ +package cs214.webapp.server.web + +import scala.util.{Try, Success, Failure} + +import io.undertow.websockets.* +import io.undertow.websockets.core.* +import io.undertow.websockets.spi.WebSocketHttpExchange + +// Cask offers a higher-level websocket API, but it doesn't propagate send +// errors back to individual channels (instead, the default error handler of the +// current actor context catches them). The implementation below uses the async +// API with a callback to make sure no error gets lost. + +type WebSocketEvent = cask.Ws.Event + +case class WebSocketEventListener(f: WebSocketEvent => Unit): + def receive(we: WebSocketEvent): Unit = f(we) + +case class WebSocketAdapter(listener: WebSocketEventListener) extends AbstractReceiveListener: + extension (message: BufferedBinaryMessage) + def getDataArray: Array[Byte] = WebSockets.mergeBuffers(message.getData.getResource*).array() + + // Missing from Cask: calls to default handlers (including deallocation of + // buffers and responding to pings) and `onError` callback. + + override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = + listener.receive(cask.Ws.Text(message.getData)) + super.onFullTextMessage(channel, message) + override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) = + listener.receive(cask.Ws.Binary(message.getDataArray)) + super.onFullBinaryMessage(channel, message) + override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) = + listener.receive(cask.Ws.Ping(message.getDataArray)) + super.onFullPingMessage(channel, message) + override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) = + listener.receive(cask.Ws.Pong(message.getDataArray)) + super.onFullPongMessage(channel, message) + override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = + listener.receive(cask.Ws.Close(cm.getCode, cm.getReason)) + super.onCloseMessage(cm, channel) + override def onError(channel: WebSocketChannel, error: Throwable) = + listener.receive(cask.Ws.Error(error)) + super.onError(channel, error) + +case class WebSocketConnectionAdapter(f: WebSocketChannel => WebSocketEventListener) extends WebSocketConnectionCallback: + def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = + channel.suspendReceives() + val listener = f(channel) + channel.addCloseTask(_ => listener.receive(cask.Ws.ChannelClosed())) + channel.getReceiveSetter.set(WebSocketAdapter(listener)) + channel.resumeReceives() + +// We don't use a Future to avoid having to pass an ExecutionContext everywhere +case class TryCallback(cb: Try[Unit] => Unit) extends WebSocketCallback[Void]: + override def complete(channel: WebSocketChannel, context: Void): Unit = + cb(Success(())) + override def onError(channel: WebSocketChannel, context: Void, throwable: Throwable): Unit = + cb(Failure(throwable)) + org.xnio.IoUtils.safeClose(channel) // Default behavior when no callback is given + +extension (channel: WebSocketChannel) + def send(event: WebSocketEvent)(callback: Try[Unit] => Unit) = + import java.nio.ByteBuffer + val cb = TryCallback(callback) + event match + case cask.Ws.Text(data) => WebSockets.sendText(data, channel, cb) + case cask.Ws.Binary(data) => WebSockets.sendBinary(ByteBuffer.wrap(data), channel, cb) + case cask.Ws.Ping(value) => WebSockets.sendPing(ByteBuffer.wrap(value), channel, cb) + case cask.Ws.Pong(value) => WebSockets.sendPong(ByteBuffer.wrap(value), channel, cb) + case cask.Ws.Close(code, reason) => WebSockets.sendClose(code, reason, channel, cb) + case _ => throw UnsupportedOperationException() -- GitLab