Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integration: added HaystackAction #398

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions burr/integrations/haystack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from typing import Any, Optional, Union

from haystack import Pipeline
from haystack.core.component import Component
from haystack.core.component.types import _empty as haystack_empty

from burr.core.action import Action
from burr.core.graph import Graph, GraphBuilder
from burr.core.state import State


class HaystackAction(Action):
"""Create a Burr `Action` from a Haystack `Component`."""

def __init__(
self,
component: Component,
reads: Union[list[str], dict[str, str]],
writes: Union[list[str], dict[str, str]],
name: Optional[str] = None,
bound_params: Optional[dict] = None,
):
"""
zilto marked this conversation as resolved.
Show resolved Hide resolved
Notes
- need to figure out how to use bind
- you can use `action.bind()` to set values of `Component.run()`.
"""
self._component = component
self._name = name
self._reads = list(reads.keys()) if isinstance(reads, dict) else reads
self._writes = list(writes.values()) if isinstance(writes, dict) else writes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably validate that writes doesn't have a duplicate mapping, given that it's the values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed some mappings to ensure reads={socket: state_field} and writes={state_field: socket}.

This seems to make the most sense because:

  • on reads, each key should be unique and map to a unique kwarg of Component.run(). The same state value could be passed to different kwargs
  • on writes, each key should be a unique state field. Even though Component.run() returns a dictionary where keys are unique, we want to prevent calling .update() on the same field several times.

In other words, this would be invalid:

class CustomComponent:
  def run() -> dict:
     return {"bar": 1, "baz": 2}

# mapping both `bar` and `baz` to `State["foo"]` is invalid
HaystackAction(
   component=CustomComponent()
   writes={"bar": "foo", "baz": "foo"}
)

# reverting the mapping ensures that two outputs can't be mapped to `foo`
HaystackAction(
   component=CustomComponent()
   writes={"foo1": "bar", "foo2": "baz"}
)

self._bound_params = bound_params if bound_params is not None else {}

self._socket_mapping = {}
if isinstance(reads, dict):
for state_field, socket_name in reads.items():
self._socket_mapping[socket_name] = state_field

if isinstance(writes, dict):
for socket_name, state_field in writes.items():
self._socket_mapping[socket_name] = state_field

@property
def reads(self) -> list[str]:
return self._reads

@property
def writes(self) -> list[str]:
return self._writes

@property
def inputs(self) -> tuple[dict[str, str], dict[str, str]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably assign this in the constructor? Break it into a helper function.

Copy link
Collaborator Author

@zilto zilto Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding .inputs(), it is currently a property, but it's return value may change if we allow a .bind() method. What was previously a required/optional input is now bound.

If .bound() is removed, then yes we could set values in __init__()

It seems that this logic belongs to .inputs() and wouldn't be of much use elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, .bound() should return a copy of the object so the .inputs() are frozen IMO. That said, this should take in bound_inputs in the constructor, and we don't need a method? You don't really want properties dynamically computed like this, it's hard to debug and a bit iffy IMO.

"""Return dictionaries of required and optional inputs."""
required_inputs, optional_inputs = {}, {}
for socket_name, input_socket in self._component.__haystack_input__._sockets_dict.items():
state_field_name = self._socket_mapping.get(socket_name, socket_name)

if state_field_name in self.reads:
continue
elif state_field_name in self._bound_params:
continue

if input_socket.default_value == haystack_empty:
required_inputs[state_field_name] = input_socket.type
else:
optional_inputs[state_field_name] = input_socket.type

return required_inputs, optional_inputs

def run(self, state: State, **run_kwargs) -> dict[str, Any]:
"""Call the Haystack `Component.run()` method. It returns a dictionary
of results with mapping {socket_name: value}.

Values come from 3 sources:
- bound parameters (from HaystackAction instantiation, or by using `.bind()`)
- state (from previous actions)
- run_kwargs (inputs from `Application.run()`)
"""
values = {}

# here, precedence matters. Alternatively, we could unpack all dictionaries at once
# which would throw an error for key collisions
for param, value in self._bound_params.items():
values[param] = value

for param in self.reads:
values[param] = state[param]

for param, value in run_kwargs.items():
values[param] = value

return self._component.run(**values)

def update(self, result: dict, state: State) -> State:
"""Update the state using the results of `Component.run()`."""
state_update = {}
for socket_name, value in result.items():
state_field_name = self._socket_mapping.get(socket_name, socket_name)
if state_field_name in self.writes:
state_update[state_field_name] = value

return state.update(**state_update)

def bind(self, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So bind is only (currently) for functions, actions don't have it. Might be a bit confusing here, but if we like it, it should really be at the Action level, as that makes a lot of sense IMO. Then it could just do this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise we can make this a more specific name to not conflict

Copy link
Collaborator Author

@zilto zilto Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand, the pattern is to only allow bound_params at init? This affects the Component.run() method

This would be ok:

haystack_retrieve_documents = HaystackAction(
    component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),
    name="retrieve_documents",
    reads=["query_embedding"],
    writes=["documents"],
    bound_params={"foo": "bar"},
)

This would not be ok?

haystack_retrieve_documents = HaystackAction(
    component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),
    name="retrieve_documents",
    reads=["query_embedding"],
    writes=["documents"],
)
haystack_retrieve_documents.bind(foo="bar")

