Skip to content

Commit

Permalink
Merge pull request #12873 from markylaing/pre-auth-resolve-resources
Browse files Browse the repository at this point in the history
Auth: Expand certificate/image fingerprints and handle effective projects in authorization check.
  • Loading branch information
tomponline authored Feb 21, 2024
2 parents aee951a + 40d791b commit 35aa763
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 33 deletions.
23 changes: 19 additions & 4 deletions lxd/auth/driver_tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"

"github.com/canonical/lxd/lxd/identity"
"github.com/canonical/lxd/lxd/request"
"github.com/canonical/lxd/shared"
"github.com/canonical/lxd/shared/api"
"github.com/canonical/lxd/shared/entity"
Expand Down Expand Up @@ -68,11 +69,15 @@ func (t *tls) CheckPermission(ctx context.Context, r *http.Request, entityURL *a
return api.StatusErrorf(http.StatusForbidden, "Certificate is restricted")
}

entityType, projectName, _, _, err := entity.ParseURL(entityURL.URL)
entityType, projectName, _, pathArgs, err := entity.ParseURL(entityURL.URL)
if err != nil {
return fmt.Errorf("Failed to parse entity URL: %w", err)
}

if entityType == entity.TypeProject {
projectName = pathArgs[0]
}

// Check server level object types
switch entityType {
case entity.TypeServer:
Expand Down Expand Up @@ -165,6 +170,8 @@ func (t *tls) GetPermissionChecker(ctx context.Context, r *http.Request, entitle
return nil, api.StatusErrorf(http.StatusForbidden, "User does not have permissions for project %q", details.projectName)
}

effectiveProject, _ := request.GetCtxValue[string](r.Context(), request.CtxEffectiveProjectName)

// Filter objects by project.
return func(entityURL *api.URL) bool {
eType, project, _, pathArgs, err := entity.ParseURL(entityURL.URL)
Expand All @@ -173,15 +180,23 @@ func (t *tls) GetPermissionChecker(ctx context.Context, r *http.Request, entitle
return false
}

// GetPermissionChecker can only be used to check permissions on entities of the same type, e.g. a list of instances.
if eType != entityType {
logger.Warn("Permission checker received URL with unexpected entity type", logger.Ctx{"expected": entityType, "actual": eType, "entity_url": entityURL})
return false
}

// If it's a project URL, the project name is in the path, not the query parameter.
if eType == entity.TypeProject {
project = pathArgs[0]
}

if eType != entityType {
logger.Warn("Permission checker received URL with unexpected entity type", logger.Ctx{"expected": entityType, "actual": eType, "entity_url": entityURL})
return false
// If an effective project has been set in the request context. We expect all entities to be in that project.
if effectiveProject != "" {
return project == effectiveProject
}

// Otherwise, check if the project is in the list of allowed projects for the entity.
return shared.ValueInSlice(project, id.Projects)
}, nil
}
12 changes: 9 additions & 3 deletions lxd/certificates.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ var certificateCmd = APIEndpoint{
Path: "certificates/{fingerprint}",

Delete: APIEndpointAction{Handler: certificateDelete, AccessHandler: allowAuthenticated},
Get: APIEndpointAction{Handler: certificateGet, AccessHandler: allowPermission(entity.TypeCertificate, auth.EntitlementCanView, "fingerprint")},
Get: APIEndpointAction{Handler: certificateGet, AccessHandler: allowAuthenticated},
Patch: APIEndpointAction{Handler: certificatePatch, AccessHandler: allowAuthenticated},
Put: APIEndpointAction{Handler: certificatePut, AccessHandler: allowAuthenticated},
}
Expand Down Expand Up @@ -697,13 +697,14 @@ func certificatesPost(d *Daemon, r *http.Request) response.Response {
// "500":
// $ref: "#/responses/InternalServerError"
func certificateGet(d *Daemon, r *http.Request) response.Response {
s := d.State()
fingerprint, err := url.PathUnescape(mux.Vars(r)["fingerprint"])
if err != nil {
return response.SmartError(err)
}

var cert *api.Certificate
err = d.State().DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
dbCertInfo, err := dbCluster.GetCertificateByFingerprintPrefix(ctx, tx.Tx(), fingerprint)
if err != nil {
return err
Expand All @@ -716,6 +717,11 @@ func certificateGet(d *Daemon, r *http.Request) response.Response {
return response.SmartError(err)
}

err = s.Authorizer.CheckPermission(r.Context(), r, entity.CertificateURL(cert.Fingerprint), auth.EntitlementCanView)
if err != nil {
return response.SmartError(err)
}

return response.SyncResponseETag(true, cert, cert)
}

Expand Down Expand Up @@ -1031,7 +1037,7 @@ func certificateDelete(d *Daemon, r *http.Request) response.Response {
}

var userCanEditCertificate bool
err = s.Authorizer.CheckPermission(r.Context(), r, entity.CertificateURL(fingerprint), auth.EntitlementCanDelete)
err = s.Authorizer.CheckPermission(r.Context(), r, entity.CertificateURL(certInfo.Fingerprint), auth.EntitlementCanDelete)
if err == nil {
userCanEditCertificate = true
} else if api.StatusErrorCheck(err, http.StatusForbidden) {
Expand Down
87 changes: 62 additions & 25 deletions lxd/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -1617,7 +1617,24 @@ func imagesGet(d *Daemon, r *http.Request) response.Response {
filterStr := r.FormValue("filter")

s := d.State()
var effectiveProjectName string
err := s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
hasImages, err := dbCluster.ProjectHasImages(ctx, tx.Tx(), projectName)
if err != nil {
return err
}

if !hasImages {
effectiveProjectName = api.ProjectDefaultName
}

return nil
})
if err != nil {
return response.SmartError(err)
}

request.SetCtxValue(r, request.CtxEffectiveProjectName, effectiveProjectName)
hasPermission, authorizationErr := s.Authorizer.GetPermissionChecker(r.Context(), r, auth.EntitlementCanView, entity.TypeImage)
if authorizationErr != nil && !api.StatusErrorCheck(authorizationErr, http.StatusForbidden) {
return response.SmartError(authorizationErr)
Expand Down Expand Up @@ -2980,17 +2997,7 @@ func imageGet(d *Daemon, r *http.Request) response.Response {
return response.SmartError(err)
}

var userCanViewImage bool
err = s.Authorizer.CheckPermission(r.Context(), r, entity.ImageURL(projectName, fingerprint), auth.EntitlementCanView)
if err == nil {
userCanViewImage = true
} else if !api.StatusErrorCheck(err, http.StatusForbidden) {
return response.SmartError(err)
}

public := d.checkTrustedClient(r) != nil || !userCanViewImage
secret := r.FormValue("secret")

// Get the image (expand partial fingerprints).
var info *api.Image
err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
info, err = doImageGet(ctx, tx, projectName, fingerprint, false)
Expand All @@ -3004,6 +3011,17 @@ func imageGet(d *Daemon, r *http.Request) response.Response {
return response.SmartError(err)
}

var userCanViewImage bool
err = s.Authorizer.CheckPermission(r.Context(), r, entity.ImageURL(projectName, info.Fingerprint), auth.EntitlementCanView)
if err == nil {
userCanViewImage = true
} else if !api.StatusErrorCheck(err, http.StatusForbidden) {
return response.SmartError(err)
}

public := d.checkTrustedClient(r) != nil || !userCanViewImage
secret := r.FormValue("secret")

op, err := imageValidSecret(s, r, projectName, info.Fingerprint, secret)
if err != nil {
return response.SmartError(err)
Expand Down Expand Up @@ -3426,10 +3444,29 @@ func imageAliasesPost(d *Daemon, r *http.Request) response.Response {
// "500":
// $ref: "#/responses/InternalServerError"
func imageAliasesGet(d *Daemon, r *http.Request) response.Response {
projectName := request.ProjectParam(r)
recursion := util.IsRecursionRequest(r)

s := d.State()

projectName := request.ProjectParam(r)
var effectiveProjectName string
err := s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
projectHasImages, err := dbCluster.ProjectHasImages(ctx, tx.Tx(), projectName)
if err != nil {
return err
}

if !projectHasImages {
effectiveProjectName = api.ProjectDefaultName
}

return nil
})
if err != nil {
return response.SmartError(err)
}

request.SetCtxValue(r, request.CtxEffectiveProjectName, effectiveProjectName)
userHasPermission, err := s.Authorizer.GetPermissionChecker(r.Context(), r, auth.EntitlementCanView, entity.TypeImageAlias)
if err != nil {
return response.InternalError(fmt.Errorf("Failed to get a permission checker: %w", err))
Expand Down Expand Up @@ -3997,20 +4034,8 @@ func imageExport(d *Daemon, r *http.Request) response.Response {
return response.SmartError(err)
}

// Access control.
var userCanViewImage bool
err = s.Authorizer.CheckPermission(r.Context(), r, entity.ImageURL(projectName, fingerprint), auth.EntitlementCanView)
if err == nil {
userCanViewImage = true
} else if !api.StatusErrorCheck(err, http.StatusForbidden) {
return response.SmartError(err)
}

public := d.checkTrustedClient(r) != nil || !userCanViewImage
secret := r.FormValue("secret")

// Get the image (expand the fingerprint).
var imgInfo *api.Image

err = s.DB.Cluster.Transaction(context.TODO(), func(ctx context.Context, tx *db.ClusterTx) error {
// Get the image (expand the fingerprint).
_, imgInfo, err = tx.GetImage(ctx, fingerprint, dbCluster.ImageFilter{Project: &projectName})
Expand All @@ -4021,6 +4046,18 @@ func imageExport(d *Daemon, r *http.Request) response.Response {
return response.SmartError(err)
}

// Access control.
var userCanViewImage bool
err = s.Authorizer.CheckPermission(r.Context(), r, entity.ImageURL(projectName, imgInfo.Fingerprint), auth.EntitlementCanView)
if err == nil {
userCanViewImage = true
} else if !api.StatusErrorCheck(err, http.StatusForbidden) {
return response.SmartError(err)
}

public := d.checkTrustedClient(r) != nil || !userCanViewImage
secret := r.FormValue("secret")

if r.RemoteAddr == "@devlxd" {
if !imgInfo.Public && !imgInfo.Cached {
return response.NotFound(fmt.Errorf("Image %q not found", fingerprint))
Expand Down
1 change: 1 addition & 0 deletions lxd/network_acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ func networkACLsGet(d *Daemon, r *http.Request) response.Response {
return response.InternalError(err)
}

request.SetCtxValue(r, request.CtxEffectiveProjectName, projectName)
userHasPermission, err := s.Authorizer.GetPermissionChecker(r.Context(), r, auth.EntitlementCanView, entity.TypeNetworkACL)
if err != nil {
return response.SmartError(err)
Expand Down
1 change: 1 addition & 0 deletions lxd/network_zones.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ func networkZonesGet(d *Daemon, r *http.Request) response.Response {
return response.InternalError(err)
}

request.SetCtxValue(r, request.CtxEffectiveProjectName, projectName)
userHasPermission, err := s.Authorizer.GetPermissionChecker(r.Context(), r, auth.EntitlementCanView, entity.TypeNetworkZone)
if err != nil {
return response.InternalError(err)
Expand Down
2 changes: 2 additions & 0 deletions lxd/networks.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ func networksGet(d *Daemon, r *http.Request) response.Response {
return response.SmartError(err)
}

request.SetCtxValue(r, request.CtxEffectiveProjectName, projectName)

recursion := util.IsRecursionRequest(r)

var networkNames []string
Expand Down
1 change: 1 addition & 0 deletions lxd/profiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ func profilesGet(d *Daemon, r *http.Request) response.Response {

recursion := util.IsRecursionRequest(r)

request.SetCtxValue(r, request.CtxEffectiveProjectName, p.Name)
userHasPermission, err := s.Authorizer.GetPermissionChecker(r.Context(), r, auth.EntitlementCanView, entity.TypeProfile)
if err != nil {
return response.InternalError(err)
Expand Down
5 changes: 5 additions & 0 deletions lxd/request/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ const (
// This contains groups defined by the identity provider if the identity authenticated with OIDC on another cluster
// member.
CtxForwardedIdentityProviderGroups CtxKey = "identity_provider_groups"

// CtxEffectiveProjectName is used to indicate that the effective project of a resource is different from the project
// specified in the URL. (For example, if a project has `features.networks=false`, any networks in this project actually
// belong to the default project).
CtxEffectiveProjectName CtxKey = "effective_project_name"
)

// Headers.
Expand Down
29 changes: 29 additions & 0 deletions lxd/request/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package request

import (
"context"
"fmt"
"net/http"
)

// GetCtxValue gets a value of type T from the context using the given key.
func GetCtxValue[T any](ctx context.Context, key CtxKey) (T, error) {
var empty T
valueAny := ctx.Value(key)
if valueAny == nil {
return empty, fmt.Errorf("Failed to get expected value %q from context", key)
}

value, ok := valueAny.(T)
if !ok {
return empty, fmt.Errorf("Value for context key %q has incorrect type (expected %T, got %T)", key, empty, valueAny)
}

return value, nil
}

// SetCtxValue sets the given value in the request context with the given key.
func SetCtxValue(r *http.Request, key CtxKey, value any) {
rWithCtx := r.WithContext(context.WithValue(r.Context(), key, value))
*r = *rWithCtx
}
1 change: 1 addition & 0 deletions lxd/storage_buckets.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ func storagePoolBucketsGet(d *Daemon, r *http.Request) response.Response {
return response.SmartError(err)
}

request.SetCtxValue(r, request.CtxEffectiveProjectName, bucketProjectName)
userHasPermission, err := s.Authorizer.GetPermissionChecker(r.Context(), r, auth.EntitlementCanView, entity.TypeStorageBucket)
if err != nil {
return response.SmartError(err)
Expand Down
7 changes: 6 additions & 1 deletion lxd/storage_volumes.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ func storagePoolVolumesGet(d *Daemon, r *http.Request) response.Response {

var dbVolumes []*db.StorageVolume
var projectImages []string
var customVolProjectName string

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
var customVolProjectName string

if !allProjects {
dbProject, err := cluster.GetProject(ctx, tx.Tx(), requestProjectName)
Expand Down Expand Up @@ -453,6 +453,11 @@ func storagePoolVolumesGet(d *Daemon, r *http.Request) response.Response {
return volA.Name < volB.Name
})

// If we're requesting for just one project, set the effective project name of volumes in this project.
if !allProjects {
request.SetCtxValue(r, request.CtxEffectiveProjectName, customVolProjectName)
}

userHasPermission, err := s.Authorizer.GetPermissionChecker(r.Context(), r, auth.EntitlementCanView, entity.TypeStorageVolume)
if err != nil {
return response.SmartError(err)
Expand Down
30 changes: 30 additions & 0 deletions test/suites/tls_restrictions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ test_tls_restrictions() {

! lxc_remote project create localhost:blah1 || false

# Ensure we can create and view resources that are not enabled for the project (e.g. their effective project is
# the default project).

# Networks are disabled when projects are created.
lxc_remote network create localhost:blah-network --project blah
lxc_remote network show localhost:blah-network --project blah
lxc_remote network list localhost: --project blah | grep blah-network
lxc_remote network rm localhost:blah-network --project blah

# Network zones are disabled when projects are created.
lxc_remote network zone create localhost:blah-zone --project blah
lxc_remote network zone show localhost:blah-zone --project blah
lxc_remote network zone list localhost: --project blah | grep blah-zone
lxc_remote network zone delete localhost:blah-zone --project blah

# Unset the profiles feature (the default is false).
lxc project unset blah features.profiles
lxc_remote profile create localhost:blah-profile --project blah
lxc_remote profile show localhost:blah-profile --project blah
lxc_remote profile list localhost: --project blah | grep blah-profile
lxc_remote profile delete localhost:blah-profile --project blah

# Unset the storage volumes feature (the default is false).
lxc project unset blah features.storage.volumes
lxc_remote storage volume create "localhost:${pool_name}" blah-volume --project blah
lxc_remote storage volume show "localhost:${pool_name}" blah-volume --project blah
lxc_remote storage volume list "localhost:${pool_name}" --project blah
lxc_remote storage volume list "localhost:${pool_name}" --project blah | grep blah-volume
lxc_remote storage volume delete "localhost:${pool_name}" blah-volume --project blah

# Cleanup
lxc config trust show "${FINGERPRINT}" | sed -e "s/restricted: true/restricted: false/" | lxc config trust edit "${FINGERPRINT}"
lxc project delete blah
Expand Down

0 comments on commit 35aa763

Please sign in to comment.