From 2c754b1cd1ab414d195e5f3d867ebf50f2bfd762 Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Wed, 2 Aug 2023 21:45:58 -0400 Subject: [PATCH 1/2] Public rasterize and vectorize method --- pyproject.toml | 2 +- src/gval/accessors/gval_dataframe.py | 37 +++++++++++++++++++++++ src/gval/accessors/gval_xarray.py | 14 ++++++++- src/gval/comparison/compute_comparison.py | 2 +- tests/cases_accessors.py | 10 +++++- tests/test_accessors.py | 13 ++++++++ 6 files changed, 74 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a229113a..744555f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ requires-python = ">=3.8" keywords = ["geospatial", "evaluations"] license = {text = "MIT"} -version = "0.1.1" +version = "0.1.2" dynamic = ["readme", "dependencies"] diff --git a/src/gval/accessors/gval_dataframe.py b/src/gval/accessors/gval_dataframe.py index 8550de96..b2c3ceff 100644 --- a/src/gval/accessors/gval_dataframe.py +++ b/src/gval/accessors/gval_dataframe.py @@ -3,9 +3,11 @@ import pandas as pd from pandera.typing import DataFrame +import xarray as xr from gval.comparison.compute_categorical_metrics import _compute_categorical_metrics from gval.utils.schemas import Metrics_df +from gval.homogenize.rasterize import _rasterize_data @pd.api.extensions.register_dataframe_accessor("gval") @@ -88,3 +90,38 @@ def compute_categorical_metrics( average=average, weights=weights, ) + + def rasterize_data( + self, reference_map: Union[xr.Dataset, xr.DataArray], rasterize_attributes: list + ) -> Union[xr.Dataset, xr.DataArray]: + """ + Convenience function for rasterizing vector data using a reference raster. For more control use `make_geocube` + from the geocube package. + + Parameters + ---------- + reference_map: Union[xr.Dataset, xr.DataArray] + Map to reference in creation of rasterized vector map + rasterize_attributes: list + Attributes to rasterize + + Returns + ------- + Union[xr.Dataset, xr.DataArray] + Rasterized Data + + Raises + ------ + KeyError + + References + ---------- + .. [1] [geocube `make_geocube`](https://corteva.github.io/geocube/html/geocube.html) + + """ + + return _rasterize_data( + candidate_map=reference_map, + benchmark_map=self._obj, + rasterize_attributes=rasterize_attributes, + ) diff --git a/src/gval/accessors/gval_xarray.py b/src/gval/accessors/gval_xarray.py index dc39a9cb..c604c7df 100644 --- a/src/gval/accessors/gval_xarray.py +++ b/src/gval/accessors/gval_xarray.py @@ -462,7 +462,7 @@ def compute_agreement_map( ) if self.agreement_map_format == "vector": - agreement_map = _vectorize_data(agreement_map) + agreement_map = agreement_map.gval.vectorize_data() return agreement_map @@ -677,3 +677,15 @@ def cont_plot( basemap=basemap, colorbar_label=colorbar_label, ) + + def vectorize_data(self) -> gpd.GeoDataFrame: + """ + Vectorize an xarray DataArray or Dataset + + Returns + ------- + gpd.GeoDataFrame + Vectorized data + """ + + return _vectorize_data(self._obj) diff --git a/src/gval/comparison/compute_comparison.py b/src/gval/comparison/compute_comparison.py index caee3f2f..6d30672b 100644 --- a/src/gval/comparison/compute_comparison.py +++ b/src/gval/comparison/compute_comparison.py @@ -284,7 +284,7 @@ def wrapper(*args, **kwargs): unique_benchmark_values=kwargs["allow_benchmark_values"], ) - else: + if "comparison_function" not in kwargs: kwargs["comparison_function"] = getattr(self, "szudzik") # Call the decorated function diff --git a/tests/cases_accessors.py b/tests/cases_accessors.py index 902042cd..46be359c 100644 --- a/tests/cases_accessors.py +++ b/tests/cases_accessors.py @@ -391,6 +391,14 @@ def case_continuous_plot_fail(candidate_map): return candidate_map +@parametrize( + "vector_map, reference_map, attributes", + list(zip(benchmark_maps[2:3], candidate_maps[0:1], [["category"]])), +) +def case_dataframe_accessor_rasterize(vector_map, reference_map, attributes): + return vector_map, reference_map, attributes + + candidate_maps = ["candidate_continuous_0.tif", "candidate_continuous_1.tif"] benchmark_maps = ["benchmark_continuous_0.tif", "benchmark_continuous_1.tif"] @@ -424,4 +432,4 @@ def case_data_set_accessor_continuous(candidate_map, benchmark_map): list(zip(candidate_maps, benchmark_maps, agreement_maps)), ) def case_accessor_attributes(candidate_map, benchmark_map, agreement_map): - return (_load_xarray(candidate_map), _load_xarray(benchmark_map), agreement_map) + return _load_xarray(candidate_map), _load_xarray(benchmark_map), agreement_map diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 14140a15..2a1fd779 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -317,3 +317,16 @@ def test_accessor_attributes(candidate_map, benchmark_map, agreement_map): assert isinstance(agreement_map, xr.DataArray) assert isinstance(attrs_df, DataFrame) + + +@parametrize_with_cases( + "vector_map, reference_map, attributes", + glob="dataframe_accessor_rasterize", +) +def test_dataframe_accessor_rasterize(vector_map, reference_map, attributes): + raster_map = vector_map.gval.rasterize_data( + reference_map=reference_map, rasterize_attributes=attributes + ) + + assert isinstance(raster_map, type(reference_map)) + assert raster_map.shape == reference_map.shape From 63154c6e457ffc76943fa9ce1f2f06e152ed0587 Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Thu, 3 Aug 2023 00:18:05 -0400 Subject: [PATCH 2/2] Test update --- src/gval/statistics/categorical_stat_funcs.py | 2 +- tests/cases_compute_categorical_metrics.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/gval/statistics/categorical_stat_funcs.py b/src/gval/statistics/categorical_stat_funcs.py index 56a119ac..1707ce6f 100644 --- a/src/gval/statistics/categorical_stat_funcs.py +++ b/src/gval/statistics/categorical_stat_funcs.py @@ -324,7 +324,7 @@ def prevalence(tp: Number, tn: Number, fp: Number, fn: Number) -> float: ---------- .. [1] [Prevalence](https://en.wikipedia.org/wiki/Prevalence) """ - return (tp + fp) / (tp + fp + tn + fn) + return (tp + fn) / (tp + fp + tn + fn) def accuracy(tp: Number, tn: Number, fp: Number, fn: Number) -> float: diff --git a/tests/cases_compute_categorical_metrics.py b/tests/cases_compute_categorical_metrics.py index 6cb18776..3f3e0274 100644 --- a/tests/cases_compute_categorical_metrics.py +++ b/tests/cases_compute_categorical_metrics.py @@ -35,7 +35,7 @@ "overall_bias": {0: 0.9535906369448297}, "positive_likelihood_ratio": {0: 1.3660848657244327}, "positive_predictive_value": {0: 0.7891880387357667}, - "prevalence": {0: 0.6986443954301008}, + "prevalence": {0: 0.7326460310772968}, "prevalence_threshold": {0: 0.46108525048654536}, "true_negative_rate": {0: 0.44911012235817577}, "true_positive_rate": {0: 0.7525623245272807}, @@ -85,7 +85,7 @@ 0: 0.31210513014910285, 1: 0.6623376623376623, }, - "prevalence": {0: 0.42776066158586024, 1: 0.4514317452480869}, + "prevalence": {0: 0.1928544403005243, 1: 0.5266600839298938}, "prevalence_threshold": {0: 0.4205207112769604, 1: 0.42959744253992793}, "true_negative_rate": {0: 0.635438291033282, 1: 0.6779661016949152}, "true_positive_rate": {0: 0.6922645739910314, 1: 0.5677290836653387}, @@ -476,12 +476,12 @@ def case_compute_categorical_metrics_fail( "tn": [28, 28, 20, 20, 12, 12], "tp": [1, 1, 5, 5, 9, 9], "prevalence": [ - 0.133333, - 0.133333, - 0.333333, - 0.333333, - 0.533333, - 0.533333, + 0.26666666666666666, + 0.26666666666666666, + 0.3333333333333333, + 0.3333333333333333, + 0.4, + 0.4, ], "true_negative_rate": [ 0.848485,