Skip to content
Snippets Groups Projects
Commit 387af536 authored by Clément Pit-Claudel's avatar Clément Pit-Claudel
Browse files

server: Rewrite the websocket server using undertow instead of cask

While the previous commit solves the issue of improperly closed channels staying
around forever, a dropped connection still logs one error message with its stack
trace (this is the default behavior of cask actors).  With this commits, we:

- Forward send and receive errors to the user-supplied callback.
- Correctly handle the error that appeared as a stack trace in 8a6361a4.
- Answer ping messages with a pong.

Together with the previous commit, this allows us to avoid infinitely growing
traces full of this message:

```
[error] java.io.IOException: UT002027: Could not send data, as the underlying web socket connection has been broken
[error] 	at io.undertow.websockets.core.WebSocketChannel.send(WebSocketChannel.java:357)
[error] 	at io.undertow.websockets.core.WebSockets.sendBlockingInternal(WebSockets.java:992)
[error] 	at io.undertow.websockets.core.WebSockets.sendBlockingInternal(WebSockets.java:986)
[error] 	at io.undertow.websockets.core.WebSockets.sendTextBlocking(WebSockets.java:200)
[error] 	at cask.endpoints.WsChannelActor.run(WebSocketEndpoint.scala:81)
[error] 	at cask.endpoints.WsChannelActor.run(WebSocketEndpoint.scala:80)
[error] 	at castor.SimpleActor.runBatch0$$anonfun$5(Actors.scala:71)
[error] 	at scala.runtime.function.JProcedure1.apply(JProcedure1.java:15)
[error] 	at scala.runtime.function.JProcedure1.apply(JProcedure1.java:10)
[error] 	at scala.collection.IterableOnceOps.foreach(IterableOnce.scala:619)
[error] 	at scala.collection.IterableOnceOps.foreach$(IterableOnce.scala:617)
[error] 	at scala.collection.AbstractIterable.foreach(Iterable.scala:935)
[error] 	at scala.collection.IterableOps$WithFilter.foreach(Iterable.scala:905)
[error] 	at castor.SimpleActor.runBatch0(Actors.scala:75)
[error] 	at castor.BaseActor.castor$BaseActor$$runWithItems(Actors.scala:36)
[error] 	at castor.BaseActor$$anon$1.run(Actors.scala:18)
[error] 	at castor.Context$Impl$$anon$1.run(Context.scala:139)
[error] 	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
[error] 	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
[error] 	at java.base/java.lang.Thread.run(Thread.java:1583)
```
parent 9b314f9f
No related branches found
No related tags found
1 merge request!40Properly handle sudden websocket disconnections
...@@ -10,6 +10,8 @@ import scala.collection.{concurrent, mutable} ...@@ -10,6 +10,8 @@ import scala.collection.{concurrent, mutable}
import scala.util.{Success, Failure, Try} import scala.util.{Success, Failure, Try}
import scala.concurrent.duration.Duration import scala.concurrent.duration.Duration
import io.undertow.websockets.core.WebSocketChannel
case class BroadcastException(ex: Throwable) extends Throwable case class BroadcastException(ex: Throwable) extends Throwable
/** Server-side apps definition and abstractions */ /** Server-side apps definition and abstractions */
...@@ -64,7 +66,7 @@ private[web] abstract class ServerApp: ...@@ -64,7 +66,7 @@ private[web] abstract class ServerApp:
def lastActivity = _lastActivity.get() def lastActivity = _lastActivity.get()
private val channels = private val channels =
concurrent.TrieMap[UserId, mutable.Set[cask.WsChannelActor]]().withDefault(_ => mutable.Set()) concurrent.TrieMap[UserId, mutable.Set[WebSocketChannel]]().withDefault(_ => mutable.Set())
def instanceInfo = // No lock needed: read-only or thread-safe members def instanceInfo = // No lock needed: read-only or thread-safe members
AdminStatusInstanceInfo( AdminStatusInstanceInfo(
...@@ -76,13 +78,13 @@ private[web] abstract class ServerApp: ...@@ -76,13 +78,13 @@ private[web] abstract class ServerApp:
def connect(userId: UserId) def connect(userId: UserId)
(implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult = instanceLock.synchronized: (implicit cc: castor.Context, log: cask.Logger): cask.WebsocketResult = instanceLock.synchronized:
recordActivity() recordActivity()
cask.WsHandler: channel => WebSocketConnectionAdapter: channel =>
channels.getOrElseUpdate(userId, mutable.Set()).add(channel)
println(f"[${appInfo.id}/$instanceId/$userId] client connected") println(f"[${appInfo.id}/$instanceId/$userId] client connected")
channels.getOrElseUpdate(userId, mutable.Set()).add(channel)
send(userId): send(userId):
val view = respondToNewClient(userId) val view = respondToNewClient(userId)
EventResponse.Wire.encode(Success(List(Action.Render(view)))) EventResponse.Wire.encode(Success(List(Action.Render(view))))
cask.WsActor { WebSocketEventListener { evt => evt match
case cask.Ws.Error(e) => case cask.Ws.Error(e) =>
println(f"[${appInfo.id}/$instanceId/$userId] error: ${e.getMessage()}") println(f"[${appInfo.id}/$instanceId/$userId] error: ${e.getMessage()}")
disconnect(userId, channel) disconnect(userId, channel)
...@@ -92,14 +94,20 @@ private[web] abstract class ServerApp: ...@@ -92,14 +94,20 @@ private[web] abstract class ServerApp:
case cask.Ws.ChannelClosed() => case cask.Ws.ChannelClosed() =>
println(f"[${appInfo.id}/$instanceId/$userId] channel closed") println(f"[${appInfo.id}/$instanceId/$userId] channel closed")
disconnect(userId, channel) disconnect(userId, channel)
case cask.Ws.Binary(data) =>
println(f"[${appInfo.id}/$instanceId/$userId] unsupported: binary data")
disconnect(userId, channel)
case cask.Ws.Text(data) => case cask.Ws.Text(data) =>
handleMessage(userId, ujson.read(data)) handleMessage(userId, ujson.read(data))
case cask.Ws.Ping(_) | cask.Ws.Pong(_) =>
()
} }
def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor): Unit = instanceLock.synchronized: def disconnect(userId: UserId, channel: WebSocketChannel): Unit = instanceLock.synchronized:
channels(userId).remove(channel) if channels(userId).remove(channel) then
if channels(userId).isEmpty then channels.remove(userId) println(f"[${appInfo.id}/$instanceId/$userId] client disconnected")
println(f"[${appInfo.id}/$instanceId/$userId] client disconnected") if channels(userId).isEmpty then
channels.remove(userId)
/** Enumerates clients connected to this instance. */ /** Enumerates clients connected to this instance. */
def connectedClients: Seq[UserId] = instanceLock.synchronized: def connectedClients: Seq[UserId] = instanceLock.synchronized:
...@@ -109,16 +117,25 @@ private[web] abstract class ServerApp: ...@@ -109,16 +117,25 @@ private[web] abstract class ServerApp:
channels.nonEmpty channels.nonEmpty
def shutdown() = instanceLock.synchronized: def shutdown() = instanceLock.synchronized:
for channels <- channels.values for (userId, userChannels) <- channels
channel <- channels channel <- userChannels
do do
channel.send(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown")) send(userId, channel)(cask.Ws.Close(cask.Ws.Close.NormalClosure, "Shutdown"))
/** Sends a message to a specific client. */ /** Sends a message to a specific client. */
private def send(userId: UserId)(message: ujson.Value): Unit = instanceLock.synchronized: private def send(userId: UserId)(message: ujson.Value): Unit = instanceLock.synchronized:
val wrapped = SocketResponseWire.encode(util.Success(message)).toString val wrapped = SocketResponseWire.encode(util.Success(message)).toString
for channel <- channels(userId) do for channel <- channels(userId) do
channel.send(cask.Ws.Text(wrapped)) send(userId, channel)(cask.Ws.Text(wrapped))
private def send(userId: UserId, channel: WebSocketChannel)
(event: WebSocketEvent): Unit = instanceLock.synchronized:
channel.send(event) {
case util.Success(()) =>
case util.Failure(ex) =>
println(f"[${appInfo.id}/$instanceId/$userId] send error: ${ex.getMessage()}")
disconnect(userId, channel)
}
def handleMessage(userId: UserId, msg: ujson.Value): Unit = instanceLock.synchronized: def handleMessage(userId: UserId, msg: ujson.Value): Unit = instanceLock.synchronized:
recordActivity() recordActivity()
...@@ -205,7 +222,7 @@ private[web] final class ClockDrivenStateMachineServerApp[E, S, V]( ...@@ -205,7 +222,7 @@ private[web] final class ClockDrivenStateMachineServerApp[E, S, V](
if !hasClients then clock.start() if !hasClients then clock.start()
super.connect(userId) super.connect(userId)
override def disconnect(userId: UserId, channel: cask.endpoints.WsChannelActor) override def disconnect(userId: UserId, channel: WebSocketChannel)
: Unit = instanceLock.synchronized: : Unit = instanceLock.synchronized:
super.disconnect(userId, channel) super.disconnect(userId, channel)
if !hasClients then clock.shutdown() if !hasClients then clock.shutdown()
......
package cs214.webapp.server.web
import scala.util.{Try, Success, Failure}
import io.undertow.websockets.*
import io.undertow.websockets.core.*
import io.undertow.websockets.spi.WebSocketHttpExchange
// Cask offers a higher-level websocket API, but it doesn't propagate send
// errors back to individual channels (instead, the default error handler of the
// current actor context catches them). The implementation below uses the async
// API with a callback to make sure no error gets lost.
type WebSocketEvent = cask.Ws.Event
case class WebSocketEventListener(f: WebSocketEvent => Unit):
def receive(we: WebSocketEvent): Unit = f(we)
case class WebSocketAdapter(listener: WebSocketEventListener) extends AbstractReceiveListener:
extension (message: BufferedBinaryMessage)
def getDataArray: Array[Byte] = WebSockets.mergeBuffers(message.getData.getResource*).array()
// Missing from Cask: calls to default handlers (including deallocation of
// buffers and responding to pings) and `onError` callback.
override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) =
listener.receive(cask.Ws.Text(message.getData))
super.onFullTextMessage(channel, message)
override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) =
listener.receive(cask.Ws.Binary(message.getDataArray))
super.onFullBinaryMessage(channel, message)
override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) =
listener.receive(cask.Ws.Ping(message.getDataArray))
super.onFullPingMessage(channel, message)
override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage) =
listener.receive(cask.Ws.Pong(message.getDataArray))
super.onFullPongMessage(channel, message)
override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) =
listener.receive(cask.Ws.Close(cm.getCode, cm.getReason))
super.onCloseMessage(cm, channel)
override def onError(channel: WebSocketChannel, error: Throwable) =
listener.receive(cask.Ws.Error(error))
super.onError(channel, error)
case class WebSocketConnectionAdapter(f: WebSocketChannel => WebSocketEventListener) extends WebSocketConnectionCallback:
def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit =
channel.suspendReceives()
val listener = f(channel)
channel.addCloseTask(_ => listener.receive(cask.Ws.ChannelClosed()))
channel.getReceiveSetter.set(WebSocketAdapter(listener))
channel.resumeReceives()
// We don't use a Future to avoid having to pass an ExecutionContext everywhere
case class TryCallback(cb: Try[Unit] => Unit) extends WebSocketCallback[Void]:
override def complete(channel: WebSocketChannel, context: Void): Unit =
cb(Success(()))
override def onError(channel: WebSocketChannel, context: Void, throwable: Throwable): Unit =
cb(Failure(throwable))
org.xnio.IoUtils.safeClose(channel) // Default behavior when no callback is given
extension (channel: WebSocketChannel)
def send(event: WebSocketEvent)(callback: Try[Unit] => Unit) =
import java.nio.ByteBuffer
val cb = TryCallback(callback)
event match
case cask.Ws.Text(data) => WebSockets.sendText(data, channel, cb)
case cask.Ws.Binary(data) => WebSockets.sendBinary(ByteBuffer.wrap(data), channel, cb)
case cask.Ws.Ping(value) => WebSockets.sendPing(ByteBuffer.wrap(value), channel, cb)
case cask.Ws.Pong(value) => WebSockets.sendPong(ByteBuffer.wrap(value), channel, cb)
case cask.Ws.Close(code, reason) => WebSockets.sendClose(code, reason, channel, cb)
case _ => throw UnsupportedOperationException()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment