Skip to content

Commit

Permalink
Support multiple quantiles with xarray (#332)
Browse files Browse the repository at this point in the history
* Support multiple quantiles with xarray

* Fix test

* type: ignore

* Bug fix

* Another bug fix

* Fix typing

* Bugfix and cleanup

* fix typing

* Another bug fix

* More xarray testing

* comment

* xfail test
  • Loading branch information
dcherian authored Feb 7, 2024
1 parent b12bcfa commit 71c9538
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 47 deletions.
8 changes: 6 additions & 2 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ def quantile_(array, inv_idx, *, q, axis, skipna, dtype=None, out=None):
if skipna:
sizes = np.add.reduceat(notnull(array), inv_idx[:-1], axis=axis)
else:
sizes = np.reshape(np.diff(inv_idx), (1,) * (array.ndim - 1) + (inv_idx.size - 1,))
nanmask = isnull(np.take_along_axis(array, sizes - 1, axis=axis))
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
sizes = np.reshape(np.diff(inv_idx), newshape)
# NaNs get sorted to the end, so look at the last element in the group to decide
# if there are NaNs
last_group_elem = np.broadcast_to(inv_idx[1:] - 1, newshape)
nanmask = isnull(np.take_along_axis(array, last_group_elem, axis=axis))

qin = q
q = np.atleast_1d(qin)
Expand Down
49 changes: 35 additions & 14 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import copy
import logging
import warnings
from functools import partial
from dataclasses import dataclass
from functools import cached_property, partial
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np
from numpy.typing import DTypeLike
from numpy.typing import ArrayLike, DTypeLike

from . import aggregate_flox, aggregate_npg, xrutils
from . import xrdtypes as dtypes
Expand Down Expand Up @@ -151,6 +152,20 @@ def returns_empty_tuple(*args, **kwargs):
return ()


@dataclass
class Dim:
values: ArrayLike
name: str | None

@cached_property
def is_scalar(self) -> bool:
return xrutils.is_scalar(self.values)

@cached_property
def size(self) -> int:
return 0 if self.is_scalar else len(self.values) # type: ignore[arg-type]


class Aggregation:
def __init__(
self,
Expand All @@ -166,7 +181,7 @@ def __init__(
dtypes=None,
final_dtype: DTypeLike | None = None,
reduction_type: Literal["reduce", "argreduce"] = "reduce",
new_axes_func: Callable | None = None,
new_dims_func: Callable | None = None,
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -209,7 +224,7 @@ def __init__(
per reduction in ``chunk`` as a tuple.
final_dtype : DType, optional
DType for output. By default, uses dtype of array being reduced.
new_axes_func: Callable
new_dims_func: Callable
Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions
added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2,
so returns (2,)
Expand Down Expand Up @@ -246,12 +261,17 @@ def __init__(
# The following are set by _initialize_aggregation
self.finalize_kwargs: dict[Any, Any] = {}
self.min_count: int = 0
self.new_axes_func: Callable = (
returns_empty_tuple if new_axes_func is None else new_axes_func
self.new_dims_func: Callable = (
returns_empty_tuple if new_dims_func is None else new_dims_func
)

def get_new_axes(self):
return self.new_axes_func(**self.finalize_kwargs)
@cached_property
def new_dims(self) -> tuple[Dim]:
return self.new_dims_func(**self.finalize_kwargs)

@cached_property
def num_new_vector_dims(self) -> int:
return len(tuple(dim for dim in self.new_dims if not dim.is_scalar))

def _normalize_dtype_fill_value(self, value, name):
value = _atleast_1d(value)
Expand Down Expand Up @@ -511,8 +531,8 @@ def _pick_second(*x):
)


def quantile_new_axes_func(q):
return tuple() if xrutils.is_scalar(q) else (len(q),)
def quantile_new_dims_func(q) -> tuple[Dim]:
return (Dim(name="quantile", values=q),)


quantile = Aggregation(
Expand All @@ -521,15 +541,15 @@ def quantile_new_axes_func(q):
chunk=None,
combine=None,
final_dtype=np.float64,
new_axes_func=quantile_new_axes_func,
new_dims_func=quantile_new_dims_func,
)
nanquantile = Aggregation(
name="nanquantile",
fill_value=dtypes.NA,
chunk=None,
combine=None,
final_dtype=np.float64,
new_axes_func=quantile_new_axes_func,
new_dims_func=quantile_new_dims_func,
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
Expand Down Expand Up @@ -638,9 +658,10 @@ def _initialize_aggregation(
# where the identity element is 0, 1
if min_count > 0:
agg.min_count = min_count
agg.chunk += ("nanlen",)
agg.numpy += ("nanlen",)
agg.combine += ("sum",)
if agg.chunk != (None,):
agg.chunk += ("nanlen",)
agg.combine += ("sum",)
agg.fill_value["intermediate"] += (0,)
agg.fill_value["numpy"] += (0,)
agg.dtype["intermediate"] += (np.intp,)
Expand Down
39 changes: 21 additions & 18 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
_atleast_1d,
_initialize_aggregation,
generic_aggregate,
quantile_new_axes_func,
quantile_new_dims_func,
)
from .cache import memoize
from .xrutils import (
Expand Down Expand Up @@ -1006,7 +1006,9 @@ def chunk_reduce(
result = result[..., :-1]
# TODO: Figure out how to generalize this
if reduction in ("quantile", "nanquantile"):
new_dims_shape = quantile_new_axes_func(**kw)
new_dims_shape = tuple(
dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar
)
else:
new_dims_shape = tuple()
result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape)
Expand Down Expand Up @@ -1044,7 +1046,7 @@ def _finalize_results(
3. Mask using counts and fill with user-provided fill_value.
4. reindex to expected_groups
"""
squeezed = _squeeze_results(results, axis)
squeezed = _squeeze_results(results, tuple(agg.num_new_vector_dims + ax for ax in axis))

min_count = agg.min_count
if min_count > 0:
Expand Down Expand Up @@ -1671,7 +1673,7 @@ def dask_groupby_agg(
raise ValueError(f"Unknown method={method}.")

# Adjust output for any new dimensions added, example for multiple quantiles
new_dims_shape = agg.get_new_axes()
new_dims_shape = tuple(dim.size for dim in agg.new_dims if not dim.is_scalar)
new_inds = tuple(range(-len(new_dims_shape), 0))
out_inds = new_inds + inds[: -len(axis)] + (inds[-1],)
output_chunks = new_dims_shape + reduced.chunks[: -len(axis)] + group_chunks
Expand Down Expand Up @@ -2297,7 +2299,21 @@ def groupby_reduce(
# TODO: How else to narrow that array.chunks is there?
assert isinstance(array, DaskArray)

if agg.chunk[0] is None and method not in [None, "blockwise"]:
if (not any_by_dask and method is None) or method == "cohorts":
preferred_method, chunks_cohorts = find_group_cohorts(
by_,
[array.chunks[ax] for ax in range(-by_.ndim, 0)],
expected_groups=expected_,
# when provided with cohorts, we *always* 'merge'
merge=(method == "cohorts"),
)
else:
preferred_method = "map-reduce"
chunks_cohorts = {}

method = _choose_method(method, preferred_method, agg, by_, nax)

if agg.chunk[0] is None and method != "blockwise":
raise NotImplementedError(
f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'."
f"Received method={method!r}"
Expand All @@ -2318,19 +2334,6 @@ def groupby_reduce(
f"Received method={method!r}"
)

if (not any_by_dask and method is None) or method == "cohorts":
preferred_method, chunks_cohorts = find_group_cohorts(
by_,
[array.chunks[ax] for ax in range(-by_.ndim, 0)],
expected_groups=expected_,
# when provided with cohorts, we *always* 'merge'
merge=(method == "cohorts"),
)
else:
preferred_method = "map-reduce"
chunks_cohorts = {}

method = _choose_method(method, preferred_method, agg, by_, nax)
# TODO: clean this up
reindex = _validate_reindex(
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)
Expand Down
32 changes: 28 additions & 4 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from packaging.version import Version
from xarray.core.duck_array_ops import _datetime_nanmin

from .aggregations import Aggregation, _atleast_1d
from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
from .core import (
_convert_expected_groups_to_index,
_get_expected_groups,
Expand Down Expand Up @@ -74,7 +74,7 @@ def xarray_reduce(
dim: Dims | ellipsis = None,
fill_value=None,
dtype: np.typing.DTypeLike = None,
method: str = "map-reduce",
method: str | None = None,
engine: str | None = None,
keep_attrs: bool | None = True,
skipna: bool | None = None,
Expand Down Expand Up @@ -387,6 +387,17 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):

result, *groups = groupby_reduce(array, *by, func=func, **kwargs)

# Transpose the new quantile dimension to the end. This is ugly.
# but new core dimensions are expected at the end :/
# but groupby_reduce inserts them at the beginning
if func in ["quantile", "nanquantile"]:
(newdim,) = quantile_new_dims_func(**finalize_kwargs)
if not newdim.is_scalar:
# NOTE: _restore_dim_order will move any new dims to the end anyway.
# This transpose is simply makes it easy to specify output_core_dims
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
result = np.moveaxis(result, 0, -1)

# Output of count has an int dtype.
if requires_numeric and func != "count":
if is_npdatetime:
Expand All @@ -412,8 +423,18 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)]
input_core_dims += [list(b.dims) for b in by_da]

newdims: tuple[Dim, ...] = (
quantile_new_dims_func(**finalize_kwargs) if func in ["quantile", "nanquantile"] else ()
)

output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple]
output_core_dims.extend(group_names)
vector_dims = [dim.name for dim in newdims if not dim.is_scalar]
output_core_dims.extend(vector_dims)

output_sizes = group_sizes
output_sizes.update({dim.name: dim.size for dim in newdims if dim.size != 0})

actual = xr.apply_ufunc(
wrapper,
ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
Expand All @@ -424,7 +445,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
output_core_dims=[output_core_dims],
dask="allowed",
dask_gufunc_kwargs=dict(
output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
output_sizes=output_sizes, output_dtypes=[dtype] if dtype is not None else None
),
keep_attrs=keep_attrs,
kwargs={
Expand All @@ -451,6 +472,9 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
if all(d not in ds_broad[var].dims for d in dim_tuple):
actual[var] = ds_broad[var]

for newdim in newdims:
actual.coords[newdim.name] = newdim.values if newdim.is_scalar else np.array(newdim.values)

expect3: T_ExpectIndex | np.ndarray
for name, expect2, by_ in zip(group_names, expected_groups_valid_list, by_da):
# Can't remove this until xarray handles IntervalIndex:
Expand Down Expand Up @@ -492,7 +516,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
else:
template = obj

if actual[var].ndim > 1:
if actual[var].ndim > 1 + len(vector_dims):
no_groupby_reorder = isinstance(
obj, xr.Dataset
) # do not re-order dataarrays inside datasets
Expand Down
38 changes: 29 additions & 9 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
fill_value = np.nan
tolerance = {"rtol": 1e-14, "atol": 1e-16}
elif "quantile" in func:
finalize_kwargs = [{"q": DEFAULT_QUANTILE}]
finalize_kwargs = [{"q": DEFAULT_QUANTILE}, {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]}]
fill_value = None
tolerance = None
else:
Expand All @@ -265,6 +265,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
array_func = _get_array_func(func)

for kwargs in finalize_kwargs:
if "quantile" in func and isinstance(kwargs["q"], list) and engine != "flox":
continue
flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value)
with np.errstate(invalid="ignore", divide="ignore"):
with warnings.catch_warnings():
Expand All @@ -289,10 +291,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):

if func in BLOCKWISE_FUNCS:
assert chunks == -1
flox_kwargs["method"] = "blockwise"

actual, *groups = groupby_reduce(array, *by, **flox_kwargs)
assert actual.ndim == expected.ndim == (array.ndim + nby - 1)
if "quantile" in func and isinstance(kwargs["q"], list):
assert actual.ndim == expected.ndim == (array.ndim + nby)
else:
assert actual.ndim == expected.ndim == (array.ndim + nby - 1)

expected_groups = tuple(np.array([idx + 1.0]) for idx in range(nby))
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect)
Expand Down Expand Up @@ -598,6 +603,15 @@ def test_nanfirst_nanlast_disallowed_dask(axis, func):


@requires_dask
@pytest.mark.xfail
@pytest.mark.parametrize("func", ["first", "last"])
def test_first_last_allowed_dask(func):
# blockwise should be fine... but doesn't work now.
groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=-1)


@requires_dask
@pytest.mark.xfail
@pytest.mark.parametrize("func", ["first", "last"])
def test_first_last_disallowed_dask(func):
# blockwise is fine
Expand Down Expand Up @@ -1678,19 +1692,25 @@ def test_xarray_fill_value_behaviour():
assert_equal(expected, actual)


@pytest.mark.parametrize("q", (0.5, (0.5,), (0.5, 0.85)))
@pytest.mark.parametrize("q", (0.5, (0.5,), (0.5, 0.67, 0.85)))
@pytest.mark.parametrize("func", ["nanquantile", "quantile"])
@pytest.mark.parametrize("chunk", [pytest.param(True, marks=requires_dask), False])
def test_multiple_quantiles(q, chunk, func):
@pytest.mark.parametrize("by_ndim", [1, 2])
def test_multiple_quantiles(q, chunk, func, by_ndim):
array = np.array([[1, -1, np.nan, 3, 4, 10, 5], [1, np.nan, np.nan, 3, 4, np.nan, np.nan]])
labels = np.array([0, 0, 0, 1, 0, 1, 1])
axis = -1
if by_ndim == 2:
labels = np.broadcast_to(labels, (5, *labels.shape))
array = np.broadcast_to(np.expand_dims(array, -2), (2, 5, array.shape[-1]))
axis = tuple(range(-by_ndim, 0))

if chunk:
array = dask.array.from_array(array, chunks=(1, -1))
array = dask.array.from_array(array, chunks=(1,) + (-1,) * by_ndim)

actual, _ = groupby_reduce(array, labels, func=func, finalize_kwargs=dict(q=q), axis=axis)
sorted_array = array[..., [0, 1, 2, 4, 3, 5, 6]]
f = partial(getattr(np, func), q=q, axis=axis, keepdims=True)
expected = np.concatenate((f(sorted_array[..., :4]), f(sorted_array[..., 4:])), axis=axis)
assert_equal(expected, actual)
expected = np.concatenate((f(sorted_array[..., :4]), f(sorted_array[..., 4:])), axis=-1)
if by_ndim == 2:
expected = expected.squeeze(axis=-2)
assert_equal(expected, actual, tolerance=1e-14)
Loading

0 comments on commit 71c9538

Please sign in to comment.