Skip to content

Commit

Permalink
refactor: proxy middleware (#135)
Browse files Browse the repository at this point in the history
* refactor: proxy middleware

* fix: update coverage badge

* fix: slug generation

* refactor: update create urn to use function receiver

* refactor: clean up identity proxy header
  • Loading branch information
mabdh authored Aug 4, 2022
1 parent 591ea0d commit 4e280ae
Show file tree
Hide file tree
Showing 48 changed files with 537 additions and 520 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
![package workflow](https://github.com/odpf/shield/actions/workflows/release.yml/badge.svg)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg?logo=apache)](LICENSE)
[![Version](https://img.shields.io/github/v/release/odpf/shield?logo=semantic-release)](Version)
[![Coverage Status](https://coveralls.io/repos/github/odpf/shield/badge.svg?branch=main)]
[![Coverage Status](https://coveralls.io/repos/github/odpf/shield/badge.svg?branch=main)](https://coveralls.io/github/odpf/shield?branch=main)

Shield is a cloud native role-based authorization aware reverse-proxy service. With Shield, you can assign roles to users or groups of users to configure policies that determine whether a particular user has the ability to perform a certain action on a given resource.

Expand Down
28 changes: 13 additions & 15 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ func serve(logger log.Logger, cfg *config.Shield) error {
return err
}

deps, err := buildAPIDependencies(ctx, logger, cfg.App.IdentityProxyHeader, resourceBlobRepository, dbClient, spiceDBClient)
deps, err := buildAPIDependencies(ctx, logger, resourceBlobRepository, dbClient, spiceDBClient)
if err != nil {
return err
}

// serving proxies
cbs, cps, err := serveProxies(ctx, logger, cfg.App.IdentityProxyHeader, cfg.Proxy, deps.ResourceService, deps.UserService)
cbs, cps, err := serveProxies(ctx, logger, cfg.App.IdentityProxyHeader, cfg.App.UserIDHeader, cfg.Proxy, deps.ResourceService, deps.UserService)
if err != nil {
return err
}
Expand Down Expand Up @@ -160,7 +160,6 @@ func serve(logger log.Logger, cfg *config.Shield) error {
func buildAPIDependencies(
ctx context.Context,
logger log.Logger,
identityProxyHeader string,
resourceBlobRepository *blob.ResourcesRepository,
dbc *db.Client,
sdb *spicedb.SpiceDB,
Expand All @@ -172,7 +171,7 @@ func buildAPIDependencies(
namespaceService := namespace.NewService(namespaceRepository)

userRepository := postgres.NewUserRepository(dbc)
userService := user.NewService(identityProxyHeader, userRepository)
userService := user.NewService(userRepository)

relationPGRepository := postgres.NewRelationRepository(dbc)
relationSpiceRepository := spicedb.NewRelationRepository(sdb)
Expand Down Expand Up @@ -216,17 +215,16 @@ func buildAPIDependencies(
}

dependencies := api.Deps{
OrgService: organizationService,
UserService: userService,
ProjectService: projectService,
GroupService: groupService,
RelationService: relationService,
ResourceService: resourceService,
RoleService: roleService,
PolicyService: policyService,
ActionService: actionService,
NamespaceService: namespaceService,
IdentityProxyHeader: identityProxyHeader,
OrgService: organizationService,
UserService: userService,
ProjectService: projectService,
GroupService: groupService,
RelationService: relationService,
ResourceService: resourceService,
RoleService: roleService,
PolicyService: policyService,
ActionService: actionService,
NamespaceService: namespaceService,
}
return dependencies, nil
}
Expand Down
14 changes: 7 additions & 7 deletions cmd/serve_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
func serveProxies(
ctx context.Context,
logger log.Logger,
identityProxyHeader string,
identityProxyHeaderKey, userIDHeaderKey string,
cfg proxy.ServicesConfig,
resourceService *resource.Service,
userService *user.Service,
Expand All @@ -32,7 +32,7 @@ func serveProxies(
var cleanUpProxies []func(ctx context.Context) error

for _, svcConfig := range cfg.Services {
hookPipeline := buildHookPipeline(logger, identityProxyHeader, resourceService)
hookPipeline := buildHookPipeline(logger, identityProxyHeaderKey, resourceService)

h2cProxy := proxy.NewH2c(
proxy.NewH2cRoundTripper(logger, hookPipeline),
Expand All @@ -57,7 +57,7 @@ func serveProxies(

ruleService := rule.NewService(ruleBlobRepository)

middlewarePipeline := buildMiddlewarePipeline(logger, h2cProxy, identityProxyHeader, resourceService, userService, ruleService)
middlewarePipeline := buildMiddlewarePipeline(logger, h2cProxy, identityProxyHeaderKey, userIDHeaderKey, resourceService, userService, ruleService)

cps := proxy.Serve(ctx, logger, svcConfig, middlewarePipeline)
cleanUpProxies = append(cleanUpProxies, cps)
Expand All @@ -67,23 +67,23 @@ func serveProxies(
return cleanUpBlobs, cleanUpProxies, nil
}

func buildHookPipeline(log log.Logger, identityProxyHeader string, resourceService v1beta1.ResourceService) hook.Service {
func buildHookPipeline(log log.Logger, identityProxyHeaderKey string, resourceService v1beta1.ResourceService) hook.Service {
rootHook := hook.New()
return authz_hook.New(log, rootHook, rootHook, identityProxyHeader, resourceService)
return authz_hook.New(log, rootHook, rootHook, identityProxyHeaderKey, resourceService)
}

// buildPipeline builds middleware sequence
func buildMiddlewarePipeline(
logger log.Logger,
proxy http.Handler,
identityProxyHeader string,
identityProxyHeaderKey, userIDHeaderKey string,
resourceService *resource.Service,
userService *user.Service,
ruleService *rule.Service,
) http.Handler {
// Note: execution order is bottom up
prefixWare := prefix.New(logger, proxy)
casbinAuthz := authz.New(logger, prefixWare, identityProxyHeader, resourceService, userService)
casbinAuthz := authz.New(logger, prefixWare, identityProxyHeaderKey, userIDHeaderKey, resourceService, userService)
basicAuthn := basic_auth.New(logger, casbinAuthz)
matchWare := rulematch.New(logger, basicAuthn, rulematch.NewRouteMatcher(ruleService))
return matchWare
Expand Down
3 changes: 2 additions & 1 deletion core/group/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/odpf/shield/core/organization"
"github.com/odpf/shield/core/relation"
"github.com/odpf/shield/core/user"
"github.com/odpf/shield/pkg/metadata"
)

type Repository interface {
Expand All @@ -29,7 +30,7 @@ type Group struct {
Slug string
Organization organization.Organization
OrganizationID string `json:"orgId"`
Metadata map[string]any
Metadata metadata.Metadata
CreatedAt time.Time
UpdatedAt time.Time
}
3 changes: 2 additions & 1 deletion core/organization/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/odpf/shield/core/user"
"github.com/odpf/shield/pkg/metadata"
)

type Repository interface {
Expand All @@ -21,7 +22,7 @@ type Organization struct {
ID string
Name string
Slug string
Metadata map[string]any
Metadata metadata.Metadata
CreatedAt time.Time
UpdatedAt time.Time
}
3 changes: 2 additions & 1 deletion core/project/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/odpf/shield/core/organization"
"github.com/odpf/shield/core/user"
"github.com/odpf/shield/pkg/metadata"
)

type Repository interface {
Expand All @@ -23,7 +24,7 @@ type Project struct {
Name string
Slug string
Organization organization.Organization
Metadata map[string]any
Metadata metadata.Metadata
CreatedAt time.Time
UpdatedAt time.Time
}
26 changes: 13 additions & 13 deletions core/resource/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,10 @@ type Resource struct {
UpdatedAt time.Time
}

type Filter struct {
ProjectID string
GroupID string
OrganizationID string
NamespaceID string
}

type YAML struct {
Name string `json:"name" yaml:"name"`
Actions map[string][]string `json:"actions" yaml:"actions"`
}

/*
/project/uuid/
*/
func CreateURN(res Resource) string {
func (res Resource) CreateURN() string {
isSystemNS := namespace.IsSystemNamespaceID(res.NamespaceID)
if isSystemNS {
return res.Name
Expand All @@ -70,3 +58,15 @@ func CreateURN(res Resource) string {
}
return fmt.Sprintf("r/%s/%s", res.NamespaceID, res.Name)
}

type Filter struct {
ProjectID string
GroupID string
OrganizationID string
NamespaceID string
}

type YAML struct {
Name string `json:"name" yaml:"name"`
Actions map[string][]string `json:"actions" yaml:"actions"`
}
4 changes: 2 additions & 2 deletions core/resource/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (s Service) Get(ctx context.Context, id string) (Resource, error) {
}

func (s Service) Create(ctx context.Context, res Resource) (Resource, error) {
urn := CreateURN(res)
urn := res.CreateURN()

usr, err := s.userService.FetchCurrentUser(ctx)
if err != nil {
Expand Down Expand Up @@ -218,7 +218,7 @@ func (s Service) CheckAuthz(ctx context.Context, res Resource, act action.Action
return false, err
}

res.URN = CreateURN(res)
res.URN = res.CreateURN()

isSystemNS := namespace.IsSystemNamespaceID(res.NamespaceID)
fetchedResource := res
Expand Down
3 changes: 2 additions & 1 deletion core/role/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/odpf/shield/core/namespace"
"github.com/odpf/shield/pkg/metadata"
)

type Repository interface {
Expand All @@ -22,7 +23,7 @@ type Role struct {
Types []string
Namespace namespace.Namespace
NamespaceID string
Metadata map[string]any
Metadata metadata.Metadata
CreatedAt time.Time
UpdatedAt time.Time
}
Expand Down
16 changes: 16 additions & 0 deletions core/rule/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package rule

import (
"context"
)

type contextRuleKey struct{}

func WithContext(ctx context.Context, rule *Rule) context.Context {
return context.WithValue(ctx, contextRuleKey{}, rule)
}

func GetFromContext(ctx context.Context) (*Rule, bool) {
rl, ok := ctx.Value(contextRuleKey{}).(*Rule)
return rl, ok
}
14 changes: 7 additions & 7 deletions core/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ type MiddlewareSpec struct {

type MiddlewareSpecs []MiddlewareSpec

type HookSpec struct {
Name string `yaml:"name"`
Config map[string]interface{} `yaml:"config"`
}

type HookSpecs []HookSpec

func (m MiddlewareSpecs) Get(name string) (MiddlewareSpec, bool) {
for _, n := range m {
if n.Name == name {
Expand All @@ -48,6 +41,13 @@ func (m MiddlewareSpecs) Get(name string) (MiddlewareSpec, bool) {
return MiddlewareSpec{}, false
}

type HookSpec struct {
Name string `yaml:"name"`
Config map[string]interface{} `yaml:"config"`
}

type HookSpecs []HookSpec

func (m HookSpecs) Get(name string) (HookSpec, bool) {
for _, n := range m {
if n.Name == name {
Expand Down
14 changes: 14 additions & 0 deletions core/user/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package user

import "context"

type contextEmailKey struct{}

func SetContextWithEmail(ctx context.Context, email string) context.Context {
return context.WithValue(ctx, contextEmailKey{}, email)
}

func GetEmailFromContext(ctx context.Context) (string, bool) {
email, ok := ctx.Value(contextEmailKey{}).(string)
return email, ok
}
50 changes: 7 additions & 43 deletions core/user/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,16 @@ package user

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
"errors"
)

var emailContext = struct{}{}

type Service struct {
repository Repository
identityProxyHeader string
repository Repository
}

func NewService(identityProxyHeader string, repository Repository) *Service {
func NewService(repository Repository) *Service {
return &Service{
identityProxyHeader: identityProxyHeader,
repository: repository,
repository: repository,
}
}

Expand Down Expand Up @@ -67,9 +61,9 @@ func (s Service) UpdateByEmail(ctx context.Context, toUpdate User) (User, error)
}

func (s Service) FetchCurrentUser(ctx context.Context) (User, error) {
email, err := fetchEmailFromMetadata(ctx, s.identityProxyHeader)
if err != nil {
return User{}, err
email, ok := GetEmailFromContext(ctx)
if !ok {
return User{}, errors.New("unable to fetch email from context")
}

fetchedUser, err := s.repository.GetByEmail(ctx, email)
Expand All @@ -79,33 +73,3 @@ func (s Service) FetchCurrentUser(ctx context.Context) (User, error) {

return fetchedUser, nil
}

// TODO need to simplify this, service package should not depend on grpc metadata
func fetchEmailFromMetadata(ctx context.Context, headerKey string) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
val, ok := GetEmailFromContext(ctx)
if !ok {
return "", fmt.Errorf("unable to fetch context from incoming")
}

return val, nil
}

var email string
metadataValues := md.Get(headerKey)
if len(metadataValues) > 0 {
email = metadataValues[0]
}
return email, nil
}

func SetEmailToContext(ctx context.Context, email string) context.Context {
return context.WithValue(ctx, emailContext, email)
}

func GetEmailFromContext(ctx context.Context) (string, bool) {
val, ok := ctx.Value(emailContext).(string)

return val, ok
}
Loading

0 comments on commit 4e280ae

Please sign in to comment.