Skip to content

Commit

Permalink
support modifying groups claim
Browse files Browse the repository at this point in the history
  • Loading branch information
markafarrell committed Nov 11, 2024
1 parent f5581de commit f94b2dd
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
12 changes: 12 additions & 0 deletions ansible_base/authentication/authenticator_plugins/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ class OpenIdConnectConfiguration(BaseAuthenticatorConfiguration):
ui_field_label=_("Username Key"),
)

GROUPS_CLAIM = CharField(
help_text=_("The JSON key used to extract the user's groups from the ID token or userinfo endpoint."),
required=False,
allow_null=True,
default="Group",
ui_field_label=_("Groups Claim"),
)


class AuthenticatorPlugin(SocialAuthMixin, OpenIdConnectAuth, AbstractAuthenticatorPlugin):
configuration_class = OpenIdConnectConfiguration
Expand All @@ -209,6 +217,10 @@ class AuthenticatorPlugin(SocialAuthMixin, OpenIdConnectAuth, AbstractAuthentica
category = "sso"
configuration_encrypted_fields = ['SECRET']

@property
def groups_claim(self):
return self.setting('GROUPS_CLAIM')

def extra_data(self, user, backend, response, *args, **kwargs):
for perm in ["is_superuser", get_setting('ANSIBLE_BASE_SOCIAL_AUDITOR_FLAG')]:
if perm in response:
Expand Down
5 changes: 4 additions & 1 deletion ansible_base/authentication/social_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(self, storage, request=None, tpl=None, additional_settings={}):
class SocialAuthMixin:
configuration_encrypted_fields = []
logger = None
groups_claim = "Group"

def __init__(self, *args, **kwargs):
# social auth expects the first arg to be a strategy instance. Since this has
Expand Down Expand Up @@ -190,7 +191,9 @@ def validate(self, serializer, data):
def create_user_claims_pipeline(*args, backend, response, **kwargs):
from ansible_base.authentication.utils.claims import update_user_claims

extra_groups = response["Group"] if "Group" in response else None
groups_claim = backend.groups_claim if backend.groups_claim is not None else "Group"

extra_groups = response[groups_claim] if groups_claim in response else []
user = update_user_claims(kwargs["user"], backend.database_instance, backend.get_user_groups(extra_groups))
if user is None:
return SOCIAL_AUTH_PIPELINE_FAILED_STATUS
56 changes: 55 additions & 1 deletion test_app/tests/authentication/test_social_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from django.conf import settings
from django.test import override_settings

from ansible_base.authentication.social_auth import AuthenticatorStorage, AuthenticatorStrategy, SocialAuthValidateCallbackMixin
from ansible_base.authentication.social_auth import (
AuthenticatorStorage,
AuthenticatorStrategy,
SocialAuthMixin,
SocialAuthValidateCallbackMixin,
create_user_claims_pipeline,
)


@mock.patch("ansible_base.authentication.social_auth.logger")
Expand Down Expand Up @@ -75,3 +81,51 @@ def test_social_auth_validate_callback_mixin(mocked_generate_slug, mocked_revers
# should always call reverse if no callback url
if has_instance and 'configuration' in test_data and not test_data.get('configuration', {}).get('CALLBACK_URL'):
assert mocked_reverse.called


@pytest.mark.parametrize(
"groups_claim,returned_groups,expected_groups",
[
(None, ["mygroup"], ["mygroup"]),
("groups", ["mygroup"], ["mygroup"]),
(None, None, []),
("groups", None, []),
],
)
@mock.patch("ansible_base.authentication.utils.claims.update_user_claims")
def test_create_user_claims_pipeline(mock_update_user_claims, groups_claim, returned_groups, expected_groups):
'''
We are testing to see if extracting groups from a claim is working correctly
'''

class MockBackend(SocialAuthMixin):
database_instance = None

def __init__(self, groups_claim=None):
if groups_claim is not None:
self.groups_claim = groups_claim

def get_user_groups(self, extra_groups=[]):
return extra_groups

backend = MockBackend(groups_claim=groups_claim)

rData = {}
if returned_groups is not None:
rData[backend.groups_claim] = returned_groups

user = {
'auth_time': "2024-11-07T05:19:08.224936Z",
'id_token': "asdf",
'refresh_token': None,
'id': "ccd2cf13-d927-41ad-cd8c-adb18b2e5f78",
'access_token': "asdf",
'token_type': "Bearer",
}

create_user_claims_pipeline(backend=backend, response=rData, user=user)

assert mock_update_user_claims.called
call_args = mock_update_user_claims.call_args

assert call_args == ((user, None, expected_groups),)

0 comments on commit f94b2dd

Please sign in to comment.