Skip to content

Commit

Permalink
Prefix token hashes
Browse files Browse the repository at this point in the history
We don't decode differently depending on the prefix right now,
but this gives us the ability to in the future if we ever need to.

Signed-off-by: Rick Elrod <[email protected]>
  • Loading branch information
relrod committed Nov 15, 2024
1 parent e9148d6 commit c78e1ea
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 15 deletions.
9 changes: 7 additions & 2 deletions ansible_base/lib/utils/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@ def hash_serializer_data(instance: Model, serializer: Type[Serializer], field: O
return hasher(metadata_json).hexdigest()


def hash_string(inp: str, hasher: Callable = hashlib.sha256):
def hash_string(inp: str, hasher: Callable = hashlib.sha256, algo=""):
"""
Takes a string and hashes it with the given hasher function.
If algo is given, it is prepended to the hash between dollar signs ($)
before the hash is returned.
NOTE: There is no salt or pepper here, so this is not secure for passwords.
It is, however, useful for *random* strings like tokens, that need to be secured.
"""
return hasher(inp.encode("utf-8")).hexdigest()
hash = hasher(inp.encode("utf-8")).hexdigest()
if algo:
return f"${algo}${hash}"
return hash
2 changes: 1 addition & 1 deletion ansible_base/oauth2_provider/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def authenticate(self, request):
bearer_token = request.META.get('HTTP_AUTHORIZATION')
if bearer_token and bearer_token.lower().startswith('bearer '):
token_component = bearer_token.split(' ', 1)[1]
hashed = hash_string(token_component, hasher=hashlib.sha256)
hashed = hash_string(token_component, hasher=hashlib.sha256, algo="sha256")
did_hash_token = True
request.META['HTTP_AUTHORIZATION'] = f"Bearer {hashed}"

Expand Down
2 changes: 1 addition & 1 deletion ansible_base/oauth2_provider/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def oauth2_admin_access_token(oauth2_application, admin_api_client, admin_user):

plaintext_token = response.data['token']
plaintext_refresh_token = response.data['refresh_token']
hashed_token = hash_string(plaintext_token, hasher=hashlib.sha256)
hashed_token = hash_string(plaintext_token, hasher=hashlib.sha256, algo="sha256")
token = OAuth2AccessToken.objects.get(token=hashed_token)
return (token, plaintext_token, plaintext_refresh_token)

Expand Down
4 changes: 2 additions & 2 deletions ansible_base/oauth2_provider/migrations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ def hash_tokens(apps, schema_editor):
for model in (OAuth2AccessToken, OAuth2RefreshToken):
for token in model.objects.all():
# Never re-hash a hashed token
if len(token.token) == 64:
if token.token.startswith("$"):
continue
hashed = hash_string(token.token, hasher=hashlib.sha256)
hashed = hash_string(token.token, hasher=hashlib.sha256, algo="sha256")
token.token = hashed
token.save()
2 changes: 1 addition & 1 deletion ansible_base/oauth2_provider/models/access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ def validate_external_users(self):
def save(self, *args, **kwargs):
if not self.pk:
self.validate_external_users()
self.token = hash_string(self.token, hasher=hashlib.sha256)
self.token = hash_string(self.token, hasher=hashlib.sha256, algo="sha256")
super().save(*args, **kwargs)
2 changes: 1 addition & 1 deletion ansible_base/oauth2_provider/models/refresh_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ class Meta(oauth2_models.AbstractRefreshToken.Meta):

def save(self, *args, **kwargs):
if not self.pk:
self.token = hash_string(self.token, hasher=hashlib.sha256)
self.token = hash_string(self.token, hasher=hashlib.sha256, algo="sha256")
super().save(*args, **kwargs)
4 changes: 2 additions & 2 deletions ansible_base/oauth2_provider/views/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_token_response(self, request):
# This code detects and auto-expires them on refresh grant
# requests.
if request.POST.get('grant_type') == 'refresh_token' and 'refresh_token' in request.POST:
hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256)
hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256, algo="sha256")
refresh_token = OAuth2RefreshToken.objects.filter(token=hashed_refresh_token).first()
if refresh_token:
expire_seconds = get_setting('OAUTH2_PROVIDER', {}).get('REFRESH_TOKEN_EXPIRE_SECONDS', 0)
Expand All @@ -49,7 +49,7 @@ def create_token_response(self, request):
if 'refresh_token' in request.POST:
did_hash_refresh_token = True
request.POST = request.POST.copy() # so it's mutable
hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256)
hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256, algo="sha256")
request.POST['refresh_token'] = hashed_refresh_token

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_correct_user(self, random_user, unauthenticated_api_client):
call_command('create_oauth2_token', arg, stdout=out)
generated_token = out.getvalue().strip()

hashed_token = hash_string(generated_token, hasher=hashlib.sha256)
hashed_token = hash_string(generated_token, hasher=hashlib.sha256, algo="sha256")
assert OAuth2AccessToken.objects.filter(user=random_user, token=hashed_token).count() == 1
assert OAuth2AccessToken.objects.get(user=random_user, token=hashed_token).scope == 'write'

Expand Down
8 changes: 4 additions & 4 deletions test_app/tests/oauth2_provider/views/test_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def test_oauth2_token_create(oauth2_application, admin_api_client, admin_user):
assert response.status_code == 201
assert 'modified' in response.data and response.data['modified'] is not None
assert 'updated' not in response.data
hashed_token = hash_string(response.data['token'], hasher=hashlib.sha256)
hashed_token = hash_string(response.data['token'], hasher=hashlib.sha256, algo="sha256")
token = OAuth2AccessToken.objects.get(token=hashed_token)
hashed_refresh_token = hash_string(response.data['refresh_token'], hasher=hashlib.sha256)
hashed_refresh_token = hash_string(response.data['refresh_token'], hasher=hashlib.sha256, algo="sha256")
refresh_token = OAuth2RefreshToken.objects.get(token=hashed_refresh_token)
assert token.application == oauth2_application
assert refresh_token.application == oauth2_application
Expand Down Expand Up @@ -372,9 +372,9 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok

json_resp = json.loads(resp.content)
new_token = json_resp['access_token']
new_token_hashed = hash_string(new_token, hasher=hashlib.sha256)
new_token_hashed = hash_string(new_token, hasher=hashlib.sha256, algo="sha256")
new_refresh_token = json_resp['refresh_token']
new_refresh_token_hashed = hash_string(new_refresh_token, hasher=hashlib.sha256)
new_refresh_token_hashed = hash_string(new_refresh_token, hasher=hashlib.sha256, algo="sha256")

assert OAuth2AccessToken.objects.filter(token=new_token_hashed).count() == 1
# checks that RefreshTokens are rotated (new RefreshToken issued)
Expand Down

0 comments on commit c78e1ea

Please sign in to comment.