diff --git a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala index 9739aababb2e551cc8e627ba76d7a9ce66aed853..0fa2c61701aebd879a9051e480d458c6b90d94d8 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/WebServerRoutes.scala @@ -8,11 +8,38 @@ import scala.util.{Try, Success, Failure} import cask.endpoints.JsonData import decorators.checkOriginHeader +import decorators.adminAuth +import java.io.File +import scala.io.Source /** HTTP routes of the WebServer */ private[server] final case class WebServerRoutes()(using cc: castor.Context, log: cask.Logger) extends cask.Routes: /** Paths where the static content served by the server is stored */ private val WEB_SRC_PATH = "www/static/" + + /** Read admin key from secret file */ + final private val ADMIN_AUTH = { + val secretPath = "/run/secrets/admin_api_key" + val secretFile = new File(secretPath) + if (secretFile.exists && secretFile.canRead) { + try { + // Read the secret + val secret = Source.fromFile(secretPath).mkString.trim + if (secret.isEmpty) { + println(s"Warning: Secret file at $secretPath is empty!") + None + } else { + Some(secret) + } + } catch { + case e: Exception => + println(s"Error reading secret file: ${e.getMessage}") + None + } + } else{ + None + } + } /** HTML page to serve when accessing the server `/` and `/app/...` path */ private def HTML_STATIC_FILE = @@ -77,6 +104,7 @@ private[server] final case class WebServerRoutes()(using cc: castor.Context, log else cask.Response(f"Unknown app '$appName'", 400) + @adminAuth(ADMIN_AUTH) @cask.getJson(f"${Endpoints.Admin.status}") def adminStatus(): cask.Response[JsonData] = AdminStatusResponseEncoder.encode: diff --git a/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala b/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala index 12a22cfffac1d0e6585dfbd9d8d20174ab0db8a6..4b057556231157d605b0e8c2d660fa0d9793951a 100644 --- a/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala +++ b/jvm/src/main/scala/cs214/webapp/server/web/decorators.scala @@ -2,6 +2,10 @@ package cs214.webapp package server package web +import java.security.MessageDigest +import java.util.Base64 +import cask.endpoints.JsonData + object decorators: extension (request: cask.Request) def originHeaderMatchesHost: Boolean = @@ -17,3 +21,30 @@ object decorators: else cask.router.Result.Success: cask.Response(f"Invalid or missing 'Origin' header: must match the Host", 403) + + class adminAuth(expectedAdminKeyOpt: Option[String]) extends cask.router.Decorator[cask.Response[JsonData], cask.Response[JsonData], Any]: + private val expectedAuthOpt: Option[String] = expectedAdminKeyOpt.map: + key => s"Basic ${Base64.getEncoder.encodeToString(s"$key:".getBytes("UTF-8"))}" + + private val expectedAuthBytesOpt: Option[Array[Byte]] = expectedAuthOpt.map(_.getBytes("UTF-8")) + override def wrapFunction(request: cask.Request, delegate: Delegate) = + val authHeaderOpt: Option[String] = request.headers.get("authorization").flatMap(_.headOption) + val authHeaderBytesOpt: Option[Array[Byte]] = authHeaderOpt.map(_.getBytes("UTF-8")) + + // Perform the check using constant time comparison + val authorized = (expectedAuthBytesOpt, authHeaderBytesOpt) match + case (Some(expectedBytes), Some(actualBytes)) => + MessageDigest.isEqual(actualBytes, expectedBytes) + case _ => + // Either expected key wasn't configured OR header was missing/malformed + false + if (authorized){ + delegate(Map()) + } else { + cask.router.Result.Success: + cask.Response( + data = ujson.Obj("error" -> "Unauthorized"), + statusCode = 401, + headers = Seq("WWW-Authenticate" -> "Basic realm=\"Admin API\"") + ) + } \ No newline at end of file diff --git a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala index 3ab706cfea35729c2a377d7a08b9ed9664b9797c..e6b78cd30165cf458bbd142880e9544cce2c8cb9 100644 --- a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala +++ b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala @@ -3,32 +3,64 @@ package cs214.webapp.server.web import io.undertow.util.Headers import cask.endpoints.WebsocketResult import io.undertow.server.HttpServerExchange +import cask.model.Response +import java.util.Base64 class DecoratorsSuite extends munit.FunSuite: - val decorator = decorators.checkOriginHeader() + val originDecorator = decorators.checkOriginHeader() - val delegate: decorator.Delegate = _ => - cask.router.Result.Success: - cask.model.Response("OK", 200) - - def decoratorStatusCode(exchange: HttpServerExchange): Int = - decorator.wrapFunction(cask.model.Request(exchange, Nil), delegate) match - case cask.router.Result.Success(resp: WebsocketResult.Response[?]) => - resp.value.statusCode - case _ => fail("Unexpected return value from decorator") + def decoratorStatusCode[O]( + decorator: cask.router.Decorator[?, O, ?], + exchange: HttpServerExchange, + innerReturnFunc: O +): Int = { + decorator.wrapFunction(cask.model.Request(exchange, Nil), _ => + cask.router.Result.Success(innerReturnFunc) + ) match { + case cask.router.Result.Success(resp: cask.model.Response[?]) => + resp.statusCode + case cask.router.Result.Success(resp: cask.endpoints.WebsocketResult.Response[_]) => + resp.value.statusCode + case other => + fail(s"Unexpected return value from decorator: $other") + } +} test("checkOriginHeader: Valid origin"): val exchange = io.undertow.server.HttpServerExchange(null) exchange.getRequestHeaders.put(Headers.HOST, "example.com") exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://example.com") - assertEquals(decoratorStatusCode(exchange), 200) + assertEquals(decoratorStatusCode(originDecorator, exchange, Response("OK", 200)), 200) exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://example.com") - assertEquals(decoratorStatusCode(exchange), 200) + assertEquals(decoratorStatusCode(originDecorator, exchange, Response("OK", 200)), 200) test("checkOriginHeader: Invalid origin"): val exchange = io.undertow.server.HttpServerExchange(null) exchange.getRequestHeaders.put(Headers.HOST, "example.com") exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://cs-214.epfl.ch") - assertEquals(decoratorStatusCode(exchange), 403) + assertEquals(decoratorStatusCode(originDecorator, exchange, Response("OK", 200)), 403) exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://cs-214.epfl.ch") - assertEquals(decoratorStatusCode(exchange), 403) + assertEquals(decoratorStatusCode(originDecorator, exchange, Response("OK", 200)), 403) + + + val adminAuthDecorator = decorators.adminAuth(Some("DEADBEEF")) + + test("authorized access: valid authorization header"): + val adminKey = "DEADBEEF" + val expectedHeader = + s"Basic ${Base64.getEncoder.encodeToString(s"$adminKey:" .getBytes("UTF-8"))}" + val exchange = new HttpServerExchange(null) + exchange.getRequestHeaders.put(Headers.AUTHORIZATION, expectedHeader) + // EXPECT 200 AUTHORIZED + assertEquals(decoratorStatusCode(adminAuthDecorator, exchange, Response("OK", 200)), 200) + + test("unauthorized access: invalid authentication"): + val adminKey = "wrongpassword" + val expectedHeader = + s"Basic ${Base64.getEncoder.encodeToString(s"$adminKey:" .getBytes("UTF-8"))}" + val exchange = new HttpServerExchange(null) + exchange.getRequestHeaders.put(Headers.AUTHORIZATION, expectedHeader) + // EXPECT 401 UNAUTHORIZED + assertEquals(decoratorStatusCode(adminAuthDecorator, exchange, Response("OK", 200)), 401) + +