I should simply remove the .bind() method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, or call it something else? Alternatively, we can make bind work at the action level. Currently it's only for functions, and class-based actions don't have it.

"""Bind a parameter for the `Component.run()` call."""
self._bound_params.update(**kwargs)
return self

zilto marked this conversation as resolved.
Show resolved Hide resolved

def _socket_name_mapping(pipeline) -> dict[str, str]:
"""Map socket names to a single state field name.

In Haystack, components communicate via sockets. A socket called
"embedding" in one component can be renamed to "query_embedding" when
passed to another component.

In Burr, there is a single state object so we need a mapping to resolve
that `embedding` and `query_embedding` point to the same value. This function
creates a mapping {socket_name: state_field} to rename sockets when creating
the Burr `Graph`.
"""
sockets_connections = [
(edge_data["from_socket"].name, edge_data["to_socket"].name)
for _, _, edge_data in pipeline.graph.edges.data()
]
mapping = {}

for from_, to in sockets_connections:
if from_ not in mapping:
mapping[from_] = {from_}
mapping[from_].add(to)

if to not in mapping:
mapping[to] = {to}
mapping[to].add(from_)

result = {}
for key, values in mapping.items():
unique_name = min(values)
result[key] = unique_name

return result


def _connected_inputs(pipeline) -> dict[str, list[str]]:
"""Get all input sockets that are connected to other components."""
return {
name: [
socket.name
for socket in data.get("input_sockets", {}).values()
if socket.is_variadic or socket.senders
]
for name, data in pipeline.graph.nodes(data=True)
}


def _connected_outputs(pipeline) -> dict[str, list[str]]:
"""Get all output sockets that are connected to other components."""
return {
name: [
socket.name for socket in data.get("output_sockets", {}).values() if socket.receivers
]
for name, data in pipeline.graph.nodes(data=True)
}


def haystack_pipeline_to_burr_graph(pipeline: Pipeline) -> Graph:
"""Convert a Haystack `Pipeline` to a Burr `Graph`.

From the Haystack `Pipeline`, we can easily retrieve transitions.
zilto marked this conversation as resolved.
Show resolved Hide resolved
For actions, we need to create `HaystackAction` from components
and map their sockets to Burr state fields
"""
socket_mapping = _socket_name_mapping(pipeline)
connected_inputs = _connected_inputs(pipeline)
connected_outputs = _connected_outputs(pipeline)

transitions = [(from_, to) for from_, to, _ in pipeline.graph.edges]

actions = []
for component_name, component in pipeline.walk():
inputs_from_state = [
socket_mapping[socket_name] for socket_name in connected_inputs[component_name]
]
outputs_to_state = [
socket_mapping[socket_name] for socket_name in connected_outputs[component_name]
]

haystack_action = HaystackAction(
name=component_name,
component=component,
reads=inputs_from_state,
writes=outputs_to_state,
)
actions.append(haystack_action)

return GraphBuilder().with_actions(*actions).with_transitions(*transitions).build()
5 changes: 5 additions & 0 deletions examples/haystack-integration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Haystack + Burr integration

Haystack is a Python library to build AI pipelines. It assembles `Component` objects into a `Pipeline`, which is a graph of operations. One benefit of Haystack is that it provides many pre-built components to manage documents and interact with LLMs.

This notebook shows how to convert a Haystack `Component` into a Burr `Action` and a `Pipeline` into a `Graph`. This allows you to integrate Haystack with Burr and leverage other Burr and Burr UI features!
824 changes: 824 additions & 0 deletions examples/haystack-integration/notebook.ipynb

Large diffs are not rendered by default.

Binary file added examples/haystack-integration/statemachine.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading