Skip to content

Commit

Permalink
Redirect
Browse files Browse the repository at this point in the history
  • Loading branch information
spenes committed Aug 16, 2023
1 parent 8d2828e commit deef1a2
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ object CollectorApp {
_ <- withGracefulShutdown(610.seconds) {
val sinks = CollectorSinks(good, bad)
val collectorService: CollectorService[F] = new CollectorService[F](config, sinks, appName, appVersion)
buildHttpServer[F](new CollectorRoutes[F](collectorService).value)
buildHttpServer[F](new CollectorRoutes[F](config.enableDefaultRedirect, collectorService).value)
}
} yield ()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.http4s.dsl.Http4sDsl
import org.http4s.implicits._
import com.comcast.ip4s.Dns

class CollectorRoutes[F[_]: Sync](collectorService: Service[F]) extends Http4sDsl[F] {
class CollectorRoutes[F[_]: Sync](enableDefaultRedirect: Boolean, collectorService: Service[F]) extends Http4sDsl[F] {

implicit val dns: Dns[F] = Dns.forSync[F]

Expand Down Expand Up @@ -55,5 +55,14 @@ class CollectorRoutes[F[_]: Sync](collectorService: Service[F]) extends Http4sDs
)
}

val value: HttpApp[F] = (healthRoutes <+> corsRoute <+> cookieRoutes).orNotFound
def rejectRedirect = HttpRoutes.of[F] {
case _ -> Root / "r" / _ =>
NotFound("redirects disabled")
}

