From 71c95385beb50f2baf49cdc754b012844ebf220e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 7 Feb 2024 11:48:42 -0700 Subject: [PATCH] Support multiple quantiles with xarray (#332) * Support multiple quantiles with xarray * Fix test * type: ignore * Bug fix * Another bug fix * Fix typing * Bugfix and cleanup * fix typing * Another bug fix * More xarray testing * comment * xfail test --- flox/aggregate_flox.py | 8 +++++-- flox/aggregations.py | 49 ++++++++++++++++++++++++++++++------------ flox/core.py | 39 +++++++++++++++++---------------- flox/xarray.py | 32 +++++++++++++++++++++++---- tests/test_core.py | 38 ++++++++++++++++++++++++-------- tests/test_xarray.py | 25 +++++++++++++++++++++ 6 files changed, 144 insertions(+), 47 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index b90804c90..adb838f93 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -60,8 +60,12 @@ def quantile_(array, inv_idx, *, q, axis, skipna, dtype=None, out=None): if skipna: sizes = np.add.reduceat(notnull(array), inv_idx[:-1], axis=axis) else: - sizes = np.reshape(np.diff(inv_idx), (1,) * (array.ndim - 1) + (inv_idx.size - 1,)) - nanmask = isnull(np.take_along_axis(array, sizes - 1, axis=axis)) + newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,) + sizes = np.reshape(np.diff(inv_idx), newshape) + # NaNs get sorted to the end, so look at the last element in the group to decide + # if there are NaNs + last_group_elem = np.broadcast_to(inv_idx[1:] - 1, newshape) + nanmask = isnull(np.take_along_axis(array, last_group_elem, axis=axis)) qin = q q = np.atleast_1d(qin) diff --git a/flox/aggregations.py b/flox/aggregations.py index 332a29999..86b0f5f2e 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -3,11 +3,12 @@ import copy import logging import warnings -from functools import partial +from dataclasses import dataclass +from functools import cached_property, partial from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict import numpy as np -from numpy.typing import DTypeLike +from numpy.typing import ArrayLike, DTypeLike from . import aggregate_flox, aggregate_npg, xrutils from . import xrdtypes as dtypes @@ -151,6 +152,20 @@ def returns_empty_tuple(*args, **kwargs): return () +@dataclass +class Dim: + values: ArrayLike + name: str | None + + @cached_property + def is_scalar(self) -> bool: + return xrutils.is_scalar(self.values) + + @cached_property + def size(self) -> int: + return 0 if self.is_scalar else len(self.values) # type: ignore[arg-type] + + class Aggregation: def __init__( self, @@ -166,7 +181,7 @@ def __init__( dtypes=None, final_dtype: DTypeLike | None = None, reduction_type: Literal["reduce", "argreduce"] = "reduce", - new_axes_func: Callable | None = None, + new_dims_func: Callable | None = None, ): """ Blueprint for computing grouped aggregations. @@ -209,7 +224,7 @@ def __init__( per reduction in ``chunk`` as a tuple. final_dtype : DType, optional DType for output. By default, uses dtype of array being reduced. - new_axes_func: Callable + new_dims_func: Callable Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2, so returns (2,) @@ -246,12 +261,17 @@ def __init__( # The following are set by _initialize_aggregation self.finalize_kwargs: dict[Any, Any] = {} self.min_count: int = 0 - self.new_axes_func: Callable = ( - returns_empty_tuple if new_axes_func is None else new_axes_func + self.new_dims_func: Callable = ( + returns_empty_tuple if new_dims_func is None else new_dims_func ) - def get_new_axes(self): - return self.new_axes_func(**self.finalize_kwargs) + @cached_property + def new_dims(self) -> tuple[Dim]: + return self.new_dims_func(**self.finalize_kwargs) + + @cached_property + def num_new_vector_dims(self) -> int: + return len(tuple(dim for dim in self.new_dims if not dim.is_scalar)) def _normalize_dtype_fill_value(self, value, name): value = _atleast_1d(value) @@ -511,8 +531,8 @@ def _pick_second(*x): ) -def quantile_new_axes_func(q): - return tuple() if xrutils.is_scalar(q) else (len(q),) +def quantile_new_dims_func(q) -> tuple[Dim]: + return (Dim(name="quantile", values=q),) quantile = Aggregation( @@ -521,7 +541,7 @@ def quantile_new_axes_func(q): chunk=None, combine=None, final_dtype=np.float64, - new_axes_func=quantile_new_axes_func, + new_dims_func=quantile_new_dims_func, ) nanquantile = Aggregation( name="nanquantile", @@ -529,7 +549,7 @@ def quantile_new_axes_func(q): chunk=None, combine=None, final_dtype=np.float64, - new_axes_func=quantile_new_axes_func, + new_dims_func=quantile_new_dims_func, ) mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) @@ -638,9 +658,10 @@ def _initialize_aggregation( # where the identity element is 0, 1 if min_count > 0: agg.min_count = min_count - agg.chunk += ("nanlen",) agg.numpy += ("nanlen",) - agg.combine += ("sum",) + if agg.chunk != (None,): + agg.chunk += ("nanlen",) + agg.combine += ("sum",) agg.fill_value["intermediate"] += (0,) agg.fill_value["numpy"] += (0,) agg.dtype["intermediate"] += (np.intp,) diff --git a/flox/core.py b/flox/core.py index 5ffa9055c..7a9b01cfe 100644 --- a/flox/core.py +++ b/flox/core.py @@ -34,7 +34,7 @@ _atleast_1d, _initialize_aggregation, generic_aggregate, - quantile_new_axes_func, + quantile_new_dims_func, ) from .cache import memoize from .xrutils import ( @@ -1006,7 +1006,9 @@ def chunk_reduce( result = result[..., :-1] # TODO: Figure out how to generalize this if reduction in ("quantile", "nanquantile"): - new_dims_shape = quantile_new_axes_func(**kw) + new_dims_shape = tuple( + dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar + ) else: new_dims_shape = tuple() result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape) @@ -1044,7 +1046,7 @@ def _finalize_results( 3. Mask using counts and fill with user-provided fill_value. 4. reindex to expected_groups """ - squeezed = _squeeze_results(results, axis) + squeezed = _squeeze_results(results, tuple(agg.num_new_vector_dims + ax for ax in axis)) min_count = agg.min_count if min_count > 0: @@ -1671,7 +1673,7 @@ def dask_groupby_agg( raise ValueError(f"Unknown method={method}.") # Adjust output for any new dimensions added, example for multiple quantiles - new_dims_shape = agg.get_new_axes() + new_dims_shape = tuple(dim.size for dim in agg.new_dims if not dim.is_scalar) new_inds = tuple(range(-len(new_dims_shape), 0)) out_inds = new_inds + inds[: -len(axis)] + (inds[-1],) output_chunks = new_dims_shape + reduced.chunks[: -len(axis)] + group_chunks @@ -2297,7 +2299,21 @@ def groupby_reduce( # TODO: How else to narrow that array.chunks is there? assert isinstance(array, DaskArray) - if agg.chunk[0] is None and method not in [None, "blockwise"]: + if (not any_by_dask and method is None) or method == "cohorts": + preferred_method, chunks_cohorts = find_group_cohorts( + by_, + [array.chunks[ax] for ax in range(-by_.ndim, 0)], + expected_groups=expected_, + # when provided with cohorts, we *always* 'merge' + merge=(method == "cohorts"), + ) + else: + preferred_method = "map-reduce" + chunks_cohorts = {} + + method = _choose_method(method, preferred_method, agg, by_, nax) + + if agg.chunk[0] is None and method != "blockwise": raise NotImplementedError( f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'." f"Received method={method!r}" @@ -2318,19 +2334,6 @@ def groupby_reduce( f"Received method={method!r}" ) - if (not any_by_dask and method is None) or method == "cohorts": - preferred_method, chunks_cohorts = find_group_cohorts( - by_, - [array.chunks[ax] for ax in range(-by_.ndim, 0)], - expected_groups=expected_, - # when provided with cohorts, we *always* 'merge' - merge=(method == "cohorts"), - ) - else: - preferred_method = "map-reduce" - chunks_cohorts = {} - - method = _choose_method(method, preferred_method, agg, by_, nax) # TODO: clean this up reindex = _validate_reindex( reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array) diff --git a/flox/xarray.py b/flox/xarray.py index b388deb09..3200d7f0a 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -9,7 +9,7 @@ from packaging.version import Version from xarray.core.duck_array_ops import _datetime_nanmin -from .aggregations import Aggregation, _atleast_1d +from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func from .core import ( _convert_expected_groups_to_index, _get_expected_groups, @@ -74,7 +74,7 @@ def xarray_reduce( dim: Dims | ellipsis = None, fill_value=None, dtype: np.typing.DTypeLike = None, - method: str = "map-reduce", + method: str | None = None, engine: str | None = None, keep_attrs: bool | None = True, skipna: bool | None = None, @@ -387,6 +387,17 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): result, *groups = groupby_reduce(array, *by, func=func, **kwargs) + # Transpose the new quantile dimension to the end. This is ugly. + # but new core dimensions are expected at the end :/ + # but groupby_reduce inserts them at the beginning + if func in ["quantile", "nanquantile"]: + (newdim,) = quantile_new_dims_func(**finalize_kwargs) + if not newdim.is_scalar: + # NOTE: _restore_dim_order will move any new dims to the end anyway. + # This transpose is simply makes it easy to specify output_core_dims + # output dim order: (*broadcast_dims, *group_dims, quantile_dim) + result = np.moveaxis(result, 0, -1) + # Output of count has an int dtype. if requires_numeric and func != "count": if is_npdatetime: @@ -412,8 +423,18 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)] input_core_dims += [list(b.dims) for b in by_da] + newdims: tuple[Dim, ...] = ( + quantile_new_dims_func(**finalize_kwargs) if func in ["quantile", "nanquantile"] else () + ) + output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple] output_core_dims.extend(group_names) + vector_dims = [dim.name for dim in newdims if not dim.is_scalar] + output_core_dims.extend(vector_dims) + + output_sizes = group_sizes + output_sizes.update({dim.name: dim.size for dim in newdims if dim.size != 0}) + actual = xr.apply_ufunc( wrapper, ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims), @@ -424,7 +445,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): output_core_dims=[output_core_dims], dask="allowed", dask_gufunc_kwargs=dict( - output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None + output_sizes=output_sizes, output_dtypes=[dtype] if dtype is not None else None ), keep_attrs=keep_attrs, kwargs={ @@ -451,6 +472,9 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): if all(d not in ds_broad[var].dims for d in dim_tuple): actual[var] = ds_broad[var] + for newdim in newdims: + actual.coords[newdim.name] = newdim.values if newdim.is_scalar else np.array(newdim.values) + expect3: T_ExpectIndex | np.ndarray for name, expect2, by_ in zip(group_names, expected_groups_valid_list, by_da): # Can't remove this until xarray handles IntervalIndex: @@ -492,7 +516,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): else: template = obj - if actual[var].ndim > 1: + if actual[var].ndim > 1 + len(vector_dims): no_groupby_reorder = isinstance( obj, xr.Dataset ) # do not re-order dataarrays inside datasets diff --git a/tests/test_core.py b/tests/test_core.py index bbe6098be..6837eb963 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -254,7 +254,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): fill_value = np.nan tolerance = {"rtol": 1e-14, "atol": 1e-16} elif "quantile" in func: - finalize_kwargs = [{"q": DEFAULT_QUANTILE}] + finalize_kwargs = [{"q": DEFAULT_QUANTILE}, {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]}] fill_value = None tolerance = None else: @@ -265,6 +265,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): array_func = _get_array_func(func) for kwargs in finalize_kwargs: + if "quantile" in func and isinstance(kwargs["q"], list) and engine != "flox": + continue flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value) with np.errstate(invalid="ignore", divide="ignore"): with warnings.catch_warnings(): @@ -289,10 +291,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if func in BLOCKWISE_FUNCS: assert chunks == -1 - flox_kwargs["method"] = "blockwise" actual, *groups = groupby_reduce(array, *by, **flox_kwargs) - assert actual.ndim == expected.ndim == (array.ndim + nby - 1) + if "quantile" in func and isinstance(kwargs["q"], list): + assert actual.ndim == expected.ndim == (array.ndim + nby) + else: + assert actual.ndim == expected.ndim == (array.ndim + nby - 1) + expected_groups = tuple(np.array([idx + 1.0]) for idx in range(nby)) for actual_group, expect in zip(groups, expected_groups): assert_equal(actual_group, expect) @@ -598,6 +603,15 @@ def test_nanfirst_nanlast_disallowed_dask(axis, func): @requires_dask +@pytest.mark.xfail +@pytest.mark.parametrize("func", ["first", "last"]) +def test_first_last_allowed_dask(func): + # blockwise should be fine... but doesn't work now. + groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=-1) + + +@requires_dask +@pytest.mark.xfail @pytest.mark.parametrize("func", ["first", "last"]) def test_first_last_disallowed_dask(func): # blockwise is fine @@ -1678,19 +1692,25 @@ def test_xarray_fill_value_behaviour(): assert_equal(expected, actual) -@pytest.mark.parametrize("q", (0.5, (0.5,), (0.5, 0.85))) +@pytest.mark.parametrize("q", (0.5, (0.5,), (0.5, 0.67, 0.85))) @pytest.mark.parametrize("func", ["nanquantile", "quantile"]) @pytest.mark.parametrize("chunk", [pytest.param(True, marks=requires_dask), False]) -def test_multiple_quantiles(q, chunk, func): +@pytest.mark.parametrize("by_ndim", [1, 2]) +def test_multiple_quantiles(q, chunk, func, by_ndim): array = np.array([[1, -1, np.nan, 3, 4, 10, 5], [1, np.nan, np.nan, 3, 4, np.nan, np.nan]]) labels = np.array([0, 0, 0, 1, 0, 1, 1]) - axis = -1 + if by_ndim == 2: + labels = np.broadcast_to(labels, (5, *labels.shape)) + array = np.broadcast_to(np.expand_dims(array, -2), (2, 5, array.shape[-1])) + axis = tuple(range(-by_ndim, 0)) if chunk: - array = dask.array.from_array(array, chunks=(1, -1)) + array = dask.array.from_array(array, chunks=(1,) + (-1,) * by_ndim) actual, _ = groupby_reduce(array, labels, func=func, finalize_kwargs=dict(q=q), axis=axis) sorted_array = array[..., [0, 1, 2, 4, 3, 5, 6]] f = partial(getattr(np, func), q=q, axis=axis, keepdims=True) - expected = np.concatenate((f(sorted_array[..., :4]), f(sorted_array[..., 4:])), axis=axis) - assert_equal(expected, actual) + expected = np.concatenate((f(sorted_array[..., :4]), f(sorted_array[..., 4:])), axis=-1) + if by_ndim == 2: + expected = expected.squeeze(axis=-2) + assert_equal(expected, actual, tolerance=1e-14) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 5909db46b..95ab2eff3 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -685,3 +685,28 @@ def test_resampling_missing_groups(chunk): with xr.set_options(use_flox=True): actual = da.resample(time="1D").mean() xr.testing.assert_identical(expected, actual) + + +@pytest.mark.parametrize("q", (0.5, (0.5,), (0.5, 0.67, 0.85))) +@pytest.mark.parametrize("skipna", [False, True]) +@pytest.mark.parametrize("chunk", [pytest.param(True, marks=requires_dask), False]) +@pytest.mark.parametrize("by_ndim", [1, 2]) +def test_multiple_quantiles(q, chunk, by_ndim, skipna): + array = np.array([[1, -1, np.nan, 3, 4, 10, 5], [1, np.nan, np.nan, 3, 4, np.nan, np.nan]]) + labels = np.array([0, 0, 0, 1, 0, 1, 1]) + dims = ("y",) + if by_ndim == 2: + labels = np.broadcast_to(labels, (5, *labels.shape)) + array = np.broadcast_to(np.expand_dims(array, -2), (2, 5, array.shape[-1])) + dims += ("y0",) + + if chunk: + array = dask.array.from_array(array, chunks=(1,) + (-1,) * by_ndim) + + da = xr.DataArray(array, dims=("x", *dims)) + by = xr.DataArray(labels, dims=dims, name="by") + + actual = xarray_reduce(da, by, func="quantile", skipna=skipna, q=q) + with xr.set_options(use_flox=False): + expected = da.groupby(by).quantile(q, skipna=skipna) + xr.testing.assert_allclose(expected, actual)