Skip to content

Commit

Permalink
Add option for custom auth
Browse files Browse the repository at this point in the history
  • Loading branch information
patrykkotlowski-dsstream committed Aug 30, 2024
1 parent 86798bc commit f602bb7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
2 changes: 2 additions & 0 deletions backend/chainlit/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ async def get_current_user(token: str = Depends(reuseable_oauth)):
if not require_login():
return None

if config.code.custom_authenticate_user:
return await config.code.custom_authenticate_user(token)
return await authenticate_user(token)
4 changes: 4 additions & 0 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import tomli
from chainlit.logger import logger
from chainlit.oauth_providers import OAuthProvider
from chainlit.translations import lint_translation_json
from chainlit.version import __version__
from dataclasses_json import DataClassJsonMixin
Expand Down Expand Up @@ -275,6 +276,9 @@ class CodeSettings:
oauth_callback: Optional[
Callable[[str, str, Dict[str, str], "User"], Optional["User"]]
] = None
# Callbacks for authenticate mechanism
custom_authenticate_user: Optional[Callable[[str], "User"]]
custom_oauth_provider: Optional[OAuthProvider]
on_logout: Optional[Callable[["Request", "Response"], Any]] = None
on_stop: Optional[Callable[[], Any]] = None
on_chat_start: Optional[Callable[[], Any]] = None
Expand Down
29 changes: 18 additions & 11 deletions backend/chainlit/oauth_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from chainlit.user import User
from fastapi import HTTPException

from chainlit import config


class OAuthProvider:
id: str
Expand Down Expand Up @@ -621,17 +623,22 @@ async def get_user_info(self, token: str):
return (gitlab_user, user)


providers = [
GithubOAuthProvider(),
GoogleOAuthProvider(),
AzureADOAuthProvider(),
AzureADHybridOAuthProvider(),
OktaOAuthProvider(),
Auth0OAuthProvider(),
DescopeOAuthProvider(),
AWSCognitoOAuthProvider(),
GitlabOAuthProvider(),
]
providers = (
[
GithubOAuthProvider(),
GoogleOAuthProvider(),
AzureADOAuthProvider(),
AzureADHybridOAuthProvider(),
OktaOAuthProvider(),
Auth0OAuthProvider(),
DescopeOAuthProvider(),
AWSCognitoOAuthProvider(),
GitlabOAuthProvider(),
]
+ [config.code.custom_oauth_provider()]
if config.code.custom_oauth_provider
else []
)


def get_oauth_provider(provider: str) -> Optional[OAuthProvider]:
Expand Down

0 comments on commit f602bb7

Please sign in to comment.