Skip to content

Commit

Permalink
feat: principal check before create relation
Browse files Browse the repository at this point in the history
  • Loading branch information
FemiNoviaLina committed Oct 3, 2024
1 parent d141af5 commit ae2a0d9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
33 changes: 16 additions & 17 deletions core/user/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/goto/shield/core/activity"
"github.com/goto/shield/core/user"
"github.com/goto/shield/core/user/mocks"
"github.com/goto/shield/pkg/logger"
shieldlogger "github.com/goto/shield/pkg/logger"
"github.com/goto/shield/pkg/uuid"
)
Expand Down Expand Up @@ -39,7 +38,7 @@ func TestService_Create(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{
Expand Down Expand Up @@ -80,7 +79,7 @@ func TestService_Create(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{}, user.ErrNotExist)
Expand Down Expand Up @@ -117,7 +116,7 @@ func TestService_Create(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{}, user.ErrNotExist)
Expand All @@ -135,7 +134,7 @@ func TestService_Create(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "").
Return(user.User{}, user.ErrMissingEmail)
Expand Down Expand Up @@ -188,7 +187,7 @@ func TestService_CreateMetadataKey(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{}, nil)
Expand Down Expand Up @@ -223,7 +222,7 @@ func TestService_CreateMetadataKey(t *testing.T) {
},
setup: func(t *testing.T) *user.Service {
t.Helper()
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
activityService := &mocks.ActivityService{}
repository := &mocks.Repository{}
repository.EXPECT().
Expand Down Expand Up @@ -280,7 +279,7 @@ func TestService_List(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{}, nil)
Expand Down Expand Up @@ -315,7 +314,7 @@ func TestService_List(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{}, nil)
Expand Down Expand Up @@ -375,7 +374,7 @@ func TestService_UpdateByID(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{}, nil)
Expand Down Expand Up @@ -446,7 +445,7 @@ func TestService_UpdateByEmail(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{}, nil)
Expand Down Expand Up @@ -511,7 +510,7 @@ func TestService_GetByEmail(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{
Expand Down Expand Up @@ -571,7 +570,7 @@ func TestService_GetByID(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByID(mock.Anything, "qwer-1234-tyui-5678-opas-90").
Return(user.User{
Expand Down Expand Up @@ -631,7 +630,7 @@ func TestService_FetchCurrentUser(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository.EXPECT().
GetByEmail(mock.Anything, "[email protected]").
Return(user.User{
Expand All @@ -657,7 +656,7 @@ func TestService_FetchCurrentUser(t *testing.T) {
t.Helper()
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
return user.NewService(logger, user.AppConfig{}, repository, activityService)
},
wantErr: user.ErrMissingEmail,
Expand Down Expand Up @@ -702,7 +701,7 @@ func TestService_DeleteUser(t *testing.T) {
name: "return error from delete by id",
setup: func(t *testing.T) *user.Service {
t.Helper()
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
repository.EXPECT().
Expand All @@ -717,7 +716,7 @@ func TestService_DeleteUser(t *testing.T) {
name: "return error from delete by email",
setup: func(t *testing.T) *user.Service {
t.Helper()
logger := shieldlogger.InitLogger(logger.Config{})
logger := shieldlogger.InitLogger(shieldlogger.Config{})
repository := &mocks.Repository{}
activityService := &mocks.ActivityService{}
repository.EXPECT().
Expand Down
17 changes: 17 additions & 0 deletions internal/api/v1beta1/relation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/goto/shield/core/action"
"github.com/goto/shield/core/group"
"github.com/goto/shield/core/namespace"
"github.com/goto/shield/core/resource"
"github.com/goto/shield/core/user"
Expand Down Expand Up @@ -77,6 +78,22 @@ func (h Handler) CreateRelation(ctx context.Context, request *shieldv1beta1.Crea
}

principal, subjectID := extractSubjectFromPrincipal(request.GetBody().GetSubject())
var err error
switch principal {
case strings.Split(schema.UserPrincipal, "/")[1]:
_, err = h.userService.Get(ctx, subjectID)
case strings.Split(schema.GroupPrincipal, "/")[1]:
_, err = h.groupService.Get(ctx, subjectID)
}
if err != nil {
switch {
case errors.Is(err, user.ErrNotExist), errors.Is(err, group.ErrNotExist):
logger.Error(err.Error())
return nil, grpcBadBodyError
default:
return nil, grpcInternalServerError
}
}

result, err := h.resourceService.CheckAuthz(ctx, resource.Resource{
Name: request.GetBody().GetObjectId(),
Expand Down

0 comments on commit ae2a0d9

Please sign in to comment.