From 8ea0cd10e18847ff6f53fb9ef7c1845a6e9aa882 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 9 Jan 2024 19:20:38 -0700 Subject: [PATCH] Cleanup (#315) * Cleanup * seen_groups for numbagg only * Use _atleast_1d in more places. --- flox/aggregations.py | 5 +++-- flox/core.py | 43 ++++++++++--------------------------------- flox/xrutils.py | 4 ++-- 3 files changed, 15 insertions(+), 37 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index b91d191b2..b5b5578fa 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -133,9 +133,10 @@ def _get_fill_value(dtype, fill_value): return fill_value -def _atleast_1d(inp): +def _atleast_1d(inp, min_length: int = 1): if xrutils.is_scalar(inp): - inp = (inp,) + inp = (inp,) * min_length + assert len(inp) >= min_length return inp diff --git a/flox/core.py b/flox/core.py index 4860593c6..df1b9292a 100644 --- a/flox/core.py +++ b/flox/core.py @@ -803,29 +803,11 @@ def chunk_reduce( dict """ - if not (isinstance(func, str) or callable(func)): - funcs = func - else: - funcs = (func,) + funcs = _atleast_1d(func) nfuncs = len(funcs) - - if isinstance(dtype, Sequence): - dtypes = dtype - else: - dtypes = (dtype,) * nfuncs - assert len(dtypes) >= nfuncs - - if isinstance(fill_value, Sequence): - fill_values = fill_value - else: - fill_values = (fill_value,) * nfuncs - assert len(fill_values) >= nfuncs - - if isinstance(kwargs, Sequence): - kwargss = kwargs - else: - kwargss = ({},) * nfuncs - assert len(kwargss) >= nfuncs + dtypes = _atleast_1d(dtype, nfuncs) + fill_values = _atleast_1d(fill_value, nfuncs) + kwargss = _atleast_1d({}, nfuncs) if kwargs is None else kwargs if isinstance(axis, Sequence): axes: T_Axes = axis @@ -862,7 +844,8 @@ def chunk_reduce( # do this *before* possible broadcasting below. # factorize_ has already taken care of offsetting - seen_groups = _unique(group_idx) + if engine == "numbagg": + seen_groups = _unique(group_idx) order = "C" if nax > 1: @@ -1551,12 +1534,9 @@ def dask_groupby_agg( groups = _extract_unknown_groups(reduced, dtype=by.dtype) group_chunks = ((np.nan,),) else: - if expected_groups is None: - expected_groups_ = _get_expected_groups(by_input, sort=sort) - else: - expected_groups_ = expected_groups - groups = (expected_groups_.to_numpy(),) - group_chunks = ((len(expected_groups_),),) + assert expected_groups is not None + groups = (expected_groups.to_numpy(),) + group_chunks = ((len(expected_groups),),) elif method == "cohorts": chunks_cohorts = find_group_cohorts( @@ -2063,10 +2043,7 @@ def groupby_reduce( is_bool_array = np.issubdtype(array.dtype, bool) array = array.astype(int) if is_bool_array else array - if isinstance(isbin, Sequence): - isbins = isbin - else: - isbins = (isbin,) * nby + isbins = _atleast_1d(isbin, nby) _assert_by_is_aligned(array.shape, bys) diff --git a/flox/xrutils.py b/flox/xrutils.py index 0ced6fbb6..515814da3 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -84,7 +84,7 @@ def __dask_tokenize__(self): def is_scalar(value: Any, include_0d: bool = True) -> bool: """Whether to treat a value as a scalar. - Any non-iterable, string, or 0-D array + Any non-iterable, string, dict, or 0-D array """ NON_NUMPY_SUPPORTED_ARRAY_TYPES = (dask_array_type, pd.Index) @@ -92,7 +92,7 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool: include_0d = getattr(value, "ndim", None) == 0 return ( include_0d - or isinstance(value, (str, bytes)) + or isinstance(value, (str, bytes, dict)) or not ( isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES) or hasattr(value, "__array_function__")