From 65a6453c8380b3b1d9e25ace3cb17143ef5512ea Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Fri, 30 Aug 2024 14:51:50 +0200 Subject: [PATCH 1/9] Add option for custom auth --- backend/chainlit/auth.py | 2 ++ backend/chainlit/config.py | 4 ++++ backend/chainlit/oauth_providers.py | 29 ++++++++++++++++++----------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/backend/chainlit/auth.py b/backend/chainlit/auth.py index 981d33abe4..d7c478b7ee 100644 --- a/backend/chainlit/auth.py +++ b/backend/chainlit/auth.py @@ -88,4 +88,6 @@ async def get_current_user(token: str = Depends(reuseable_oauth)): if not require_login(): return None + if config.code.custom_authenticate_user: + return await config.code.custom_authenticate_user(token) return await authenticate_user(token) diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 5700479677..26dc93b117 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -18,6 +18,7 @@ import tomli from chainlit.logger import logger +from chainlit.oauth_providers import OAuthProvider from chainlit.translations import lint_translation_json from chainlit.version import __version__ from dataclasses_json import DataClassJsonMixin @@ -291,6 +292,9 @@ class CodeSettings: oauth_callback: Optional[ Callable[[str, str, Dict[str, str], "User"], Awaitable[Optional["User"]]] ] = None + # Callbacks for authenticate mechanism + custom_authenticate_user: Optional[Callable[[str], "User"]] + custom_oauth_provider: Optional[OAuthProvider] on_logout: Optional[Callable[["Request", "Response"], Any]] = None on_stop: Optional[Callable[[], Any]] = None on_chat_start: Optional[Callable[[], Any]] = None diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index fe019859b1..89839e20b0 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -8,6 +8,8 @@ from chainlit.user import User from fastapi import HTTPException +from chainlit import config + class OAuthProvider: id: str @@ -621,17 +623,22 @@ async def get_user_info(self, token: str): return (gitlab_user, user) -providers = [ - GithubOAuthProvider(), - GoogleOAuthProvider(), - AzureADOAuthProvider(), - AzureADHybridOAuthProvider(), - OktaOAuthProvider(), - Auth0OAuthProvider(), - DescopeOAuthProvider(), - AWSCognitoOAuthProvider(), - GitlabOAuthProvider(), -] +providers = ( + [ + GithubOAuthProvider(), + GoogleOAuthProvider(), + AzureADOAuthProvider(), + AzureADHybridOAuthProvider(), + OktaOAuthProvider(), + Auth0OAuthProvider(), + DescopeOAuthProvider(), + AWSCognitoOAuthProvider(), + GitlabOAuthProvider(), + ] + + [config.code.custom_oauth_provider()] + if config.code.custom_oauth_provider + else [] +) def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: From 477d81d115b68fc967ff1bc91d1d422a2a9b5e5a Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Mon, 2 Sep 2024 09:11:10 +0200 Subject: [PATCH 2/9] Add custom auth --- backend/chainlit/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 26dc93b117..874168ec1f 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -13,6 +13,7 @@ List, Literal, Optional, + Type, Union, ) @@ -293,8 +294,8 @@ class CodeSettings: Callable[[str, str, Dict[str, str], "User"], Awaitable[Optional["User"]]] ] = None # Callbacks for authenticate mechanism - custom_authenticate_user: Optional[Callable[[str], "User"]] - custom_oauth_provider: Optional[OAuthProvider] + custom_authenticate_user: Optional[Callable[[str], Awaitable["User"]]] = None + custom_oauth_provider: Optional[Type[OAuthProvider]] = None on_logout: Optional[Callable[["Request", "Response"], Any]] = None on_stop: Optional[Callable[[], Any]] = None on_chat_start: Optional[Callable[[], Any]] = None From d568c64e490d477d5c4c586d0c80c95643820f0d Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Mon, 2 Sep 2024 09:14:43 +0200 Subject: [PATCH 3/9] Refactoring --- backend/chainlit/auth.py | 6 +++--- backend/chainlit/oauth_providers.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/chainlit/auth.py b/backend/chainlit/auth.py index d7c478b7ee..9bd2073b55 100644 --- a/backend/chainlit/auth.py +++ b/backend/chainlit/auth.py @@ -42,9 +42,9 @@ def get_configuration(): "requireLogin": require_login(), "passwordAuth": config.code.password_auth_callback is not None, "headerAuth": config.code.header_auth_callback is not None, - "oauthProviders": get_configured_oauth_providers() - if is_oauth_enabled() - else [], + "oauthProviders": ( + get_configured_oauth_providers() if is_oauth_enabled() else [] + ), } diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index 89839e20b0..eecf3547fc 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -4,12 +4,11 @@ from typing import Dict, List, Optional, Tuple import httpx +from chainlit.config import config from chainlit.secret import random_secret from chainlit.user import User from fastapi import HTTPException -from chainlit import config - class OAuthProvider: id: str @@ -623,6 +622,7 @@ async def get_user_info(self, token: str): return (gitlab_user, user) +custom_oauth = config.code.custom_oauth_provider # type: ignore providers = ( [ GithubOAuthProvider(), @@ -635,8 +635,8 @@ async def get_user_info(self, token: str): AWSCognitoOAuthProvider(), GitlabOAuthProvider(), ] - + [config.code.custom_oauth_provider()] - if config.code.custom_oauth_provider + + [custom_oauth()] + if custom_oauth else [] ) From d084f58326571416a40d20d0983774d8f53b980b Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Mon, 2 Sep 2024 09:33:41 +0200 Subject: [PATCH 4/9] Refactoring --- backend/chainlit/__init__.py | 1 + backend/chainlit/auth.py | 2 +- backend/chainlit/config.py | 2 +- backend/chainlit/oauth_providers.py | 652 ------------------ backend/chainlit/oauth_providers/__init__.py | 0 .../oauth_providers/auth0_oauth_provider.py | 69 ++ .../aws_cognito_oauth_provider.py | 72 ++ .../azure_ad_hubrid_oauth_provider.py | 92 +++ .../azure_ad_oauth_provider.py | 89 +++ .../oauth_providers/descope_oauth_provider.py | 61 ++ backend/chainlit/oauth_providers/github.py | 63 ++ .../oauth_providers/gitlab_oauth_provider.py | 67 ++ backend/chainlit/oauth_providers/google.py | 58 ++ .../oauth_providers/oauth_provider.py | 22 + .../oauth_providers/okta_oauth_provider.py | 78 +++ backend/chainlit/oauth_providers/providers.py | 44 ++ backend/chainlit/server.py | 10 + 17 files changed, 728 insertions(+), 654 deletions(-) delete mode 100644 backend/chainlit/oauth_providers.py create mode 100644 backend/chainlit/oauth_providers/__init__.py create mode 100644 backend/chainlit/oauth_providers/auth0_oauth_provider.py create mode 100644 backend/chainlit/oauth_providers/aws_cognito_oauth_provider.py create mode 100644 backend/chainlit/oauth_providers/azure_ad_hubrid_oauth_provider.py create mode 100644 backend/chainlit/oauth_providers/azure_ad_oauth_provider.py create mode 100644 backend/chainlit/oauth_providers/descope_oauth_provider.py create mode 100644 backend/chainlit/oauth_providers/github.py create mode 100644 backend/chainlit/oauth_providers/gitlab_oauth_provider.py create mode 100644 backend/chainlit/oauth_providers/google.py create mode 100644 backend/chainlit/oauth_providers/oauth_provider.py create mode 100644 backend/chainlit/oauth_providers/okta_oauth_provider.py create mode 100644 backend/chainlit/oauth_providers/providers.py diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 0506ef38f3..f53c8f6bff 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -41,6 +41,7 @@ ErrorMessage, Message, ) +from chainlit.oauth_providers.providers import get_configured_oauth_providers from chainlit.step import Step, step from chainlit.sync import make_async, run_sync from chainlit.types import AudioChunk, ChatProfile, Starter diff --git a/backend/chainlit/auth.py b/backend/chainlit/auth.py index 9bd2073b55..049e94f158 100644 --- a/backend/chainlit/auth.py +++ b/backend/chainlit/auth.py @@ -5,7 +5,7 @@ import jwt from chainlit.config import config from chainlit.data import get_data_layer -from chainlit.oauth_providers import get_configured_oauth_providers +from chainlit.oauth_providers.providers import get_configured_oauth_providers from chainlit.user import User from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 874168ec1f..e8cefb9572 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -19,7 +19,7 @@ import tomli from chainlit.logger import logger -from chainlit.oauth_providers import OAuthProvider +from chainlit.oauth_providers.oauth_provider import OAuthProvider from chainlit.translations import lint_translation_json from chainlit.version import __version__ from dataclasses_json import DataClassJsonMixin diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py deleted file mode 100644 index eecf3547fc..0000000000 --- a/backend/chainlit/oauth_providers.py +++ /dev/null @@ -1,652 +0,0 @@ -import base64 -import os -import urllib.parse -from typing import Dict, List, Optional, Tuple - -import httpx -from chainlit.config import config -from chainlit.secret import random_secret -from chainlit.user import User -from fastapi import HTTPException - - -class OAuthProvider: - id: str - env: List[str] - client_id: str - client_secret: str - authorize_url: str - authorize_params: Dict[str, str] - - def is_configured(self): - return all([os.environ.get(env) for env in self.env]) - - async def get_token(self, code: str, url: str) -> str: - raise NotImplementedError() - - async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]: - raise NotImplementedError() - - -class GithubOAuthProvider(OAuthProvider): - id = "github" - env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"] - authorize_url = "https://github.com/login/oauth/authorize" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET") - self.authorize_params = { - "scope": "user:email", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - } - async with httpx.AsyncClient() as client: - response = await client.post( - "https://github.com/login/oauth/access_token", - data=payload, - ) - response.raise_for_status() - content = urllib.parse.parse_qs(response.text) - token = content.get("access_token", [""])[0] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - user_response = await client.get( - "https://api.github.com/user", - headers={"Authorization": f"token {token}"}, - ) - user_response.raise_for_status() - github_user = user_response.json() - - emails_response = await client.get( - "https://api.github.com/user/emails", - headers={"Authorization": f"token {token}"}, - ) - emails_response.raise_for_status() - emails = emails_response.json() - - github_user.update({"emails": emails}) - user = User( - identifier=github_user["login"], - metadata={"image": github_user["avatar_url"], "provider": "github"}, - ) - return (github_user, user) - - -class GoogleOAuthProvider(OAuthProvider): - id = "google" - env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"] - authorize_url = "https://accounts.google.com/o/oauth2/v2/auth" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET") - self.authorize_params = { - "scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", - "response_type": "code", - "access_type": "offline", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - "https://oauth2.googleapis.com/token", - data=payload, - ) - response.raise_for_status() - json = response.json() - token = json.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://www.googleapis.com/userinfo/v2/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - google_user = response.json() - user = User( - identifier=google_user["email"], - metadata={"image": google_user["picture"], "provider": "google"}, - ) - return (google_user, user) - - -class AzureADOAuthProvider(OAuthProvider): - id = "azure-ad" - env = [ - "OAUTH_AZURE_AD_CLIENT_ID", - "OAUTH_AZURE_AD_CLIENT_SECRET", - "OAUTH_AZURE_AD_TENANT_ID", - ] - authorize_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize" - if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - ) - token_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token" - if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/token" - ) - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET") - self.authorize_params = { - "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), - "response_type": "code", - "scope": "https://graph.microsoft.com/User.Read", - "response_mode": "query", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://graph.microsoft.com/v1.0/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - azure_user = response.json() - - try: - photo_response = await client.get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", - headers={"Authorization": f"Bearer {token}"}, - ) - photo_data = await photo_response.aread() - base64_image = base64.b64encode(photo_data) - azure_user["image"] = ( - f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" - ) - except Exception as e: - # Ignore errors getting the photo - pass - - user = User( - identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, - ) - return (azure_user, user) - - -class AzureADHybridOAuthProvider(OAuthProvider): - id = "azure-ad-hybrid" - env = [ - "OAUTH_AZURE_AD_HYBRID_CLIENT_ID", - "OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET", - "OAUTH_AZURE_AD_HYBRID_TENANT_ID", - ] - authorize_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize" - if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - ) - token_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token" - if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/token" - ) - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET") - nonce = random_secret(16) - self.authorize_params = { - "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"), - "response_type": "code id_token", - "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid", - "response_mode": "form_post", - "nonce": nonce, - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://graph.microsoft.com/v1.0/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - azure_user = response.json() - - try: - photo_response = await client.get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", - headers={"Authorization": f"Bearer {token}"}, - ) - photo_data = await photo_response.aread() - base64_image = base64.b64encode(photo_data) - azure_user["image"] = ( - f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" - ) - except Exception as e: - # Ignore errors getting the photo - pass - - user = User( - identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, - ) - return (azure_user, user) - - -class OktaOAuthProvider(OAuthProvider): - id = "okta" - env = [ - "OAUTH_OKTA_CLIENT_ID", - "OAUTH_OKTA_CLIENT_SECRET", - "OAUTH_OKTA_DOMAIN", - ] - # Avoid trailing slash in domain if supplied - domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET") - self.authorization_server_id = os.environ.get( - "OAUTH_OKTA_AUTHORIZATION_SERVER_ID", "" - ) - self.authorize_url = ( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize" - ) - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "response_mode": "query", - } - - def get_authorization_server_path(self): - if not self.authorization_server_id: - return "/default" - if self.authorization_server_id == "false": - return "" - return f"/{self.authorization_server_id}" - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token", - data=payload, - ) - response.raise_for_status() - json_data = response.json() - - token = json_data.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - okta_user = response.json() - - user = User( - identifier=okta_user.get("email"), - metadata={"image": "", "provider": "okta"}, - ) - return (okta_user, user) - - -class Auth0OAuthProvider(OAuthProvider): - id = "auth0" - env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"] - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET") - # Ensure that the domain does not have a trailing slash - self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}" - self.original_domain = ( - f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}" - if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN") - else self.domain - ) - - self.authorize_url = f"{self.domain}/authorize" - - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "audience": f"{self.original_domain}/userinfo", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.original_domain}/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - auth0_user = response.json() - user = User( - identifier=auth0_user.get("email"), - metadata={ - "image": auth0_user.get("picture", ""), - "provider": "auth0", - }, - ) - return (auth0_user, user) - - -class DescopeOAuthProvider(OAuthProvider): - id = "descope" - env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"] - # Ensure that the domain does not have a trailing slash - domain = f"https://api.descope.com/oauth2/v1" - - authorize_url = f"{domain}/authorize" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET") - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "audience": f"{self.domain}/userinfo", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} - ) - response.raise_for_status() # This will raise an exception for 4xx/5xx responses - descope_user = response.json() - - user = User( - identifier=descope_user.get("email"), - metadata={"image": "", "provider": "descope"}, - ) - return (descope_user, user) - - -class AWSCognitoOAuthProvider(OAuthProvider): - id = "aws-cognito" - env = [ - "OAUTH_COGNITO_CLIENT_ID", - "OAUTH_COGNITO_CLIENT_SECRET", - "OAUTH_COGNITO_DOMAIN", - ] - authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login" - token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET") - self.authorize_params = { - "response_type": "code", - "client_id": self.client_id, - "scope": "openid profile email", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - user_info_url = ( - f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo" - ) - async with httpx.AsyncClient() as client: - response = await client.get( - user_info_url, - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - cognito_user = response.json() - - # Customize user metadata as needed - user = User( - identifier=cognito_user["email"], - metadata={ - "image": cognito_user.get("picture", ""), - "provider": "aws-cognito", - }, - ) - return (cognito_user, user) - - -class GitlabOAuthProvider(OAuthProvider): - id = "gitlab" - env = [ - "OAUTH_GITLAB_CLIENT_ID", - "OAUTH_GITLAB_CLIENT_SECRET", - "OAUTH_GITLAB_DOMAIN", - ] - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET") - # Ensure that the domain does not have a trailing slash - self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}" - - self.authorize_url = f"{self.domain}/oauth/authorize" - - self.authorize_params = { - "scope": "openid profile email", - "response_type": "code", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/oauth/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - gitlab_user = response.json() - user = User( - identifier=gitlab_user.get("email"), - metadata={ - "image": gitlab_user.get("picture", ""), - "provider": "gitlab", - }, - ) - return (gitlab_user, user) - - -custom_oauth = config.code.custom_oauth_provider # type: ignore -providers = ( - [ - GithubOAuthProvider(), - GoogleOAuthProvider(), - AzureADOAuthProvider(), - AzureADHybridOAuthProvider(), - OktaOAuthProvider(), - Auth0OAuthProvider(), - DescopeOAuthProvider(), - AWSCognitoOAuthProvider(), - GitlabOAuthProvider(), - ] - + [custom_oauth()] - if custom_oauth - else [] -) - - -def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: - for p in providers: - if p.id == provider: - return p - return None - - -def get_configured_oauth_providers(): - return [p.id for p in providers if p.is_configured()] diff --git a/backend/chainlit/oauth_providers/__init__.py b/backend/chainlit/oauth_providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/chainlit/oauth_providers/auth0_oauth_provider.py b/backend/chainlit/oauth_providers/auth0_oauth_provider.py new file mode 100644 index 0000000000..97f7692c1b --- /dev/null +++ b/backend/chainlit/oauth_providers/auth0_oauth_provider.py @@ -0,0 +1,69 @@ +import os + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.user import User +from fastapi import HTTPException + + +class Auth0OAuthProvider(OAuthProvider): + id = "auth0" + env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"] + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET") + # Ensure that the domain does not have a trailing slash + self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}" + self.original_domain = ( + f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}" + if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN") + else self.domain + ) + + self.authorize_url = f"{self.domain}/authorize" + + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "audience": f"{self.original_domain}/userinfo", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.original_domain}/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + auth0_user = response.json() + user = User( + identifier=auth0_user.get("email"), + metadata={ + "image": auth0_user.get("picture", ""), + "provider": "auth0", + }, + ) + return (auth0_user, user) diff --git a/backend/chainlit/oauth_providers/aws_cognito_oauth_provider.py b/backend/chainlit/oauth_providers/aws_cognito_oauth_provider.py new file mode 100644 index 0000000000..f42df632a3 --- /dev/null +++ b/backend/chainlit/oauth_providers/aws_cognito_oauth_provider.py @@ -0,0 +1,72 @@ +import os + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.user import User +from fastapi import HTTPException + + +class AWSCognitoOAuthProvider(OAuthProvider): + id = "aws-cognito" + env = [ + "OAUTH_COGNITO_CLIENT_ID", + "OAUTH_COGNITO_CLIENT_SECRET", + "OAUTH_COGNITO_DOMAIN", + ] + authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login" + token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET") + self.authorize_params = { + "response_type": "code", + "client_id": self.client_id, + "scope": "openid profile email", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + user_info_url = ( + f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo" + ) + async with httpx.AsyncClient() as client: + response = await client.get( + user_info_url, + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + cognito_user = response.json() + + # Customize user metadata as needed + user = User( + identifier=cognito_user["email"], + metadata={ + "image": cognito_user.get("picture", ""), + "provider": "aws-cognito", + }, + ) + return (cognito_user, user) diff --git a/backend/chainlit/oauth_providers/azure_ad_hubrid_oauth_provider.py b/backend/chainlit/oauth_providers/azure_ad_hubrid_oauth_provider.py new file mode 100644 index 0000000000..71ff92135a --- /dev/null +++ b/backend/chainlit/oauth_providers/azure_ad_hubrid_oauth_provider.py @@ -0,0 +1,92 @@ +import base64 +import os + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.secret import random_secret +from chainlit.user import User +from fastapi import HTTPException + + +class AzureADHybridOAuthProvider(OAuthProvider): + id = "azure-ad-hybrid" + env = [ + "OAUTH_AZURE_AD_HYBRID_CLIENT_ID", + "OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET", + "OAUTH_AZURE_AD_HYBRID_TENANT_ID", + ] + authorize_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize" + if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + ) + token_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token" + if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/token" + ) + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET") + nonce = random_secret(16) + self.authorize_params = { + "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"), + "response_type": "code id_token", + "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid", + "response_mode": "form_post", + "nonce": nonce, + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json["access_token"] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://graph.microsoft.com/v1.0/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + azure_user = response.json() + + try: + photo_response = await client.get( + "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", + headers={"Authorization": f"Bearer {token}"}, + ) + photo_data = await photo_response.aread() + base64_image = base64.b64encode(photo_data) + azure_user["image"] = ( + f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" + ) + except Exception as e: + # Ignore errors getting the photo + pass + + user = User( + identifier=azure_user["userPrincipalName"], + metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + ) + return (azure_user, user) diff --git a/backend/chainlit/oauth_providers/azure_ad_oauth_provider.py b/backend/chainlit/oauth_providers/azure_ad_oauth_provider.py new file mode 100644 index 0000000000..6e25d086a1 --- /dev/null +++ b/backend/chainlit/oauth_providers/azure_ad_oauth_provider.py @@ -0,0 +1,89 @@ +import base64 +import os + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.user import User +from fastapi import HTTPException + + +class AzureADOAuthProvider(OAuthProvider): + id = "azure-ad" + env = [ + "OAUTH_AZURE_AD_CLIENT_ID", + "OAUTH_AZURE_AD_CLIENT_SECRET", + "OAUTH_AZURE_AD_TENANT_ID", + ] + authorize_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize" + if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + ) + token_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token" + if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/token" + ) + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET") + self.authorize_params = { + "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), + "response_type": "code", + "scope": "https://graph.microsoft.com/User.Read", + "response_mode": "query", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json["access_token"] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://graph.microsoft.com/v1.0/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + azure_user = response.json() + + try: + photo_response = await client.get( + "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", + headers={"Authorization": f"Bearer {token}"}, + ) + photo_data = await photo_response.aread() + base64_image = base64.b64encode(photo_data) + azure_user["image"] = ( + f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" + ) + except Exception as e: + # Ignore errors getting the photo + pass + + user = User( + identifier=azure_user["userPrincipalName"], + metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + ) + return (azure_user, user) diff --git a/backend/chainlit/oauth_providers/descope_oauth_provider.py b/backend/chainlit/oauth_providers/descope_oauth_provider.py new file mode 100644 index 0000000000..7cbf004819 --- /dev/null +++ b/backend/chainlit/oauth_providers/descope_oauth_provider.py @@ -0,0 +1,61 @@ +import os + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.user import User + + +class DescopeOAuthProvider(OAuthProvider): + id = "descope" + env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"] + # Ensure that the domain does not have a trailing slash + domain = f"https://api.descope.com/oauth2/v1" + + authorize_url = f"{domain}/authorize" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET") + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "audience": f"{self.domain}/userinfo", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} + ) + response.raise_for_status() # This will raise an exception for 4xx/5xx responses + descope_user = response.json() + + user = User( + identifier=descope_user.get("email"), + metadata={"image": "", "provider": "descope"}, + ) + return (descope_user, user) diff --git a/backend/chainlit/oauth_providers/github.py b/backend/chainlit/oauth_providers/github.py new file mode 100644 index 0000000000..1b2a499b96 --- /dev/null +++ b/backend/chainlit/oauth_providers/github.py @@ -0,0 +1,63 @@ +import os +import urllib.parse + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.user import User +from fastapi import HTTPException + + +class GithubOAuthProvider(OAuthProvider): + id = "github" + env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"] + authorize_url = "https://github.com/login/oauth/authorize" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET") + self.authorize_params = { + "scope": "user:email", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + } + async with httpx.AsyncClient() as client: + response = await client.post( + "https://github.com/login/oauth/access_token", + data=payload, + ) + response.raise_for_status() + content = urllib.parse.parse_qs(response.text) + token = content.get("access_token", [""])[0] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + user_response = await client.get( + "https://api.github.com/user", + headers={"Authorization": f"token {token}"}, + ) + user_response.raise_for_status() + github_user = user_response.json() + + emails_response = await client.get( + "https://api.github.com/user/emails", + headers={"Authorization": f"token {token}"}, + ) + emails_response.raise_for_status() + emails = emails_response.json() + + github_user.update({"emails": emails}) + user = User( + identifier=github_user["login"], + metadata={"image": github_user["avatar_url"], "provider": "github"}, + ) + return (github_user, user) diff --git a/backend/chainlit/oauth_providers/gitlab_oauth_provider.py b/backend/chainlit/oauth_providers/gitlab_oauth_provider.py new file mode 100644 index 0000000000..26c1907402 --- /dev/null +++ b/backend/chainlit/oauth_providers/gitlab_oauth_provider.py @@ -0,0 +1,67 @@ +import os + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.user import User +from fastapi import HTTPException + + +class GitlabOAuthProvider(OAuthProvider): + id = "gitlab" + env = [ + "OAUTH_GITLAB_CLIENT_ID", + "OAUTH_GITLAB_CLIENT_SECRET", + "OAUTH_GITLAB_DOMAIN", + ] + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET") + # Ensure that the domain does not have a trailing slash + self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}" + + self.authorize_url = f"{self.domain}/oauth/authorize" + + self.authorize_params = { + "scope": "openid profile email", + "response_type": "code", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/oauth/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + gitlab_user = response.json() + user = User( + identifier=gitlab_user.get("email"), + metadata={ + "image": gitlab_user.get("picture", ""), + "provider": "gitlab", + }, + ) + return (gitlab_user, user) diff --git a/backend/chainlit/oauth_providers/google.py b/backend/chainlit/oauth_providers/google.py new file mode 100644 index 0000000000..6c0debd3fc --- /dev/null +++ b/backend/chainlit/oauth_providers/google.py @@ -0,0 +1,58 @@ +import os + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.user import User + + +class GoogleOAuthProvider(OAuthProvider): + id = "google" + env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"] + authorize_url = "https://accounts.google.com/o/oauth2/v2/auth" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET") + self.authorize_params = { + "scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", + "response_type": "code", + "access_type": "offline", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + "https://oauth2.googleapis.com/token", + data=payload, + ) + response.raise_for_status() + json = response.json() + token = json.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://www.googleapis.com/userinfo/v2/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + google_user = response.json() + user = User( + identifier=google_user["email"], + metadata={"image": google_user["picture"], "provider": "google"}, + ) + return (google_user, user) diff --git a/backend/chainlit/oauth_providers/oauth_provider.py b/backend/chainlit/oauth_providers/oauth_provider.py new file mode 100644 index 0000000000..4db4d0d413 --- /dev/null +++ b/backend/chainlit/oauth_providers/oauth_provider.py @@ -0,0 +1,22 @@ +import os +from typing import Dict, List, Tuple + +from chainlit.user import User + + +class OAuthProvider: + id: str + env: List[str] + client_id: str + client_secret: str + authorize_url: str + authorize_params: Dict[str, str] + + def is_configured(self): + return all([os.environ.get(env) for env in self.env]) + + async def get_token(self, code: str, url: str) -> str: + raise NotImplementedError() + + async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]: + raise NotImplementedError() diff --git a/backend/chainlit/oauth_providers/okta_oauth_provider.py b/backend/chainlit/oauth_providers/okta_oauth_provider.py new file mode 100644 index 0000000000..86270c1407 --- /dev/null +++ b/backend/chainlit/oauth_providers/okta_oauth_provider.py @@ -0,0 +1,78 @@ +import os + +import httpx +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.user import User + + +class OktaOAuthProvider(OAuthProvider): + id = "okta" + env = [ + "OAUTH_OKTA_CLIENT_ID", + "OAUTH_OKTA_CLIENT_SECRET", + "OAUTH_OKTA_DOMAIN", + ] + # Avoid trailing slash in domain if supplied + domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET") + self.authorization_server_id = os.environ.get( + "OAUTH_OKTA_AUTHORIZATION_SERVER_ID", "" + ) + self.authorize_url = ( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize" + ) + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "response_mode": "query", + } + + def get_authorization_server_path(self): + if not self.authorization_server_id: + return "/default" + if self.authorization_server_id == "false": + return "" + return f"/{self.authorization_server_id}" + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token", + data=payload, + ) + response.raise_for_status() + json_data = response.json() + + token = json_data.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + okta_user = response.json() + + user = User( + identifier=okta_user.get("email"), + metadata={"image": "", "provider": "okta"}, + ) + return (okta_user, user) diff --git a/backend/chainlit/oauth_providers/providers.py b/backend/chainlit/oauth_providers/providers.py new file mode 100644 index 0000000000..afb59dd079 --- /dev/null +++ b/backend/chainlit/oauth_providers/providers.py @@ -0,0 +1,44 @@ +from typing import Optional + +from chainlit.config import config +from chainlit.oauth_providers.auth0_oauth_provider import Auth0OAuthProvider +from chainlit.oauth_providers.aws_cognito_oauth_provider import AWSCognitoOAuthProvider +from chainlit.oauth_providers.azure_ad_hubrid_oauth_provider import ( + AzureADHybridOAuthProvider, +) +from chainlit.oauth_providers.azure_ad_oauth_provider import AzureADOAuthProvider +from chainlit.oauth_providers.descope_oauth_provider import DescopeOAuthProvider +from chainlit.oauth_providers.github import GithubOAuthProvider +from chainlit.oauth_providers.gitlab_oauth_provider import GitlabOAuthProvider +from chainlit.oauth_providers.google import GoogleOAuthProvider +from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth_providers.okta_oauth_provider import OktaOAuthProvider + +custom_oauth = config.code.custom_oauth_provider +providers = ( + [ + GithubOAuthProvider(), + GoogleOAuthProvider(), + AzureADOAuthProvider(), + AzureADHybridOAuthProvider(), + OktaOAuthProvider(), + Auth0OAuthProvider(), + DescopeOAuthProvider(), + AWSCognitoOAuthProvider(), + GitlabOAuthProvider(), + ] + + [custom_oauth()] + if custom_oauth + else [] +) + + +def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: + for p in providers: + if p.id == provider: + return p + return None + + +def get_configured_oauth_providers(): + return [p.id for p in providers if p.is_configured()] diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 597830ee43..308b9fae59 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -6,6 +6,16 @@ import re import shutil import urllib.parse +from typing import Any, Optional, Union + +from chainlit.oauth_providers.providers import get_oauth_provider +from chainlit.secret import random_secret + +mimetypes.add_type("application/javascript", ".js") +mimetypes.add_type("text/css", ".css") + +import asyncio +import os import webbrowser from contextlib import asynccontextmanager from pathlib import Path From deea68d3adc94dc9156e05baacff17f379fbc177 Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Mon, 2 Sep 2024 10:02:08 +0200 Subject: [PATCH 5/9] Add decorators --- backend/chainlit/__init__.py | 198 ++++++++++++++++++++++++++++++++++- 1 file changed, 197 insertions(+), 1 deletion(-) diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index f53c8f6bff..e502d78b2a 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -12,7 +12,7 @@ logger.info("Loaded .env file") import asyncio -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional import chainlit.input_widget as input_widget from chainlit.action import Action @@ -76,6 +76,202 @@ from chainlit.langchain.callbacks import ( AsyncLangchainCallbackHandler, LangchainCallbackHandler, + config.code.oauth_callback = wrap_user_function(func) + return func + + +@trace +def custom_authenticate_user(func: Callable[[str], Awaitable[User]]) -> Callable: + """ + A decorator to authenticate the user via custom token validation. + + Args: + func (Callable[[str, str, Dict[str, str], User], Optional[User]]): The authentication callback to execute. + + Returns: + Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated authentication callback. + """ + + if len(get_configured_oauth_providers()) == 0: + raise ValueError( + "You must set the environment variable for at least one oauth provider to use oauth authentication." + ) + + config.code.custom_authenticate_user = wrap_user_function(func) + return func + + +@trace +def custom_oauth_provider(func: Callable[[str], Awaitable[User]]) -> Callable: + """ + A decorator to integrate custom OAuth provider logic for user authentication. + + Args: + func (Callable[[str, str, Dict[str, str], User], Optional[User]]): A function that returns an instance of the OAuthProvider class, encapsulating the logic and details for the custom OAuth provider. + + Returns: + Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated callback function that handles authentication via the custom OAuth provider. + """ + + if len(get_configured_oauth_providers()) == 0: + raise ValueError( + "You must set the environment variable for at least one oauth provider to use oauth authentication." + ) + + config.code.custom_oauth_provider = wrap_user_function(func) + return func + + +@trace +def on_logout(func: Callable[[Request, Response], Any]) -> Callable: + """ + Function called when the user logs out. + Takes the FastAPI request and response as parameters. + """ + + config.code.on_logout = wrap_user_function(func) + return func + + +@trace +def on_message(func: Callable) -> Callable: + """ + Framework agnostic decorator to react to messages coming from the UI. + The decorated function is called every time a new message is received. + + Args: + func (Callable[[Message], Any]): The function to be called when a new message is received. Takes a cl.Message. + + Returns: + Callable[[str], Any]: The decorated on_message function. + """ + + async def with_parent_id(message: Message): + async with Step(name="on_message", type="run", parent_id=message.id) as s: + s.input = message.content + if len(inspect.signature(func).parameters) > 0: + await func(message) + else: + await func() + + config.code.on_message = wrap_user_function(with_parent_id) + return func + + +@trace +def on_chat_start(func: Callable) -> Callable: + """ + Hook to react to the user websocket connection event. + + Args: + func (Callable[], Any]): The connection hook to execute. + + Returns: + Callable[], Any]: The decorated hook. + """ + + config.code.on_chat_start = wrap_user_function( + step(func, name="on_chat_start", type="run"), with_task=True + ) + return func + + +@trace +def on_chat_resume(func: Callable[[ThreadDict], Any]) -> Callable: + """ + Hook to react to resume websocket connection event. + + Args: + func (Callable[], Any]): The connection hook to execute. + + Returns: + Callable[], Any]: The decorated hook. + """ + + config.code.on_chat_resume = wrap_user_function(func, with_task=True) + return func + + +@trace +def set_chat_profiles( + func: Callable[[Optional["User"]], List["ChatProfile"]] +) -> Callable: + """ + Programmatic declaration of the available chat profiles (can depend on the User from the session if authentication is setup). + + Args: + func (Callable[[Optional["User"]], List["ChatProfile"]]): The function declaring the chat profiles. + + Returns: + Callable[[Optional["User"]], List["ChatProfile"]]: The decorated function. + """ + + config.code.set_chat_profiles = wrap_user_function(func) + return func + + +@trace +def set_starters(func: Callable[[Optional["User"]], List["Starter"]]) -> Callable: + """ + Programmatic declaration of the available starter (can depend on the User from the session if authentication is setup). + + Args: + func (Callable[[Optional["User"]], List["Starter"]]): The function declaring the starters. + + Returns: + Callable[[Optional["User"]], List["Starter"]]: The decorated function. + """ + + config.code.set_starters = wrap_user_function(func) + return func + + +@trace +def on_chat_end(func: Callable) -> Callable: + """ + Hook to react to the user websocket disconnect event. + + Args: + func (Callable[], Any]): The disconnect hook to execute. + + Returns: + Callable[], Any]: The decorated hook. + """ + + config.code.on_chat_end = wrap_user_function(func, with_task=True) + return func + + +@trace +def on_audio_chunk(func: Callable) -> Callable: + """ + Hook to react to the audio chunks being sent. + + Args: + chunk (AudioChunk): The audio chunk being sent. + + Returns: + Callable[], Any]: The decorated hook. + """ + + config.code.on_audio_chunk = wrap_user_function(func, with_task=False) + return func + + +@trace +def on_audio_end(func: Callable) -> Callable: + """ + Hook to react to the audio stream ending. This is called after the last audio chunk is sent. + + Args: + elements ([List[Element]): The files that were uploaded before starting the audio stream (if any). + + Returns: + Callable[], Any]: The decorated hook. + """ + + config.code.on_audio_end = wrap_user_function( + step(func, name="on_audio_end", type="run"), with_task=True ) from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler from chainlit.mistralai import instrument_mistralai From 7be90b4570924cb864eae176da786f52b1247f5a Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Mon, 2 Sep 2024 10:20:28 +0200 Subject: [PATCH 6/9] Refactoring --- backend/chainlit/__init__.py | 2 +- backend/chainlit/auth.py | 2 +- backend/chainlit/config.py | 4 +- .../{oauth_providers => oauth}/__init__.py | 0 .../auth0_oauth_provider.py | 2 +- .../aws_cognito_oauth_provider.py | 2 +- .../azure_ad_hubrid_oauth_provider.py | 2 +- .../azure_ad_oauth_provider.py | 2 +- .../descope_oauth_provider.py | 2 +- .../{oauth_providers => oauth}/github.py | 2 +- .../gitlab_oauth_provider.py | 2 +- .../{oauth_providers => oauth}/google.py | 2 +- .../oauth_provider.py | 0 .../okta_oauth_provider.py | 2 +- backend/chainlit/oauth/providers.py | 42 ++++++++++++++++++ backend/chainlit/oauth_providers.py | 20 +++++++++ backend/chainlit/oauth_providers/providers.py | 44 ------------------- backend/chainlit/server.py | 2 +- 18 files changed, 76 insertions(+), 58 deletions(-) rename backend/chainlit/{oauth_providers => oauth}/__init__.py (100%) rename backend/chainlit/{oauth_providers => oauth}/auth0_oauth_provider.py (97%) rename backend/chainlit/{oauth_providers => oauth}/aws_cognito_oauth_provider.py (97%) rename backend/chainlit/{oauth_providers => oauth}/azure_ad_hubrid_oauth_provider.py (98%) rename backend/chainlit/{oauth_providers => oauth}/azure_ad_oauth_provider.py (97%) rename backend/chainlit/{oauth_providers => oauth}/descope_oauth_provider.py (96%) rename backend/chainlit/{oauth_providers => oauth}/github.py (97%) rename backend/chainlit/{oauth_providers => oauth}/gitlab_oauth_provider.py (97%) rename backend/chainlit/{oauth_providers => oauth}/google.py (96%) rename backend/chainlit/{oauth_providers => oauth}/oauth_provider.py (100%) rename backend/chainlit/{oauth_providers => oauth}/okta_oauth_provider.py (97%) create mode 100644 backend/chainlit/oauth/providers.py create mode 100644 backend/chainlit/oauth_providers.py delete mode 100644 backend/chainlit/oauth_providers/providers.py diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index e502d78b2a..71071d7e52 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -41,7 +41,7 @@ ErrorMessage, Message, ) -from chainlit.oauth_providers.providers import get_configured_oauth_providers +from chainlit.oauth.providers import get_configured_oauth_providers from chainlit.step import Step, step from chainlit.sync import make_async, run_sync from chainlit.types import AudioChunk, ChatProfile, Starter diff --git a/backend/chainlit/auth.py b/backend/chainlit/auth.py index 049e94f158..a4b1b326a7 100644 --- a/backend/chainlit/auth.py +++ b/backend/chainlit/auth.py @@ -5,7 +5,7 @@ import jwt from chainlit.config import config from chainlit.data import get_data_layer -from chainlit.oauth_providers.providers import get_configured_oauth_providers +from chainlit.oauth.providers import get_configured_oauth_providers from chainlit.user import User from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index e8cefb9572..22fc0d6ec9 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -19,7 +19,7 @@ import tomli from chainlit.logger import logger -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.translations import lint_translation_json from chainlit.version import __version__ from dataclasses_json import DataClassJsonMixin @@ -295,7 +295,7 @@ class CodeSettings: ] = None # Callbacks for authenticate mechanism custom_authenticate_user: Optional[Callable[[str], Awaitable["User"]]] = None - custom_oauth_provider: Optional[Type[OAuthProvider]] = None + custom_oauth_provider: Optional[Callable[[], Type[OAuthProvider]]] = None on_logout: Optional[Callable[["Request", "Response"], Any]] = None on_stop: Optional[Callable[[], Any]] = None on_chat_start: Optional[Callable[[], Any]] = None diff --git a/backend/chainlit/oauth_providers/__init__.py b/backend/chainlit/oauth/__init__.py similarity index 100% rename from backend/chainlit/oauth_providers/__init__.py rename to backend/chainlit/oauth/__init__.py diff --git a/backend/chainlit/oauth_providers/auth0_oauth_provider.py b/backend/chainlit/oauth/auth0_oauth_provider.py similarity index 97% rename from backend/chainlit/oauth_providers/auth0_oauth_provider.py rename to backend/chainlit/oauth/auth0_oauth_provider.py index 97f7692c1b..91e70fb376 100644 --- a/backend/chainlit/oauth_providers/auth0_oauth_provider.py +++ b/backend/chainlit/oauth/auth0_oauth_provider.py @@ -1,7 +1,7 @@ import os import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.user import User from fastapi import HTTPException diff --git a/backend/chainlit/oauth_providers/aws_cognito_oauth_provider.py b/backend/chainlit/oauth/aws_cognito_oauth_provider.py similarity index 97% rename from backend/chainlit/oauth_providers/aws_cognito_oauth_provider.py rename to backend/chainlit/oauth/aws_cognito_oauth_provider.py index f42df632a3..d5d286d134 100644 --- a/backend/chainlit/oauth_providers/aws_cognito_oauth_provider.py +++ b/backend/chainlit/oauth/aws_cognito_oauth_provider.py @@ -1,7 +1,7 @@ import os import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.user import User from fastapi import HTTPException diff --git a/backend/chainlit/oauth_providers/azure_ad_hubrid_oauth_provider.py b/backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py similarity index 98% rename from backend/chainlit/oauth_providers/azure_ad_hubrid_oauth_provider.py rename to backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py index 71ff92135a..da935d1543 100644 --- a/backend/chainlit/oauth_providers/azure_ad_hubrid_oauth_provider.py +++ b/backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py @@ -2,7 +2,7 @@ import os import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.secret import random_secret from chainlit.user import User from fastapi import HTTPException diff --git a/backend/chainlit/oauth_providers/azure_ad_oauth_provider.py b/backend/chainlit/oauth/azure_ad_oauth_provider.py similarity index 97% rename from backend/chainlit/oauth_providers/azure_ad_oauth_provider.py rename to backend/chainlit/oauth/azure_ad_oauth_provider.py index 6e25d086a1..2ed32bbe28 100644 --- a/backend/chainlit/oauth_providers/azure_ad_oauth_provider.py +++ b/backend/chainlit/oauth/azure_ad_oauth_provider.py @@ -2,7 +2,7 @@ import os import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.user import User from fastapi import HTTPException diff --git a/backend/chainlit/oauth_providers/descope_oauth_provider.py b/backend/chainlit/oauth/descope_oauth_provider.py similarity index 96% rename from backend/chainlit/oauth_providers/descope_oauth_provider.py rename to backend/chainlit/oauth/descope_oauth_provider.py index 7cbf004819..08c96da0d5 100644 --- a/backend/chainlit/oauth_providers/descope_oauth_provider.py +++ b/backend/chainlit/oauth/descope_oauth_provider.py @@ -1,7 +1,7 @@ import os import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.user import User diff --git a/backend/chainlit/oauth_providers/github.py b/backend/chainlit/oauth/github.py similarity index 97% rename from backend/chainlit/oauth_providers/github.py rename to backend/chainlit/oauth/github.py index 1b2a499b96..5ab0d72059 100644 --- a/backend/chainlit/oauth_providers/github.py +++ b/backend/chainlit/oauth/github.py @@ -2,7 +2,7 @@ import urllib.parse import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.user import User from fastapi import HTTPException diff --git a/backend/chainlit/oauth_providers/gitlab_oauth_provider.py b/backend/chainlit/oauth/gitlab_oauth_provider.py similarity index 97% rename from backend/chainlit/oauth_providers/gitlab_oauth_provider.py rename to backend/chainlit/oauth/gitlab_oauth_provider.py index 26c1907402..22e993a77b 100644 --- a/backend/chainlit/oauth_providers/gitlab_oauth_provider.py +++ b/backend/chainlit/oauth/gitlab_oauth_provider.py @@ -1,7 +1,7 @@ import os import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.user import User from fastapi import HTTPException diff --git a/backend/chainlit/oauth_providers/google.py b/backend/chainlit/oauth/google.py similarity index 96% rename from backend/chainlit/oauth_providers/google.py rename to backend/chainlit/oauth/google.py index 6c0debd3fc..0d4cc1cfea 100644 --- a/backend/chainlit/oauth_providers/google.py +++ b/backend/chainlit/oauth/google.py @@ -1,7 +1,7 @@ import os import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.user import User diff --git a/backend/chainlit/oauth_providers/oauth_provider.py b/backend/chainlit/oauth/oauth_provider.py similarity index 100% rename from backend/chainlit/oauth_providers/oauth_provider.py rename to backend/chainlit/oauth/oauth_provider.py diff --git a/backend/chainlit/oauth_providers/okta_oauth_provider.py b/backend/chainlit/oauth/okta_oauth_provider.py similarity index 97% rename from backend/chainlit/oauth_providers/okta_oauth_provider.py rename to backend/chainlit/oauth/okta_oauth_provider.py index 86270c1407..a531ddadb0 100644 --- a/backend/chainlit/oauth_providers/okta_oauth_provider.py +++ b/backend/chainlit/oauth/okta_oauth_provider.py @@ -1,7 +1,7 @@ import os import httpx -from chainlit.oauth_providers.oauth_provider import OAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider from chainlit.user import User diff --git a/backend/chainlit/oauth/providers.py b/backend/chainlit/oauth/providers.py new file mode 100644 index 0000000000..a9510e67a7 --- /dev/null +++ b/backend/chainlit/oauth/providers.py @@ -0,0 +1,42 @@ +from typing import Optional + +from chainlit.config import config +from chainlit.oauth.auth0_oauth_provider import Auth0OAuthProvider +from chainlit.oauth.aws_cognito_oauth_provider import AWSCognitoOAuthProvider +from chainlit.oauth.azure_ad_hubrid_oauth_provider import AzureADHybridOAuthProvider +from chainlit.oauth.azure_ad_oauth_provider import AzureADOAuthProvider +from chainlit.oauth.descope_oauth_provider import DescopeOAuthProvider +from chainlit.oauth.github import GithubOAuthProvider +from chainlit.oauth.gitlab_oauth_provider import GitlabOAuthProvider +from chainlit.oauth.google import GoogleOAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider +from chainlit.oauth.okta_oauth_provider import OktaOAuthProvider + +custom_oauth = config.code.custom_oauth_provider +providers = ( + [ + GithubOAuthProvider(), + GoogleOAuthProvider(), + AzureADOAuthProvider(), + AzureADHybridOAuthProvider(), + OktaOAuthProvider(), + Auth0OAuthProvider(), + DescopeOAuthProvider(), + AWSCognitoOAuthProvider(), + GitlabOAuthProvider(), + ] + + [custom_oauth()] + if custom_oauth + else [] +) + + +def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: + for p in providers: + if p.id == provider: + return p + return None + + +def get_configured_oauth_providers(): + return [p.id for p in providers if p.is_configured()] diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py new file mode 100644 index 0000000000..40e90f85f3 --- /dev/null +++ b/backend/chainlit/oauth_providers.py @@ -0,0 +1,20 @@ +import warnings + +warnings.warn( + "The 'oauth_providers' module is deprecated and will be removed in a future version. " + "Please use 'oauth' instead.", + DeprecationWarning, + stacklevel=2, +) + +from chainlit.oauth.providers import ( + get_configured_oauth_providers, + get_oauth_provider, + providers, +) + +__all__ = [ + "providers", + "get_oauth_provider", + "get_configured_oauth_providers", +] diff --git a/backend/chainlit/oauth_providers/providers.py b/backend/chainlit/oauth_providers/providers.py deleted file mode 100644 index afb59dd079..0000000000 --- a/backend/chainlit/oauth_providers/providers.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Optional - -from chainlit.config import config -from chainlit.oauth_providers.auth0_oauth_provider import Auth0OAuthProvider -from chainlit.oauth_providers.aws_cognito_oauth_provider import AWSCognitoOAuthProvider -from chainlit.oauth_providers.azure_ad_hubrid_oauth_provider import ( - AzureADHybridOAuthProvider, -) -from chainlit.oauth_providers.azure_ad_oauth_provider import AzureADOAuthProvider -from chainlit.oauth_providers.descope_oauth_provider import DescopeOAuthProvider -from chainlit.oauth_providers.github import GithubOAuthProvider -from chainlit.oauth_providers.gitlab_oauth_provider import GitlabOAuthProvider -from chainlit.oauth_providers.google import GoogleOAuthProvider -from chainlit.oauth_providers.oauth_provider import OAuthProvider -from chainlit.oauth_providers.okta_oauth_provider import OktaOAuthProvider - -custom_oauth = config.code.custom_oauth_provider -providers = ( - [ - GithubOAuthProvider(), - GoogleOAuthProvider(), - AzureADOAuthProvider(), - AzureADHybridOAuthProvider(), - OktaOAuthProvider(), - Auth0OAuthProvider(), - DescopeOAuthProvider(), - AWSCognitoOAuthProvider(), - GitlabOAuthProvider(), - ] - + [custom_oauth()] - if custom_oauth - else [] -) - - -def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: - for p in providers: - if p.id == provider: - return p - return None - - -def get_configured_oauth_providers(): - return [p.id for p in providers if p.is_configured()] diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 308b9fae59..f88ac68057 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -8,7 +8,7 @@ import urllib.parse from typing import Any, Optional, Union -from chainlit.oauth_providers.providers import get_oauth_provider +from chainlit.oauth.providers import get_oauth_provider from chainlit.secret import random_secret mimetypes.add_type("application/javascript", ".js") From 1c760feb590777caa3db00b502583ce0255e8b0e Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Mon, 2 Sep 2024 10:26:09 +0200 Subject: [PATCH 7/9] Refactoring --- backend/chainlit/oauth_providers.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index 40e90f85f3..c8d45bcc4c 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -7,6 +7,16 @@ stacklevel=2, ) +from chainlit.oauth.auth0_oauth_provider import Auth0OAuthProvider +from chainlit.oauth.aws_cognito_oauth_provider import AWSCognitoOAuthProvider +from chainlit.oauth.azure_ad_hubrid_oauth_provider import AzureADHybridOAuthProvider +from chainlit.oauth.azure_ad_oauth_provider import AzureADOAuthProvider +from chainlit.oauth.descope_oauth_provider import DescopeOAuthProvider +from chainlit.oauth.github import GithubOAuthProvider +from chainlit.oauth.gitlab_oauth_provider import GitlabOAuthProvider +from chainlit.oauth.google import GoogleOAuthProvider +from chainlit.oauth.oauth_provider import OAuthProvider +from chainlit.oauth.okta_oauth_provider import OktaOAuthProvider from chainlit.oauth.providers import ( get_configured_oauth_providers, get_oauth_provider, @@ -17,4 +27,14 @@ "providers", "get_oauth_provider", "get_configured_oauth_providers", + "OAuthProvider", + "GithubOAuthProvider", + "GoogleOAuthProvider", + "AzureADOAuthProvider", + "AzureADHybridOAuthProvider", + "OktaOAuthProvider", + "Auth0OAuthProvider", + "DescopeOAuthProvider", + "AWSCognitoOAuthProvider", + "GitlabOAuthProvider", ] From fa98dd9e902eba155674f33478d031673a209097 Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Mon, 2 Sep 2024 11:06:35 +0200 Subject: [PATCH 8/9] Switch back to old structure --- backend/chainlit/__init__.py | 199 +---- backend/chainlit/auth.py | 2 +- backend/chainlit/callbacks.py | 38 +- backend/chainlit/chat_context.py | 4 +- backend/chainlit/config.py | 2 +- backend/chainlit/data/acl.py | 2 +- backend/chainlit/data/sql_alchemy.py | 17 +- backend/chainlit/oauth/__init__.py | 0 .../chainlit/oauth/auth0_oauth_provider.py | 69 -- .../oauth/aws_cognito_oauth_provider.py | 72 -- .../oauth/azure_ad_hubrid_oauth_provider.py | 92 --- .../chainlit/oauth/azure_ad_oauth_provider.py | 89 --- .../chainlit/oauth/descope_oauth_provider.py | 61 -- backend/chainlit/oauth/github.py | 63 -- .../chainlit/oauth/gitlab_oauth_provider.py | 67 -- backend/chainlit/oauth/google.py | 58 -- backend/chainlit/oauth/oauth_provider.py | 22 - backend/chainlit/oauth/okta_oauth_provider.py | 78 -- backend/chainlit/oauth/providers.py | 42 -- backend/chainlit/oauth_providers.py | 683 +++++++++++++++++- backend/chainlit/server.py | 10 - 21 files changed, 699 insertions(+), 971 deletions(-) delete mode 100644 backend/chainlit/oauth/__init__.py delete mode 100644 backend/chainlit/oauth/auth0_oauth_provider.py delete mode 100644 backend/chainlit/oauth/aws_cognito_oauth_provider.py delete mode 100644 backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py delete mode 100644 backend/chainlit/oauth/azure_ad_oauth_provider.py delete mode 100644 backend/chainlit/oauth/descope_oauth_provider.py delete mode 100644 backend/chainlit/oauth/github.py delete mode 100644 backend/chainlit/oauth/gitlab_oauth_provider.py delete mode 100644 backend/chainlit/oauth/google.py delete mode 100644 backend/chainlit/oauth/oauth_provider.py delete mode 100644 backend/chainlit/oauth/okta_oauth_provider.py delete mode 100644 backend/chainlit/oauth/providers.py diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 71071d7e52..0506ef38f3 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -12,7 +12,7 @@ logger.info("Loaded .env file") import asyncio -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict import chainlit.input_widget as input_widget from chainlit.action import Action @@ -41,7 +41,6 @@ ErrorMessage, Message, ) -from chainlit.oauth.providers import get_configured_oauth_providers from chainlit.step import Step, step from chainlit.sync import make_async, run_sync from chainlit.types import AudioChunk, ChatProfile, Starter @@ -76,202 +75,6 @@ from chainlit.langchain.callbacks import ( AsyncLangchainCallbackHandler, LangchainCallbackHandler, - config.code.oauth_callback = wrap_user_function(func) - return func - - -@trace -def custom_authenticate_user(func: Callable[[str], Awaitable[User]]) -> Callable: - """ - A decorator to authenticate the user via custom token validation. - - Args: - func (Callable[[str, str, Dict[str, str], User], Optional[User]]): The authentication callback to execute. - - Returns: - Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated authentication callback. - """ - - if len(get_configured_oauth_providers()) == 0: - raise ValueError( - "You must set the environment variable for at least one oauth provider to use oauth authentication." - ) - - config.code.custom_authenticate_user = wrap_user_function(func) - return func - - -@trace -def custom_oauth_provider(func: Callable[[str], Awaitable[User]]) -> Callable: - """ - A decorator to integrate custom OAuth provider logic for user authentication. - - Args: - func (Callable[[str, str, Dict[str, str], User], Optional[User]]): A function that returns an instance of the OAuthProvider class, encapsulating the logic and details for the custom OAuth provider. - - Returns: - Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated callback function that handles authentication via the custom OAuth provider. - """ - - if len(get_configured_oauth_providers()) == 0: - raise ValueError( - "You must set the environment variable for at least one oauth provider to use oauth authentication." - ) - - config.code.custom_oauth_provider = wrap_user_function(func) - return func - - -@trace -def on_logout(func: Callable[[Request, Response], Any]) -> Callable: - """ - Function called when the user logs out. - Takes the FastAPI request and response as parameters. - """ - - config.code.on_logout = wrap_user_function(func) - return func - - -@trace -def on_message(func: Callable) -> Callable: - """ - Framework agnostic decorator to react to messages coming from the UI. - The decorated function is called every time a new message is received. - - Args: - func (Callable[[Message], Any]): The function to be called when a new message is received. Takes a cl.Message. - - Returns: - Callable[[str], Any]: The decorated on_message function. - """ - - async def with_parent_id(message: Message): - async with Step(name="on_message", type="run", parent_id=message.id) as s: - s.input = message.content - if len(inspect.signature(func).parameters) > 0: - await func(message) - else: - await func() - - config.code.on_message = wrap_user_function(with_parent_id) - return func - - -@trace -def on_chat_start(func: Callable) -> Callable: - """ - Hook to react to the user websocket connection event. - - Args: - func (Callable[], Any]): The connection hook to execute. - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_chat_start = wrap_user_function( - step(func, name="on_chat_start", type="run"), with_task=True - ) - return func - - -@trace -def on_chat_resume(func: Callable[[ThreadDict], Any]) -> Callable: - """ - Hook to react to resume websocket connection event. - - Args: - func (Callable[], Any]): The connection hook to execute. - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_chat_resume = wrap_user_function(func, with_task=True) - return func - - -@trace -def set_chat_profiles( - func: Callable[[Optional["User"]], List["ChatProfile"]] -) -> Callable: - """ - Programmatic declaration of the available chat profiles (can depend on the User from the session if authentication is setup). - - Args: - func (Callable[[Optional["User"]], List["ChatProfile"]]): The function declaring the chat profiles. - - Returns: - Callable[[Optional["User"]], List["ChatProfile"]]: The decorated function. - """ - - config.code.set_chat_profiles = wrap_user_function(func) - return func - - -@trace -def set_starters(func: Callable[[Optional["User"]], List["Starter"]]) -> Callable: - """ - Programmatic declaration of the available starter (can depend on the User from the session if authentication is setup). - - Args: - func (Callable[[Optional["User"]], List["Starter"]]): The function declaring the starters. - - Returns: - Callable[[Optional["User"]], List["Starter"]]: The decorated function. - """ - - config.code.set_starters = wrap_user_function(func) - return func - - -@trace -def on_chat_end(func: Callable) -> Callable: - """ - Hook to react to the user websocket disconnect event. - - Args: - func (Callable[], Any]): The disconnect hook to execute. - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_chat_end = wrap_user_function(func, with_task=True) - return func - - -@trace -def on_audio_chunk(func: Callable) -> Callable: - """ - Hook to react to the audio chunks being sent. - - Args: - chunk (AudioChunk): The audio chunk being sent. - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_audio_chunk = wrap_user_function(func, with_task=False) - return func - - -@trace -def on_audio_end(func: Callable) -> Callable: - """ - Hook to react to the audio stream ending. This is called after the last audio chunk is sent. - - Args: - elements ([List[Element]): The files that were uploaded before starting the audio stream (if any). - - Returns: - Callable[], Any]: The decorated hook. - """ - - config.code.on_audio_end = wrap_user_function( - step(func, name="on_audio_end", type="run"), with_task=True ) from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler from chainlit.mistralai import instrument_mistralai diff --git a/backend/chainlit/auth.py b/backend/chainlit/auth.py index a4b1b326a7..9bd2073b55 100644 --- a/backend/chainlit/auth.py +++ b/backend/chainlit/auth.py @@ -5,7 +5,7 @@ import jwt from chainlit.config import config from chainlit.data import get_data_layer -from chainlit.oauth.providers import get_configured_oauth_providers +from chainlit.oauth_providers import get_configured_oauth_providers from chainlit.user import User from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer diff --git a/backend/chainlit/callbacks.py b/backend/chainlit/callbacks.py index b559049d7b..c9ac5fe43c 100644 --- a/backend/chainlit/callbacks.py +++ b/backend/chainlit/callbacks.py @@ -4,7 +4,11 @@ from chainlit.action import Action from chainlit.config import config from chainlit.message import Message -from chainlit.oauth_providers import get_configured_oauth_providers +from chainlit.oauth_providers import ( + OAuthProvider, + get_configured_oauth_providers, + providers, +) from chainlit.step import Step, step from chainlit.telemetry import trace from chainlit.types import ChatProfile, Starter, ThreadDict @@ -87,6 +91,38 @@ async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, return func +@trace +def custom_authenticate_user(func: Callable[[str], Awaitable[User]]) -> Callable: + """ + A decorator to authenticate the user via custom token validation. + + Args: + func (Callable[[str], Awaitable[User]]): The authentication callback to execute. + + Returns: + Callable[[str], Awaitable[User]]: The decorated authentication callback. + """ + + if len(get_configured_oauth_providers()) == 0: + raise ValueError( + "You must set the environment variable for at least one oauth provider to use oauth authentication." + ) + + config.code.custom_authenticate_user = wrap_user_function(func) + return func + + +def custom_oauth_provider(func: Callable[[], OAuthProvider]) -> None: + """ + A decorator to integrate custom OAuth provider logic for user authentication. + + Args: + func (Callable[[], OAuthProvider): A function that returns an instance of the OAuthProvider class, encapsulating the logic and details for the custom OAuth provider. + """ + + providers.append(func()) + + @trace def on_logout(func: Callable[[Request, Response], Any]) -> Callable: """ diff --git a/backend/chainlit/chat_context.py b/backend/chainlit/chat_context.py index 5f7215ba56..81bf66b3d2 100644 --- a/backend/chainlit/chat_context.py +++ b/backend/chainlit/chat_context.py @@ -25,10 +25,10 @@ def add(self, message: "Message"): if context.session.id not in chat_contexts: chat_contexts[context.session.id] = [] - + if message not in chat_contexts[context.session.id]: chat_contexts[context.session.id].append(message) - + return message def remove(self, message: "Message") -> bool: diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 22fc0d6ec9..af54857d99 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -19,7 +19,7 @@ import tomli from chainlit.logger import logger -from chainlit.oauth.oauth_provider import OAuthProvider +from chainlit.oauth_providers import OAuthProvider from chainlit.translations import lint_translation_json from chainlit.version import __version__ from dataclasses_json import DataClassJsonMixin diff --git a/backend/chainlit/data/acl.py b/backend/chainlit/data/acl.py index 65c040170a..5264fb6951 100644 --- a/backend/chainlit/data/acl.py +++ b/backend/chainlit/data/acl.py @@ -8,7 +8,7 @@ async def is_thread_author(username: str, thread_id: str): raise HTTPException(status_code=400, detail="Data layer not initialized") thread_author = await data_layer.get_thread_author(thread_id) - + if not thread_author: raise HTTPException(status_code=404, detail="Thread not found") diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py index 9a4f65b411..cbc0a01155 100644 --- a/backend/chainlit/data/sql_alchemy.py +++ b/backend/chainlit/data/sql_alchemy.py @@ -373,12 +373,18 @@ async def delete_feedback(self, feedback_id: str) -> bool: return True ###### Elements ###### - async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]: + async def get_element( + self, thread_id: str, element_id: str + ) -> Optional["ElementDict"]: if self.show_logger: - logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}") + logger.info( + f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}" + ) query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id""" parameters = {"thread_id": thread_id, "element_id": element_id} - element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters) + element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql( + query=query, parameters=parameters + ) if isinstance(element, list) and element: element_dict: Dict[str, Any] = element[0] return ElementDict( @@ -396,7 +402,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen autoPlay=element_dict.get("autoPlay"), playerConfig=element_dict.get("playerConfig"), forId=element_dict.get("forId"), - mime=element_dict.get("mime") + mime=element_dict.get("mime"), ) else: return None @@ -607,7 +613,8 @@ async def get_all_user_threads( tags=step_feedback.get("step_tags"), input=( step_feedback.get("step_input", "") - if step_feedback.get("step_showinput") not in [None, "false"] + if step_feedback.get("step_showinput") + not in [None, "false"] else None ), output=step_feedback.get("step_output", ""), diff --git a/backend/chainlit/oauth/__init__.py b/backend/chainlit/oauth/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/backend/chainlit/oauth/auth0_oauth_provider.py b/backend/chainlit/oauth/auth0_oauth_provider.py deleted file mode 100644 index 91e70fb376..0000000000 --- a/backend/chainlit/oauth/auth0_oauth_provider.py +++ /dev/null @@ -1,69 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class Auth0OAuthProvider(OAuthProvider): - id = "auth0" - env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"] - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET") - # Ensure that the domain does not have a trailing slash - self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}" - self.original_domain = ( - f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}" - if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN") - else self.domain - ) - - self.authorize_url = f"{self.domain}/authorize" - - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "audience": f"{self.original_domain}/userinfo", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.original_domain}/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - auth0_user = response.json() - user = User( - identifier=auth0_user.get("email"), - metadata={ - "image": auth0_user.get("picture", ""), - "provider": "auth0", - }, - ) - return (auth0_user, user) diff --git a/backend/chainlit/oauth/aws_cognito_oauth_provider.py b/backend/chainlit/oauth/aws_cognito_oauth_provider.py deleted file mode 100644 index d5d286d134..0000000000 --- a/backend/chainlit/oauth/aws_cognito_oauth_provider.py +++ /dev/null @@ -1,72 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class AWSCognitoOAuthProvider(OAuthProvider): - id = "aws-cognito" - env = [ - "OAUTH_COGNITO_CLIENT_ID", - "OAUTH_COGNITO_CLIENT_SECRET", - "OAUTH_COGNITO_DOMAIN", - ] - authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login" - token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET") - self.authorize_params = { - "response_type": "code", - "client_id": self.client_id, - "scope": "openid profile email", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - user_info_url = ( - f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo" - ) - async with httpx.AsyncClient() as client: - response = await client.get( - user_info_url, - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - cognito_user = response.json() - - # Customize user metadata as needed - user = User( - identifier=cognito_user["email"], - metadata={ - "image": cognito_user.get("picture", ""), - "provider": "aws-cognito", - }, - ) - return (cognito_user, user) diff --git a/backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py b/backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py deleted file mode 100644 index da935d1543..0000000000 --- a/backend/chainlit/oauth/azure_ad_hubrid_oauth_provider.py +++ /dev/null @@ -1,92 +0,0 @@ -import base64 -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.secret import random_secret -from chainlit.user import User -from fastapi import HTTPException - - -class AzureADHybridOAuthProvider(OAuthProvider): - id = "azure-ad-hybrid" - env = [ - "OAUTH_AZURE_AD_HYBRID_CLIENT_ID", - "OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET", - "OAUTH_AZURE_AD_HYBRID_TENANT_ID", - ] - authorize_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize" - if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - ) - token_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token" - if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/token" - ) - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET") - nonce = random_secret(16) - self.authorize_params = { - "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"), - "response_type": "code id_token", - "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid", - "response_mode": "form_post", - "nonce": nonce, - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://graph.microsoft.com/v1.0/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - azure_user = response.json() - - try: - photo_response = await client.get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", - headers={"Authorization": f"Bearer {token}"}, - ) - photo_data = await photo_response.aread() - base64_image = base64.b64encode(photo_data) - azure_user["image"] = ( - f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" - ) - except Exception as e: - # Ignore errors getting the photo - pass - - user = User( - identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, - ) - return (azure_user, user) diff --git a/backend/chainlit/oauth/azure_ad_oauth_provider.py b/backend/chainlit/oauth/azure_ad_oauth_provider.py deleted file mode 100644 index 2ed32bbe28..0000000000 --- a/backend/chainlit/oauth/azure_ad_oauth_provider.py +++ /dev/null @@ -1,89 +0,0 @@ -import base64 -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class AzureADOAuthProvider(OAuthProvider): - id = "azure-ad" - env = [ - "OAUTH_AZURE_AD_CLIENT_ID", - "OAUTH_AZURE_AD_CLIENT_SECRET", - "OAUTH_AZURE_AD_TENANT_ID", - ] - authorize_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize" - if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - ) - token_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token" - if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/token" - ) - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET") - self.authorize_params = { - "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), - "response_type": "code", - "scope": "https://graph.microsoft.com/User.Read", - "response_mode": "query", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://graph.microsoft.com/v1.0/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - azure_user = response.json() - - try: - photo_response = await client.get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", - headers={"Authorization": f"Bearer {token}"}, - ) - photo_data = await photo_response.aread() - base64_image = base64.b64encode(photo_data) - azure_user["image"] = ( - f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" - ) - except Exception as e: - # Ignore errors getting the photo - pass - - user = User( - identifier=azure_user["userPrincipalName"], - metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, - ) - return (azure_user, user) diff --git a/backend/chainlit/oauth/descope_oauth_provider.py b/backend/chainlit/oauth/descope_oauth_provider.py deleted file mode 100644 index 08c96da0d5..0000000000 --- a/backend/chainlit/oauth/descope_oauth_provider.py +++ /dev/null @@ -1,61 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User - - -class DescopeOAuthProvider(OAuthProvider): - id = "descope" - env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"] - # Ensure that the domain does not have a trailing slash - domain = f"https://api.descope.com/oauth2/v1" - - authorize_url = f"{domain}/authorize" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET") - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "audience": f"{self.domain}/userinfo", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} - ) - response.raise_for_status() # This will raise an exception for 4xx/5xx responses - descope_user = response.json() - - user = User( - identifier=descope_user.get("email"), - metadata={"image": "", "provider": "descope"}, - ) - return (descope_user, user) diff --git a/backend/chainlit/oauth/github.py b/backend/chainlit/oauth/github.py deleted file mode 100644 index 5ab0d72059..0000000000 --- a/backend/chainlit/oauth/github.py +++ /dev/null @@ -1,63 +0,0 @@ -import os -import urllib.parse - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class GithubOAuthProvider(OAuthProvider): - id = "github" - env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"] - authorize_url = "https://github.com/login/oauth/authorize" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET") - self.authorize_params = { - "scope": "user:email", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - } - async with httpx.AsyncClient() as client: - response = await client.post( - "https://github.com/login/oauth/access_token", - data=payload, - ) - response.raise_for_status() - content = urllib.parse.parse_qs(response.text) - token = content.get("access_token", [""])[0] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - user_response = await client.get( - "https://api.github.com/user", - headers={"Authorization": f"token {token}"}, - ) - user_response.raise_for_status() - github_user = user_response.json() - - emails_response = await client.get( - "https://api.github.com/user/emails", - headers={"Authorization": f"token {token}"}, - ) - emails_response.raise_for_status() - emails = emails_response.json() - - github_user.update({"emails": emails}) - user = User( - identifier=github_user["login"], - metadata={"image": github_user["avatar_url"], "provider": "github"}, - ) - return (github_user, user) diff --git a/backend/chainlit/oauth/gitlab_oauth_provider.py b/backend/chainlit/oauth/gitlab_oauth_provider.py deleted file mode 100644 index 22e993a77b..0000000000 --- a/backend/chainlit/oauth/gitlab_oauth_provider.py +++ /dev/null @@ -1,67 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User -from fastapi import HTTPException - - -class GitlabOAuthProvider(OAuthProvider): - id = "gitlab" - env = [ - "OAUTH_GITLAB_CLIENT_ID", - "OAUTH_GITLAB_CLIENT_SECRET", - "OAUTH_GITLAB_DOMAIN", - ] - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET") - # Ensure that the domain does not have a trailing slash - self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}" - - self.authorize_url = f"{self.domain}/oauth/authorize" - - self.authorize_params = { - "scope": "openid profile email", - "response_type": "code", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/oauth/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - gitlab_user = response.json() - user = User( - identifier=gitlab_user.get("email"), - metadata={ - "image": gitlab_user.get("picture", ""), - "provider": "gitlab", - }, - ) - return (gitlab_user, user) diff --git a/backend/chainlit/oauth/google.py b/backend/chainlit/oauth/google.py deleted file mode 100644 index 0d4cc1cfea..0000000000 --- a/backend/chainlit/oauth/google.py +++ /dev/null @@ -1,58 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User - - -class GoogleOAuthProvider(OAuthProvider): - id = "google" - env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"] - authorize_url = "https://accounts.google.com/o/oauth2/v2/auth" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET") - self.authorize_params = { - "scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", - "response_type": "code", - "access_type": "offline", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - "https://oauth2.googleapis.com/token", - data=payload, - ) - response.raise_for_status() - json = response.json() - token = json.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://www.googleapis.com/userinfo/v2/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - google_user = response.json() - user = User( - identifier=google_user["email"], - metadata={"image": google_user["picture"], "provider": "google"}, - ) - return (google_user, user) diff --git a/backend/chainlit/oauth/oauth_provider.py b/backend/chainlit/oauth/oauth_provider.py deleted file mode 100644 index 4db4d0d413..0000000000 --- a/backend/chainlit/oauth/oauth_provider.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -from typing import Dict, List, Tuple - -from chainlit.user import User - - -class OAuthProvider: - id: str - env: List[str] - client_id: str - client_secret: str - authorize_url: str - authorize_params: Dict[str, str] - - def is_configured(self): - return all([os.environ.get(env) for env in self.env]) - - async def get_token(self, code: str, url: str) -> str: - raise NotImplementedError() - - async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]: - raise NotImplementedError() diff --git a/backend/chainlit/oauth/okta_oauth_provider.py b/backend/chainlit/oauth/okta_oauth_provider.py deleted file mode 100644 index a531ddadb0..0000000000 --- a/backend/chainlit/oauth/okta_oauth_provider.py +++ /dev/null @@ -1,78 +0,0 @@ -import os - -import httpx -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.user import User - - -class OktaOAuthProvider(OAuthProvider): - id = "okta" - env = [ - "OAUTH_OKTA_CLIENT_ID", - "OAUTH_OKTA_CLIENT_SECRET", - "OAUTH_OKTA_DOMAIN", - ] - # Avoid trailing slash in domain if supplied - domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET") - self.authorization_server_id = os.environ.get( - "OAUTH_OKTA_AUTHORIZATION_SERVER_ID", "" - ) - self.authorize_url = ( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize" - ) - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "response_mode": "query", - } - - def get_authorization_server_path(self): - if not self.authorization_server_id: - return "/default" - if self.authorization_server_id == "false": - return "" - return f"/{self.authorization_server_id}" - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token", - data=payload, - ) - response.raise_for_status() - json_data = response.json() - - token = json_data.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - okta_user = response.json() - - user = User( - identifier=okta_user.get("email"), - metadata={"image": "", "provider": "okta"}, - ) - return (okta_user, user) diff --git a/backend/chainlit/oauth/providers.py b/backend/chainlit/oauth/providers.py deleted file mode 100644 index a9510e67a7..0000000000 --- a/backend/chainlit/oauth/providers.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Optional - -from chainlit.config import config -from chainlit.oauth.auth0_oauth_provider import Auth0OAuthProvider -from chainlit.oauth.aws_cognito_oauth_provider import AWSCognitoOAuthProvider -from chainlit.oauth.azure_ad_hubrid_oauth_provider import AzureADHybridOAuthProvider -from chainlit.oauth.azure_ad_oauth_provider import AzureADOAuthProvider -from chainlit.oauth.descope_oauth_provider import DescopeOAuthProvider -from chainlit.oauth.github import GithubOAuthProvider -from chainlit.oauth.gitlab_oauth_provider import GitlabOAuthProvider -from chainlit.oauth.google import GoogleOAuthProvider -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.oauth.okta_oauth_provider import OktaOAuthProvider - -custom_oauth = config.code.custom_oauth_provider -providers = ( - [ - GithubOAuthProvider(), - GoogleOAuthProvider(), - AzureADOAuthProvider(), - AzureADHybridOAuthProvider(), - OktaOAuthProvider(), - Auth0OAuthProvider(), - DescopeOAuthProvider(), - AWSCognitoOAuthProvider(), - GitlabOAuthProvider(), - ] - + [custom_oauth()] - if custom_oauth - else [] -) - - -def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: - for p in providers: - if p.id == provider: - return p - return None - - -def get_configured_oauth_providers(): - return [p.id for p in providers if p.is_configured()] diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index c8d45bcc4c..fe019859b1 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -1,40 +1,645 @@ -import warnings - -warnings.warn( - "The 'oauth_providers' module is deprecated and will be removed in a future version. " - "Please use 'oauth' instead.", - DeprecationWarning, - stacklevel=2, -) - -from chainlit.oauth.auth0_oauth_provider import Auth0OAuthProvider -from chainlit.oauth.aws_cognito_oauth_provider import AWSCognitoOAuthProvider -from chainlit.oauth.azure_ad_hubrid_oauth_provider import AzureADHybridOAuthProvider -from chainlit.oauth.azure_ad_oauth_provider import AzureADOAuthProvider -from chainlit.oauth.descope_oauth_provider import DescopeOAuthProvider -from chainlit.oauth.github import GithubOAuthProvider -from chainlit.oauth.gitlab_oauth_provider import GitlabOAuthProvider -from chainlit.oauth.google import GoogleOAuthProvider -from chainlit.oauth.oauth_provider import OAuthProvider -from chainlit.oauth.okta_oauth_provider import OktaOAuthProvider -from chainlit.oauth.providers import ( - get_configured_oauth_providers, - get_oauth_provider, - providers, -) - -__all__ = [ - "providers", - "get_oauth_provider", - "get_configured_oauth_providers", - "OAuthProvider", - "GithubOAuthProvider", - "GoogleOAuthProvider", - "AzureADOAuthProvider", - "AzureADHybridOAuthProvider", - "OktaOAuthProvider", - "Auth0OAuthProvider", - "DescopeOAuthProvider", - "AWSCognitoOAuthProvider", - "GitlabOAuthProvider", +import base64 +import os +import urllib.parse +from typing import Dict, List, Optional, Tuple + +import httpx +from chainlit.secret import random_secret +from chainlit.user import User +from fastapi import HTTPException + + +class OAuthProvider: + id: str + env: List[str] + client_id: str + client_secret: str + authorize_url: str + authorize_params: Dict[str, str] + + def is_configured(self): + return all([os.environ.get(env) for env in self.env]) + + async def get_token(self, code: str, url: str) -> str: + raise NotImplementedError() + + async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]: + raise NotImplementedError() + + +class GithubOAuthProvider(OAuthProvider): + id = "github" + env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"] + authorize_url = "https://github.com/login/oauth/authorize" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET") + self.authorize_params = { + "scope": "user:email", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + } + async with httpx.AsyncClient() as client: + response = await client.post( + "https://github.com/login/oauth/access_token", + data=payload, + ) + response.raise_for_status() + content = urllib.parse.parse_qs(response.text) + token = content.get("access_token", [""])[0] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + user_response = await client.get( + "https://api.github.com/user", + headers={"Authorization": f"token {token}"}, + ) + user_response.raise_for_status() + github_user = user_response.json() + + emails_response = await client.get( + "https://api.github.com/user/emails", + headers={"Authorization": f"token {token}"}, + ) + emails_response.raise_for_status() + emails = emails_response.json() + + github_user.update({"emails": emails}) + user = User( + identifier=github_user["login"], + metadata={"image": github_user["avatar_url"], "provider": "github"}, + ) + return (github_user, user) + + +class GoogleOAuthProvider(OAuthProvider): + id = "google" + env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"] + authorize_url = "https://accounts.google.com/o/oauth2/v2/auth" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET") + self.authorize_params = { + "scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", + "response_type": "code", + "access_type": "offline", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + "https://oauth2.googleapis.com/token", + data=payload, + ) + response.raise_for_status() + json = response.json() + token = json.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://www.googleapis.com/userinfo/v2/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + google_user = response.json() + user = User( + identifier=google_user["email"], + metadata={"image": google_user["picture"], "provider": "google"}, + ) + return (google_user, user) + + +class AzureADOAuthProvider(OAuthProvider): + id = "azure-ad" + env = [ + "OAUTH_AZURE_AD_CLIENT_ID", + "OAUTH_AZURE_AD_CLIENT_SECRET", + "OAUTH_AZURE_AD_TENANT_ID", + ] + authorize_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize" + if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + ) + token_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token" + if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/token" + ) + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET") + self.authorize_params = { + "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), + "response_type": "code", + "scope": "https://graph.microsoft.com/User.Read", + "response_mode": "query", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json["access_token"] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://graph.microsoft.com/v1.0/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + azure_user = response.json() + + try: + photo_response = await client.get( + "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", + headers={"Authorization": f"Bearer {token}"}, + ) + photo_data = await photo_response.aread() + base64_image = base64.b64encode(photo_data) + azure_user["image"] = ( + f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" + ) + except Exception as e: + # Ignore errors getting the photo + pass + + user = User( + identifier=azure_user["userPrincipalName"], + metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + ) + return (azure_user, user) + + +class AzureADHybridOAuthProvider(OAuthProvider): + id = "azure-ad-hybrid" + env = [ + "OAUTH_AZURE_AD_HYBRID_CLIENT_ID", + "OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET", + "OAUTH_AZURE_AD_HYBRID_TENANT_ID", + ] + authorize_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize" + if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + ) + token_url = ( + f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token" + if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT") + else "https://login.microsoftonline.com/common/oauth2/v2.0/token" + ) + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET") + nonce = random_secret(16) + self.authorize_params = { + "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"), + "response_type": "code id_token", + "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid", + "response_mode": "form_post", + "nonce": nonce, + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json["access_token"] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + "https://graph.microsoft.com/v1.0/me", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + azure_user = response.json() + + try: + photo_response = await client.get( + "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", + headers={"Authorization": f"Bearer {token}"}, + ) + photo_data = await photo_response.aread() + base64_image = base64.b64encode(photo_data) + azure_user["image"] = ( + f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" + ) + except Exception as e: + # Ignore errors getting the photo + pass + + user = User( + identifier=azure_user["userPrincipalName"], + metadata={"image": azure_user.get("image"), "provider": "azure-ad"}, + ) + return (azure_user, user) + + +class OktaOAuthProvider(OAuthProvider): + id = "okta" + env = [ + "OAUTH_OKTA_CLIENT_ID", + "OAUTH_OKTA_CLIENT_SECRET", + "OAUTH_OKTA_DOMAIN", + ] + # Avoid trailing slash in domain if supplied + domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET") + self.authorization_server_id = os.environ.get( + "OAUTH_OKTA_AUTHORIZATION_SERVER_ID", "" + ) + self.authorize_url = ( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize" + ) + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "response_mode": "query", + } + + def get_authorization_server_path(self): + if not self.authorization_server_id: + return "/default" + if self.authorization_server_id == "false": + return "" + return f"/{self.authorization_server_id}" + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token", + data=payload, + ) + response.raise_for_status() + json_data = response.json() + + token = json_data.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + okta_user = response.json() + + user = User( + identifier=okta_user.get("email"), + metadata={"image": "", "provider": "okta"}, + ) + return (okta_user, user) + + +class Auth0OAuthProvider(OAuthProvider): + id = "auth0" + env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"] + + def __init__(self): + self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET") + # Ensure that the domain does not have a trailing slash + self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}" + self.original_domain = ( + f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}" + if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN") + else self.domain + ) + + self.authorize_url = f"{self.domain}/authorize" + + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "audience": f"{self.original_domain}/userinfo", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.original_domain}/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + auth0_user = response.json() + user = User( + identifier=auth0_user.get("email"), + metadata={ + "image": auth0_user.get("picture", ""), + "provider": "auth0", + }, + ) + return (auth0_user, user) + + +class DescopeOAuthProvider(OAuthProvider): + id = "descope" + env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"] + # Ensure that the domain does not have a trailing slash + domain = f"https://api.descope.com/oauth2/v1" + + authorize_url = f"{domain}/authorize" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET") + self.authorize_params = { + "response_type": "code", + "scope": "openid profile email", + "audience": f"{self.domain}/userinfo", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} + ) + response.raise_for_status() # This will raise an exception for 4xx/5xx responses + descope_user = response.json() + + user = User( + identifier=descope_user.get("email"), + metadata={"image": "", "provider": "descope"}, + ) + return (descope_user, user) + + +class AWSCognitoOAuthProvider(OAuthProvider): + id = "aws-cognito" + env = [ + "OAUTH_COGNITO_CLIENT_ID", + "OAUTH_COGNITO_CLIENT_SECRET", + "OAUTH_COGNITO_DOMAIN", + ] + authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login" + token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token" + + def __init__(self): + self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET") + self.authorize_params = { + "response_type": "code", + "client_id": self.client_id, + "scope": "openid profile email", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=payload, + ) + response.raise_for_status() + json = response.json() + + token = json.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + user_info_url = ( + f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo" + ) + async with httpx.AsyncClient() as client: + response = await client.get( + user_info_url, + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + + cognito_user = response.json() + + # Customize user metadata as needed + user = User( + identifier=cognito_user["email"], + metadata={ + "image": cognito_user.get("picture", ""), + "provider": "aws-cognito", + }, + ) + return (cognito_user, user) + + +class GitlabOAuthProvider(OAuthProvider): + id = "gitlab" + env = [ + "OAUTH_GITLAB_CLIENT_ID", + "OAUTH_GITLAB_CLIENT_SECRET", + "OAUTH_GITLAB_DOMAIN", + ] + + def __init__(self): + self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET") + # Ensure that the domain does not have a trailing slash + self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}" + + self.authorize_url = f"{self.domain}/oauth/authorize" + + self.authorize_params = { + "scope": "openid profile email", + "response_type": "code", + } + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.domain}/oauth/token", + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/oauth/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + gitlab_user = response.json() + user = User( + identifier=gitlab_user.get("email"), + metadata={ + "image": gitlab_user.get("picture", ""), + "provider": "gitlab", + }, + ) + return (gitlab_user, user) + + +providers = [ + GithubOAuthProvider(), + GoogleOAuthProvider(), + AzureADOAuthProvider(), + AzureADHybridOAuthProvider(), + OktaOAuthProvider(), + Auth0OAuthProvider(), + DescopeOAuthProvider(), + AWSCognitoOAuthProvider(), + GitlabOAuthProvider(), ] + + +def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: + for p in providers: + if p.id == provider: + return p + return None + + +def get_configured_oauth_providers(): + return [p.id for p in providers if p.is_configured()] diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index f88ac68057..597830ee43 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -6,16 +6,6 @@ import re import shutil import urllib.parse -from typing import Any, Optional, Union - -from chainlit.oauth.providers import get_oauth_provider -from chainlit.secret import random_secret - -mimetypes.add_type("application/javascript", ".js") -mimetypes.add_type("text/css", ".css") - -import asyncio -import os import webbrowser from contextlib import asynccontextmanager from pathlib import Path From 8338774ca22806e90d54e829715416972eee3575 Mon Sep 17 00:00:00 2001 From: lwieczorek-dss <143825023+lwieczorek-dss@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:11:07 +0100 Subject: [PATCH 9/9] Custom OAuth provider implementation covered with UTs --- backend/tests/test_callbacks.py | 59 +++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/backend/tests/test_callbacks.py b/backend/tests/test_callbacks.py index 3c41000137..137899b478 100644 --- a/backend/tests/test_callbacks.py +++ b/backend/tests/test_callbacks.py @@ -107,6 +107,65 @@ async def auth_func( assert result is None +async def test_custom_authenticate_user(test_config): + from unittest.mock import patch + + from chainlit.callbacks import custom_authenticate_user + from chainlit.user import User + + # Mock the get_configured_oauth_providers function + with patch( + "chainlit.callbacks.get_configured_oauth_providers", + return_value=["custom_provider"], + ): + + @custom_authenticate_user + async def auth_func( + provider_id: str, + token: str, + raw_user_data: dict, + default_app_user: User, + id_token: str | None = None, + ) -> User | None: + if ( + provider_id == "custom_provider" and token == "valid_token" + ): # nosec B105 + return User(identifier="oauth_user") + return None + + # Test that the callback is properly registered as custom one + assert test_config.code.custom_authenticate_user is not None + + # Test the wrapped function with valid data + result = await test_config.code.custom_authenticate_user( + "custom_provider", "valid_token", {}, User(identifier="default_user") + ) + assert isinstance(result, User) + assert result.identifier == "oauth_user" + + # Test with invalid data + result = await test_config.code.custom_authenticate_user( + "google", "invalid_token", {}, User(identifier="default_user") + ) + assert result is None + + +async def test_custom_oauth_provider(test_config): + from unittest.mock import Mock, patch + + from chainlit.callbacks import custom_oauth_provider + + custom_provider = Mock() + + with patch("chainlit.callbacks.providers") as providers: + + # Add custom provider to providers + custom_oauth_provider(custom_provider) + + # Custom provider should be added to providers + providers.append.assert_called_once_with(custom_provider()) + + async def test_on_message(mock_chainlit_context, test_config): from chainlit.callbacks import on_message from chainlit.message import Message