Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oauth2_provider] Hash access and refresh tokens #641

Merged
merged 6 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ansible_base/lib/utils/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,18 @@ def hash_serializer_data(instance: Model, serializer: Type[Serializer], field: O
serialized_data = serialized_data[field]
metadata_json = json.dumps(serialized_data, sort_keys=True).encode("utf-8")
return hasher(metadata_json).hexdigest()


def hash_string(inp: str, hasher: Callable = hashlib.sha256, algo=""):
"""
Takes a string and hashes it with the given hasher function.
If algo is given, it is prepended to the hash between dollar signs ($)
before the hash is returned.

NOTE: There is no salt or pepper here, so this is not secure for passwords.
It is, however, useful for *random* strings like tokens, that need to be secured.
"""
hash = hasher(inp.encode("utf-8")).hexdigest()
if algo:
return f"${algo}${hash}"
return hash
22 changes: 21 additions & 1 deletion ansible_base/oauth2_provider/authentication.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import hashlib
import logging

from django.utils.encoding import smart_str
from oauth2_provider.contrib.rest_framework import OAuth2Authentication
from oauth2_provider.oauth2_backends import OAuthLibCore as _OAuthLibCore
from rest_framework.exceptions import UnsupportedMediaType

from ansible_base.lib.utils.hashing import hash_string

logger = logging.getLogger('ansible_base.oauth2_provider.authentication')


Expand All @@ -18,7 +21,24 @@ def extract_body(self, request):

class LoggedOAuth2Authentication(OAuth2Authentication):
def authenticate(self, request):
ret = super().authenticate(request)
# sha256 the bearer token. We store the hash in the database
# and this gives us a place to hash the incoming token for comparison
did_hash_token = False
bearer_token = request.META.get('HTTP_AUTHORIZATION')
if bearer_token and bearer_token.lower().startswith('bearer '):
token_component = bearer_token.split(' ', 1)[1]
hashed = hash_string(token_component, hasher=hashlib.sha256, algo="sha256")
did_hash_token = True
request.META['HTTP_AUTHORIZATION'] = f"Bearer {hashed}"

# We don't /really/ want to modify the request, so after we're done authing,
# revert what we did above.
try:
ret = super().authenticate(request)
finally:
if did_hash_token:
request.META['HTTP_AUTHORIZATION'] = bearer_token

if ret:
user, token = ret
username = user.username if user else '<none>'
Expand Down
12 changes: 11 additions & 1 deletion ansible_base/oauth2_provider/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hashlib
from datetime import datetime, timezone

import pytest
from oauthlib.common import generate_token

from ansible_base.lib.testing.fixtures import copy_fixture
from ansible_base.lib.utils.hashing import hash_string
from ansible_base.lib.utils.response import get_relative_url
from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2Application

Expand Down Expand Up @@ -62,10 +64,18 @@ def oauth2_application_password(randname):

@pytest.fixture
def oauth2_admin_access_token(oauth2_application, admin_api_client, admin_user):
"""
3-tuple with (token object with hashed token, plaintext token, plaintext_refresh_token)
"""
url = get_relative_url('token-list')
response = admin_api_client.post(url, {'application': oauth2_application[0].pk})
assert response.status_code == 201
return OAuth2AccessToken.objects.get(token=response.data['token'])

plaintext_token = response.data['token']
plaintext_refresh_token = response.data['refresh_token']
hashed_token = hash_string(plaintext_token, hasher=hashlib.sha256, algo="sha256")
token = OAuth2AccessToken.objects.get(token=hashed_token)
return (token, plaintext_token, plaintext_refresh_token)


@copy_fixture(copies=3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def __init__(self):
self.user = user

serializer_obj.context['request'] = FakeRequest()
token_record = serializer_obj.create(config)
self.stdout.write(token_record.token)
serializer_obj.create(config)
self.stdout.write(serializer_obj.unencrypted_token)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from django.db import migrations

from ansible_base.oauth2_provider.migrations._utils import hash_tokens


class Migration(migrations.Migration):
dependencies = [
("dab_oauth2_provider", "0004_alter_oauth2accesstoken_scope"),
]

operations = [
migrations.RunPython(hash_tokens),
]
16 changes: 16 additions & 0 deletions ansible_base/oauth2_provider/migrations/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import hashlib

from ansible_base.lib.utils.hashing import hash_string


def hash_tokens(apps, schema_editor):
OAuth2AccessToken = apps.get_model("dab_oauth2_provider", "OAuth2AccessToken")
OAuth2RefreshToken = apps.get_model("dab_oauth2_provider", "OAuth2RefreshToken")
for model in (OAuth2AccessToken, OAuth2RefreshToken):
for token in model.objects.all():
# Never re-hash a hashed token
if token.token.startswith("$"):
continue
hashed = hash_string(token.token, hasher=hashlib.sha256, algo="sha256")
token.token = hashed
token.save()
4 changes: 4 additions & 0 deletions ansible_base/oauth2_provider/models/access_token.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib

import oauth2_provider.models as oauth2_models
from django.conf import settings
from django.core.exceptions import ValidationError
Expand All @@ -7,6 +9,7 @@
from oauthlib import oauth2

from ansible_base.lib.abstract_models.common import CommonModel
from ansible_base.lib.utils.hashing import hash_string
from ansible_base.lib.utils.models import prevent_search
from ansible_base.lib.utils.settings import get_setting
from ansible_base.oauth2_provider.utils import is_external_account
Expand Down Expand Up @@ -103,4 +106,5 @@ def validate_external_users(self):
def save(self, *args, **kwargs):
if not self.pk:
self.validate_external_users()
self.token = hash_string(self.token, hasher=hashlib.sha256, algo="sha256")
super().save(*args, **kwargs)
8 changes: 8 additions & 0 deletions ansible_base/oauth2_provider/models/refresh_token.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import hashlib

import oauth2_provider.models as oauth2_models
from django.conf import settings
from django.db import models
from django.utils.translation import gettext_lazy as _

from ansible_base.lib.abstract_models.common import CommonModel
from ansible_base.lib.utils.hashing import hash_string
from ansible_base.lib.utils.models import prevent_search

activitystream = object
Expand All @@ -21,3 +24,8 @@ class Meta(oauth2_models.AbstractRefreshToken.Meta):

token = prevent_search(models.CharField(max_length=255))
updated = None # Tracked in CommonModel with 'modified', no need for this

def save(self, *args, **kwargs):
if not self.pk:
self.token = hash_string(self.token, hasher=hashlib.sha256, algo="sha256")
super().save(*args, **kwargs)
52 changes: 29 additions & 23 deletions ansible_base/oauth2_provider/serializers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
logger = logging.getLogger("ansible_base.oauth2_provider.serializers.token")


class BaseOAuth2TokenSerializer(CommonModelSerializer):
class OAuth2TokenSerializer(CommonModelSerializer):
refresh_token = SerializerMethodField()
token = SerializerMethodField()

unencrypted_token = None # Only used in POST so we can return the token in the response
unencrypted_refresh_token = None # Only used in POST so we can return the refresh token in the response

class Meta:
model = OAuth2AccessToken
Expand All @@ -40,23 +42,23 @@ class Meta:
read_only_fields = ('user', 'token', 'expires', 'refresh_token')
extra_kwargs = {'scope': {'allow_null': False, 'required': False}, 'user': {'allow_null': False, 'required': True}}

def get_token(self, obj) -> str:
request = self.context.get('request')
try:
if request and request.method == 'POST':
return obj.token
else:
return ENCRYPTED_STRING
except ObjectDoesNotExist:
return ''
def to_representation(self, instance):
request = self.context.get('request', None)
ret = super().to_representation(instance)
if request and request.method == 'POST':
# If we're creating the token, show it. Otherwise, show the encrypted string.
ret['token'] = self.unencrypted_token
else:
ret['token'] = ENCRYPTED_STRING
return ret

def get_refresh_token(self, obj) -> Optional[str]:
request = self.context.get('request')
try:
if not obj.refresh_token:
return None
elif request and request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
return self.unencrypted_refresh_token
else:
return ENCRYPTED_STRING
except ObjectDoesNotExist:
Expand All @@ -78,26 +80,30 @@ def validate_scope(self, value):
raise ValidationError(_('Must be a simple space-separated string with allowed scopes {}.').format(SCOPES))
return value

def create(self, validated_data):
validated_data['user'] = self.context['request'].user
try:
return super().create(validated_data)
except AccessDeniedError as e:
raise PermissionDenied(str(e))


class OAuth2TokenSerializer(BaseOAuth2TokenSerializer):
def create(self, validated_data):
current_user = get_current_user()
validated_data['token'] = generate_token()
expires_delta = get_setting('OAUTH2_PROVIDER', {}).get('ACCESS_TOKEN_EXPIRE_SECONDS', 0)
if expires_delta == 0:
logger.warning("OAUTH2_PROVIDER.ACCESS_TOKEN_EXPIRE_SECONDS was set to 0, creating token that has already expired")
validated_data['expires'] = now() + timedelta(seconds=expires_delta)
obj = super().create(validated_data)
validated_data['user'] = self.context['request'].user
self.unencrypted_token = validated_data.get('token') # Before it is hashed

try:
obj = super().create(validated_data)
except AccessDeniedError as e:
raise PermissionDenied(str(e))

if obj.application and obj.application.user:
obj.user = obj.application.user
obj.save()
if obj.application:
OAuth2RefreshToken.objects.create(user=current_user, token=generate_token(), application=obj.application, access_token=obj)
self.unencrypted_refresh_token = generate_token()
OAuth2RefreshToken.objects.create(
user=current_user,
token=self.unencrypted_refresh_token,
application=obj.application,
access_token=obj,
)
return obj
23 changes: 21 additions & 2 deletions ansible_base/oauth2_provider/views/token.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import hashlib
from datetime import timedelta

from django.utils.timezone import now
from oauth2_provider import views as oauth_views
from oauthlib import oauth2
from rest_framework.viewsets import ModelViewSet

from ansible_base.lib.utils.hashing import hash_string
from ansible_base.lib.utils.settings import get_setting
from ansible_base.lib.utils.views.django_app_api import AnsibleBaseDjangoAppApiView
from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken
Expand All @@ -28,7 +30,8 @@ def create_token_response(self, request):
# This code detects and auto-expires them on refresh grant
# requests.
if request.POST.get('grant_type') == 'refresh_token' and 'refresh_token' in request.POST:
refresh_token = OAuth2RefreshToken.objects.filter(token=request.POST['refresh_token']).first()
hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256, algo="sha256")
refresh_token = OAuth2RefreshToken.objects.filter(token=hashed_refresh_token).first()
if refresh_token:
expire_seconds = get_setting('OAUTH2_PROVIDER', {}).get('REFRESH_TOKEN_EXPIRE_SECONDS', 0)
if refresh_token.created + timedelta(seconds=expire_seconds) < now():
Expand All @@ -38,7 +41,23 @@ def create_token_response(self, request):

# oauth2_provider.oauth2_backends.OAuthLibCore.create_token_response
# (we override this so we can implement our own error handling to be compatible with AWX)
uri, http_method, body, headers = core._extract_params(request)

# This is really, really ugly. Modify the request to hash the refresh_token
# but only long enough for the oauth lib to do its magic.
did_hash_refresh_token = False
old_post = request.POST
if 'refresh_token' in request.POST:
did_hash_refresh_token = True
request.POST = request.POST.copy() # so it's mutable
hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256, algo="sha256")
request.POST['refresh_token'] = hashed_refresh_token

