Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support modifying groups claim for social auth #640

Open
wants to merge 1 commit into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ansible_base/authentication/authenticator_plugins/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ class OpenIdConnectConfiguration(BaseAuthenticatorConfiguration):
ui_field_label=_("Username Key"),
)

GROUPS_CLAIM = CharField(
bhavenst marked this conversation as resolved.
Show resolved Hide resolved
help_text=_("The JSON key used to extract the user's groups from the ID token or userinfo endpoint."),
bhavenst marked this conversation as resolved.
Show resolved Hide resolved
required=False,
allow_null=True,
default="Group",
ui_field_label=_("Groups Claim"),
)


class AuthenticatorPlugin(SocialAuthMixin, OpenIdConnectAuth, AbstractAuthenticatorPlugin):
configuration_class = OpenIdConnectConfiguration
Expand All @@ -209,6 +217,10 @@ class AuthenticatorPlugin(SocialAuthMixin, OpenIdConnectAuth, AbstractAuthentica
category = "sso"
configuration_encrypted_fields = ['SECRET']

@property
def groups_claim(self):
return self.setting('GROUPS_CLAIM')

def extra_data(self, user, backend, response, *args, **kwargs):
for perm in ["is_superuser", get_setting('ANSIBLE_BASE_SOCIAL_AUDITOR_FLAG')]:
if perm in response:
Expand Down
5 changes: 4 additions & 1 deletion ansible_base/authentication/social_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(self, storage, request=None, tpl=None, additional_settings={}):
class SocialAuthMixin:
configuration_encrypted_fields = []
logger = None
groups_claim = "Group"

def __init__(self, *args, **kwargs):
# social auth expects the first arg to be a strategy instance. Since this has
Expand Down Expand Up @@ -190,7 +191,9 @@ def validate(self, serializer, data):
def create_user_claims_pipeline(*args, backend, response, **kwargs):
from ansible_base.authentication.utils.claims import update_user_claims

extra_groups = response["Group"] if "Group" in response else None
groups_claim = backend.groups_claim if backend.groups_claim is not None else "Group"

extra_groups = response[groups_claim] if groups_claim in response else []
user = update_user_claims(kwargs["user"], backend.database_instance, backend.get_user_groups(extra_groups))
if user is None:
return SOCIAL_AUTH_PIPELINE_FAILED_STATUS
56 changes: 55 additions & 1 deletion test_app/tests/authentication/test_social_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from django.conf import settings
from django.test import override_settings

from ansible_base.authentication.social_auth import AuthenticatorStorage, AuthenticatorStrategy, SocialAuthValidateCallbackMixin
from ansible_base.authentication.social_auth import (
AuthenticatorStorage,
AuthenticatorStrategy,
SocialAuthMixin,
SocialAuthValidateCallbackMixin,
create_user_claims_pipeline,
)


@mock.patch("ansible_base.authentication.social_auth.logger")
Expand Down Expand Up @@ -75,3 +81,51 @@ def test_social_auth_validate_callback_mixin(mocked_generate_slug, mocked_revers
# should always call reverse if no callback url
if has_instance and 'configuration' in test_data and not test_data.get('configuration', {}).get('CALLBACK_URL'):
assert mocked_reverse.called


@pytest.mark.parametrize(
"groups_claim,returned_groups,expected_groups",
[
(None, ["mygroup"], ["mygroup"]),
("groups", ["mygroup"], ["mygroup"]),
(None, None, []),
("groups", None, []),
],
)
@mock.patch("ansible_base.authentication.utils.claims.update_user_claims")
def test_create_user_claims_pipeline(mock_update_user_claims, groups_claim, returned_groups, expected_groups):
'''
We are testing to see if extracting groups from a claim is working correctly
'''

class MockBackend(SocialAuthMixin):
database_instance = None

def __init__(self, groups_claim=None):
if groups_claim is not None:
self.groups_claim = groups_claim

def get_user_groups(self, extra_groups=[]):
return extra_groups

backend = MockBackend(groups_claim=groups_claim)

rData = {}
if returned_groups is not None:
rData[backend.groups_claim] = returned_groups

user = {
'auth_time': "2024-11-07T05:19:08.224936Z",
'id_token': "asdf",
'refresh_token': None,
'id': "ccd2cf13-d927-41ad-cd8c-adb18b2e5f78",
'access_token': "asdf",
'token_type': "Bearer",
}

create_user_claims_pipeline(backend=backend, response=rData, user=user)

assert mock_update_user_claims.called
call_args = mock_update_user_claims.call_args

assert call_args == ((user, None, expected_groups),)
Loading