Skip to content

Commit

Permalink
Merge pull request #148 from NOAA-OWP/pairing_dict_hotfix
Browse files Browse the repository at this point in the history
Apply pairing dict hotfix and setup new test case
  • Loading branch information
fernando-aristizabal authored Aug 1, 2023
2 parents f121ad5 + ef7dbb5 commit c0ed2a1
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 20 deletions.
13 changes: 9 additions & 4 deletions src/gval/accessors/gval_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __handle_attribute_tracking(

return results

@Comparison.comparison_function_from_string
def categorical_compare(
self,
benchmark_map: Union[gpd.GeoDataFrame, xr.Dataset, xr.DataArray],
Expand Down Expand Up @@ -465,6 +466,7 @@ def compute_agreement_map(

return agreement_map

@Comparison.comparison_function_from_string
def compute_crosstab(
self,
benchmark_map: Union[xr.Dataset, xr.DataArray],
Expand All @@ -474,6 +476,7 @@ def compute_crosstab(
comparison_function: Optional[
Union[Callable, nb.np.ufunc.dufunc.DUFunc, np.ufunc, np.vectorize, str]
] = "szudzik",
pairing_dict: Optional[Dict[Tuple[Number, Number], Number]] = None,
) -> DataFrame[Crosstab_df]:
"""
Crosstab 2 or 3-dimensional xarray DataArray to produce Crosstab DataFrame.
Expand All @@ -490,6 +493,12 @@ def compute_crosstab(
Value to exclude from crosstab. This could be used to denote a no data value if masking wasn't used. By default, NaNs are not cross-tabulated.
comparison_function : Optional[Union[Callable, nb.np.ufunc.dufunc.DUFunc, np.ufunc, np.vectorize, str]], default = "szudzik"
Function to compute agreement values. If None, then no agreement values are computed.
pairing_dict: Optional[Dict[Tuple[Number, Number], Number]], default = None
When "pairing_dict" is used for the comparison_function argument, a pairing dictionary can be passed by user. A pairing dictionary is structured as `{(c, b) : a}` where `(c, b)` is a tuple of the candidate and benchmark value pairing, respectively, and `a` is the value for the agreement array to be used for this pairing.
If None is passed for pairing_dict, the allow_candidate_values and allow_benchmark_values arguments are required. For this case, the pairings in these two iterables will be paired in the order provided and an agreement value will be assigned to each pairing starting with 0 and ending with the number of possible pairings.
A pairing dictionary can be used by the user to note which values to allow and which to ignore for comparisons. It can also be used to decide how nans are handled for cases where either the candidate and benchmark maps have nans or both.
Returns
-------
Expand All @@ -498,10 +507,6 @@ def compute_crosstab(
"""
self.check_same_type(benchmark_map)

# NOTE: Temporary fix until better solution is found
if isinstance(comparison_function, str):
comparison_function = getattr(Comparison, comparison_function)

if isinstance(self._obj, xr.Dataset):
return _crosstab_Datasets(
self._obj,
Expand Down
67 changes: 59 additions & 8 deletions src/gval/comparison/compute_comparison.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Callable
from functools import wraps
import inspect
from numbers import Number
Expand All @@ -11,6 +11,7 @@
cantor_pair_signed,
szudzik_pair_signed,
difference,
_make_pairing_dict_fn,
)
from gval.comparison.agreement import _compute_agreement_map

Expand Down Expand Up @@ -233,12 +234,62 @@ def process_agreement_map(self, **kwargs) -> Union[xr.DataArray, xr.Dataset]:
Agreement map.
"""

if isinstance(kwargs["comparison_function"], str):
if kwargs["comparison_function"] in self.registered_functions:
kwargs["comparison_function"] = getattr(
self, kwargs["comparison_function"]
)
return self.comparison_function_from_string(func=_compute_agreement_map)(
**kwargs
)

def comparison_function_from_string(
self, func: Callable
) -> Callable: # pragma: no cover
"""
Decorator function to compose a pairing dict comparison function from a string argument
Parameters
----------
func: Callable
Function requiring check for pairing_dict comparison function
Returns
-------
Callable
Function with appropriate comparison function
"""

@wraps(func)
def wrapper(*args, **kwargs):
# NOTE: Temporary fix until better solution is found
if "comparison_function" in kwargs and isinstance(
kwargs["comparison_function"], str
):
if kwargs["comparison_function"] in self.registered_functions:
kwargs["comparison_function"] = getattr(
self, kwargs["comparison_function"]
)
else:
raise KeyError("Pairing function not found in registered functions")

# In case the arguments do not exist
kwargs["pairing_dict"] = kwargs.get("pairing_dict")
kwargs["allow_candidate_values"] = kwargs.get("allow_candidate_values")
kwargs["allow_benchmark_values"] = kwargs.get("allow_benchmark_values")

if (
kwargs["comparison_function"] == "pairing_dict"
): # when pairing_dict is a dict
# this creates the pairing dictionary from the passed allowed values
kwargs["comparison_function"] = _make_pairing_dict_fn(
pairing_dict=kwargs["pairing_dict"],
unique_candidate_values=kwargs["allow_candidate_values"],
unique_benchmark_values=kwargs["allow_benchmark_values"],
)

else:
raise KeyError("Pairing function not found in registered functions")
kwargs["comparison_function"] = getattr(self, "szudzik")

# Call the decorated function
result = func(*args, **kwargs)

return result

return _compute_agreement_map(**kwargs)
return wrapper
50 changes: 43 additions & 7 deletions tests/cases_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,57 @@
]


positive_cat = np.array([2, 2, 2, 2, 2, 2])
negative_cat = np.array([[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]])
rasterize_attrs = [None, None, ["category"], None, None, None]
memory_strategy = ["normal", "normal", "normal", "normal", "moderate", "aggressive"]
positive_cat = np.array([2, 2, 2, 2, 2, 2, 2])
negative_cat = np.array([[0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]])
rasterize_attrs = [None, None, ["category"], None, None, None, None]
memory_strategy = [
"normal",
"normal",
"normal",
"normal",
"moderate",
"aggressive",
"normal",
]
exception_list = [OSError, ValueError, TypeError]
comparison_funcs = [
"szudzik",
"szudzik",
"szudzik",
"szudzik",
"szudzik",
"szudzik",
"pairing_dict",
]


@parametrize(
"candidate_map, benchmark_map, positive_categories, negative_categories, rasterize_attributes, memory_strategies",
"candidate_map, benchmark_map, positive_categories, negative_categories, rasterize_attributes, memory_strategies, comparison_function",
list(
zip(
candidate_maps,
benchmark_maps,
[
candidate_maps[0],
candidate_maps[1],
candidate_maps[2],
candidate_maps[3],
candidate_maps[4],
candidate_maps[5],
candidate_maps[0],
],
[
benchmark_maps[0],
benchmark_maps[1],
benchmark_maps[2],
benchmark_maps[3],
benchmark_maps[4],
benchmark_maps[5],
benchmark_maps[0],
],
positive_cat,
negative_cat,
rasterize_attrs,
memory_strategy,
comparison_funcs,
)
),
)
Expand All @@ -100,6 +134,7 @@ def case_data_array_accessor_success(
negative_categories,
rasterize_attributes,
memory_strategies,
comparison_function,
):
return (
candidate_map,
Expand All @@ -108,6 +143,7 @@ def case_data_array_accessor_success(
negative_categories,
rasterize_attributes,
memory_strategies,
comparison_function,
)


Expand Down
6 changes: 5 additions & 1 deletion tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@parametrize_with_cases(
"candidate_map, benchmark_map, positive_categories, negative_categories, rasterize_attributes, memory_strategies",
"candidate_map, benchmark_map, positive_categories, negative_categories, rasterize_attributes, memory_strategies, comparison_function",
glob="data_array_accessor_success",
)
def test_data_array_accessor_success(
Expand All @@ -22,6 +22,7 @@ def test_data_array_accessor_success(
negative_categories,
rasterize_attributes,
memory_strategies,
comparison_function,
):
adjust_memory_strategy(memory_strategies)

Expand All @@ -34,6 +35,9 @@ def test_data_array_accessor_success(
positive_categories=positive_categories,
negative_categories=negative_categories,
rasterize_attributes=rasterize_attributes,
comparison_function=comparison_function,
allow_candidate_values=[1, 2, np.nan],
allow_benchmark_values=[0, 2, np.nan],
)

if rasterize_attributes is None:
Expand Down

0 comments on commit c0ed2a1

Please sign in to comment.