Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Select to define transitions #361

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion burr/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from burr.core.action import Action, Condition, Result, action, default, expr, when
from burr.core.action import Action, Condition, Result, Select, action, default, expr, when
from burr.core.application import (
Application,
ApplicationBuilder,
Expand All @@ -18,6 +18,7 @@
"default",
"expr",
"Result",
"Select",
"State",
"when",
]
45 changes: 45 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
List,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -378,6 +379,50 @@ def __invert__(self):
# exists = Condition.exists


# TODO type `resolver` to prevent user-facing type-mismatch
# e.g., a user provided `def foo(state: State, actions: list)`
# would be too restrictive for a `Sequence` type
class Select(Function):
def __init__(
self,
keys: List[str],
resolver: Callable[[State, Sequence[Action]], str],
name: str = None,
):
self._keys = keys
self._resolver = resolver
self._name = name
# TODO add a `default` kwarg;
# could an Action, action_name: str, or action_idx: int
# `default` value could be returned if `_resolver` returns None

@property
def name(self) -> str:
return self._name

@property
def reads(self) -> list[str]:
return self._keys

@property
def resolver(self) -> Callable[[State, Sequence[Action]], str]:
return self._resolver

def __repr__(self) -> str:
return f"select: {self._name}"

def _validate(self, state: State):
missing_keys = set(self._keys) - set(state.keys())
if missing_keys:
raise ValueError(
f"Missing keys in state required by condition: {self} {', '.join(missing_keys)}"
)

def run(self, state: State, possible_actions: Sequence[Action]) -> str:
self._validate(state)
return self._resolver(state, possible_actions)


class Result(Action):
def __init__(self, *fields: str):
"""Represents a result action. This is purely a convenience class to
Expand Down
5 changes: 4 additions & 1 deletion burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Expand All @@ -33,6 +34,7 @@
Condition,
Function,
Reducer,
Select,
SingleStepAction,
SingleStepStreamingAction,
StreamingAction,
Expand Down Expand Up @@ -2005,7 +2007,8 @@ def with_actions(
def with_transitions(
self,
*transitions: Union[
Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition]
Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]]],
Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]], Union[Condition, Select]],
],
) -> "ApplicationBuilder":
"""Adds transitions to the application. Transitions are specified as tuples of either:
Expand Down
41 changes: 30 additions & 11 deletions burr/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import inspect
import logging
import pathlib
from typing import Any, Callable, List, Literal, Optional, Set, Tuple, Union
from typing import Any, Callable, List, Literal, Optional, Sequence, Set, Tuple, Union

from burr import telemetry
from burr.core.action import Action, Condition, create_action, default
from burr.core.action import Action, Condition, Select, create_action, default
from burr.core.state import State
from burr.core.validation import BASE_ERROR_MESSAGE, assert_set

Expand Down Expand Up @@ -118,6 +118,13 @@ def get_next_node(
return self._action_map[entrypoint]
possibilities = self._adjacency_map[prior_step]
for next_action, condition in possibilities:
# When `Select` is used, all possibilities have the same `condition` attached.
# Hitting a `Select` will necessarily exit the for loop
if isinstance(condition, Select):
possible_actions = [self._action_map[p[0]] for p in possibilities]
selected_action = condition.run(state, possible_actions)
return self._action_map[selected_action]

if condition.run(state)[Condition.KEY]:
return self._action_map[next_action]
return None
Expand Down Expand Up @@ -235,7 +242,7 @@ class GraphBuilder:

def __init__(self):
"""Initializes the graph builder."""
self.transitions: Optional[List[Tuple[str, str, Condition]]] = None
self.transitions: Optional[List[Tuple[str, str, Union[Condition, Select]]]] = None
self.actions: Optional[List[Action]] = None

def with_actions(
Expand Down Expand Up @@ -269,7 +276,8 @@ def with_actions(
def with_transitions(
self,
*transitions: Union[
Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition]
Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]]],
Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]], Union[Condition, Select]],
],
) -> "GraphBuilder":
"""Adds transitions to the graph. Transitions are specified as tuples of either:
Expand All @@ -291,14 +299,25 @@ def with_transitions(
condition = conditions[0]
else:
condition = default
if not isinstance(from_, list):
# check required because issubclass(str, Sequence) == True
if isinstance(from_, Sequence) and not isinstance(from_, str):
from_ = [*from_]
else:
from_ = [from_]
for action in from_:
if not isinstance(action, str):
raise ValueError(f"Transition source must be a string, not {action}")
if not isinstance(to_, str):
raise ValueError(f"Transition target must be a string, not {to_}")
self.transitions.append((action, to_, condition))
if isinstance(to_, Sequence) and not isinstance(to_, str):
if not isinstance(condition, Select):
raise ValueError(
"Transition with multiple targets require a `Select` condition."
)
else:
to_ = [to_]
for source in from_:
for target in to_:
if not isinstance(source, str):
raise ValueError(f"Transition source must be a string, not {source}")
if not isinstance(target, str):
raise ValueError(f"Transition target must be a string, not {to_}")
self.transitions.append((source, target, condition))
return self

def build(self) -> Graph:
Expand Down
34 changes: 34 additions & 0 deletions tests/core/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Function,
Input,
Result,
Select,
SingleStepAction,
SingleStepStreamingAction,
StreamingAction,
Expand Down Expand Up @@ -200,6 +201,39 @@ def test_condition_lmda():
# assert cond.run(State({"foo" : "bar"})) == {Condition.KEY: False}


def test_select_constant():
select = Select([], resolver=lambda *args: "foo")
selected_action = select.run(State(), [])

assert selected_action == "foo"


def test_select_determistic():
@action(reads=[], writes=[])
def bar(state):
return state

@action(reads=[], writes=[])
def baz(state):
return state

def length_resolver(state: State, actions: list[Action]) -> str:
foo = state["foo"]
action_idx = len(foo) % len(actions)
return actions[action_idx].name

foo1 = "len=3" # % 2 = 1
foo2 = "len_is_8" # % 2 = 0
actions = [create_action(bar, "bar"), create_action(baz, "baz")]
select = Select(["foo"], resolver=length_resolver)

selected_1 = select.run(State({"foo": foo1}), possible_actions=actions)
assert selected_1 == actions[len(foo1) % len(actions)].name

selected_2 = select.run(State({"foo": foo2}), possible_actions=actions)
assert selected_2 == actions[len(foo2) % len(actions)].name


def test_result():
result = Result("foo", "bar")
assert result.run(State({"foo": "baz", "bar": "qux", "baz": "quux"})) == {
Expand Down
Loading