diff --git a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala index d57b4fe9d61b166a438aa79e2bbf208fac87a572..f9c65341ba85252f22e2178bce420f1cddd657e5 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/ServerApp.scala @@ -10,6 +10,8 @@ import scala.collection.{concurrent, mutable} import scala.util.{Success, Failure, Try} import scala.concurrent.duration.Duration +import io.undertow.websockets.core.WebSocketChannel + case class BroadcastException(ex: Throwable) extends Throwable /** Server-side apps definition and abstractions */ @@ -64,7 +66,7 @@ private[web] abstract class ServerApp: def lastActivity = _lastActivity.get() private val channels = - concurrent.TrieMap[UserId, mutable.Set[cask.WsChannelActor]]().withDefault(_ => mutable.Set()) + concurrent.TrieMap[UserId, mutable.Set[WebSocketChannel]]().withDefault(_ => mutable.Set()) def instanceInfo = // No lock needed: read-only or thread-safe members AdminStatusInstanceInfo( @@ -76,13 +78,13 @@ private[web] abstract class ServerApp: def connect(userId: UserId) (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult = instanceLock.synchronized: recordActivity() - cask.WsHandler: channel => - channels.getOrElseUpdate(userId, mutable.Set()).add(channel) + WebSocketConnectionAdapter: channel => println(f"[${appInfo.id}/$instanceId/$userId] client connected") + channels.getOrElseUpdate(userId, mutable.Set()).add(channel) send(userId): val view = respondToNewClient(userId) EventResponse.Wire.encode(Success(List(Action.Render(view)))) - cask.WsActor { + WebSocketEventListener { evt => evt match case cask.Ws.Error(e) => println(f"[${appInfo.id}/$instanceId/$userId] error: ${e.getMessage()}") disconnect(userId, channel) @@ -92,14 +94,20 @@ private[web] abstract class ServerApp: case cask.Ws.ChannelClosed() => println(f"[${appInfo.id}/$instanceId/$userId] channel closed") disconnect(userId, channel) + case cask.Ws.Binary(data) => + println(f"[${appInfo.id}/$instanceId/$userId] unsupported: binary data") + disconnect(userId, channel) case cask.Ws.Text(data) => handleMessage(userId, ujson.read(data)) + case cask.Ws.Ping(_) | cask.Ws.Pong(_) => + () } - def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor): Unit = instanceLock.synchronized: - channels(userId).remove(channel) - if channels(userId).isEmpty then channels.remove(userId) - println(f"[${appInfo.id}/$instanceId/$userId] client disconnected") + def disconnect(userId: UserId, channel: WebSocketChannel): Unit = instanceLock.synchronized: + if channels(userId).remove(channel) then + println(f"[${appInfo.id}/$instanceId/$userId] client disconnected") + if channels(userId).isEmpty then + channels.remove(userId) /** Enumerates clients connected to this instance. */ def connectedClients: Seq[UserId] = instanceLock.synchronized: @@ -109,16 +117,25 @@ private[web] abstract class ServerApp: channels.nonEmpty def shutdown() = instanceLock.synchronized: - for channels <- channels.values - channel <- channels + for (userId, userChannels) <- channels + channel <- userChannels do - channel.send(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown")) + send(userId, channel)(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown")) /** Sends a message to a specific client. */ private def send(userId: UserId)(message: ujson.Value): Unit = instanceLock.synchronized: val wrapped = SocketResponseWire.encode(util.Success(message)).toString for channel <- channels(userId) do - channel.send(cask.Ws.Text(wrapped)) + send(userId, channel)(cask.Ws.Text(wrapped)) + + private def send(userId: UserId, channel: WebSocketChannel) + (event: WebSocketEvent): Unit = instanceLock.synchronized: + channel.send(event) { + case util.Success(()) => + case util.Failure(ex) => + println(f"[${appInfo.id}/$instanceId/$userId] send error: ${ex.getMessage()}") + disconnect(userId, channel) + } def handleMessage(userId: UserId, msg: ujson.Value): Unit = instanceLock.synchronized: recordActivity() @@ -205,7 +222,7 @@ private[web] final class ClockDrivenStateMachineServerApp[E, S, V]( if !hasClients then clock.start() super.connect(userId) - override def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor) + override def disconnect(userId: UserId, channel: WebSocketChannel) : Unit = instanceLock.synchronized: super.disconnect(userId, channel) if !hasClients then clock.shutdown() diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala new file mode 100644 index 0000000000000000000000000000000000000000..1cb1bcad5dcd444a482d3f8470042f5664e92706 --- /dev/null +++ b/jvm/src/main/scala/cs214/webapp/server/web/WebSocket.scala @@ -0,0 +1,71 @@ +package cs214.webapp.server.web + +import scala.util.{Try, Success, Failure} + +import io.undertow.websockets.* +import io.undertow.websockets.core.* +import io.undertow.websockets.spi.WebSocketHttpExchange + +// Cask offers a higher-level websocket API, but it doesn't propagate send +// errors back to individual channels (instead, the default error handler of the +// current actor context catches them). The implementation below uses the async +// API with a callback to make sure no error gets lost. + +type WebSocketEvent = cask.Ws.Event + +case class WebSocketEventListener(f: WebSocketEvent => Unit): + def receive(we: WebSocketEvent): Unit = f(we) + +case class WebSocketAdapter(listener: WebSocketEventListener) extends AbstractReceiveListener: + extension (message: BufferedBinaryMessage) + def getDataArray: Array[Byte] = WebSockets.mergeBuffers(message.getData.getResource*).array() + + // Missing from Cask: calls to default handlers (including deallocation of + // buffers and responding to pings) and `onError` callback. + + override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = + listener.receive(cask.Ws.Text(message.getData)) + super.onFullTextMessage(channel, message) + override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) = + listener.receive(cask.Ws.Binary(message.getDataArray)) + super.onFullBinaryMessage(channel, message) + override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) = + listener.receive(cask.Ws.Ping(message.getDataArray)) + super.onFullPingMessage(channel, message) + override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) = + listener.receive(cask.Ws.Pong(message.getDataArray)) + super.onFullPongMessage(channel, message) + override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = + listener.receive(cask.Ws.Close(cm.getCode, cm.getReason)) + super.onCloseMessage(cm, channel) + override def onError(channel: WebSocketChannel, error: Throwable) = + listener.receive(cask.Ws.Error(error)) + super.onError(channel, error) + +case class WebSocketConnectionAdapter(f: WebSocketChannel => WebSocketEventListener) extends WebSocketConnectionCallback: + def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = + channel.suspendReceives() + val listener = f(channel) + channel.addCloseTask(_ => listener.receive(cask.Ws.ChannelClosed())) + channel.getReceiveSetter.set(WebSocketAdapter(listener)) + channel.resumeReceives() + +// We don't use a Future to avoid having to pass an ExecutionContext everywhere +case class TryCallback(cb: Try[Unit] => Unit) extends WebSocketCallback[Void]: + override def complete(channel: WebSocketChannel, context: Void): Unit = + cb(Success(())) + override def onError(channel: WebSocketChannel, context: Void, throwable: Throwable): Unit = + cb(Failure(throwable)) + org.xnio.IoUtils.safeClose(channel) // Default behavior when no callback is given + +extension (channel: WebSocketChannel) + def send(event: WebSocketEvent)(callback: Try[Unit] => Unit) = + import java.nio.ByteBuffer + val cb = TryCallback(callback) + event match + case cask.Ws.Text(data) => WebSockets.sendText(data, channel, cb) + case cask.Ws.Binary(data) => WebSockets.sendBinary(ByteBuffer.wrap(data), channel, cb) + case cask.Ws.Ping(value) => WebSockets.sendPing(ByteBuffer.wrap(value), channel, cb) + case cask.Ws.Pong(value) => WebSockets.sendPong(ByteBuffer.wrap(value), channel, cb) + case cask.Ws.Close(code, reason) => WebSockets.sendClose(code, reason, channel, cb) + case _ => throw UnsupportedOperationException()