Skip to content

Commit

Permalink
Fix: model validation not being strict, thus strings were being coerc…
Browse files Browse the repository at this point in the history
…ed. Add: test to check if strings are not being coerced
  • Loading branch information
GeraldIr committed Jul 23, 2024
1 parent 2c17def commit 3396ae6
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 20 deletions.
34 changes: 17 additions & 17 deletions openeo_pg_parser_networkx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _parse_datamodel(nested_graph: dict) -> ProcessGraph:
Parses a nested process graph into the Pydantic datamodel for ProcessGraph.
"""

return ProcessGraph.model_validate(nested_graph)
return ProcessGraph.model_validate(nested_graph, strict=True)

def _parse_process_graph(self, process_graph: ProcessGraph, arg_name: str = None):
"""
Expand Down Expand Up @@ -187,11 +187,11 @@ def _parse_argument(self, arg: any, arg_name: str, access_func: Callable):

# This access func business is necessary to let the program "remember" how to access and thus update this reference later
sub_access_func = partial(
lambda key, access_func, new_value=None, set_bool=False: access_func()[
key
]
if not set_bool
else access_func().__setitem__(key, new_value),
lambda key, access_func, new_value=None, set_bool=False: (
access_func()[key]
if not set_bool
else access_func().__setitem__(key, new_value)
),
key=k,
access_func=access_func,
)
Expand All @@ -205,11 +205,11 @@ def _parse_argument(self, arg: any, arg_name: str, access_func: Callable):
parsed_arg = parse_nested_parameter(element)

sub_access_func = partial(
lambda key, access_func, new_value=None, set_bool=False: access_func()[
key
]
if not set_bool
else access_func().__setitem__(key, new_value),
lambda key, access_func, new_value=None, set_bool=False: (
access_func()[key]
if not set_bool
else access_func().__setitem__(key, new_value)
),
key=i,
access_func=access_func,
)
Expand Down Expand Up @@ -246,12 +246,12 @@ def _walk_node(self):

# This just points to the resolved_kwarg itself!
access_func = partial(
lambda node_uid, arg_name, new_value=None, set_bool=False: self.G.nodes[
node_uid
]["resolved_kwargs"][arg_name]
if not set_bool
else self.G.nodes[node_uid]["resolved_kwargs"].__setitem__(
arg_name, new_value
lambda node_uid, arg_name, new_value=None, set_bool=False: (
self.G.nodes[node_uid]["resolved_kwargs"][arg_name]
if not set_bool
else self.G.nodes[node_uid]["resolved_kwargs"].__setitem__(
arg_name, new_value
)
),
node_uid=self._EVAL_ENV.node_uid,
arg_name=arg_name,
Expand Down
5 changes: 3 additions & 2 deletions openeo_pg_parser_networkx/pg_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
RootModel,
StringConstraints,
ValidationError,
conlist,
constr,
field_validator,
model_validator,
Expand Down Expand Up @@ -259,7 +258,9 @@ def __repr__(self):


class TemporalInterval(RootModel):
root: conlist(Union[Year, Date, DateTime, Time, None], min_length=2, max_length=2)
root: Annotated[
list[Union[Year, Date, DateTime, Time, None]], Field(min_length=2, max_length=2)
]

@field_validator("root")
def validate_temporal_interval(cls, value: Any) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openeo-pg-parser-networkx"
version = "2024.5.0"
version = "2024.7.0"

description = "Parse OpenEO process graphs from JSON to traversible Python objects."
authors = ["Lukas Weidenholzer <[email protected]>", "Sean Hoyal <[email protected]>", "Valentina Hutter <[email protected]>", "Gerald Irsiegler <[email protected]>"]
Expand Down
36 changes: 36 additions & 0 deletions tests/test_pg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,42 @@ def test_bounding_box_with_faulty_crs(get_process_graph_with_args):
]


def test_string_validation(get_process_graph_with_args):
'''
During the pydantic 2 update, we found that some special strings get parsed
to non-string values ('t' to True, 'f' to False, etc.)
Check that every incoming string stays a string by default
'''

test_args = {
'arg_t': 't',
'arg_f': 'f',
'arg_str': 'arg_123_str',
'arg_int': '123',
'arg_float': '123.4',
}

pg = get_process_graph_with_args(test_args)

# Parse indirectly to check if model validation is strict and does not type coerce
parsed_graph = OpenEOProcessGraph(pg_data=pg)

# Parse directly to check if strict model validation works seperately
parsed_args = [
ProcessGraph.model_validate(pg, strict=True)
.process_graph[TEST_NODE_KEY]
.arguments[arg_name]
for arg_name in test_args.keys()
]

resolved_kwargs = parsed_graph.nodes[0][1]['resolved_kwargs'].items()

assert all([isinstance(resolved_kwarg, str) for _, resolved_kwarg in resolved_kwargs])

assert all([isinstance(parsed_arg, str) for parsed_arg in parsed_args])


def test_bounding_box_int_crs(get_process_graph_with_args):
pg = get_process_graph_with_args(
{'spatial_extent': {'west': 0, 'east': 10, 'south': 0, 'north': 10, 'crs': 4326}}
Expand Down

0 comments on commit 3396ae6

Please sign in to comment.