-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
1 changed file
with
42 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
import base64, random, hashlib | ||
import base64 | ||
import random | ||
import hashlib | ||
from datetime import timedelta | ||
from django.test import TestCase | ||
from django.utils import timezone | ||
|
@@ -19,8 +21,15 @@ def setUp(self, oauth2_settings=oauth2_settings): | |
""" | ||
Create a demo user, an OAuth Application and an access token for use in testing. | ||
""" | ||
self.test_user = User.objects.create_user(username="oauth_test_user",email= "[email protected]", password="123456") | ||
self.dev_user = User.objects.create_user(username="oauth_dev_user",email= "[email protected]", password="123456") | ||
self.test_user = User.objects.create_user( | ||
username="oauth_test_user", | ||
email="[email protected]", | ||
password="123456", | ||
) | ||
|
||
self.dev_user = User.objects.create_user( | ||
username="oauth_dev_user", email="[email protected]", password="123456" | ||
) | ||
|
||
self.oauth2_settings = oauth2_settings | ||
|
||
|
@@ -30,7 +39,7 @@ def setUp(self, oauth2_settings=oauth2_settings): | |
user=self.dev_user, | ||
client_type=Application.CLIENT_CONFIDENTIAL, | ||
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, | ||
client_secret = CLEARTEXT_SECRET, | ||
client_secret=CLEARTEXT_SECRET, | ||
) | ||
|
||
self.access_token = AccessToken.objects.create( | ||
|
@@ -39,7 +48,7 @@ def setUp(self, oauth2_settings=oauth2_settings): | |
expires=timezone.now() + timedelta(seconds=300), | ||
token="secret-access-token-key", | ||
application=self.application, | ||
) | ||
) | ||
|
||
def _create_authorization_header(self, token): | ||
return "Bearer {0}".format(token) | ||
|
@@ -56,8 +65,12 @@ def get_basic_auth_header(self, user, password): | |
} | ||
|
||
return auth_headers | ||
|
||
def get_random_string(self, length=12, allowed_chars="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"): | ||
|
||
def get_random_string( | ||
self, | ||
length=12, | ||
allowed_chars="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", | ||
): | ||
""" | ||
Return a securely generated random string. | ||
The default length of 12 with the a-z, A-Z, 0-9 character set returns | ||
|
@@ -91,7 +104,9 @@ def generate_pkce_codes(self, algorithm, length=43): | |
verifier = self.get_random_string(length=length) | ||
if algorithm == "S256": | ||
challenge = ( | ||
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") | ||
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()) | ||
.decode() | ||
.rstrip("=") | ||
) | ||
elif algorithm == "plain": | ||
challenge = verifier | ||
|
@@ -114,7 +129,9 @@ def get_auth(self): | |
"allow": True, | ||
} | ||
|
||
response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) | ||
response = self.client.post( | ||
reverse("oauth2_provider:authorize"), data=authcode_data | ||
) | ||
query_dict = parse_qs(urlparse(response["Location"]).query) | ||
return query_dict["code"].pop() | ||
|
||
|
@@ -133,7 +150,9 @@ def get_auth_pkce(self, code_challenge, code_challenge_method): | |
"code_challenge_method": code_challenge_method, | ||
} | ||
|
||
response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) | ||
response = self.client.post( | ||
reverse("oauth2_provider:authorize"), data=authcode_data | ||
) | ||
query_dict = parse_qs(urlparse(response["Location"]).query) | ||
return query_dict["code"].pop() | ||
|
||
|
@@ -156,9 +175,13 @@ def test_basic_auth(self): | |
"code": authorization_code, | ||
"redirect_uri": "http://example.org", | ||
} | ||
auth_headers = self.get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) | ||
auth_headers = self.get_basic_auth_header( | ||
self.application.client_id, CLEARTEXT_SECRET | ||
) | ||
|
||
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) | ||
response = self.client.post( | ||
reverse("oauth2_provider:token"), data=token_request_data, **auth_headers | ||
) | ||
self.assertEqual(response.status_code, 200) | ||
|
||
def test_secure_auth_pkce(self): | ||
|
@@ -180,7 +203,11 @@ def test_secure_auth_pkce(self): | |
"redirect_uri": "http://example.org", | ||
"code_verifier": code_verifier, | ||
} | ||
auth_headers = self.get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) | ||
auth_headers = self.get_basic_auth_header( | ||
self.application.client_id, CLEARTEXT_SECRET | ||
) | ||
|
||
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) | ||
self.assertEqual(response.status_code, 200) | ||
response = self.client.post( | ||
reverse("oauth2_provider:token"), data=token_request_data, **auth_headers | ||
) | ||
self.assertEqual(response.status_code, 200) |