Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CORS support #330

Merged
merged 1 commit into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.snowplowanalytics.snowplow.collectors.scalastream

import cats.implicits._
import cats.effect.Sync
import org.http4s.{HttpApp, HttpRoutes}
import org.http4s._
import org.http4s.dsl.Http4sDsl
import org.http4s.implicits._
import com.comcast.ip4s.Dns
Expand All @@ -16,6 +16,11 @@ class CollectorRoutes[F[_]: Sync](collectorService: Service[F]) extends Http4sDs
Ok("OK")
}

private val corsRoute = HttpRoutes.of[F] {
case req @ OPTIONS -> _ =>
collectorService.preflightResponse(req)
}

private val cookieRoutes = HttpRoutes.of[F] {
case req @ POST -> Root / vendor / version =>
val path = collectorService.determinePath(vendor, version)
Expand Down Expand Up @@ -53,5 +58,5 @@ class CollectorRoutes[F[_]: Sync](collectorService: Service[F]) extends Http4sDs
)
}

val value: HttpApp[F] = (healthRoutes <+> cookieRoutes).orNotFound
val value: HttpApp[F] = (healthRoutes <+> corsRoute <+> cookieRoutes).orNotFound
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.snowplowanalytics.snowplow.CollectorPayload.thrift.model1.CollectorPa
import com.snowplowanalytics.snowplow.collectors.scalastream.model._

