Skip to content

Commit

Permalink
Add CORS support
Browse files Browse the repository at this point in the history
  • Loading branch information
spenes committed Aug 14, 2023
1 parent fa995ae commit adf8035
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.snowplowanalytics.snowplow.collectors.scalastream

import cats.implicits._
import cats.effect.Sync
import org.http4s.{HttpApp, HttpRoutes}
import org.http4s._
import org.http4s.dsl.Http4sDsl
import org.http4s.implicits._
import com.comcast.ip4s.Dns
Expand All @@ -16,6 +16,11 @@ class CollectorRoutes[F[_]: Sync](collectorService: Service[F]) extends Http4sDs
Ok("OK")
}

private val corsRoute = HttpRoutes.of[F] {
case req @ OPTIONS -> _ =>
collectorService.preflightResponse(req)
}

private val cookieRoutes = HttpRoutes.of[F] {
case req @ POST -> Root / vendor / version =>
val path = collectorService.determinePath(vendor, version)
Expand Down Expand Up @@ -53,5 +58,5 @@ class CollectorRoutes[F[_]: Sync](collectorService: Service[F]) extends Http4sDs
)
}

val value: HttpApp[F] = (healthRoutes <+> cookieRoutes).orNotFound
val value: HttpApp[F] = (healthRoutes <+> corsRoute <+> cookieRoutes).orNotFound
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.snowplowanalytics.snowplow.CollectorPayload.thrift.model1.CollectorPa
import com.snowplowanalytics.snowplow.collectors.scalastream.model._

