diff --git a/burr/integrations/pydantic.py b/burr/integrations/pydantic.py index a42deeda..3ce4b97c 100644 --- a/burr/integrations/pydantic.py +++ b/burr/integrations/pydantic.py @@ -113,14 +113,13 @@ def _validate_and_extract_signature_types( ) type_hints = typing.get_type_hints(fn) - if (state_model := type_hints["state"]) is inspect.Parameter.empty or not issubclass( - state_model, pydantic.BaseModel - ): + state_model = type_hints.get("state") + if state_model is None or state_model is inspect.Parameter.empty or not issubclass(state_model, pydantic.BaseModel): raise ValueError( f"Function fn: {fn.__qualname__} is not a valid pydantic action. " - "a type annotation of a type extending: pydantic.BaseModel. Got parameter " - "state: {state_model.__qualname__}." + "The 'state' parameter must be annotated with a type extending pydantic.BaseModel." ) + if (ret_hint := type_hints.get("return")) is None or not issubclass( ret_hint, pydantic.BaseModel ): diff --git a/tests/integrations/test_burr_pydantic.py b/tests/integrations/test_burr_pydantic.py index 0d330f0d..4f9523d0 100644 --- a/tests/integrations/test_burr_pydantic.py +++ b/tests/integrations/test_burr_pydantic.py @@ -154,27 +154,27 @@ def test_model_from_state(): def _fn_without_state_arg(foo: OriginalModel) -> OriginalModel: - ... + return foo; def _fn_with_incorrect_state_arg(state: int) -> OriginalModel: - ... + return OriginalModel(foo, bar) def _fn_with_incorrect_return_type(state: OriginalModel) -> int: - ... + return 42 def _fn_with_no_return_type(state: OriginalModel): - ... + pass def _fn_correct_same_itype_otype(state: OriginalModel, input_1: int) -> OriginalModel: - ... + return state def _fn_correct_diff_itype_otype(state: OriginalModel, input_1: int) -> NestedModel: - ... + return NestedModel(nested_field1=input_1) @pytest.mark.parametrize(