Skip to content

Commit

Permalink
Adds default reads/writes to burr actions
Browse files Browse the repository at this point in the history
This allows you to specify defaults if your action does not write. In
the majority of cases they will be none, but this allows simple (static)
arbitrary values. This specifically helps with the branching case --
e.g. where you have two options, and want to null out anything it
doesn't write. For instance, an error and a result -- you'll only ever
produce one or the other.

This works both in the function and class-based approaches -- in the
function-based it is part of the two decorators
(@action/@streaming_action). In the class-based it is part of the class,
overriding the default_reads and default_writes property function

We add a bunch of new tests for default (as the code to handle multiple
action types is fairly dispersed, for now), and also make the naming of
the other tests/content more consistent.
  • Loading branch information
elijahbenizzy committed Jul 16, 2024
1 parent bb2c446 commit ce6ca20
Show file tree
Hide file tree
Showing 3 changed files with 426 additions and 65 deletions.
79 changes: 72 additions & 7 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def reads(self) -> list[str]:
"""
pass

@property
def default_reads(self) -> Dict[str, Any]:
return {}

@abc.abstractmethod
def run(self, state: State, **run_kwargs) -> dict:
"""Runs the function on the given state and returns the result.
Expand Down Expand Up @@ -122,6 +126,10 @@ def writes(self) -> list[str]:
"""
pass

@property
def default_writes(self) -> Dict[str, Any]:
return {}

@abc.abstractmethod
def update(self, result: dict, state: State) -> State:
pass
Expand Down Expand Up @@ -484,6 +492,8 @@ def __init__(
fn: Callable,
reads: List[str],
writes: List[str],
default_reads: Dict[str, Any] = None,
default_writes: Dict[str, Any] = None,
bound_params: dict = None,
):
"""Instantiates a function-based action with the given function, reads, and writes.
Expand All @@ -499,11 +509,21 @@ def __init__(
self._writes = writes
self._bound_params = bound_params if bound_params is not None else {}
self._inputs = _get_inputs(self._bound_params, self._fn)
self._default_reads = default_reads if default_reads is not None else {}
self._default_writes = default_writes if default_writes is not None else {}

@property
def fn(self) -> Callable:
return self._fn

@property
def default_reads(self) -> Dict[str, Any]:
return self._default_reads

@property
def default_writes(self) -> Dict[str, Any]:
return self._default_writes

@property
def reads(self) -> list[str]:
return self._reads
Expand All @@ -526,7 +546,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
:return:
"""
return FunctionBasedAction(
self._fn, self._reads, self._writes, {**self._bound_params, **kwargs}
self._fn,
self._reads,
self._writes,
self.default_reads,
self._default_writes,
{**self._bound_params, **kwargs},
)

def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
Expand Down Expand Up @@ -918,6 +943,8 @@ def __init__(
],
reads: List[str],
writes: List[str],
default_reads: Optional[Dict[str, Any]] = None,
default_writes: Optional[Dict[str, Any]] = None,
bound_params: dict = None,
):
"""Instantiates a function-based streaming action with the given function, reads, and writes.
Expand All @@ -931,6 +958,8 @@ def __init__(
self._fn = fn
self._reads = reads
self._writes = writes
self._default_reads = default_reads if default_reads is not None else {}
self._default_writes = default_writes if default_writes is not None else {}
self._bound_params = bound_params if bound_params is not None else {}

async def _a_stream_run_and_update(
Expand All @@ -957,6 +986,12 @@ def reads(self) -> list[str]:
def writes(self) -> list[str]:
return self._writes

def default_writes(self) -> Dict[str, Any]:
return self._default_writes

def default_reads(self) -> Dict[str, Any]:
return self._default_reads

@property
def streaming(self) -> bool:
return True
Expand All @@ -969,7 +1004,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction":
:return:
"""
return FunctionBasedStreamingAction(
self._fn, self._reads, self._writes, {**self._bound_params, **kwargs}
self._fn,
self._reads,
self._writes,
self._default_reads,
self._default_writes,
{**self._bound_params, **kwargs},
)

@property
Expand Down Expand Up @@ -999,7 +1039,10 @@ def bind(self, **kwargs: Any) -> Self:
...


def copy_func(f: types.FunctionType) -> types.FunctionType:
T = TypeVar("T", bound=types.FunctionType)


def copy_func(f: T) -> T:
"""Copies a function. This is used internally to bind parameters to a function
so we don't accidentally overwrite them.
Expand Down Expand Up @@ -1033,7 +1076,12 @@ def my_action(state: State, z: int) -> tuple[dict, State]:
return self


def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]:
def action(
reads: List[str],
writes: List[str],
default_reads: Dict[str, Any] = None,
default_writes: Dict[str, Any] = None,
) -> Callable[[Callable], FunctionRepresentingAction]:
"""Decorator to create a function-based action. This is user-facing.
Note that, in the future, with typed state, we may not need this for
all cases.
Expand All @@ -1044,19 +1092,34 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Function
:param reads: Items to read from the state
:param writes: Items to write to the state
:param default_reads: Default values for reads. If nothing upstream produces these, they will
be filled automatically with the default values. Leaving blank will have no default values.
:param default_writes: Default values for writes. If your action does not write these, they will be
filled automatically. Leaving blank this will have no default values.
:return: The decorator to assign the function as an action
"""
default_reads = default_reads if default_reads is not None else {}
default_writes = default_writes if default_writes is not None else {}

def decorator(fn) -> FunctionRepresentingAction:
setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes))
setattr(
fn,
FunctionBasedAction.ACTION_FUNCTION,
FunctionBasedAction(
fn, reads, writes, default_reads=default_reads, default_writes=default_writes
),
)
setattr(fn, "bind", types.MethodType(bind, fn))
return fn

