Skip to content

Commit

Permalink
Add custom default Retriever/TextSplitter, update deprecated OpenAI m…
Browse files Browse the repository at this point in the history
…odel (#67)

* feat: new load_type + create_retriever functions

* test: `create_retriever` test case

* test: fix test case

* feat: custom splitter from conf

* refactor: use `text_splitter` as key

* test: add custom splitter test case

* refactor: drop gtp-3.5turbo` from defaults

* refactor: drop gpt-3.5-turbo from defaults

* feat: default custom retriever + splitter via settings

* test: custom retriever + spliiter via settings test case

* doc: custom splitter and retriever in settings

* fix: handle chunk_size/overlap from collection conf

* chore: update settings sample

* test: new create splitter test case
  • Loading branch information
stefanorosanelli authored Sep 2, 2024
1 parent 50cd4c2 commit 2635cf1
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 45 deletions.
10 changes: 7 additions & 3 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ BREVIA_ENV_SECRETS='{
# QA completion LLM
QA_COMPLETION_LLM='{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 200,
"verbose": true
Expand All @@ -19,7 +19,7 @@ QA_COMPLETION_LLM='{
# QA followup LLM
QA_FOLLOWUP_LLM='{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 200,
"verbose": true
Expand All @@ -28,6 +28,8 @@ QA_FOLLOWUP_LLM='{
QA_FOLLOWUP_SIM_THRESHOLD=0.735
# Chat history - uncomment to disable session conversation, avoiding chat history loading
# QA_NO_CHAT_HISTORY=True
# Uncomment to use a custom QA retriever with custom arguments
# QA_RETRIEVER='{"retriever": "my_project.CustomSplitter", "some_var": "some_value"}'

# Access tokens secret - if missing no access token validity is checked
# Generate it with: openssl rand -hex 32
Expand All @@ -49,12 +51,14 @@ EMBEDDINGS='{"_type": "openai-embeddings"}'
## Index creation
TEXT_CHUNK_SIZE=2000
TEXT_CHUNK_OVERLAP=100
# Uncomment to use a custom splitter with custom arguments
# TEXT_SPLITTER='{"splitter": "my_project.CustomSplitter", "some_var": "some_value"}'

## Summarize
# Summarization LLM
SUMMARIZE_LLM='{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 2000
}'
Expand Down
42 changes: 26 additions & 16 deletions brevia/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def add_document(
""" Add document to index and return number of splitted text chunks"""
collection = single_collection_by_name(collection_name)
embed_conf = collection.cmetadata.get('embeddings', None) if collection else None
split_conf = collection.cmetadata.get('text_splitter', None) if collection else None

texts = split_document(document, split_conf)
texts = split_document(
document=document,
collection_meta=collection.cmetadata if collection else {},
)
PGVector.from_documents(
embedding=load_embeddings(embed_conf),
documents=texts,
Expand All @@ -86,13 +87,11 @@ def add_document(
return len(texts)


def split_document(document: Document, split_conf: dict | None = None):
def split_document(
document: Document, collection_meta: dict = {}
) -> list[Document]:
""" Split document into text chunks and return a list of documents"""
if not split_conf:
text_splitter = create_default_splitter()
else:
text_splitter = create_custom_splitter(split_conf)

text_splitter = create_splitter(collection_meta)
texts = text_splitter.split_documents([document])
counter = 1
for text in texts:
Expand All @@ -101,15 +100,26 @@ def split_document(document: Document, split_conf: dict | None = None):
return texts


def create_default_splitter() -> TextSplitter:
""" Create default text splitter"""
def create_splitter(collection_meta: dict) -> TextSplitter:
""" Create text splitter"""
settings = get_settings()

return NLTKTextSplitter(
separator="\n",
chunk_size=settings.text_chunk_size,
chunk_overlap=settings.text_chunk_overlap
custom_splitter = collection_meta.get(
'text_splitter',
settings.text_splitter.copy()
)
chunk_size = collection_meta.get('chunk_size', settings.text_chunk_size)
chunk_overlap = collection_meta.get('chunk_overlap', settings.text_chunk_overlap)

if not custom_splitter:
return NLTKTextSplitter(
separator="\n",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)

chunk_conf = {'chunk_size': chunk_size, 'chunk_overlap': chunk_overlap}

return create_custom_splitter({**chunk_conf, **custom_splitter})


def create_custom_splitter(split_conf: dict) -> TextSplitter:
Expand Down
5 changes: 4 additions & 1 deletion brevia/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ def conversation_chain(

# main chain, do all the jobs
search_kwargs = {'k': chat_params.docs_num, 'filter': chat_params.filter}
retriever_conf = collection.cmetadata.get('qa_retriever')
retriever_conf = collection.cmetadata.get(
'qa_retriever',
settings.qa_retriever.copy()
)
if not retriever_conf:
retriever = create_default_retriever(
store=docsearch,
Expand Down
10 changes: 6 additions & 4 deletions brevia/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,32 @@ class Settings(BaseSettings):
# Test models - only in unit tests
use_test_models: bool = False

# Index - textsplitter settings
# Index - text splitter settings
text_chunk_size: int = 2000
text_chunk_overlap: int = 200
text_splitter: Json[dict[str, Any]] = '{}' # custom splitter settings

# Search
search_docs_num: int = 4

# LLM settings
qa_completion_llm: Json[dict[str, Any]] = """{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 1000,
"verbose": true
}"""
qa_followup_llm: Json[dict[str, Any]] = """{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 200,
"verbose": true
}"""
summarize_llm: Json[dict[str, Any]] = """{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o",
"temperature": 0,
"max_tokens": 2000
}"""
Expand All @@ -83,6 +84,7 @@ class Settings(BaseSettings):
# QA
qa_no_chat_history: bool = False # don't load chat history
qa_followup_sim_threshold: float = 0.735 # similitude threshold in followup
qa_retriever: Json[dict[str, Any]] = '{}' # custom retriever settings
feature_qa_lang_detect: bool = False

# Summarization
Expand Down
21 changes: 18 additions & 3 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,36 @@ TEXT_CHUNK_SIZE=2000
TEXT_CHUNK_OVERLAP=100
```

`TEXT_SPLITTER`
This variable is an optional JSON string configuration, used to override the default text splitter

Can be something like:

`'{"splitter": "my_project.CustomSplitter", "some_var": "some_value"}'`

Where:

* `splitter` key must be present and point to a module path of a valid
retriever class extending langchain `TextSplitter`
* other splitter constructor attributes can be specified in the configuration,
like `some_var` in the above example

## Q&A and Chat

Under the hood of Q&A and Chat actions (see [Chat and Search](chat_search.md) section) you can configure models and behaviors via these variables:

* `QA_COMPLETION_LLM`: configuration for the main conversational model, used by `/chat` and `/completion` endpoints; a JSON string is used to configure the corresponding LangChain chat model class; an OpenAI instance is used as default: `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 1000, "verbose": true}'` where for instance `model_name` and other attributes can be adjusted to meet your needs
* `QA_FOLLOWUP_LLM`: configuration for the follow-up question model, used by `/chat` endpoint defining a follow up question for a conversation usgin chat history; a JSON string; an OpenAI instance used as default `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 200, "verbose": true}'`
* `QA_COMPLETION_LLM`: configuration for the main conversational model, used by `/chat` and `/completion` endpoints; a JSON string is used to configure the corresponding LangChain chat model class; an OpenAI instance is used as default: `'{"_type": "openai-chat", "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 1000, "verbose": true}'` where for instance `model_name` and other attributes can be adjusted to meet your needs
* `QA_FOLLOWUP_LLM`: configuration for the follow-up question model, used by `/chat` endpoint defining a follow up question for a conversation usgin chat history; a JSON string; an OpenAI instance used as default `'{"_type": "openai-chat", "model_name": "gpt-4o-mini", "temperature": 0, "max_tokens": 200, "verbose": true}'`
* `QA_FOLLOWUP_SIM_THRESHOLD`: a numeric value between 0 and 1 indicating similarity threshold between questions to determine if chat history should be used, defaults to `0.735`
* `QA_NO_CHAT_HISTORY`: disables chat history entirely if set to `True` or any other value
* `SEARCH_DOCS_NUM`: default number of documents used to search for answers, defaults to `4`
* `QA_RETRIEVER`: optional configuration for a custom retriever class, used by `/chat` endpoint, it's a JSON string defining a custom class and optional attributes; an example configuration can be `'{"retriever": "my_project.CustomRetriever", "some_var": "some_value"}'` where `retriever` key must be present with a module path pointing to a valid retriever class extending langchain `BaseRetriever` whereas other constructor attributes can be specified in the configuration, like `some_var` in the above example

## Summarization

To configure summarize related actions in `/summarize` or `/upload_summarize` endpoints the related environment variables are:

* `SUMMARIZE_LLM`: the LLM to be used, a JSON string using the same format of `QA_COMPLETION_LLM` in the above paragraph; defatults to an OpenAI instance `'{"_type": "openai-chat", "model_name": "gpt-3.5-turbo-16k", "temperature": 0, "max_tokens": 2000}'`
* `SUMMARIZE_LLM`: the LLM to be used, a JSON string using the same format of `QA_COMPLETION_LLM` in the above paragraph; defatults to an OpenAI instance `'{"_type": "openai-chat", "model_name": "gpt-4o", "temperature": 0, "max_tokens": 2000}'`
* `SUMM_TOKEN_SPLITTER`: the maximum size of individual text chunks processed during summarization, defaults to `4000` - see `TEXT_CHUNK_SIZE` in [Text Segmentation](#text-segmentation) paragraph
* `SUMM_TOKEN_OVERLAP`: the amount of overlap between consecutive text chunks, defaults to `500` - see `TEXT_CHUNK_OVERLAP` in [Text Segmentation](#text-segmentation) paragraph
* `SUMM_DEFAULT_CHAIN`: chain type to be used if not specified, defaults to `stuff`
2 changes: 1 addition & 1 deletion docs/tutorials/create_collection.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ QA_FOLLOWUP_LLM='{
}'
```

Replace `your_llm_type` and `your_llm_model` with your chosen LLM provider and specific model (e.g., "openai-chat", "gpt-3.5-turbo-16k").
Replace `your_llm_type` and `your_llm_model` with your chosen LLM provider and specific model (e.g., "openai-chat", "gpt-4o-mini").
Adjust `temperature` and `max_tokens` parameters as needed.

## Database
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/create_summarization.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Define the summarization model and parameters using the following JSON snippet a
```JSON
SUMMARIZE_LLM='{
"_type": "openai-chat",
"model_name": "gpt-3.5-turbo-16k",
"model_name": "gpt-4o-mini",
"temperature": 0,
"max_tokens": 2000
}'
Expand Down
51 changes: 49 additions & 2 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
from pathlib import Path
from unittest.mock import patch
from h11 import Response
from langchain_text_splitters import NLTKTextSplitter
from langchain_text_splitters.character import RecursiveCharacterTextSplitter
import pytest
from requests import HTTPError
from langchain.docstore.document import Document
from brevia.index import (
load_pdf_file, split_document, update_links_documents,
add_document, document_has_changed, select_load_link_options,
documents_metadata,
documents_metadata, create_splitter,
)
from brevia.collections import create_collection
from brevia.settings import get_settings


def test_load_pdf_file():
Expand Down Expand Up @@ -109,9 +112,53 @@ def test_select_load_link_options():


def test_custom_split():
"""Test split_documents method with cuseom splitter class"""
"""Test split_documents method with custom splitter class"""
doc1 = Document(page_content='some content? no', metadata={'type': 'questions'})
cls = 'langchain_text_splitters.character.RecursiveCharacterTextSplitter'
texts = split_document(doc1, {'splitter': cls})

assert len(texts) == 1


def test_add_document_custom_split():
"""Test add_document method with custom splitter in settings"""
settings = get_settings()
current_splitter = settings.text_splitter
settings.text_splitter = {
'splitter': 'langchain_text_splitters.character.RecursiveCharacterTextSplitter'
}
doc1 = Document(page_content='some content? no', metadata={'type': 'questions'})
num = add_document(document=doc1, collection_name='test')
assert num == 1

settings.text_splitter = current_splitter


def test_create_splitter_chunk_params():
"""Test create_splitter method"""
splitter = create_splitter({'chunk_size': 2222, 'chunk_overlap': 333})

assert isinstance(splitter, NLTKTextSplitter)
assert splitter._chunk_size == 2222
assert splitter._chunk_overlap == 333

splitter = create_splitter({})

assert isinstance(splitter, NLTKTextSplitter)
assert splitter._chunk_size == get_settings().text_chunk_size
assert splitter._chunk_overlap == get_settings().text_chunk_overlap

custom_splitter = {
'splitter': 'langchain_text_splitters.character.RecursiveCharacterTextSplitter',
'chunk_size': 1111,
'chunk_overlap': 555,
}
splitter = create_splitter({
'chunk_size': 3333,
'chunk_overlap': 444,
'text_splitter': custom_splitter,
})

assert isinstance(splitter, RecursiveCharacterTextSplitter)
assert splitter._chunk_size == 1111
assert splitter._chunk_overlap == 555
45 changes: 31 additions & 14 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.docstore.document import Document
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.vectorstores.pgvector import PGVector
from langchain.retrievers.multi_query import MultiQueryRetriever
from brevia.query import (
conversation_chain,
load_qa_prompt,
load_condense_prompt,
search_vector_qa,
ChatParams,
SearchQuery,
create_custom_retriever,
)
from brevia.collections import create_collection
from brevia.connection import connection_string
from brevia.index import add_document
from brevia.models import load_embeddings
from brevia.settings import get_settings

FAKE_PROMPT = {
'_type': 'prompt',
Expand Down Expand Up @@ -128,15 +126,34 @@ def test_conversation_chain():
assert isinstance(chain, ConversationalRetrievalChain)


def test_create_custom_retriever():
"""Test create_custom_retriever function"""
conf = {'retriever': 'langchain_core.vectorstores.VectorStoreRetriever'}
store = PGVector(
connection_string=connection_string(),
embedding_function=load_embeddings(),
use_jsonb=True,
def test_conversation_chain_multiquery():
"""Test conversation_chain function with multiquery"""
collection = create_collection('test', {})
chain = conversation_chain(
collection=collection,
chat_params=ChatParams(multiquery=True)
)
retriever = create_custom_retriever(store, {}, conf)

assert retriever is not None
assert isinstance(retriever, VectorStoreRetriever)
assert chain is not None
assert isinstance(chain, ConversationalRetrievalChain)
assert isinstance(chain.retriever, MultiQueryRetriever)


def test_conversation_chain_custom_retriever():
"""Test conversation_chain with custom retriever"""
collection = create_collection('test', {})
settings = get_settings()
current_retriever = settings.qa_retriever
conf = {'retriever': 'langchain_core.vectorstores.VectorStoreRetriever'}
settings.qa_retriever = conf
chain = conversation_chain(collection=collection, chat_params=ChatParams())

assert chain is not None
assert isinstance(chain.retriever, VectorStoreRetriever)
settings.qa_retriever = current_retriever

collection = create_collection('test', {'qa_retriever': conf})
chain = conversation_chain(collection=collection, chat_params=ChatParams())

assert chain is not None
assert isinstance(chain.retriever, VectorStoreRetriever)

0 comments on commit 2635cf1

Please sign in to comment.