diff --git a/burr/core/action.py b/burr/core/action.py index 5f35cfe6..40b1f0bd 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -6,6 +6,7 @@ import sys import types import typing +from abc import ABC from typing import ( Any, AsyncGenerator, @@ -42,6 +43,16 @@ def reads(self) -> list[str]: """ pass + @property + def default_reads(self) -> Dict[str, Any]: + """Default values to read from state if they are not there already. + This just fills out the gaps in state. This must be a subset + of the ``reads`` value. + + :return: + """ + return {} + @abc.abstractmethod def run(self, state: State, **run_kwargs) -> dict: """Runs the function on the given state and returns the result. @@ -122,18 +133,68 @@ def writes(self) -> list[str]: """ pass + @property + def default_writes(self) -> Dict[str, Any]: + """Default state writes for the reducer. If nothing writes this field from within + the reducer, then this will be written. Note that this is not (currently) + intended to work with append/increment operations. + + This must be a subset of the ``writes`` value. + + :return: A key/value dictionary of default writes. + """ + return {} + @abc.abstractmethod def update(self, result: dict, state: State) -> State: pass -class Action(Function, Reducer, abc.ABC): +class _PostValidator(abc.ABCMeta): + """Metaclass to allow for __post_init__ to be called after __init__. + While this is general we're keeping it here for now as it is only used + by the Action class. This enables us to ensure that the default_reads are correct. + """ + + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if post := getattr(cls, "__post_init__", None): + post(instance) + return instance + + +class Action(Function, Reducer, ABC, metaclass=_PostValidator): def __init__(self): """Represents an action in a state machine. This is the base class from which actions extend. Note that this class needs to have a name set after the fact. """ self._name = None + def __post_init__(self): + self._validate_defaults() + + def _validate_defaults(self): + reads = set(self.reads) + missing_default_reads = {key for key in self.default_reads.keys() if key not in reads} + if missing_default_reads: + raise ValueError( + f"The following default state reads are not in the set of reads for action: {self}: {', '.join(missing_default_reads)}. " + f"Every read in default_reads must be in the reads list." + ) + writes = self.writes + missing_default_writes = {key for key in self.default_writes.keys() if key not in writes} + if missing_default_writes: + raise ValueError( + f"The following default state writes are not in the set of writes for action: {self}: {', '.join(missing_default_writes)}. " + f"Every write in default_writes must be in the writes list." + ) + default_writes_also_in_reads = {key for key in self.default_writes.keys() if key in reads} + if default_writes_also_in_reads: + raise ValueError( + f"The following default state writes are also in the reads for action: {self}: {', '.join(default_writes_also_in_reads)}. " + f"Every write in default_writes must not be in the reads list -- this leads to undefined behavior." + ) + def with_name(self, name: str) -> Self: """Returns a copy of the given action with the given name. Why do we need this? We instantiate actions without names, and then set them later. This is a way to @@ -484,6 +545,8 @@ def __init__( fn: Callable, reads: List[str], writes: List[str], + default_reads: Dict[str, Any] = None, + default_writes: Dict[str, Any] = None, bound_params: dict = None, ): """Instantiates a function-based action with the given function, reads, and writes. @@ -499,11 +562,21 @@ def __init__( self._writes = writes self._bound_params = bound_params if bound_params is not None else {} self._inputs = _get_inputs(self._bound_params, self._fn) + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} @property def fn(self) -> Callable: return self._fn + @property + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + + @property + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + @property def reads(self) -> list[str]: return self._reads @@ -526,7 +599,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction": :return: """ return FunctionBasedAction( - self._fn, self._reads, self._writes, {**self._bound_params, **kwargs} + self._fn, + self._reads, + self._writes, + self.default_reads, + self._default_writes, + {**self._bound_params, **kwargs}, ) def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]: @@ -918,6 +996,8 @@ def __init__( ], reads: List[str], writes: List[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, bound_params: dict = None, ): """Instantiates a function-based streaming action with the given function, reads, and writes. @@ -931,6 +1011,8 @@ def __init__( self._fn = fn self._reads = reads self._writes = writes + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} self._bound_params = bound_params if bound_params is not None else {} async def _a_stream_run_and_update( @@ -957,6 +1039,14 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return self._writes + @property + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + + @property + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + @property def streaming(self) -> bool: return True @@ -969,7 +1059,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction": :return: """ return FunctionBasedStreamingAction( - self._fn, self._reads, self._writes, {**self._bound_params, **kwargs} + self._fn, + self._reads, + self._writes, + self._default_reads, + self._default_writes, + {**self._bound_params, **kwargs}, ) @property @@ -999,7 +1094,10 @@ def bind(self, **kwargs: Any) -> Self: ... -def copy_func(f: types.FunctionType) -> types.FunctionType: +T = TypeVar("T", bound=types.FunctionType) + + +def copy_func(f: T) -> T: """Copies a function. This is used internally to bind parameters to a function so we don't accidentally overwrite them. @@ -1033,7 +1131,12 @@ def my_action(state: State, z: int) -> tuple[dict, State]: return self -def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]: +def action( + reads: List[str], + writes: List[str], + default_reads: Dict[str, Any] = None, + default_writes: Dict[str, Any] = None, +) -> Callable[[Callable], FunctionRepresentingAction]: """Decorator to create a function-based action. This is user-facing. Note that, in the future, with typed state, we may not need this for all cases. @@ -1044,11 +1147,27 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Function :param reads: Items to read from the state :param writes: Items to write to the state + :param default_reads: Default values for reads. If nothing upstream produces these, they will + be filled automatically. This is equivalent to adding + ``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})`` + at the beginning of your function. + :param default_writes: Default values for writes. If the action's state update does not write to this, + they will be filled automatically with the default values. Leaving blank will have no default values. + This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function. + Note that this will not work as intended with append/increment operations, so be careful. :return: The decorator to assign the function as an action """ + default_reads = default_reads if default_reads is not None else {} + default_writes = default_writes if default_writes is not None else {} def decorator(fn) -> FunctionRepresentingAction: - setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes)) + setattr( + fn, + FunctionBasedAction.ACTION_FUNCTION, + FunctionBasedAction( + fn, reads, writes, default_reads=default_reads, default_writes=default_writes + ), + ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn @@ -1056,7 +1175,10 @@ def decorator(fn) -> FunctionRepresentingAction: def streaming_action( - reads: List[str], writes: List[str] + reads: List[str], + writes: List[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ) -> Callable[[Callable], FunctionRepresentingAction]: """Decorator to create a streaming function-based action. This is user-facing. @@ -1090,14 +1212,28 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State] # return the final result return {'response': full_response}, state.update(response=full_response) + :param reads: Items to read from the state + :param writes: Items to write to the state + :param default_reads: Default values for reads. If nothing upstream produces these, they will + be filled automatically. This is equivalent to adding + ``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})`` + at the beginning of your function. + :param default_writes: Default values for writes. If the action's state update does not write to this, + they will be filled automatically with the default values. Leaving blank will have no default values. + This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function. + Note that this will not work as intended with append/increment operations, so be careful. + :return: The decorator to assign the function as an action + """ + default_reads = default_reads if default_reads is not None else {} + default_writes = default_writes if default_writes is not None else {} def wrapped(fn) -> FunctionRepresentingAction: fn = copy_func(fn) setattr( fn, FunctionBasedAction.ACTION_FUNCTION, - FunctionBasedStreamingAction(fn, reads, writes), + FunctionBasedStreamingAction(fn, reads, writes, default_reads, default_writes), ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn diff --git a/burr/core/application.py b/burr/core/application.py index 69591f00..949db96c 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -83,6 +83,33 @@ def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_ _raise_fn_return_validation_error(output, action_name) +def _pre_apply_read_defaults( + state: State, + default_reads: Dict[str, Any], +): + """Applies default values to the state prior to execution. + This just applies them to the state so the action can overwrite them. + """ + state_update = {} + for key, value in default_reads.items(): + if key not in state: + state_update[key] = value + return state.update(**state_update) + + +def _pre_apply_write_defaults( + state: State, + default_writes: Dict[str, Any], +) -> State: + """Applies default values to the state prior to execution. + This just applies them to the state so the action can overwrite them. + """ + state_update = {} + for key, value in default_writes.items(): + state_update[key] = value + return state.update(**state_update) + + def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict: """Runs a function, returning the result of running the function. Note this restricts the keys in the state to only those that the @@ -100,6 +127,7 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name "instead...)" ) state_to_use = state.subset(*function.reads) + state_to_use = _pre_apply_read_defaults(state_to_use, function.default_reads) function.validate_inputs(inputs) result = function.run(state_to_use, **inputs) _validate_result(result, name) @@ -112,6 +140,7 @@ async def _arun_function( """Runs a function, returning the result of running the function. Async version of the above.""" state_to_use = state.subset(*function.reads) + state_to_use = _pre_apply_read_defaults(state_to_use, function.default_reads) function.validate_inputs(inputs) result = await function.run(state_to_use, **inputs) _validate_result(result, name) @@ -168,7 +197,8 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta :return: """ # TODO -- better guarding on state reads/writes - new_state = reducer.update(result, state) + state_with_defaults = _pre_apply_write_defaults(state, reducer.default_writes) + new_state = reducer.update(result, state_with_defaults) keys_in_new_state = set(new_state.keys()) new_keys = keys_in_new_state - set(state.keys()) extra_keys = new_keys - set(reducer.writes) @@ -216,6 +246,22 @@ def _format_BASE_ERROR_MESSAGE(action: Action, input_state: State, inputs: dict) return "\n" + border + "\n" + message + "\n" + border +def _prep_state_single_step_action(state: State, action: SingleStepAction): + """Runs default application for single step action. + First applies read defaults, then applies write defaults. Note write defaults + will blogger read + + :param state: + :param action: + :return: + """ + # first apply read defaults + state = _pre_apply_read_defaults(state, action.default_reads) + # then apply write defaults so if the action doesn't write it will be in state + state = _pre_apply_write_defaults(state, action.default_writes) + return state + + def _run_single_step_action( action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] ) -> Tuple[Dict[str, Any], State]: @@ -229,6 +275,7 @@ def _run_single_step_action( """ # TODO -- guard all reads/writes with a subset of the state action.validate_inputs(inputs) + state = _prep_state_single_step_action(state, action) result, new_state = _adjust_single_step_output( action.run_and_update(state, **inputs), action.name ) @@ -245,6 +292,7 @@ def _run_single_step_streaming_action( """Runs a single step streaming action. This API is internal-facing. This normalizes + validates the output.""" action.validate_inputs(inputs) + state = _prep_state_single_step_action(state, action) generator = action.stream_run_and_update(state, **inputs) result = None state_update = None @@ -269,11 +317,26 @@ def _run_single_step_streaming_action( yield result, state_update +async def _arun_single_step_action( + action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] +) -> Tuple[dict, State]: + """Runs a single step action in async. See the synchronous version for more details.""" + state_to_use = _prep_state_single_step_action(state, action) + action.validate_inputs(inputs) + result, new_state = _adjust_single_step_output( + await action.run_and_update(state_to_use, **inputs), action.name + ) + _validate_result(result, action.name) + _validate_reducer_writes(action, new_state, action.name) + return result, _state_update(state, new_state) + + async def _arun_single_step_streaming_action( action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]] ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: """Runs a single step streaming action in async. See the synchronous version for more details.""" action.validate_inputs(inputs) + state = _prep_state_single_step_action(state, action) generator = action.stream_run_and_update(state, **inputs) result = None state_update = None @@ -310,6 +373,7 @@ def _run_multi_step_streaming_action( This peeks ahead by one so we know when this is done (and when to validate). """ action.validate_inputs(inputs) + state = _pre_apply_read_defaults(state, action.default_reads) generator = action.stream_run(state, **inputs) result = None for item in generator: @@ -331,6 +395,7 @@ async def _arun_multi_step_streaming_action( ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: """Runs a multi-step streaming action in async. See the synchronous version for more details.""" action.validate_inputs(inputs) + state = _pre_apply_read_defaults(state, action.default_reads) generator = action.stream_run(state, **inputs) result = None async for item in generator: @@ -347,20 +412,6 @@ async def _arun_multi_step_streaming_action( yield result, state_update -async def _arun_single_step_action( - action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] -) -> Tuple[dict, State]: - """Runs a single step action in async. See the synchronous version for more details.""" - state_to_use = state - action.validate_inputs(inputs) - result, new_state = _adjust_single_step_output( - await action.run_and_update(state_to_use, **inputs), action.name - ) - _validate_result(result, action.name) - _validate_reducer_writes(action, new_state, action.name) - return result, _state_update(state, new_state) - - @dataclasses.dataclass class ApplicationGraph(Graph): """User-facing representation of the state machine. This has @@ -639,7 +690,6 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d return out async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True): - # we want to increment regardless of failure with self.context: next_action = self.get_next_action() if next_action is None: diff --git a/burr/core/state.py b/burr/core/state.py index 62c9a583..49a6fb3a 100644 --- a/burr/core/state.py +++ b/burr/core/state.py @@ -95,6 +95,11 @@ def writes(self) -> list[str]: """Returns the keys that this state delta writes""" pass + @abc.abstractmethod + def deletes(self) -> list[str]: + """Returns the keys that this state delta deletes""" + pass + @abc.abstractmethod def apply_mutate(self, inputs: dict): """Applies the state delta to the inputs""" @@ -117,6 +122,9 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return list(self.values.keys()) + def deletes(self) -> list[str]: + return [] + def apply_mutate(self, inputs: dict): inputs.update(self.values) @@ -137,13 +145,21 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return list(self.values.keys()) + def deletes(self) -> list[str]: + return [] + def apply_mutate(self, inputs: dict): for key, value in self.values.items(): if key not in inputs: inputs[key] = [] if not isinstance(inputs[key], list): raise ValueError(f"Cannot append to non-list value {key}={inputs[self.key]}") - inputs[key].append(value) + inputs[key] = [ + *inputs[key], + value, + ] # Not as efficient but safer, so we don't mutate the original list + # we're doing this to avoid a copy.deepcopy() call, so it is already more efficient than it was before + # That said, if one modifies prior values in the list, it is on them, and undefined behavior def validate(self, input_state: Dict[str, Any]): incorrect_types = {} @@ -171,6 +187,9 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return list(self.values.keys()) + def deletes(self) -> list[str]: + return [] + def validate(self, input_state: Dict[str, Any]): incorrect_types = {} for write_key in self.writes(): @@ -201,11 +220,14 @@ def name(cls) -> str: return "delete" def reads(self) -> list[str]: - return list(self.keys) + return [] def writes(self) -> list[str]: return [] + def deletes(self) -> list[str]: + return list(self.keys) + def apply_mutate(self, inputs: dict): for key in self.keys: inputs.pop(key, None) @@ -221,11 +243,12 @@ def __init__(self, initial_values: Dict[str, Any] = None): def apply_operation(self, operation: StateDelta) -> "State": """Applies a given operation to the state, returning a new state""" - new_state = copy.deepcopy(self._state) # TODO -- restrict to just the read keys + new_state = copy.copy(self._state) # TODO -- restrict to just the read keys operation.validate(new_state) operation.apply_mutate( new_state ) # todo -- validate that the write keys are the only different ones + # we want to carry this on for now return State(new_state) def get_all(self) -> Dict[str, Any]: @@ -331,7 +354,9 @@ def merge(self, other: "State") -> "State": def subset(self, *keys: str, ignore_missing: bool = True) -> "State": """Returns a subset of the state, with only the given keys""" - return State({key: self[key] for key in keys if key in self or not ignore_missing}) + return State( + {key: self[key] for key in keys if key in self or not ignore_missing}, + ) def __getitem__(self, __k: str) -> Any: return self._state[__k] diff --git a/docs/reference/state.rst b/docs/reference/state.rst index 4e5c3e06..d999bd6f 100644 --- a/docs/reference/state.rst +++ b/docs/reference/state.rst @@ -1,6 +1,6 @@ -================= +===== State -================= +===== Use the state API to manipulate the state of the application. diff --git a/tests/core/test_action.py b/tests/core/test_action.py index b623b598..70a49a63 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, Generator, Optional, Tuple, cast +from typing import Any, AsyncGenerator, Dict, Generator, Optional, Tuple, cast import pytest @@ -719,3 +719,81 @@ async def callback(r: Optional[dict], s: State, e: Exception): ((result, state, error),) = called assert state["foo"] == "bar" assert result is None + + +def test_action_subclass_validate_defaults_fails_incorrect_writes(): + """Tests that the initialization of subclass hook validates as intended""" + + class IncorrectReads(Action): + @property + def reads(self) -> list[str]: + return ["foo"] + + @property + def writes(self) -> list[str]: + return ["bar"] + + @property + def default_reads(self) -> Dict[str, Any]: + return {"qux": None} + + def run(self, state: State) -> dict: + raise ValueError("This should never be called") + + def update(self, result: dict, state: State) -> State: + raise ValueError("This should never be called") + + with pytest.raises(ValueError): + IncorrectReads() + + +def test_action_subclass_validate_defaults_fails_incorrect_reads(): + """Tests that the initialiation of subclass hook validates as intended""" + + class IncorrectWrites(Action): + @property + def reads(self) -> list[str]: + return ["foo"] + + @property + def writes(self) -> list[str]: + return ["bar"] + + @property + def default_writes(self) -> Dict[str, Any]: + return {"qux": None} + + def run(self, state: State) -> dict: + raise ValueError("This should never be called") + + def update(self, result: dict, state: State) -> State: + raise ValueError("This should never be called") + + with pytest.raises(ValueError): + IncorrectWrites() + + +def test_action_subclass_validate_defaults_fails_reads_in_default_writes(): + """Tests that the initialiation of subclass hook validates as intended""" + + class IncorrectWrites(Action): + @property + def reads(self) -> list[str]: + return ["foo"] + + @property + def writes(self) -> list[str]: + return ["bar"] + + @property + def default_writes(self) -> Dict[str, Any]: + return {"foo": None} + + def run(self, state: State) -> dict: + raise ValueError("This should never be called") + + def update(self, result: dict, state: State) -> State: + raise ValueError("This should never be called") + + with pytest.raises(ValueError): + IncorrectWrites() diff --git a/tests/core/test_application.py b/tests/core/test_application.py index d6c4257a..bf77a7b0 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -31,6 +31,8 @@ _arun_multi_step_streaming_action, _arun_single_step_action, _arun_single_step_streaming_action, + _pre_apply_read_defaults, + _pre_apply_write_defaults, _run_function, _run_multi_step_streaming_action, _run_reducer, @@ -60,6 +62,8 @@ def __init__( fn: Callable[..., dict], update_fn: Callable[[dict, State], State], inputs: list[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ): super(PassedInAction, self).__init__() self._reads = reads @@ -67,6 +71,8 @@ def __init__( self._fn = fn self._update_fn = update_fn self._inputs = inputs + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} def run(self, state: State, **run_kwargs) -> dict: return self._fn(state, **run_kwargs) @@ -75,6 +81,14 @@ def run(self, state: State, **run_kwargs) -> dict: def inputs(self) -> list[str]: return self._inputs + @property + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + + @property + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + def update(self, result: dict, state: State) -> State: return self._update_fn(result, state) @@ -95,8 +109,18 @@ def __init__( fn: Callable[..., Awaitable[dict]], update_fn: Callable[[dict, State], State], inputs: list[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ): - super().__init__(reads=reads, writes=writes, fn=fn, update_fn=update_fn, inputs=inputs) # type: ignore + super().__init__( + reads=reads, + writes=writes, + fn=fn, + update_fn=update_fn, + inputs=inputs, + default_reads=default_reads, + default_writes=default_writes, + ) # type: ignore async def run(self, state: State, **run_kwargs) -> dict: return await self._fn(state, **run_kwargs) @@ -105,7 +129,7 @@ async def run(self, state: State, **run_kwargs) -> dict: base_counter_action = PassedInAction( reads=["count"], writes=["count"], - fn=lambda state: {"count": state.get("count", 0) + 1}, + fn=lambda state: {"count": state["count"] + 1}, update_fn=lambda result, state: state.update(**result), inputs=[], ) @@ -113,13 +137,21 @@ async def run(self, state: State, **run_kwargs) -> dict: base_counter_action_with_inputs = PassedInAction( reads=["count"], writes=["count"], - fn=lambda state, additional_increment: { - "count": state.get("count", 0) + 1 + additional_increment - }, + fn=lambda state, additional_increment: {"count": state["count"] + 1 + additional_increment}, update_fn=lambda result, state: state.update(**result), inputs=["additional_increment"], ) +base_counter_action_with_defaults = PassedInAction( + reads=["count"], + writes=["count", "error"], + fn=lambda state: {"count": state["count"] + 1}, + update_fn=lambda result, state: state.update(**result), + inputs=[], + default_reads={"count": 0}, + default_writes={"error": None}, +) + class ActionTracker(PreRunStepHook, PostRunStepHook): def __init__(self): @@ -172,7 +204,7 @@ async def post_run_step(self, *, action: Action, **future_kwargs): async def _counter_update_async(state: State, additional_increment: int = 0) -> dict: await asyncio.sleep(0.0001) # just so we can make this *truly* async # does not matter, but more accurately simulates an async function - return {"count": state.get("count", 0) + 1 + additional_increment} + return {"count": state["count"] + 1 + additional_increment} base_counter_action_async = PassedInActionAsync( @@ -193,6 +225,16 @@ async def _counter_update_async(state: State, additional_increment: int = 0) -> inputs=["additional_increment"], ) +base_counter_action_async_with_defaults = PassedInActionAsync( + reads=["count"], + writes=["count", "error"], + fn=_counter_update_async, + update_fn=lambda result, state: state.update(**result), + inputs=[], + default_reads={"count": 0}, + default_writes={"error": None}, +) + class BrokenStepException(Exception): pass @@ -236,18 +278,46 @@ async def incorrect(x): ) +def test__pre_read_apply_defaults(): + state = State({"in_state": 0}) + defaults = {"in_state": 1, "not_in_state": 2} + result = _pre_apply_read_defaults(state, defaults) + assert result["in_state"] == 0 + assert result["not_in_state"] == 2 + + +def test__pre_write_apply_defaults(): + # Write defaults should always be applied, + # they'll be overwritten by the reducer + state = State({"in_state": 0, "to_be_overwritten": 0}) + defaults = {"in_state": 1, "not_in_state": 2, "to_be_overwritten": 3} + result = _pre_apply_write_defaults(state, defaults) + assert result["in_state"] == 1 + assert result["not_in_state"] == 2 + assert result["to_be_overwritten"] == 3 + + def test__run_function(): """Tests that we can run a function""" action = base_counter_action - state = State({}) + state = State({"count": 0}) result = _run_function(action, state, inputs={}, name=action.name) assert result == {"count": 1} +def test__run_function_defaults(): + action = base_counter_action_with_defaults + state = State({}) + result = _run_function(action, state, inputs={}, name=action.name) + assert result == { + "count": 1 + } # default read is applied, write is not as it is a reducer capability, and is not part of the result + + def test__run_function_with_inputs(): """Tests that we can run a function""" action = base_counter_action_with_inputs - state = State({}) + state = State({"count": 0}) result = _run_function(action, state, inputs={"additional_increment": 1}, name=action.name) assert result == {"count": 2} @@ -255,7 +325,7 @@ def test__run_function_with_inputs(): def test__run_function_cant_run_async(): """Tests that we can't run an async function""" action = base_counter_action_async - state = State({}) + state = State({"count": 0}) with pytest.raises(ValueError, match="async"): _run_function(action, state, inputs={}, name=action.name) @@ -263,11 +333,19 @@ def test__run_function_cant_run_async(): def test__run_function_incorrect_result_type(): """Tests that we can run an async function""" action = base_action_incorrect_result_type - state = State({}) + state = State({"count": 0}) with pytest.raises(ValueError, match="returned a non-dict"): _run_function(action, state, inputs={}, name=action.name) +def test__run_reducer_applies_defaults(): + """Tests that we can run a reducer and it behaves as expected""" + reducer = base_counter_action_with_defaults + state = State({"count": 0}) + state = _run_reducer(reducer, state, {"count": 1}, "reducer") + assert state.get_all() == {"count": 1, "error": None} + + def test__run_reducer_modifies_state(): """Tests that we can run a reducer and it behaves as expected""" reducer = PassedInAction( @@ -299,6 +377,14 @@ def test__run_reducer_deletes_state(): async def test__arun_function(): """Tests that we can run an async function""" action = base_counter_action_async + state = State({"count": 0}) + result = await _arun_function(action, state, inputs={}, name=action.name) + assert result == {"count": 1} + + +async def test__arun_function_with_defaults(): + """Tests that we can run an async function""" + action = base_counter_action_async_with_defaults state = State({}) result = await _arun_function(action, state, inputs={}, name=action.name) assert result == {"count": 1} @@ -315,7 +401,7 @@ async def test__arun_function_incorrect_result_type(): async def test__arun_function_with_inputs(): """Tests that we can run an async function""" action = base_counter_action_with_inputs_async - state = State({}) + state = State({"count": 0}) result = await _arun_function( action, state, inputs={"additional_increment": 1}, name=action.name ) @@ -447,6 +533,9 @@ def writes(self) -> list[str]: class SingleStepCounter(SingleStepAction): + def __init__(self): + super(SingleStepCounter, self).__init__() + def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: result = {"count": state["count"] + 1 + sum([0] + list(run_kwargs.values()))} return result, state.update(**result).append(tracker=result["count"]) @@ -466,6 +555,23 @@ def inputs(self) -> list[str]: return ["additional_increment"] +class SingleStepCounterWithDefaults(SingleStepCounter): + def __init__(self): + super(SingleStepCounterWithDefaults, self).__init__() + + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepActionIncorrectResultType(SingleStepAction): def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return "not a dict", state @@ -498,6 +604,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepCounterWithDefaultsAsync(SingleStepCounterAsync): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepCounterWithInputsAsync(SingleStepCounterAsync): @property def inputs(self) -> list[str]: @@ -527,6 +647,20 @@ def update(self, result: dict, state: State) -> State: return state.update(**result).append(tracker=result["count"]) +class StreamingCounterWithDefaults(StreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class AsyncStreamingCounter(AsyncStreamingAction): async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: if "steps_per_count" in run_kwargs: @@ -552,6 +686,20 @@ def update(self, result: dict, state: State) -> State: return state.update(**result).append(tracker=result["count"]) +class StreamingCounterWithDefaultsAsync(AsyncStreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepStreamingCounter(SingleStepStreamingAction): def stream_run_and_update( self, state: State, **run_kwargs @@ -571,6 +719,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepStreamingCounterWithDefaults(SingleStepStreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepStreamingCounterAsync(SingleStepStreamingAction): async def stream_run_and_update( self, state: State, **run_kwargs @@ -592,6 +754,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepStreamingCounterWithDefaultsAsync(SingleStepStreamingCounterAsync): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class StreamingActionIncorrectResultType(StreamingAction): def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]: yield {} @@ -662,12 +838,20 @@ def writes(self) -> list[str]: base_single_step_counter_async = SingleStepCounterAsync() base_single_step_counter_with_inputs = SingleStepCounterWithInputs() base_single_step_counter_with_inputs_async = SingleStepCounterWithInputsAsync() +base_single_step_counter_with_defaults = SingleStepCounterWithDefaults() +base_single_step_counter_with_defaults_async = SingleStepCounterWithDefaultsAsync() base_streaming_counter = StreamingCounter() +base_streaming_counter_with_defaults = StreamingCounterWithDefaults() base_streaming_single_step_counter = SingleStepStreamingCounter() +base_streaming_single_step_counter_with_defaults = SingleStepStreamingCounterWithDefaults() base_streaming_counter_async = AsyncStreamingCounter() +base_streaming_counter_with_defaults_async = StreamingCounterWithDefaultsAsync() base_streaming_single_step_counter_async = SingleStepStreamingCounterAsync() +base_streaming_single_step_counter_with_defaults_async = ( + SingleStepStreamingCounterWithDefaultsAsync() +) base_single_step_action_incorrect_result_type = SingleStepActionIncorrectResultType() base_single_step_action_incorrect_result_type_async = SingleStepActionIncorrectResultTypeAsync() @@ -709,6 +893,18 @@ def test__run_single_step_action_with_inputs(): assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} +def test__run_single_step_action_with_defaults(): + action = base_single_step_counter_with_defaults.with_name("counter") + state = State({}) + result, state = _run_single_step_action(action, state, {}) + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + async def test__arun_single_step_action(): action = base_single_step_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) @@ -735,6 +931,18 @@ async def test__arun_single_step_action_with_inputs(): assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} +async def test__arun_single_step_action_with_defaults(): + action = base_single_step_counter_with_defaults_async.with_name("counter") + state = State({}) + result, state = await _arun_single_step_action(action, state, {}) + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + class SingleStepActionWithDeletion(SingleStepAction): def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) @@ -770,7 +978,26 @@ def test__run_multistep_streaming_action(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} -async def test__run_multistep_streaming_action_async(): +def test__run_multistep_streaming_action_default(): + action = base_streaming_counter_with_defaults.with_name("counter") + state = State({}) + generator = _run_multi_step_streaming_action(action, state, inputs={}) + last_result = -1 + result = None + for result, state in generator: + if last_result < 1: + # Otherwise you hit floating poit comparison problems + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + +async def test__arun_multistep_streaming_action(): action = base_streaming_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_multi_step_streaming_action(action, state, inputs={}) @@ -785,6 +1012,25 @@ async def test__run_multistep_streaming_action_async(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} +async def test__arun_multistep_streaming_action_with_defaults(): + action = base_streaming_counter_with_defaults_async.with_name("counter") + state = State({}) + generator = _arun_multi_step_streaming_action(action, state, inputs={}) + last_result = -1 + result = None + async for result, state in generator: + if last_result < 1: + # Otherwise you hit floating poit comparison problems + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + def test__run_streaming_action_incorrect_result_type(): action = StreamingActionIncorrectResultType() state = State() @@ -793,7 +1039,7 @@ def test__run_streaming_action_incorrect_result_type(): collections.deque(gen, maxlen=0) # exhaust the generator -async def test__run_streaming_action_incorrect_result_type_async(): +async def test__arun_streaming_action_incorrect_result_type(): action = StreamingActionIncorrectResultTypeAsync() state = State() with pytest.raises(ValueError, match="returned a non-dict"): @@ -810,7 +1056,7 @@ def test__run_single_step_streaming_action_incorrect_result_type(): collections.deque(gen, maxlen=0) # exhaust the generator -async def test__run_single_step_streaming_action_incorrect_result_type_async(): +async def test__arun_single_step_streaming_action_incorrect_result_type(): action = StreamingSingleStepActionIncorrectResultTypeAsync() state = State() with pytest.raises(ValueError, match="returned a non-dict"): @@ -834,7 +1080,27 @@ def test__run_single_step_streaming_action(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} -async def test__run_single_step_streaming_action_async(): +def test__run_single_step_streaming_with_defaults(): + action = base_streaming_single_step_counter_with_defaults.with_name("counter") + state = State() + generator = _run_single_step_streaming_action(action, state, inputs={}) + last_result = -1 + result, state = None, None + for result, state in generator: + if last_result < 1: + # Otherwise you hit comparison issues + # This is because we get to the last one, which is the final result + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + +async def test__arun_single_step_streaming_action(): async_action = base_streaming_single_step_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_single_step_streaming_action(async_action, state, inputs={}) @@ -850,6 +1116,26 @@ async def test__run_single_step_streaming_action_async(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} +async def test__arun_single_step_streaming_action_with_defaults(): + async_action = base_streaming_single_step_counter_with_defaults_async.with_name("counter") + state = State({}) + generator = _arun_single_step_streaming_action(async_action, state, inputs={}) + last_result = -1 + result, state = None, None + async for result, state in generator: + if last_result < 1: + # Otherwise you hit comparison issues + # This is because we get to the last one, which is the final result + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) @@ -866,7 +1152,7 @@ def test_app_step(): """Tests that we can run a step in an app""" counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -945,7 +1231,7 @@ def test_app_step_done(): """Tests that when we cannot run a step, we return None""" counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -963,7 +1249,7 @@ async def test_app_astep(): """Tests that we can run an async step in an app""" counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1022,7 +1308,7 @@ async def test_app_astep_broken(caplog): """Tests that we can run a step in an app""" broken_action = base_broken_action_async.with_name("broken_action_unique_name") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="broken_action_unique_name", partition_key="test", uid="test-123", @@ -1042,7 +1328,7 @@ async def test_app_astep_done(): """Tests that when we cannot run a step, we return None""" counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1060,7 +1346,7 @@ async def test_app_astep_done(): def test_app_many_steps(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1080,7 +1366,7 @@ def test_app_many_steps(): async def test_app_many_a_steps(): counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1101,7 +1387,7 @@ def test_iterate(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1141,7 +1427,7 @@ def test_iterate_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1170,7 +1456,7 @@ async def test_aiterate(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1202,7 +1488,7 @@ async def test_aiterate_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1232,7 +1518,7 @@ async def test_app_aiterate_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1257,7 +1543,7 @@ def test_run(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1279,7 +1565,7 @@ def test_run_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1302,7 +1588,7 @@ def test_run_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1326,7 +1612,7 @@ def test_run_with_inputs_multiple_actions(): counter_action1 = base_counter_action_with_inputs.with_name("counter1") counter_action2 = base_counter_action_with_inputs.with_name("counter2") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter1", partition_key="test", uid="test-123", @@ -1351,7 +1637,7 @@ async def test_arun(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1373,7 +1659,7 @@ async def test_arun_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1396,7 +1682,7 @@ async def test_arun_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1421,7 +1707,7 @@ async def test_arun_with_inputs_multiple_actions(): counter_action1 = base_counter_action_with_inputs_async.with_name("counter1") counter_action2 = base_counter_action_with_inputs_async.with_name("counter2") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter1", partition_key="test", uid="test-123", @@ -1449,7 +1735,7 @@ async def test_app_a_run_async_and_sync(): counter_action_sync = base_counter_action_async.with_name("counter_sync") counter_action_async = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_sync", partition_key="test", uid="test-123", @@ -1878,7 +2164,7 @@ async def test_astream_result_halt_before(): def test_app_set_state(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State(), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1901,7 +2187,7 @@ def test_app_get_next_step(): counter_action_2 = base_counter_action.with_name("counter_2") counter_action_3 = base_counter_action.with_name("counter_3") app = Application( - state=State(), + state=State({"count": 0}), entrypoint="counter_1", partition_key="test", uid="test-123", @@ -1988,7 +2274,7 @@ def test_application_run_step_hooks_sync(): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test", @@ -2035,7 +2321,7 @@ async def test_application_run_step_hooks_async(): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test", @@ -2079,7 +2365,7 @@ async def test_application_run_step_runs_hooks(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(*hooks), partition_key="test", @@ -2147,7 +2433,7 @@ def post_application_create(self, **kwargs): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test", diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 4d15730f..9f94d89b 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -3,11 +3,12 @@ see failures in these tests, you should make a unit test, demonstrate the failure there, then fix both in that test and the end-to-end test.""" from io import StringIO -from typing import Any, Tuple +from typing import Any, AsyncGenerator, Generator, Tuple from unittest.mock import patch from burr.core import Action, ApplicationBuilder, State, action -from burr.core.action import Input, Result, expr +from burr.core.action import Input, Result, expr, streaming_action +from burr.core.graph import GraphBuilder from burr.lifecycle import base @@ -89,3 +90,105 @@ def echo(state: State) -> Tuple[dict, State]: format="png", ) assert result["response"] == prompt + + +def test_action_end_to_end_streaming_with_defaults(): + @streaming_action( + reads=["count"], + writes=["done", "error"], + default_reads={"count": 10}, + default_writes={"done": False, "error": None}, + ) + def echo( + state: State, should_error: bool, letter_to_repeat: str + ) -> Generator[Tuple[dict, State], None, None]: + for i in range(state["count"]): + yield {"letter_to_repeat": letter_to_repeat}, None + if should_error: + yield {"error": "Error"}, state.update(error="Error") + else: + yield {"done": True}, state.update(done=True) + + graph = ( + GraphBuilder() + .with_actions( + echo_success=echo.bind(should_error=False), + echo_failure=echo.bind(should_error=True), + ) + .with_transitions( + ("echo_success", "echo_failure"), + ("echo_failure", "echo_success"), + ) + .build() + ) + app = ApplicationBuilder().with_graph(graph).with_entrypoint("echo_success").build() + action_completed, streaming_container = app.stream_result( + halt_after=["echo_success"], inputs={"letter_to_repeat": "a"} + ) + for item in streaming_container: + assert item == {"letter_to_repeat": "a"} + result, state = streaming_container.get() + assert result == {"done": True} + assert state["done"] is True + assert state["error"] is None # default + + action_completed, streaming_container = app.stream_result( + halt_after=["echo_failure"], inputs={"letter_to_repeat": "a"} + ) + for item in streaming_container: + assert item == {"letter_to_repeat": "a"} + result, state = streaming_container.get() + assert result == {"error": "Error"} + assert state["done"] is False + assert state["error"] == "Error" + + +async def test_action_end_to_end_streaming_with_defaults_async(): + @streaming_action( + reads=["count"], + writes=["done", "error"], + default_reads={"count": 10}, + default_writes={"done": False, "error": None}, + ) + async def echo( + state: State, should_error: bool, letter_to_repeat: str + ) -> AsyncGenerator[Tuple[dict, State], None]: + for i in range(state["count"]): + yield {"letter_to_repeat": letter_to_repeat}, None + if should_error: + yield {"error": "Error"}, state.update(error="Error") + else: + yield {"done": True}, state.update(done=True) + + graph = ( + GraphBuilder() + .with_actions( + echo_success=echo.bind(should_error=False), + echo_failure=echo.bind(should_error=True), + ) + .with_transitions( + ("echo_success", "echo_failure"), + ("echo_failure", "echo_success"), + ) + .build() + ) + app = ApplicationBuilder().with_graph(graph).with_entrypoint("echo_success").build() + action_completed, streaming_container = await app.astream_result( + halt_after=["echo_success"], inputs={"letter_to_repeat": "a"} + ) + async for item in streaming_container: + assert item == {"letter_to_repeat": "a"} + result, state = await streaming_container.get() + assert result == {"done": True} + assert state["done"] is True + assert state["error"] is None # default + + action_completed, streaming_container = await app.astream_result( + halt_after=["echo_failure"], inputs={"letter_to_repeat": "a"} + ) + async for item in streaming_container: + assert item == {"letter_to_repeat": "a"} + result, state = await streaming_container.get() + assert result == {"error": "Error"} + assert state["done"] is False + assert state["error"] == "Error"