Skip to content

Commit

Permalink
Add min_date query filter in /chat_history (#21)
Browse files Browse the repository at this point in the history
* feat: add min date filter

* test: add filter test cases
  • Loading branch information
stefanorosanelli authored Nov 27, 2023
1 parent c5b9b6c commit 6898cb7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
11 changes: 7 additions & 4 deletions brevia/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def is_valid_uuid(val) -> bool:


def get_history(
min_date: str | None = None,
max_date: str | None = None,
collection: str | None = None,
page: int = 1,
Expand All @@ -143,7 +144,7 @@ def get_history(
using pagination data in response
"""
max_date = datetime.now() if max_date is None else max_date
filter_date = ChatHistoryStore.created <= max_date
min_date = datetime.fromtimestamp(0) if min_date is None else min_date
filter_collection = CollectionStore.name == collection
if collection is None:
filter_collection = CollectionStore.name is not None
Expand All @@ -155,7 +156,8 @@ def get_history(
with Session(db_connection()) as session:
query = get_history_query(
session=session,
filter_date=filter_date,
filter_min_date=ChatHistoryStore.created >= min_date,
filter_max_date=ChatHistoryStore.created <= max_date,
filter_collection=filter_collection,
)
count = query.count()
Expand All @@ -180,7 +182,8 @@ def get_history(

def get_history_query(
session: Session,
filter_date: BinaryExpression,
filter_min_date: BinaryExpression,
filter_max_date: BinaryExpression,
filter_collection: BinaryExpression,
) -> Query:
"""Return get history query"""
Expand All @@ -197,6 +200,6 @@ def get_history_query(
CollectionStore,
CollectionStore.uuid == ChatHistoryStore.collection_id
)
.filter(filter_date, filter_collection)
.filter(filter_min_date, filter_max_date, filter_collection)
.order_by(sqlalchemy.desc(ChatHistoryStore.created))
)
2 changes: 2 additions & 0 deletions brevia/routers/chat_history_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

@router.get('/chat_history', dependencies=get_dependencies(json_content_type=False))
def read_chat_history(
min_date: str | None = None,
max_date: str | None = None,
collection: str | None = None,
page: int = 1,
page_size: int = 50,
):
""" /chat_history endpoint, read stored chat history """
return chat_history.get_history(
min_date=min_date,
max_date=max_date,
collection=collection,
page=page,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_chat_history.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""chat_history module tests"""
from datetime import datetime, timedelta
import uuid
import pytest
from brevia.chat_history import (
Expand Down Expand Up @@ -55,6 +56,27 @@ def test_get_history():
assert history_items == expected


def test_get_history_filters():
"""Test get_history filters"""
create_collection('test_collection', {})
session_id = uuid.uuid4()
add_history(session_id, 'test_collection', 'who?', 'me')
yesterday = datetime.strftime(datetime.now() - timedelta(1), '%Y-%m-%d')
tomorrow = datetime.strftime(datetime.now() + timedelta(1), '%Y-%m-%d')
result = get_history(min_date=yesterday)
assert result['meta']['pagination']['count'] == 1
result = get_history(min_date=tomorrow)
assert result['meta']['pagination']['count'] == 0
result = get_history(max_date=tomorrow)
assert result['meta']['pagination']['count'] == 1
result = get_history(min_date=yesterday, collection='test_collection')
assert result['meta']['pagination']['count'] == 1
result = get_history(min_date=tomorrow, max_date=tomorrow)
assert result['meta']['pagination']['count'] == 0
result = get_history(min_date=yesterday, collection='test2')
assert result['meta']['pagination']['count'] == 0


def test_history_from_db():
"""Test history_from_db function"""
result = history_from_db(uuid.uuid4())
Expand Down

0 comments on commit 6898cb7

Please sign in to comment.