Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom default Retriever/TextSplitter, update deprecated OpenAI model #67

Merged
merged 17 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ BREVIA_ENV_SECRETS='{
# QA completion LLM
QA_COMPLETION_LLM='{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 200,
"verbose": true
Expand All @@ -19,7 +19,7 @@ QA_COMPLETION_LLM='{
# QA followup LLM
QA_FOLLOWUP_LLM='{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 200,
"verbose": true
Expand All @@ -28,6 +28,8 @@ QA_FOLLOWUP_LLM='{
QA_FOLLOWUP_SIM_THRESHOLD=0.735
# Chat history - uncomment to disable session conversation, avoiding chat history loading
# QA_NO_CHAT_HISTORY=True
# Uncomment to use a custom QA retriever with custom arguments
# QA_RETRIEVER='{"retriever": "my_project.CustomSplitter", "some_var": "some_value"}'

# Access tokens secret - if missing no access token validity is checked
# Generate it with: openssl rand -hex 32
Expand All @@ -49,12 +51,14 @@ EMBEDDINGS='{"_type": "openai-embeddings"}'
## Index creation
TEXT_CHUNK_SIZE=2000
TEXT_CHUNK_OVERLAP=100
# Uncomment to use a custom splitter with custom arguments
# TEXT_SPLITTER='{"splitter": "my_project.CustomSplitter", "some_var": "some_value"}'

