From a8cf4b49bc08742c92606318d66314fc432771dd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Wed, 18 Dec 2024 23:09:51 +0100
Subject: [PATCH] server: Refactor websocket decorator implementation

---
 build.sbt                                     |  2 -
 .../server/decorators/originValidation.scala  | 41 ---------
 .../webapp/server/web/WebServerRoutes.scala   | 20 ++---
 .../cs214/webapp/server/web/decorators.scala  | 19 ++++
 .../scala/cs214/webapp/DecoratorsSuite.scala  | 37 ++++++++
 .../decorators/originValidationTest.scala     | 87 -------------------
 6 files changed, 66 insertions(+), 140 deletions(-)
 delete mode 100644 jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala
 create mode 100644 jvm/src/main/scala/cs214/webapp/server/web/decorators.scala
 create mode 100644 jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala
 delete mode 100644 jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala

diff --git a/build.sbt b/build.sbt
index 6aa4f41..41f8a90 100644
--- a/build.sbt
+++ b/build.sbt
@@ -6,7 +6,6 @@ val webSocketVersion = "1.5.4"
 val caskVersion = "0.9.4"
 val slf4jVersion = "2.0.5"
 val reflectionsVersion = "0.10.2"
-val scalaCheckVersion = "1.18.1"
 
 val options = List("-deprecation", "-feature", "-language:fewerBraces", "-Xfatal-warnings")
 
@@ -32,6 +31,5 @@ lazy val webappLib = crossProject(JSPlatform, JVMPlatform).in(file("."))
       "org.slf4j" % "slf4j-nop" % slf4jVersion,
       "org.reflections" % "reflections" % reflectionsVersion,
       "org.scala-lang" %% "toolkit-test" % toolkitVersion % Test,
-      "org.scalacheck" %% "scalacheck" % scalaCheckVersion % Test,
     ),
   )
diff --git a/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala b/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala
deleted file mode 100644
index 874919d..0000000
--- a/jvm/src/main/scala/cs214/webapp/server/decorators/originValidation.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-package cs214.webapp.server.decorators
-
-import cask.model.Response
-import cask.router.{Decorator}
-import cask.router.Result
-
-
-
-/** Decorator to validate the origin of the request. 
- * Cask Decorators enforce strict matching type signatures
- * with the core function they are decorating.
- * So for each new Return type T, a new Decorator class 
- * extending originValidation must be created.
- * The only method to override is constructForbiddenResponse
- * which might be different for each Return type T.
- * 
-*/
-private class originValidation[T] extends Decorator[Any, T, Any] {
-  def wrapFunction(ctx: cask.Request, delegate: Delegate): Result[T] = {
-
-    // Check if the Origin header is valid
-    val isSourceValid = ctx.headers.get("host").flatMap(_.headOption).exists: host =>
-      ctx.headers.get("origin").flatMap(_.headOption).exists: origin =>
-        origin == s"http://$host" || origin == s"https://$host"
-
-    if (isSourceValid) {
-      // Call the core logic
-      delegate(Map.empty)
-    } else {
-      // Return a 403 Forbidden response
-      constructForbiddenResponse.asInstanceOf[Result[T]] 
-    }
-  }
-  def constructForbiddenResponse: Result[T] = ???
-}
-/* WebSocket origin validation */
-class originValidationWebSocket extends originValidation[cask.endpoints.WebsocketResult] {
-  override def constructForbiddenResponse: Result[cask.endpoints.WebsocketResult] = {
-    Result.Success(new cask.endpoints.WebsocketResult.Response(cask.Response("Forbidden", 403)))
-  }
-}
\ No newline at end of file
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
index 3c6f8ce..e8199ed 100644
--- a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
+++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala
@@ -6,7 +6,8 @@ import java.net.InetAddress
 import scala.jdk.CollectionConverters.*
 import scala.util.Try
 import cask.endpoints.JsonData
