diff --git a/literalai/instrumentation/llamaindex.py b/literalai/instrumentation/llamaindex.py index 09ce4b4..0b9fbb0 100644 --- a/literalai/instrumentation/llamaindex.py +++ b/literalai/instrumentation/llamaindex.py @@ -1,7 +1,7 @@ import logging import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union, cast - +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from typing_extensions import TypedDict from llama_index.core.base.llms.types import MessageRole from llama_index.core.base.response.schema import Response, StreamingResponse from llama_index.core.instrumentation import get_dispatcher @@ -30,7 +30,7 @@ from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from openai.types import CompletionUsage from openai.types.chat import ChatCompletion -from pydantic import Field +from pydantic import PrivateAttr from literalai.context import active_thread_var from literalai.observability.generation import ChatGeneration, GenerationMessageRole @@ -111,8 +111,8 @@ def create_generation(event: LLMChatStartEvent): class LiteralEventHandler(BaseEventHandler): """This class handles events coming from LlamaIndex.""" - _client: "LiteralClient" = Field(...) - _span_handler: "LiteralSpanHandler" = Field(...) + _client: "LiteralClient" = PrivateAttr(...) + _span_handler: "LiteralSpanHandler" = PrivateAttr(...) runs: Dict[str, List[Step]] = {} streaming_run_ids: List[str] = [] diff --git a/tests/e2e/test_llamaindex.py b/tests/e2e/test_llamaindex.py new file mode 100644 index 0000000..bee2652 --- /dev/null +++ b/tests/e2e/test_llamaindex.py @@ -0,0 +1,42 @@ +import os +import urllib.parse + +import pytest + +from literalai import LiteralClient + +from dotenv import load_dotenv + +load_dotenv() + + +@pytest.fixture +def non_mocked_hosts() -> list: + non_mocked_hosts = [] + + # Always skip mocking API + url = os.getenv("LITERAL_API_URL", None) + if url is not None: + parsed = urllib.parse.urlparse(url) + non_mocked_hosts.append(parsed.hostname) + + return non_mocked_hosts + + +@pytest.mark.e2e +class TestLlamaIndex: + @pytest.fixture( + scope="class" + ) # Feel free to move this fixture up for further testing + def client(self): + url = os.getenv("LITERAL_API_URL", None) + api_key = os.getenv("LITERAL_API_KEY", None) + assert url is not None and api_key is not None, "Missing environment variables" + + client = LiteralClient(batch_size=5, url=url, api_key=api_key) + client.instrument_llamaindex() + + return client + + async def test_instrument_llamaindex(self, client: "LiteralClient"): + assert client is not None