Skip to content
Snippets Groups Projects

server: Protect against cross-origin websocket requests

Merged Clément Pit-Claudel requested to merge cpc/websocket-origin into main
1 file
+ 7
13
Compare changes
  • Side-by-side
  • Inline
@@ -6,6 +6,7 @@ import java.net.InetAddress
import scala.jdk.CollectionConverters.*
import scala.util.Try
import cask.endpoints.JsonData
import cs214.webapp.server.decorators.originValidationWebSocket
/** HTTP routes of the WebServer */
private[server] final case class WebServerRoutes()(using cc: castor.Context, log: cask.Logger) extends cask.Routes:
@@ -33,7 +34,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
_ = if WebServer.debug then println(f"[debug] found address ${addr.getHostAddress}")
yield addr.getHostAddress
Try(addresses.toList.head).getOrElse(InetAddress.getLocalHost.getHostAddress)
@cask.get("/")
def getIndexFile() = HTML_STATIC_FILE
@@ -43,7 +44,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
// For all /app subsegments, provide the HTML page
@cask.get(f"${Endpoints.App}")
def getApp(segments: cask.RemainingPathSegments) = HTML_STATIC_FILE
@cask.getJson(f"${Endpoints.Api.listApps}")
def getListApps() =
ListAppsResponse.Wire.encode(ListAppsResponse(WebServer.appDirectory.values.map(_.appInfo).toSeq))
@@ -59,7 +60,6 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
case None =>
cask.Response(f"Unknown instance id $instanceId", 400)
response
@cask.post(f"${Endpoints.Api.createInstance}")
def postInitApp(request: cask.Request) =
val response: cask.Response[JsonData] =
@@ -70,20 +70,14 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
val appId = WebServer.createInstance(req.get.appName, req.get.userIds)
CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId))
response
def hasAcceptableOriginHeader(request: cask.Request): Boolean =
// request.headers.get("sec-fetch-site").exists(_.contains("same-origin")) // Not supported in Safari
request.headers.get("host").flatMap(_.headOption).exists: host =>
request.headers.get("origin").flatMap(_.headOption).exists: origin =>
origin == f"http://$host" || origin == f"https://$host"
@originValidationWebSocket()
@cask.websocket(f"${Endpoints.WebSocket}/:instanceId/:userId")
def websocket(instanceId: String, userId: String, request: cask.Request): cask.WebsocketResult =
if hasAcceptableOriginHeader(request) then
WebServer.instances.get(instanceId) match
case Some(app) => app.connect(userId)
case None => cask.Response(f"Unknown instance id $instanceId", 400)
else
cask.Response(f"Invalid or missing 'Origin' header", 403)
initialize()
Loading