-import cs214.webapp.server.decorators.originValidationWebSocket
+
+import decorators.checkOriginHeader
 
 /** HTTP routes of the WebServer */
 private[server] final case class WebServerRoutes()(using cc: castor.Context, log: cask.Logger) extends cask.Routes:
@@ -34,7 +35,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
         _ = if WebServer.debug then println(f"[debug] found address ${addr.getHostAddress}")
       yield addr.getHostAddress
     Try(addresses.toList.head).getOrElse(InetAddress.getLocalHost.getHostAddress)
-  
+
   @cask.get("/")
   def getIndexFile() = HTML_STATIC_FILE
 
@@ -44,7 +45,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
   // For all /app subsegments, provide the HTML page
   @cask.get(f"${Endpoints.App}")
   def getApp(segments: cask.RemainingPathSegments) = HTML_STATIC_FILE
-  
+
   @cask.getJson(f"${Endpoints.Api.listApps}")
   def getListApps() =
     ListAppsResponse.Wire.encode(ListAppsResponse(WebServer.appDirectory.values.map(_.appInfo).toSeq))
@@ -61,6 +62,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
       case None =>
         cask.Response(f"Unknown instance id $instanceId", 400)
     response
+
   @cask.post(f"${Endpoints.Api.createInstance}")
   def postInitApp(request: cask.Request) =
     val response: cask.Response[JsonData] =
@@ -71,14 +73,12 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log
         val appId = WebServer.createInstance(req.get.appName, req.get.userIds)
         CreateInstanceResponse.Wire.encode(CreateInstanceResponse(appId))
     response
-    
-  @originValidationWebSocket()
+
+  @checkOriginHeader
   @cask.websocket(f"${Endpoints.WebSocket}/:instanceId/:userId")
   def websocket(instanceId: String, userId: String, request: cask.Request): cask.WebsocketResult =
-      WebServer.instances.get(instanceId) match
-        case Some(app) => app.connect(userId)
-        case None => cask.Response(f"Unknown instance id $instanceId", 400)
-
-  
+    WebServer.instances.get(instanceId) match
+      case Some(app) => app.connect(userId)
+      case None => cask.Response(f"Unknown instance id $instanceId", 400)
 
   initialize()
