From ff3fdf04cd75577327bb290e1cfecd0538e31618 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= <clement.pit-claudel@epfl.ch>
Date: Fri, 20 Dec 2024 00:30:13 +0100
Subject: [PATCH] tests: Simplify decorators suite

---
 .../scala/cs214/webapp/DecoratorsSuite.scala  | 31 +++++++++----------
 1 file changed, 14 insertions(+), 17 deletions(-)

diff --git a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala
index 335e2ec..3ab706c 100644
--- a/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala
+++ b/jvm/src/test/scala/cs214/webapp/DecoratorsSuite.scala
@@ -1,37 +1,34 @@
 package cs214.webapp.server.web
 
-import cask.router.Result
-import cask.model.Request
 import io.undertow.util.Headers
 import cask.endpoints.WebsocketResult
+import io.undertow.server.HttpServerExchange
 
 class DecoratorsSuite extends munit.FunSuite:
-  val exchange = io.undertow.server.HttpServerExchange(null)
-
-  val request = Request(exchange, Nil)
-  val delegate = new decorator.Delegate:
-    def apply(v1: Map[String, Any]): cask.router.Result[WebsocketResult] =
-      cask.router.Result.Success:
-        cask.model.Response("OK", 200)
-
   val decorator = decorators.checkOriginHeader()
 
-  def decoratorStatusCode: Int =
-    decorator.wrapFunction(request, delegate) match
-      case Result.Success(resp: WebsocketResult.Response[?]) =>
+  val delegate: decorator.Delegate = _ =>
+    cask.router.Result.Success:
+      cask.model.Response("OK", 200)
+
+  def decoratorStatusCode(exchange: HttpServerExchange): Int =
+    decorator.wrapFunction(cask.model.Request(exchange, Nil), delegate) match
+      case cask.router.Result.Success(resp: WebsocketResult.Response[?]) =>
         resp.value.statusCode
       case _ => fail("Unexpected return value from decorator")
 
   test("checkOriginHeader: Valid origin"):
+    val exchange = io.undertow.server.HttpServerExchange(null)
     exchange.getRequestHeaders.put(Headers.HOST, "example.com")
     exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://example.com")
-    assertEquals(decoratorStatusCode, 200)
+    assertEquals(decoratorStatusCode(exchange), 200)
     exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://example.com")
-    assertEquals(decoratorStatusCode, 200)
+    assertEquals(decoratorStatusCode(exchange), 200)
 
   test("checkOriginHeader: Invalid origin"):
+    val exchange = io.undertow.server.HttpServerExchange(null)
     exchange.getRequestHeaders.put(Headers.HOST, "example.com")
     exchange.getRequestHeaders.put(Headers.ORIGIN, s"http://cs-214.epfl.ch")
-    assertEquals(decoratorStatusCode, 403)
+    assertEquals(decoratorStatusCode(exchange), 403)
     exchange.getRequestHeaders.put(Headers.ORIGIN, s"https://cs-214.epfl.ch")
-    assertEquals(decoratorStatusCode, 403)
+    assertEquals(decoratorStatusCode(exchange), 403)
-- 
GitLab