Skip to content

Commit

Permalink
Merge pull request #30 from onaio/patch-28
Browse files Browse the repository at this point in the history
Patch #28: validate provider names irregardless of case
  • Loading branch information
moshthepitt authored May 14, 2019
2 parents 1960eff + 30048c5 commit 37375ba
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 7 deletions.
8 changes: 4 additions & 4 deletions superset_patchup/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import jwt
from flask_login import login_user

from superset_patchup.utils import is_safe_url
from superset_patchup.utils import is_safe_url, is_valid_provider


class AuthOAuthView(SupersetAuthOAuthView):
Expand Down Expand Up @@ -229,7 +229,7 @@ def oauth_user_info(self, provider, response=None):
# above)
email_base = app.config.get("PATCHUP_EMAIL_BASE")

if provider == "onadata":
if is_valid_provider(provider, "onadata"):
user = (self.appbuilder.sm.oauth_remotes[provider].get(
"api/v1/user.json").data)

Expand All @@ -245,7 +245,7 @@ def oauth_user_info(self, provider, response=None):
"last_name": user_data["last_name"],
}

if provider == "OpenSRP":
if is_valid_provider(provider, "OpenSRP"):
user_object = (self.appbuilder.sm.oauth_remotes[provider].get(
"user-details").data)
username = user_object["userName"]
Expand All @@ -261,7 +261,7 @@ def oauth_user_info(self, provider, response=None):

return result

if provider == "openlmis":
if is_valid_provider(provider, "openlmis"):
# get access token
my_token = self.oauth_tokengetter()[0]
# get referenceDataUserId
Expand Down
10 changes: 10 additions & 0 deletions superset_patchup/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,13 @@ def is_safe_url(target_url):
test_url = urlparse(urljoin(request.host_url, target_url))
return test_url.scheme in ("http",
"https") and ref_url.netloc == test_url.netloc


def is_valid_provider(user_input: str, static_provider: str) -> bool:
"""
Validate a user's provider input irrespectve of case
"""
try:
return user_input.lower() == static_provider.lower()
except AttributeError:
return False
17 changes: 16 additions & 1 deletion tests/oauth/test_oauth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This module tests oauth
"""
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, call

from superset import app

Expand Down Expand Up @@ -325,3 +325,18 @@ def test_login_redirec(

oauth_view.login(provider="onadata")
mock_redirect.assert_called_once_with("/superset/dashboard/3")

@patch('superset_patchup.oauth.is_valid_provider')
def test_is_valid_provider_is_called_for_opendata(self, function_mock):
"""
Test that is_valid_provider function is called for all provider names
"""
function_mock.return_value = False
appbuilder = MagicMock()
csm = CustomSecurityManager(appbuilder=appbuilder)
csm.oauth_user_info(provider="Onadata")
assert call("Onadata", "onadata") in function_mock.call_args_list
csm.oauth_user_info(provider="opensrp")
assert call("opensrp", "OpenSRP") in function_mock.call_args_list
csm.oauth_user_info(provider="OPENLMIS")
assert call("OPENLMIS", "openlmis") in function_mock.call_args_list
15 changes: 13 additions & 2 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
This module tests utils
"""
from unittest.mock import patch
from unittest.mock import patch, MagicMock

from superset_patchup.utils import get_complex_env_var, is_safe_url
from superset_patchup.utils import get_complex_env_var, is_safe_url, is_valid_provider
from superset_patchup.oauth import CustomSecurityManager


class TestUtils:
Expand Down Expand Up @@ -55,3 +56,13 @@ def test_get_complex_env_var(self, mock):
bool_params = get_complex_env_var("PARAMS", default_params)
assert isinstance(bool_params, bool)
assert bool_params is True

def test_case_insensitivity_for_provider(self):
"""
Test that provider information form user can be case insesitive,
to static standard strings that they will be checked against
"""
assert is_valid_provider("opensrp", "OpenSRP")
assert is_valid_provider("OnaData", 'onadata')
assert is_valid_provider("OpenlMis", "openlmis")
assert not is_valid_provider("oensrp", "OpenSrp")

0 comments on commit 37375ba

Please sign in to comment.