diff --git a/burr/integrations/haystack.py b/burr/integrations/haystack.py new file mode 100644 index 00000000..508862f5 --- /dev/null +++ b/burr/integrations/haystack.py @@ -0,0 +1,303 @@ +import inspect +from collections.abc import Mapping +from typing import Any, Optional, Sequence, Union + +from haystack import Pipeline +from haystack.core.component import Component +from haystack.core.component.types import _empty as haystack_empty + +from burr.core.action import Action +from burr.core.graph import Graph, GraphBuilder +from burr.core.state import State + + +# TODO show OpenTelemetry integration +class HaystackAction(Action): + """Burr ``Action`` wrapping a Haystack ``Component``. + + Haystack ``Component`` is the basic block of a Haystack ``Pipeline``. + A ``Component`` is instantiated, then it receives inputs for its ``.run()`` method + and returns output values. + + Learn more about components here: https://docs.haystack.deepset.ai/docs/custom-components + """ + + def __init__( + self, + component: Component, + reads: Union[list[str], dict[str, str]], + writes: Union[list[str], dict[str, str]], + name: Optional[str] = None, + bound_params: Optional[dict] = None, + ): + """Create a Burr ``Action`` from a Haystack ``Component``. + + :param component: Haystack ``Component`` to wrap + :param reads: State fields read and passed to ``Component.run()`` + :param writes: State fields where results of ``Component.run()`` are written + :param name: Name of the action. Can be set later via ``.with_name()`` or in the + ``ApplicationBuilder``. + :param bound_params: Parameters to bind to the `Component.run()` method. + + Basic example: + + .. code-block:: python + + from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever + from haystack.document_stores.in_memory import InMemoryDocumentStore + from burr.core import ApplicationBuilder + from burr.integrations.haystack import HaystackAction + + retrieve_documents = HaystackAction( + component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()), + name="retrieve_documents", + reads=["query_embedding"], + writes=["documents"], + ) + + app = ( + ApplicationBuilder() + .with_actions(retrieve_documents) + .with_transitions("retrieve_documents", "retrieve_documents") + .with_entrypoint("retrieve_documents") + .build() + ) + """ + self._component = component + self._name = name + self._bound_params = bound_params if bound_params is not None else {} + + # NOTE input and output socket mappings are kept separately to avoid naming conflicts. + if isinstance(reads, Mapping): + self._input_socket_mapping = reads + self._reads = list(set(reads.values())) + elif isinstance(reads, Sequence): + self._input_socket_mapping = {socket_name: socket_name for socket_name in reads} + self._reads = reads + else: + raise TypeError(f"`reads` must be a sequence or mapping. Received: {type(reads)}") + + self._validate_input_sockets() + + if isinstance(writes, Mapping): + self._output_socket_mapping = writes + self._writes = list(writes.keys()) + elif isinstance(writes, Sequence): + self._output_socket_mapping = {socket_name: socket_name for socket_name in writes} + self._writes = writes + else: + raise TypeError(f"`writes` must be a sequence or mapping. Received: {type(writes)}") + + self._validate_output_sockets() + + def _validate_input_sockets(self) -> None: + component_inputs = self._component.__haystack_input__._sockets_dict.keys() + for socket_name in self._input_socket_mapping.keys(): + if socket_name not in component_inputs: + raise ValueError( + f"Socket `{socket_name}` not found in `Component` inputs: {component_inputs}" + ) + + def _validate_output_sockets(self) -> None: + component_outputs = self._component.__haystack_output__._sockets_dict.keys() + for socket_name in self._output_socket_mapping.values(): + if socket_name not in component_outputs: + raise ValueError( + f"Socket `{socket_name}` not found in `Component` outputs: {component_outputs}" + ) + + @property + def component(self) -> Component: + """Haystack `Component` used by this action.""" + return self._component + + @property + def reads(self) -> list[str]: + """State fields read and passed to `Component.run()`""" + return self._reads + + @property + def writes(self) -> list[str]: + """State fields where results of `Component.run()` are written.""" + return self._writes + + @property + def inputs(self) -> tuple[dict[str, str], dict[str, str]]: + """Return dictionaries of required and optional inputs for `Component.run()`""" + required_inputs, optional_inputs = {}, {} + for socket_name, input_socket in self._component.__haystack_input__._sockets_dict.items(): + state_field_name = self._input_socket_mapping.get(socket_name, socket_name) + + # if we expect the value to come from state (previous actions) or it's a + # bound parameter, then this socket isn't a user-provided input + if state_field_name in self.reads or state_field_name in self._bound_params: + continue + + # determine if input is required or optional based on the socket's default value + if input_socket.default_value == haystack_empty: + required_inputs[state_field_name] = input_socket.type + else: + optional_inputs[state_field_name] = input_socket.type + + return required_inputs, optional_inputs + + def run(self, state: State, **run_kwargs) -> dict[str, Any]: + """Call the Haystack `Component.run()` method. + + :param state: State object of the application. It contains some input values + for ``Component.run()``. + :param run_kwargs: User-provided inputs for ``Component.run()``. + :return: Dictionary of results with mapping ``{socket_name: value}``. + + Note, values come from 3 sources: + - state (from previous actions) + - run_kwargs (inputs from ``Application.run()``) + - bound parameters (from ``HaystackAction`` instantiation) + """ + values = {} + + # here, precedence matters. Alternatively, we could unpack all dictionaries at once + # which would throw an error for key collisions + for input_socket_name, value in self._bound_params.items(): + values[input_socket_name] = value + + for input_socket_name, state_field_name in self._input_socket_mapping.items(): + try: + values[input_socket_name] = state[state_field_name] + except KeyError as e: + raise ValueError(f"No value found in state for field: {state_field_name}") from e + + for input_socket_name, value in run_kwargs.items(): + values[input_socket_name] = value + + return self._component.run(**values) + + def update(self, result: dict, state: State) -> State: + """Update the state using the results of ``Component.run()``. + The output socket name is mapped to the Burr state field name. + + Values returned by ``Component.run()`` that aren't in ``writes`` are ignored. + """ + # TODO we could want to handle ``.update()`` and ``.append()`` differently + state_update = {} + + for state_field_name, output_socket_name in self._output_socket_mapping.items(): + if state_field_name in self.writes: + try: + state_update[state_field_name] = result[output_socket_name] + except KeyError as e: + raise ValueError( + f"Socket `{output_socket_name}` missing from output of `Component.run()`" + ) from e + return state.update(**state_update) + + def get_source(self) -> str: + """Return the source code of the Haystack ``Component``. + + NOTE. This doesn't include the initialization parameters of the ``Component``. + This can be obtained using``HaystackAction().component.to_dict()``, but this + method might is not implemented for all components. + """ + return inspect.getsource(self._component.__class__) + + +def _socket_name_mapping(sockets_connections: list[tuple[str, str]]) -> dict[str, str]: + """Map socket names to a single socket name. + + In Haystack, components communicate via sockets. A socket called + "embedding" in one component can be renamed to "query_embedding" when + passed to another component. + + In Burr, there is a single state object so we need a mapping to resolve + that `embedding` and `query_embedding` point to the same value. This function + creates a mapping {socket_name: state_field} to rename sockets when creating + the Burr `Graph`. + """ + all_connections: dict[str, set[str]] = {} + for from_, to in sockets_connections: + if from_ not in all_connections: + all_connections[from_] = {from_} + all_connections[from_].add(to) + + if to not in all_connections: + all_connections[to] = {to} + all_connections[to].add(from_) + + reduced_mapping: dict[str, str] = {} + for key, values in all_connections.items(): + unique_name = min(values) + reduced_mapping[key] = unique_name + + return reduced_mapping + + +def _connected_inputs(pipeline) -> dict[str, list[str]]: + """Get all input sockets that are connected to other components.""" + return { + name: [ + socket.name + for socket in data.get("input_sockets", {}).values() + if socket.is_variadic or socket.senders + ] + for name, data in pipeline.graph.nodes(data=True) + } + + +def _connected_outputs(pipeline) -> dict[str, list[str]]: + """Get all output sockets that are connected to other components.""" + return { + name: [ + socket.name for socket in data.get("output_sockets", {}).values() if socket.receivers + ] + for name, data in pipeline.graph.nodes(data=True) + } + + +def haystack_pipeline_to_burr_graph(pipeline: Pipeline) -> Graph: + """Convert a Haystack `Pipeline` to a Burr `Graph`. + + NOTE. This currently doesn't support Haystack pipelines with + parallel branches. Learn more https://docs.haystack.deepset.ai/docs/pipelines#branching + + From the Haystack `Pipeline`, we can easily retrieve transitions. + For actions, we need to create `HaystackAction` from components + and map their sockets to Burr state fields + + EXPERIMENTAL: This feature is experimental and may change in the future. + Changes to Haystack or Burr could impact this function. Please let us know if + you encounter any issues. + """ + + # get all socket connections in the pipeline + sockets_connections = [ + (edge_data["from_socket"].name, edge_data["to_socket"].name) + for _, _, edge_data in pipeline.graph.edges.data() + ] + socket_mapping = _socket_name_mapping(sockets_connections) + + transitions = [(from_, to) for from_, to, _ in pipeline.graph.edges] + + # get all input and output sockets that are connected to other components + connected_inputs = _connected_inputs(pipeline) + connected_outputs = _connected_outputs(pipeline) + + actions = [] + for component_name, component in pipeline.walk(): + inputs_mapping = { + socket_name: socket_mapping[socket_name] + for socket_name in connected_inputs[component_name] + } + outputs_mapping = { + socket_mapping[socket_name]: socket_name + for socket_name in connected_outputs[component_name] + } + + haystack_action = HaystackAction( + name=component_name, + component=component, + reads=inputs_mapping, + writes=outputs_mapping, + ) + actions.append(haystack_action) + + return GraphBuilder().with_actions(*actions).with_transitions(*transitions).build() diff --git a/docs/reference/integrations/haystack.rst b/docs/reference/integrations/haystack.rst new file mode 100644 index 00000000..045a099b --- /dev/null +++ b/docs/reference/integrations/haystack.rst @@ -0,0 +1,9 @@ +======== +Haystack +======== + +The Haystack integration allows you to use ``Component`` as Burr ``Action`` using the ``HaystackAction`` construct. You can visit the examples in ``burr/examples/haystack-integration`` for a notebook tutorial. + +.. autoclass:: burr.integrations.haystack.HaystackAction + +.. autofunction:: burr.integrations.haystack.haystack_pipeline_to_burr_graph diff --git a/examples/haystack-integration/README.md b/examples/haystack-integration/README.md new file mode 100644 index 00000000..b99d7e17 --- /dev/null +++ b/examples/haystack-integration/README.md @@ -0,0 +1,5 @@ +# Haystack + Burr integration + +Haystack is a Python library to build AI pipelines. It assembles `Component` objects into a `Pipeline`, which is a graph of operations. One benefit of Haystack is that it provides many pre-built components to manage documents and interact with LLMs. + +This notebook shows how to convert a Haystack `Component` into a Burr `Action` and a `Pipeline` into a `Graph`. This allows you to integrate Haystack with Burr and leverage other Burr and Burr UI features! diff --git a/examples/haystack-integration/__init__.py b/examples/haystack-integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/haystack-integration/application.py b/examples/haystack-integration/application.py new file mode 100644 index 00000000..584ad93f --- /dev/null +++ b/examples/haystack-integration/application.py @@ -0,0 +1,77 @@ +import os + +from haystack.components.builders import PromptBuilder +from haystack.components.embedders import SentenceTransformersTextEmbedder +from haystack.components.generators import OpenAIGenerator +from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever +from haystack.document_stores.in_memory import InMemoryDocumentStore + +from burr.core import ApplicationBuilder, State, action +from burr.integrations.haystack import HaystackAction + +# dummy OpenAI key to avoid raising an error +os.environ["OPENAI_API_KEY"] = "sk-..." + + +embed_text = HaystackAction( + component=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), + name="embed_text", + reads=[], + writes={"embedding": "query_embedding"}, +) + + +retrieve_documents = HaystackAction( + component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()), + name="retrieve_documents", + reads=["query_embedding"], + writes=["documents"], +) + + +build_prompt = HaystackAction( + component=PromptBuilder(template="Document: {{documents}} Question: {{question}}"), + name="build_prompt", + reads=["documents"], + writes={"prompt": "question_prompt"}, +) + + +generate_answer = HaystackAction( + component=OpenAIGenerator(model="gpt-4o-mini"), + name="generate_answer", + reads={"question_prompt": "prompt"}, + writes={"text": "answer"}, +) + + +@action(reads=["answer"], writes=[]) +def display_answer(state: State) -> State: + print(state["answer"]) + return state + + +def build_application(): + return ( + ApplicationBuilder() + .with_actions( + embed_text, + retrieve_documents, + build_prompt, + generate_answer, + display_answer, + ) + .with_transitions( + ("embed_text", "retrieve_documents"), + ("retrieve_documents", "build_prompt"), + ("build_prompt", "generate_answer"), + ("generate_answer", "display_answer"), + ) + .with_entrypoint("embed_text") + .build() + ) + + +if __name__ == "__main__": + app = build_application() + app.visualize(include_state=True) diff --git a/examples/haystack-integration/notebook.ipynb b/examples/haystack-integration/notebook.ipynb new file mode 100644 index 00000000..2bbed2a4 --- /dev/null +++ b/examples/haystack-integration/notebook.ipynb @@ -0,0 +1,812 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Haystack + Burr integration\n", + "\n", + "Haystack is a Python library to build AI pipelines. It assembles `Component` objects into a `Pipeline`, which is a graph of operations. One benefit of Haystack is that it provides many pre-built components to manage documents and interact with LLMs.\n", + "\n", + "This notebook shows how to convert a Haystack `Component` into a Burr `Action` and a `Pipeline` into a `Graph`. This allows you to integrate Haystack with Burr and leverage other Burr and Burr UI features!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Native Haystack\n", + "The next cells show how to build a simple RAG pipeline using Haystack. You create the components and add them to the pipeline using `.add_component()`. Then, you need to specify connections between components using `.connect()`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/tjean/projects/dagworks/burr/.venv/lib/python3.11/site-packages/haystack/core/errors.py:34: DeprecationWarning: PipelineMaxLoops is deprecated and will be remove in version '2.7.0'; use PipelineMaxComponentRuns instead.\n", + " warnings.warn(\n", + "/home/tjean/projects/dagworks/burr/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import os\n", + "\n", + "from haystack.components.embedders import SentenceTransformersTextEmbedder\n", + "from haystack.components.builders import PromptBuilder\n", + "from haystack.components.generators import OpenAIGenerator\n", + "from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever\n", + "from haystack.document_stores.in_memory import InMemoryDocumentStore\n", + "from haystack import Pipeline\n", + "\n", + "# dummy OpenAI key to avoid raising an error\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n", + "\n", + "# 1. create components\n", + "document_store = InMemoryDocumentStore()\n", + "text_embedder = SentenceTransformersTextEmbedder(model=\"sentence-transformers/all-MiniLM-L6-v2\")\n", + "prompt_builder = PromptBuilder(template=\"Document: {{documents}} Question: {{question}}\")\n", + "retriever = InMemoryEmbeddingRetriever(document_store)\n", + "generator = OpenAIGenerator(model=\"gpt-4o-mini\")\n", + "\n", + "# 2. create pipeline\n", + "basic_rag_pipeline = Pipeline()\n", + "\n", + "# 3. add components to the pipeline\n", + "basic_rag_pipeline.add_component(\"text_embedder\", text_embedder)\n", + "basic_rag_pipeline.add_component(\"retriever\", retriever)\n", + "basic_rag_pipeline.add_component(\"prompt_builder\", prompt_builder)\n", + "basic_rag_pipeline.add_component(\"llm\", generator)\n", + "\n", + "# 4. connect components\n", + "basic_rag_pipeline.connect(\"text_embedder.embedding\", \"retriever.query_embedding\")\n", + "basic_rag_pipeline.connect(\"retriever\", \"prompt_builder.documents\")\n", + "basic_rag_pipeline.connect(\"prompt_builder\", \"llm\")\n", + "\n", + "basic_rag_pipeline.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Without using any integration, you could use Haystack within Burr's `actions`. The next is illustrative of how it can work." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "embed_text\n", + "\n", + "embed_text(): query_embedding\n", + "\n", + "\n", + "\n", + "retrieve_documents\n", + "\n", + "retrieve_documents(query_embedding): documents\n", + "\n", + "\n", + "\n", + "embed_text->retrieve_documents\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__user_question\n", + "\n", + "input: user_question\n", + "\n", + "\n", + "\n", + "input__user_question->embed_text\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "build_prompt\n", + "\n", + "build_prompt(documents): question_prompt\n", + "\n", + "\n", + "\n", + "input__user_question->build_prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "retrieve_documents->build_prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "generate_answer\n", + "\n", + "generate_answer(question_prompt): answer\n", + "\n", + "\n", + "\n", + "build_prompt->generate_answer\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from burr.core import action, State, ApplicationBuilder\n", + "\n", + "\n", + "@action(reads=[], writes=[\"query_embedding\"])\n", + "def embed_text(state: State, user_question: str) -> State:\n", + " text_embedder = SentenceTransformersTextEmbedder(model=\"sentence-transformers/all-MiniLM-L6-v2\")\n", + "\n", + " results = text_embedder.run(text=user_question)\n", + " return state.update(query_embedding=results[\"embedding\"])\n", + "\n", + "\n", + "@action(reads=[\"query_embedding\"], writes=[\"documents\"])\n", + "def retrieve_documents(state: State) -> State:\n", + " query_embedding = state[\"query_embedding\"]\n", + "\n", + " document_store = InMemoryDocumentStore()\n", + " retriever = InMemoryEmbeddingRetriever(document_store)\n", + "\n", + " results = retriever.run(query_embedding=query_embedding)\n", + " return state.update(documents=results[\"documents\"])\n", + "\n", + "\n", + "@action(reads=[\"documents\"], writes=[\"question_prompt\"])\n", + "def build_prompt(state: State, user_question: str) -> State:\n", + " documents = state[\"documents\"]\n", + "\n", + " prompt_builder = PromptBuilder(template=\"Document: {{documents}} Question: {{question}}\")\n", + "\n", + " results = prompt_builder.run(documents=documents, question=user_question) \n", + " return state.update(question_prompt=results[\"prompt\"])\n", + "\n", + "\n", + "@action(reads=[\"question_prompt\"], writes=[\"answer\"])\n", + "def generate_answer(state: State) -> State:\n", + " question_prompt = state[\"question_prompt\"]\n", + "\n", + " generator = OpenAIGenerator(model=\"gpt-4o-mini\")\n", + "\n", + " results = generator.run(prompt=question_prompt)\n", + " return state.update(answer=results[\"text\"])\n", + "\n", + "\n", + "app = (\n", + " ApplicationBuilder()\n", + " .with_actions(\n", + " embed_text,\n", + " retrieve_documents,\n", + " build_prompt,\n", + " generate_answer\n", + " )\n", + " .with_transitions(\n", + " (\"embed_text\", \"retrieve_documents\"),\n", + " (\"retrieve_documents\", \"build_prompt\"),\n", + " (\"build_prompt\", \"generate_answer\"))\n", + " .with_entrypoint(\"embed_text\")\n", + " .build()\n", + ")\n", + "app.visualize(include_state=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notes:\n", + "- Instead of using `Component` objects, we wrap them into `@action` decorated functions.\n", + "- While Haystack pipelines allow components to communicate via sockets, Burr relies on a centralized state.\n", + "- Burr requires building the `Graph` \"all at once\" via the `ApplicationBuilder` or `GraphBuilder` while Haystack allows to incrementally add `.add_component()` and `.connect()` statements to the pipeline.\n", + "- Haystack allows the parameters of `Component.run()` to be provided by other components via sockets or from the user inputs. Burr separates the two via the `State` object or the function arguments given through `.run(inputs=...)`.\n", + "- Haystack `Component` are objects, meaning they need to be instantiated and are stateful. Burr `Action` are stateless, which allows to resume runs from any `State` and enable \"time-travel debugging\".\n", + "- Haystack uses a `Router` component to [expression conditional edges](https://docs.haystack.deepset.ai/reference/routers-api#conditionalrouter). Burr allows to add condition directly via the `.with_transitions()` method by specifying in the tuple `(from_action, to_action, condition)`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Burr's `HaystackAction`\n", + "\n", + "To avoid having to wrap each component into an `@action` function, the `HaystackAction` was added to Burr. It takes an instantiated `Component`, a `name`, and the `reads/writes` of the action.\n", + "\n", + "The next cell shows two identical actions, one without the integration (taken from the previous section) and one using `HaystackAction`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from burr.integrations.haystack import HaystackAction\n", + "\n", + "@action(reads=[\"query_embedding\"], writes=[\"documents\"])\n", + "def retrieve_documents(state: State) -> State:\n", + " query_embedding = state[\"query_embedding\"]\n", + "\n", + " document_store = InMemoryDocumentStore()\n", + " retriever = InMemoryEmbeddingRetriever(document_store)\n", + " \n", + " results = retriever.run(query_embedding=query_embedding)\n", + " return state.update(documents=results[\"documents\"])\n", + "\n", + "\n", + "haystack_retrieve_documents = HaystackAction(\n", + " component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),\n", + " name=\"retrieve_documents\",\n", + " reads=[\"query_embedding\"],\n", + " writes=[\"documents\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next cell shows the entire application using the `HaystackAction` integration. The action `display_answer` defined using `@action` was added to show you can can combine both approaches.\n", + "\n", + "Note that some for some `HaystackAction`, `reads` and `writes` are dictionaries instead of the usual lists. This helps map the values from the Burr `State` to the Haystack `Component.run()` parameters and outputs. \n", + "\n", + "For example, in `generate_answer`:\n", + " - `reads={\"prompt\": \"question_prompt\"}` takes the value `State[\"question_prompt\"]` and assigns it to `Component.run(prompt=...)`\n", + " - `writes={\"answer\": \"replies\"}` takes the value `Component.run(...)[\"replies\"]` and assigns it to `state.update(answer=...)`" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "embed_text\n", + "\n", + "embed_text(): query_embedding\n", + "\n", + "\n", + "\n", + "retrieve_documents\n", + "\n", + "retrieve_documents(query_embedding): documents\n", + "\n", + "\n", + "\n", + "embed_text->retrieve_documents\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__text\n", + "\n", + "input: text\n", + "\n", + "\n", + "\n", + "input__text->embed_text\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "build_prompt\n", + "\n", + "build_prompt(documents): question_prompt\n", + "\n", + "\n", + "\n", + "retrieve_documents->build_prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__return_embedding\n", + "\n", + "input: return_embedding\n", + "\n", + "\n", + "\n", + "input__return_embedding->retrieve_documents\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__scale_score\n", + "\n", + "input: scale_score\n", + "\n", + "\n", + "\n", + "input__scale_score->retrieve_documents\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__filters\n", + "\n", + "input: filters\n", + "\n", + "\n", + "\n", + "input__filters->retrieve_documents\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__top_k\n", + "\n", + "input: top_k\n", + "\n", + "\n", + "\n", + "input__top_k->retrieve_documents\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "generate_answer\n", + "\n", + "generate_answer(question_prompt): answer\n", + "\n", + "\n", + "\n", + "build_prompt->generate_answer\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__template_variables\n", + "\n", + "input: template_variables\n", + "\n", + "\n", + "\n", + "input__template_variables->build_prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__question\n", + "\n", + "input: question\n", + "\n", + "\n", + "\n", + "input__question->build_prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__template\n", + "\n", + "input: template\n", + "\n", + "\n", + "\n", + "input__template->build_prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "display_answer\n", + "\n", + "display_answer(answer): \n", + "\n", + "\n", + "\n", + "generate_answer->display_answer\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__streaming_callback\n", + "\n", + "input: streaming_callback\n", + "\n", + "\n", + "\n", + "input__streaming_callback->generate_answer\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__generation_kwargs\n", + "\n", + "input: generation_kwargs\n", + "\n", + "\n", + "\n", + "input__generation_kwargs->generate_answer\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from burr.core import action, State, ApplicationBuilder\n", + "\n", + "embed_text = HaystackAction(\n", + " component=SentenceTransformersTextEmbedder(model=\"sentence-transformers/all-MiniLM-L6-v2\"),\n", + " name=\"embed_text\",\n", + " reads=[],\n", + " writes={\"query_embedding\": \"embedding\"},\n", + ")\n", + "\n", + "retrieve_documents = HaystackAction(\n", + " component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),\n", + " name=\"retrieve_documents\",\n", + " reads=[\"query_embedding\"],\n", + " writes=[\"documents\"],\n", + ")\n", + "\n", + "build_prompt = HaystackAction(\n", + " component=PromptBuilder(template=\"Document: {{documents}} Question: {{question}}\"),\n", + " name=\"build_prompt\",\n", + " reads=[\"documents\"],\n", + " writes={\"question_prompt\": \"prompt\"},\n", + ")\n", + "\n", + "generate_answer = HaystackAction(\n", + " component=OpenAIGenerator(model=\"gpt-4o-mini\"),\n", + " name=\"generate_answer\",\n", + " reads={\"prompt\": \"question_prompt\"},\n", + " writes={\"answer\": \"replies\"}\n", + ")\n", + "\n", + "@action(reads=[\"answer\"], writes=[])\n", + "def display_answer(state: State) -> State:\n", + " print(state[\"answer\"])\n", + " return state\n", + "\n", + "\n", + "app = (\n", + " ApplicationBuilder()\n", + " .with_actions(\n", + " embed_text,\n", + " retrieve_documents,\n", + " build_prompt,\n", + " generate_answer,\n", + " display_answer,\n", + " )\n", + " .with_transitions(\n", + " (\"embed_text\", \"retrieve_documents\"),\n", + " (\"retrieve_documents\", \"build_prompt\"),\n", + " (\"build_prompt\", \"generate_answer\"),\n", + " (\"generate_answer\", \"display_answer\"),\n", + " )\n", + " .with_entrypoint(\"embed_text\")\n", + " .build()\n", + ")\n", + "app.visualize(include_state=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting a Haystack `Pipeline`\n", + "\n", + "If you have an existing Haystack `Pipeline`, you can convert it into a Burr `Graph` using in a single line of code using `haystack_pipeline_to_burr_graph()`.\n", + "\n", + "Next, we convert the `basic_rag_pipeline` defined at the beginning of the notebook. The resulting `Graph` can be passed to the `ApplicationBuilder.with_graph()` clause.\n", + "\n", + "The visualization should match the previous ones, but with different names (e.g., `generate_answer` is `llm`)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "text_embedder\n", + "\n", + "text_embedder(): embedding\n", + "\n", + "\n", + "\n", + "retriever\n", + "\n", + "retriever(embedding): documents\n", + "\n", + "\n", + "\n", + "text_embedder->retriever\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__text\n", + "\n", + "input: text\n", + "\n", + "\n", + "\n", + "input__text->text_embedder\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "prompt_builder\n", + "\n", + "prompt_builder(documents): prompt\n", + "\n", + "\n", + "\n", + "retriever->prompt_builder\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__return_embedding\n", + "\n", + "input: return_embedding\n", + "\n", + "\n", + "\n", + "input__return_embedding->retriever\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__scale_score\n", + "\n", + "input: scale_score\n", + "\n", + "\n", + "\n", + "input__scale_score->retriever\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__filters\n", + "\n", + "input: filters\n", + "\n", + "\n", + "\n", + "input__filters->retriever\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__top_k\n", + "\n", + "input: top_k\n", + "\n", + "\n", + "\n", + "input__top_k->retriever\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "llm\n", + "\n", + "llm(prompt): \n", + "\n", + "\n", + "\n", + "prompt_builder->llm\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__template_variables\n", + "\n", + "input: template_variables\n", + "\n", + "\n", + "\n", + "input__template_variables->prompt_builder\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__question\n", + "\n", + "input: question\n", + "\n", + "\n", + "\n", + "input__question->prompt_builder\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__template\n", + "\n", + "input: template\n", + "\n", + "\n", + "\n", + "input__template->prompt_builder\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__streaming_callback\n", + "\n", + "input: streaming_callback\n", + "\n", + "\n", + "\n", + "input__streaming_callback->llm\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__generation_kwargs\n", + "\n", + "input: generation_kwargs\n", + "\n", + "\n", + "\n", + "input__generation_kwargs->llm\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from burr.integrations.haystack import haystack_pipeline_to_burr_graph\n", + "\n", + "haystack_graph = haystack_pipeline_to_burr_graph(basic_rag_pipeline)\n", + "app = (\n", + " ApplicationBuilder()\n", + " .with_graph(haystack_graph)\n", + " .with_entrypoint(\"prompt_builder\")\n", + " .build()\n", + ")\n", + "app.visualize(include_state=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/haystack-integration/statemachine.png b/examples/haystack-integration/statemachine.png new file mode 100644 index 00000000..907f4d28 Binary files /dev/null and b/examples/haystack-integration/statemachine.png differ diff --git a/pyproject.toml b/pyproject.toml index 8bbc631d..f32c279d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,8 @@ tests = [ "pydantic[email]", "pyarrow", "redis", - "burr[opentelemetry]" + "burr[opentelemetry]", + "burr[haystack]" ] documentation = [ @@ -113,6 +114,10 @@ pydantic = [ "pydantic" ] +haystack = [ + "haystack-ai" +] + cli = [ "loguru", "click", diff --git a/tests/integrations/test_burr_haystack.py b/tests/integrations/test_burr_haystack.py new file mode 100644 index 00000000..3ddb124b --- /dev/null +++ b/tests/integrations/test_burr_haystack.py @@ -0,0 +1,253 @@ +from haystack import Pipeline, component +from haystack.components.embedders import SentenceTransformersTextEmbedder +from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever +from haystack.document_stores.in_memory import InMemoryDocumentStore + +from burr.core import State, action +from burr.core.graph import GraphBuilder +from burr.integrations.haystack import HaystackAction, haystack_pipeline_to_burr_graph + + +@component +class MockComponent: + def __init__(self, required_init: str, optional_init: str = "default"): + self.required_init = required_init + self.optional_init = optional_init + + @component.output_types(output_1=str, output_2=str) + def run(self, required_input: str, optional_input: str = "default") -> dict: + return { + "output_1": required_input, + "output_2": optional_input, + } + + +@action(reads=["query_embedding"], writes=["documents"]) +def retrieve_documents(state: State) -> State: + query_embedding = state["query_embedding"] + + document_store = InMemoryDocumentStore() + retriever = InMemoryEmbeddingRetriever(document_store) + + results = retriever.run(query_embedding=query_embedding) + return state.update(documents=results["documents"]) + + +haystack_retrieve_documents = HaystackAction( + component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()), + name="retrieve_documents", + reads=["query_embedding"], + writes=["documents"], +) + + +def test_input_socket_mapping(): + # {input_socket_name: state_field} + reads = {"required_input": "foo"} + + haction = HaystackAction( + component=MockComponent(required_init="init"), name="mock", reads=reads, writes=[] + ) + + assert haction.reads == list(set(reads.values())) == ["foo"] + + +def test_input_socket_sequence(): + # {input_socket_name: input_socket_name} + reads = ["required_input"] + + haction = HaystackAction( + component=MockComponent(required_init="init"), name="mock", reads=reads, writes=[] + ) + + assert haction.reads == list(reads) == ["required_input"] + + +def test_output_socket_mapping(): + # {state_field: output_socket_name} + writes = {"bar": "output_1"} + + haction = HaystackAction( + component=MockComponent(required_init="init"), name="mock", reads=[], writes=writes + ) + + assert haction.writes == list(writes.keys()) == ["bar"] + + +def test_output_socket_sequence(): + # {output_socket_name: output_socket_name} + writes = ["output_1"] + + haction = HaystackAction( + component=MockComponent(required_init="init"), name="mock", reads=[], writes=writes + ) + + assert haction.writes == writes == ["output_1"] + + +def test_get_component_source(): + haction = HaystackAction( + component=MockComponent(required_init="init"), name="mock", reads=[], writes=[] + ) + + expected_source = """\ +@component +class MockComponent: + def __init__(self, required_init: str, optional_init: str = "default"): + self.required_init = required_init + self.optional_init = optional_init + + @component.output_types(output_1=str, output_2=str) + def run(self, required_input: str, optional_input: str = "default") -> dict: + return { + "output_1": required_input, + "output_2": optional_input, + } +""" + + assert haction.get_source() == expected_source + + +def test_run_with_external_inputs(): + state = State(initial_values={}) + haction = HaystackAction( + component=MockComponent(required_init="init"), name="mock", reads=[], writes=[] + ) + + results = haction.run(state=state, required_input="as_input") + + assert results == {"output_1": "as_input", "output_2": "default"} + + +def test_run_with_state_inputs(): + state = State(initial_values={"foo": "bar"}) + haction = HaystackAction( + component=MockComponent(required_init="init"), + name="mock", + reads={"required_input": "foo"}, + writes=[], + ) + + results = haction.run(state=state) + + assert results == {"output_1": "bar", "output_2": "default"} + + +def test_run_with_bound_params(): + state = State(initial_values={}) + haction = HaystackAction( + component=MockComponent(required_init="init"), + name="mock", + reads=[], + writes=[], + bound_params={"required_input": "baz"}, + ) + + results = haction.run(state=state) + + assert results == {"output_1": "baz", "output_2": "default"} + + +def test_run_mixed_params(): + state = State(initial_values={"foo": "bar"}) + haction = HaystackAction( + component=MockComponent(required_init="init"), + name="mock", + reads={"required_input": "foo"}, + writes=[], + bound_params={"optional_input": "baz"}, + ) + + results = haction.run(state=state) + + assert results == {"output_1": "bar", "output_2": "baz"} + + +def test_run_with_sequence(): + state = State(initial_values={"required_input": "bar"}) + haction = HaystackAction( + component=MockComponent(required_init="init"), + name="mock", + reads=["required_input"], + writes=[], + ) + + results = haction.run(state=state) + + assert results == {"output_1": "bar", "output_2": "default"} + + +def test_update_with_writes_mapping(): + state = State(initial_values={}) + results = {"output_1": 1, "output_2": 2} + haction = HaystackAction( + component=MockComponent(required_init="init"), + name="mock", + reads=[], + writes={"foo": "output_1"}, + ) + + new_state = haction.update(result=results, state=state) + + assert new_state["foo"] == 1 + + +def test_update_with_writes_sequence(): + state = State(initial_values={}) + results = {"output_1": 1, "output_2": 2} + haction = HaystackAction( + component=MockComponent(required_init="init"), + name="mock", + reads=[], + writes=["output_1"], + ) + + new_state = haction.update(result=results, state=state) + + assert new_state["output_1"] == 1 + + +def test_pipeline_converter(): + # create haystack Pipeline + retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore()) + text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") + + basic_rag_pipeline = Pipeline() + basic_rag_pipeline.add_component("text_embedder", text_embedder) + basic_rag_pipeline.add_component("retriever", retriever) + basic_rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + + # create Burr application + embed_text = HaystackAction( + component=text_embedder, + name="text_embedder", + reads=[], + writes={"query_embedding": "embedding"}, + ) + + retrieve_documents = HaystackAction( + component=retriever, + name="retriever", + reads=["query_embedding"], + writes=["documents"], + ) + + burr_graph = ( + GraphBuilder() + .with_actions(embed_text, retrieve_documents) + .with_transitions(("text_embedder", "retriever")) + .build() + ) + + # convert the Haystack Pipeline to a Burr graph + haystack_graph = haystack_pipeline_to_burr_graph(basic_rag_pipeline) + + converted_action_names = [action.name for action in haystack_graph.actions] + for graph_action in burr_graph.actions: + assert graph_action.name in converted_action_names + + for burr_t in burr_graph.transitions: + assert any( + burr_t.from_.name == haystack_t.from_.name and burr_t.to.name == haystack_t.to.name + for haystack_t in haystack_graph.transitions + )