Skip to content

Commit

Permalink
fix: CQDG-793 review comments and check client only for device method
Browse files Browse the repository at this point in the history
  • Loading branch information
adipaul1981 committed Jul 4, 2024
1 parent f6a48da commit 2d96dd2
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 54 deletions.
12 changes: 6 additions & 6 deletions src/main/scala/bio/ferlab/ferload/endpoints/DrsEndpoints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,23 @@ object DrsEndpoints:
.errorOut(statusCode.and(jsonBody[ErrorResponse]))
.out(jsonBody[Authorizations])

private def getObject(authorizationService: AuthorizationService) =
private def getObject(authorizationService: AuthorizationService, method: String) =
objectEnpoint
.securityIn(auth.bearer[String]())
.securityIn(path[String].name("object_id"))
.errorOut(statusCode.and(jsonBody[ErrorResponse]))
.serverSecurityLogic((token, objectId) => authorizationService.authLogic(token, Seq(objectId)))
.serverSecurityLogic((token, objectId) => authorizationService.authLogic(token, Seq(objectId), method))
.get
.out(jsonBody[DrsObject])


private def getAccessMethod(authorizationService: AuthorizationService) =
private def getAccessMethod(authorizationService: AuthorizationService, method: String) =
objectEnpoint
.securityIn(auth.bearer[String]())
.securityIn(path[String].name("object_id"))
.securityIn("access" / path[String].name("access_id"))
.errorOut(statusCode.and(jsonBody[ErrorResponse]))
.serverSecurityLogic((token, objectId, accessId) => authorizationService.authLogic(token, Seq(objectId), Some(accessId)))
.serverSecurityLogic((token, objectId, accessId) => authorizationService.authLogic(token, Seq(objectId), method, Some(accessId)))
.get
.out(jsonBody[AccessURL])

Expand All @@ -89,7 +89,7 @@ object DrsEndpoints:
}
}

