From ce6ca209894a50a523ed97aec8bc3ea6a734c021 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Tue, 16 Jul 2024 08:31:31 -0700 Subject: [PATCH] Adds default reads/writes to burr actions This allows you to specify defaults if your action does not write. In the majority of cases they will be none, but this allows simple (static) arbitrary values. This specifically helps with the branching case -- e.g. where you have two options, and want to null out anything it doesn't write. For instance, an error and a result -- you'll only ever produce one or the other. This works both in the function and class-based approaches -- in the function-based it is part of the two decorators (@action/@streaming_action). In the class-based it is part of the class, overriding the default_reads and default_writes property function We add a bunch of new tests for default (as the code to handle multiple action types is fairly dispersed, for now), and also make the naming of the other tests/content more consistent. --- burr/core/action.py | 79 +++++++- burr/core/application.py | 60 ++++-- tests/core/test_application.py | 352 +++++++++++++++++++++++++++++---- 3 files changed, 426 insertions(+), 65 deletions(-) diff --git a/burr/core/action.py b/burr/core/action.py index 5f35cfe6..7c4933a9 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -42,6 +42,10 @@ def reads(self) -> list[str]: """ pass + @property + def default_reads(self) -> Dict[str, Any]: + return {} + @abc.abstractmethod def run(self, state: State, **run_kwargs) -> dict: """Runs the function on the given state and returns the result. @@ -122,6 +126,10 @@ def writes(self) -> list[str]: """ pass + @property + def default_writes(self) -> Dict[str, Any]: + return {} + @abc.abstractmethod def update(self, result: dict, state: State) -> State: pass @@ -484,6 +492,8 @@ def __init__( fn: Callable, reads: List[str], writes: List[str], + default_reads: Dict[str, Any] = None, + default_writes: Dict[str, Any] = None, bound_params: dict = None, ): """Instantiates a function-based action with the given function, reads, and writes. @@ -499,11 +509,21 @@ def __init__( self._writes = writes self._bound_params = bound_params if bound_params is not None else {} self._inputs = _get_inputs(self._bound_params, self._fn) + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} @property def fn(self) -> Callable: return self._fn + @property + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + + @property + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + @property def reads(self) -> list[str]: return self._reads @@ -526,7 +546,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction": :return: """ return FunctionBasedAction( - self._fn, self._reads, self._writes, {**self._bound_params, **kwargs} + self._fn, + self._reads, + self._writes, + self.default_reads, + self._default_writes, + {**self._bound_params, **kwargs}, ) def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]: @@ -918,6 +943,8 @@ def __init__( ], reads: List[str], writes: List[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, bound_params: dict = None, ): """Instantiates a function-based streaming action with the given function, reads, and writes. @@ -931,6 +958,8 @@ def __init__( self._fn = fn self._reads = reads self._writes = writes + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} self._bound_params = bound_params if bound_params is not None else {} async def _a_stream_run_and_update( @@ -957,6 +986,12 @@ def reads(self) -> list[str]: def writes(self) -> list[str]: return self._writes + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + @property def streaming(self) -> bool: return True @@ -969,7 +1004,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction": :return: """ return FunctionBasedStreamingAction( - self._fn, self._reads, self._writes, {**self._bound_params, **kwargs} + self._fn, + self._reads, + self._writes, + self._default_reads, + self._default_writes, + {**self._bound_params, **kwargs}, ) @property @@ -999,7 +1039,10 @@ def bind(self, **kwargs: Any) -> Self: ... -def copy_func(f: types.FunctionType) -> types.FunctionType: +T = TypeVar("T", bound=types.FunctionType) + + +def copy_func(f: T) -> T: """Copies a function. This is used internally to bind parameters to a function so we don't accidentally overwrite them. @@ -1033,7 +1076,12 @@ def my_action(state: State, z: int) -> tuple[dict, State]: return self -def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]: +def action( + reads: List[str], + writes: List[str], + default_reads: Dict[str, Any] = None, + default_writes: Dict[str, Any] = None, +) -> Callable[[Callable], FunctionRepresentingAction]: """Decorator to create a function-based action. This is user-facing. Note that, in the future, with typed state, we may not need this for all cases. @@ -1044,11 +1092,23 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Function :param reads: Items to read from the state :param writes: Items to write to the state + :param default_reads: Default values for reads. If nothing upstream produces these, they will + be filled automatically with the default values. Leaving blank will have no default values. + :param default_writes: Default values for writes. If your action does not write these, they will be + filled automatically. Leaving blank this will have no default values. :return: The decorator to assign the function as an action """ + default_reads = default_reads if default_reads is not None else {} + default_writes = default_writes if default_writes is not None else {} def decorator(fn) -> FunctionRepresentingAction: - setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes)) + setattr( + fn, + FunctionBasedAction.ACTION_FUNCTION, + FunctionBasedAction( + fn, reads, writes, default_reads=default_reads, default_writes=default_writes + ), + ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn @@ -1056,7 +1116,10 @@ def decorator(fn) -> FunctionRepresentingAction: def streaming_action( - reads: List[str], writes: List[str] + reads: List[str], + writes: List[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ) -> Callable[[Callable], FunctionRepresentingAction]: """Decorator to create a streaming function-based action. This is user-facing. @@ -1091,13 +1154,15 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State] return {'response': full_response}, state.update(response=full_response) """ + default_reads = default_reads if default_reads is not None else {} + default_writes = default_writes if default_writes is not None else {} def wrapped(fn) -> FunctionRepresentingAction: fn = copy_func(fn) setattr( fn, FunctionBasedAction.ACTION_FUNCTION, - FunctionBasedStreamingAction(fn, reads, writes), + FunctionBasedStreamingAction(fn, reads, writes, default_reads, default_writes), ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn diff --git a/burr/core/application.py b/burr/core/application.py index 69591f00..93bace7c 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -83,6 +83,19 @@ def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_ _raise_fn_return_validation_error(output, action_name) +def _apply_defaults(state: State, defaults: Dict[str, Any]) -> State: + state_update = {} + state_to_use = state + # We really don't need to short-circuit but I want to avoid the update function + # So we might as well + if len(defaults) > 0: + for key, value in defaults.items(): + if key not in state: + state_update[key] = value + state_to_use = state.update(**state_update) + return state_to_use + + def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict: """Runs a function, returning the result of running the function. Note this restricts the keys in the state to only those that the @@ -100,6 +113,7 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name "instead...)" ) state_to_use = state.subset(*function.reads) + state_to_use = _apply_defaults(state_to_use, function.default_reads) function.validate_inputs(inputs) result = function.run(state_to_use, **inputs) _validate_result(result, name) @@ -112,6 +126,7 @@ async def _arun_function( """Runs a function, returning the result of running the function. Async version of the above.""" state_to_use = state.subset(*function.reads) + state_to_use = _apply_defaults(state_to_use, function.default_reads) function.validate_inputs(inputs) result = await function.run(state_to_use, **inputs) _validate_result(result, name) @@ -177,6 +192,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta f"Action {name} attempted to write to keys {extra_keys} " f"that it did not declare. It declared: ({reducer.writes})!" ) + new_state = _apply_defaults(new_state, reducer.default_writes) _validate_reducer_writes(reducer, new_state, name) return _state_update(state, new_state) @@ -229,9 +245,11 @@ def _run_single_step_action( """ # TODO -- guard all reads/writes with a subset of the state action.validate_inputs(inputs) + state = _apply_defaults(state, action.default_reads) result, new_state = _adjust_single_step_output( action.run_and_update(state, **inputs), action.name ) + new_state = _apply_defaults(new_state, action.default_writes) _validate_result(result, action.name) out = result, _state_update(state, new_state) _validate_result(result, action.name) @@ -245,6 +263,7 @@ def _run_single_step_streaming_action( """Runs a single step streaming action. This API is internal-facing. This normalizes + validates the output.""" action.validate_inputs(inputs) + state = _apply_defaults(state, action.default_reads) generator = action.stream_run_and_update(state, **inputs) result = None state_update = None @@ -265,15 +284,33 @@ def _run_single_step_streaming_action( f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')" ) _validate_result(result, action.name) + state_update = _apply_defaults(state_update, action.default_writes) _validate_reducer_writes(action, state_update, action.name) yield result, state_update +async def _arun_single_step_action( + action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] +) -> Tuple[dict, State]: + """Runs a single step action in async. See the synchronous version for more details.""" + state_to_use = state + state_to_use = _apply_defaults(state_to_use, action.default_reads) + action.validate_inputs(inputs) + result, new_state = _adjust_single_step_output( + await action.run_and_update(state_to_use, **inputs), action.name + ) + new_state = _apply_defaults(new_state, action.default_writes) + _validate_result(result, action.name) + _validate_reducer_writes(action, new_state, action.name) + return result, _state_update(state, new_state) + + async def _arun_single_step_streaming_action( action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]] ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: """Runs a single step streaming action in async. See the synchronous version for more details.""" action.validate_inputs(inputs) + state = _apply_defaults(state, action.default_reads) generator = action.stream_run_and_update(state, **inputs) result = None state_update = None @@ -294,6 +331,7 @@ async def _arun_single_step_streaming_action( f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')" ) _validate_result(result, action.name) + state_update = _apply_defaults(state_update, action.default_writes) _validate_reducer_writes(action, state_update, action.name) # TODO -- add guard against zero-length stream yield result, state_update @@ -310,6 +348,7 @@ def _run_multi_step_streaming_action( This peeks ahead by one so we know when this is done (and when to validate). """ action.validate_inputs(inputs) + state = _apply_defaults(state, action.default_reads) generator = action.stream_run(state, **inputs) result = None for item in generator: @@ -320,8 +359,9 @@ def _run_multi_step_streaming_action( result = item if next_result is not None: yield next_result, None - state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name) + state_update = _run_reducer(action, state, result, action.name) + state_update = _apply_defaults(state_update, action.default_writes) _validate_reducer_writes(action, state_update, action.name) yield result, state_update @@ -331,6 +371,7 @@ async def _arun_multi_step_streaming_action( ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: """Runs a multi-step streaming action in async. See the synchronous version for more details.""" action.validate_inputs(inputs) + state = _apply_defaults(state, action.default_reads) generator = action.stream_run(state, **inputs) result = None async for item in generator: @@ -341,26 +382,13 @@ async def _arun_multi_step_streaming_action( result = item if next_result is not None: yield next_result, None - state_update = _run_reducer(action, state, result, action.name) _validate_result(result, action.name) + state_update = _run_reducer(action, state, result, action.name) + state_update = _apply_defaults(state_update, action.default_writes) _validate_reducer_writes(action, state_update, action.name) yield result, state_update -async def _arun_single_step_action( - action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] -) -> Tuple[dict, State]: - """Runs a single step action in async. See the synchronous version for more details.""" - state_to_use = state - action.validate_inputs(inputs) - result, new_state = _adjust_single_step_output( - await action.run_and_update(state_to_use, **inputs), action.name - ) - _validate_result(result, action.name) - _validate_reducer_writes(action, new_state, action.name) - return result, _state_update(state, new_state) - - @dataclasses.dataclass class ApplicationGraph(Graph): """User-facing representation of the state machine. This has diff --git a/tests/core/test_application.py b/tests/core/test_application.py index d6c4257a..b287dafd 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -27,6 +27,7 @@ ApplicationBuilder, ApplicationContext, _adjust_single_step_output, + _apply_defaults, _arun_function, _arun_multi_step_streaming_action, _arun_single_step_action, @@ -60,6 +61,8 @@ def __init__( fn: Callable[..., dict], update_fn: Callable[[dict, State], State], inputs: list[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ): super(PassedInAction, self).__init__() self._reads = reads @@ -67,6 +70,8 @@ def __init__( self._fn = fn self._update_fn = update_fn self._inputs = inputs + self._default_reads = default_reads if default_reads is not None else {} + self._default_writes = default_writes if default_writes is not None else {} def run(self, state: State, **run_kwargs) -> dict: return self._fn(state, **run_kwargs) @@ -75,6 +80,14 @@ def run(self, state: State, **run_kwargs) -> dict: def inputs(self) -> list[str]: return self._inputs + @property + def default_reads(self) -> Dict[str, Any]: + return self._default_reads + + @property + def default_writes(self) -> Dict[str, Any]: + return self._default_writes + def update(self, result: dict, state: State) -> State: return self._update_fn(result, state) @@ -95,8 +108,18 @@ def __init__( fn: Callable[..., Awaitable[dict]], update_fn: Callable[[dict, State], State], inputs: list[str], + default_reads: Optional[Dict[str, Any]] = None, + default_writes: Optional[Dict[str, Any]] = None, ): - super().__init__(reads=reads, writes=writes, fn=fn, update_fn=update_fn, inputs=inputs) # type: ignore + super().__init__( + reads=reads, + writes=writes, + fn=fn, + update_fn=update_fn, + inputs=inputs, + default_reads=default_reads, + default_writes=default_writes, + ) # type: ignore async def run(self, state: State, **run_kwargs) -> dict: return await self._fn(state, **run_kwargs) @@ -105,7 +128,7 @@ async def run(self, state: State, **run_kwargs) -> dict: base_counter_action = PassedInAction( reads=["count"], writes=["count"], - fn=lambda state: {"count": state.get("count", 0) + 1}, + fn=lambda state: {"count": state["count"] + 1}, update_fn=lambda result, state: state.update(**result), inputs=[], ) @@ -113,13 +136,21 @@ async def run(self, state: State, **run_kwargs) -> dict: base_counter_action_with_inputs = PassedInAction( reads=["count"], writes=["count"], - fn=lambda state, additional_increment: { - "count": state.get("count", 0) + 1 + additional_increment - }, + fn=lambda state, additional_increment: {"count": state["count"] + 1 + additional_increment}, update_fn=lambda result, state: state.update(**result), inputs=["additional_increment"], ) +base_counter_action_with_defaults = PassedInAction( + reads=["count"], + writes=["count", "error"], + fn=lambda state: {"count": state["count"] + 1}, + update_fn=lambda result, state: state.update(**result), + inputs=[], + default_reads={"count": 0}, + default_writes={"error": None}, +) + class ActionTracker(PreRunStepHook, PostRunStepHook): def __init__(self): @@ -172,7 +203,7 @@ async def post_run_step(self, *, action: Action, **future_kwargs): async def _counter_update_async(state: State, additional_increment: int = 0) -> dict: await asyncio.sleep(0.0001) # just so we can make this *truly* async # does not matter, but more accurately simulates an async function - return {"count": state.get("count", 0) + 1 + additional_increment} + return {"count": state["count"] + 1 + additional_increment} base_counter_action_async = PassedInActionAsync( @@ -193,6 +224,16 @@ async def _counter_update_async(state: State, additional_increment: int = 0) -> inputs=["additional_increment"], ) +base_counter_action_async_with_defaults = PassedInActionAsync( + reads=["count"], + writes=["count", "error"], + fn=_counter_update_async, + update_fn=lambda result, state: state.update(**result), + inputs=[], + default_reads={"count": 0}, + default_writes={"error": None}, +) + class BrokenStepException(Exception): pass @@ -236,18 +277,35 @@ async def incorrect(x): ) +def test__apply_defaults(): + state = State({"in_state": 0}) + defaults = {"in_state": 1, "not_in_state": 2} + result = _apply_defaults(state, defaults) + assert result.get("in_state") == 0 + assert result.get("not_in_state") == 2 + + def test__run_function(): """Tests that we can run a function""" action = base_counter_action - state = State({}) + state = State({"count": 0}) result = _run_function(action, state, inputs={}, name=action.name) assert result == {"count": 1} +def test__run_function_defaults(): + action = base_counter_action_with_defaults + state = State({}) + result = _run_function(action, state, inputs={}, name=action.name) + assert result == { + "count": 1 + } # default read is applied, write is not as it is a reducer capability, and is not part of the result + + def test__run_function_with_inputs(): """Tests that we can run a function""" action = base_counter_action_with_inputs - state = State({}) + state = State({"count": 0}) result = _run_function(action, state, inputs={"additional_increment": 1}, name=action.name) assert result == {"count": 2} @@ -255,7 +313,7 @@ def test__run_function_with_inputs(): def test__run_function_cant_run_async(): """Tests that we can't run an async function""" action = base_counter_action_async - state = State({}) + state = State({"count": 0}) with pytest.raises(ValueError, match="async"): _run_function(action, state, inputs={}, name=action.name) @@ -263,11 +321,19 @@ def test__run_function_cant_run_async(): def test__run_function_incorrect_result_type(): """Tests that we can run an async function""" action = base_action_incorrect_result_type - state = State({}) + state = State({"count": 0}) with pytest.raises(ValueError, match="returned a non-dict"): _run_function(action, state, inputs={}, name=action.name) +def test__run_reducer_applies_defaults(): + """Tests that we can run a reducer and it behaves as expected""" + reducer = base_counter_action_with_defaults + state = State({"count": 0}) + state = _run_reducer(reducer, state, {"count": 1}, "reducer") + assert state.get_all() == {"count": 1, "error": None} + + def test__run_reducer_modifies_state(): """Tests that we can run a reducer and it behaves as expected""" reducer = PassedInAction( @@ -299,6 +365,14 @@ def test__run_reducer_deletes_state(): async def test__arun_function(): """Tests that we can run an async function""" action = base_counter_action_async + state = State({"count": 0}) + result = await _arun_function(action, state, inputs={}, name=action.name) + assert result == {"count": 1} + + +async def test__arun_function_with_defaults(): + """Tests that we can run an async function""" + action = base_counter_action_async_with_defaults state = State({}) result = await _arun_function(action, state, inputs={}, name=action.name) assert result == {"count": 1} @@ -315,7 +389,7 @@ async def test__arun_function_incorrect_result_type(): async def test__arun_function_with_inputs(): """Tests that we can run an async function""" action = base_counter_action_with_inputs_async - state = State({}) + state = State({"count": 0}) result = await _arun_function( action, state, inputs={"additional_increment": 1}, name=action.name ) @@ -466,6 +540,20 @@ def inputs(self) -> list[str]: return ["additional_increment"] +class SingleStepCounterWithDefaults(SingleStepCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0, "tracker": []} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepActionIncorrectResultType(SingleStepAction): def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return "not a dict", state @@ -498,6 +586,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepCounterWithDefaultsAsync(SingleStepCounterAsync): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0, "tracker": []} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepCounterWithInputsAsync(SingleStepCounterAsync): @property def inputs(self) -> list[str]: @@ -527,6 +629,20 @@ def update(self, result: dict, state: State) -> State: return state.update(**result).append(tracker=result["count"]) +class StreamingCounterWithDefaults(StreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0, "tracker": []} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class AsyncStreamingCounter(AsyncStreamingAction): async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: if "steps_per_count" in run_kwargs: @@ -552,6 +668,20 @@ def update(self, result: dict, state: State) -> State: return state.update(**result).append(tracker=result["count"]) +class StreamingCounterWithDefaultsAsync(AsyncStreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0, "tracker": []} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepStreamingCounter(SingleStepStreamingAction): def stream_run_and_update( self, state: State, **run_kwargs @@ -571,6 +701,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepStreamingCounterWithDefaults(SingleStepStreamingCounter): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0, "tracker": []} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class SingleStepStreamingCounterAsync(SingleStepStreamingAction): async def stream_run_and_update( self, state: State, **run_kwargs @@ -592,6 +736,20 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepStreamingCounterWithDefaultsAsync(SingleStepStreamingCounterAsync): + @property + def default_reads(self) -> dict[str, Any]: + return {"count": 0, "tracker": []} + + @property + def default_writes(self) -> dict[str, Any]: + return {"error": None} + + @property + def writes(self) -> list[str]: + return super().writes + ["error"] + + class StreamingActionIncorrectResultType(StreamingAction): def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]: yield {} @@ -662,12 +820,20 @@ def writes(self) -> list[str]: base_single_step_counter_async = SingleStepCounterAsync() base_single_step_counter_with_inputs = SingleStepCounterWithInputs() base_single_step_counter_with_inputs_async = SingleStepCounterWithInputsAsync() +base_single_step_counter_with_defaults = SingleStepCounterWithDefaults() +base_single_step_counter_with_defaults_async = SingleStepCounterWithDefaultsAsync() base_streaming_counter = StreamingCounter() +base_streaming_counter_with_defaults = StreamingCounterWithDefaults() base_streaming_single_step_counter = SingleStepStreamingCounter() +base_streaming_single_step_counter_with_defaults = SingleStepStreamingCounterWithDefaults() base_streaming_counter_async = AsyncStreamingCounter() +base_streaming_counter_with_defaults_async = StreamingCounterWithDefaultsAsync() base_streaming_single_step_counter_async = SingleStepStreamingCounterAsync() +base_streaming_single_step_counter_with_defaults_async = ( + SingleStepStreamingCounterWithDefaultsAsync() +) base_single_step_action_incorrect_result_type = SingleStepActionIncorrectResultType() base_single_step_action_incorrect_result_type_async = SingleStepActionIncorrectResultTypeAsync() @@ -709,6 +875,18 @@ def test__run_single_step_action_with_inputs(): assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} +def test__run_single_step_action_with_defaults(): + action = base_single_step_counter_with_defaults.with_name("counter") + state = State({}) + result, state = _run_single_step_action(action, state, {}) + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + async def test__arun_single_step_action(): action = base_single_step_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) @@ -735,6 +913,18 @@ async def test__arun_single_step_action_with_inputs(): assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} +async def test__arun_single_step_action_with_defaults(): + action = base_single_step_counter_with_defaults_async.with_name("counter") + state = State({}) + result, state = await _arun_single_step_action(action, state, {}) + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + class SingleStepActionWithDeletion(SingleStepAction): def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) @@ -770,7 +960,26 @@ def test__run_multistep_streaming_action(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} -async def test__run_multistep_streaming_action_async(): +def test__run_multistep_streaming_action_default(): + action = base_streaming_counter_with_defaults.with_name("counter") + state = State({}) + generator = _run_multi_step_streaming_action(action, state, inputs={}) + last_result = -1 + result = None + for result, state in generator: + if last_result < 1: + # Otherwise you hit floating poit comparison problems + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + +async def test__arun_multistep_streaming_action(): action = base_streaming_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_multi_step_streaming_action(action, state, inputs={}) @@ -785,6 +994,25 @@ async def test__run_multistep_streaming_action_async(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} +async def test__arun_multistep_streaming_action_with_defaults(): + action = base_streaming_counter_with_defaults_async.with_name("counter") + state = State({}) + generator = _arun_multi_step_streaming_action(action, state, inputs={}) + last_result = -1 + result = None + async for result, state in generator: + if last_result < 1: + # Otherwise you hit floating poit comparison problems + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + def test__run_streaming_action_incorrect_result_type(): action = StreamingActionIncorrectResultType() state = State() @@ -793,7 +1021,7 @@ def test__run_streaming_action_incorrect_result_type(): collections.deque(gen, maxlen=0) # exhaust the generator -async def test__run_streaming_action_incorrect_result_type_async(): +async def test__arun_streaming_action_incorrect_result_type(): action = StreamingActionIncorrectResultTypeAsync() state = State() with pytest.raises(ValueError, match="returned a non-dict"): @@ -810,7 +1038,7 @@ def test__run_single_step_streaming_action_incorrect_result_type(): collections.deque(gen, maxlen=0) # exhaust the generator -async def test__run_single_step_streaming_action_incorrect_result_type_async(): +async def test__arun_single_step_streaming_action_incorrect_result_type(): action = StreamingSingleStepActionIncorrectResultTypeAsync() state = State() with pytest.raises(ValueError, match="returned a non-dict"): @@ -834,7 +1062,27 @@ def test__run_single_step_streaming_action(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} -async def test__run_single_step_streaming_action_async(): +def test__run_single_step_streaming_with_defaults(): + action = base_streaming_single_step_counter_with_defaults.with_name("counter") + state = State() + generator = _run_single_step_streaming_action(action, state, inputs={}) + last_result = -1 + result, state = None, None + for result, state in generator: + if last_result < 1: + # Otherwise you hit comparison issues + # This is because we get to the last one, which is the final result + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + +async def test__arun_single_step_streaming_action(): async_action = base_streaming_single_step_counter_async.with_name("counter") state = State({"count": 0, "tracker": []}) generator = _arun_single_step_streaming_action(async_action, state, inputs={}) @@ -850,6 +1098,26 @@ async def test__run_single_step_streaming_action_async(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} +async def test__arun_single_step_streaming_action_with_defaults(): + async_action = base_streaming_single_step_counter_with_defaults_async.with_name("counter") + state = State({}) + generator = _arun_single_step_streaming_action(async_action, state, inputs={}) + last_result = -1 + result, state = None, None + async for result, state in generator: + if last_result < 1: + # Otherwise you hit comparison issues + # This is because we get to the last one, which is the final result + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker", "error").get_all() == { + "count": 1, + "tracker": [1], + "error": None, + } + + class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) @@ -866,7 +1134,7 @@ def test_app_step(): """Tests that we can run a step in an app""" counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -945,7 +1213,7 @@ def test_app_step_done(): """Tests that when we cannot run a step, we return None""" counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -963,7 +1231,7 @@ async def test_app_astep(): """Tests that we can run an async step in an app""" counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1022,7 +1290,7 @@ async def test_app_astep_broken(caplog): """Tests that we can run a step in an app""" broken_action = base_broken_action_async.with_name("broken_action_unique_name") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="broken_action_unique_name", partition_key="test", uid="test-123", @@ -1042,7 +1310,7 @@ async def test_app_astep_done(): """Tests that when we cannot run a step, we return None""" counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1060,7 +1328,7 @@ async def test_app_astep_done(): def test_app_many_steps(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1080,7 +1348,7 @@ def test_app_many_steps(): async def test_app_many_a_steps(): counter_action = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_async", partition_key="test", uid="test-123", @@ -1101,7 +1369,7 @@ def test_iterate(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1141,7 +1409,7 @@ def test_iterate_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1170,7 +1438,7 @@ async def test_aiterate(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1202,7 +1470,7 @@ async def test_aiterate_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1232,7 +1500,7 @@ async def test_app_aiterate_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1257,7 +1525,7 @@ def test_run(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1279,7 +1547,7 @@ def test_run_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1302,7 +1570,7 @@ def test_run_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1326,7 +1594,7 @@ def test_run_with_inputs_multiple_actions(): counter_action1 = base_counter_action_with_inputs.with_name("counter1") counter_action2 = base_counter_action_with_inputs.with_name("counter2") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter1", partition_key="test", uid="test-123", @@ -1351,7 +1619,7 @@ async def test_arun(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1373,7 +1641,7 @@ async def test_arun_halt_before(): result_action = Result("count").with_name("result") counter_action = base_counter_action_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1396,7 +1664,7 @@ async def test_arun_with_inputs(): result_action = Result("count").with_name("result") counter_action = base_counter_action_with_inputs_async.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1421,7 +1689,7 @@ async def test_arun_with_inputs_multiple_actions(): counter_action1 = base_counter_action_with_inputs_async.with_name("counter1") counter_action2 = base_counter_action_with_inputs_async.with_name("counter2") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter1", partition_key="test", uid="test-123", @@ -1449,7 +1717,7 @@ async def test_app_a_run_async_and_sync(): counter_action_sync = base_counter_action_async.with_name("counter_sync") counter_action_async = base_counter_action_async.with_name("counter_async") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter_sync", partition_key="test", uid="test-123", @@ -1878,7 +2146,7 @@ async def test_astream_result_halt_before(): def test_app_set_state(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State(), + state=State({"count": 0}), entrypoint="counter", partition_key="test", uid="test-123", @@ -1901,7 +2169,7 @@ def test_app_get_next_step(): counter_action_2 = base_counter_action.with_name("counter_2") counter_action_3 = base_counter_action.with_name("counter_3") app = Application( - state=State(), + state=State({"count": 0}), entrypoint="counter_1", partition_key="test", uid="test-123", @@ -1988,7 +2256,7 @@ def test_application_run_step_hooks_sync(): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test", @@ -2035,7 +2303,7 @@ async def test_application_run_step_hooks_async(): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test", @@ -2079,7 +2347,7 @@ async def test_application_run_step_runs_hooks(): counter_action = base_counter_action.with_name("counter") app = Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(*hooks), partition_key="test", @@ -2147,7 +2415,7 @@ def post_application_create(self, **kwargs): counter_action = base_counter_action.with_name("counter") result_action = Result("count").with_name("result") Application( - state=State({}), + state=State({"count": 0}), entrypoint="counter", adapter_set=internal.LifecycleAdapterSet(tracker), partition_key="test",