Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for custom auth #1280

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/chainlit/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,6 @@ async def get_current_user(token: str = Depends(reuseable_oauth)):
if not require_login():
return None

if config.code.custom_authenticate_user:
return await config.code.custom_authenticate_user(token)
return await authenticate_user(token)
38 changes: 37 additions & 1 deletion backend/chainlit/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from chainlit.action import Action
from chainlit.config import config
from chainlit.message import Message
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.oauth_providers import (
OAuthProvider,
get_configured_oauth_providers,
providers,
)
from chainlit.step import Step, step
from chainlit.telemetry import trace
from chainlit.types import ChatProfile, Starter, ThreadDict
Expand Down Expand Up @@ -87,6 +91,38 @@ async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str,
return func


@trace
def custom_authenticate_user(func: Callable[[str], Awaitable[User]]) -> Callable:
"""
A decorator to authenticate the user via custom token validation.

Args:
func (Callable[[str], Awaitable[User]]): The authentication callback to execute.

Returns:
Callable[[str], Awaitable[User]]: The decorated authentication callback.
"""

if len(get_configured_oauth_providers()) == 0:
raise ValueError(
"You must set the environment variable for at least one oauth provider to use oauth authentication."
)

config.code.custom_authenticate_user = wrap_user_function(func)
return func


def custom_oauth_provider(func: Callable[[], OAuthProvider]) -> None:
"""
A decorator to integrate custom OAuth provider logic for user authentication.

Args:
func (Callable[[], OAuthProvider): A function that returns an instance of the OAuthProvider class, encapsulating the logic and details for the custom OAuth provider.
"""

providers.append(func())


@trace
def on_logout(func: Callable[[Request, Response], Any]) -> Callable:
"""
Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def add(self, message: "Message"):

if context.session.id not in chat_contexts:
chat_contexts[context.session.id] = []

if message not in chat_contexts[context.session.id]:
chat_contexts[context.session.id].append(message)

return message

def remove(self, message: "Message") -> bool:
Expand Down
5 changes: 5 additions & 0 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
List,
Literal,
Optional,
Type,
Union,
)

import tomli
from chainlit.logger import logger
from chainlit.oauth_providers import OAuthProvider
from chainlit.translations import lint_translation_json
from chainlit.version import __version__
from dataclasses_json import DataClassJsonMixin
Expand Down Expand Up @@ -278,6 +280,9 @@ class CodeSettings:
oauth_callback: Optional[
Callable[[str, str, Dict[str, str], "User"], Awaitable[Optional["User"]]]
] = None
# Callbacks for authenticate mechanism
custom_authenticate_user: Optional[Callable[[str], Awaitable["User"]]] = None
custom_oauth_provider: Optional[Callable[[], Type[OAuthProvider]]] = None
on_logout: Optional[Callable[["Request", "Response"], Any]] = None
on_stop: Optional[Callable[[], Any]] = None
on_chat_start: Optional[Callable[[], Any]] = None
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/data/acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ async def is_thread_author(username: str, thread_id: str):
raise HTTPException(status_code=400, detail="Data layer not initialized")

thread_author = await data_layer.get_thread_author(thread_id)

if not thread_author:
raise HTTPException(status_code=404, detail="Thread not found")

Expand Down
59 changes: 59 additions & 0 deletions backend/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,65 @@ async def auth_func(
assert result is None


async def test_custom_authenticate_user(test_config):
from unittest.mock import patch

from chainlit.callbacks import custom_authenticate_user
from chainlit.user import User

# Mock the get_configured_oauth_providers function
with patch(
"chainlit.callbacks.get_configured_oauth_providers",
return_value=["custom_provider"],
):

@custom_authenticate_user
async def auth_func(
provider_id: str,
token: str,
raw_user_data: dict,
default_app_user: User,
id_token: str | None = None,
) -> User | None:
if (
provider_id == "custom_provider" and token == "valid_token"
): # nosec B105
return User(identifier="oauth_user")
return None

# Test that the callback is properly registered as custom one
assert test_config.code.custom_authenticate_user is not None

# Test the wrapped function with valid data
result = await test_config.code.custom_authenticate_user(
"custom_provider", "valid_token", {}, User(identifier="default_user")
)
assert isinstance(result, User)
assert result.identifier == "oauth_user"

# Test with invalid data
result = await test_config.code.custom_authenticate_user(
"google", "invalid_token", {}, User(identifier="default_user")
)
assert result is None


async def test_custom_oauth_provider(test_config):
from unittest.mock import Mock, patch

from chainlit.callbacks import custom_oauth_provider

custom_provider = Mock()

with patch("chainlit.callbacks.providers") as providers:

# Add custom provider to providers
custom_oauth_provider(custom_provider)

# Custom provider should be added to providers
providers.append.assert_called_once_with(custom_provider())


async def test_on_message(mock_chainlit_context, test_config):
from chainlit.callbacks import on_message
from chainlit.message import Message
Expand Down