Skip to content

Commit

Permalink
CLN: BaseGrouper (#59034)
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke authored Jun 17, 2024
1 parent 8395f98 commit ee05885
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 52 deletions.
3 changes: 2 additions & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,8 @@ def nunique(self, dropna: bool = True) -> Series | DataFrame:
b 1
dtype: int64
"""
ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
val = self.obj._values
codes, uniques = algorithms.factorize(val, use_na_sentinel=dropna, sort=False)

Expand Down
17 changes: 11 additions & 6 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,7 @@ def _wrap_applied_output(

@final
def _numba_prep(self, data: DataFrame):
ids, ngroups = self._grouper.group_info
ngroups = self._grouper.ngroups
sorted_index = self._grouper.result_ilocs
sorted_ids = self._grouper._sorted_ids

Expand Down Expand Up @@ -1969,7 +1969,8 @@ def _cumcount_array(self, ascending: bool = True) -> np.ndarray:
this is currently implementing sort=False
(though the default is sort=True) for groupby in general
"""
ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
sorter = get_group_index_sorter(ids, ngroups)
ids, count = ids[sorter], len(ids)

Expand Down Expand Up @@ -2185,7 +2186,8 @@ def count(self) -> NDFrameT:
Freq: MS, dtype: int64
"""
data = self._get_data_to_aggregate()
ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
mask = ids != -1

is_series = data.ndim == 1
Expand Down Expand Up @@ -3840,7 +3842,8 @@ def _fill(self, direction: Literal["ffill", "bfill"], limit: int | None = None):
if limit is None:
limit = -1

ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups

col_func = partial(
libgroupby.group_fillna_indexer,
Expand Down Expand Up @@ -4361,7 +4364,8 @@ def post_processor(
qs = np.array([q], dtype=np.float64)
pass_qs = None

ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
if self.dropna:
# splitter drops NA groups, we need to do the same
ids = ids[ids >= 0]
Expand Down Expand Up @@ -5038,7 +5042,8 @@ def shift(
else:
if fill_value is lib.no_default:
fill_value = None
ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
res_indexer = np.zeros(len(ids), dtype=np.int64)

libgroupby.group_shift_indexer(res_indexer, ids, ngroups, period)
Expand Down
57 changes: 13 additions & 44 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
Generator,
Hashable,
Iterator,
Sequence,
)

from pandas.core.generic import NDFrame
Expand Down Expand Up @@ -581,25 +580,21 @@ class BaseGrouper:
def __init__(
self,
axis: Index,
groupings: Sequence[grouper.Grouping],
groupings: list[grouper.Grouping],
sort: bool = True,
dropna: bool = True,
) -> None:
assert isinstance(axis, Index), axis

self.axis = axis
self._groupings: list[grouper.Grouping] = list(groupings)
self._groupings = groupings
self._sort = sort
self.dropna = dropna

@property
def groupings(self) -> list[grouper.Grouping]:
return self._groupings

@property
def shape(self) -> Shape:
return tuple(ping.ngroups for ping in self.groupings)

def __iter__(self) -> Iterator[Hashable]:
return iter(self.indices)

Expand Down Expand Up @@ -628,11 +623,15 @@ def _get_splitter(self, data: NDFrame) -> DataSplitter:
-------
Generator yielding subsetted objects
"""
ids, ngroups = self.group_info
return _get_splitter(
if isinstance(data, Series):
klass: type[DataSplitter] = SeriesSplitter
else:
# i.e. DataFrame
klass = FrameSplitter

return klass(
data,
ids,
ngroups,
self.ngroups,
sorted_ids=self._sorted_ids,
sort_idx=self.result_ilocs,
)
Expand Down Expand Up @@ -692,7 +691,8 @@ def size(self) -> Series:
"""
Compute group sizes.
"""
ids, ngroups = self.group_info
ids = self.ids
ngroups = self.ngroups
out: np.ndarray | list
if ngroups:
out = np.bincount(ids[ids != -1], minlength=ngroups)
Expand Down Expand Up @@ -729,12 +729,6 @@ def has_dropped_na(self) -> bool:
"""
return bool((self.ids < 0).any())

@cache_readonly
def group_info(self) -> tuple[npt.NDArray[np.intp], int]:
result_index, ids = self.result_index_and_ids
ngroups = len(result_index)
return ids, ngroups

@cache_readonly
def codes_info(self) -> npt.NDArray[np.intp]:
# return the codes of items in original grouped axis
Expand Down Expand Up @@ -1123,10 +1117,6 @@ def indices(self):
i = bin
return indices

@cache_readonly
def group_info(self) -> tuple[npt.NDArray[np.intp], int]:
return self.ids, self.ngroups

@cache_readonly
def codes(self) -> list[npt.NDArray[np.intp]]:
return [self.ids]
Expand Down Expand Up @@ -1191,29 +1181,25 @@ class DataSplitter(Generic[NDFrameT]):
def __init__(
self,
data: NDFrameT,
labels: npt.NDArray[np.intp],
ngroups: int,
*,
sort_idx: npt.NDArray[np.intp],
sorted_ids: npt.NDArray[np.intp],
) -> None:
self.data = data
self.labels = ensure_platform_int(labels) # _should_ already be np.intp
self.ngroups = ngroups

self._slabels = sorted_ids
self._sort_idx = sort_idx

def __iter__(self) -> Iterator:
sdata = self._sorted_data

if self.ngroups == 0:
# we are inside a generator, rather than raise StopIteration
# we merely return signal the end
return

starts, ends = lib.generate_slices(self._slabels, self.ngroups)

sdata = self._sorted_data
for start, end in zip(starts, ends):
yield self._chop(sdata, slice(start, end))

Expand Down Expand Up @@ -1241,20 +1227,3 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
mgr = sdata._mgr.get_slice(slice_obj, axis=1)
df = sdata._constructor_from_mgr(mgr, axes=mgr.axes)
return df.__finalize__(sdata, method="groupby")


def _get_splitter(
data: NDFrame,
labels: npt.NDArray[np.intp],
ngroups: int,
*,
sort_idx: npt.NDArray[np.intp],
sorted_ids: npt.NDArray[np.intp],
) -> DataSplitter:
if isinstance(data, Series):
klass: type[DataSplitter] = SeriesSplitter
else:
# i.e. DataFrame
klass = FrameSplitter

return klass(data, labels, ngroups, sort_idx=sort_idx, sorted_ids=sorted_ids)
4 changes: 3 additions & 1 deletion pandas/tests/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def test_int64_overflow_groupby_large_df_shuffled(self, agg):
gr = df.groupby(list("abcde"))

# verify this is testing what it is supposed to test!
assert is_int64_overflow_possible(gr._grouper.shape)
assert is_int64_overflow_possible(
tuple(ping.ngroups for ping in gr._grouper.groupings)
)

mi = MultiIndex.from_arrays(
[ar.ravel() for ar in np.array_split(np.unique(arr, axis=0), 5, axis=1)],
Expand Down

0 comments on commit ee05885

Please sign in to comment.