Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up authentication #93

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 16 additions & 70 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,19 @@

import abc
import os
from typing import Optional, Sequence, Type

from typing import Optional
from requests.auth import AuthBase
import requests
import requests.auth


class Authentication(metaclass=abc.ABCMeta):
@abc.abstractmethod
def set_http_session(self, http_session):
def set_http_session(self, http_session: requests.Session) -> None:
pass

@abc.abstractmethod
def set_client_session(self, client_session):
pass

@abc.abstractmethod
def setup(self):
pass

def get_exceptions(self):
return tuple()

def handle_err(self, error):
pass
def get_exceptions(self) -> Sequence[Type[Exception]]:
return ()


class KerberosAuthentication(Authentication):
Expand All @@ -60,10 +50,7 @@ def __init__(
self._delegate = delegate
self._ca_bundle = ca_bundle

def set_client_session(self, client_session):
pass

def set_http_session(self, http_session):
def set_http_session(self, http_session: requests.Session) -> None:
try:
import requests_kerberos
except ImportError:
Expand All @@ -83,57 +70,31 @@ def set_http_session(self, http_session):
)
if self._ca_bundle:
http_session.verify = self._ca_bundle
return http_session

def setup(self, trino_client):
self.set_client_session(trino_client.client_session)
self.set_http_session(trino_client.http_session)

def get_exceptions(self):
def get_exceptions(self) -> Sequence[Type[Exception]]:
try:
from requests_kerberos.exceptions import KerberosExchangeError

return (KerberosExchangeError,)
except ImportError:
raise RuntimeError("unable to import requests_kerberos")

def handle_error(self, handle_error):
pass


class BasicAuthentication(Authentication):
def __init__(self, username, password):
def __init__(self, username: str, password: str) -> None:
self._username = username
self._password = password

def set_client_session(self, client_session):
pass

def set_http_session(self, http_session):
try:
import requests.auth
except ImportError:
raise RuntimeError("unable to import requests.auth")

def set_http_session(self, http_session: requests.Session) -> None:
http_session.auth = requests.auth.HTTPBasicAuth(self._username, self._password)
return http_session

def setup(self, trino_client):
self.set_client_session(trino_client.client_session)
self.set_http_session(trino_client.http_session)

def get_exceptions(self):
return ()

def handle_error(self, handle_error):
pass


class _BearerAuth(AuthBase):
class _BearerAuth(requests.auth.AuthBase):
"""
Custom implementation of Authentication class for bearer token
Custom implementation of AuthBase class for bearer token
"""
def __init__(self, token):

def __init__(self, token: str) -> None:
self.token = token

def __call__(self, r):
Expand All @@ -142,23 +103,8 @@ def __call__(self, r):


class JWTAuthentication(Authentication):

def __init__(self, token):
def __init__(self, token: str) -> None:
self.token = token

def set_client_session(self, client_session):
pass

def set_http_session(self, http_session):
def set_http_session(self, http_session: requests.Session) -> None:
http_session.auth = _BearerAuth(self.token)
return http_session

def setup(self, trino_client):
self.set_client_session(trino_client.client_session)
self.set_http_session(trino_client.http_session)

def get_exceptions(self):
return ()

def handle_error(self, handle_error):
pass
3 changes: 2 additions & 1 deletion trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import trino.logging
from trino import constants, exceptions
from trino.auth import Authentication
from trino.transaction import NO_TRANSACTION

__all__ = ["TrinoQuery", "TrinoRequest"]
Expand Down Expand Up @@ -201,7 +202,7 @@ def __init__(
http_headers: Optional[Dict[str, str]] = None,
transaction_id: Optional[str] = NO_TRANSACTION,
http_scheme: str = constants.HTTP,
auth: Optional[Any] = constants.DEFAULT_AUTH,
auth: Optional[Authentication] = constants.DEFAULT_AUTH,
redirect_handler: Any = None,
max_attempts: int = MAX_ATTEMPTS,
request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT,
Expand Down
5 changes: 3 additions & 2 deletions trino/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional
from typing import Optional

from trino.auth import Authentication

DEFAULT_PORT = 8080
DEFAULT_SOURCE = "trino-python-client"
DEFAULT_CATALOG: Optional[str] = None
DEFAULT_SCHEMA: Optional[str] = None
DEFAULT_AUTH: Optional[Any] = None
DEFAULT_AUTH: Optional[Authentication] = None
DEFAULT_MAX_ATTEMPTS = 3
DEFAULT_REQUEST_TIMEOUT: float = 30.0

Expand Down