trait Service[F[_]] {
def preflightResponse(req: Request[F]): F[Response[F]]
def cookie(
body: F[Option[String]],
path: String,
Expand Down Expand Up @@ -59,7 +60,7 @@ class CollectorService[F[_]: Sync](

private val splitBatch: SplitBatch = SplitBatch(appName, appVersion)

def cookie(
override def cookie(
body: F[Option[String]],
path: String,
cookie: Option[RequestCookie],
Expand Down Expand Up @@ -102,17 +103,30 @@ class CollectorService[F[_]: Sync](
)
headerList = List(
setCookie.map(_.toRaw1),
cacheControl(pixelExpected).map(_.toRaw1)
cacheControl(pixelExpected).map(_.toRaw1),
accessControlAllowOriginHeader(request).some,
`Access-Control-Allow-Credentials`().toRaw1.some
benjben marked this conversation as resolved.
Show resolved Hide resolved
).flatten
responseHeaders = Headers(headerList)
_ <- sinkEvent(event, partitionKey)
} yield buildHttpResponse(responseHeaders, pixelExpected)

def determinePath(vendor: String, version: String): String = {
override def determinePath(vendor: String, version: String): String = {
val original = s"/$vendor/$version"
config.paths.getOrElse(original, original)
}

override def preflightResponse(req: Request[F]): F[Response[F]] = Sync[F].pure {
Response[F](
headers = Headers(
accessControlAllowOriginHeader(req),
`Access-Control-Allow-Credentials`(),
`Access-Control-Allow-Headers`(ci"Content-Type", ci"SP-Anonymous"),
benjben marked this conversation as resolved.
Show resolved Hide resolved
`Access-Control-Max-Age`.Cache(config.cors.accessControlMaxAge.toSeconds).asInstanceOf[`Access-Control-Max-Age`]
)
)
}

def extractHeader(req: Request[F], headerName: String): Option[String] =
req.headers.get(CIString(headerName)).map(_.head.value)

Expand Down Expand Up @@ -240,6 +254,19 @@ class CollectorService[F[_]: Sync](
}
}

/**
* Creates an Access-Control-Allow-Origin header which specifically allows the domain which made
* the request
*
* @param request Incoming request
* @return Header allowing only the domain which made the request or everything
*/
def accessControlAllowOriginHeader(request: Request[F]): Header.Raw =
Header.Raw(
ci"Access-Control-Allow-Origin",
extractHostsFromOrigin(request.headers).headOption.map(_.renderString).getOrElse("*")
benjben marked this conversation as resolved.
Show resolved Hide resolved
)

/**
* Determines the cookie domain to be used by inspecting the Origin header of the request
* and trying to find a match in the list of domains specified in the config file.
Expand All @@ -259,28 +286,25 @@ class CollectorService[F[_]: Sync](
(domains match {
case Nil => None
case _ =>
val originHosts = extractHosts(headers)
domains.find(domain => originHosts.exists(validMatch(_, domain)))
val originDomains = extractHostsFromOrigin(headers).map(_.host.value)
domains.find(domain => originDomains.exists(validMatch(_, domain)))
}).orElse(fallbackDomain)

/** Extracts the host names from a list of values in the request's Origin header. */
def extractHosts(headers: Headers): List[String] =
def extractHostsFromOrigin(headers: Headers): List[Origin.Host] =
(for {
// We can't use 'headers.get[Origin]' function in here because of the bug
// reported here: https://github.com/http4s/http4s/issues/7236
// To circumvent the bug, we split the the Origin header value with blank char
// and parse items individually.
originSplit <- headers.get(ci"Origin").map(_.head.value.split(' '))
parsed = originSplit.map(Origin.parse(_).toOption).toList.flatten
hosts = parsed.flatMap(extractHostFromOrigin)
hosts = parsed.flatMap {
case Origin.Null => List.empty
case Origin.HostList(hosts) => hosts.toList
}
} yield hosts).getOrElse(List.empty)

private def extractHostFromOrigin(originHeader: Origin): List[String] =
originHeader match {
case Origin.Null => List.empty
case Origin.HostList(hosts) => hosts.map(_.host.value).toList
}

/**
* Ensures a match is valid.
* We only want matches where:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ object model {
sameSite: Option[SameSite]
)

final case class CORSConfig(accessControlMaxAge: FiniteDuration)

final case class CollectorConfig(
paths: Map[String, String],
cookie: CookieConfig
cookie: CookieConfig,
cors: CORSConfig
) {
val cookieConfig = if (cookie.enabled) Some(cookie) else None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class CollectorRoutesSpec extends Specification {

def getCookieCalls: List[CookieParams] = cookieCalls.toList

override def preflightResponse(req: Request[IO]): IO[Response[IO]] =
IO.pure(Response[IO](status = Ok, body = Stream.emit("preflight response").through(text.utf8.encode)))

override def cookie(
body: IO[Option[String]],
path: String,
Expand Down Expand Up @@ -69,6 +72,18 @@ class CollectorRoutesSpec extends Specification {
response.as[String].unsafeRunSync() must beEqualTo("OK")
}

"respond to the cors route with a preflight response" in {
val (_, routes) = createTestServices
def test(uri: Uri) = {
val request = Request[IO](method = Method.OPTIONS, uri = uri)
val response = routes.run(request).unsafeRunSync()
benjben marked this conversation as resolved.
Show resolved Hide resolved
response.as[String].unsafeRunSync() shouldEqual "preflight response"
}
test(uri"/i")
test(uri"/health")
test(uri"/p3/p4")
}

"respond to the post cookie route with the cookie response" in {
val (collectorService, routes) = createTestServices

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,73 @@ class CollectorServiceSpec extends Specification {
)
r.body.compile.toList.unsafeRunSync().toArray shouldEqual CollectorService.pixel
}

"include CORS headers in the response" in {
val r = service
.cookie(
body = IO.pure(Some("b")),
path = "p",
cookie = None,
request = Request[IO](),
pixelExpected = true,
doNotTrack = false,
contentType = None
)
.unsafeRunSync()
r.headers.get[`Access-Control-Allow-Credentials`] shouldEqual Some(
`Access-Control-Allow-Credentials`()
)
r.headers.get(ci"Access-Control-Allow-Origin").map(_.head) shouldEqual Some(
Header.Raw(ci"Access-Control-Allow-Origin", "*")
)
}

"include the origin if given to CORS headers in the response" in {
val headers = Headers(
Origin
.HostList(
NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com")),
Origin.Host(
scheme = Uri.Scheme.http,
host = Uri.Host.unsafeFromString("otherorigin.com"),
port = Some(8080)
)
)
)
.asInstanceOf[Origin]
)
val request = Request[IO](headers = headers)
val r = service
.cookie(
body = IO.pure(Some("b")),
path = "p",
cookie = None,
request = request,
pixelExpected = true,
doNotTrack = false,
contentType = None
)
.unsafeRunSync()
r.headers.get[`Access-Control-Allow-Credentials`] shouldEqual Some(
`Access-Control-Allow-Credentials`()
)
r.headers.get(ci"Access-Control-Allow-Origin").map(_.head) shouldEqual Some(
Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com")
)
}
}

"preflightResponse" in {
"return a response appropriate to cors preflight options requests" in {
val expected = Headers(
Header.Raw(ci"Access-Control-Allow-Origin", "*"),
`Access-Control-Allow-Credentials`(),
`Access-Control-Allow-Headers`(ci"Content-Type", ci"SP-Anonymous"),
`Access-Control-Max-Age`.Cache(60).asInstanceOf[`Access-Control-Max-Age`]
)
service.preflightResponse(Request[IO]()).unsafeRunSync.headers shouldEqual expected
}
}

"buildEvent" in {
Expand Down Expand Up @@ -379,6 +446,46 @@ class CollectorServiceSpec extends Specification {
}
}

"accessControlAllowOriginHeader" in {
"give a restricted ACAO header if there is an Origin header in the request" in {
val headers = Headers(
Origin
.HostList(
NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com"))
)
)
.asInstanceOf[Origin]
)
val request = Request[IO](headers = headers)
val expected = Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com")
service.accessControlAllowOriginHeader(request) shouldEqual expected
}
"give a restricted ACAO header if there are multiple Origin headers in the request" in {
val headers = Headers(
Origin
.HostList(
NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.http, host = Uri.Host.unsafeFromString("origin.com")),
Origin.Host(
scheme = Uri.Scheme.http,
host = Uri.Host.unsafeFromString("otherorigin.com"),
port = Some(8080)
)
)
)
.asInstanceOf[Origin]
)
val request = Request[IO](headers = headers)
val expected = Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com")
service.accessControlAllowOriginHeader(request) shouldEqual expected
}
"give an open ACAO header if there are no Origin headers in the request" in {
val expected = Header.Raw(ci"Access-Control-Allow-Origin", "*")
service.accessControlAllowOriginHeader(Request[IO]()) shouldEqual expected
}
}

"cookieDomain" in {
val testCookieConfig = CookieConfig(
enabled = true,
Expand Down Expand Up @@ -496,18 +603,17 @@ class CollectorServiceSpec extends Specification {

"extractHosts" in {
"correctly extract the host names from a list of values in the request's Origin header" in {
val origin: Origin = Origin.HostList(
NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.https, host = Uri.Host.unsafeFromString("origin.com")),
Origin.Host(
scheme = Uri.Scheme.http,
host = Uri.Host.unsafeFromString("subdomain.otherorigin.gov.co.uk"),
port = Some(8080)
)
val originHostList = NonEmptyList.of(
Origin.Host(scheme = Uri.Scheme.https, host = Uri.Host.unsafeFromString("origin.com")),
Origin.Host(
scheme = Uri.Scheme.http,
host = Uri.Host.unsafeFromString("subdomain.otherorigin.gov.co.uk"),
port = Some(8080)
)
)
val headers = Headers(origin.toRaw1)
service.extractHosts(headers) shouldEqual Seq("origin.com", "subdomain.otherorigin.gov.co.uk")
val origin: Origin = Origin.HostList(originHostList)
val headers = Headers(origin.toRaw1)
service.extractHostsFromOrigin(headers) shouldEqual originHostList.toList
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ object TestUtils {
secure = false,
httpOnly = false,
sameSite = None
)
),
cors = CORSConfig(60.seconds)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ object StdoutCollector extends IOApp {
secure = false,
httpOnly = false,
sameSite = None
)
),
cors = CORSConfig(60.seconds)
),
BuildInfo.shortName,
BuildInfo.version
Expand Down
Loading