return decorator


def streaming_action(
reads: List[str], writes: List[str]
reads: List[str],
writes: List[str],
default_reads: Optional[Dict[str, Any]] = None,
default_writes: Optional[Dict[str, Any]] = None,
) -> Callable[[Callable], FunctionRepresentingAction]:
"""Decorator to create a streaming function-based action. This is user-facing.
Expand Down Expand Up @@ -1091,13 +1154,15 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State]
return {'response': full_response}, state.update(response=full_response)
"""
default_reads = default_reads if default_reads is not None else {}
default_writes = default_writes if default_writes is not None else {}

def wrapped(fn) -> FunctionRepresentingAction:
fn = copy_func(fn)
setattr(
fn,
FunctionBasedAction.ACTION_FUNCTION,
FunctionBasedStreamingAction(fn, reads, writes),
FunctionBasedStreamingAction(fn, reads, writes, default_reads, default_writes),
)
setattr(fn, "bind", types.MethodType(bind, fn))
return fn
Expand Down
60 changes: 44 additions & 16 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_
_raise_fn_return_validation_error(output, action_name)


def _apply_defaults(state: State, defaults: Dict[str, Any]) -> State:
state_update = {}
state_to_use = state
# We really don't need to short-circuit but I want to avoid the update function
# So we might as well
if len(defaults) > 0:
for key, value in defaults.items():
if key not in state:
state_update[key] = value
state_to_use = state.update(**state_update)
return state_to_use


