Skip to content

Commit

Permalink
Cleanup (#315)
Browse files Browse the repository at this point in the history
* Cleanup

* seen_groups for numbagg only

* Use _atleast_1d in more places.
  • Loading branch information
dcherian authored Jan 10, 2024
1 parent 0c4a7f9 commit 8ea0cd1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 37 deletions.
5 changes: 3 additions & 2 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ def _get_fill_value(dtype, fill_value):
return fill_value


def _atleast_1d(inp):
def _atleast_1d(inp, min_length: int = 1):
if xrutils.is_scalar(inp):
inp = (inp,)
inp = (inp,) * min_length
assert len(inp) >= min_length
return inp


Expand Down
43 changes: 10 additions & 33 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,29 +803,11 @@ def chunk_reduce(
dict
"""

if not (isinstance(func, str) or callable(func)):
funcs = func
else:
funcs = (func,)
funcs = _atleast_1d(func)
nfuncs = len(funcs)

if isinstance(dtype, Sequence):
dtypes = dtype
else:
dtypes = (dtype,) * nfuncs
assert len(dtypes) >= nfuncs

if isinstance(fill_value, Sequence):
fill_values = fill_value
else:
fill_values = (fill_value,) * nfuncs
assert len(fill_values) >= nfuncs

if isinstance(kwargs, Sequence):
kwargss = kwargs
else:
kwargss = ({},) * nfuncs
assert len(kwargss) >= nfuncs
dtypes = _atleast_1d(dtype, nfuncs)
fill_values = _atleast_1d(fill_value, nfuncs)
kwargss = _atleast_1d({}, nfuncs) if kwargs is None else kwargs

if isinstance(axis, Sequence):
axes: T_Axes = axis
Expand Down Expand Up @@ -862,7 +844,8 @@ def chunk_reduce(

# do this *before* possible broadcasting below.
# factorize_ has already taken care of offsetting
seen_groups = _unique(group_idx)
if engine == "numbagg":
seen_groups = _unique(group_idx)

order = "C"
if nax > 1:
Expand Down Expand Up @@ -1551,12 +1534,9 @@ def dask_groupby_agg(
groups = _extract_unknown_groups(reduced, dtype=by.dtype)
group_chunks = ((np.nan,),)
else:
if expected_groups is None:
expected_groups_ = _get_expected_groups(by_input, sort=sort)
else:
expected_groups_ = expected_groups
groups = (expected_groups_.to_numpy(),)
group_chunks = ((len(expected_groups_),),)
assert expected_groups is not None
groups = (expected_groups.to_numpy(),)
group_chunks = ((len(expected_groups),),)

elif method == "cohorts":
chunks_cohorts = find_group_cohorts(
Expand Down Expand Up @@ -2063,10 +2043,7 @@ def groupby_reduce(
is_bool_array = np.issubdtype(array.dtype, bool)
array = array.astype(int) if is_bool_array else array

if isinstance(isbin, Sequence):
isbins = isbin
else:
isbins = (isbin,) * nby
isbins = _atleast_1d(isbin, nby)

_assert_by_is_aligned(array.shape, bys)

Expand Down
4 changes: 2 additions & 2 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def __dask_tokenize__(self):
def is_scalar(value: Any, include_0d: bool = True) -> bool:
"""Whether to treat a value as a scalar.
Any non-iterable, string, or 0-D array
Any non-iterable, string, dict, or 0-D array
"""
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (dask_array_type, pd.Index)

if include_0d:
include_0d = getattr(value, "ndim", None) == 0
return (
include_0d
or isinstance(value, (str, bytes))
or isinstance(value, (str, bytes, dict))
or not (
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(value, "__array_function__")
Expand Down

0 comments on commit 8ea0cd1

Please sign in to comment.