Skip to content

Commit

Permalink
Adds default reads/writes to burr actions
Browse files Browse the repository at this point in the history
This allows you to specify defaults if your action does not write. In
the majority of cases they will be none, but this allows simple (static)
arbitrary values. This specifically helps with the branching case --
e.g. where you have two options, and want to null out anything it
doesn't write. For instance, an error and a result -- you'll only ever
produce one or the other.

This works both in the function and class-based approaches -- in the
function-based it is part of the two decorators
(@action/@streaming_action). In the class-based it is part of the class,
overriding the default_reads and default_writes property function

We add a bunch of new tests for default (as the code to handle multiple
action types is fairly dispersed, for now), and also make the naming of
the other tests/content more consistent.

Note that this does not currently work with settings defaults to
append/increment operations -- it will produce strange behavior.
This is documented in all appropriate signatures.

This also does not work (or even make sense) in the case that the
function writes a default that it also reads. In that case, it will
clobber the current value with the write value. To avoid this,
we just error out if that is the case beforehand.
  • Loading branch information
elijahbenizzy committed Jul 17, 2024
1 parent bb2c446 commit 308be71
Show file tree
Hide file tree
Showing 7 changed files with 753 additions and 75 deletions.
152 changes: 144 additions & 8 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import types
import typing
from abc import ABC
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -42,6 +43,16 @@ def reads(self) -> list[str]:
"""
pass

@property
def default_reads(self) -> Dict[str, Any]:
"""Default values to read from state if they are not there already.
This just fills out the gaps in state. This must be a subset
of the ``reads`` value.
:return:
"""
return {}

@abc.abstractmethod
def run(self, state: State, **run_kwargs) -> dict:
"""Runs the function on the given state and returns the result.
Expand Down Expand Up @@ -122,18 +133,68 @@ def writes(self) -> list[str]:
"""
pass

@property
def default_writes(self) -> Dict[str, Any]:
"""Default state writes for the reducer. If nothing writes this field from within
the reducer, then this will be written. Note that this is not (currently)
intended to work with append/increment operations.
This must be a subset of the ``writes`` value.
:return: A key/value dictionary of default writes.
"""
return {}

@abc.abstractmethod
def update(self, result: dict, state: State) -> State:
pass


class Action(Function, Reducer, abc.ABC):
class _PostValidator(abc.ABCMeta):
"""Metaclass to allow for __post_init__ to be called after __init__.
While this is general we're keeping it here for now as it is only used
by the Action class. This enables us to ensure that the default_reads are correct.
"""

def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if post := getattr(cls, "__post_init__", None):
post(instance)
return instance


class Action(Function, Reducer, ABC, metaclass=_PostValidator):
def __init__(self):
"""Represents an action in a state machine. This is the base class from which
actions extend. Note that this class needs to have a name set after the fact.
"""
self._name = None

def __post_init__(self):
self._validate_defaults()

def _validate_defaults(self):
reads = set(self.reads)
missing_default_reads = {key for key in self.default_reads.keys() if key not in reads}
if missing_default_reads:
raise ValueError(
f"The following default state reads are not in the set of reads for action: {self}: {', '.join(missing_default_reads)}. "
f"Every read in default_reads must be in the reads list."
)
writes = self.writes
missing_default_writes = {key for key in self.default_writes.keys() if key not in writes}
if missing_default_writes:
raise ValueError(
f"The following default state writes are not in the set of writes for action: {self}: {', '.join(missing_default_writes)}. "
f"Every write in default_writes must be in the writes list."
)
default_writes_also_in_reads = {key for key in self.default_writes.keys() if key in reads}
if default_writes_also_in_reads:
raise ValueError(
f"The following default state writes are also in the reads for action: {self}: {', '.join(default_writes_also_in_reads)}. "
f"Every write in default_writes must not be in the reads list -- this leads to undefined behavior."
)

def with_name(self, name: str) -> Self:
"""Returns a copy of the given action with the given name. Why do we need this?
We instantiate actions without names, and then set them later. This is a way to
Expand Down Expand Up @@ -484,6 +545,8 @@ def __init__(
fn: Callable,
reads: List[str],
writes: List[str],
default_reads: Dict[str, Any] = None,
default_writes: Dict[str, Any] = None,
bound_params: dict = None,
):
"""Instantiates a function-based action with the given function, reads, and writes.
Expand All @@ -499,11 +562,21 @@ def __init__(
self._writes = writes
self._bound_params = bound_params if bound_params is not None else {}
self._inputs = _get_inputs(self._bound_params, self._fn)
self._default_reads = default_reads if default_reads is not None else {}
self._default_writes = default_writes if default_writes is not None else {}

@property
def fn(self) -> Callable:
return self._fn

@property
def default_reads(self) -> Dict[str, Any]:
return self._default_reads

@property
def default_writes(self) -> Dict[str, Any]:
return self._default_writes

@property
def reads(self) -> list[str]:
return self._reads
Expand All @@ -526,7 +599,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
:return:
"""
return FunctionBasedAction(
self._fn, self._reads, self._writes, {**self._bound_params, **kwargs}
self._fn,
self._reads,
self._writes,
self.default_reads,
self._default_writes,
{**self._bound_params, **kwargs},
)

