From 4e280ae927e68ac3da2d2e8adfad238493b4fc36 Mon Sep 17 00:00:00 2001 From: Abduh Date: Thu, 4 Aug 2022 17:07:23 +0700 Subject: [PATCH] refactor: proxy middleware (#135) * refactor: proxy middleware * fix: update coverage badge * fix: slug generation * refactor: update create urn to use function receiver * refactor: clean up identity proxy header --- README.md | 2 +- cmd/serve.go | 28 +++---- cmd/serve_proxy.go | 14 ++-- core/group/group.go | 3 +- core/organization/organization.go | 3 +- core/project/project.go | 3 +- core/resource/resource.go | 26 +++--- core/resource/service.go | 4 +- core/role/role.go | 3 +- core/rule/context.go | 16 ++++ core/rule/rule.go | 14 ++-- core/user/context.go | 14 ++++ core/user/service.go | 50 ++--------- core/user/user.go | 4 +- internal/api/api.go | 23 +++--- internal/api/v1beta1/action.go | 13 ++- internal/api/v1beta1/errors.go | 21 +++++ internal/api/v1beta1/group.go | 38 +++++---- internal/api/v1beta1/namespace.go | 8 +- internal/api/v1beta1/org.go | 38 +++++---- internal/api/v1beta1/org_test.go | 3 +- internal/api/v1beta1/permission_check.go | 10 +-- internal/api/v1beta1/policy.go | 12 +-- internal/api/v1beta1/policy_test.go | 5 +- internal/api/v1beta1/project.go | 30 +++---- internal/api/v1beta1/project_test.go | 5 +- internal/api/v1beta1/relation.go | 24 +++--- internal/api/v1beta1/resource.go | 33 ++++---- internal/api/v1beta1/role.go | 24 +++--- internal/api/v1beta1/user.go | 79 ++++++++---------- internal/api/v1beta1/user_test.go | 23 +++--- internal/api/v1beta1/util.go | 52 ------------ internal/api/v1beta1/v1beta1.go | 53 +++++------- internal/proxy/director.go | 5 +- internal/proxy/hook/authz/authz.go | 20 ++--- internal/proxy/middleware/attribute.go | 21 +++++ internal/proxy/middleware/authz/authz.go | 51 ++++++------ internal/proxy/middleware/basic_auth/auth.go | 3 +- internal/proxy/middleware/context.go | 56 +++++++++++++ internal/proxy/middleware/middleware.go | 82 ------------------- internal/server/config.go | 2 +- .../grpc_interceptors/grpc_interceptors.go | 12 +-- internal/server/server.go | 2 +- pkg/httputil/context.go | 28 +++++++ pkg/httputil/header.go | 6 ++ pkg/metadata/metadata.go | 34 ++++++++ pkg/str/slug.go | 32 ++++++++ pkg/str/utils.go | 25 ------ 48 files changed, 537 insertions(+), 520 deletions(-) create mode 100644 core/rule/context.go create mode 100644 core/user/context.go create mode 100644 internal/api/v1beta1/errors.go delete mode 100644 internal/api/v1beta1/util.go create mode 100644 internal/proxy/middleware/attribute.go create mode 100644 internal/proxy/middleware/context.go rename {pkg => internal/server}/grpc_interceptors/grpc_interceptors.go (73%) create mode 100644 pkg/httputil/context.go create mode 100644 pkg/httputil/header.go create mode 100644 pkg/metadata/metadata.go create mode 100644 pkg/str/slug.go diff --git a/README.md b/README.md index 8fa6f68e2..c7c925380 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ![package workflow](https://github.com/odpf/shield/actions/workflows/release.yml/badge.svg) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg?logo=apache)](LICENSE) [![Version](https://img.shields.io/github/v/release/odpf/shield?logo=semantic-release)](Version) -[![Coverage Status](https://coveralls.io/repos/github/odpf/shield/badge.svg?branch=main)] +[![Coverage Status](https://coveralls.io/repos/github/odpf/shield/badge.svg?branch=main)](https://coveralls.io/github/odpf/shield?branch=main) Shield is a cloud native role-based authorization aware reverse-proxy service. With Shield, you can assign roles to users or groups of users to configure policies that determine whether a particular user has the ability to perform a certain action on a given resource. diff --git a/cmd/serve.go b/cmd/serve.go index e560cd1a6..9f6e272b9 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -99,13 +99,13 @@ func serve(logger log.Logger, cfg *config.Shield) error { return err } - deps, err := buildAPIDependencies(ctx, logger, cfg.App.IdentityProxyHeader, resourceBlobRepository, dbClient, spiceDBClient) + deps, err := buildAPIDependencies(ctx, logger, resourceBlobRepository, dbClient, spiceDBClient) if err != nil { return err } // serving proxies - cbs, cps, err := serveProxies(ctx, logger, cfg.App.IdentityProxyHeader, cfg.Proxy, deps.ResourceService, deps.UserService) + cbs, cps, err := serveProxies(ctx, logger, cfg.App.IdentityProxyHeader, cfg.App.UserIDHeader, cfg.Proxy, deps.ResourceService, deps.UserService) if err != nil { return err } @@ -160,7 +160,6 @@ func serve(logger log.Logger, cfg *config.Shield) error { func buildAPIDependencies( ctx context.Context, logger log.Logger, - identityProxyHeader string, resourceBlobRepository *blob.ResourcesRepository, dbc *db.Client, sdb *spicedb.SpiceDB, @@ -172,7 +171,7 @@ func buildAPIDependencies( namespaceService := namespace.NewService(namespaceRepository) userRepository := postgres.NewUserRepository(dbc) - userService := user.NewService(identityProxyHeader, userRepository) + userService := user.NewService(userRepository) relationPGRepository := postgres.NewRelationRepository(dbc) relationSpiceRepository := spicedb.NewRelationRepository(sdb) @@ -216,17 +215,16 @@ func buildAPIDependencies( } dependencies := api.Deps{ - OrgService: organizationService, - UserService: userService, - ProjectService: projectService, - GroupService: groupService, - RelationService: relationService, - ResourceService: resourceService, - RoleService: roleService, - PolicyService: policyService, - ActionService: actionService, - NamespaceService: namespaceService, - IdentityProxyHeader: identityProxyHeader, + OrgService: organizationService, + UserService: userService, + ProjectService: projectService, + GroupService: groupService, + RelationService: relationService, + ResourceService: resourceService, + RoleService: roleService, + PolicyService: policyService, + ActionService: actionService, + NamespaceService: namespaceService, } return dependencies, nil } diff --git a/cmd/serve_proxy.go b/cmd/serve_proxy.go index 7b6ca8366..88a756a76 100644 --- a/cmd/serve_proxy.go +++ b/cmd/serve_proxy.go @@ -23,7 +23,7 @@ import ( func serveProxies( ctx context.Context, logger log.Logger, - identityProxyHeader string, + identityProxyHeaderKey, userIDHeaderKey string, cfg proxy.ServicesConfig, resourceService *resource.Service, userService *user.Service, @@ -32,7 +32,7 @@ func serveProxies( var cleanUpProxies []func(ctx context.Context) error for _, svcConfig := range cfg.Services { - hookPipeline := buildHookPipeline(logger, identityProxyHeader, resourceService) + hookPipeline := buildHookPipeline(logger, identityProxyHeaderKey, resourceService) h2cProxy := proxy.NewH2c( proxy.NewH2cRoundTripper(logger, hookPipeline), @@ -57,7 +57,7 @@ func serveProxies( ruleService := rule.NewService(ruleBlobRepository) - middlewarePipeline := buildMiddlewarePipeline(logger, h2cProxy, identityProxyHeader, resourceService, userService, ruleService) + middlewarePipeline := buildMiddlewarePipeline(logger, h2cProxy, identityProxyHeaderKey, userIDHeaderKey, resourceService, userService, ruleService) cps := proxy.Serve(ctx, logger, svcConfig, middlewarePipeline) cleanUpProxies = append(cleanUpProxies, cps) @@ -67,23 +67,23 @@ func serveProxies( return cleanUpBlobs, cleanUpProxies, nil } -func buildHookPipeline(log log.Logger, identityProxyHeader string, resourceService v1beta1.ResourceService) hook.Service { +func buildHookPipeline(log log.Logger, identityProxyHeaderKey string, resourceService v1beta1.ResourceService) hook.Service { rootHook := hook.New() - return authz_hook.New(log, rootHook, rootHook, identityProxyHeader, resourceService) + return authz_hook.New(log, rootHook, rootHook, identityProxyHeaderKey, resourceService) } // buildPipeline builds middleware sequence func buildMiddlewarePipeline( logger log.Logger, proxy http.Handler, - identityProxyHeader string, + identityProxyHeaderKey, userIDHeaderKey string, resourceService *resource.Service, userService *user.Service, ruleService *rule.Service, ) http.Handler { // Note: execution order is bottom up prefixWare := prefix.New(logger, proxy) - casbinAuthz := authz.New(logger, prefixWare, identityProxyHeader, resourceService, userService) + casbinAuthz := authz.New(logger, prefixWare, identityProxyHeaderKey, userIDHeaderKey, resourceService, userService) basicAuthn := basic_auth.New(logger, casbinAuthz) matchWare := rulematch.New(logger, basicAuthn, rulematch.NewRouteMatcher(ruleService)) return matchWare diff --git a/core/group/group.go b/core/group/group.go index eed2c460b..ca8b43296 100644 --- a/core/group/group.go +++ b/core/group/group.go @@ -7,6 +7,7 @@ import ( "github.com/odpf/shield/core/organization" "github.com/odpf/shield/core/relation" "github.com/odpf/shield/core/user" + "github.com/odpf/shield/pkg/metadata" ) type Repository interface { @@ -29,7 +30,7 @@ type Group struct { Slug string Organization organization.Organization OrganizationID string `json:"orgId"` - Metadata map[string]any + Metadata metadata.Metadata CreatedAt time.Time UpdatedAt time.Time } diff --git a/core/organization/organization.go b/core/organization/organization.go index bda9ceec9..52c3b8d4a 100644 --- a/core/organization/organization.go +++ b/core/organization/organization.go @@ -5,6 +5,7 @@ import ( "time" "github.com/odpf/shield/core/user" + "github.com/odpf/shield/pkg/metadata" ) type Repository interface { @@ -21,7 +22,7 @@ type Organization struct { ID string Name string Slug string - Metadata map[string]any + Metadata metadata.Metadata CreatedAt time.Time UpdatedAt time.Time } diff --git a/core/project/project.go b/core/project/project.go index 58a628336..4943ce91f 100644 --- a/core/project/project.go +++ b/core/project/project.go @@ -6,6 +6,7 @@ import ( "github.com/odpf/shield/core/organization" "github.com/odpf/shield/core/user" + "github.com/odpf/shield/pkg/metadata" ) type Repository interface { @@ -23,7 +24,7 @@ type Project struct { Name string Slug string Organization organization.Organization - Metadata map[string]any + Metadata metadata.Metadata CreatedAt time.Time UpdatedAt time.Time } diff --git a/core/resource/resource.go b/core/resource/resource.go index 6f2ceca9f..63206c8ea 100644 --- a/core/resource/resource.go +++ b/core/resource/resource.go @@ -45,22 +45,10 @@ type Resource struct { UpdatedAt time.Time } -type Filter struct { - ProjectID string - GroupID string - OrganizationID string - NamespaceID string -} - -type YAML struct { - Name string `json:"name" yaml:"name"` - Actions map[string][]string `json:"actions" yaml:"actions"` -} - /* /project/uuid/ */ -func CreateURN(res Resource) string { +func (res Resource) CreateURN() string { isSystemNS := namespace.IsSystemNamespaceID(res.NamespaceID) if isSystemNS { return res.Name @@ -70,3 +58,15 @@ func CreateURN(res Resource) string { } return fmt.Sprintf("r/%s/%s", res.NamespaceID, res.Name) } + +type Filter struct { + ProjectID string + GroupID string + OrganizationID string + NamespaceID string +} + +type YAML struct { + Name string `json:"name" yaml:"name"` + Actions map[string][]string `json:"actions" yaml:"actions"` +} diff --git a/core/resource/service.go b/core/resource/service.go index a781e1fa5..1e394a469 100644 --- a/core/resource/service.go +++ b/core/resource/service.go @@ -46,7 +46,7 @@ func (s Service) Get(ctx context.Context, id string) (Resource, error) { } func (s Service) Create(ctx context.Context, res Resource) (Resource, error) { - urn := CreateURN(res) + urn := res.CreateURN() usr, err := s.userService.FetchCurrentUser(ctx) if err != nil { @@ -218,7 +218,7 @@ func (s Service) CheckAuthz(ctx context.Context, res Resource, act action.Action return false, err } - res.URN = CreateURN(res) + res.URN = res.CreateURN() isSystemNS := namespace.IsSystemNamespaceID(res.NamespaceID) fetchedResource := res diff --git a/core/role/role.go b/core/role/role.go index e358ed7a3..5cd618feb 100644 --- a/core/role/role.go +++ b/core/role/role.go @@ -7,6 +7,7 @@ import ( "time" "github.com/odpf/shield/core/namespace" + "github.com/odpf/shield/pkg/metadata" ) type Repository interface { @@ -22,7 +23,7 @@ type Role struct { Types []string Namespace namespace.Namespace NamespaceID string - Metadata map[string]any + Metadata metadata.Metadata CreatedAt time.Time UpdatedAt time.Time } diff --git a/core/rule/context.go b/core/rule/context.go new file mode 100644 index 000000000..2b077aa70 --- /dev/null +++ b/core/rule/context.go @@ -0,0 +1,16 @@ +package rule + +import ( + "context" +) + +type contextRuleKey struct{} + +func WithContext(ctx context.Context, rule *Rule) context.Context { + return context.WithValue(ctx, contextRuleKey{}, rule) +} + +func GetFromContext(ctx context.Context) (*Rule, bool) { + rl, ok := ctx.Value(contextRuleKey{}).(*Rule) + return rl, ok +} diff --git a/core/rule/rule.go b/core/rule/rule.go index 5f8e03178..1118eb9cd 100644 --- a/core/rule/rule.go +++ b/core/rule/rule.go @@ -32,13 +32,6 @@ type MiddlewareSpec struct { type MiddlewareSpecs []MiddlewareSpec -type HookSpec struct { - Name string `yaml:"name"` - Config map[string]interface{} `yaml:"config"` -} - -type HookSpecs []HookSpec - func (m MiddlewareSpecs) Get(name string) (MiddlewareSpec, bool) { for _, n := range m { if n.Name == name { @@ -48,6 +41,13 @@ func (m MiddlewareSpecs) Get(name string) (MiddlewareSpec, bool) { return MiddlewareSpec{}, false } +type HookSpec struct { + Name string `yaml:"name"` + Config map[string]interface{} `yaml:"config"` +} + +type HookSpecs []HookSpec + func (m HookSpecs) Get(name string) (HookSpec, bool) { for _, n := range m { if n.Name == name { diff --git a/core/user/context.go b/core/user/context.go new file mode 100644 index 000000000..b6a6dcb08 --- /dev/null +++ b/core/user/context.go @@ -0,0 +1,14 @@ +package user + +import "context" + +type contextEmailKey struct{} + +func SetContextWithEmail(ctx context.Context, email string) context.Context { + return context.WithValue(ctx, contextEmailKey{}, email) +} + +func GetEmailFromContext(ctx context.Context) (string, bool) { + email, ok := ctx.Value(contextEmailKey{}).(string) + return email, ok +} diff --git a/core/user/service.go b/core/user/service.go index 5b31f5526..2f28d9080 100644 --- a/core/user/service.go +++ b/core/user/service.go @@ -2,22 +2,16 @@ package user import ( "context" - "fmt" - - "google.golang.org/grpc/metadata" + "errors" ) -var emailContext = struct{}{} - type Service struct { - repository Repository - identityProxyHeader string + repository Repository } -func NewService(identityProxyHeader string, repository Repository) *Service { +func NewService(repository Repository) *Service { return &Service{ - identityProxyHeader: identityProxyHeader, - repository: repository, + repository: repository, } } @@ -67,9 +61,9 @@ func (s Service) UpdateByEmail(ctx context.Context, toUpdate User) (User, error) } func (s Service) FetchCurrentUser(ctx context.Context) (User, error) { - email, err := fetchEmailFromMetadata(ctx, s.identityProxyHeader) - if err != nil { - return User{}, err + email, ok := GetEmailFromContext(ctx) + if !ok { + return User{}, errors.New("unable to fetch email from context") } fetchedUser, err := s.repository.GetByEmail(ctx, email) @@ -79,33 +73,3 @@ func (s Service) FetchCurrentUser(ctx context.Context) (User, error) { return fetchedUser, nil } - -// TODO need to simplify this, service package should not depend on grpc metadata -func fetchEmailFromMetadata(ctx context.Context, headerKey string) (string, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - val, ok := GetEmailFromContext(ctx) - if !ok { - return "", fmt.Errorf("unable to fetch context from incoming") - } - - return val, nil - } - - var email string - metadataValues := md.Get(headerKey) - if len(metadataValues) > 0 { - email = metadataValues[0] - } - return email, nil -} - -func SetEmailToContext(ctx context.Context, email string) context.Context { - return context.WithValue(ctx, emailContext, email) -} - -func GetEmailFromContext(ctx context.Context) (string, bool) { - val, ok := ctx.Value(emailContext).(string) - - return val, ok -} diff --git a/core/user/user.go b/core/user/user.go index b85599363..fde42a4b7 100644 --- a/core/user/user.go +++ b/core/user/user.go @@ -3,6 +3,8 @@ package user import ( "context" "time" + + "github.com/odpf/shield/pkg/metadata" ) type Repository interface { @@ -19,7 +21,7 @@ type User struct { ID string Name string Email string - Metadata map[string]any + Metadata metadata.Metadata CreatedAt time.Time UpdatedAt time.Time } diff --git a/internal/api/api.go b/internal/api/api.go index a6a6c9426..d1d322431 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -15,16 +15,15 @@ import ( ) type Deps struct { - OrgService *organization.Service - ProjectService *project.Service - GroupService *group.Service - RoleService *role.Service - PolicyService *policy.Service - UserService *user.Service - NamespaceService *namespace.Service - ActionService *action.Service - RelationService *relation.Service - ResourceService *resource.Service - RuleService *rule.Service - IdentityProxyHeader string + OrgService *organization.Service + ProjectService *project.Service + GroupService *group.Service + RoleService *role.Service + PolicyService *policy.Service + UserService *user.Service + NamespaceService *namespace.Service + ActionService *action.Service + RelationService *relation.Service + ResourceService *resource.Service + RuleService *rule.Service } diff --git a/internal/api/v1beta1/action.go b/internal/api/v1beta1/action.go index f350b83aa..a4bacb418 100644 --- a/internal/api/v1beta1/action.go +++ b/internal/api/v1beta1/action.go @@ -23,7 +23,6 @@ var grpcActionNotFoundErr = status.Errorf(codes.NotFound, "action doesn't exist" func (h Handler) ListActions(ctx context.Context, request *shieldv1beta1.ListActionsRequest) (*shieldv1beta1.ListActionsResponse, error) { logger := grpczap.Extract(ctx) - var actions []*shieldv1beta1.Action actionsList, err := h.actionService.List(ctx) if err != nil { @@ -31,13 +30,13 @@ func (h Handler) ListActions(ctx context.Context, request *shieldv1beta1.ListAct return nil, grpcInternalServerError } + var actions []*shieldv1beta1.Action for _, act := range actionsList { actPB, err := transformActionToPB(act) if err != nil { logger.Error(err.Error()) return nil, grpcInternalServerError } - actions = append(actions, &actPB) } @@ -48,9 +47,9 @@ func (h Handler) CreateAction(ctx context.Context, request *shieldv1beta1.Create logger := grpczap.Extract(ctx) newAction, err := h.actionService.Create(ctx, action.Action{ - ID: request.GetBody().Id, - Name: request.GetBody().Name, - NamespaceID: request.GetBody().NamespaceId, + ID: request.GetBody().GetId(), + Name: request.GetBody().GetName(), + NamespaceID: request.GetBody().GetNamespaceId(), }) if err != nil { logger.Error(err.Error()) @@ -102,8 +101,8 @@ func (h Handler) UpdateAction(ctx context.Context, request *shieldv1beta1.Update updatedAction, err := h.actionService.Update(ctx, request.GetId(), action.Action{ ID: request.GetId(), - Name: request.GetBody().Name, - NamespaceID: request.GetBody().NamespaceId, + Name: request.GetBody().GetName(), + NamespaceID: request.GetBody().GetNamespaceId(), }) if err != nil { diff --git a/internal/api/v1beta1/errors.go b/internal/api/v1beta1/errors.go new file mode 100644 index 000000000..901925aed --- /dev/null +++ b/internal/api/v1beta1/errors.go @@ -0,0 +1,21 @@ +package v1beta1 + +import ( + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// HTTP Codes defined here: +// https://github.com/grpc-ecosystem/grpc-gateway/blob/master/runtime/errors.go#L36 +var ( + internalServerError = errors.New("internal server error") + badRequestError = errors.New("invalid syntax in body") + permissionDeniedError = errors.New("permission denied") + + grpcInternalServerError = status.Errorf(codes.Internal, internalServerError.Error()) + grpcConflictError = status.Errorf(codes.AlreadyExists, badRequestError.Error()) + grpcBadBodyError = status.Error(codes.InvalidArgument, badRequestError.Error()) + grpcPermissionDenied = status.Error(codes.PermissionDenied, permissionDeniedError.Error()) +) diff --git a/internal/api/v1beta1/group.go b/internal/api/v1beta1/group.go index eb74f1fbe..1bad20471 100644 --- a/internal/api/v1beta1/group.go +++ b/internal/api/v1beta1/group.go @@ -5,6 +5,8 @@ import ( "strings" "github.com/odpf/shield/pkg/errors" + "github.com/odpf/shield/pkg/metadata" + "github.com/odpf/shield/pkg/str" "github.com/odpf/shield/pkg/uuid" grpczap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" @@ -17,7 +19,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -68,30 +69,31 @@ func (h Handler) ListGroups(ctx context.Context, request *shieldv1beta1.ListGrou func (h Handler) CreateGroup(ctx context.Context, request *shieldv1beta1.CreateGroupRequest) (*shieldv1beta1.CreateGroupResponse, error) { logger := grpczap.Extract(ctx) - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { logger.Error(err.Error()) return nil, grpcBadBodyError } - slug := request.GetBody().Slug - if strings.TrimSpace(slug) == "" { - slug = generateSlug(request.GetBody().Name) + grp := group.Group{ + Name: request.GetBody().GetName(), + Slug: request.GetBody().GetSlug(), + OrganizationID: request.GetBody().GetOrgId(), + Metadata: metaDataMap, } - newGroup, err := h.groupService.Create(ctx, group.Group{ - Name: request.Body.Name, - Slug: slug, - OrganizationID: request.Body.OrgId, - Metadata: metaDataMap, - }) + if strings.TrimSpace(grp.Slug) == "" { + grp.Slug = str.GenerateSlug(grp.Name) + } + + newGroup, err := h.groupService.Create(ctx, grp) if err != nil { logger.Error(err.Error()) return nil, grpcInternalServerError } - metaData, err := structpb.NewStruct(mapOfInterfaceValues(newGroup.Metadata)) + metaData, err := newGroup.Metadata.ToStructPB() if err != nil { logger.Error(err.Error()) return nil, grpcInternalServerError @@ -160,9 +162,10 @@ func (h Handler) ListGroupUsers(ctx context.Context, request *shieldv1beta1.List func (h Handler) AddGroupUser(ctx context.Context, request *shieldv1beta1.AddGroupUserRequest) (*shieldv1beta1.AddGroupUserResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } + updatedUsers, err := h.groupService.AddUsers(ctx, request.GetId(), request.GetBody().GetUserIds()) if err != nil { logger.Error(err.Error()) @@ -177,7 +180,6 @@ func (h Handler) AddGroupUser(ctx context.Context, request *shieldv1beta1.AddGro } var users []*shieldv1beta1.User - for _, u := range updatedUsers { userPB, err := transformUserToPB(u) if err != nil { @@ -214,11 +216,11 @@ func (h Handler) RemoveGroupUser(ctx context.Context, request *shieldv1beta1.Rem func (h Handler) UpdateGroup(ctx context.Context, request *shieldv1beta1.UpdateGroupRequest) (*shieldv1beta1.UpdateGroupResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { return nil, grpcBadBodyError } @@ -285,7 +287,7 @@ func (h Handler) ListGroupAdmins(ctx context.Context, request *shieldv1beta1.Lis func (h Handler) AddGroupAdmin(ctx context.Context, request *shieldv1beta1.AddGroupAdminRequest) (*shieldv1beta1.AddGroupAdminResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } updatedUsers, err := h.groupService.AddAdmins(ctx, request.GetId(), request.GetBody().GetUserIds()) @@ -340,7 +342,7 @@ func (h Handler) RemoveGroupAdmin(ctx context.Context, request *shieldv1beta1.Re } func transformGroupToPB(grp group.Group) (shieldv1beta1.Group, error) { - metaData, err := structpb.NewStruct(mapOfInterfaceValues(grp.Metadata)) + metaData, err := grp.Metadata.ToStructPB() if err != nil { return shieldv1beta1.Group{}, err } diff --git a/internal/api/v1beta1/namespace.go b/internal/api/v1beta1/namespace.go index 0d056f2cd..f97a89a41 100644 --- a/internal/api/v1beta1/namespace.go +++ b/internal/api/v1beta1/namespace.go @@ -48,8 +48,8 @@ func (h Handler) CreateNamespace(ctx context.Context, request *shieldv1beta1.Cre logger := grpczap.Extract(ctx) newNS, err := h.namespaceService.Create(ctx, namespace.Namespace{ - ID: request.GetBody().Id, - Name: request.GetBody().Name, + ID: request.GetBody().GetId(), + Name: request.GetBody().GetName(), }) if err != nil { @@ -100,8 +100,8 @@ func (h Handler) UpdateNamespace(ctx context.Context, request *shieldv1beta1.Upd logger := grpczap.Extract(ctx) updatedNS, err := h.namespaceService.Update(ctx, namespace.Namespace{ - ID: request.GetBody().Id, - Name: request.GetBody().Name, + ID: request.GetBody().GetId(), + Name: request.GetBody().GetName(), }) if err != nil { diff --git a/internal/api/v1beta1/org.go b/internal/api/v1beta1/org.go index d45a26f94..71ec16ea3 100644 --- a/internal/api/v1beta1/org.go +++ b/internal/api/v1beta1/org.go @@ -6,6 +6,8 @@ import ( "github.com/odpf/shield/core/user" "github.com/odpf/shield/pkg/errors" + "github.com/odpf/shield/pkg/metadata" + "github.com/odpf/shield/pkg/str" "github.com/odpf/shield/pkg/uuid" grpczap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" @@ -14,7 +16,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" shieldv1beta1 "github.com/odpf/shield/proto/v1beta1" @@ -59,33 +60,34 @@ func (h Handler) CreateOrganization(ctx context.Context, request *shieldv1beta1. logger := grpczap.Extract(ctx) // TODO (@krtkvrm): Add validations using Proto - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { logger.Error(err.Error()) return nil, grpcBadBodyError } - slug := request.GetBody().Slug - if strings.TrimSpace(slug) == "" { - slug = generateSlug(request.GetBody().Name) + org := organization.Organization{ + Name: request.GetBody().GetName(), + Slug: request.GetBody().GetSlug(), + Metadata: metaDataMap, } - newOrg, err := h.orgService.Create(ctx, organization.Organization{ - Name: request.GetBody().Name, - Slug: slug, - Metadata: metaDataMap, - }) + if strings.TrimSpace(org.Slug) == "" { + org.Slug = str.GenerateSlug(org.Name) + } + + newOrg, err := h.orgService.Create(ctx, org) if err != nil { logger.Error(err.Error()) return nil, grpcInternalServerError } - metaData, err := structpb.NewStruct(mapOfInterfaceValues(newOrg.Metadata)) + metaData, err := newOrg.Metadata.ToStructPB() if err != nil { logger.Error(err.Error()) return nil, grpcInternalServerError @@ -131,11 +133,11 @@ func (h Handler) GetOrganization(ctx context.Context, request *shieldv1beta1.Get func (h Handler) UpdateOrganization(ctx context.Context, request *shieldv1beta1.UpdateOrganizationRequest) (*shieldv1beta1.UpdateOrganizationResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { return nil, grpcBadBodyError } @@ -144,13 +146,13 @@ func (h Handler) UpdateOrganization(ctx context.Context, request *shieldv1beta1. if uuid.IsValid(request.GetId()) { updatedOrg, err = h.orgService.Update(ctx, organization.Organization{ ID: request.GetId(), - Name: request.GetBody().Name, - Slug: request.GetBody().Slug, + Name: request.GetBody().GetName(), + Slug: request.GetBody().GetSlug(), Metadata: metaDataMap, }) } else { updatedOrg, err = h.orgService.Update(ctx, organization.Organization{ - Name: request.GetBody().Name, + Name: request.GetBody().GetName(), Slug: request.GetId(), Metadata: metaDataMap, }) @@ -248,7 +250,7 @@ func (h Handler) RemoveOrganizationAdmin(ctx context.Context, request *shieldv1b } func transformOrgToPB(org organization.Organization) (shieldv1beta1.Organization, error) { - metaData, err := structpb.NewStruct(mapOfInterfaceValues(org.Metadata)) + metaData, err := org.Metadata.ToStructPB() if err != nil { return shieldv1beta1.Organization{}, err } diff --git a/internal/api/v1beta1/org_test.go b/internal/api/v1beta1/org_test.go index e9ce6ef83..e6d5215ad 100644 --- a/internal/api/v1beta1/org_test.go +++ b/internal/api/v1beta1/org_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/odpf/shield/core/user" + "github.com/odpf/shield/pkg/metadata" "github.com/stretchr/testify/assert" @@ -24,7 +25,7 @@ var testOrgMap = map[string]organization.Organization{ ID: "9f256f86-31a3-11ec-8d3d-0242ac130003", Name: "Org 1", Slug: "org-1", - Metadata: map[string]any{ + Metadata: metadata.Metadata{ "email": "org1@org1.com", "age": 21, "intern": true, diff --git a/internal/api/v1beta1/permission_check.go b/internal/api/v1beta1/permission_check.go index 1772f6f07..70b5e850c 100644 --- a/internal/api/v1beta1/permission_check.go +++ b/internal/api/v1beta1/permission_check.go @@ -19,18 +19,18 @@ var ( internalServerErr = fmt.Errorf("internal server error") ) -func (h Handler) CheckResourcePermission(ctx context.Context, in *shieldv1beta1.ResourceActionAuthzRequest) (*shieldv1beta1.ResourceActionAuthzResponse, error) { +func (h Handler) CheckResourcePermission(ctx context.Context, req *shieldv1beta1.ResourceActionAuthzRequest) (*shieldv1beta1.ResourceActionAuthzResponse, error) { logger := grpczap.Extract(ctx) - if err := in.ValidateAll(); err != nil { + if err := req.ValidateAll(); err != nil { formattedErr := getValidationErrorMessage(err) logger.Error(formattedErr.Error()) return nil, status.Errorf(codes.NotFound, formattedErr.Error()) } result, err := h.resourceService.CheckAuthz(ctx, resource.Resource{ - Name: in.ResourceId, - NamespaceID: in.NamespaceId, - }, action.Action{ID: in.ActionId}) + Name: req.GetResourceId(), + NamespaceID: req.GetNamespaceId(), + }, action.Action{ID: req.GetActionId()}) if err != nil { formattedErr := fmt.Errorf("%s: %w", internalServerErr, err) logger.Error(formattedErr.Error()) diff --git a/internal/api/v1beta1/policy.go b/internal/api/v1beta1/policy.go index 68ef1a5db..d0c4ef07c 100644 --- a/internal/api/v1beta1/policy.go +++ b/internal/api/v1beta1/policy.go @@ -50,9 +50,9 @@ func (h Handler) CreatePolicy(ctx context.Context, request *shieldv1beta1.Create var policies []*shieldv1beta1.Policy newPolicies, err := h.policyService.Create(ctx, policy.Policy{ - RoleID: request.GetBody().RoleId, - NamespaceID: request.GetBody().NamespaceId, - ActionID: request.GetBody().ActionId, + RoleID: request.GetBody().GetRoleId(), + NamespaceID: request.GetBody().GetNamespaceId(), + ActionID: request.GetBody().GetActionId(), }) if err != nil { @@ -114,9 +114,9 @@ func (h Handler) UpdatePolicy(ctx context.Context, request *shieldv1beta1.Update updatedPolices, err := h.policyService.Update(ctx, policy.Policy{ ID: request.GetId(), - RoleID: request.GetBody().RoleId, - NamespaceID: request.GetBody().NamespaceId, - ActionID: request.GetBody().ActionId, + RoleID: request.GetBody().GetRoleId(), + NamespaceID: request.GetBody().GetNamespaceId(), + ActionID: request.GetBody().GetActionId(), }) if err != nil { diff --git a/internal/api/v1beta1/policy_test.go b/internal/api/v1beta1/policy_test.go index 9cc59dc93..6a29a5365 100644 --- a/internal/api/v1beta1/policy_test.go +++ b/internal/api/v1beta1/policy_test.go @@ -10,6 +10,7 @@ import ( "github.com/odpf/shield/core/namespace" "github.com/odpf/shield/core/policy" "github.com/odpf/shield/core/role" + "github.com/odpf/shield/pkg/metadata" shieldv1beta1 "github.com/odpf/shield/proto/v1beta1" "github.com/stretchr/testify/assert" @@ -43,7 +44,7 @@ var testPolicyMap = map[string]policy.Policy{ Role: role.Role{ ID: "reader", Name: "Reader", - Metadata: map[string]any{}, + Metadata: metadata.Metadata{}, Namespace: namespace.Namespace{ ID: "resource-1", Name: "Resource 1", @@ -185,7 +186,7 @@ func TestCreatePolicy(t *testing.T) { Role: role.Role{ ID: "reader", Name: "Reader", - Metadata: map[string]any{}, + Metadata: metadata.Metadata{}, Namespace: namespace.Namespace{ ID: "resource-1", Name: "Resource 1", diff --git a/internal/api/v1beta1/project.go b/internal/api/v1beta1/project.go index 388d22588..d48e2fc96 100644 --- a/internal/api/v1beta1/project.go +++ b/internal/api/v1beta1/project.go @@ -6,6 +6,8 @@ import ( "strings" "github.com/odpf/shield/core/user" + "github.com/odpf/shield/pkg/metadata" + "github.com/odpf/shield/pkg/str" "github.com/odpf/shield/pkg/uuid" grpczap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" @@ -15,7 +17,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" shieldv1beta1 "github.com/odpf/shield/proto/v1beta1" @@ -58,30 +59,31 @@ func (h Handler) ListProjects(ctx context.Context, request *shieldv1beta1.ListPr func (h Handler) CreateProject(ctx context.Context, request *shieldv1beta1.CreateProjectRequest) (*shieldv1beta1.CreateProjectResponse, error) { logger := grpczap.Extract(ctx) - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { logger.Error(err.Error()) return nil, grpcBadBodyError } - slug := request.GetBody().Slug - if strings.TrimSpace(slug) == "" { - slug = generateSlug(request.GetBody().Name) + prj := project.Project{ + Name: request.GetBody().GetName(), + Slug: request.GetBody().GetSlug(), + Metadata: metaDataMap, + Organization: organization.Organization{ID: request.GetBody().GetOrgId()}, } - newProject, err := h.projectService.Create(ctx, project.Project{ - Name: request.GetBody().Name, - Slug: slug, - Metadata: metaDataMap, - Organization: organization.Organization{ID: request.GetBody().OrgId}, - }) + if strings.TrimSpace(prj.Slug) == "" { + prj.Slug = str.GenerateSlug(prj.Name) + } + + newProject, err := h.projectService.Create(ctx, prj) if err != nil { logger.Error(err.Error()) return nil, grpcInternalServerError } - metaData, err := structpb.NewStruct(mapOfInterfaceValues(newProject.Metadata)) + metaData, err := newProject.Metadata.ToStructPB() if err != nil { logger.Error(err.Error()) return nil, grpcInternalServerError @@ -125,7 +127,7 @@ func (h Handler) GetProject(ctx context.Context, request *shieldv1beta1.GetProje func (h Handler) UpdateProject(ctx context.Context, request *shieldv1beta1.UpdateProjectRequest) (*shieldv1beta1.UpdateProjectResponse, error) { logger := grpczap.Extract(ctx) - metaDataMap, err := mapOfStringValues(request.GetBody().GetMetadata().AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { return nil, grpcBadBodyError } @@ -239,7 +241,7 @@ func (h Handler) RemoveProjectAdmin(ctx context.Context, request *shieldv1beta1. } func transformProjectToPB(prj project.Project) (shieldv1beta1.Project, error) { - metaData, err := structpb.NewStruct(mapOfInterfaceValues(prj.Metadata)) + metaData, err := prj.Metadata.ToStructPB() if err != nil { return shieldv1beta1.Project{}, err } diff --git a/internal/api/v1beta1/project_test.go b/internal/api/v1beta1/project_test.go index 5a0afaf21..c0a560b01 100644 --- a/internal/api/v1beta1/project_test.go +++ b/internal/api/v1beta1/project_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/odpf/shield/core/user" + "github.com/odpf/shield/pkg/metadata" "github.com/stretchr/testify/assert" @@ -27,7 +28,7 @@ var testProjectMap = map[string]project.Project{ ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Name: "Prj 1", Slug: "prj-1", - Metadata: map[string]any{ + Metadata: metadata.Metadata{ "email": "org1@org1.com", }, CreatedAt: time.Time{}, @@ -37,7 +38,7 @@ var testProjectMap = map[string]project.Project{ ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Name: "Prj 2", Slug: "prj-2", - Metadata: map[string]any{ + Metadata: metadata.Metadata{ "email": "org1@org2.com", }, CreatedAt: time.Time{}, diff --git a/internal/api/v1beta1/relation.go b/internal/api/v1beta1/relation.go index 53c46e27d..dff110678 100644 --- a/internal/api/v1beta1/relation.go +++ b/internal/api/v1beta1/relation.go @@ -49,16 +49,16 @@ func (h Handler) ListRelations(ctx context.Context, request *shieldv1beta1.ListR func (h Handler) CreateRelation(ctx context.Context, request *shieldv1beta1.CreateRelationRequest) (*shieldv1beta1.CreateRelationResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } newRelation, err := h.relationService.Create(ctx, relation.Relation{ - SubjectNamespaceID: request.GetBody().SubjectType, - SubjectID: request.GetBody().SubjectId, - ObjectNamespaceID: request.GetBody().ObjectType, - ObjectID: request.GetBody().ObjectId, - RoleID: request.GetBody().RoleId, + SubjectNamespaceID: request.GetBody().GetSubjectType(), + SubjectID: request.GetBody().GetSubjectId(), + ObjectNamespaceID: request.GetBody().GetObjectType(), + ObjectID: request.GetBody().GetObjectId(), + RoleID: request.GetBody().GetRoleId(), }) if err != nil { @@ -113,17 +113,17 @@ func (h Handler) GetRelation(ctx context.Context, request *shieldv1beta1.GetRela func (h Handler) UpdateRelation(ctx context.Context, request *shieldv1beta1.UpdateRelationRequest) (*shieldv1beta1.UpdateRelationResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } updatedRelation, err := h.relationService.Update(ctx, relation.Relation{ ID: request.GetId(), - SubjectNamespaceID: request.GetBody().SubjectType, - SubjectID: request.GetBody().SubjectId, - ObjectNamespaceID: request.GetBody().ObjectType, - ObjectID: request.GetBody().ObjectId, - RoleID: request.GetBody().RoleId, + SubjectNamespaceID: request.GetBody().GetSubjectType(), + SubjectID: request.GetBody().GetSubjectId(), + ObjectNamespaceID: request.GetBody().GetObjectType(), + ObjectID: request.GetBody().GetObjectId(), + RoleID: request.GetBody().GetRoleId(), }) if err != nil { diff --git a/internal/api/v1beta1/resource.go b/internal/api/v1beta1/resource.go index 73f254de6..45213e445 100644 --- a/internal/api/v1beta1/resource.go +++ b/internal/api/v1beta1/resource.go @@ -58,17 +58,17 @@ func (h Handler) ListResources(ctx context.Context, request *shieldv1beta1.ListR func (h Handler) CreateResource(ctx context.Context, request *shieldv1beta1.CreateResourceRequest) (*shieldv1beta1.CreateResourceResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } newResource, err := h.resourceService.Create(ctx, resource.Resource{ - OrganizationID: request.GetBody().OrganizationId, - ProjectID: request.GetBody().ProjectId, - GroupID: request.GetBody().GroupId, - NamespaceID: request.GetBody().NamespaceId, - Name: request.GetBody().Name, - UserID: request.GetBody().UserId, + OrganizationID: request.GetBody().GetOrganizationId(), + ProjectID: request.GetBody().GetProjectId(), + GroupID: request.GetBody().GetGroupId(), + NamespaceID: request.GetBody().GetNamespaceId(), + Name: request.GetBody().GetName(), + UserID: request.GetBody().GetUserId(), }) if err != nil { @@ -91,8 +91,7 @@ func (h Handler) CreateResource(ctx context.Context, request *shieldv1beta1.Crea func (h Handler) GetResource(ctx context.Context, request *shieldv1beta1.GetResourceRequest) (*shieldv1beta1.GetResourceResponse, error) { logger := grpczap.Extract(ctx) - fetchedResource, err := h.resourceService.Get(ctx, request.Id) - + fetchedResource, err := h.resourceService.Get(ctx, request.GetId()) if err != nil { logger.Error(err.Error()) switch { @@ -119,17 +118,17 @@ func (h Handler) GetResource(ctx context.Context, request *shieldv1beta1.GetReso func (h Handler) UpdateResource(ctx context.Context, request *shieldv1beta1.UpdateResourceRequest) (*shieldv1beta1.UpdateResourceResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } - updatedResource, err := h.resourceService.Update(ctx, request.Id, resource.Resource{ - OrganizationID: request.GetBody().OrganizationId, - ProjectID: request.GetBody().ProjectId, - GroupID: request.GetBody().GroupId, - NamespaceID: request.GetBody().NamespaceId, - Name: request.GetBody().Name, - UserID: request.GetBody().UserId, + updatedResource, err := h.resourceService.Update(ctx, request.GetId(), resource.Resource{ + OrganizationID: request.GetBody().GetOrganizationId(), + ProjectID: request.GetBody().GetProjectId(), + GroupID: request.GetBody().GetGroupId(), + NamespaceID: request.GetBody().GetNamespaceId(), + Name: request.GetBody().GetName(), + UserID: request.GetBody().GetUserId(), }) if err != nil { diff --git a/internal/api/v1beta1/role.go b/internal/api/v1beta1/role.go index 0789f4668..ea9f802d3 100644 --- a/internal/api/v1beta1/role.go +++ b/internal/api/v1beta1/role.go @@ -6,9 +6,9 @@ import ( grpczap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "github.com/odpf/shield/core/role" + "github.com/odpf/shield/pkg/metadata" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" shieldv1beta1 "github.com/odpf/shield/proto/v1beta1" @@ -48,17 +48,17 @@ func (h Handler) ListRoles(ctx context.Context, request *shieldv1beta1.ListRoles func (h Handler) CreateRole(ctx context.Context, request *shieldv1beta1.CreateRoleRequest) (*shieldv1beta1.CreateRoleResponse, error) { logger := grpczap.Extract(ctx) - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { logger.Error(err.Error()) return nil, grpcBadBodyError } newRole, err := h.roleService.Create(ctx, role.Role{ - ID: request.GetBody().Id, - Name: request.GetBody().Name, - Types: request.GetBody().Types, - NamespaceID: request.GetBody().NamespaceId, + ID: request.GetBody().GetId(), + Name: request.GetBody().GetName(), + Types: request.GetBody().GetTypes(), + NamespaceID: request.GetBody().GetNamespaceId(), Metadata: metaDataMap, }) if err != nil { @@ -110,16 +110,16 @@ func (h Handler) GetRole(ctx context.Context, request *shieldv1beta1.GetRoleRequ func (h Handler) UpdateRole(ctx context.Context, request *shieldv1beta1.UpdateRoleRequest) (*shieldv1beta1.UpdateRoleResponse, error) { logger := grpczap.Extract(ctx) - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { return nil, grpcBadBodyError } updatedRole, err := h.roleService.Update(ctx, role.Role{ - ID: request.GetBody().Id, - Name: request.GetBody().Name, - Types: request.GetBody().Types, - NamespaceID: request.GetBody().NamespaceId, + ID: request.GetBody().GetId(), + Name: request.GetBody().GetName(), + Types: request.GetBody().GetTypes(), + NamespaceID: request.GetBody().GetNamespaceId(), Metadata: metaDataMap, }) if err != nil { @@ -144,7 +144,7 @@ func (h Handler) UpdateRole(ctx context.Context, request *shieldv1beta1.UpdateRo } func transformRoleToPB(from role.Role) (shieldv1beta1.Role, error) { - metaData, err := structpb.NewStruct(mapOfInterfaceValues(from.Metadata)) + metaData, err := from.Metadata.ToStructPB() if err != nil { return shieldv1beta1.Role{}, err } diff --git a/internal/api/v1beta1/user.go b/internal/api/v1beta1/user.go index 20e989f02..77eedcc6c 100644 --- a/internal/api/v1beta1/user.go +++ b/internal/api/v1beta1/user.go @@ -7,12 +7,11 @@ import ( grpczap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/odpf/shield/core/user" + "github.com/odpf/shield/pkg/metadata" "github.com/odpf/shield/pkg/str" shieldv1beta1 "github.com/odpf/shield/proto/v1beta1" ) @@ -67,24 +66,29 @@ func (h Handler) ListUsers(ctx context.Context, request *shieldv1beta1.ListUsers func (h Handler) CreateUser(ctx context.Context, request *shieldv1beta1.CreateUserRequest) (*shieldv1beta1.CreateUserResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) - if err != nil { - logger.Error(err.Error()) + currentUserEmail, ok := user.GetEmailFromContext(ctx) + if !ok { return nil, grpcBadBodyError } - currentUserEmail, _ := fetchEmailFromMetadata(ctx, h.identityProxyHeader) if len(currentUserEmail) == 0 { logger.Error(emptyEmailId.Error()) return nil, emptyEmailId } - email := str.DefaultStringIfEmpty(request.GetBody().Email, currentUserEmail) + + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) + if err != nil { + logger.Error(err.Error()) + return nil, grpcBadBodyError + } + + email := str.DefaultStringIfEmpty(request.GetBody().GetEmail(), currentUserEmail) userT := user.User{ - Name: request.GetBody().Name, + Name: request.GetBody().GetName(), Email: email, Metadata: metaDataMap, } @@ -99,7 +103,7 @@ func (h Handler) CreateUser(ctx context.Context, request *shieldv1beta1.CreateUs } } - metaData, err := structpb.NewStruct(mapOfInterfaceValues(newUser.Metadata)) + metaData, err := newUser.Metadata.ToStructPB() if err != nil { logger.Error(err.Error()) return nil, grpcInternalServerError @@ -145,10 +149,11 @@ func (h Handler) GetUser(ctx context.Context, request *shieldv1beta1.GetUserRequ func (h Handler) GetCurrentUser(ctx context.Context, request *shieldv1beta1.GetCurrentUserRequest) (*shieldv1beta1.GetCurrentUserResponse, error) { logger := grpczap.Extract(ctx) - email, err := fetchEmailFromMetadata(ctx, h.identityProxyHeader) - if err != nil { + email, ok := user.GetEmailFromContext(ctx) + if !ok { return nil, grpcBadBodyError } + if len(email) == 0 { logger.Error(emptyEmailId.Error()) return nil, emptyEmailId @@ -181,19 +186,19 @@ func (h Handler) GetCurrentUser(ctx context.Context, request *shieldv1beta1.GetC func (h Handler) UpdateUser(ctx context.Context, request *shieldv1beta1.UpdateUserRequest) (*shieldv1beta1.UpdateUserResponse, error) { logger := grpczap.Extract(ctx) - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { return nil, grpcBadBodyError } updatedUser, err := h.userService.UpdateByID(ctx, user.User{ ID: request.GetId(), - Name: request.GetBody().Name, - Email: request.GetBody().Email, + Name: request.GetBody().GetName(), + Email: request.GetBody().GetEmail(), Metadata: metaDataMap, }) if err != nil { @@ -220,12 +225,12 @@ func (h Handler) UpdateUser(ctx context.Context, request *shieldv1beta1.UpdateUs func (h Handler) UpdateCurrentUser(ctx context.Context, request *shieldv1beta1.UpdateCurrentUserRequest) (*shieldv1beta1.UpdateCurrentUserResponse, error) { logger := grpczap.Extract(ctx) - email, err := fetchEmailFromMetadata(ctx, h.identityProxyHeader) - if err != nil { + email, ok := user.GetEmailFromContext(ctx) + if !ok { return nil, grpcBadBodyError } - if request.Body == nil { + if request.GetBody() == nil { return nil, grpcBadBodyError } if len(email) == 0 { @@ -233,18 +238,18 @@ func (h Handler) UpdateCurrentUser(ctx context.Context, request *shieldv1beta1.U return nil, emptyEmailId } - metaDataMap, err := mapOfStringValues(request.GetBody().Metadata.AsMap()) + metaDataMap, err := metadata.Build(request.GetBody().GetMetadata().AsMap()) if err != nil { return nil, grpcBadBodyError } // if email in request body is different from the email in the header - if request.GetBody().Email != email { + if request.GetBody().GetEmail() != email { return nil, grpcBadBodyError } updatedUser, err := h.userService.UpdateByEmail(ctx, user.User{ - Name: request.GetBody().Name, + Name: request.GetBody().GetName(), Email: email, Metadata: metaDataMap, }) @@ -267,26 +272,26 @@ func (h Handler) UpdateCurrentUser(ctx context.Context, request *shieldv1beta1.U return &shieldv1beta1.UpdateCurrentUserResponse{User: &userPB}, nil } -func transformUserToPB(user user.User) (shieldv1beta1.User, error) { - metaData, err := structpb.NewStruct(mapOfInterfaceValues(user.Metadata)) +func transformUserToPB(usr user.User) (shieldv1beta1.User, error) { + metaData, err := usr.Metadata.ToStructPB() if err != nil { return shieldv1beta1.User{}, err } return shieldv1beta1.User{ - Id: user.ID, - Name: user.Name, - Email: user.Email, + Id: usr.ID, + Name: usr.Name, + Email: usr.Email, Metadata: metaData, - CreatedAt: timestamppb.New(user.CreatedAt), - UpdatedAt: timestamppb.New(user.UpdatedAt), + CreatedAt: timestamppb.New(usr.CreatedAt), + UpdatedAt: timestamppb.New(usr.UpdatedAt), }, nil } func (h Handler) ListUserGroups(ctx context.Context, request *shieldv1beta1.ListUserGroupsRequest) (*shieldv1beta1.ListUserGroupsResponse, error) { logger := grpczap.Extract(ctx) var groups []*shieldv1beta1.Group - groupsList, err := h.groupService.ListUserGroups(ctx, request.Id, request.Role) + groupsList, err := h.groupService.ListUserGroups(ctx, request.GetId(), request.GetRole()) if err != nil { logger.Error(err.Error()) @@ -307,17 +312,3 @@ func (h Handler) ListUserGroups(ctx context.Context, request *shieldv1beta1.List Groups: groups, }, nil } - -func fetchEmailFromMetadata(ctx context.Context, headerKey string) (string, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return "", grpcBadBodyError - } - - var email string - metadataValues := md.Get(headerKey) - if len(metadataValues) > 0 { - email = metadataValues[0] - } - return email, nil -} diff --git a/internal/api/v1beta1/user_test.go b/internal/api/v1beta1/user_test.go index 6db459c3c..faddb4b10 100644 --- a/internal/api/v1beta1/user_test.go +++ b/internal/api/v1beta1/user_test.go @@ -7,11 +7,11 @@ import ( "time" "github.com/odpf/shield/core/user" + "github.com/odpf/shield/pkg/metadata" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -24,7 +24,7 @@ var testUserMap = map[string]user.User{ ID: "9f256f86-31a3-11ec-8d3d-0242ac130003", Name: "User 1", Email: "test@test.com", - Metadata: map[string]any{ + Metadata: metadata.Metadata{ "foo": "bar", "age": 21, "intern": true, @@ -185,9 +185,8 @@ func TestCreateUser(t *testing.T) { var resp *shieldv1beta1.CreateUserResponse var err error if tt.title == "success" { - mockDep := Handler{userService: tt.mockUserSrv, identityProxyHeader: "x-auth-email"} - md := metadata.Pairs(mockDep.identityProxyHeader, tt.header) - ctx := metadata.NewIncomingContext(context.Background(), md) + mockDep := Handler{userService: tt.mockUserSrv} + ctx := user.SetContextWithEmail(context.Background(), tt.header) resp, err = mockDep.CreateUser(ctx, tt.req) } else { mockDep := Handler{userService: tt.mockUserSrv} @@ -226,7 +225,7 @@ func TestGetCurrentUser(t *testing.T) { ID: "user-id-1", Name: "some user", Email: "someuser@test.com", - Metadata: map[string]any{ + Metadata: metadata.Metadata{ "foo": "bar", }, CreatedAt: time.Time{}, @@ -254,9 +253,8 @@ func TestGetCurrentUser(t *testing.T) { t.Run(tt.title, func(t *testing.T) { t.Parallel() - mockDep := Handler{userService: tt.mockUserSrv, identityProxyHeader: "x-auth-email"} - md := metadata.Pairs(mockDep.identityProxyHeader, tt.header) - ctx := metadata.NewIncomingContext(context.Background(), md) + mockDep := Handler{userService: tt.mockUserSrv} + ctx := user.SetContextWithEmail(context.Background(), tt.header) resp, err := mockDep.GetCurrentUser(ctx, nil) assert.EqualValues(t, resp, tt.want) @@ -329,7 +327,7 @@ func TestUpdateCurrentUser(t *testing.T) { ID: "user-id-1", Name: "abc user", Email: "abcuser@test.com", - Metadata: map[string]any{ + Metadata: metadata.Metadata{ "foo": "bar", }, CreatedAt: time.Time{}, @@ -366,9 +364,8 @@ func TestUpdateCurrentUser(t *testing.T) { t.Run(tt.title, func(t *testing.T) { t.Parallel() - mockDep := Handler{userService: tt.mockUserSrv, identityProxyHeader: "x-auth-email"} - md := metadata.Pairs(mockDep.identityProxyHeader, tt.header) - ctx := metadata.NewIncomingContext(context.Background(), md) + mockDep := Handler{userService: tt.mockUserSrv} + ctx := user.SetContextWithEmail(context.Background(), tt.header) resp, err := mockDep.UpdateCurrentUser(ctx, tt.req) assert.EqualValues(t, resp, tt.want) diff --git a/internal/api/v1beta1/util.go b/internal/api/v1beta1/util.go deleted file mode 100644 index bf8e3d9e3..000000000 --- a/internal/api/v1beta1/util.go +++ /dev/null @@ -1,52 +0,0 @@ -package v1beta1 - -import ( - "fmt" - "strings" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// HTTP Codes defined here: -// https://github.com/grpc-ecosystem/grpc-gateway/blob/master/runtime/errors.go#L36 - -var ( - grpcInternalServerError = status.Errorf(codes.Internal, internalServerError.Error()) - grpcConflictError = status.Errorf(codes.AlreadyExists, badRequestError.Error()) - grpcBadBodyError = status.Error(codes.InvalidArgument, badRequestError.Error()) - grpcPermissionDenied = status.Error(codes.PermissionDenied, permissionDeniedError.Error()) -) - -func mapOfStringValues(m map[string]interface{}) (map[string]any, error) { - newMap := make(map[string]any) - - for key, value := range m { - switch value := value.(type) { - case any: - newMap[key] = value - default: - return map[string]any{}, fmt.Errorf("value for %s key is not string", key) - } - } - - return newMap, nil -} - -func mapOfInterfaceValues(m map[string]any) map[string]interface{} { - newMap := make(map[string]interface{}) - - for key, value := range m { - newMap[key] = value - } - - return newMap -} - -func generateSlug(name string) string { - preProcessed := strings.ReplaceAll(strings.TrimSpace(strings.TrimSpace(name)), "_", "-") - return strings.Join( - strings.Split(preProcessed, " "), - "-", - ) -} diff --git a/internal/api/v1beta1/v1beta1.go b/internal/api/v1beta1/v1beta1.go index 31bb3a684..d793f1364 100644 --- a/internal/api/v1beta1/v1beta1.go +++ b/internal/api/v1beta1/v1beta1.go @@ -2,7 +2,6 @@ package v1beta1 import ( "context" - "errors" "github.com/odpf/salt/server" "github.com/odpf/shield/internal/api" @@ -11,44 +10,36 @@ import ( type Handler struct { shieldv1beta1.UnimplementedShieldServiceServer - orgService OrganizationService - projectService ProjectService - groupService GroupService - roleService RoleService - policyService PolicyService - userService UserService - namespaceService NamespaceService - actionService ActionService - relationService RelationService - resourceService ResourceService - ruleService RuleService - identityProxyHeader string + orgService OrganizationService + projectService ProjectService + groupService GroupService + roleService RoleService + policyService PolicyService + userService UserService + namespaceService NamespaceService + actionService ActionService + relationService RelationService + resourceService ResourceService + ruleService RuleService } -var ( - internalServerError = errors.New("internal server error") - badRequestError = errors.New("invalid syntax in body") - permissionDeniedError = errors.New("permission denied") -) - func Register(ctx context.Context, s *server.MuxServer, gw *server.GRPCGateway, deps api.Deps) { gw.RegisterHandler(ctx, shieldv1beta1.RegisterShieldServiceHandlerFromEndpoint) s.RegisterService( &shieldv1beta1.ShieldService_ServiceDesc, &Handler{ - orgService: deps.OrgService, - projectService: deps.ProjectService, - groupService: deps.GroupService, - roleService: deps.RoleService, - policyService: deps.PolicyService, - userService: deps.UserService, - namespaceService: deps.NamespaceService, - actionService: deps.ActionService, - relationService: deps.RelationService, - resourceService: deps.ResourceService, - ruleService: deps.RuleService, - identityProxyHeader: deps.IdentityProxyHeader, + orgService: deps.OrgService, + projectService: deps.ProjectService, + groupService: deps.GroupService, + roleService: deps.RoleService, + policyService: deps.PolicyService, + userService: deps.UserService, + namespaceService: deps.NamespaceService, + actionService: deps.ActionService, + relationService: deps.RelationService, + resourceService: deps.ResourceService, + ruleService: deps.RuleService, }, ) } diff --git a/internal/proxy/director.go b/internal/proxy/director.go index bb25ad90a..85bbd0d91 100644 --- a/internal/proxy/director.go +++ b/internal/proxy/director.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/odpf/shield/internal/proxy/middleware" + "github.com/odpf/shield/pkg/httputil" ) var ctxRequestErrorKey = struct{}{} @@ -41,9 +42,9 @@ func (h Director) Direct(req *http.Request) { } else { req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery } - if _, ok := req.Header["User-Agent"]; !ok { + if _, ok := req.Header[httputil.HeaderUserAgent]; !ok { // explicitly disable User-Agent so it's not set to default value - req.Header.Set("User-Agent", "") + req.Header.Set(httputil.HeaderUserAgent, "") } req.Header.Set("proxy-by", "shield") } diff --git a/internal/proxy/hook/authz/authz.go b/internal/proxy/hook/authz/authz.go index c5c4dc61c..0c6ccac53 100644 --- a/internal/proxy/hook/authz/authz.go +++ b/internal/proxy/hook/authz/authz.go @@ -30,19 +30,18 @@ type Authz struct { // To skip all the next hooks and just respond back escape hook.Service - // TODO need to figure out what best to pass this - identityProxyHeader string + identityProxyHeaderKey string resourceService ResourceService } -func New(log log.Logger, next, escape hook.Service, identityProxyHeader string, resourceService ResourceService) Authz { +func New(log log.Logger, next, escape hook.Service, identityProxyHeaderKey string, resourceService ResourceService) Authz { return Authz{ - log: log, - next: next, - escape: escape, - identityProxyHeader: identityProxyHeader, - resourceService: resourceService, + log: log, + next: next, + escape: escape, + identityProxyHeaderKey: identityProxyHeaderKey, + resourceService: resourceService, } } @@ -85,8 +84,9 @@ func (a Authz) ServeHook(res *http.Response, err error) (*http.Response, error) attributes := map[string]interface{}{} attributes["namespace"] = ruleFromRequest.Backend.Namespace - attributes["user"] = res.Request.Header.Get(a.identityProxyHeader) - res.Request = res.Request.WithContext(user.SetEmailToContext(res.Request.Context(), res.Request.Header.Get(a.identityProxyHeader))) + identityProxyHeaderValue := res.Request.Header.Get(a.identityProxyHeaderKey) + attributes["user"] = identityProxyHeaderValue + res.Request = res.Request.WithContext(user.SetContextWithEmail(res.Request.Context(), identityProxyHeaderValue)) for id, attr := range config.Attributes { bdy, _ := middleware.ExtractRequestBody(res.Request) diff --git a/internal/proxy/middleware/attribute.go b/internal/proxy/middleware/attribute.go new file mode 100644 index 000000000..080efedfe --- /dev/null +++ b/internal/proxy/middleware/attribute.go @@ -0,0 +1,21 @@ +package middleware + +const ( + AttributeTypeQuery AttributeType = "query" + AttributeTypeHeader AttributeType = "header" + AttributeTypeJSONPayload AttributeType = "json_payload" + AttributeTypeGRPCPayload AttributeType = "grpc_payload" + AttributeTypePathParam AttributeType = "path_param" + AttributeTypeConstant AttributeType = "constant" +) + +type AttributeType string + +type Attribute struct { + Key string `yaml:"key" mapstructure:"key"` + Type AttributeType `yaml:"type" mapstructure:"type"` + Index string `yaml:"index" mapstructure:"index"` // proto index + Path string `yaml:"path" mapstructure:"path"` + Params []string `yaml:"params" mapstructure:"params"` + Value string `yaml:"value" mapstructure:"value"` +} diff --git a/internal/proxy/middleware/authz/authz.go b/internal/proxy/middleware/authz/authz.go index de9910dbd..4927f6173 100644 --- a/internal/proxy/middleware/authz/authz.go +++ b/internal/proxy/middleware/authz/authz.go @@ -17,10 +17,6 @@ import ( "github.com/odpf/salt/log" ) -const ( - userIDHeader = "X-Shield-User-Id" -) - type ResourceService interface { CheckAuthz(ctx context.Context, resource resource.Resource, act action.Action) (bool, error) } @@ -30,11 +26,12 @@ type UserService interface { } type Authz struct { - log log.Logger - identityProxyHeader string - next http.Handler - resourceService ResourceService - userService UserService + log log.Logger + identityProxyHeaderKey string + userIDHeaderKey string + next http.Handler + resourceService ResourceService + userService UserService } type Config struct { @@ -42,13 +39,19 @@ type Config struct { Attributes map[string]middleware.Attribute `yaml:"attributes" mapstructure:"attributes"` // auth field -> Attribute } -func New(log log.Logger, next http.Handler, identityProxyHeader string, resourceService ResourceService, userService UserService) *Authz { +func New( + log log.Logger, + next http.Handler, + identityProxyHeaderKey, userIDHeaderKey string, + resourceService ResourceService, + userService UserService) *Authz { return &Authz{ - log: log, - identityProxyHeader: identityProxyHeader, - next: next, - resourceService: resourceService, - userService: userService, + log: log, + identityProxyHeaderKey: identityProxyHeaderKey, + userIDHeaderKey: userIDHeaderKey, + next: next, + resourceService: resourceService, + userService: userService, } } @@ -60,7 +63,7 @@ func (c Authz) Info() *middleware.MiddlewareInfo { } func (c *Authz) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - req = req.WithContext(user.SetEmailToContext(req.Context(), req.Header.Get(c.identityProxyHeader))) + req = req.WithContext(user.SetContextWithEmail(req.Context(), req.Header.Get(c.identityProxyHeaderKey))) usr, err := c.userService.FetchCurrentUser(req.Context()) if err != nil { @@ -69,7 +72,7 @@ func (c *Authz) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - req.Header.Set(userIDHeader, usr.ID) + req.Header.Set(c.userIDHeaderKey, usr.ID) rule, ok := middleware.ExtractRule(req) if !ok { @@ -101,7 +104,7 @@ func (c *Authz) ServeHTTP(rw http.ResponseWriter, req *http.Request) { permissionAttributes["namespace"] = rule.Backend.Namespace - permissionAttributes["user"] = req.Header.Get(c.identityProxyHeader) + permissionAttributes["user"] = req.Header.Get(c.identityProxyHeaderKey) for res, attr := range config.Attributes { _ = res @@ -235,7 +238,6 @@ func (c *Authz) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (w Authz) notAllowed(rw http.ResponseWriter) { rw.WriteHeader(http.StatusUnauthorized) - return } func createResources(permissionAttributes map[string]interface{}) ([]resource.Resource, error) { @@ -276,15 +278,14 @@ func createResources(permissionAttributes map[string]interface{}) ([]resource.Re func getAttributesValues(attributes interface{}) ([]string, error) { var values []string - switch attributes.(type) { + + switch attributes := attributes.(type) { case []string: - for _, i := range attributes.([]string) { - values = append(values, i) - } + values = append(values, attributes...) case string: - values = append(values, attributes.(string)) + values = append(values, attributes) case []interface{}: - for _, i := range attributes.([]interface{}) { + for _, i := range attributes { values = append(values, i.(string)) } case interface{}: diff --git a/internal/proxy/middleware/basic_auth/auth.go b/internal/proxy/middleware/basic_auth/auth.go index 1ccac2cd9..e7e439b33 100644 --- a/internal/proxy/middleware/basic_auth/auth.go +++ b/internal/proxy/middleware/basic_auth/auth.go @@ -9,6 +9,7 @@ import ( "github.com/odpf/shield/internal/proxy/middleware" "github.com/odpf/shield/pkg/body_extractor" + "github.com/odpf/shield/pkg/httputil" goauth "github.com/abbot/go-http-auth" "github.com/mitchellh/mapstructure" @@ -95,7 +96,7 @@ func (w *BasicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { var authedUser string if authedUser = authenticator.CheckAuth(req); authedUser != "" { - req.Header.Set("X-User", authedUser) + req.Header.Set(httputil.HeaderXUser, authedUser) } else { w.notAllowed(rw) return diff --git a/internal/proxy/middleware/context.go b/internal/proxy/middleware/context.go new file mode 100644 index 000000000..0fc753e1f --- /dev/null +++ b/internal/proxy/middleware/context.go @@ -0,0 +1,56 @@ +package middleware + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" + + "github.com/odpf/shield/core/rule" + "github.com/odpf/shield/pkg/httputil" +) + +func EnrichRule(req *http.Request, r *rule.Rule) { + *req = *req.WithContext(rule.WithContext(req.Context(), r)) +} + +func EnrichRequestBody(r *http.Request) error { + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + defer (r.Body).Close() + + // repopulate body + (*r).Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) + *r = *r.WithContext(httputil.SetContextWithRequestBody(r.Context(), reqBody)) + return nil +} + +func ExtractRequestBody(r *http.Request) (io.ReadCloser, bool) { + body, ok := httputil.GetRequestBodyFromContext(r.Context()) + if !ok { + return nil, false + } + return ioutil.NopCloser(bytes.NewBuffer(body)), true +} + +func ExtractRule(r *http.Request) (*rule.Rule, bool) { + return rule.GetFromContext(r.Context()) +} + +func ExtractMiddleware(r *http.Request, name string) (rule.MiddlewareSpec, bool) { + rl, ok := ExtractRule(r) + if !ok { + return rule.MiddlewareSpec{}, false + } + return rl.Middlewares.Get(name) +} + +func EnrichPathParams(r *http.Request, params map[string]string) { + *r = *r.WithContext(httputil.SetContextWithPathParams(r.Context(), params)) +} + +func ExtractPathParams(r *http.Request) (map[string]string, bool) { + return httputil.GetPathParamsFromContext(r.Context()) +} diff --git a/internal/proxy/middleware/middleware.go b/internal/proxy/middleware/middleware.go index 1f0a0eceb..3393d0f73 100644 --- a/internal/proxy/middleware/middleware.go +++ b/internal/proxy/middleware/middleware.go @@ -1,21 +1,9 @@ package middleware import ( - "bytes" - "context" "fmt" - "io" - "io/ioutil" "net/http" "time" - - "github.com/odpf/shield/core/rule" -) - -const ( - ctxRuleKey = "middleware_rule" - ctxPathParamsKey = "path_params" - ctxBodyKey = "body_ctx" ) type Middleware interface { @@ -28,76 +16,6 @@ type MiddlewareInfo struct { Description string } -func EnrichRule(r *http.Request, rule *rule.Rule) { - *r = *r.WithContext(context.WithValue(r.Context(), ctxRuleKey, rule)) -} - -func EnrichRequestBody(r *http.Request) error { - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - return err - } - defer (r.Body).Close() - - // repopulate body - (*r).Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) - *r = *r.WithContext(context.WithValue(r.Context(), ctxBodyKey, reqBody)) - return nil -} - -func ExtractRequestBody(r *http.Request) (io.ReadCloser, bool) { - body, ok := r.Context().Value(ctxBodyKey).([]byte) - if !ok { - return nil, false - } - return ioutil.NopCloser(bytes.NewBuffer(body)), true -} - -func ExtractRule(r *http.Request) (*rule.Rule, bool) { - rl, ok := r.Context().Value(ctxRuleKey).(*rule.Rule) - return rl, ok -} - -func ExtractMiddleware(r *http.Request, name string) (rule.MiddlewareSpec, bool) { - rl, ok := r.Context().Value(ctxRuleKey).(*rule.Rule) - if !ok { - return rule.MiddlewareSpec{}, false - } - return rl.Middlewares.Get(name) -} - -func EnrichPathParams(r *http.Request, params map[string]string) { - *r = *r.WithContext(context.WithValue(r.Context(), ctxPathParamsKey, params)) -} - -func ExtractPathParams(r *http.Request) (map[string]string, bool) { - params, ok := r.Context().Value(ctxPathParamsKey).(map[string]string) - if !ok { - return nil, false - } - return params, true -} - -const ( - AttributeTypeQuery AttributeType = "query" - AttributeTypeHeader AttributeType = "header" - AttributeTypeJSONPayload AttributeType = "json_payload" - AttributeTypeGRPCPayload AttributeType = "grpc_payload" - AttributeTypePathParam AttributeType = "path_param" - AttributeTypeConstant AttributeType = "constant" -) - -type AttributeType string - -type Attribute struct { - Key string `yaml:"key" mapstructure:"key"` - Type AttributeType `yaml:"type" mapstructure:"type"` - Index string `yaml:"index" mapstructure:"index"` // proto index - Path string `yaml:"path" mapstructure:"path"` - Params []string `yaml:"params" mapstructure:"params"` - Value string `yaml:"value" mapstructure:"value"` -} - func Elapsed(what string) func() { start := time.Now() return func() { diff --git a/internal/server/config.go b/internal/server/config.go index dd51012cd..e09b5b631 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -15,7 +15,7 @@ type Config struct { // to access RulesPath files RulesPathSecret string `yaml:"ruleset_secret" mapstructure:"ruleset_secret"` - // TODO might not suitable here + // TODO might not suitable here because it is also being used by proxy // Headers which will have user's email id IdentityProxyHeader string `yaml:"identity_proxy_header" mapstructure:"identity_proxy_header" default:"X-Shield-Email"` diff --git a/pkg/grpc_interceptors/grpc_interceptors.go b/internal/server/grpc_interceptors/grpc_interceptors.go similarity index 73% rename from pkg/grpc_interceptors/grpc_interceptors.go rename to internal/server/grpc_interceptors/grpc_interceptors.go index 4d34db695..a0929bf57 100644 --- a/pkg/grpc_interceptors/grpc_interceptors.go +++ b/internal/server/grpc_interceptors/grpc_interceptors.go @@ -4,14 +4,11 @@ import ( "context" "fmt" + "github.com/odpf/shield/core/user" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) -const ( - identityCtx = "identityCtx" -) - func EnrichCtxWithIdentity(identityHeader string) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { md, ok := metadata.FromIncomingContext(ctx) @@ -25,12 +22,7 @@ func EnrichCtxWithIdentity(identityHeader string) grpc.UnaryServerInterceptor { email = metadataValues[0] } - ctx = context.WithValue(ctx, identityCtx, email) + ctx = user.SetContextWithEmail(ctx, email) return handler(ctx, req) } } - -func GetIdentityHeader(ctx context.Context) (string, bool) { - identity, ok := ctx.Value(identityCtx).(string) - return identity, ok -} diff --git a/internal/server/server.go b/internal/server/server.go index 66daca2cc..26f174513 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,7 @@ import ( "github.com/odpf/salt/server" "github.com/odpf/shield/internal/api" "github.com/odpf/shield/internal/api/v1beta1" - "github.com/odpf/shield/pkg/grpc_interceptors" + "github.com/odpf/shield/internal/server/grpc_interceptors" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" diff --git a/pkg/httputil/context.go b/pkg/httputil/context.go new file mode 100644 index 000000000..803b64de3 --- /dev/null +++ b/pkg/httputil/context.go @@ -0,0 +1,28 @@ +package httputil + +import ( + "context" +) + +type ( + contextRequestBodyKey struct{} + contextPathParamsKey struct{} +) + +func SetContextWithRequestBody(ctx context.Context, body []byte) context.Context { + return context.WithValue(ctx, contextRequestBodyKey{}, body) +} + +func GetRequestBodyFromContext(ctx context.Context) ([]byte, bool) { + body, ok := ctx.Value(contextRequestBodyKey{}).([]byte) + return body, ok +} + +func SetContextWithPathParams(ctx context.Context, params map[string]string) context.Context { + return context.WithValue(ctx, contextPathParamsKey{}, params) +} + +func GetPathParamsFromContext(ctx context.Context) (map[string]string, bool) { + params, ok := ctx.Value(contextPathParamsKey{}).(map[string]string) + return params, ok +} diff --git a/pkg/httputil/header.go b/pkg/httputil/header.go new file mode 100644 index 000000000..fbac59083 --- /dev/null +++ b/pkg/httputil/header.go @@ -0,0 +1,6 @@ +package httputil + +const ( + HeaderUserAgent = "User-Agent" + HeaderXUser = "X-User" +) diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go new file mode 100644 index 000000000..57976302f --- /dev/null +++ b/pkg/metadata/metadata.go @@ -0,0 +1,34 @@ +package metadata + +import ( + "fmt" + + "google.golang.org/protobuf/types/known/structpb" +) + +type Metadata map[string]any + +func (m Metadata) ToStructPB() (*structpb.Struct, error) { + newMap := make(map[string]interface{}) + + for key, value := range m { + newMap[key] = value + } + + return structpb.NewStruct(newMap) +} + +func Build(m map[string]interface{}) (Metadata, error) { + newMap := make(Metadata) + + for key, value := range m { + switch value := value.(type) { + case any: + newMap[key] = value + default: + return Metadata{}, fmt.Errorf("value for %s key is not string", key) + } + } + + return newMap, nil +} diff --git a/pkg/str/slug.go b/pkg/str/slug.go new file mode 100644 index 000000000..e635cafa6 --- /dev/null +++ b/pkg/str/slug.go @@ -0,0 +1,32 @@ +package str + +import "strings" + +type SlugifyOptions struct { + KeepHyphen bool + KeepColon bool + KeepHash bool +} + +func Slugify(str string, options SlugifyOptions) string { + str = strings.ToLower(str) + str = strings.ReplaceAll(str, " ", "_") + if !options.KeepHyphen { + str = strings.ReplaceAll(str, "-", "_") + } + if !options.KeepColon { + str = strings.ReplaceAll(str, ":", "_") + } + if !options.KeepHash { + str = strings.ReplaceAll(str, "#", "_") + } + return str +} + +func GenerateSlug(name string) string { + preProcessed := strings.ReplaceAll(strings.TrimSpace(strings.TrimSpace(name)), "_", "-") + return strings.Join( + strings.Split(preProcessed, " "), + "-", + ) +} diff --git a/pkg/str/utils.go b/pkg/str/utils.go index f76029cc3..25e427e81 100644 --- a/pkg/str/utils.go +++ b/pkg/str/utils.go @@ -1,33 +1,8 @@ package str -import ( - "strings" -) - func DefaultStringIfEmpty(str string, defaultString string) string { if str != "" { return str } return defaultString } - -type SlugifyOptions struct { - KeepHyphen bool - KeepColon bool - KeepHash bool -} - -func Slugify(str string, options SlugifyOptions) string { - str = strings.ToLower(str) - str = strings.ReplaceAll(str, " ", "_") - if !options.KeepHyphen { - str = strings.ReplaceAll(str, "-", "_") - } - if !options.KeepColon { - str = strings.ReplaceAll(str, ":", "_") - } - if !options.KeepHash { - str = strings.ReplaceAll(str, "#", "_") - } - return str -}