From a8cf4b49bc08742c92606318d66314fc432771dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch> Date: Wed, 18 Dec 2024 23:09:51 +0100 Subject: [PATCH] server: Refactor websocket decorator implementation --- build.sbt | 2 - .../server/decorators/originValidation.scala | 41 --------- .../webapp/server/web/WebServerRoutes.scala | 20 ++--- .../cs214/webapp/server/web/decorators.scala | 19 ++++ .../scala/cs214/webapp/DecoratorsSuite.scala | 37 ++++++++ .../decorators/originValidationTest.scala | 87 ------------------- 6 files changed, 66 insertions(+), 140 deletions(-) delete mode 100644 jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala create mode 100644 jvm/src/main/scala/cs214/webapp/server/web/decorators.scala create mode 100644 jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala delete mode 100644 jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala diff --git a/build.sbt b/build.sbt index 6aa4f41..41f8a90 100644 --- a/build.sbt +++ b/build.sbt @@ -6,7 +6,6 @@ val webSocketVersion = "1.5.4" val caskVersion = "0.9.4" val slf4jVersion = "2.0.5" val reflectionsVersion = "0.10.2" -val scalaCheckVersion = "1.18.1" val options = List("-deprecation", "-feature", "-language:fewerBraces", "-Xfatal-warnings") @@ -32,6 +31,5 @@ lazy val webappLib = crossProject(JSPlatform, JVMPlatform).in(file(".")) "org.slf4j" % "slf4j-nop" % slf4jVersion, "org.reflections" % "reflections" % reflectionsVersion, "org.scala-lang" %% "toolkit-test" % toolkitVersion % Test, - "org.scalacheck" %% "scalacheck" % scalaCheckVersion % Test, ), ) diff --git a/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala b/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala deleted file mode 100644 index 874919d..0000000 --- a/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala +++ /dev/null @@ -1,41 +0,0 @@ -package cs214.webapp.server.decorators - -import cask.model.Response -import cask.router.{Decorator} -import cask.router.Result - - - -/** Decorator to validate the origin of the request. - * Cask Decorators enforce strict matching type signatures - * with the core function they are decorating. - * So for each new Return type T, a new Decorator class - * extending originValidation must be created. - * The only method to override is constructForbiddenResponse - * which might be different for each Return type T. - * -*/ -private class originValidation[T] extends Decorator[Any, T, Any] { - def wrapFunction(ctx: cask.Request, delegate: Delegate): Result[T] = { - - // Check if the Origin header is valid - val isSourceValid = ctx.headers.get("host").flatMap(_.headOption).exists: host => - ctx.headers.get("origin").flatMap(_.headOption).exists: origin => - origin == s"http://$host" || origin == s"https://$host" - - if (isSourceValid) { - // Call the core logic - delegate(Map.empty) - } else { - // Return a 403 Forbidden response - constructForbiddenResponse.asInstanceOf[Result[T]] - } - } - def constructForbiddenResponse: Result[T] = ??? -} -/* WebSocket origin validation */ -class originValidationWebSocket extends originValidation[cask.endpoints.WebsocketResult] { - override def constructForbiddenResponse: Result[cask.endpoints.WebsocketResult] = { - Result.Success(new cask.endpoints.WebsocketResult.Response(cask.Response("Forbidden", 403))) - } -} \ No newline at end of file diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala index 3c6f8ce..e8199ed 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala @@ -6,7 +6,8 @@ import java.net.InetAddress import scala.jdk.CollectionConverters.* import scala.util.Try import cask.endpoints.JsonData -import cs214.webapp.server.decorators.originValidationWebSocket + +import decorators.checkOriginHeader /** HTTP routes of the WebServer */ private[server] final case class WebServerRoutes()(using cc: castor.Context, log: cask.Logger) extends cask.Routes: @@ -34,7 +35,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log _ = if WebServer.debug then println(f"[debug] found address ${addr.getHostAddress}") yield addr.getHostAddress Try(addresses.toList.head).getOrElse(InetAddress.getLocalHost.getHostAddress) - + @cask.get("/") def getIndexFile() = HTML_STATIC_FILE @@ -44,7 +45,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log // For all /app subsegments, provide the HTML page @cask.get(f"${Endpoints.App}") def getApp(segments: cask.RemainingPathSegments) = HTML_STATIC_FILE - + @cask.getJson(f"${Endpoints.Api.listApps}") def getListApps() = ListAppsResponse.Wire.encode(ListAppsResponse(WebServer.appDirectory.values.map(_.appInfo).toSeq)) @@ -61,6 +62,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log case None => cask.Response(f"Unknown instance id $instanceId", 400) response + @cask.post(f"${Endpoints.Api.createInstance}") def postInitApp(request: cask.Request) = val response: cask.Response[JsonData] = @@ -71,14 +73,12 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log val appId = WebServer.createInstance(req.get.appName, req.get.userIds) CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId)) response - - @originValidationWebSocket() + + @checkOriginHeader @cask.websocket(f"${Endpoints.WebSocket}/:instanceId/:userId") def websocket(instanceId: String, userId: String, request: cask.Request): cask.WebsocketResult = - WebServer.instances.get(instanceId) match - case Some(app) => app.connect(userId) - case None => cask.Response(f"Unknown instance id $instanceId", 400) - - + WebServer.instances.get(instanceId) match + case Some(app) => app.connect(userId) + case None => cask.Response(f"Unknown instance id $instanceId", 400) initialize() diff --git a/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala b/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala new file mode 100644 index 0000000..12a22cf --- /dev/null +++ b/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala @@ -0,0 +1,19 @@ +package cs214.webapp +package server +package web + +object decorators: + extension (request: cask.Request) + def originHeaderMatchesHost: Boolean = + // request.headers.get("sec-fetch-site").exists(_.contains("same-origin")) // Not supported in Safari + request.headers.get("origin").flatMap(_.headOption).exists: origin => + request.headers.get("host").flatMap(_.headOption).exists: host => + origin == f"http://$host" || origin == f"https://$host" + + class checkOriginHeader extends cask.router.Decorator[cask.WebsocketResult, cask.WebsocketResult, Any]: + override def wrapFunction(request: cask.Request, delegate: Delegate) = + if request.originHeaderMatchesHost then + delegate(Map()) + else + cask.router.Result.Success: + cask.Response(f"Invalid or missing 'Origin' header: must match the Host", 403) diff --git a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala new file mode 100644 index 0000000..335e2ec --- /dev/null +++ b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala @@ -0,0 +1,37 @@ +package cs214.webapp.server.web + +import cask.router.Result +import cask.model.Request +import io.undertow.util.Headers +import cask.endpoints.WebsocketResult + +class DecoratorsSuite extends munit.FunSuite: + val exchange = io.undertow.server.HttpServerExchange(null) + + val request = Request(exchange, Nil) + val delegate = new decorator.Delegate: + def apply(v1: Map[String, Any]): cask.router.Result[WebsocketResult] = + cask.router.Result.Success: + cask.model.Response("OK", 200) + + val decorator = decorators.checkOriginHeader() + + def decoratorStatusCode: Int = + decorator.wrapFunction(request, delegate) match + case Result.Success(resp: WebsocketResult.Response[?]) => + resp.value.statusCode + case _ => fail("Unexpected return value from decorator") + + test("checkOriginHeader: Valid origin"): + exchange.getRequestHeaders.put(Headers.HOST, "example.com") + exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://example.com") + assertEquals(decoratorStatusCode, 200) + exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://example.com") + assertEquals(decoratorStatusCode, 200) + + test("checkOriginHeader: Invalid origin"): + exchange.getRequestHeaders.put(Headers.HOST, "example.com") + exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://cs-214.epfl.ch") + assertEquals(decoratorStatusCode, 403) + exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://cs-214.epfl.ch") + assertEquals(decoratorStatusCode, 403) diff --git a/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala b/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala deleted file mode 100644 index 00ad20a..0000000 --- a/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala +++ /dev/null @@ -1,87 +0,0 @@ -package cs214.webapp.decorators - - -import cask.router.Result -import cask.model.Request -import io.undertow.server.HttpServerExchange -import io.undertow.util.Headers -import cask.endpoints.WebsocketResult -import scala.concurrent.duration.Duration -import cs214.webapp.server.decorators.originValidationWebSocket -import org.scalacheck.Gen - -val arbitraryString: Gen[String] = Gen.alphaStr - - - -class OriginValidationTest extends munit.FunSuite { - override val munitTimeout: Duration = Duration(1, "s") - - - val exchange = new HttpServerExchange(null) - val request = Request(exchange, Nil) - - - val ctx = Request(exchange, null) - val decorator = new originValidationWebSocket() - - val delegate = new decorator.Delegate { - def apply(v1: Map[String, Any]): cask.router.Result[WebsocketResult] = - cask.router.Result.Success( - cask.model.Response("OK", 200) - ) - } - - def checkValid(expectedStatus : Int) = { - // Simulate the decorator being called - val result = decorator.wrapFunction(request,delegate) - // Check expected status code - result match { - case suc: Result.Success[_] => - suc.value match - case resp : WebsocketResult.Response[?] => - if resp.value.statusCode != expectedStatus then - fail(s"Expected $expectedStatus") - case _ => fail("Expected Response") - case _ => fail("Expected Result.Success") - } - } - - test("originValidation - valid origin") { - // Check that for any similar host and origin the validation is successful - for i <- 0 to 1000 do { - val host = arbitraryString.sample.get - // Check for http - exchange.getRequestHeaders.put(Headers.HOST, host) - exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://$host") - checkValid(200) - // Check for https - exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://$host") - checkValid(200) - } - } - - test("originValidation - invalid origin") { - // Check that for any different host and origin the validation is unsuccessful - for i <- 0 to 1000 do { - - var host = arbitraryString.sample.get - var origin = arbitraryString.sample.get - - // Ensure that the host and origin are different - while host == origin do { - host = arbitraryString.sample.get - origin = arbitraryString.sample.get - } - // Check for http - exchange.getRequestHeaders.put(Headers.HOST, host) - exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://$origin") - checkValid(403) - // Check for https - exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://$origin") - checkValid(403) - - } - } - -} -- GitLab