diff --git a/ansible_base/lib/utils/hashing.py b/ansible_base/lib/utils/hashing.py index 63b1a870b..933579ef4 100644 --- a/ansible_base/lib/utils/hashing.py +++ b/ansible_base/lib/utils/hashing.py @@ -16,3 +16,18 @@ def hash_serializer_data(instance: Model, serializer: Type[Serializer], field: O serialized_data = serialized_data[field] metadata_json = json.dumps(serialized_data, sort_keys=True).encode("utf-8") return hasher(metadata_json).hexdigest() + + +def hash_string(inp: str, hasher: Callable = hashlib.sha256, algo=""): + """ + Takes a string and hashes it with the given hasher function. + If algo is given, it is prepended to the hash between dollar signs ($) + before the hash is returned. + + NOTE: There is no salt or pepper here, so this is not secure for passwords. + It is, however, useful for *random* strings like tokens, that need to be secured. + """ + hash = hasher(inp.encode("utf-8")).hexdigest() + if algo: + return f"${algo}${hash}" + return hash diff --git a/ansible_base/oauth2_provider/authentication.py b/ansible_base/oauth2_provider/authentication.py index 6df5c6de2..3329592f5 100644 --- a/ansible_base/oauth2_provider/authentication.py +++ b/ansible_base/oauth2_provider/authentication.py @@ -1,3 +1,4 @@ +import hashlib import logging from django.utils.encoding import smart_str @@ -5,6 +6,8 @@ from oauth2_provider.oauth2_backends import OAuthLibCore as _OAuthLibCore from rest_framework.exceptions import UnsupportedMediaType +from ansible_base.lib.utils.hashing import hash_string + logger = logging.getLogger('ansible_base.oauth2_provider.authentication') @@ -18,7 +21,24 @@ def extract_body(self, request): class LoggedOAuth2Authentication(OAuth2Authentication): def authenticate(self, request): - ret = super().authenticate(request) + # sha256 the bearer token. We store the hash in the database + # and this gives us a place to hash the incoming token for comparison + did_hash_token = False + bearer_token = request.META.get('HTTP_AUTHORIZATION') + if bearer_token and bearer_token.lower().startswith('bearer '): + token_component = bearer_token.split(' ', 1)[1] + hashed = hash_string(token_component, hasher=hashlib.sha256, algo="sha256") + did_hash_token = True + request.META['HTTP_AUTHORIZATION'] = f"Bearer {hashed}" + + # We don't /really/ want to modify the request, so after we're done authing, + # revert what we did above. + try: + ret = super().authenticate(request) + finally: + if did_hash_token: + request.META['HTTP_AUTHORIZATION'] = bearer_token + if ret: user, token = ret username = user.username if user else '' diff --git a/ansible_base/oauth2_provider/fixtures.py b/ansible_base/oauth2_provider/fixtures.py index 06580714a..e96dd230c 100644 --- a/ansible_base/oauth2_provider/fixtures.py +++ b/ansible_base/oauth2_provider/fixtures.py @@ -1,9 +1,11 @@ +import hashlib from datetime import datetime, timezone import pytest from oauthlib.common import generate_token from ansible_base.lib.testing.fixtures import copy_fixture +from ansible_base.lib.utils.hashing import hash_string from ansible_base.lib.utils.response import get_relative_url from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2Application @@ -62,10 +64,18 @@ def oauth2_application_password(randname): @pytest.fixture def oauth2_admin_access_token(oauth2_application, admin_api_client, admin_user): + """ + 3-tuple with (token object with hashed token, plaintext token, plaintext_refresh_token) + """ url = get_relative_url('token-list') response = admin_api_client.post(url, {'application': oauth2_application[0].pk}) assert response.status_code == 201 - return OAuth2AccessToken.objects.get(token=response.data['token']) + + plaintext_token = response.data['token'] + plaintext_refresh_token = response.data['refresh_token'] + hashed_token = hash_string(plaintext_token, hasher=hashlib.sha256, algo="sha256") + token = OAuth2AccessToken.objects.get(token=hashed_token) + return (token, plaintext_token, plaintext_refresh_token) @copy_fixture(copies=3) diff --git a/ansible_base/oauth2_provider/management/commands/create_oauth2_token.py b/ansible_base/oauth2_provider/management/commands/create_oauth2_token.py index d50e4291b..12f298e9d 100644 --- a/ansible_base/oauth2_provider/management/commands/create_oauth2_token.py +++ b/ansible_base/oauth2_provider/management/commands/create_oauth2_token.py @@ -31,5 +31,5 @@ def __init__(self): self.user = user serializer_obj.context['request'] = FakeRequest() - token_record = serializer_obj.create(config) - self.stdout.write(token_record.token) + serializer_obj.create(config) + self.stdout.write(serializer_obj.unencrypted_token) diff --git a/ansible_base/oauth2_provider/migrations/0005_hash_existing_tokens.py b/ansible_base/oauth2_provider/migrations/0005_hash_existing_tokens.py new file mode 100644 index 000000000..13849bd4f --- /dev/null +++ b/ansible_base/oauth2_provider/migrations/0005_hash_existing_tokens.py @@ -0,0 +1,13 @@ +from django.db import migrations + +from ansible_base.oauth2_provider.migrations._utils import hash_tokens + + +class Migration(migrations.Migration): + dependencies = [ + ("dab_oauth2_provider", "0004_alter_oauth2accesstoken_scope"), + ] + + operations = [ + migrations.RunPython(hash_tokens), + ] diff --git a/ansible_base/oauth2_provider/migrations/_utils.py b/ansible_base/oauth2_provider/migrations/_utils.py new file mode 100644 index 000000000..fa2f08bb8 --- /dev/null +++ b/ansible_base/oauth2_provider/migrations/_utils.py @@ -0,0 +1,16 @@ +import hashlib + +from ansible_base.lib.utils.hashing import hash_string + + +def hash_tokens(apps, schema_editor): + OAuth2AccessToken = apps.get_model("dab_oauth2_provider", "OAuth2AccessToken") + OAuth2RefreshToken = apps.get_model("dab_oauth2_provider", "OAuth2RefreshToken") + for model in (OAuth2AccessToken, OAuth2RefreshToken): + for token in model.objects.all(): + # Never re-hash a hashed token + if token.token.startswith("$"): + continue + hashed = hash_string(token.token, hasher=hashlib.sha256, algo="sha256") + token.token = hashed + token.save() diff --git a/ansible_base/oauth2_provider/models/access_token.py b/ansible_base/oauth2_provider/models/access_token.py index 05d3cc21f..e5cd6c930 100644 --- a/ansible_base/oauth2_provider/models/access_token.py +++ b/ansible_base/oauth2_provider/models/access_token.py @@ -1,3 +1,5 @@ +import hashlib + import oauth2_provider.models as oauth2_models from django.conf import settings from django.core.exceptions import ValidationError @@ -7,6 +9,7 @@ from oauthlib import oauth2 from ansible_base.lib.abstract_models.common import CommonModel +from ansible_base.lib.utils.hashing import hash_string from ansible_base.lib.utils.models import prevent_search from ansible_base.lib.utils.settings import get_setting from ansible_base.oauth2_provider.utils import is_external_account @@ -103,4 +106,5 @@ def validate_external_users(self): def save(self, *args, **kwargs): if not self.pk: self.validate_external_users() + self.token = hash_string(self.token, hasher=hashlib.sha256, algo="sha256") super().save(*args, **kwargs) diff --git a/ansible_base/oauth2_provider/models/refresh_token.py b/ansible_base/oauth2_provider/models/refresh_token.py index 078a87cf9..a782b35b2 100644 --- a/ansible_base/oauth2_provider/models/refresh_token.py +++ b/ansible_base/oauth2_provider/models/refresh_token.py @@ -1,9 +1,12 @@ +import hashlib + import oauth2_provider.models as oauth2_models from django.conf import settings from django.db import models from django.utils.translation import gettext_lazy as _ from ansible_base.lib.abstract_models.common import CommonModel +from ansible_base.lib.utils.hashing import hash_string from ansible_base.lib.utils.models import prevent_search activitystream = object @@ -21,3 +24,8 @@ class Meta(oauth2_models.AbstractRefreshToken.Meta): token = prevent_search(models.CharField(max_length=255)) updated = None # Tracked in CommonModel with 'modified', no need for this + + def save(self, *args, **kwargs): + if not self.pk: + self.token = hash_string(self.token, hasher=hashlib.sha256, algo="sha256") + super().save(*args, **kwargs) diff --git a/ansible_base/oauth2_provider/serializers/token.py b/ansible_base/oauth2_provider/serializers/token.py index 9c9a9aa0e..427e96e18 100644 --- a/ansible_base/oauth2_provider/serializers/token.py +++ b/ansible_base/oauth2_provider/serializers/token.py @@ -20,9 +20,11 @@ logger = logging.getLogger("ansible_base.oauth2_provider.serializers.token") -class BaseOAuth2TokenSerializer(CommonModelSerializer): +class OAuth2TokenSerializer(CommonModelSerializer): refresh_token = SerializerMethodField() - token = SerializerMethodField() + + unencrypted_token = None # Only used in POST so we can return the token in the response + unencrypted_refresh_token = None # Only used in POST so we can return the refresh token in the response class Meta: model = OAuth2AccessToken @@ -40,15 +42,15 @@ class Meta: read_only_fields = ('user', 'token', 'expires', 'refresh_token') extra_kwargs = {'scope': {'allow_null': False, 'required': False}, 'user': {'allow_null': False, 'required': True}} - def get_token(self, obj) -> str: - request = self.context.get('request') - try: - if request and request.method == 'POST': - return obj.token - else: - return ENCRYPTED_STRING - except ObjectDoesNotExist: - return '' + def to_representation(self, instance): + request = self.context.get('request', None) + ret = super().to_representation(instance) + if request and request.method == 'POST': + # If we're creating the token, show it. Otherwise, show the encrypted string. + ret['token'] = self.unencrypted_token + else: + ret['token'] = ENCRYPTED_STRING + return ret def get_refresh_token(self, obj) -> Optional[str]: request = self.context.get('request') @@ -56,7 +58,7 @@ def get_refresh_token(self, obj) -> Optional[str]: if not obj.refresh_token: return None elif request and request.method == 'POST': - return getattr(obj.refresh_token, 'token', '') + return self.unencrypted_refresh_token else: return ENCRYPTED_STRING except ObjectDoesNotExist: @@ -78,15 +80,6 @@ def validate_scope(self, value): raise ValidationError(_('Must be a simple space-separated string with allowed scopes {}.').format(SCOPES)) return value - def create(self, validated_data): - validated_data['user'] = self.context['request'].user - try: - return super().create(validated_data) - except AccessDeniedError as e: - raise PermissionDenied(str(e)) - - -class OAuth2TokenSerializer(BaseOAuth2TokenSerializer): def create(self, validated_data): current_user = get_current_user() validated_data['token'] = generate_token() @@ -94,10 +87,23 @@ def create(self, validated_data): if expires_delta == 0: logger.warning("OAUTH2_PROVIDER.ACCESS_TOKEN_EXPIRE_SECONDS was set to 0, creating token that has already expired") validated_data['expires'] = now() + timedelta(seconds=expires_delta) - obj = super().create(validated_data) + validated_data['user'] = self.context['request'].user + self.unencrypted_token = validated_data.get('token') # Before it is hashed + + try: + obj = super().create(validated_data) + except AccessDeniedError as e: + raise PermissionDenied(str(e)) + if obj.application and obj.application.user: obj.user = obj.application.user obj.save() if obj.application: - OAuth2RefreshToken.objects.create(user=current_user, token=generate_token(), application=obj.application, access_token=obj) + self.unencrypted_refresh_token = generate_token() + OAuth2RefreshToken.objects.create( + user=current_user, + token=self.unencrypted_refresh_token, + application=obj.application, + access_token=obj, + ) return obj diff --git a/ansible_base/oauth2_provider/views/token.py b/ansible_base/oauth2_provider/views/token.py index 39e087089..4b8db3844 100644 --- a/ansible_base/oauth2_provider/views/token.py +++ b/ansible_base/oauth2_provider/views/token.py @@ -1,3 +1,4 @@ +import hashlib from datetime import timedelta from django.utils.timezone import now @@ -5,6 +6,7 @@ from oauthlib import oauth2 from rest_framework.viewsets import ModelViewSet +from ansible_base.lib.utils.hashing import hash_string from ansible_base.lib.utils.settings import get_setting from ansible_base.lib.utils.views.django_app_api import AnsibleBaseDjangoAppApiView from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken @@ -28,7 +30,8 @@ def create_token_response(self, request): # This code detects and auto-expires them on refresh grant # requests. if request.POST.get('grant_type') == 'refresh_token' and 'refresh_token' in request.POST: - refresh_token = OAuth2RefreshToken.objects.filter(token=request.POST['refresh_token']).first() + hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256, algo="sha256") + refresh_token = OAuth2RefreshToken.objects.filter(token=hashed_refresh_token).first() if refresh_token: expire_seconds = get_setting('OAUTH2_PROVIDER', {}).get('REFRESH_TOKEN_EXPIRE_SECONDS', 0) if refresh_token.created + timedelta(seconds=expire_seconds) < now(): @@ -38,7 +41,23 @@ def create_token_response(self, request): # oauth2_provider.oauth2_backends.OAuthLibCore.create_token_response # (we override this so we can implement our own error handling to be compatible with AWX) - uri, http_method, body, headers = core._extract_params(request) + + # This is really, really ugly. Modify the request to hash the refresh_token + # but only long enough for the oauth lib to do its magic. + did_hash_refresh_token = False + old_post = request.POST + if 'refresh_token' in request.POST: + did_hash_refresh_token = True + request.POST = request.POST.copy() # so it's mutable + hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256, algo="sha256") + request.POST['refresh_token'] = hashed_refresh_token + + try: + uri, http_method, body, headers = core._extract_params(request) + finally: + if did_hash_refresh_token: + request.POST = old_post + extra_credentials = core._get_extra_credentials(request) try: headers, body, status = core.server.create_token_response(uri, http_method, body, headers, extra_credentials) diff --git a/test_app/tests/oauth2_provider/management/commands/test_cleanup_tokens.py b/test_app/tests/oauth2_provider/management/commands/test_cleanup_tokens.py index 0e02854d9..715e426d5 100644 --- a/test_app/tests/oauth2_provider/management/commands/test_cleanup_tokens.py +++ b/test_app/tests/oauth2_provider/management/commands/test_cleanup_tokens.py @@ -35,7 +35,7 @@ def test_cleanup_expired_tokens(self, oauth2_admin_access_token): attempt_cleanup(0, 0) # Manually expire admin token - oauth2_admin_access_token.expires = datetime.datetime.fromtimestamp(0) - oauth2_admin_access_token.save() + oauth2_admin_access_token[0].expires = datetime.datetime.fromtimestamp(0) + oauth2_admin_access_token[0].save() attempt_cleanup(1, 1) diff --git a/test_app/tests/oauth2_provider/management/commands/test_create_oauth2_token.py b/test_app/tests/oauth2_provider/management/commands/test_create_oauth2_token.py index 83eb4c728..e60fc0e02 100644 --- a/test_app/tests/oauth2_provider/management/commands/test_create_oauth2_token.py +++ b/test_app/tests/oauth2_provider/management/commands/test_create_oauth2_token.py @@ -1,4 +1,5 @@ # Python +import hashlib import random import string from io import StringIO @@ -10,6 +11,8 @@ from django.core.management import call_command from django.core.management.base import CommandError +from ansible_base.lib.utils.hashing import hash_string +from ansible_base.lib.utils.response import get_relative_url from ansible_base.oauth2_provider.models import OAuth2AccessToken User = get_user_model() @@ -34,11 +37,21 @@ def test_non_existing_user(self): call_command('create_oauth2_token', arg, stdout=out) assert 'The user does not exist.' in str(excinfo.value) - def test_correct_user(self, random_user): + def test_correct_user(self, random_user, unauthenticated_api_client): user_username = random_user.username with StringIO() as out: arg = '--user=' + user_username call_command('create_oauth2_token', arg, stdout=out) generated_token = out.getvalue().strip() - assert OAuth2AccessToken.objects.filter(user=random_user, token=generated_token).count() == 1 - assert OAuth2AccessToken.objects.get(user=random_user, token=generated_token).scope == 'write' + + hashed_token = hash_string(generated_token, hasher=hashlib.sha256, algo="sha256") + assert OAuth2AccessToken.objects.filter(user=random_user, token=hashed_token).count() == 1 + assert OAuth2AccessToken.objects.get(user=random_user, token=hashed_token).scope == 'write' + + url = get_relative_url("user-me") + response = unauthenticated_api_client.get( + url, + headers={'Authorization': f'Bearer {generated_token}'}, + ) + assert response.status_code == 200 + assert response.data['username'] == user_username diff --git a/test_app/tests/oauth2_provider/management/commands/test_revoke_oauth2_tokens.py b/test_app/tests/oauth2_provider/management/commands/test_revoke_oauth2_tokens.py index 173423f19..47a4bf844 100644 --- a/test_app/tests/oauth2_provider/management/commands/test_revoke_oauth2_tokens.py +++ b/test_app/tests/oauth2_provider/management/commands/test_revoke_oauth2_tokens.py @@ -29,7 +29,7 @@ def test_revoke_all_access_tokens(self, oauth2_admin_access_token, oauth2_user_a def test_revoke_access_token_for_user(self, oauth2_admin_access_token, oauth2_user_application_token): with StringIO() as out: - admin_username = oauth2_admin_access_token.user.username + admin_username = oauth2_admin_access_token[0].user.username user_username = oauth2_user_application_token.user.username assert OAuth2AccessToken.objects.count() == 2 diff --git a/test_app/tests/oauth2_provider/migrations/__init__.py b/test_app/tests/oauth2_provider/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test_app/tests/oauth2_provider/migrations/test_utils.py b/test_app/tests/oauth2_provider/migrations/test_utils.py new file mode 100644 index 000000000..d95976b5d --- /dev/null +++ b/test_app/tests/oauth2_provider/migrations/test_utils.py @@ -0,0 +1,44 @@ +from django.apps import apps + +from ansible_base.lib.utils.response import get_relative_url +from ansible_base.oauth2_provider.migrations._utils import hash_tokens + + +def test_oauth2_migrations_hash_tokens(unauthenticated_api_client, oauth2_admin_access_token): + """ + Force an unhashed token, run the migration function, and ensure the token is hashed. + """ + unhashed_token = oauth2_admin_access_token[1] + oauth2_admin_access_token[0].token = unhashed_token + oauth2_admin_access_token[0].save() + + url = get_relative_url("user-me") + response = unauthenticated_api_client.get( + url, + headers={'Authorization': f'Bearer {oauth2_admin_access_token[1]}'}, + ) + # When we set the token back to unhashed, we shouldn't be able to auth with it. + assert response.status_code == 401 + + hash_tokens(apps, None) + + url = get_relative_url("user-me") + response = unauthenticated_api_client.get( + url, + headers={'Authorization': f'Bearer {oauth2_admin_access_token[1]}'}, + ) + # Now it's been hashed, so we can auth + assert response.status_code == 200 + assert response.data['username'] == oauth2_admin_access_token[0].user.username + + # And if we re-run the hash function again for some reason, we never double-hash + hash_tokens(apps, None) + + url = get_relative_url("user-me") + response = unauthenticated_api_client.get( + url, + headers={'Authorization': f'Bearer {oauth2_admin_access_token[1]}'}, + ) + # We can still auth + assert response.status_code == 200 + assert response.data['username'] == oauth2_admin_access_token[0].user.username diff --git a/test_app/tests/oauth2_provider/test_authentication.py b/test_app/tests/oauth2_provider/test_authentication.py index 6ae938c22..4efacd0c3 100644 --- a/test_app/tests/oauth2_provider/test_authentication.py +++ b/test_app/tests/oauth2_provider/test_authentication.py @@ -22,10 +22,10 @@ def test_oauth2_bearer_get_user_correct(unauthenticated_api_client, oauth2_admin url = get_relative_url("user-me") response = unauthenticated_api_client.get( url, - headers={'Authorization': f'Bearer {oauth2_admin_access_token.token}'}, + headers={'Authorization': f'Bearer {oauth2_admin_access_token[1]}'}, ) assert response.status_code == 200 - assert response.data['username'] == oauth2_admin_access_token.user.username + assert response.data['username'] == oauth2_admin_access_token[0].user.username @pytest.mark.parametrize( @@ -40,7 +40,7 @@ def test_oauth2_bearer_get(unauthenticated_api_client, oauth2_admin_access_token GET an animal with a bearer token. """ url = get_relative_url("animal-detail", kwargs={"pk": animal.pk}) - token = oauth2_admin_access_token.token if token == 'fixture' else generate_token() + token = oauth2_admin_access_token[1] if token == 'fixture' else generate_token() response = unauthenticated_api_client.get( url, headers={'Authorization': f'Bearer {token}'}, @@ -62,7 +62,7 @@ def test_oauth2_bearer_post(unauthenticated_api_client, oauth2_admin_access_toke POST an animal with a bearer token. """ url = get_relative_url("animal-list") - token = oauth2_admin_access_token.token if token == 'fixture' else generate_token() + token = oauth2_admin_access_token[1] if token == 'fixture' else generate_token() data = { "name": "Fido", "owner": admin_user.pk, @@ -89,7 +89,7 @@ def test_oauth2_bearer_patch(unauthenticated_api_client, oauth2_admin_access_tok PATCH an animal with a bearer token. """ url = get_relative_url("animal-detail", kwargs={"pk": animal.pk}) - token = oauth2_admin_access_token.token if token == 'fixture' else generate_token() + token = oauth2_admin_access_token[1] if token == 'fixture' else generate_token() data = { "name": "Fido", } @@ -115,7 +115,7 @@ def test_oauth2_bearer_put(unauthenticated_api_client, oauth2_admin_access_token PUT an animal with a bearer token. """ url = get_relative_url("animal-detail", kwargs={"pk": animal.pk}) - token = oauth2_admin_access_token.token if token == 'fixture' else generate_token() + token = oauth2_admin_access_token[1] if token == 'fixture' else generate_token() data = { "name": "Fido", "owner": admin_user.pk, @@ -135,8 +135,8 @@ def test_oauth2_bearer_no_activitystream(unauthenticated_api_client, oauth2_admi Ensure no activitystream entries for bearer token based auth """ url = get_relative_url("animal-detail", kwargs={"pk": animal.pk}) - token = oauth2_admin_access_token.token - existing_as_count = len(oauth2_admin_access_token.activity_stream_entries) + token = oauth2_admin_access_token[1] + existing_as_count = len(oauth2_admin_access_token[0].activity_stream_entries) response = unauthenticated_api_client.get( url, @@ -145,7 +145,7 @@ def test_oauth2_bearer_no_activitystream(unauthenticated_api_client, oauth2_admi assert response.status_code == 200 assert response.data['name'] == animal.name - updated_token = OAuth2AccessToken.objects.get(token=token) + updated_token = OAuth2AccessToken.objects.get(token=oauth2_admin_access_token[0].token) assert len(updated_token.activity_stream_entries) == existing_as_count @@ -163,8 +163,8 @@ def test_oauth2_scope_permission(request, admin_user, oauth2_admin_access_token, """ Ensure that scopes are adhered to for PATs """ - oauth2_admin_access_token.scope = scope - oauth2_admin_access_token.save() + oauth2_admin_access_token[0].scope = scope + oauth2_admin_access_token[0].save() url = get_relative_url("animal-list") data = { @@ -174,7 +174,7 @@ def test_oauth2_scope_permission(request, admin_user, oauth2_admin_access_token, response = unauthenticated_api_client.post( url, data=data, - headers={'Authorization': f'Bearer {oauth2_admin_access_token.token}'}, + headers={'Authorization': f'Bearer {oauth2_admin_access_token[1]}'}, ) assert response.status_code == status, response.status_code diff --git a/test_app/tests/oauth2_provider/test_models.py b/test_app/tests/oauth2_provider/test_models.py index aaccfc2e0..ddc03b1c2 100644 --- a/test_app/tests/oauth2_provider/test_models.py +++ b/test_app/tests/oauth2_provider/test_models.py @@ -5,8 +5,8 @@ @pytest.mark.django_db def test_oauth2_revoke_access_then_refresh_token(oauth2_admin_access_token): - token = oauth2_admin_access_token - refresh_token = oauth2_admin_access_token.refresh_token + token = oauth2_admin_access_token[0] + refresh_token = oauth2_admin_access_token[0].refresh_token assert OAuth2AccessToken.objects.count() == 1 assert OAuth2RefreshToken.objects.count() == 1 @@ -22,7 +22,7 @@ def test_oauth2_revoke_access_then_refresh_token(oauth2_admin_access_token): @pytest.mark.django_db def test_oauth2_revoke_refresh_token(oauth2_admin_access_token): - refresh_token = oauth2_admin_access_token.refresh_token + refresh_token = oauth2_admin_access_token[0].refresh_token assert OAuth2AccessToken.objects.count() == 1 assert OAuth2RefreshToken.objects.count() == 1 diff --git a/test_app/tests/oauth2_provider/views/test_token.py b/test_app/tests/oauth2_provider/views/test_token.py index 7de79af5a..0c521b282 100644 --- a/test_app/tests/oauth2_provider/views/test_token.py +++ b/test_app/tests/oauth2_provider/views/test_token.py @@ -1,4 +1,5 @@ import base64 +import hashlib import json import time @@ -7,6 +8,7 @@ from ansible_base.authentication.models import AuthenticatorUser from ansible_base.lib.utils.encryption import ENCRYPTED_STRING +from ansible_base.lib.utils.hashing import hash_string from ansible_base.lib.utils.response import get_relative_url from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken @@ -217,7 +219,7 @@ def test_oauth2_application_token_summary_fields(admin_api_client, oauth2_admin_ response = admin_api_client.get(url) assert response.status_code == 200 assert response.data['summary_fields']['tokens']['count'] == 1 - assert response.data['summary_fields']['tokens']['results'][0] == {'id': oauth2_admin_access_token.pk, 'scope': 'write', 'token': ENCRYPTED_STRING} + assert response.data['summary_fields']['tokens']['results'][0] == {'id': oauth2_admin_access_token[0].pk, 'scope': 'write', 'token': ENCRYPTED_STRING} @pytest.mark.django_db @@ -262,15 +264,17 @@ def test_oauth2_authorized_list_is_user_related_field(user, admin_api_client): @pytest.mark.django_db -def test_oauth2_token_createn(oauth2_application, admin_api_client, admin_user): +def test_oauth2_token_create(oauth2_application, admin_api_client, admin_user): oauth2_application = oauth2_application[0] url = get_relative_url('token-list') response = admin_api_client.post(url, {'scope': 'read', 'application': oauth2_application.pk}) assert response.status_code == 201 assert 'modified' in response.data and response.data['modified'] is not None assert 'updated' not in response.data - token = OAuth2AccessToken.objects.get(token=response.data['token']) - refresh_token = OAuth2RefreshToken.objects.get(token=response.data['refresh_token']) + hashed_token = hash_string(response.data['token'], hasher=hashlib.sha256, algo="sha256") + token = OAuth2AccessToken.objects.get(token=hashed_token) + hashed_refresh_token = hash_string(response.data['refresh_token'], hasher=hashlib.sha256, algo="sha256") + refresh_token = OAuth2RefreshToken.objects.get(token=hashed_refresh_token) assert token.application == oauth2_application assert refresh_token.application == oauth2_application assert token.user == admin_user @@ -308,28 +312,28 @@ def test_oauth2_token_createn(oauth2_application, admin_api_client, admin_user): @pytest.mark.django_db def test_oauth2_token_update(oauth2_admin_access_token, admin_api_client): - assert oauth2_admin_access_token.scope == 'write' - url = get_relative_url('token-detail', kwargs={'pk': oauth2_admin_access_token.pk}) + assert oauth2_admin_access_token[0].scope == 'write' + url = get_relative_url('token-detail', kwargs={'pk': oauth2_admin_access_token[0].pk}) response = admin_api_client.patch(url, {'scope': 'read'}) assert response.status_code == 200 - oauth2_admin_access_token.refresh_from_db() - assert oauth2_admin_access_token.scope == 'read' + oauth2_admin_access_token[0].refresh_from_db() + assert oauth2_admin_access_token[0].scope == 'read' @pytest.mark.django_db def test_oauth2_token_delete(oauth2_admin_access_token, admin_api_client): - url = get_relative_url('token-detail', kwargs={'pk': oauth2_admin_access_token.pk}) + url = get_relative_url('token-detail', kwargs={'pk': oauth2_admin_access_token[0].pk}) response = admin_api_client.delete(url) assert response.status_code == 204 assert OAuth2AccessToken.objects.count() == 0 assert OAuth2RefreshToken.objects.count() == 1 - url = get_relative_url('application-access_tokens-list', kwargs={'pk': oauth2_admin_access_token.application.pk}) + url = get_relative_url('application-access_tokens-list', kwargs={'pk': oauth2_admin_access_token[0].application.pk}) response = admin_api_client.get(url) assert response.status_code == 200 assert response.data['count'] == 0 - url = get_relative_url('application-detail', kwargs={'pk': oauth2_admin_access_token.application.pk}) + url = get_relative_url('application-detail', kwargs={'pk': oauth2_admin_access_token[0].application.pk}) response = admin_api_client.get(url) assert response.status_code == 200 assert response.data['summary_fields']['tokens']['count'] == 0 @@ -342,12 +346,13 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok """ app = oauth2_application[0] secret = oauth2_application[1] - refresh_token = oauth2_admin_access_token.refresh_token + refresh_token = oauth2_admin_access_token[2] + refresh_token_obj = oauth2_admin_access_token[0].refresh_token url = get_relative_url('token') data = { 'grant_type': 'refresh_token', - 'refresh_token': refresh_token.token, + 'refresh_token': refresh_token, } resp = unauthenticated_api_client.post( url, @@ -356,8 +361,8 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok headers={'Authorization': 'Basic ' + base64.b64encode(f"{app.client_id}:{secret}".encode()).decode()}, ) assert resp.status_code == 201 - assert OAuth2RefreshToken.objects.filter(token=refresh_token).exists() - original_refresh_token = OAuth2RefreshToken.objects.get(token=refresh_token) + assert OAuth2RefreshToken.objects.filter(token=refresh_token_obj.token).exists() + original_refresh_token = OAuth2RefreshToken.objects.get(token=refresh_token_obj.token) assert oauth2_admin_access_token not in OAuth2AccessToken.objects.all() assert OAuth2AccessToken.objects.count() == 1 @@ -367,12 +372,14 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok json_resp = json.loads(resp.content) new_token = json_resp['access_token'] + new_token_hashed = hash_string(new_token, hasher=hashlib.sha256, algo="sha256") new_refresh_token = json_resp['refresh_token'] + new_refresh_token_hashed = hash_string(new_refresh_token, hasher=hashlib.sha256, algo="sha256") - assert OAuth2AccessToken.objects.filter(token=new_token).count() == 1 + assert OAuth2AccessToken.objects.filter(token=new_token_hashed).count() == 1 # checks that RefreshTokens are rotated (new RefreshToken issued) - assert OAuth2RefreshToken.objects.filter(token=new_refresh_token).count() == 1 - new_refresh_obj = OAuth2RefreshToken.objects.get(token=new_refresh_token) + assert OAuth2RefreshToken.objects.filter(token=new_refresh_token_hashed).count() == 1 + new_refresh_obj = OAuth2RefreshToken.objects.get(token=new_refresh_token_hashed) assert not new_refresh_obj.revoked @@ -383,7 +390,8 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2 """ app = oauth2_application[0] secret = oauth2_application[1] - refresh_token = oauth2_admin_access_token.refresh_token + refresh_token = oauth2_admin_access_token[2] + refresh_token_obj = oauth2_admin_access_token[0].refresh_token settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 1 settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] = 1 @@ -393,7 +401,7 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2 url = get_relative_url('token') data = { 'grant_type': 'refresh_token', - 'refresh_token': refresh_token.token, + 'refresh_token': refresh_token, } response = admin_api_client.post( url, @@ -403,7 +411,7 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2 ) assert response.status_code == 403 assert b'The refresh token has expired.' in response.content - assert OAuth2RefreshToken.objects.filter(token=refresh_token).exists() + assert OAuth2RefreshToken.objects.filter(token=refresh_token_obj.token).exists() assert OAuth2AccessToken.objects.count() == 1 assert OAuth2RefreshToken.objects.count() == 1