def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
Expand Down Expand Up @@ -918,6 +996,8 @@ def __init__(
],
reads: List[str],
writes: List[str],
default_reads: Optional[Dict[str, Any]] = None,
default_writes: Optional[Dict[str, Any]] = None,
bound_params: dict = None,
):
"""Instantiates a function-based streaming action with the given function, reads, and writes.
Expand All @@ -931,6 +1011,8 @@ def __init__(
self._fn = fn
self._reads = reads
self._writes = writes
self._default_reads = default_reads if default_reads is not None else {}
self._default_writes = default_writes if default_writes is not None else {}
self._bound_params = bound_params if bound_params is not None else {}

async def _a_stream_run_and_update(
Expand All @@ -957,6 +1039,14 @@ def reads(self) -> list[str]:
def writes(self) -> list[str]:
return self._writes

@property
def default_writes(self) -> Dict[str, Any]:
return self._default_writes

@property
def default_reads(self) -> Dict[str, Any]:
return self._default_reads

@property
def streaming(self) -> bool:
return True
Expand All @@ -969,7 +1059,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction":
:return:
"""
return FunctionBasedStreamingAction(
self._fn, self._reads, self._writes, {**self._bound_params, **kwargs}
self._fn,
self._reads,
self._writes,
self._default_reads,
self._default_writes,
{**self._bound_params, **kwargs},
)

@property
Expand Down Expand Up @@ -999,7 +1094,10 @@ def bind(self, **kwargs: Any) -> Self:
...


def copy_func(f: types.FunctionType) -> types.FunctionType:
T = TypeVar("T", bound=types.FunctionType)


def copy_func(f: T) -> T:
"""Copies a function. This is used internally to bind parameters to a function
so we don't accidentally overwrite them.
Expand Down Expand Up @@ -1033,7 +1131,12 @@ def my_action(state: State, z: int) -> tuple[dict, State]:
return self


def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]:
def action(
reads: List[str],
writes: List[str],
default_reads: Dict[str, Any] = None,
default_writes: Dict[str, Any] = None,
) -> Callable[[Callable], FunctionRepresentingAction]:
"""Decorator to create a function-based action. This is user-facing.
Note that, in the future, with typed state, we may not need this for
all cases.
Expand All @@ -1044,19 +1147,38 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Function
:param reads: Items to read from the state
:param writes: Items to write to the state
:param default_reads: Default values for reads. If nothing upstream produces these, they will
be filled automatically. This is equivalent to adding
``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})``
at the beginning of your function.
:param default_writes: Default values for writes. If the action's state update does not write to this,
they will be filled automatically with the default values. Leaving blank will have no default values.
This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function.
Note that this will not work as intended with append/increment operations, so be careful.
:return: The decorator to assign the function as an action
"""
default_reads = default_reads if default_reads is not None else {}
default_writes = default_writes if default_writes is not None else {}

def decorator(fn) -> FunctionRepresentingAction:
setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes))
setattr(
fn,
FunctionBasedAction.ACTION_FUNCTION,
FunctionBasedAction(
fn, reads, writes, default_reads=default_reads, default_writes=default_writes
),
)
setattr(fn, "bind", types.MethodType(bind, fn))
return fn

return decorator


def streaming_action(
reads: List[str], writes: List[str]
reads: List[str],
writes: List[str],
default_reads: Optional[Dict[str, Any]] = None,
default_writes: Optional[Dict[str, Any]] = None,
) -> Callable[[Callable], FunctionRepresentingAction]:
"""Decorator to create a streaming function-based action. This is user-facing.
Expand Down Expand Up @@ -1090,14 +1212,28 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State]
# return the final result
return {'response': full_response}, state.update(response=full_response)
:param reads: Items to read from the state
:param writes: Items to write to the state
:param default_reads: Default values for reads. If nothing upstream produces these, they will
be filled automatically. This is equivalent to adding
``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})``
at the beginning of your function.
:param default_writes: Default values for writes. If the action's state update does not write to this,
they will be filled automatically with the default values. Leaving blank will have no default values.
This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function.
Note that this will not work as intended with append/increment operations, so be careful.
:return: The decorator to assign the function as an action
"""
default_reads = default_reads if default_reads is not None else {}
default_writes = default_writes if default_writes is not None else {}

def wrapped(fn) -> FunctionRepresentingAction:
fn = copy_func(fn)
setattr(
fn,
FunctionBasedAction.ACTION_FUNCTION,
FunctionBasedStreamingAction(fn, reads, writes),
FunctionBasedStreamingAction(fn, reads, writes, default_reads, default_writes),
)
setattr(fn, "bind", types.MethodType(bind, fn))
return fn
Expand Down
Loading

0 comments on commit 308be71

Please sign in to comment.