-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
integration: added HaystackAction #398
Open
zilto
wants to merge
8
commits into
main
Choose a base branch
from
integrations/haystack
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
660a268
added HaystackAction
f88037e
fixed dictionary iterator
8f731b9
fixed kwargs iterator
0b24192
remove autoreload cell
dc421d0
added tests and docs; applied reviews
f2bc122
added haystack optional dep
82ee637
fixed typo
f9e171b
fixed dependency to haystack-ai; removed print()
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
import inspect | ||
from collections.abc import Mapping | ||
from typing import Any, Optional, Sequence, Union | ||
|
||
from haystack import Pipeline | ||
from haystack.core.component import Component | ||
from haystack.core.component.types import _empty as haystack_empty | ||
|
||
from burr.core.action import Action | ||
from burr.core.graph import Graph, GraphBuilder | ||
from burr.core.state import State | ||
|
||
|
||
# TODO show OpenTelemetry integration | ||
class HaystackAction(Action): | ||
"""Burr ``Action`` wrapping a Haystack ``Component``. | ||
|
||
Haystack ``Component`` is the basic block of a Haystack ``Pipeline``. | ||
A ``Component`` is instantiated, then it receives inputs for its ``.run()`` method | ||
and returns output values. | ||
|
||
Learn more about components here: https://docs.haystack.deepset.ai/docs/custom-components | ||
""" | ||
|
||
def __init__( | ||
self, | ||
component: Component, | ||
reads: Union[list[str], dict[str, str]], | ||
writes: Union[list[str], dict[str, str]], | ||
name: Optional[str] = None, | ||
bound_params: Optional[dict] = None, | ||
): | ||
"""Create a Burr ``Action`` from a Haystack ``Component``. | ||
|
||
:param component: Haystack ``Component`` to wrap | ||
:param reads: State fields read and passed to ``Component.run()`` | ||
:param writes: State fields where results of ``Component.run()`` are written | ||
:param name: Name of the action. Can be set later via ``.with_name()`` or in the | ||
``ApplicationBuilder``. | ||
:param bound_params: Parameters to bind to the `Component.run()` method. | ||
|
||
Basic example: | ||
|
||
.. code-block:: python | ||
|
||
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever | ||
from haystack.document_stores.in_memory import InMemoryDocumentStore | ||
from burr.core import ApplicationBuilder | ||
from burr.integrations.haystack import HaystackAction | ||
|
||
retrieve_documents = HaystackAction( | ||
component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()), | ||
name="retrieve_documents", | ||
reads=["query_embedding"], | ||
writes=["documents"], | ||
) | ||
|
||
app = ( | ||
ApplicationBuilder() | ||
.with_actions(retrieve_documents) | ||
.with_transitions("retrieve_documents", "retrieve_documents") | ||
.with_entrypoint("retrieve_documents") | ||
.build() | ||
) | ||
""" | ||
self._component = component | ||
self._name = name | ||
self._bound_params = bound_params if bound_params is not None else {} | ||
|
||
# NOTE input and output socket mappings are kept separately to avoid naming conflicts. | ||
if isinstance(reads, Mapping): | ||
self._input_socket_mapping = reads | ||
self._reads = list(set(reads.values())) | ||
elif isinstance(reads, Sequence): | ||
self._input_socket_mapping = {socket_name: socket_name for socket_name in reads} | ||
self._reads = reads | ||
else: | ||
raise TypeError(f"`reads` must be a sequence or mapping. Received: {type(reads)}") | ||
|
||
self._validate_input_sockets() | ||
|
||
if isinstance(writes, Mapping): | ||
self._output_socket_mapping = writes | ||
self._writes = list(writes.keys()) | ||
elif isinstance(writes, Sequence): | ||
self._output_socket_mapping = {socket_name: socket_name for socket_name in writes} | ||
self._writes = writes | ||
else: | ||
raise TypeError(f"`writes` must be a sequence or mapping. Received: {type(writes)}") | ||
|
||
self._validate_output_sockets() | ||
|
||
def _validate_input_sockets(self) -> None: | ||
component_inputs = self._component.__haystack_input__._sockets_dict.keys() | ||
for socket_name in self._input_socket_mapping.keys(): | ||
if socket_name not in component_inputs: | ||
raise ValueError( | ||
f"Socket `{socket_name}` not found in `Component` inputs: {component_inputs}" | ||
) | ||
|
||
def _validate_output_sockets(self) -> None: | ||
component_outputs = self._component.__haystack_output__._sockets_dict.keys() | ||
for socket_name in self._output_socket_mapping.values(): | ||
if socket_name not in component_outputs: | ||
raise ValueError( | ||
f"Socket `{socket_name}` not found in `Component` outputs: {component_outputs}" | ||
) | ||
|
||
@property | ||
def component(self) -> Component: | ||
"""Haystack `Component` used by this action.""" | ||
return self._component | ||
|
||
@property | ||
def reads(self) -> list[str]: | ||
"""State fields read and passed to `Component.run()`""" | ||
return self._reads | ||
|
||
@property | ||
def writes(self) -> list[str]: | ||
"""State fields where results of `Component.run()` are written.""" | ||
return self._writes | ||
|
||
@property | ||
def inputs(self) -> tuple[dict[str, str], dict[str, str]]: | ||
"""Return dictionaries of required and optional inputs for `Component.run()`""" | ||
required_inputs, optional_inputs = {}, {} | ||
for socket_name, input_socket in self._component.__haystack_input__._sockets_dict.items(): | ||
state_field_name = self._input_socket_mapping.get(socket_name, socket_name) | ||
|
||
# if we expect the value to come from state (previous actions) or it's a | ||
# bound parameter, then this socket isn't a user-provided input | ||
if state_field_name in self.reads or state_field_name in self._bound_params: | ||
continue | ||
|
||
# determine if input is required or optional based on the socket's default value | ||
if input_socket.default_value == haystack_empty: | ||
required_inputs[state_field_name] = input_socket.type | ||
else: | ||
optional_inputs[state_field_name] = input_socket.type | ||
|
||
return required_inputs, optional_inputs | ||
|
||
def run(self, state: State, **run_kwargs) -> dict[str, Any]: | ||
"""Call the Haystack `Component.run()` method. | ||
|
||
:param state: State object of the application. It contains some input values | ||
for ``Component.run()``. | ||
:param run_kwargs: User-provided inputs for ``Component.run()``. | ||
:return: Dictionary of results with mapping ``{socket_name: value}``. | ||
|
||
Note, values come from 3 sources: | ||
- state (from previous actions) | ||
- run_kwargs (inputs from ``Application.run()``) | ||
- bound parameters (from ``HaystackAction`` instantiation) | ||
""" | ||
values = {} | ||
|
||
# here, precedence matters. Alternatively, we could unpack all dictionaries at once | ||
# which would throw an error for key collisions | ||
for input_socket_name, value in self._bound_params.items(): | ||
values[input_socket_name] = value | ||
|
||
for input_socket_name, state_field_name in self._input_socket_mapping.items(): | ||
try: | ||
values[input_socket_name] = state[state_field_name] | ||
except KeyError as e: | ||
raise ValueError(f"No value found in state for field: {state_field_name}") from e | ||
|
||
for input_socket_name, value in run_kwargs.items(): | ||
values[input_socket_name] = value | ||
|
||
return self._component.run(**values) | ||
|
||
def update(self, result: dict, state: State) -> State: | ||
"""Update the state using the results of ``Component.run()``. | ||
The output socket name is mapped to the Burr state field name. | ||
|
||
Values returned by ``Component.run()`` that aren't in ``writes`` are ignored. | ||
""" | ||
# TODO we could want to handle ``.update()`` and ``.append()`` differently | ||
state_update = {} | ||
|
||
for state_field_name, output_socket_name in self._output_socket_mapping.items(): | ||
if state_field_name in self.writes: | ||
try: | ||
state_update[state_field_name] = result[output_socket_name] | ||
except KeyError as e: | ||
raise ValueError( | ||
f"Socket `{output_socket_name}` missing from output of `Component.run()`" | ||
) from e | ||
return state.update(**state_update) | ||
|
||
def get_source(self) -> str: | ||
"""Return the source code of the Haystack ``Component``. | ||
|
||
NOTE. This doesn't include the initialization parameters of the ``Component``. | ||
This can be obtained using``HaystackAction().component.to_dict()``, but this | ||
method might is not implemented for all components. | ||
""" | ||
return inspect.getsource(self._component.__class__) | ||
|
||
zilto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _socket_name_mapping(sockets_connections: list[tuple[str, str]]) -> dict[str, str]: | ||
"""Map socket names to a single socket name. | ||
|
||
In Haystack, components communicate via sockets. A socket called | ||
"embedding" in one component can be renamed to "query_embedding" when | ||
passed to another component. | ||
|
||
In Burr, there is a single state object so we need a mapping to resolve | ||
that `embedding` and `query_embedding` point to the same value. This function | ||
creates a mapping {socket_name: state_field} to rename sockets when creating | ||
the Burr `Graph`. | ||
""" | ||
all_connections: dict[str, set[str]] = {} | ||
for from_, to in sockets_connections: | ||
if from_ not in all_connections: | ||
all_connections[from_] = {from_} | ||
all_connections[from_].add(to) | ||
|
||
if to not in all_connections: | ||
all_connections[to] = {to} | ||
all_connections[to].add(from_) | ||
|
||
reduced_mapping: dict[str, str] = {} | ||
for key, values in all_connections.items(): | ||
unique_name = min(values) | ||
reduced_mapping[key] = unique_name | ||
|
||
return reduced_mapping | ||
|
||
|
||
def _connected_inputs(pipeline) -> dict[str, list[str]]: | ||
"""Get all input sockets that are connected to other components.""" | ||
return { | ||
name: [ | ||
socket.name | ||
for socket in data.get("input_sockets", {}).values() | ||
if socket.is_variadic or socket.senders | ||
] | ||
for name, data in pipeline.graph.nodes(data=True) | ||
} | ||
|
||
|
||
def _connected_outputs(pipeline) -> dict[str, list[str]]: | ||
"""Get all output sockets that are connected to other components.""" | ||
return { | ||
name: [ | ||
socket.name for socket in data.get("output_sockets", {}).values() if socket.receivers | ||
] | ||
for name, data in pipeline.graph.nodes(data=True) | ||
} | ||
|
||
|
||
def haystack_pipeline_to_burr_graph(pipeline: Pipeline) -> Graph: | ||
"""Convert a Haystack `Pipeline` to a Burr `Graph`. | ||
|
||
NOTE. This currently doesn't support Haystack pipelines with | ||
parallel branches. Learn more https://docs.haystack.deepset.ai/docs/pipelines#branching | ||
|
||
From the Haystack `Pipeline`, we can easily retrieve transitions. | ||
zilto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
For actions, we need to create `HaystackAction` from components | ||
and map their sockets to Burr state fields | ||
|
||
EXPERIMENTAL: This feature is experimental and may change in the future. | ||
Changes to Haystack or Burr could impact this function. Please let us know if | ||
you encounter any issues. | ||
""" | ||
|
||
# get all socket connections in the pipeline | ||
sockets_connections = [ | ||
(edge_data["from_socket"].name, edge_data["to_socket"].name) | ||
for _, _, edge_data in pipeline.graph.edges.data() | ||
] | ||
socket_mapping = _socket_name_mapping(sockets_connections) | ||
|
||
transitions = [(from_, to) for from_, to, _ in pipeline.graph.edges] | ||
|
||
# get all input and output sockets that are connected to other components | ||
connected_inputs = _connected_inputs(pipeline) | ||
connected_outputs = _connected_outputs(pipeline) | ||
|
||
actions = [] | ||
for component_name, component in pipeline.walk(): | ||
inputs_mapping = { | ||
socket_name: socket_mapping[socket_name] | ||
for socket_name in connected_inputs[component_name] | ||
} | ||
outputs_mapping = { | ||
socket_mapping[socket_name]: socket_name | ||
for socket_name in connected_outputs[component_name] | ||
} | ||
|
||
haystack_action = HaystackAction( | ||
name=component_name, | ||
component=component, | ||
reads=inputs_mapping, | ||
writes=outputs_mapping, | ||
) | ||
actions.append(haystack_action) | ||
|
||
return GraphBuilder().with_actions(*actions).with_transitions(*transitions).build() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
======== | ||
Haystack | ||
======== | ||
|
||
The Haystack integration allows you to use ``Component`` as Burr ``Action`` using the ``HaystackAction`` construct. You can visit the examples in ``burr/examples/haystack-integration`` for a notebook tutorial. | ||
|
||
.. autoclass:: burr.integrations.haystack.HaystackAction | ||
|
||
.. autofunction:: burr.integrations.haystack.haystack_pipeline_to_burr_graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Haystack + Burr integration | ||
|
||
Haystack is a Python library to build AI pipelines. It assembles `Component` objects into a `Pipeline`, which is a graph of operations. One benefit of Haystack is that it provides many pre-built components to manage documents and interact with LLMs. | ||
|
||
This notebook shows how to convert a Haystack `Component` into a Burr `Action` and a `Pipeline` into a `Graph`. This allows you to integrate Haystack with Burr and leverage other Burr and Burr UI features! |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import os | ||
|
||
from haystack.components.builders import PromptBuilder | ||
from haystack.components.embedders import SentenceTransformersTextEmbedder | ||
from haystack.components.generators import OpenAIGenerator | ||
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever | ||
from haystack.document_stores.in_memory import InMemoryDocumentStore | ||
|
||
from burr.core import ApplicationBuilder, State, action | ||
from burr.integrations.haystack import HaystackAction | ||
|
||
# dummy OpenAI key to avoid raising an error | ||
os.environ["OPENAI_API_KEY"] = "sk-..." | ||
|
||
|
||
embed_text = HaystackAction( | ||
component=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"), | ||
name="embed_text", | ||
reads=[], | ||
writes={"embedding": "query_embedding"}, | ||
) | ||
|
||
|
||
retrieve_documents = HaystackAction( | ||
component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()), | ||
name="retrieve_documents", | ||
reads=["query_embedding"], | ||
writes=["documents"], | ||
) | ||
|
||
|
||
build_prompt = HaystackAction( | ||
component=PromptBuilder(template="Document: {{documents}} Question: {{question}}"), | ||
name="build_prompt", | ||
reads=["documents"], | ||
writes={"prompt": "question_prompt"}, | ||
) | ||
|
||
|
||
generate_answer = HaystackAction( | ||
component=OpenAIGenerator(model="gpt-4o-mini"), | ||
name="generate_answer", | ||
reads={"question_prompt": "prompt"}, | ||
writes={"text": "answer"}, | ||
) | ||
|
||
|
||
@action(reads=["answer"], writes=[]) | ||
def display_answer(state: State) -> State: | ||
print(state["answer"]) | ||
return state | ||
|
||
|
||
def build_application(): | ||
return ( | ||
ApplicationBuilder() | ||
.with_actions( | ||
embed_text, | ||
retrieve_documents, | ||
build_prompt, | ||
generate_answer, | ||
display_answer, | ||
) | ||
.with_transitions( | ||
("embed_text", "retrieve_documents"), | ||
("retrieve_documents", "build_prompt"), | ||
("build_prompt", "generate_answer"), | ||
("generate_answer", "display_answer"), | ||
) | ||
.with_entrypoint("embed_text") | ||
.build() | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app = build_application() | ||
app.visualize(include_state=True) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably assign this in the constructor? Break it into a helper function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding
.inputs()
, it is currently a property, but it's return value may change if we allow a.bind()
method. What was previously a required/optional input is now bound.If
.bound()
is removed, then yes we could set values in__init__()
It seems that this logic belongs to
.inputs()
and wouldn't be of much use elsewhere.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah,
.bound()
should return a copy of the object so the.inputs()
are frozen IMO. That said, this should take inbound_inputs
in the constructor, and we don't need a method? You don't really want properties dynamically computed like this, it's hard to debug and a bit iffy IMO.