Skip to content

Commit

Permalink
Switch back to old structure
Browse files Browse the repository at this point in the history
  • Loading branch information
patrykkotlowski-dsstream committed Sep 20, 2024
1 parent 1c760fe commit fa98dd9
Show file tree
Hide file tree
Showing 21 changed files with 699 additions and 971 deletions.
199 changes: 1 addition & 198 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger.info("Loaded .env file")

import asyncio
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict

import chainlit.input_widget as input_widget
from chainlit.action import Action
Expand Down Expand Up @@ -41,7 +41,6 @@
ErrorMessage,
Message,
)
from chainlit.oauth.providers import get_configured_oauth_providers
from chainlit.step import Step, step
from chainlit.sync import make_async, run_sync
from chainlit.types import AudioChunk, ChatProfile, Starter
Expand Down Expand Up @@ -76,202 +75,6 @@
from chainlit.langchain.callbacks import (
AsyncLangchainCallbackHandler,
LangchainCallbackHandler,
config.code.oauth_callback = wrap_user_function(func)
return func


@trace
def custom_authenticate_user(func: Callable[[str], Awaitable[User]]) -> Callable:
"""
A decorator to authenticate the user via custom token validation.
Args:
func (Callable[[str, str, Dict[str, str], User], Optional[User]]): The authentication callback to execute.
Returns:
Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated authentication callback.
"""

if len(get_configured_oauth_providers()) == 0:
raise ValueError(
"You must set the environment variable for at least one oauth provider to use oauth authentication."
)

config.code.custom_authenticate_user = wrap_user_function(func)
return func


@trace
def custom_oauth_provider(func: Callable[[str], Awaitable[User]]) -> Callable:
"""
A decorator to integrate custom OAuth provider logic for user authentication.
Args:
func (Callable[[str, str, Dict[str, str], User], Optional[User]]): A function that returns an instance of the OAuthProvider class, encapsulating the logic and details for the custom OAuth provider.
Returns:
Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated callback function that handles authentication via the custom OAuth provider.
"""

if len(get_configured_oauth_providers()) == 0:
raise ValueError(
"You must set the environment variable for at least one oauth provider to use oauth authentication."
)

config.code.custom_oauth_provider = wrap_user_function(func)
return func


@trace
def on_logout(func: Callable[[Request, Response], Any]) -> Callable:
"""
Function called when the user logs out.
Takes the FastAPI request and response as parameters.
"""

config.code.on_logout = wrap_user_function(func)
return func


@trace
def on_message(func: Callable) -> Callable:
"""
Framework agnostic decorator to react to messages coming from the UI.
The decorated function is called every time a new message is received.
Args:
func (Callable[[Message], Any]): The function to be called when a new message is received. Takes a cl.Message.
Returns:
Callable[[str], Any]: The decorated on_message function.
"""

async def with_parent_id(message: Message):
async with Step(name="on_message", type="run", parent_id=message.id) as s:
s.input = message.content
if len(inspect.signature(func).parameters) > 0:
await func(message)
else:
await func()

config.code.on_message = wrap_user_function(with_parent_id)
return func


@trace
def on_chat_start(func: Callable) -> Callable:
"""
Hook to react to the user websocket connection event.
Args:
func (Callable[], Any]): The connection hook to execute.
Returns:
Callable[], Any]: The decorated hook.
"""

config.code.on_chat_start = wrap_user_function(
step(func, name="on_chat_start", type="run"), with_task=True
)
return func


@trace
def on_chat_resume(func: Callable[[ThreadDict], Any]) -> Callable:
"""
Hook to react to resume websocket connection event.
Args:
func (Callable[], Any]): The connection hook to execute.
Returns:
Callable[], Any]: The decorated hook.
"""

config.code.on_chat_resume = wrap_user_function(func, with_task=True)
return func


@trace
def set_chat_profiles(
func: Callable[[Optional["User"]], List["ChatProfile"]]
) -> Callable:
"""
Programmatic declaration of the available chat profiles (can depend on the User from the session if authentication is setup).
Args:
func (Callable[[Optional["User"]], List["ChatProfile"]]): The function declaring the chat profiles.
Returns:
Callable[[Optional["User"]], List["ChatProfile"]]: The decorated function.
"""

config.code.set_chat_profiles = wrap_user_function(func)
return func


@trace
def set_starters(func: Callable[[Optional["User"]], List["Starter"]]) -> Callable:
"""
Programmatic declaration of the available starter (can depend on the User from the session if authentication is setup).
Args:
func (Callable[[Optional["User"]], List["Starter"]]): The function declaring the starters.
Returns:
Callable[[Optional["User"]], List["Starter"]]: The decorated function.
"""

config.code.set_starters = wrap_user_function(func)
return func


@trace
def on_chat_end(func: Callable) -> Callable:
"""
Hook to react to the user websocket disconnect event.
Args:
func (Callable[], Any]): The disconnect hook to execute.
Returns:
Callable[], Any]: The decorated hook.
"""

config.code.on_chat_end = wrap_user_function(func, with_task=True)
return func


@trace
def on_audio_chunk(func: Callable) -> Callable:
"""
Hook to react to the audio chunks being sent.
Args:
chunk (AudioChunk): The audio chunk being sent.
Returns:
Callable[], Any]: The decorated hook.
"""

config.code.on_audio_chunk = wrap_user_function(func, with_task=False)
return func


@trace
def on_audio_end(func: Callable) -> Callable:
"""
Hook to react to the audio stream ending. This is called after the last audio chunk is sent.
Args:
elements ([List[Element]): The files that were uploaded before starting the audio stream (if any).
Returns:
Callable[], Any]: The decorated hook.
"""

config.code.on_audio_end = wrap_user_function(
step(func, name="on_audio_end", type="run"), with_task=True
)
from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler
from chainlit.mistralai import instrument_mistralai
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jwt
from chainlit.config import config
from chainlit.data import get_data_layer
from chainlit.oauth.providers import get_configured_oauth_providers
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.user import User
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
Expand Down
38 changes: 37 additions & 1 deletion backend/chainlit/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from chainlit.action import Action
from chainlit.config import config
from chainlit.message import Message
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.oauth_providers import (
OAuthProvider,
get_configured_oauth_providers,
providers,
)
from chainlit.step import Step, step
from chainlit.telemetry import trace
from chainlit.types import ChatProfile, Starter, ThreadDict
Expand Down Expand Up @@ -87,6 +91,38 @@ async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str,
return func


@trace
def custom_authenticate_user(func: Callable[[str], Awaitable[User]]) -> Callable:
"""
A decorator to authenticate the user via custom token validation.
Args:
func (Callable[[str], Awaitable[User]]): The authentication callback to execute.
Returns:
Callable[[str], Awaitable[User]]: The decorated authentication callback.
"""

if len(get_configured_oauth_providers()) == 0:
raise ValueError(
"You must set the environment variable for at least one oauth provider to use oauth authentication."
)

config.code.custom_authenticate_user = wrap_user_function(func)
return func


def custom_oauth_provider(func: Callable[[], OAuthProvider]) -> None:
"""
A decorator to integrate custom OAuth provider logic for user authentication.
Args:
func (Callable[[], OAuthProvider): A function that returns an instance of the OAuthProvider class, encapsulating the logic and details for the custom OAuth provider.
"""

providers.append(func())


@trace
def on_logout(func: Callable[[Request, Response], Any]) -> Callable:
"""
Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def add(self, message: "Message"):

if context.session.id not in chat_contexts:
chat_contexts[context.session.id] = []

if message not in chat_contexts[context.session.id]:
chat_contexts[context.session.id].append(message)

return message

def remove(self, message: "Message") -> bool:
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import tomli
from chainlit.logger import logger
from chainlit.oauth.oauth_provider import OAuthProvider
from chainlit.oauth_providers import OAuthProvider
from chainlit.translations import lint_translation_json
from chainlit.version import __version__
from dataclasses_json import DataClassJsonMixin
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/data/acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ async def is_thread_author(username: str, thread_id: str):
raise HTTPException(status_code=400, detail="Data layer not initialized")

thread_author = await data_layer.get_thread_author(thread_id)

if not thread_author:
raise HTTPException(status_code=404, detail="Thread not found")

Expand Down
17 changes: 12 additions & 5 deletions backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,18 @@ async def delete_feedback(self, feedback_id: str) -> bool:
return True

###### Elements ######
async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]:
async def get_element(
self, thread_id: str, element_id: str
) -> Optional["ElementDict"]:
if self.show_logger:
logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}")
logger.info(
f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}"
)
query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
parameters = {"thread_id": thread_id, "element_id": element_id}
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters)
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(
query=query, parameters=parameters
)
if isinstance(element, list) and element:
element_dict: Dict[str, Any] = element[0]
return ElementDict(
Expand All @@ -396,7 +402,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
autoPlay=element_dict.get("autoPlay"),
playerConfig=element_dict.get("playerConfig"),
forId=element_dict.get("forId"),
mime=element_dict.get("mime")
mime=element_dict.get("mime"),
)
else:
return None
Expand Down Expand Up @@ -607,7 +613,8 @@ async def get_all_user_threads(
tags=step_feedback.get("step_tags"),
input=(
step_feedback.get("step_input", "")
if step_feedback.get("step_showinput") not in [None, "false"]
if step_feedback.get("step_showinput")
not in [None, "false"]
else None
),
output=step_feedback.get("step_output", ""),
Expand Down
Empty file removed backend/chainlit/oauth/__init__.py
Empty file.
Loading

0 comments on commit fa98dd9

Please sign in to comment.