def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict:
"""Runs a function, returning the result of running the function.
Note this restricts the keys in the state to only those that the
Expand All @@ -100,6 +113,7 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name
"instead...)"
)
state_to_use = state.subset(*function.reads)
state_to_use = _apply_defaults(state_to_use, function.default_reads)
function.validate_inputs(inputs)
result = function.run(state_to_use, **inputs)
_validate_result(result, name)
Expand All @@ -112,6 +126,7 @@ async def _arun_function(
"""Runs a function, returning the result of running the function.
Async version of the above."""
state_to_use = state.subset(*function.reads)
state_to_use = _apply_defaults(state_to_use, function.default_reads)
function.validate_inputs(inputs)
result = await function.run(state_to_use, **inputs)
_validate_result(result, name)
Expand Down Expand Up @@ -177,6 +192,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta
f"Action {name} attempted to write to keys {extra_keys} "
f"that it did not declare. It declared: ({reducer.writes})!"
)
new_state = _apply_defaults(new_state, reducer.default_writes)
_validate_reducer_writes(reducer, new_state, name)
return _state_update(state, new_state)

Expand Down Expand Up @@ -229,9 +245,11 @@ def _run_single_step_action(
"""
# TODO -- guard all reads/writes with a subset of the state
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
result, new_state = _adjust_single_step_output(
action.run_and_update(state, **inputs), action.name
)
new_state = _apply_defaults(new_state, action.default_writes)
_validate_result(result, action.name)
out = result, _state_update(state, new_state)
_validate_result(result, action.name)
Expand All @@ -245,6 +263,7 @@ def _run_single_step_streaming_action(
"""Runs a single step streaming action. This API is internal-facing.
This normalizes + validates the output."""
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
generator = action.stream_run_and_update(state, **inputs)
result = None
state_update = None
Expand All @@ -265,15 +284,33 @@ def _run_single_step_streaming_action(
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
)
_validate_result(result, action.name)
state_update = _apply_defaults(state_update, action.default_writes)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update


async def _arun_single_step_action(
action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]]
) -> Tuple[dict, State]:
"""Runs a single step action in async. See the synchronous version for more details."""
state_to_use = state
state_to_use = _apply_defaults(state_to_use, action.default_reads)
action.validate_inputs(inputs)
result, new_state = _adjust_single_step_output(
await action.run_and_update(state_to_use, **inputs), action.name
)
new_state = _apply_defaults(new_state, action.default_writes)
_validate_result(result, action.name)
_validate_reducer_writes(action, new_state, action.name)
return result, _state_update(state, new_state)


async def _arun_single_step_streaming_action(
action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]]
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
"""Runs a single step streaming action in async. See the synchronous version for more details."""
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
generator = action.stream_run_and_update(state, **inputs)
result = None
state_update = None
Expand All @@ -294,6 +331,7 @@ async def _arun_single_step_streaming_action(
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
)
_validate_result(result, action.name)
state_update = _apply_defaults(state_update, action.default_writes)
_validate_reducer_writes(action, state_update, action.name)
# TODO -- add guard against zero-length stream
yield result, state_update
Expand All @@ -310,6 +348,7 @@ def _run_multi_step_streaming_action(
This peeks ahead by one so we know when this is done (and when to validate).
"""
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
generator = action.stream_run(state, **inputs)
result = None
for item in generator:
Expand All @@ -320,8 +359,9 @@ def _run_multi_step_streaming_action(
result = item
if next_result is not None:
yield next_result, None
state_update = _run_reducer(action, state, result, action.name)
_validate_result(result, action.name)
state_update = _run_reducer(action, state, result, action.name)
state_update = _apply_defaults(state_update, action.default_writes)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update

Expand All @@ -331,6 +371,7 @@ async def _arun_multi_step_streaming_action(
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
"""Runs a multi-step streaming action in async. See the synchronous version for more details."""
action.validate_inputs(inputs)
state = _apply_defaults(state, action.default_reads)
generator = action.stream_run(state, **inputs)
result = None
async for item in generator:
Expand All @@ -341,26 +382,13 @@ async def _arun_multi_step_streaming_action(
result = item
if next_result is not None:
yield next_result, None
state_update = _run_reducer(action, state, result, action.name)
_validate_result(result, action.name)
state_update = _run_reducer(action, state, result, action.name)
state_update = _apply_defaults(state_update, action.default_writes)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update


async def _arun_single_step_action(
action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]]
) -> Tuple[dict, State]:
"""Runs a single step action in async. See the synchronous version for more details."""
state_to_use = state
action.validate_inputs(inputs)
result, new_state = _adjust_single_step_output(
await action.run_and_update(state_to_use, **inputs), action.name
)
_validate_result(result, action.name)
_validate_reducer_writes(action, new_state, action.name)
return result, _state_update(state, new_state)


@dataclasses.dataclass
class ApplicationGraph(Graph):
"""User-facing representation of the state machine. This has
Expand Down
Loading

0 comments on commit ce6ca20

Please sign in to comment.