diff --git a/.env.sample b/.env.sample index 0f4d2bc..81e3122 100644 --- a/.env.sample +++ b/.env.sample @@ -10,7 +10,7 @@ BREVIA_ENV_SECRETS='{ # QA completion LLM QA_COMPLETION_LLM='{ "_type": "openai-chat", - "model_name": "gpt-3.5-turbo-16k", + "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 200, "verbose": true @@ -19,7 +19,7 @@ QA_COMPLETION_LLM='{ # QA followup LLM QA_FOLLOWUP_LLM='{ "_type": "openai-chat", - "model_name": "gpt-3.5-turbo-16k", + "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 200, "verbose": true @@ -28,6 +28,8 @@ QA_FOLLOWUP_LLM='{ QA_FOLLOWUP_SIM_THRESHOLD=0.735 # Chat history - uncomment to disable session conversation, avoiding chat history loading # QA_NO_CHAT_HISTORY=True +# Uncomment to use a custom QA retriever with custom arguments +# QA_RETRIEVER='{"retriever": "my_project.CustomSplitter", "some_var": "some_value"}' # Access tokens secret - if missing no access token validity is checked # Generate it with: openssl rand -hex 32 @@ -49,12 +51,14 @@ EMBEDDINGS='{"_type": "openai-embeddings"}' ## Index creation TEXT_CHUNK_SIZE=2000 TEXT_CHUNK_OVERLAP=100 +# Uncomment to use a custom splitter with custom arguments +# TEXT_SPLITTER='{"splitter": "my_project.CustomSplitter", "some_var": "some_value"}' ## Summarize # Summarization LLM SUMMARIZE_LLM='{ "_type": "openai-chat", - "model_name": "gpt-3.5-turbo-16k", + "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 2000 }' diff --git a/brevia/index.py b/brevia/index.py index 1d72b7f..9949af7 100644 --- a/brevia/index.py +++ b/brevia/index.py @@ -71,9 +71,10 @@ def add_document( """ Add document to index and return number of splitted text chunks""" collection = single_collection_by_name(collection_name) embed_conf = collection.cmetadata.get('embeddings', None) if collection else None - split_conf = collection.cmetadata.get('text_splitter', None) if collection else None - - texts = split_document(document, split_conf) + texts = split_document( + document=document, + collection_meta=collection.cmetadata if collection else {}, + ) PGVector.from_documents( embedding=load_embeddings(embed_conf), documents=texts, @@ -86,13 +87,11 @@ def add_document( return len(texts) -def split_document(document: Document, split_conf: dict | None = None): +def split_document( + document: Document, collection_meta: dict = {} +) -> list[Document]: """ Split document into text chunks and return a list of documents""" - if not split_conf: - text_splitter = create_default_splitter() - else: - text_splitter = create_custom_splitter(split_conf) - + text_splitter = create_splitter(collection_meta) texts = text_splitter.split_documents([document]) counter = 1 for text in texts: @@ -101,15 +100,26 @@ def split_document(document: Document, split_conf: dict | None = None): return texts -def create_default_splitter() -> TextSplitter: - """ Create default text splitter""" +def create_splitter(collection_meta: dict) -> TextSplitter: + """ Create text splitter""" settings = get_settings() - - return NLTKTextSplitter( - separator="\n", - chunk_size=settings.text_chunk_size, - chunk_overlap=settings.text_chunk_overlap + custom_splitter = collection_meta.get( + 'text_splitter', + settings.text_splitter.copy() ) + chunk_size = collection_meta.get('chunk_size', settings.text_chunk_size) + chunk_overlap = collection_meta.get('chunk_overlap', settings.text_chunk_overlap) + + if not custom_splitter: + return NLTKTextSplitter( + separator="\n", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + + chunk_conf = {'chunk_size': chunk_size, 'chunk_overlap': chunk_overlap} + + return create_custom_splitter({**chunk_conf, **custom_splitter}) def create_custom_splitter(split_conf: dict) -> TextSplitter: diff --git a/brevia/query.py b/brevia/query.py index f24d8c7..b841cb4 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -271,7 +271,10 @@ def conversation_chain( # main chain, do all the jobs search_kwargs = {'k': chat_params.docs_num, 'filter': chat_params.filter} - retriever_conf = collection.cmetadata.get('qa_retriever') + retriever_conf = collection.cmetadata.get( + 'qa_retriever', + settings.qa_retriever.copy() + ) if not retriever_conf: retriever = create_default_retriever( store=docsearch, diff --git a/brevia/settings.py b/brevia/settings.py index dde9b6d..d18ed62 100644 --- a/brevia/settings.py +++ b/brevia/settings.py @@ -40,9 +40,10 @@ class Settings(BaseSettings): # Test models - only in unit tests use_test_models: bool = False - # Index - textsplitter settings + # Index - text splitter settings text_chunk_size: int = 2000 text_chunk_overlap: int = 200 + text_splitter: Json[dict[str, Any]] = '{}' # custom splitter settings # Search search_docs_num: int = 4 @@ -50,21 +51,21 @@ class Settings(BaseSettings): # LLM settings qa_completion_llm: Json[dict[str, Any]] = """{ "_type": "openai-chat", - "model_name": "gpt-3.5-turbo-16k", + "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 1000, "verbose": true }""" qa_followup_llm: Json[dict[str, Any]] = """{ "_type": "openai-chat", - "model_name": "gpt-3.5-turbo-16k", + "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 200, "verbose": true }""" summarize_llm: Json[dict[str, Any]] = """{ "_type": "openai-chat", - "model_name": "gpt-3.5-turbo-16k", + "model_name": "gpt-4o", "temperature": 0, "max_tokens": 2000 }""" @@ -83,6 +84,7 @@ class Settings(BaseSettings): # QA qa_no_chat_history: bool = False # don't load chat history qa_followup_sim_threshold: float = 0.735 # similitude threshold in followup + qa_retriever: Json[dict[str, Any]] = '{}' # custom retriever settings feature_qa_lang_detect: bool = False # Summarization diff --git a/docs/config.md b/docs/config.md index adffe99..9dce455 100644 --- a/docs/config.md +++ b/docs/config.md @@ -100,21 +100,36 @@ TEXT_CHUNK_SIZE=2000 TEXT_CHUNK_OVERLAP=100 ``` +`TEXT_SPLITTER` +This variable is an optional JSON string configuration, used to override the default text splitter + +Can be something like: + +`'{"splitter": "my_project.CustomSplitter", "some_var": "some_value"}'` + +Where: + +* `splitter` key must be present and point to a module path of a valid + retriever class extending langchain `TextSplitter` +* other splitter constructor attributes can be specified in the configuration, + like `some_var` in the above example + ## Q&A and Chat Under the hood of Q&A and Chat actions (see [Chat and Search](chat_search.md) section) you can configure models and behaviors via these variables: -* `QA_COMPLETION_LLM`: configuration for the main conversational model, used by `/chat` and `/completion` endpoints; a JSON string is used to configure the corresponding LangChain chat model class; an OpenAI instance is used as default: `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 1000, "verbose": true}'` where for instance `model_name` and other attributes can be adjusted to meet your needs -* `QA_FOLLOWUP_LLM`: configuration for the follow-up question model, used by `/chat` endpoint defining a follow up question for a conversation usgin chat history; a JSON string; an OpenAI instance used as default `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 200, "verbose": true}'` +* `QA_COMPLETION_LLM`: configuration for the main conversational model, used by `/chat` and `/completion` endpoints; a JSON string is used to configure the corresponding LangChain chat model class; an OpenAI instance is used as default: `'{"_type": "openai-chat", "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 1000, "verbose": true}'` where for instance `model_name` and other attributes can be adjusted to meet your needs +* `QA_FOLLOWUP_LLM`: configuration for the follow-up question model, used by `/chat` endpoint defining a follow up question for a conversation usgin chat history; a JSON string; an OpenAI instance used as default `'{"_type": "openai-chat", "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 200, "verbose": true}'` * `QA_FOLLOWUP_SIM_THRESHOLD`: a numeric value between 0 and 1 indicating similarity threshold between questions to determine if chat history should be used, defaults to `0.735` * `QA_NO_CHAT_HISTORY`: disables chat history entirely if set to `True` or any other value * `SEARCH_DOCS_NUM`: default number of documents used to search for answers, defaults to `4` +* `QA_RETRIEVER`: optional configuration for a custom retriever class, used by `/chat` endpoint, it's a JSON string defining a custom class and optional attributes; an example configuration can be `'{"retriever": "my_project.CustomRetriever", "some_var": "some_value"}'` where `retriever` key must be present with a module path pointing to a valid retriever class extending langchain `BaseRetriever` whereas other constructor attributes can be specified in the configuration, like `some_var` in the above example ## Summarization To configure summarize related actions in `/summarize` or `/upload_summarize` endpoints the related environment variables are: -* `SUMMARIZE_LLM`: the LLM to be used, a JSON string using the same format of `QA_COMPLETION_LLM` in the above paragraph; defatults to an OpenAI instance `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 2000}'` +* `SUMMARIZE_LLM`: the LLM to be used, a JSON string using the same format of `QA_COMPLETION_LLM` in the above paragraph; defatults to an OpenAI instance `'{"_type": "openai-chat", "model_name": "gpt-4o", "temperature": 0, "max_tokens": 2000}'` * `SUMM_TOKEN_SPLITTER`: the maximum size of individual text chunks processed during summarization, defaults to `4000` - see `TEXT_CHUNK_SIZE` in [Text Segmentation](#text-segmentation) paragraph * `SUMM_TOKEN_OVERLAP`: the amount of overlap between consecutive text chunks, defaults to `500` - see `TEXT_CHUNK_OVERLAP` in [Text Segmentation](#text-segmentation) paragraph * `SUMM_DEFAULT_CHAIN`: chain type to be used if not specified, defaults to `stuff` diff --git a/docs/tutorials/create_collection.md b/docs/tutorials/create_collection.md index d23c612..c6e2530 100644 --- a/docs/tutorials/create_collection.md +++ b/docs/tutorials/create_collection.md @@ -40,7 +40,7 @@ QA_FOLLOWUP_LLM='{ }' ``` -Replace `your_llm_type` and `your_llm_model` with your chosen LLM provider and specific model (e.g., "openai-chat", "gpt-3.5-turbo-16k"). +Replace `your_llm_type` and `your_llm_model` with your chosen LLM provider and specific model (e.g., "openai-chat", "gpt-4o-mini"). Adjust `temperature` and `max_tokens` parameters as needed. ## Database diff --git a/docs/tutorials/create_summarization.md b/docs/tutorials/create_summarization.md index 375fc04..ca156f8 100644 --- a/docs/tutorials/create_summarization.md +++ b/docs/tutorials/create_summarization.md @@ -12,7 +12,7 @@ Define the summarization model and parameters using the following JSON snippet a ```JSON SUMMARIZE_LLM='{ "_type": "openai-chat", - "model_name": "gpt-3.5-turbo-16k", + "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 2000 }' diff --git a/tests/test_index.py b/tests/test_index.py index 4f3b1f3..04c3ea2 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -2,15 +2,18 @@ from pathlib import Path from unittest.mock import patch from h11 import Response +from langchain_text_splitters import NLTKTextSplitter +from langchain_text_splitters.character import RecursiveCharacterTextSplitter import pytest from requests import HTTPError from langchain.docstore.document import Document from brevia.index import ( load_pdf_file, split_document, update_links_documents, add_document, document_has_changed, select_load_link_options, - documents_metadata, + documents_metadata, create_splitter, ) from brevia.collections import create_collection +from brevia.settings import get_settings def test_load_pdf_file(): @@ -109,9 +112,53 @@ def test_select_load_link_options(): def test_custom_split(): - """Test split_documents method with cuseom splitter class""" + """Test split_documents method with custom splitter class""" doc1 = Document(page_content='some content? no', metadata={'type': 'questions'}) cls = 'langchain_text_splitters.character.RecursiveCharacterTextSplitter' texts = split_document(doc1, {'splitter': cls}) assert len(texts) == 1 + + +def test_add_document_custom_split(): + """Test add_document method with custom splitter in settings""" + settings = get_settings() + current_splitter = settings.text_splitter + settings.text_splitter = { + 'splitter': 'langchain_text_splitters.character.RecursiveCharacterTextSplitter' + } + doc1 = Document(page_content='some content? no', metadata={'type': 'questions'}) + num = add_document(document=doc1, collection_name='test') + assert num == 1 + + settings.text_splitter = current_splitter + + +def test_create_splitter_chunk_params(): + """Test create_splitter method""" + splitter = create_splitter({'chunk_size': 2222, 'chunk_overlap': 333}) + + assert isinstance(splitter, NLTKTextSplitter) + assert splitter._chunk_size == 2222 + assert splitter._chunk_overlap == 333 + + splitter = create_splitter({}) + + assert isinstance(splitter, NLTKTextSplitter) + assert splitter._chunk_size == get_settings().text_chunk_size + assert splitter._chunk_overlap == get_settings().text_chunk_overlap + + custom_splitter = { + 'splitter': 'langchain_text_splitters.character.RecursiveCharacterTextSplitter', + 'chunk_size': 1111, + 'chunk_overlap': 555, + } + splitter = create_splitter({ + 'chunk_size': 3333, + 'chunk_overlap': 444, + 'text_splitter': custom_splitter, + }) + + assert isinstance(splitter, RecursiveCharacterTextSplitter) + assert splitter._chunk_size == 1111 + assert splitter._chunk_overlap == 555 diff --git a/tests/test_query.py b/tests/test_query.py index ec284d0..4b57fd6 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -4,7 +4,7 @@ from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain from langchain.docstore.document import Document from langchain_core.vectorstores import VectorStoreRetriever -from langchain.vectorstores.pgvector import PGVector +from langchain.retrievers.multi_query import MultiQueryRetriever from brevia.query import ( conversation_chain, load_qa_prompt, @@ -12,12 +12,10 @@ search_vector_qa, ChatParams, SearchQuery, - create_custom_retriever, ) from brevia.collections import create_collection -from brevia.connection import connection_string from brevia.index import add_document -from brevia.models import load_embeddings +from brevia.settings import get_settings FAKE_PROMPT = { '_type': 'prompt', @@ -128,15 +126,34 @@ def test_conversation_chain(): assert isinstance(chain, ConversationalRetrievalChain) -def test_create_custom_retriever(): - """Test create_custom_retriever function""" - conf = {'retriever': 'langchain_core.vectorstores.VectorStoreRetriever'} - store = PGVector( - connection_string=connection_string(), - embedding_function=load_embeddings(), - use_jsonb=True, +def test_conversation_chain_multiquery(): + """Test conversation_chain function with multiquery""" + collection = create_collection('test', {}) + chain = conversation_chain( + collection=collection, + chat_params=ChatParams(multiquery=True) ) - retriever = create_custom_retriever(store, {}, conf) - assert retriever is not None - assert isinstance(retriever, VectorStoreRetriever) + assert chain is not None + assert isinstance(chain, ConversationalRetrievalChain) + assert isinstance(chain.retriever, MultiQueryRetriever) + + +def test_conversation_chain_custom_retriever(): + """Test conversation_chain with custom retriever""" + collection = create_collection('test', {}) + settings = get_settings() + current_retriever = settings.qa_retriever + conf = {'retriever': 'langchain_core.vectorstores.VectorStoreRetriever'} + settings.qa_retriever = conf + chain = conversation_chain(collection=collection, chat_params=ChatParams()) + + assert chain is not None + assert isinstance(chain.retriever, VectorStoreRetriever) + settings.qa_retriever = current_retriever + + collection = create_collection('test', {'qa_retriever': conf}) + chain = conversation_chain(collection=collection, chat_params=ChatParams()) + + assert chain is not None + assert isinstance(chain.retriever, VectorStoreRetriever)