try:
uri, http_method, body, headers = core._extract_params(request)
finally:
if did_hash_refresh_token:
request.POST = old_post

extra_credentials = core._get_extra_credentials(request)
try:
headers, body, status = core.server.create_token_response(uri, http_method, body, headers, extra_credentials)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_cleanup_expired_tokens(self, oauth2_admin_access_token):
attempt_cleanup(0, 0)

# Manually expire admin token
oauth2_admin_access_token.expires = datetime.datetime.fromtimestamp(0)
oauth2_admin_access_token.save()
oauth2_admin_access_token[0].expires = datetime.datetime.fromtimestamp(0)
oauth2_admin_access_token[0].save()

attempt_cleanup(1, 1)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Python
import hashlib
import random
import string
from io import StringIO
Expand All @@ -10,6 +11,8 @@
from django.core.management import call_command
from django.core.management.base import CommandError

from ansible_base.lib.utils.hashing import hash_string
from ansible_base.lib.utils.response import get_relative_url
from ansible_base.oauth2_provider.models import OAuth2AccessToken

User = get_user_model()
Expand All @@ -34,11 +37,21 @@ def test_non_existing_user(self):
call_command('create_oauth2_token', arg, stdout=out)
assert 'The user does not exist.' in str(excinfo.value)

def test_correct_user(self, random_user):
def test_correct_user(self, random_user, unauthenticated_api_client):
user_username = random_user.username
with StringIO() as out:
arg = '--user=' + user_username
call_command('create_oauth2_token', arg, stdout=out)
generated_token = out.getvalue().strip()
assert OAuth2AccessToken.objects.filter(user=random_user, token=generated_token).count() == 1
assert OAuth2AccessToken.objects.get(user=random_user, token=generated_token).scope == 'write'

hashed_token = hash_string(generated_token, hasher=hashlib.sha256, algo="sha256")
assert OAuth2AccessToken.objects.filter(user=random_user, token=hashed_token).count() == 1
assert OAuth2AccessToken.objects.get(user=random_user, token=hashed_token).scope == 'write'

url = get_relative_url("user-me")
response = unauthenticated_api_client.get(
url,
headers={'Authorization': f'Bearer {generated_token}'},
)
assert response.status_code == 200
assert response.data['username'] == user_username
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_revoke_all_access_tokens(self, oauth2_admin_access_token, oauth2_user_a

def test_revoke_access_token_for_user(self, oauth2_admin_access_token, oauth2_user_application_token):
with StringIO() as out:
admin_username = oauth2_admin_access_token.user.username
admin_username = oauth2_admin_access_token[0].user.username
user_username = oauth2_user_application_token.user.username

assert OAuth2AccessToken.objects.count() == 2
Expand Down
Empty file.
Loading
Loading