From 1b6bdb48864c6a06cad274ac88459ad07d25dda0 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Sun, 25 Feb 2024 12:34:49 -0800 Subject: [PATCH] Solves #31 -- enables deletion operation We weren't respecting it before. This handles that. This just measures the delete operations and reapplies it. That said, this is not the best way of doing things -- see #33 for a more involved approach. Note we've also relaxed the restrictions on modifications for delete/write -- this is necessary for the current workaround. --- burr/core/application.py | 54 +++++++++++++++++++++--------- tests/core/test_application.py | 61 ++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 15 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 89bb529a..2dd6a906 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -75,6 +75,33 @@ async def _arun_function(function: Function, state: State, inputs: Dict[str, Any return await function.run(state_to_use, **inputs) +def _state_update(state_to_modify: State, modified_state: State) -> State: + """This is a hack to apply state updates and ensure that we are respecting deletions. Specifically, the process is: + + 1. We subset the state to what we want to read + 2. We perform a set of state-specific writes to it + 3. We measure which ones were deleted + 4. We then merge the whole state back in + 5. We then delete the keys that were deleted + + This is suboptimal -- we should not be observing the state, we should be using the state commands and layering in deltas. + That said, we currently eagerly evaluate the state at all operations, which means we have to do it this way. See + https://github.com/DAGWorks-Inc/burr/issues/33 for a more details plan. + + This function was written to solve this issue: https://github.com/DAGWorks-Inc/burr/issues/28. + + + :param state_subset_pre_update: The subset of state passed to the update() function + :param modified_state: The subset of state realized after the update() function + :param state_to_modify: The state to modify-- this is the original + :return: + """ + old_state_keys = set(state_to_modify.keys()) + new_state_keys = set(modified_state.keys()) + deleted_keys = list(old_state_keys - new_state_keys) + return state_to_modify.merge(modified_state).wipe(delete=deleted_keys) + + def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> State: """Runs the reducer, returning the new state. Note this restricts the keys in the state to only those that the function writes. @@ -84,17 +111,17 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta :param result: :return: """ - - state_to_use = state.subset(*reducer.writes) - new_state = reducer.update(result, state_to_use).subset(*reducer.writes) + # TODO -- better guarding on state reads/writes + new_state = reducer.update(result, state) keys_in_new_state = set(new_state.keys()) - extra_keys = keys_in_new_state - set(reducer.writes) - if extra_keys: + new_keys = keys_in_new_state - set(state.keys()) + extra_keys = new_keys - set(reducer.writes) + if len(extra_keys) > 0: raise ValueError( f"Action {name} attempted to write to keys {extra_keys} " f"that it did not declare. It declared: ({reducer.writes})!" ) - return state.merge(new_state) + return _state_update(state, new_state) def _create_dict_string(kwargs: dict) -> str: @@ -142,24 +169,21 @@ def _run_single_step_action( :param inputs: Inputs to pass directly to the action :return: The result of running the action, and the new state """ - state_to_use = state.subset( - *action.reads, *action.writes - ) # TODO -- specify some as required and some as not + # TODO -- guard all reads/writes with a subset of the state action.validate_inputs(inputs) - result, new_state = action.run_and_update(state_to_use, **inputs) - return result, state.merge(new_state.subset(*action.writes)) # we just want the writes action + result, new_state = action.run_and_update(state, **inputs) + out = result, _state_update(state, new_state) + return out async def _arun_single_step_action( action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] ) -> Tuple[dict, State]: """Runs a single step action in async. See the synchronous version for more details.""" - state_to_use = state.subset( - *action.reads, *action.writes - ) # TODO -- specify some as required and some as not + state_to_use = state action.validate_inputs(inputs) result, new_state = await action.run_and_update(state_to_use, **inputs) - return result, state.merge(new_state.subset(*action.writes)) # we just want the writes action + return result, _state_update(state, new_state) @dataclasses.dataclass diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 32b764c2..3814978c 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -15,6 +15,7 @@ _arun_single_step_action, _assert_set, _run_function, + _run_reducer, _run_single_step_action, _validate_actions, _validate_start, @@ -169,6 +170,34 @@ def test__run_function_cant_run_async(): _run_function(action, state, inputs={}) +def test__run_reducer_modifies_state(): + """Tests that we can run a reducer and it behaves as expected""" + reducer = PassedInAction( + reads=["counter"], + writes=["counter"], + fn=..., + update_fn=lambda result, state: state.update(**result), + inputs=[], + ) + state = State({"counter": 0}) + state = _run_reducer(reducer, state, {"counter": 1}, "reducer") + assert state["counter"] == 1 + + +def test__run_reducer_deletes_state(): + """Tests that we can run a reducer that deletes an item from state""" + reducer = PassedInAction( + reads=["counter"], + writes=[], # TODO -- figure out how we can better know that it deletes items...ß + fn=..., + update_fn=lambda result, state: state.wipe(delete=["counter"]), + inputs=[], + ) + state = State({"counter": 0}) + state = _run_reducer(reducer, state, {}, "deletion_reducer") + assert "counter" not in state + + async def test__arun_function(): """Tests that we can run an async function""" action = base_counter_action_async @@ -279,6 +308,38 @@ async def test__arun_single_step_action_with_inputs(): assert state.subset("count", "tracker").get_all() == {"count": 4, "tracker": [2, 4]} +class SingleStepActionWithDeletion(SingleStepAction): + def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: + return {}, state.wipe(delete=["to_delete"]) + + @property + def reads(self) -> list[str]: + return ["to_delete"] + + @property + def writes(self) -> list[str]: + return ["to_delete"] + + +def test__run_single_step_action_deletes_state(): + action = SingleStepActionWithDeletion() + state = State({"to_delete": 0}) + result, state = _run_single_step_action(action, state, inputs={}) + assert "to_delete" not in state + + +class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): + async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: + return {}, state.wipe(delete=["to_delete"]) + + +async def test__arun_single_step_action_deletes_state(): + action = SingleStepActionWithDeletionAsync() + state = State({"to_delete": 0}) + result, state = await _arun_single_step_action(action, state, inputs={}) + assert "to_delete" not in state + + def test_app_step(): """Tests that we can run a step in an app""" counter_action = base_counter_action.with_name("counter")