Skip to content

Commit

Permalink
Enable a custom text splitter for a collection (#66)
Browse files Browse the repository at this point in the history
* feat: new load_type + create_retriever functions

* test: `create_retriever` test case

* test: fix test case

* feat: custom splitter from conf

* refactor: use `text_splitter` as key

* test: add custom splitter test case
  • Loading branch information
stefanorosanelli authored Aug 30, 2024
1 parent a971e2b commit bf0e8f4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
49 changes: 32 additions & 17 deletions brevia/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.documents import Document
from langchain_text_splitters import NLTKTextSplitter
from langchain_text_splitters.base import TextSplitter
from requests import HTTPError
from sqlalchemy.orm import Session
from brevia import connection, load_file
from brevia.collections import single_collection_by_name
from brevia.models import load_embeddings
from brevia.settings import get_settings
from brevia.utilities.json_api import query_data_pagination
from brevia.utilities.types import load_type


def init_index():
Expand Down Expand Up @@ -67,9 +69,13 @@ def add_document(
document_id: str = None,
) -> int:
""" Add document to index and return number of splitted text chunks"""
texts = split_document(document)
collection = single_collection_by_name(collection_name)
embed_conf = collection.cmetadata.get('embeddings', None) if collection else None
split_conf = collection.cmetadata.get('text_splitter', None) if collection else None

texts = split_document(document, split_conf)
PGVector.from_documents(
embedding=load_embeddings(collection_embeddings(collection_name)),
embedding=load_embeddings(embed_conf),
documents=texts,
collection_name=collection_name,
connection_string=connection.connection_string(),
Expand All @@ -80,29 +86,38 @@ def add_document(
return len(texts)


def collection_embeddings(collection_name: str) -> dict | None:
""" Return custom embeddings of a collection"""
collection = single_collection_by_name(collection_name)
if collection is None:
return None
def split_document(document: Document, split_conf: dict | None = None):
""" Split document into text chunks and return a list of documents"""
if not split_conf:
text_splitter = create_default_splitter()
else:
text_splitter = create_custom_splitter(split_conf)

return collection.cmetadata.get('embeddings', None)
texts = text_splitter.split_documents([document])
counter = 1
for text in texts:
text.metadata['part'] = counter
counter += 1
return texts


def split_document(document: Document):
""" Split document into text chunks and return a list of documents"""
def create_default_splitter() -> TextSplitter:
""" Create default text splitter"""
settings = get_settings()
text_splitter = NLTKTextSplitter(

return NLTKTextSplitter(
separator="\n",
chunk_size=settings.text_chunk_size,
chunk_overlap=settings.text_chunk_overlap
)
texts = text_splitter.split_documents([document])
counter = 1
for text in texts:
text.metadata['part'] = counter
counter += 1
return texts


def create_custom_splitter(split_conf: dict) -> TextSplitter:
""" Create custom text splitter"""
splitter_name = split_conf.pop('splitter', '')
splitter_class = load_type(splitter_name, TextSplitter)

return splitter_class(**split_conf)


def remove_document(
Expand Down
11 changes: 10 additions & 1 deletion tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from requests import HTTPError
from langchain.docstore.document import Document
from brevia.index import (
load_pdf_file, update_links_documents,
load_pdf_file, split_document, update_links_documents,
add_document, document_has_changed, select_load_link_options,
documents_metadata,
)
Expand Down Expand Up @@ -106,3 +106,12 @@ def test_select_load_link_options():

result = select_load_link_options(url='someurl.org', options=options)
assert result == {}


def test_custom_split():
"""Test split_documents method with cuseom splitter class"""
doc1 = Document(page_content='some content? no', metadata={'type': 'questions'})
cls = 'langchain_text_splitters.character.RecursiveCharacterTextSplitter'
texts = split_document(doc1, {'splitter': cls})

assert len(texts) == 1

0 comments on commit bf0e8f4

Please sign in to comment.