trait Service[F[_]] {
def preflightResponse(req: Request[F]): F[Response[F]]
def cookie(
body: F[Option[String]],
path: String,
Expand Down Expand Up @@ -59,7 +60,7 @@ class CollectorService[F[_]: Sync](

private val splitBatch: SplitBatch = SplitBatch(appName, appVersion)

def cookie(
override def cookie(
body: F[Option[String]],
path: String,
cookie: Option[RequestCookie],
Expand Down Expand Up @@ -102,17 +103,30 @@ class CollectorService[F[_]: Sync](
)
headerList = List(
setCookie.map(_.toRaw1),
cacheControl(pixelExpected).map(_.toRaw1)
cacheControl(pixelExpected).map(_.toRaw1),
accessControlAllowOriginHeader(request).some,
`Access-Control-Allow-Credentials`().toRaw1.some
).flatten
responseHeaders = Headers(headerList)
_ <- sinkEvent(event, partitionKey)
} yield buildHttpResponse(responseHeaders, pixelExpected)

def determinePath(vendor: String, version: String): String = {
override def determinePath(vendor: String, version: String): String = {
val original = s"/$vendor/$version"
config.paths.getOrElse(original, original)
}

override def preflightResponse(req: Request[F]): F[Response[F]] = Sync[F].pure {
Response[F](
headers = Headers(
accessControlAllowOriginHeader(req),
`Access-Control-Allow-Credentials`(),
`Access-Control-Allow-Headers`(ci"Content-Type", ci"SP-Anonymous"),
`Access-Control-Max-Age`.Cache(config.cors.accessControlMaxAge.toSeconds).asInstanceOf[`Access-Control-Max-Age`]
)
)
}

def extractHeader(req: Request[F], headerName: String): Option[String] =
req.headers.get(CIString(headerName)).map(_.head.value)

Expand Down Expand Up @@ -240,6 +254,19 @@ class CollectorService[F[_]: Sync](
}
}

/**
* Creates an Access-Control-Allow-Origin header which specifically allows the domain which made
* the request
*
* @param request Incoming request
* @return Header allowing only the domain which made the request or everything
*/
def accessControlAllowOriginHeader(request: Request[F]): Header.Raw =
Header.Raw(
ci"Access-Control-Allow-Origin",
extractHostsFromOrigin(request.headers).headOption.map(_.renderString).getOrElse("*")
)

/**
* Determines the cookie domain to be used by inspecting the Origin header of the request
* and trying to find a match in the list of domains specified in the config file.
Expand All @@ -259,28 +286,25 @@ class CollectorService[F[_]: Sync](
(domains match {
case Nil => None
case _ =>
val originHosts = extractHosts(headers)
domains.find(domain => originHosts.exists(validMatch(_, domain)))
val originDomains = extractHostsFromOrigin(headers).map(_.host.value)
domains.find(domain => originDomains.exists(validMatch(_, domain)))
}).orElse(fallbackDomain)

/** Extracts the host names from a list of values in the request's Origin header. */
def extractHosts(headers: Headers): List[String] =
def extractHostsFromOrigin(headers: Headers): List[Origin.Host] =
(for {
// We can't use 'headers.get[Origin]' function in here because of the bug
// reported here: https://github.com/http4s/http4s/issues/7236
// To circumvent the bug, we split the the Origin header value with blank char
// and parse items individually.
originSplit <- headers.get(ci"Origin").map(_.head.value.split(' '))
parsed = originSplit.map(Origin.parse(_).toOption).toList.flatten
hosts = parsed.flatMap(extractHostFromOrigin)
hosts = parsed.flatMap {
case Origin.Null => List.empty
case Origin.HostList(hosts) => hosts.toList
}
} yield hosts).getOrElse(List.empty)

private def extractHostFromOrigin(originHeader: Origin): List[String] =
originHeader match {
case Origin.Null => List.empty
case Origin.HostList(hosts) => hosts.map(_.host.value).toList
}

/**
* Ensures a match is valid.
* We only want matches where:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ object model {
sameSite: Option[SameSite]
)

final case class CORSConfig(accessControlMaxAge: FiniteDuration)

final case class CollectorConfig(
paths: Map[String, String],
cookie: CookieConfig
cookie: CookieConfig,
cors: CORSConfig
) {
val cookieConfig = if (cookie.enabled) Some(cookie) else None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class CollectorRoutesSpec extends Specification {

def getCookieCalls: List[CookieParams] = cookieCalls.toList

override def preflightResponse(req: Request[IO]): IO[Response[IO]] =
IO.pure(Response[IO](status = Ok, body = Stream.emit("preflight response").through(text.utf8.encode)))

override def cookie(
body: IO[Option[String]],
path: String,
Expand Down Expand Up @@ -69,6 +72,18 @@ class CollectorRoutesSpec extends Specification {
response.as[String].unsafeRunSync() must beEqualTo("OK")
}

"respond to the cors route with a preflight response" in {
val (_, routes) = createTestServices
def test(uri: Uri) = {
val request = Request[IO](method = Method.OPTIONS, uri = uri)
val response = routes.run(request).unsafeRunSync()
response.as[String].unsafeRunSync() shouldEqual "preflight response"
}
test(uri"/i")
test(uri"/health")
test(uri"/p3/p4")
}

"respond to the post cookie route with the cookie response" in {
val (collectorService, routes) = createTestServices

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,73 @@ class CollectorServiceSpec extends Specification {
)
r.body.compile.toList.unsafeRunSync().toArray shouldEqual CollectorService.pixel
}

"include CORS headers in the response" in {
val r = service
.cookie(
body = IO.pure(Some("b")),
path = "p",
cookie = None,
request = Request[IO](),
pixelExpected = true,
doNotTrack = false,
contentType = None
)
.unsafeRunSync()
r.headers.get[`Access-Control-Allow-Credentials`] shouldEqual Some(
`Access-Control-Allow-Credentials`()
)
r.headers.get(ci"Access-Control-Allow-Origin").map(_.head) shouldEqual Some(
Header.Raw(ci"Access-Control-Allow-Origin", "*")
)
}

"include the origin if given to CORS headers in the response" in {
val headers = Headers(
Origin
.HostList(
NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com")),
Origin.Host(
scheme = Uri.Scheme.http,
host = Uri.Host.unsafeFromString("otherorigin.com"),
port = Some(8080)
)
)
)
.asInstanceOf[Origin]
)
val request = Request[IO](headers = headers)
val r = service
.cookie(
body = IO.pure(Some("b")),
path = "p",
cookie = None,
request = request,
pixelExpected = true,
doNotTrack = false,
contentType = None
)
.unsafeRunSync()
r.headers.get[`Access-Control-Allow-Credentials`] shouldEqual Some(
`Access-Control-Allow-Credentials`()
)
r.headers.get(ci"Access-Control-Allow-Origin").map(_.head) shouldEqual Some(
Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com")
)
}
}

"preflightResponse" in {
"return a response appropriate to cors preflight options requests" in {
val expected = Headers(
Header.Raw(ci"Access-Control-Allow-Origin", "*"),
`Access-Control-Allow-Credentials`(),
`Access-Control-Allow-Headers`(ci"Content-Type", ci"SP-Anonymous"),
`Access-Control-Max-Age`.Cache(60).asInstanceOf[`Access-Control-Max-Age`]
)
service.preflightResponse(Request[IO]()).unsafeRunSync.headers shouldEqual expected
}
}

"buildEvent" in {
Expand Down Expand Up @@ -379,6 +446,46 @@ class CollectorServiceSpec extends Specification {
}
}

"accessControlAllowOriginHeader" in {
"give a restricted ACAO header if there is an Origin header in the request" in {
val headers = Headers(
Origin
.HostList(
NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com"))
)
)
.asInstanceOf[Origin]
)
val request = Request[IO](headers = headers)
val expected = Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com")
service.accessControlAllowOriginHeader(request) shouldEqual expected
}
"give a restricted ACAO header if there are multiple Origin headers in the request" in {
val headers = Headers(
Origin
.HostList(
NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com")),
Origin.Host(
scheme = Uri.Scheme.http,
host = Uri.Host.unsafeFromString("otherorigin.com"),
port = Some(8080)
)
)
)
.asInstanceOf[Origin]
)
val request = Request[IO](headers = headers)
val expected = Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com")
service.accessControlAllowOriginHeader(request) shouldEqual expected
}
"give an open ACAO header if there are no Origin headers in the request" in {
val expected = Header.Raw(ci"Access-Control-Allow-Origin", "*")
service.accessControlAllowOriginHeader(Request[IO]()) shouldEqual expected
}
}

"cookieDomain" in {
val testCookieConfig = CookieConfig(
enabled = true,
Expand Down Expand Up @@ -496,18 +603,17 @@ class CollectorServiceSpec extends Specification {

"extractHosts" in {
"correctly extract the host names from a list of values in the request's Origin header" in {
val origin: Origin = Origin.HostList(
NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.https, host = Uri.Host.unsafeFromString("origin.com")),
Origin.Host(
scheme = Uri.Scheme.http,
host = Uri.Host.unsafeFromString("subdomain.otherorigin.gov.co.uk"),
port = Some(8080)
)
val originHostList = NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.https, host = Uri.Host.unsafeFromString("origin.com")),
Origin.Host(
scheme = Uri.Scheme.http,
host = Uri.Host.unsafeFromString("subdomain.otherorigin.gov.co.uk"),
port = Some(8080)
)
)
val headers = Headers(origin.toRaw1)
service.extractHosts(headers) shouldEqual Seq("origin.com", "subdomain.otherorigin.gov.co.uk")
val origin: Origin = Origin.HostList(originHostList)
val headers = Headers(origin.toRaw1)
service.extractHostsFromOrigin(headers) shouldEqual originHostList.toList
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ object TestUtils {
secure = false,
httpOnly = false,
sameSite = None
)
),
cors = CORSConfig(60.seconds)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ object StdoutCollector extends IOApp {
secure = false,
httpOnly = false,
sameSite = None
)
),
cors = CORSConfig(60.seconds)
),
BuildInfo.shortName,
BuildInfo.version
Expand Down

0 comments on commit adf8035

Please sign in to comment.