Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Saisakul Chernbumroong authored and Saisakul Chernbumroong committed Nov 15, 2024
1 parent 86973c8 commit 2512c00
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
10 changes: 6 additions & 4 deletions redbox-core/redbox/graph/nodes/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _search_documents(query: str, state: Annotated[RedboxState, InjectedState])
return _search_documents


def build_govuk_search_tool(num_results: int = 1) -> Tool:
def build_govuk_search_tool(num_results: int = 1, filter=False) -> Tool:
"""Constructs a tool that searches gov.uk and sets state["documents"]."""

tokeniser = tiktoken.encoding_for_model("gpt-4o")
Expand All @@ -157,7 +157,7 @@ def recalculate_similarity(response, query, num_results):
return response

@tool
def _search_govuk(query: str, state: Annotated[dict, InjectedState]) -> dict[str, Any]:
def _search_govuk(query: str, state: Annotated[dict, InjectedState], filter=filter) -> dict[str, Any]:
"""
Search for documents on gov.uk based on a query string.
This endpoint is used to search for documents on gov.uk. There are many types of documents on gov.uk.
Expand Down Expand Up @@ -186,15 +186,17 @@ def _search_govuk(query: str, state: Annotated[dict, InjectedState]) -> dict[str
f"{url_base}/api/search.json",
params={
"q": query,
"count": 10,
"count": 10 if filter else num_results,
"fields": required_fields,
},
headers={"Accept": "application/json"},
)
response.raise_for_status()
response = response.json()

response = recalculate_similarity(response, query, num_results)
if filter:
response = recalculate_similarity(response, query, num_results)

mapped_documents = []
for i, doc in enumerate(response["results"]):
if any(field not in doc for field in required_fields):
Expand Down
35 changes: 34 additions & 1 deletion redbox-core/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
has_injected_state,
is_valid_tool,
)
from redbox.models.settings import Settings
from redbox.models.chain import AISettings, RedboxQuery, RedboxState
from redbox.models.file import ChunkCreatorType, ChunkMetadata, ChunkResolution
from redbox.models.settings import Settings
from redbox.test.data import RedboxChatTestCase
from redbox.transform import flatten_document_state
from tests.retriever.test_retriever import TEST_CHAIN_PARAMETERS
Expand Down Expand Up @@ -210,3 +210,36 @@ def test_wikipedia_tool():
metadata = ChunkMetadata.model_validate(document.metadata)
assert urlparse(metadata.uri).hostname == "en.wikipedia.org"
assert metadata.creator_type == ChunkCreatorType.wikipedia


@pytest.mark.parametrize(
"is_filter, relevant_return, query, keyword",
[
(False, False, "UK government use of AI", "artificial intelligence"),
(True, True, "UK government use of AI", "artificial intelligence"),
],
)
def test_gov_filter_AI(is_filter, relevant_return, query, keyword):
def run_tool(is_filter):
tool = build_govuk_search_tool(num_results=1, filter=is_filter)
state_update = tool.invoke(
{
"query": query,
"state": RedboxState(
request=RedboxQuery(
question=query,
s3_keys=[],
user_uuid=uuid4(),
chat_history=[],
ai_settings=AISettings(),
permitted_s3_keys=[],
)
),
}
)

return flatten_document_state(state_update["documents"])

# call gov tool without additional filter
documents = run_tool(is_filter)
assert any(keyword in document.page_content for document in documents) == relevant_return

0 comments on commit 2512c00

Please sign in to comment.