diff --git a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala index 335e2ec2f2d5eb90d887025b7532cc6a6b7f44ee..3ab706cfea35729c2a377d7a08b9ed9664b9797c 100644 --- a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala +++ b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala @@ -1,37 +1,34 @@ package cs214.webapp.server.web -import cask.router.Result -import cask.model.Request import io.undertow.util.Headers import cask.endpoints.WebsocketResult +import io.undertow.server.HttpServerExchange class DecoratorsSuite extends munit.FunSuite: - val exchange = io.undertow.server.HttpServerExchange(null) - - val request = Request(exchange, Nil) - val delegate = new decorator.Delegate: - def apply(v1: Map[String, Any]): cask.router.Result[WebsocketResult] = - cask.router.Result.Success: - cask.model.Response("OK", 200) - val decorator = decorators.checkOriginHeader() - def decoratorStatusCode: Int = - decorator.wrapFunction(request, delegate) match - case Result.Success(resp: WebsocketResult.Response[?]) => + val delegate: decorator.Delegate = _ => + cask.router.Result.Success: + cask.model.Response("OK", 200) + + def decoratorStatusCode(exchange: HttpServerExchange): Int = + decorator.wrapFunction(cask.model.Request(exchange, Nil), delegate) match + case cask.router.Result.Success(resp: WebsocketResult.Response[?]) => resp.value.statusCode case _ => fail("Unexpected return value from decorator") test("checkOriginHeader: Valid origin"): + val exchange = io.undertow.server.HttpServerExchange(null) exchange.getRequestHeaders.put(Headers.HOST, "example.com") exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://example.com") - assertEquals(decoratorStatusCode, 200) + assertEquals(decoratorStatusCode(exchange), 200) exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://example.com") - assertEquals(decoratorStatusCode, 200) + assertEquals(decoratorStatusCode(exchange), 200) test("checkOriginHeader: Invalid origin"): + val exchange = io.undertow.server.HttpServerExchange(null) exchange.getRequestHeaders.put(Headers.HOST, "example.com") exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://cs-214.epfl.ch") - assertEquals(decoratorStatusCode, 403) + assertEquals(decoratorStatusCode(exchange), 403) exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://cs-214.epfl.ch") - assertEquals(decoratorStatusCode, 403) + assertEquals(decoratorStatusCode(exchange), 403)