private def getObjectServer(config: Config, authorizationService: AuthorizationService, resourceService: ResourceService) = getObject(authorizationService).serverLogicSuccess { (user, _) =>
private def getObjectServer(config: Config, authorizationService: AuthorizationService, resourceService: ResourceService) = getObject(authorizationService, config.ferloadClientConfig.method).serverLogicSuccess { (user, _) =>
_ =>
for {
resource <- resourceService.getResourceById(user.permissions.head.rsid)
Expand All @@ -98,7 +98,7 @@ object DrsEndpoints:
}


private def getAccessMethodServer(config: Config, authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service) = getAccessMethod(authorizationService).serverLogicSuccess { (user, accessId) =>
private def getAccessMethodServer(config: Config, authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service) = getAccessMethod(authorizationService, config.ferloadClientConfig.method).serverLogicSuccess { (user, accessId) =>
_ =>
//fetch according to accessId, it is unique for now
if(accessId.isEmpty || config.drsConfig.accessId != accessId.get){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ import sttp.tapir.server.*
object LegacyObjectEndpoints:


private def securedGlobalEndpoint(authorizationService: AuthorizationService, resourceGlobalName: String): PartialServerEndpoint[String, (User, Option[String]), Unit, (StatusCode, ErrorResponse), Unit, Any, IO] =
private def securedGlobalEndpoint(authorizationService: AuthorizationService, resourceGlobalName: String, method: String): PartialServerEndpoint[String, (User, Option[String]), Unit, (StatusCode, ErrorResponse), Unit, Any, IO] =
endpoint
.securityIn(auth.bearer[String]())
.errorOut(statusCode.and(jsonBody[ErrorResponse]))
.serverSecurityLogic(token => authorizationService.authLogic(token, Seq(resourceGlobalName)))
.serverSecurityLogic(token => authorizationService.authLogic(token, Seq(resourceGlobalName), method: String))

private def objectByPath(authorizationService: AuthorizationService, resourceGlobalName: String): PartialServerEndpoint[String, (User, Option[String]), List[String], (StatusCode, ErrorResponse), ObjectUrl, Any, IO] =
securedGlobalEndpoint(authorizationService, resourceGlobalName)
private def objectByPath(authorizationService: AuthorizationService, resourceGlobalName: String, method: String): PartialServerEndpoint[String, (User, Option[String]), List[String], (StatusCode, ErrorResponse), ObjectUrl, Any, IO] =
securedGlobalEndpoint(authorizationService, resourceGlobalName, method)
.get
.description("Retrieve an object by its path and return an url to download it")
.deprecated()
.in(paths.description("Path of the object to retrieve"))
.out(jsonBody[ObjectUrl])

private def objectsByPaths(authorizationService: AuthorizationService, resourceGlobalName: String): PartialServerEndpoint[String, (User, Option[String]), String, (StatusCode, ErrorResponse), Map[String, String], Any, IO] =
securedGlobalEndpoint(authorizationService, resourceGlobalName)
private def objectsByPaths(authorizationService: AuthorizationService, resourceGlobalName: String, method: String): PartialServerEndpoint[String, (User, Option[String]), String, (StatusCode, ErrorResponse), Map[String, String], Any, IO] =
securedGlobalEndpoint(authorizationService, resourceGlobalName, method)
.description("Retrieve a list of objects by their paths and return a list of download URLs for each object")
.deprecated()
.post
Expand All @@ -41,13 +41,13 @@ object LegacyObjectEndpoints:
.example(Map("file1.vcf" -> "https://file1.vcf", "file2.vcf" -> "https://file2.vcf"))
)

def objectByPathServer(authorizationService: AuthorizationService, s3Service: S3Service, resourceGlobalName: String, defaultBucket: String): ServerEndpoint[Any, IO] =
objectByPath(authorizationService, resourceGlobalName).serverLogicSuccess { user =>
def objectByPathServer(authorizationService: AuthorizationService, s3Service: S3Service, resourceGlobalName: String, defaultBucket: String, method: String): ServerEndpoint[Any, IO] =
objectByPath(authorizationService, resourceGlobalName, method).serverLogicSuccess { user =>
file => s3Service.presignedUrl(defaultBucket, file.mkString("/")).pure[IO].map(ObjectUrl.apply)
}

def listObjectsByPathServer(authorizationService: AuthorizationService, s3Service: S3Service, resourceGlobalName: String, defaultBucket: String): ServerEndpoint[Any, IO] =
objectsByPaths(authorizationService, resourceGlobalName).serverLogicSuccess { user =>
def listObjectsByPathServer(authorizationService: AuthorizationService, s3Service: S3Service, resourceGlobalName: String, defaultBucket: String, method: String): ServerEndpoint[Any, IO] =
objectsByPaths(authorizationService, resourceGlobalName, method).serverLogicSuccess { user =>
files =>
files.split("\n")
.toList
Expand All @@ -59,8 +59,8 @@ object LegacyObjectEndpoints:
b <- config.s3Config.defaultBucket
r <- config.auth.resourcesGlobalName
servers = List(
listObjectsByPathServer(authorizationService, s3Service, r, b),
objectByPathServer(authorizationService, s3Service, r, b)
listObjectsByPathServer(authorizationService, s3Service, r, b, config.ferloadClientConfig.method),
objectByPathServer(authorizationService, s3Service, r, b, config.ferloadClientConfig.method)
)
} yield servers
s.getOrElse(Nil)
Expand Down
44 changes: 22 additions & 22 deletions src/main/scala/bio/ferlab/ferload/endpoints/ObjectsEndpoints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ object ObjectsEndpoints:

private val byIdEndpoint = baseEndpoint.securityIn("objects")

private def singleObject(authorizationService: AuthorizationService): PartialServerEndpoint[(String, String), (User, Option[String]), Unit, (StatusCode, ErrorResponse), ObjectUrl, Any, IO] = byIdEndpoint
private def singleObject(authorizationService: AuthorizationService, method: String): PartialServerEndpoint[(String, String), (User, Option[String]), Unit, (StatusCode, ErrorResponse), ObjectUrl, Any, IO] = byIdEndpoint
.get
.securityIn(path[String].name("object_id"))
.serverSecurityLogic((token, objectId) => authorizationService.authLogic(token, Seq(objectId)))
.serverSecurityLogic((token, objectId) => authorizationService.authLogic(token, Seq(objectId), method))
.description("Retrieve an object by its id and return an url to download it")
.out(jsonBody[ObjectUrl])

private def listObjects(authorizationService: AuthorizationService): PartialServerEndpoint[(String, String), (User, Option[String]), Unit, (StatusCode, ErrorResponse), Map[String, String], Any, IO] = byIdEndpoint
private def listObjects(authorizationService: AuthorizationService, method: String): PartialServerEndpoint[(String, String), (User, Option[String]), Unit, (StatusCode, ErrorResponse), Map[String, String], Any, IO] = byIdEndpoint
.post
.securityIn("list")
.securityIn(stringBody.description("List of ids of objects to retrieve").example("FI1\nFI2"))
.serverSecurityLogic((token, objects) => authorizationService.authLogic(token, objects.split("\n")))
.serverSecurityLogic((token, objects) => authorizationService.authLogic(token, objects.split("\n"), method))
.description("Retrieve an object by its id and return an url to download it")
.out(jsonBody[Map[String, String]]
.description("List of files URLs by object id")
.example(Map("FI1" -> "https://file1.vcf", "FI2" -> "https://file2.vcf")))


def singleObjectServer(authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service): ServerEndpoint[Any, IO] =
singleObject(authorizationService).serverLogicSuccess { (user, _) =>
def singleObjectServer(authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service, method: String): ServerEndpoint[Any, IO] =
singleObject(authorizationService, method).serverLogicSuccess { (user, _) =>
_ =>
for {
resource <- resourceService.getResourceById(user.permissions.head.rsid)
Expand All @@ -51,8 +51,8 @@ object ObjectsEndpoints:
}


def listObjectsServer(authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service): ServerEndpoint[Any, IO] =
listObjects(authorizationService).serverLogicSuccess { (user, _) =>
def listObjectsServer(authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service, method: String): ServerEndpoint[Any, IO] =
listObjects(authorizationService, method).serverLogicSuccess { (user, _) =>
_ =>
val resourcesIO: IO[List[ReadResource]] = user.permissions.toList.traverse(p => resourceService.getResourceById(p.rsid))
resourcesIO.map { resources =>
Expand All @@ -64,31 +64,31 @@ object ObjectsEndpoints:

}

def all(authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service): Seq[ServerEndpoint[Any, IO]] = List(
singleObjectServer(authorizationService, resourceService, s3Service),
listObjectsServer(authorizationService, resourceService, s3Service)
def all(authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service, method: String): Seq[ServerEndpoint[Any, IO]] = List(
singleObjectServer(authorizationService, resourceService, s3Service, method),
listObjectsServer(authorizationService, resourceService, s3Service, method)
)

object ByPath:
private def byPathEndpoint(authorizationService: AuthorizationService, resourceGlobalName: String): PartialServerEndpoint[String, (User, Option[String]), Unit, (StatusCode, ErrorResponse), Unit, Any, IO] =
private def byPathEndpoint(authorizationService: AuthorizationService, resourceGlobalName: String, method: String): PartialServerEndpoint[String, (User, Option[String]), Unit, (StatusCode, ErrorResponse), Unit, Any, IO] =
baseEndpoint
.securityIn("objects")
.securityIn("bypath")
.serverSecurityLogic(token => authorizationService.authLogic(token, Seq(resourceGlobalName)))
.serverSecurityLogic(token => authorizationService.authLogic(token, Seq(resourceGlobalName), method))

private def singleObject(authorizationService: AuthorizationService, resourceGlobalName: String): PartialServerEndpoint[String, (User, Option[String]), String, (StatusCode, ErrorResponse), ObjectUrl, Any, IO] =
byPathEndpoint(authorizationService, resourceGlobalName)
private def singleObject(authorizationService: AuthorizationService, resourceGlobalName: String, method: String): PartialServerEndpoint[String, (User, Option[String]), String, (StatusCode, ErrorResponse), ObjectUrl, Any, IO] =
byPathEndpoint(authorizationService, resourceGlobalName, method)
.get
.description("Retrieve an object by its path and return an url to download it")
.in(query[String]("path").description("Path of the object to retrieve").example("dir1/file1.vcf"))
.out(jsonBody[ObjectUrl])

def singleObjectServer(authorizationService: AuthorizationService, s3Service: S3Service, resourceGlobalName: String, defaultBucket: String): ServerEndpoint[Any, IO] =
singleObject(authorizationService, resourceGlobalName).serverLogicSuccess { user =>
def singleObjectServer(authorizationService: AuthorizationService, s3Service: S3Service, resourceGlobalName: String, defaultBucket: String, method: String): ServerEndpoint[Any, IO] =
singleObject(authorizationService, resourceGlobalName, method).serverLogicSuccess { user =>
file => s3Service.presignedUrl(defaultBucket, file).pure[IO].map(ObjectUrl.apply)
}

private def listObjects(authorizationService: AuthorizationService, resourceGlobalName: String): PartialServerEndpoint[String, (User, Option[String]), String, (StatusCode, ErrorResponse), Map[String, String], Any, IO] = byPathEndpoint(authorizationService, resourceGlobalName)
private def listObjects(authorizationService: AuthorizationService, resourceGlobalName: String, method: String): PartialServerEndpoint[String, (User, Option[String]), String, (StatusCode, ErrorResponse), Map[String, String], Any, IO] = byPathEndpoint(authorizationService, resourceGlobalName, method)
.description("Retrieve a list of objects by their path and return a list of download URLs for each object")
.post
.in("list")
Expand All @@ -98,7 +98,7 @@ object ObjectsEndpoints:
.example(Map("dir1/file1.vcf" -> "https://file1.vcf", "dir1/file2.vcf" -> "https://file2.vcf"))
)

def listObjectsServer(authorizationService: AuthorizationService, s3Service: S3Service, resourceGlobalName: String, defaultBucket: String): ServerEndpoint[Any, IO] = listObjects(authorizationService, resourceGlobalName).serverLogicSuccess { user =>
def listObjectsServer(authorizationService: AuthorizationService, s3Service: S3Service, resourceGlobalName: String, defaultBucket: String, method: String): ServerEndpoint[Any, IO] = listObjects(authorizationService, resourceGlobalName, method).serverLogicSuccess { user =>
files =>
files.split("\n")
.toList
Expand All @@ -110,14 +110,14 @@ object ObjectsEndpoints:
b <- config.s3Config.defaultBucket
r <- config.auth.resourcesGlobalName
servers = List(
singleObjectServer(authorizationService, s3Service, r, b),
listObjectsServer(authorizationService, s3Service, r, b)
singleObjectServer(authorizationService, s3Service, r, b, config.ferloadClientConfig.method),
listObjectsServer(authorizationService, s3Service, r, b, config.ferloadClientConfig.method)
)
} yield servers
s.getOrElse(Nil)

}

def all(config: Config, authorizationService: AuthorizationService, resourceService: ResourceService, s3Service: S3Service): Seq[ServerEndpoint[Any, IO]] =
ByPath.all(config, authorizationService, s3Service) ++ ById.all(authorizationService, resourceService, s3Service)
ByPath.all(config, authorizationService, s3Service) ++ ById.all(authorizationService, resourceService, s3Service, config.ferloadClientConfig.method)

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package bio.ferlab.ferload.services

import bio.ferlab.ferload.AuthConfig
import bio.ferlab.ferload.{AuthConfig, FerloadClientConfig}
import bio.ferlab.ferload.endpoints.PermissionsEndpoints.InputPermissions
import bio.ferlab.ferload.endpoints.RawPermissions
import bio.ferlab.ferload.model.*
Expand Down Expand Up @@ -100,17 +100,20 @@ class AuthorizationService(authConfig: AuthConfig, backend: SttpBackend[IO, Fs2S
* @param resources the resources to access
* @return the user with permissions if the token is valid and if user have access to the resources. Otherwise, return errors (Unauthorized, Forbidden, NotFound).
*/
def authLogic(token: String, resources: Seq[String], accessId: Option[String] = None): IO[Either[(StatusCode, ErrorResponse), (User, Option[String])]] = {
def authLogic(token: String, resources: Seq[String], method: String, accessId: Option[String] = None): IO[Either[(StatusCode, ErrorResponse), (User, Option[String])]] = {
val r: IO[(User, Option[String])] = for {
accessToken <- introspectToken(token)
partyToken <- requestPartyToken(token, resources)
permissionToken <- introspectToken(partyToken)
} yield {

//Only request with token from the audience client is authorized
val isAuthorizedClientAccessToken = accessToken.azp.exists(_.equalsIgnoreCase(authConfig.audience.get))
if(!isAuthorizedClientAccessToken){
throw HttpError(s"Unauthorized client: ${accessToken.azp.getOrElse("Nothing")}", StatusCode.Forbidden)

// For device method we only authorize tokens from a specific client
if(method == FerloadClientConfig.DEVICE) {
val isAuthorizedClientAccessToken = accessToken.azp.exists(_.equalsIgnoreCase(authConfig.audience.get))
if (!isAuthorizedClientAccessToken) {
throw HttpError(s"Unauthorized client: ${accessToken.azp.getOrElse("Nothing")}", StatusCode.Forbidden)
}
}

val value: Set[Permissions] = permissionToken.authorization.map(_.permissions.toSet).getOrElse(Set.empty)
Expand Down
Loading

0 comments on commit 2d96dd2

Please sign in to comment.