diff --git a/build.sbt b/build.sbt index 6aa4f413d5e52a86da44364d83e96e361ecd94f6..41f8a9005bac5ca86437ee8c281aca1b0c10d783 100644 --- a/build.sbt +++ b/build.sbt @@ -6,7 +6,6 @@ val webSocketVersion = "1.5.4" val caskVersion = "0.9.4" val slf4jVersion = "2.0.5" val reflectionsVersion = "0.10.2" -val scalaCheckVersion = "1.18.1" val options = List("-deprecation", "-feature", "-language:fewerBraces", "-Xfatal-warnings") @@ -32,6 +31,5 @@ lazy val webappLib = crossProject(JSPlatform, JVMPlatform).in(file(".")) "org.slf4j" % "slf4j-nop" % slf4jVersion, "org.reflections" % "reflections" % reflectionsVersion, "org.scala-lang" %% "toolkit-test" % toolkitVersion % Test, - "org.scalacheck" %% "scalacheck" % scalaCheckVersion % Test, ), ) diff --git a/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala b/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala deleted file mode 100644 index 874919d3eb641f8f51e270ff7c92edbe48366643..0000000000000000000000000000000000000000 --- a/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala +++ /dev/null @@ -1,41 +0,0 @@ -package cs214.webapp.server.decorators - -import cask.model.Response -import cask.router.{Decorator} -import cask.router.Result - - - -/** Decorator to validate the origin of the request. - * Cask Decorators enforce strict matching type signatures - * with the core function they are decorating. - * So for each new Return type T, a new Decorator class - * extending originValidation must be created. - * The only method to override is constructForbiddenResponse - * which might be different for each Return type T. - * -*/ -private class originValidation[T] extends Decorator[Any, T, Any] { - def wrapFunction(ctx: cask.Request, delegate: Delegate): Result[T] = { - - // Check if the Origin header is valid - val isSourceValid = ctx.headers.get("host").flatMap(_.headOption).exists: host => - ctx.headers.get("origin").flatMap(_.headOption).exists: origin => - origin == s"http://$host" || origin == s"https://$host" - - if (isSourceValid) { - // Call the core logic - delegate(Map.empty) - } else { - // Return a 403 Forbidden response - constructForbiddenResponse.asInstanceOf[Result[T]] - } - } - def constructForbiddenResponse: Result[T] = ??? -} -/* WebSocket origin validation */ -class originValidationWebSocket extends originValidation[cask.endpoints.WebsocketResult] { - override def constructForbiddenResponse: Result[cask.endpoints.WebsocketResult] = { - Result.Success(new cask.endpoints.WebsocketResult.Response(cask.Response("Forbidden", 403))) - } -} \ No newline at end of file diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala index 3c6f8ceef56e5c73b28eca5a48181ef2ed7bb7a0..e8199ed9d4f55d4a998e77dcd8e6af087917dcd3 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala @@ -6,7 +6,8 @@ import java.net.InetAddress import scala.jdk.CollectionConverters.* import scala.util.Try import cask.endpoints.JsonData -import cs214.webapp.server.decorators.originValidationWebSocket + +import decorators.checkOriginHeader /** HTTP routes of the WebServer */ private[server] final case class WebServerRoutes()(using cc: castor.Context, log: cask.Logger) extends cask.Routes: @@ -34,7 +35,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log _ = if WebServer.debug then println(f"[debug] found address ${addr.getHostAddress}") yield addr.getHostAddress Try(addresses.toList.head).getOrElse(InetAddress.getLocalHost.getHostAddress) - + @cask.get("/") def getIndexFile() = HTML_STATIC_FILE @@ -44,7 +45,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log // For all /app subsegments, provide the HTML page @cask.get(f"${Endpoints.App}") def getApp(segments: cask.RemainingPathSegments) = HTML_STATIC_FILE - + @cask.getJson(f"${Endpoints.Api.listApps}") def getListApps() = ListAppsResponse.Wire.encode(ListAppsResponse(WebServer.appDirectory.values.map(_.appInfo).toSeq)) @@ -61,6 +62,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log case None => cask.Response(f"Unknown instance id $instanceId", 400) response + @cask.post(f"${Endpoints.Api.createInstance}") def postInitApp(request: cask.Request) = val response: cask.Response[JsonData] = @@ -71,14 +73,12 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log val appId = WebServer.createInstance(req.get.appName, req.get.userIds) CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId)) response - - @originValidationWebSocket() + + @checkOriginHeader @cask.websocket(f"${Endpoints.WebSocket}/:instanceId/:userId") def websocket(instanceId: String, userId: String, request: cask.Request): cask.WebsocketResult = - WebServer.instances.get(instanceId) match - case Some(app) => app.connect(userId) - case None => cask.Response(f"Unknown instance id $instanceId", 400) - - + WebServer.instances.get(instanceId) match + case Some(app) => app.connect(userId) + case None => cask.Response(f"Unknown instance id $instanceId", 400) initialize() diff --git a/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala b/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala new file mode 100644 index 0000000000000000000000000000000000000000..12a22cfffac1d0e6585dfbd9d8d20174ab0db8a6 --- /dev/null +++ b/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala @@ -0,0 +1,19 @@ +package cs214.webapp +package server +package web + +object decorators: + extension (request: cask.Request) + def originHeaderMatchesHost: Boolean = + // request.headers.get("sec-fetch-site").exists(_.contains("same-origin")) // Not supported in Safari + request.headers.get("origin").flatMap(_.headOption).exists: origin => + request.headers.get("host").flatMap(_.headOption).exists: host => + origin == f"http://$host" || origin == f"https://$host" + + class checkOriginHeader extends cask.router.Decorator[cask.WebsocketResult, cask.WebsocketResult, Any]: + override def wrapFunction(request: cask.Request, delegate: Delegate) = + if request.originHeaderMatchesHost then + delegate(Map()) + else + cask.router.Result.Success: + cask.Response(f"Invalid or missing 'Origin' header: must match the Host", 403) diff --git a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..335e2ec2f2d5eb90d887025b7532cc6a6b7f44ee --- /dev/null +++ b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala @@ -0,0 +1,37 @@ +package cs214.webapp.server.web + +import cask.router.Result +import cask.model.Request +import io.undertow.util.Headers +import cask.endpoints.WebsocketResult + +class DecoratorsSuite extends munit.FunSuite: + val exchange = io.undertow.server.HttpServerExchange(null) + + val request = Request(exchange, Nil) + val delegate = new decorator.Delegate: + def apply(v1: Map[String, Any]): cask.router.Result[WebsocketResult] = + cask.router.Result.Success: + cask.model.Response("OK", 200) + + val decorator = decorators.checkOriginHeader() + + def decoratorStatusCode: Int = + decorator.wrapFunction(request, delegate) match + case Result.Success(resp: WebsocketResult.Response[?]) => + resp.value.statusCode + case _ => fail("Unexpected return value from decorator") + + test("checkOriginHeader: Valid origin"): + exchange.getRequestHeaders.put(Headers.HOST, "example.com") + exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://example.com") + assertEquals(decoratorStatusCode, 200) + exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://example.com") + assertEquals(decoratorStatusCode, 200) + + test("checkOriginHeader: Invalid origin"): + exchange.getRequestHeaders.put(Headers.HOST, "example.com") + exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://cs-214.epfl.ch") + assertEquals(decoratorStatusCode, 403) + exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://cs-214.epfl.ch") + assertEquals(decoratorStatusCode, 403) diff --git a/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala b/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala deleted file mode 100644 index 00ad20a58db96aa19a4d5a4181ad92bfa095ad71..0000000000000000000000000000000000000000 --- a/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala +++ /dev/null @@ -1,87 +0,0 @@ -package cs214.webapp.decorators - - -import cask.router.Result -import cask.model.Request -import io.undertow.server.HttpServerExchange -import io.undertow.util.Headers -import cask.endpoints.WebsocketResult -import scala.concurrent.duration.Duration -import cs214.webapp.server.decorators.originValidationWebSocket -import org.scalacheck.Gen - -val arbitraryString: Gen[String] = Gen.alphaStr - - - -class OriginValidationTest extends munit.FunSuite { - override val munitTimeout: Duration = Duration(1, "s") - - - val exchange = new HttpServerExchange(null) - val request = Request(exchange, Nil) - - - val ctx = Request(exchange, null) - val decorator = new originValidationWebSocket() - - val delegate = new decorator.Delegate { - def apply(v1: Map[String, Any]): cask.router.Result[WebsocketResult] = - cask.router.Result.Success( - cask.model.Response("OK", 200) - ) - } - - def checkValid(expectedStatus : Int) = { - // Simulate the decorator being called - val result = decorator.wrapFunction(request,delegate) - // Check expected status code - result match { - case suc: Result.Success[_] => - suc.value match - case resp : WebsocketResult.Response[?] => - if resp.value.statusCode != expectedStatus then - fail(s"Expected $expectedStatus") - case _ => fail("Expected Response") - case _ => fail("Expected Result.Success") - } - } - - test("originValidation - valid origin") { - // Check that for any similar host and origin the validation is successful - for i <- 0 to 1000 do { - val host = arbitraryString.sample.get - // Check for http - exchange.getRequestHeaders.put(Headers.HOST, host) - exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://$host") - checkValid(200) - // Check for https - exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://$host") - checkValid(200) - } - } - - test("originValidation - invalid origin") { - // Check that for any different host and origin the validation is unsuccessful - for i <- 0 to 1000 do { - - var host = arbitraryString.sample.get - var origin = arbitraryString.sample.get - - // Ensure that the host and origin are different - while host == origin do { - host = arbitraryString.sample.get - origin = arbitraryString.sample.get - } - // Check for http - exchange.getRequestHeaders.put(Headers.HOST, host) - exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://$origin") - checkValid(403) - // Check for https - exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://$origin") - checkValid(403) - - } - } - -}