diff --git a/events/auth.py b/events/auth.py index b2571bfa0..b932e2071 100644 --- a/events/auth.py +++ b/events/auth.py @@ -2,12 +2,28 @@ from django.contrib.gis.db import models from django.utils.translation import gettext_lazy as _ from django_orghierarchy.models import Organization +from helusers.oidc import ApiTokenAuthentication as HelApiTokenAuthentication from rest_framework import authentication, exceptions from events.models import DataSource from helevents.models import UserModelPermissionMixin +class ApiTokenAuthentication(HelApiTokenAuthentication): + def authenticate(self, request): + """Extract the AMR claim from the authentication payload.""" + auth_data = super().authenticate(request) + if not auth_data: + return auth_data + + user, auth = auth_data + + if amr_claim := auth.data.get("amr"): + user.token_amr_claim = amr_claim + + return user, auth + + class ApiKeyAuthentication(authentication.BaseAuthentication): def authenticate(self, request): # django converts 'apikey' to 'HTTP_APIKEY' outside runserver diff --git a/events/tests/test_auth.py b/events/tests/test_auth.py index be917f5be..556deb9e9 100644 --- a/events/tests/test_auth.py +++ b/events/tests/test_auth.py @@ -7,11 +7,14 @@ from helusers.oidc import ApiTokenAuthentication from helusers.settings import api_token_auth_settings from jose import jwt +from rest_framework import status from events.models import DataSource +from helevents.tests.factories import UserFactory from ..auth import ApiKeyUser from .keys import rsa_key +from .utils import versioned_reverse DEFAULT_ORGANIZATION_ID = "others" @@ -27,7 +30,7 @@ def global_requests_mock(requests_mock): req_mock = None -def get_api_token_for_user_with_scopes(user_uuid, scopes: list): +def get_api_token_for_user_with_scopes(user_uuid, scopes: list, amr: str = None): """Build a proper auth token with desired scopes.""" audience = api_token_auth_settings.AUDIENCE issuer = api_token_auth_settings.ISSUER @@ -51,6 +54,7 @@ def get_api_token_for_user_with_scopes(user_uuid, scopes: list): "sub": str(user_uuid), "iat": int(now.timestamp()), "exp": int(expire.timestamp()), + "amr": amr if amr else "github", auth_field: scopes, } encoded_jwt = jwt.encode( @@ -122,3 +126,26 @@ def test_valid_jwt_is_accepted(): user_uuid = uuid.UUID("b7a35517-eb1f-46c9-88bf-3206fb659c3c") user, jwt_value = do_authentication(user_uuid) assert user.uuid == user_uuid + + +@pytest.mark.parametrize("login_using_ad", [True, False]) +@pytest.mark.django_db +def test_user_is_external_based_on_login_method(api_client, settings, login_using_ad): + """Using AD authentication forces the User.is_external to False.""" + user = UserFactory() + detail_url = versioned_reverse("user-detail", kwargs={"pk": user.uuid}) + ad_method = "helsinkiazuread" + settings.NON_EXTERNAL_AUTHENTICATION_METHODS = [ad_method] + if login_using_ad: + auth_method = ad_method + else: + auth_method = "non-ad_method" + auth_header = get_api_token_for_user_with_scopes( + user.uuid, [api_token_auth_settings.API_SCOPE_PREFIX], amr=auth_method + ) + api_client.credentials(HTTP_AUTHORIZATION=auth_header) + + response = api_client.get(detail_url, format="json") + + assert response.status_code == status.HTTP_200_OK, str(response.content) + assert response.data["is_external"] != login_using_ad diff --git a/events/tests/test_permissions.py b/events/tests/test_permissions.py index 022b54c5c..0b4d429fc 100644 --- a/events/tests/test_permissions.py +++ b/events/tests/test_permissions.py @@ -94,7 +94,9 @@ def test_can_edit_event(membership_status, expected_public, expected_draft): ], ) @pytest.mark.django_db -def test_is_external_user(is_admin, is_regular_user, expected): +def test_user_is_external_based_on_group_membership( + is_admin, is_regular_user, expected +): with ( patch.object( UserModelPermissionMixin, "organization_memberships", new_callable=MagicMock diff --git a/helevents/models.py b/helevents/models.py index 04f9272aa..023881b36 100644 --- a/helevents/models.py +++ b/helevents/models.py @@ -1,3 +1,4 @@ +import logging from functools import reduce from django.conf import settings @@ -7,6 +8,8 @@ from events.models import PublicationStatus +logger = logging.getLogger(__name__) + class UserModelPermissionMixin: """Permission mixin for user models @@ -15,9 +18,29 @@ class UserModelPermissionMixin: for user models. """ + @property + def token_amr_claim(self) -> str: + claim = getattr(self, "_token_amr_claim", None) + if claim is None: + logger.warning( + "User.token_amr_claim used without a request or authentication.", + stack_info=True, + stacklevel=2, + ) + + return claim + + @token_amr_claim.setter + def token_amr_claim(self, value: str): + self._token_amr_claim = value + @cached_property def is_external(self): """Check if the user is an external user""" + + if self.token_amr_claim in settings.NON_EXTERNAL_AUTHENTICATION_METHODS: + return False + return ( not self.organization_memberships.exists() and not self.admin_organizations.exists() diff --git a/linkedevents/settings.py b/linkedevents/settings.py index 9d120fc64..988797e2a 100644 --- a/linkedevents/settings.py +++ b/linkedevents/settings.py @@ -77,6 +77,8 @@ def get_git_revision_hash() -> str: MAILGUN_API_KEY=(str, ""), MEDIA_ROOT=(environ.Path(), root("media")), MEDIA_URL=(str, "/media/"), + # "helsinkiazuread" = Tunnistamo auth_backends.helsinki_azure_ad.HelsinkiAzureADTenantOAuth2 + NON_EXTERNAL_AUTHENTICATION_METHODS=(list, ["helsinkiazuread"]), REDIS_SENTINELS=(list, []), REDIS_URL=(str, None), REDIS_PASSWORD=(str, None), @@ -304,6 +306,9 @@ def get_git_revision_hash() -> str: # Publisher used for events created by users without organization (i.e. external users) EXTERNAL_USER_PUBLISHER_ID = env("EXTERNAL_USER_PUBLISHER_ID") +# Which OIDC authentication methods are never considered as external users +NON_EXTERNAL_AUTHENTICATION_METHODS = env("NON_EXTERNAL_AUTHENTICATION_METHODS") + # # REST Framework # @@ -328,7 +333,7 @@ def get_git_revision_hash() -> str: ), "DEFAULT_AUTHENTICATION_CLASSES": ( "events.auth.ApiKeyAuthentication", - "helusers.oidc.ApiTokenAuthentication", + "events.auth.ApiTokenAuthentication", ), "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning", "VIEW_NAME_FUNCTION": "linkedevents.utils.get_view_name",