val value: HttpApp[F] = {
val routes = healthRoutes <+> corsRoute <+> cookieRoutes
val res = if (enableDefaultRedirect) routes else rejectRedirect <+> routes
res.orNotFound
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class CollectorService[F[_]: Sync](
): F[Response[F]] =
for {
body <- body
redirect = path.startsWith("/r/")
hostname = extractHostname(request)
userAgent = extractHeader(request, "User-Agent")
refererUri = extractHeader(request, "Referer")
Expand Down Expand Up @@ -107,7 +108,12 @@ class CollectorService[F[_]: Sync](
).flatten
responseHeaders = Headers(headerList)
_ <- sinkEvent(event, partitionKey)
resp = buildHttpResponse(responseHeaders, pixelExpected)
resp = buildHttpResponse(
queryParams = request.uri.query.params,
headers = responseHeaders,
redirect = redirect,
pixelExpected = pixelExpected
)
} yield resp

override def determinePath(vendor: String, version: String): String = {
Expand Down Expand Up @@ -173,26 +179,60 @@ class CollectorService[F[_]: Sync](
e
}

// TODO: Handle necessary cases to build http response in here
def buildHttpResponse(
queryParams: Map[String, String],
headers: Headers,
redirect: Boolean,
pixelExpected: Boolean
): Response[F] =
if (redirect)
buildRedirectHttpResponse(queryParams, headers)
else
buildUsualHttpResponse(pixelExpected, headers)

/** Builds the appropriate http response when not dealing with click redirects. */
def buildUsualHttpResponse(pixelExpected: Boolean, headers: Headers): Response[F] =
pixelExpected match {
case true =>
Response[F](
headers = headers.put(`Content-Type`(MediaType.image.gif)),
body = pixelStream
body = pixelStream
)
// See https://github.com/snowplow/snowplow-javascript-tracker/issues/482
case false =>
Response[F](
status = Ok,
status = Ok,
headers = headers,
body = Stream.emit("ok").through(fs2.text.utf8.encode)
body = Stream.emit("ok").through(fs2.text.utf8.encode)
)
}

/** Builds the appropriate http response when dealing with click redirects. */
def buildRedirectHttpResponse(queryParams: Map[String, String], headers: Headers): Response[F] = {
val targetUri = for {
target <- queryParams.get("u")
uri <- Uri.fromString(target).toOption
_ = uri if redirectTargetAllowed(uri)
} yield uri

targetUri match {
case Some(t) =>
Response[F](
status = Found,
headers = headers.put(Location(t))
)
case _ => Response[F](
status = BadRequest,
headers = headers
)
}
}

private def redirectTargetAllowed(target: Uri): Boolean =
if (config.redirectDomains.isEmpty) true
else config.redirectDomains.contains(target.host.map(_.renderString).getOrElse(""))


// TODO: Since Remote-Address and Raw-Request-URI is akka-specific headers,
// they aren't included in here. It might be good to search for counterparts in Http4s.
/** If the SP-Anonymous header is not present, retrieves all headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ object model {
final case class CollectorConfig(
paths: Map[String, String],
cookie: CookieConfig,
cors: CORSConfig
cors: CORSConfig,
enableDefaultRedirect: Boolean,
redirectDomains: Set[String]
) {
val cookieConfig = if (cookie.enabled) Some(cookie) else None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ class CollectorRoutesSpec extends Specification {
override def determinePath(vendor: String, version: String): String = "/p1/p2"
}

def createTestServices = {
def createTestServices(enabledDefaultRedirect: Boolean = true) = {
val collectorService = new TestService()
val routes = new CollectorRoutes[IO](collectorService).value
val routes = new CollectorRoutes[IO](enabledDefaultRedirect, collectorService).value
(collectorService, routes)
}

"The collector route" should {
"respond to the health route with an ok response" in {
val (_, routes) = createTestServices
val (_, routes) = createTestServices()
val request = Request[IO](method = Method.GET, uri = uri"/health")
val response = routes.run(request).unsafeRunSync()

Expand All @@ -70,7 +70,7 @@ class CollectorRoutesSpec extends Specification {
}

"respond to the cors route with a preflight response" in {
val (_, routes) = createTestServices
val (_, routes) = createTestServices()
def test(uri: Uri) = {
val request = Request[IO](method = Method.OPTIONS, uri = uri)
val response = routes.run(request).unsafeRunSync()
Expand All @@ -82,7 +82,7 @@ class CollectorRoutesSpec extends Specification {
}

"respond to the post cookie route with the cookie response" in {
val (collectorService, routes) = createTestServices
val (collectorService, routes) = createTestServices()

val request = Request[IO](method = Method.POST, uri = uri"/p3/p4")
.withEntity("testBody")
Expand All @@ -102,7 +102,7 @@ class CollectorRoutesSpec extends Specification {

"respond to the get or head cookie route with the cookie response" in {
def test(method: Method) = {
val (collectorService, routes) = createTestServices
val (collectorService, routes) = createTestServices()

val request = Request[IO](method = method, uri = uri"/p3/p4").withEntity("testBody")
val response = routes.run(request).unsafeRunSync()
Expand All @@ -124,7 +124,7 @@ class CollectorRoutesSpec extends Specification {

"respond to the get or head pixel route with the cookie response" in {
def test(method: Method, uri: String) = {
val (collectorService, routes) = createTestServices
val (collectorService, routes) = createTestServices()

val request = Request[IO](method = method, uri = Uri.unsafeFromString(uri)).withEntity("testBody")
val response = routes.run(request).unsafeRunSync()
Expand All @@ -145,6 +145,31 @@ class CollectorRoutesSpec extends Specification {
test(Method.GET, "/ice.png")
test(Method.HEAD, "/ice.png")
}

"allow redirect routes when redirects enabled" in {
val (_, routes) = createTestServices()

val request = Request[IO](method = Method.GET, uri = uri"/r/abc")
val response = routes.run(request).unsafeRunSync()

response.status must beEqualTo(Status.Ok)
response.bodyText.compile.string.unsafeRunSync() must beEqualTo("cookie")
}

"disallow redirect routes when redirects disabled" in {
def test(method: Method) = {
val (_, routes) = createTestServices(enabledDefaultRedirect = false)

val request = Request[IO](method = method, uri = uri"/r/abc")
val response = routes.run(request).unsafeRunSync()

response.status must beEqualTo(Status.NotFound)
response.bodyText.compile.string.unsafeRunSync() must beEqualTo("redirects disabled")
}

test(Method.GET)
test(Method.POST)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class CollectorServiceSpec extends Specification {
secure = false
)

def probeService(): ProbeService = {
def probeService(config: CollectorConfig = TestUtils.testConf): ProbeService = {
val good = new TestSink
val bad = new TestSink
val service = new CollectorService[IO](
config = TestUtils.testConf,
config = config,
sinks = CollectorSinks[IO](good, bad),
appName = "appName",
appVersion = "appVersion"
Expand Down Expand Up @@ -358,6 +358,33 @@ class CollectorServiceSpec extends Specification {
Header.Raw(ci"Access-Control-Allow-Origin", "http://origin.com")
)
}

"redirect if path starts with '/r/'" in {
val testConf = TestUtils.testConf.copy(
redirectDomains = Set("snowplow.acme.com", "example.com")
)
val testPath = "/r/example?u=https://snowplow.acme.com/12"
val ProbeService(service, good, bad) = probeService(config = testConf)
val req = Request[IO](
method = Method.GET,
uri = Uri.unsafeFromString(testPath)
)
val r = service
.cookie(
body = IO.pure(Some("b")),
path = testPath,
request = req,
pixelExpected = false,
doNotTrack = false,
contentType = None
)
.unsafeRunSync()

r.status mustEqual Status.Found
r.headers.get[Location] must beSome(Location(Uri.unsafeFromString("https://snowplow.acme.com/12")))
good.storedRawEvents must have size 1
bad.storedRawEvents must have size 0
}
}

"preflightResponse" in {
Expand Down Expand Up @@ -444,20 +471,116 @@ class CollectorServiceSpec extends Specification {
}

"buildHttpResponse" in {
"rely on buildRedirectHttpResponse if redirect is true" in {
val testConfig = TestUtils.testConf.copy(
redirectDomains = Set("example1.com", "example2.com")
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildHttpResponse(
queryParams = Map("u" -> "https://example1.com/12"),
headers = testHeaders,
redirect = true,
pixelExpected = true
)
res.status.code shouldEqual 302
res.headers shouldEqual testHeaders.put(Location(Uri.unsafeFromString("https://example1.com/12")))
}
"send back a gif if pixelExpected is true" in {
val res = service.buildHttpResponse(testHeaders, pixelExpected = true)
val res = service.buildHttpResponse(
queryParams = Map.empty,
headers = testHeaders,
redirect = false,
pixelExpected = true
)
res.status shouldEqual Status.Ok
res.headers shouldEqual testHeaders.put(`Content-Type`(MediaType.image.gif))
res.body.compile.toList.unsafeRunSync().toArray shouldEqual CollectorService.pixel
}
"send back ok otherwise" in {
val res = service.buildHttpResponse(testHeaders, pixelExpected = false)
val res = service.buildHttpResponse(
queryParams = Map.empty,
headers = testHeaders,
redirect = false,
pixelExpected = false
)
res.status shouldEqual Status.Ok
res.headers shouldEqual testHeaders
res.bodyText.compile.toList.unsafeRunSync() shouldEqual List("ok")
}
}

"buildUsualHttpResponse" in {
"send back a gif if pixelExpected is true" in {
val res = service.buildUsualHttpResponse(
headers = testHeaders,
pixelExpected = true
)
res.status shouldEqual Status.Ok
res.headers shouldEqual testHeaders.put(`Content-Type`(MediaType.image.gif))
res.body.compile.toList.unsafeRunSync().toArray shouldEqual CollectorService.pixel
}
"send back ok otherwise" in {
val res = service.buildUsualHttpResponse(
headers = testHeaders,
pixelExpected = false
)
res.status shouldEqual Status.Ok
res.headers shouldEqual testHeaders
res.bodyText.compile.toList.unsafeRunSync() shouldEqual List("ok")
}
}

"buildRedirectHttpResponse" in {
"give back a 302 if redirecting and there is a u query param" in {
val testConfig = TestUtils.testConf.copy(
redirectDomains = Set("example1.com", "example2.com")
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildRedirectHttpResponse(
queryParams = Map("u" -> "https://example1.com/12"),
headers = testHeaders
)
res.status.code shouldEqual 302
res.headers shouldEqual testHeaders.put(Location(Uri.unsafeFromString("https://example1.com/12")))
}
"give back a 400 if redirecting and there are no u query params" in {
val testConfig = TestUtils.testConf.copy(
redirectDomains = Set("example1.com", "example2.com")
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildRedirectHttpResponse(
queryParams = Map.empty,
headers = testHeaders
)
res.status.code shouldEqual 400
res.headers shouldEqual testHeaders
}
"give back a 400 if redirecting to a disallowed domain" in {
val testConfig = TestUtils.testConf.copy(
redirectDomains = Set("example1.com", "example2.com")
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildRedirectHttpResponse(
queryParams = Map("u" -> "https://invalidexample1.com/12"),
headers = testHeaders
)
res.status.code shouldEqual 400
res.headers shouldEqual testHeaders
}
"give back a 302 if redirecting to an unknown domain, with no restrictions on domains" in {
val testConfig = TestUtils.testConf.copy(
redirectDomains = Set.empty
)
val ProbeService(service, _, _) = probeService(config = testConfig)
val res = service.buildRedirectHttpResponse(
queryParams = Map("u" -> "https://unknown.example.com/12"),
headers = testHeaders
)
res.status.code shouldEqual 302
res.headers shouldEqual testHeaders.put(Location(Uri.unsafeFromString("https://unknown.example.com/12")))
}
}

"ipAndPartitionkey" in {
"give back the ip and partition key as ip if remote address is defined" in {
val address = Some("127.0.0.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ object TestUtils {
httpOnly = false,
sameSite = None
),
cors = CORSConfig(60.seconds)
cors = CORSConfig(60.seconds),
enableDefaultRedirect = false,
redirectDomains = Set.empty
)
}
Loading

0 comments on commit deef1a2

Please sign in to comment.