## Summarize
# Summarization LLM
SUMMARIZE_LLM='{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 2000
}'
Expand Down
42 changes: 26 additions & 16 deletions brevia/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def add_document(
""" Add document to index and return number of splitted text chunks"""
collection = single_collection_by_name(collection_name)
embed_conf = collection.cmetadata.get('embeddings', None) if collection else None
split_conf = collection.cmetadata.get('text_splitter', None) if collection else None

texts = split_document(document, split_conf)
texts = split_document(
document=document,
collection_meta=collection.cmetadata if collection else {},
)
PGVector.from_documents(
embedding=load_embeddings(embed_conf),
documents=texts,
Expand All @@ -86,13 +87,11 @@ def add_document(
return len(texts)


def split_document(document: Document, split_conf: dict | None = None):
def split_document(
document: Document, collection_meta: dict = {}
) -> list[Document]:
""" Split document into text chunks and return a list of documents"""
if not split_conf:
text_splitter = create_default_splitter()
else:
text_splitter = create_custom_splitter(split_conf)

text_splitter = create_splitter(collection_meta)
texts = text_splitter.split_documents([document])
counter = 1
for text in texts:
Expand All @@ -101,15 +100,26 @@ def split_document(document: Document, split_conf: dict | None = None):
return texts


def create_default_splitter() -> TextSplitter:
""" Create default text splitter"""
def create_splitter(collection_meta: dict) -> TextSplitter:
""" Create text splitter"""
settings = get_settings()

return NLTKTextSplitter(
separator="\n",
chunk_size=settings.text_chunk_size,
chunk_overlap=settings.text_chunk_overlap
custom_splitter = collection_meta.get(
'text_splitter',
settings.text_splitter.copy()
)
chunk_size = collection_meta.get('chunk_size', settings.text_chunk_size)
chunk_overlap = collection_meta.get('chunk_overlap', settings.text_chunk_overlap)

if not custom_splitter:
return NLTKTextSplitter(
separator="\n",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)

chunk_conf = {'chunk_size': chunk_size, 'chunk_overlap': chunk_overlap}

return create_custom_splitter({**chunk_conf, **custom_splitter})


def create_custom_splitter(split_conf: dict) -> TextSplitter:
Expand Down
5 changes: 4 additions & 1 deletion brevia/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ def conversation_chain(

# main chain, do all the jobs
search_kwargs = {'k': chat_params.docs_num, 'filter': chat_params.filter}
retriever_conf = collection.cmetadata.get('qa_retriever')
retriever_conf = collection.cmetadata.get(
'qa_retriever',
settings.qa_retriever.copy()
)
if not retriever_conf:
retriever = create_default_retriever(
store=docsearch,
Expand Down
10 changes: 6 additions & 4 deletions brevia/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,32 @@ class Settings(BaseSettings):
# Test models - only in unit tests
use_test_models: bool = False

# Index - textsplitter settings
# Index - text splitter settings
text_chunk_size: int = 2000
text_chunk_overlap: int = 200
text_splitter: Json[dict[str, Any]] = '{}' # custom splitter settings

# Search
search_docs_num: int = 4

# LLM settings
qa_completion_llm: Json[dict[str, Any]] = """{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 1000,
"verbose": true
}"""
qa_followup_llm: Json[dict[str, Any]] = """{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 200,
"verbose": true
}"""
summarize_llm: Json[dict[str, Any]] = """{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o",
"temperature": 0,
"max_tokens": 2000
}"""
Expand All @@ -83,6 +84,7 @@ class Settings(BaseSettings):
# QA
qa_no_chat_history: bool = False # don't load chat history
qa_followup_sim_threshold: float = 0.735 # similitude threshold in followup
qa_retriever: Json[dict[str, Any]] = '{}' # custom retriever settings
feature_qa_lang_detect: bool = False

# Summarization
Expand Down
21 changes: 18 additions & 3 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,36 @@ TEXT_CHUNK_SIZE=2000
TEXT_CHUNK_OVERLAP=100
```

`TEXT_SPLITTER`
This variable is an optional JSON string configuration, used to override the default text splitter

Can be something like:

`'{"splitter": "my_project.CustomSplitter", "some_var": "some_value"}'`

Where:

* `splitter` key must be present and point to a module path of a valid
retriever class extending langchain `TextSplitter`
* other splitter constructor attributes can be specified in the configuration,
like `some_var` in the above example

## Q&A and Chat

Under the hood of Q&A and Chat actions (see [Chat and Search](chat_search.md) section) you can configure models and behaviors via these variables:

* `QA_COMPLETION_LLM`: configuration for the main conversational model, used by `/chat` and `/completion` endpoints; a JSON string is used to configure the corresponding LangChain chat model class; an OpenAI instance is used as default: `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 1000, "verbose": true}'` where for instance `model_name` and other attributes can be adjusted to meet your needs
* `QA_FOLLOWUP_LLM`: configuration for the follow-up question model, used by `/chat` endpoint defining a follow up question for a conversation usgin chat history; a JSON string; an OpenAI instance used as default `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 200, "verbose": true}'`
* `QA_COMPLETION_LLM`: configuration for the main conversational model, used by `/chat` and `/completion` endpoints; a JSON string is used to configure the corresponding LangChain chat model class; an OpenAI instance is used as default: `'{"_type": "openai-chat", "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 1000, "verbose": true}'` where for instance `model_name` and other attributes can be adjusted to meet your needs
* `QA_FOLLOWUP_LLM`: configuration for the follow-up question model, used by `/chat` endpoint defining a follow up question for a conversation usgin chat history; a JSON string; an OpenAI instance used as default `'{"_type": "openai-chat", "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 200, "verbose": true}'`
* `QA_FOLLOWUP_SIM_THRESHOLD`: a numeric value between 0 and 1 indicating similarity threshold between questions to determine if chat history should be used, defaults to `0.735`
* `QA_NO_CHAT_HISTORY`: disables chat history entirely if set to `True` or any other value
* `SEARCH_DOCS_NUM`: default number of documents used to search for answers, defaults to `4`
* `QA_RETRIEVER`: optional configuration for a custom retriever class, used by `/chat` endpoint, it's a JSON string defining a custom class and optional attributes; an example configuration can be `'{"retriever": "my_project.CustomRetriever", "some_var": "some_value"}'` where `retriever` key must be present with a module path pointing to a valid retriever class extending langchain `BaseRetriever` whereas other constructor attributes can be specified in the configuration, like `some_var` in the above example

## Summarization

To configure summarize related actions in `/summarize` or `/upload_summarize` endpoints the related environment variables are:

* `SUMMARIZE_LLM`: the LLM to be used, a JSON string using the same format of `QA_COMPLETION_LLM` in the above paragraph; defatults to an OpenAI instance `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 2000}'`
* `SUMMARIZE_LLM`: the LLM to be used, a JSON string using the same format of `QA_COMPLETION_LLM` in the above paragraph; defatults to an OpenAI instance `'{"_type": "openai-chat", "model_name": "gpt-4o", "temperature": 0, "max_tokens": 2000}'`
* `SUMM_TOKEN_SPLITTER`: the maximum size of individual text chunks processed during summarization, defaults to `4000` - see `TEXT_CHUNK_SIZE` in [Text Segmentation](#text-segmentation) paragraph
* `SUMM_TOKEN_OVERLAP`: the amount of overlap between consecutive text chunks, defaults to `500` - see `TEXT_CHUNK_OVERLAP` in [Text Segmentation](#text-segmentation) paragraph
* `SUMM_DEFAULT_CHAIN`: chain type to be used if not specified, defaults to `stuff`
2 changes: 1 addition & 1 deletion docs/tutorials/create_collection.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ QA_FOLLOWUP_LLM='{
}'
```

Replace `your_llm_type` and `your_llm_model` with your chosen LLM provider and specific model (e.g., "openai-chat", "gpt-3.5-turbo-16k").
Replace `your_llm_type` and `your_llm_model` with your chosen LLM provider and specific model (e.g., "openai-chat", "gpt-4o-mini").
Adjust `temperature` and `max_tokens` parameters as needed.

## Database
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/create_summarization.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Define the summarization model and parameters using the following JSON snippet a
```JSON
SUMMARIZE_LLM='{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 2000
}'
Expand Down
51 changes: 49 additions & 2 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
from pathlib import Path
from unittest.mock import patch
from h11 import Response
from langchain_text_splitters import NLTKTextSplitter
from langchain_text_splitters.character import RecursiveCharacterTextSplitter
import pytest
from requests import HTTPError
from langchain.docstore.document import Document
from brevia.index import (
load_pdf_file, split_document, update_links_documents,
add_document, document_has_changed, select_load_link_options,
documents_metadata,
documents_metadata, create_splitter,
)
from brevia.collections import create_collection
from brevia.settings import get_settings


def test_load_pdf_file():
Expand Down Expand Up @@ -109,9 +112,53 @@ def test_select_load_link_options():


def test_custom_split():
"""Test split_documents method with cuseom splitter class"""
"""Test split_documents method with custom splitter class"""
doc1 = Document(page_content='some content? no', metadata={'type': 'questions'})
cls = 'langchain_text_splitters.character.RecursiveCharacterTextSplitter'
texts = split_document(doc1, {'splitter': cls})

assert len(texts) == 1


def test_add_document_custom_split():
"""Test add_document method with custom splitter in settings"""
settings = get_settings()
current_splitter = settings.text_splitter
settings.text_splitter = {
'splitter': 'langchain_text_splitters.character.RecursiveCharacterTextSplitter'
}
doc1 = Document(page_content='some content? no', metadata={'type': 'questions'})
num = add_document(document=doc1, collection_name='test')
assert num == 1

settings.text_splitter = current_splitter


def test_create_splitter_chunk_params():
"""Test create_splitter method"""
splitter = create_splitter({'chunk_size': 2222, 'chunk_overlap': 333})

assert isinstance(splitter, NLTKTextSplitter)
assert splitter._chunk_size == 2222
assert splitter._chunk_overlap == 333

splitter = create_splitter({})

assert isinstance(splitter, NLTKTextSplitter)
assert splitter._chunk_size == get_settings().text_chunk_size
assert splitter._chunk_overlap == get_settings().text_chunk_overlap

custom_splitter = {
'splitter': 'langchain_text_splitters.character.RecursiveCharacterTextSplitter',
'chunk_size': 1111,
'chunk_overlap': 555,
}
splitter = create_splitter({
'chunk_size': 3333,
'chunk_overlap': 444,
'text_splitter': custom_splitter,
})

assert isinstance(splitter, RecursiveCharacterTextSplitter)
assert splitter._chunk_size == 1111
assert splitter._chunk_overlap == 555
45 changes: 31 additions & 14 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.docstore.document import Document
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.vectorstores.pgvector import PGVector
from langchain.retrievers.multi_query import MultiQueryRetriever
from brevia.query import (
conversation_chain,
load_qa_prompt,
load_condense_prompt,
search_vector_qa,
ChatParams,
SearchQuery,
create_custom_retriever,
)
from brevia.collections import create_collection
from brevia.connection import connection_string
from brevia.index import add_document
from brevia.models import load_embeddings
from brevia.settings import get_settings

FAKE_PROMPT = {
'_type': 'prompt',
Expand Down Expand Up @@ -128,15 +126,34 @@ def test_conversation_chain():
assert isinstance(chain, ConversationalRetrievalChain)


def test_create_custom_retriever():
"""Test create_custom_retriever function"""
conf = {'retriever': 'langchain_core.vectorstores.VectorStoreRetriever'}
store = PGVector(
connection_string=connection_string(),
embedding_function=load_embeddings(),
use_jsonb=True,
def test_conversation_chain_multiquery():
"""Test conversation_chain function with multiquery"""
collection = create_collection('test', {})
chain = conversation_chain(
collection=collection,
chat_params=ChatParams(multiquery=True)
)
retriever = create_custom_retriever(store, {}, conf)

assert retriever is not None
assert isinstance(retriever, VectorStoreRetriever)
assert chain is not None
assert isinstance(chain, ConversationalRetrievalChain)
assert isinstance(chain.retriever, MultiQueryRetriever)


def test_conversation_chain_custom_retriever():
"""Test conversation_chain with custom retriever"""
collection = create_collection('test', {})
settings = get_settings()
current_retriever = settings.qa_retriever
conf = {'retriever': 'langchain_core.vectorstores.VectorStoreRetriever'}
settings.qa_retriever = conf
chain = conversation_chain(collection=collection, chat_params=ChatParams())

assert chain is not None
assert isinstance(chain.retriever, VectorStoreRetriever)
settings.qa_retriever = current_retriever

collection = create_collection('test', {'qa_retriever': conf})
chain = conversation_chain(collection=collection, chat_params=ChatParams())

assert chain is not None
assert isinstance(chain.retriever, VectorStoreRetriever)
Loading