diff --git a/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorRoutes.scala b/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorRoutes.scala index e6d04e578..aa5000d70 100644 --- a/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorRoutes.scala +++ b/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorRoutes.scala @@ -2,7 +2,7 @@ package com.snowplowanalytics.snowplow.collectors.scalastream import cats.implicits._ import cats.effect.Sync -import org.http4s.{HttpApp, HttpRoutes} +import org.http4s._ import org.http4s.dsl.Http4sDsl import org.http4s.implicits._ import com.comcast.ip4s.Dns @@ -16,6 +16,11 @@ class CollectorRoutes[F[_]: Sync](collectorService: Service[F]) extends Http4sDs Ok("OK") } + private val corsRoute = HttpRoutes.of[F] { + case req @ OPTIONS -> _ => + collectorService.preflightResponse(req) + } + private val cookieRoutes = HttpRoutes.of[F] { case req @ POST -> Root / vendor / version => val path = collectorService.determinePath(vendor, version) @@ -53,5 +58,5 @@ class CollectorRoutes[F[_]: Sync](collectorService: Service[F]) extends Http4sDs ) } - val value: HttpApp[F] = (healthRoutes <+> cookieRoutes).orNotFound + val value: HttpApp[F] = (healthRoutes <+> corsRoute <+> cookieRoutes).orNotFound } diff --git a/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorService.scala b/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorService.scala index 78d384e26..407b1856f 100644 --- a/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorService.scala +++ b/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorService.scala @@ -26,6 +26,7 @@ import com.snowplowanalytics.snowplow.CollectorPayload.thrift.model1.CollectorPa import com.snowplowanalytics.snowplow.collectors.scalastream.model._ trait Service[F[_]] { + def preflightResponse(req: Request[F]): F[Response[F]] def cookie( body: F[Option[String]], path: String, @@ -59,7 +60,7 @@ class CollectorService[F[_]: Sync]( private val splitBatch: SplitBatch = SplitBatch(appName, appVersion) - def cookie( + override def cookie( body: F[Option[String]], path: String, cookie: Option[RequestCookie], @@ -102,17 +103,30 @@ class CollectorService[F[_]: Sync]( ) headerList = List( setCookie.map(_.toRaw1), - cacheControl(pixelExpected).map(_.toRaw1) + cacheControl(pixelExpected).map(_.toRaw1), + accessControlAllowOriginHeader(request).some, + `Access-Control-Allow-Credentials`().toRaw1.some ).flatten responseHeaders = Headers(headerList) _ <- sinkEvent(event, partitionKey) } yield buildHttpResponse(responseHeaders, pixelExpected) - def determinePath(vendor: String, version: String): String = { + override def determinePath(vendor: String, version: String): String = { val original = s"/$vendor/$version" config.paths.getOrElse(original, original) } + override def preflightResponse(req: Request[F]): F[Response[F]] = Sync[F].pure { + Response[F]( + headers = Headers( + accessControlAllowOriginHeader(req), + `Access-Control-Allow-Credentials`(), + `Access-Control-Allow-Headers`(ci"Content-Type", ci"SP-Anonymous"), + `Access-Control-Max-Age`.Cache(config.cors.accessControlMaxAge.toSeconds).asInstanceOf[`Access-Control-Max-Age`] + ) + ) + } + def extractHeader(req: Request[F], headerName: String): Option[String] = req.headers.get(CIString(headerName)).map(_.head.value) @@ -240,6 +254,19 @@ class CollectorService[F[_]: Sync]( } } + /** + * Creates an Access-Control-Allow-Origin header which specifically allows the domain which made + * the request + * + * @param request Incoming request + * @return Header allowing only the domain which made the request or everything + */ + def accessControlAllowOriginHeader(request: Request[F]): Header.Raw = + Header.Raw( + ci"Access-Control-Allow-Origin", + extractHostsFromOrigin(request.headers).headOption.map(_.renderString).getOrElse("*") + ) + /** * Determines the cookie domain to be used by inspecting the Origin header of the request * and trying to find a match in the list of domains specified in the config file. @@ -259,12 +286,12 @@ class CollectorService[F[_]: Sync]( (domains match { case Nil => None case _ => - val originHosts = extractHosts(headers) - domains.find(domain => originHosts.exists(validMatch(_, domain))) + val originDomains = extractHostsFromOrigin(headers).map(_.host.value) + domains.find(domain => originDomains.exists(validMatch(_, domain))) }).orElse(fallbackDomain) /** Extracts the host names from a list of values in the request's Origin header. */ - def extractHosts(headers: Headers): List[String] = + def extractHostsFromOrigin(headers: Headers): List[Origin.Host] = (for { // We can't use 'headers.get[Origin]' function in here because of the bug // reported here: https://github.com/http4s/http4s/issues/7236 @@ -272,15 +299,12 @@ class CollectorService[F[_]: Sync]( // and parse items individually. originSplit <- headers.get(ci"Origin").map(_.head.value.split(' ')) parsed = originSplit.map(Origin.parse(_).toOption).toList.flatten - hosts = parsed.flatMap(extractHostFromOrigin) + hosts = parsed.flatMap { + case Origin.Null => List.empty + case Origin.HostList(hosts) => hosts.toList + } } yield hosts).getOrElse(List.empty) - private def extractHostFromOrigin(originHeader: Origin): List[String] = - originHeader match { - case Origin.Null => List.empty - case Origin.HostList(hosts) => hosts.map(_.host.value).toList - } - /** * Ensures a match is valid. * We only want matches where: diff --git a/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/model.scala b/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/model.scala index 18c0b4563..ff4eabfc9 100644 --- a/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/model.scala +++ b/http4s/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/model.scala @@ -42,9 +42,12 @@ object model { sameSite: Option[SameSite] ) + final case class CORSConfig(accessControlMaxAge: FiniteDuration) + final case class CollectorConfig( paths: Map[String, String], - cookie: CookieConfig + cookie: CookieConfig, + cors: CORSConfig ) { val cookieConfig = if (cookie.enabled) Some(cookie) else None } diff --git a/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorRoutesSpec.scala b/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorRoutesSpec.scala index 46428efe8..94682e3e8 100644 --- a/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorRoutesSpec.scala +++ b/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorRoutesSpec.scala @@ -28,6 +28,9 @@ class CollectorRoutesSpec extends Specification { def getCookieCalls: List[CookieParams] = cookieCalls.toList + override def preflightResponse(req: Request[IO]): IO[Response[IO]] = + IO.pure(Response[IO](status = Ok, body = Stream.emit("preflight response").through(text.utf8.encode))) + override def cookie( body: IO[Option[String]], path: String, @@ -69,6 +72,18 @@ class CollectorRoutesSpec extends Specification { response.as[String].unsafeRunSync() must beEqualTo("OK") } + "respond to the cors route with a preflight response" in { + val (_, routes) = createTestServices + def test(uri: Uri) = { + val request = Request[IO](method = Method.OPTIONS, uri = uri) + val response = routes.run(request).unsafeRunSync() + response.as[String].unsafeRunSync() shouldEqual "preflight response" + } + test(uri"/i") + test(uri"/health") + test(uri"/p3/p4") + } + "respond to the post cookie route with the cookie response" in { val (collectorService, routes) = createTestServices diff --git a/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorServiceSpec.scala b/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorServiceSpec.scala index a172574c4..e06c8f8a9 100644 --- a/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorServiceSpec.scala +++ b/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/CollectorServiceSpec.scala @@ -175,6 +175,73 @@ class CollectorServiceSpec extends Specification { ) r.body.compile.toList.unsafeRunSync().toArray shouldEqual CollectorService.pixel } + + "include CORS headers in the response" in { + val r = service + .cookie( + body = IO.pure(Some("b")), + path = "p", + cookie = None, + request = Request[IO](), + pixelExpected = true, + doNotTrack = false, + contentType = None + ) + .unsafeRunSync() + r.headers.get[`Access-Control-Allow-Credentials`] shouldEqual Some( + `Access-Control-Allow-Credentials`() + ) + r.headers.get(ci"Access-Control-Allow-Origin").map(_.head) shouldEqual Some( + Header.Raw(ci"Access-Control-Allow-Origin", "*") + ) + } + + "include the origin if given to CORS headers in the response" in { + val headers = Headers( + Origin + .HostList( + NonEmptyList.of( + Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com")), + Origin.Host( + scheme = Uri.Scheme.http, + host = Uri.Host.unsafeFromString("otherorigin.com"), + port = Some(8080) + ) + ) + ) + .asInstanceOf[Origin] + ) + val request = Request[IO](headers = headers) + val r = service + .cookie( + body = IO.pure(Some("b")), + path = "p", + cookie = None, + request = request, + pixelExpected = true, + doNotTrack = false, + contentType = None + ) + .unsafeRunSync() + r.headers.get[`Access-Control-Allow-Credentials`] shouldEqual Some( + `Access-Control-Allow-Credentials`() + ) + r.headers.get(ci"Access-Control-Allow-Origin").map(_.head) shouldEqual Some( + Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com") + ) + } + } + + "preflightResponse" in { + "return a response appropriate to cors preflight options requests" in { + val expected = Headers( + Header.Raw(ci"Access-Control-Allow-Origin", "*"), + `Access-Control-Allow-Credentials`(), + `Access-Control-Allow-Headers`(ci"Content-Type", ci"SP-Anonymous"), + `Access-Control-Max-Age`.Cache(60).asInstanceOf[`Access-Control-Max-Age`] + ) + service.preflightResponse(Request[IO]()).unsafeRunSync.headers shouldEqual expected + } } "buildEvent" in { @@ -379,6 +446,46 @@ class CollectorServiceSpec extends Specification { } } + "accessControlAllowOriginHeader" in { + "give a restricted ACAO header if there is an Origin header in the request" in { + val headers = Headers( + Origin + .HostList( + NonEmptyList.of( + Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com")) + ) + ) + .asInstanceOf[Origin] + ) + val request = Request[IO](headers = headers) + val expected = Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com") + service.accessControlAllowOriginHeader(request) shouldEqual expected + } + "give a restricted ACAO header if there are multiple Origin headers in the request" in { + val headers = Headers( + Origin + .HostList( + NonEmptyList.of( + Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com")), + Origin.Host( + scheme = Uri.Scheme.http, + host = Uri.Host.unsafeFromString("otherorigin.com"), + port = Some(8080) + ) + ) + ) + .asInstanceOf[Origin] + ) + val request = Request[IO](headers = headers) + val expected = Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com") + service.accessControlAllowOriginHeader(request) shouldEqual expected + } + "give an open ACAO header if there are no Origin headers in the request" in { + val expected = Header.Raw(ci"Access-Control-Allow-Origin", "*") + service.accessControlAllowOriginHeader(Request[IO]()) shouldEqual expected + } + } + "cookieDomain" in { val testCookieConfig = CookieConfig( enabled = true, @@ -496,18 +603,17 @@ class CollectorServiceSpec extends Specification { "extractHosts" in { "correctly extract the host names from a list of values in the request's Origin header" in { - val origin: Origin = Origin.HostList( - NonEmptyList.of( - Origin.Host(scheme = Uri.Scheme.https, host = Uri.Host.unsafeFromString("origin.com")), - Origin.Host( - scheme = Uri.Scheme.http, - host = Uri.Host.unsafeFromString("subdomain.otherorigin.gov.co.uk"), - port = Some(8080) - ) + val originHostList = NonEmptyList.of( + Origin.Host(scheme = Uri.Scheme.https, host = Uri.Host.unsafeFromString("origin.com")), + Origin.Host( + scheme = Uri.Scheme.http, + host = Uri.Host.unsafeFromString("subdomain.otherorigin.gov.co.uk"), + port = Some(8080) ) ) - val headers = Headers(origin.toRaw1) - service.extractHosts(headers) shouldEqual Seq("origin.com", "subdomain.otherorigin.gov.co.uk") + val origin: Origin = Origin.HostList(originHostList) + val headers = Headers(origin.toRaw1) + service.extractHostsFromOrigin(headers) shouldEqual originHostList.toList } } diff --git a/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/TestUtils.scala b/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/TestUtils.scala index a60a79c0a..a4aa99982 100644 --- a/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/TestUtils.scala +++ b/http4s/src/test/scala/com.snowplowanalytics.snowplow.collectors.scalastream/TestUtils.scala @@ -20,6 +20,7 @@ object TestUtils { secure = false, httpOnly = false, sameSite = None - ) + ), + cors = CORSConfig(60.seconds) ) } diff --git a/stdout/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/StdoutCollector.scala b/stdout/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/StdoutCollector.scala index 7a1a7f592..e70564142 100644 --- a/stdout/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/StdoutCollector.scala +++ b/stdout/src/main/scala/com.snowplowanalytics.snowplow.collectors.scalastream/StdoutCollector.scala @@ -39,7 +39,8 @@ object StdoutCollector extends IOApp { secure = false, httpOnly = false, sameSite = None - ) + ), + cors = CORSConfig(60.seconds) ), BuildInfo.shortName, BuildInfo.version