diff --git a/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala b/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala
new file mode 100644
index 0000000..12a22cf
--- /dev/null
+++ b/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala
@@ -0,0 +1,19 @@
+package cs214.webapp
+package server
+package web
+
+object decorators:
+  extension (request: cask.Request)
+    def originHeaderMatchesHost: Boolean =
+      // request.headers.get("sec-fetch-site").exists(_.contains("same-origin")) // Not supported in Safari
+      request.headers.get("origin").flatMap(_.headOption).exists: origin =>
+        request.headers.get("host").flatMap(_.headOption).exists: host =>
+          origin == f"http://$host" || origin == f"https://$host"
+
+  class checkOriginHeader extends cask.router.Decorator[cask.WebsocketResult, cask.WebsocketResult, Any]:
+    override def wrapFunction(request: cask.Request, delegate: Delegate) =
+      if request.originHeaderMatchesHost then
+        delegate(Map())
+      else
+        cask.router.Result.Success:
+          cask.Response(f"Invalid or missing 'Origin' header: must match the Host", 403)
diff --git a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala
new file mode 100644
index 0000000..335e2ec
--- /dev/null
+++ b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala
@@ -0,0 +1,37 @@
+package cs214.webapp.server.web
+
+import cask.router.Result
+import cask.model.Request
+import io.undertow.util.Headers
+import cask.endpoints.WebsocketResult
+
+class DecoratorsSuite extends munit.FunSuite:
+  val exchange = io.undertow.server.HttpServerExchange(null)
+
+  val request = Request(exchange, Nil)
+  val delegate = new decorator.Delegate:
+    def apply(v1: Map[String, Any]): cask.router.Result[WebsocketResult] =
+      cask.router.Result.Success:
+        cask.model.Response("OK", 200)
+
+  val decorator = decorators.checkOriginHeader()
+
+  def decoratorStatusCode: Int =
+    decorator.wrapFunction(request, delegate) match
+      case Result.Success(resp: WebsocketResult.Response[?]) =>
+        resp.value.statusCode
+      case _ => fail("Unexpected return value from decorator")
+
+  test("checkOriginHeader: Valid origin"):
+    exchange.getRequestHeaders.put(Headers.HOST, "example.com")
+    exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://example.com")
+    assertEquals(decoratorStatusCode, 200)
+    exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://example.com")
+    assertEquals(decoratorStatusCode, 200)
+
+  test("checkOriginHeader: Invalid origin"):
+    exchange.getRequestHeaders.put(Headers.HOST, "example.com")
+    exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://cs-214.epfl.ch")
+    assertEquals(decoratorStatusCode, 403)
+    exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://cs-214.epfl.ch")
+    assertEquals(decoratorStatusCode, 403)
diff --git a/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala b/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala
deleted file mode 100644
index 00ad20a..0000000
--- a/jvm/src/test/scala/cs214/webapp/decorators/originValidationTest.scala
+++ /dev/null
@@ -1,87 +0,0 @@
-package cs214.webapp.decorators
-
-
-import cask.router.Result
-import cask.model.Request
-import io.undertow.server.HttpServerExchange
-import io.undertow.util.Headers
-import cask.endpoints.WebsocketResult
-import scala.concurrent.duration.Duration
-import cs214.webapp.server.decorators.originValidationWebSocket
-import org.scalacheck.Gen
-
-val arbitraryString: Gen[String] = Gen.alphaStr
-
-
-
-class OriginValidationTest extends munit.FunSuite {
-    override val munitTimeout: Duration = Duration(1, "s")
-
-    
-    val exchange = new HttpServerExchange(null)
-    val request = Request(exchange, Nil)
-
-
-    val ctx = Request(exchange, null)
-    val decorator = new originValidationWebSocket()
-
-    val delegate = new decorator.Delegate {
-        def apply(v1: Map[String, Any]): cask.router.Result[WebsocketResult] =
-        cask.router.Result.Success(
-            cask.model.Response("OK", 200)
-        )
-    }
-
-    def checkValid(expectedStatus : Int) = {
-        // Simulate the decorator being called
-        val result = decorator.wrapFunction(request,delegate)
-        // Check expected status code
-        result match {
-            case suc: Result.Success[_] => 
-                suc.value match
-                    case resp : WebsocketResult.Response[?] => 
-                        if resp.value.statusCode != expectedStatus then
-                            fail(s"Expected $expectedStatus")
-                    case _ => fail("Expected Response")
-            case _ => fail("Expected Result.Success")
-        }
-    }
-
-    test("originValidation - valid origin") {
-        // Check that for any similar host and origin the validation is successful
-        for i <- 0 to 1000 do {
-            val host = arbitraryString.sample.get
-            // Check for http
-            exchange.getRequestHeaders.put(Headers.HOST, host)
-            exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://$host")
-            checkValid(200)
-            // Check for https
-            exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://$host")
-            checkValid(200)
-        }
-    }
-
-    test("originValidation - invalid origin") {
-        // Check that for any different host and origin the validation is unsuccessful
-        for i <- 0 to 1000 do {
-
-            var host = arbitraryString.sample.get
-            var origin = arbitraryString.sample.get
-
-            // Ensure that the host and origin are different
-            while host == origin do {
-                host = arbitraryString.sample.get
-                origin = arbitraryString.sample.get
-            }   
-            // Check for http
-            exchange.getRequestHeaders.put(Headers.HOST, host)
-            exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://$origin")
-            checkValid(403)
-            // Check for https
-            exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://$origin")
-            checkValid(403)
-
-        }
-    }
-  
-}
-- 
GitLab