diff --git a/trino/auth.py b/trino/auth.py index 48128311..693da04e 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -12,29 +12,19 @@ import abc import os +from typing import Optional, Sequence, Type -from typing import Optional -from requests.auth import AuthBase +import requests +import requests.auth class Authentication(metaclass=abc.ABCMeta): @abc.abstractmethod - def set_http_session(self, http_session): + def set_http_session(self, http_session: requests.Session) -> None: pass - @abc.abstractmethod - def set_client_session(self, client_session): - pass - - @abc.abstractmethod - def setup(self): - pass - - def get_exceptions(self): - return tuple() - - def handle_err(self, error): - pass + def get_exceptions(self) -> Sequence[Type[Exception]]: + return () class KerberosAuthentication(Authentication): @@ -60,10 +50,7 @@ def __init__( self._delegate = delegate self._ca_bundle = ca_bundle - def set_client_session(self, client_session): - pass - - def set_http_session(self, http_session): + def set_http_session(self, http_session: requests.Session) -> None: try: import requests_kerberos except ImportError: @@ -83,13 +70,8 @@ def set_http_session(self, http_session): ) if self._ca_bundle: http_session.verify = self._ca_bundle - return http_session - def setup(self, trino_client): - self.set_client_session(trino_client.client_session) - self.set_http_session(trino_client.http_session) - - def get_exceptions(self): + def get_exceptions(self) -> Sequence[Type[Exception]]: try: from requests_kerberos.exceptions import KerberosExchangeError @@ -97,43 +79,22 @@ def get_exceptions(self): except ImportError: raise RuntimeError("unable to import requests_kerberos") - def handle_error(self, handle_error): - pass - class BasicAuthentication(Authentication): - def __init__(self, username, password): + def __init__(self, username: str, password: str) -> None: self._username = username self._password = password - def set_client_session(self, client_session): - pass - - def set_http_session(self, http_session): - try: - import requests.auth - except ImportError: - raise RuntimeError("unable to import requests.auth") - + def set_http_session(self, http_session: requests.Session) -> None: http_session.auth = requests.auth.HTTPBasicAuth(self._username, self._password) - return http_session - - def setup(self, trino_client): - self.set_client_session(trino_client.client_session) - self.set_http_session(trino_client.http_session) - - def get_exceptions(self): - return () - - def handle_error(self, handle_error): - pass -class _BearerAuth(AuthBase): +class _BearerAuth(requests.auth.AuthBase): """ - Custom implementation of Authentication class for bearer token + Custom implementation of AuthBase class for bearer token """ - def __init__(self, token): + + def __init__(self, token: str) -> None: self.token = token def __call__(self, r): @@ -142,23 +103,8 @@ def __call__(self, r): class JWTAuthentication(Authentication): - - def __init__(self, token): + def __init__(self, token: str) -> None: self.token = token - def set_client_session(self, client_session): - pass - - def set_http_session(self, http_session): + def set_http_session(self, http_session: requests.Session) -> None: http_session.auth = _BearerAuth(self.token) - return http_session - - def setup(self, trino_client): - self.set_client_session(trino_client.client_session) - self.set_http_session(trino_client.http_session) - - def get_exceptions(self): - return () - - def handle_error(self, handle_error): - pass diff --git a/trino/client.py b/trino/client.py index 4cd51236..e8af0afb 100644 --- a/trino/client.py +++ b/trino/client.py @@ -41,6 +41,7 @@ import trino.logging from trino import constants, exceptions +from trino.auth import Authentication from trino.transaction import NO_TRANSACTION __all__ = ["TrinoQuery", "TrinoRequest"] @@ -201,7 +202,7 @@ def __init__( http_headers: Optional[Dict[str, str]] = None, transaction_id: Optional[str] = NO_TRANSACTION, http_scheme: str = constants.HTTP, - auth: Optional[Any] = constants.DEFAULT_AUTH, + auth: Optional[Authentication] = constants.DEFAULT_AUTH, redirect_handler: Any = None, max_attempts: int = MAX_ATTEMPTS, request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT, diff --git a/trino/constants.py b/trino/constants.py index 9c81617f..5f08d43c 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -10,14 +10,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Optional +from trino.auth import Authentication DEFAULT_PORT = 8080 DEFAULT_SOURCE = "trino-python-client" DEFAULT_CATALOG: Optional[str] = None DEFAULT_SCHEMA: Optional[str] = None -DEFAULT_AUTH: Optional[Any] = None +DEFAULT_AUTH: Optional[Authentication] = None DEFAULT_MAX_ATTEMPTS = 3 DEFAULT_REQUEST_TIMEOUT: float = 30.0