Skip to content

Commit

Permalink
Improve MultiIndex label rename checks
Browse files Browse the repository at this point in the history
  • Loading branch information
TabLand committed Feb 19, 2024
1 parent 1f622e2 commit 26aa99c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 12 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ Other
^^^^^
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)
- Bug in :meth:`DataFrame.where` where using a non-bool type array in the function would return a ``ValueError`` instead of a ``TypeError`` (:issue:`56330`)
- Fixed bug in :meth:`DataFrame.rename` where checks on argument errors="raise" are not consistent with the actual transformation applied (:issue:`55169`). Logic change is accompanied with improvement to docs, a new test and a more descriptive ``KeyError`` message when a tuple label rename is attempted across :class:`MultiIndex` levels

.. ***DO NOT USE THIS SECTION***
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5400,6 +5400,11 @@ def rename(
level : int or level name, default None
In case of a MultiIndex, only rename labels in the specified
level.
.. note::
Labels are renamed individually, and not via tuples across
MultiIndex levels
errors : {'ignore', 'raise'}, default 'ignore'
If 'raise', raise a `KeyError` when a dict-like `mapper`, `index`,
or `columns` contains labels that are not present in the Index
Expand Down
42 changes: 30 additions & 12 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,18 +1035,36 @@ def _rename(

# GH 13473
if not callable(replacements):
if ax._is_multi and level is not None:
indexer = ax.get_level_values(level).get_indexer_for(replacements)
else:
indexer = ax.get_indexer_for(replacements)

if errors == "raise" and len(indexer[indexer == -1]):
missing_labels = [
label
for index, label in enumerate(replacements)
if indexer[index] == -1
]
raise KeyError(f"{missing_labels} not found in axis")
if errors == "raise":
missing_labels = []
for replacement in replacements:
if ax._is_multi:
indexers = [
ax.get_level_values(i).get_indexer_for([replacement])
for i in range(ax.nlevels)
if i == level or level is None
]
else:
indexers = [ax.get_indexer_for([replacement])]

found_anywhere = any(any(indexer != -1) for indexer in indexers)
if not found_anywhere:
missing_labels.append(replacement)

if len(missing_labels) > 0:
error = f"{missing_labels} not found in axis"
if ax._is_multi:
tuple_rename_tried = any(
type(label) is tuple and label in ax
for label in missing_labels
)
if tuple_rename_tried:
error += (
". Please provide individual labels for "
"replacement, and not tuples across "
"MultiIndex levels"
)
raise KeyError(error)

new_index = ax._transform_index(f, level=level)
result._set_axis_nocheck(new_index, axis=axis_no, inplace=True)
Expand Down
24 changes: 24 additions & 0 deletions pandas/tests/frame/methods/test_rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ def test_rename_multiindex(self):
renamed = df.rename(index={"foo1": "foo3", "bar2": "bar3"}, level=0)
tm.assert_index_equal(renamed.index, new_index)

def test_rename_multiindex_with_checks(self):
df = DataFrame({("a", "count"): [1, 2], ("a", "sum"): [3, 4]})
renamed = df.rename(
columns={"a": "b", "count": "number_of", "sum": "total"}, errors="raise"
)

new_columns = MultiIndex.from_tuples([("b", "number_of"), ("b", "total")])

tm.assert_index_equal(renamed.columns, new_columns)

def test_rename_nocopy(self, float_frame):
renamed = float_frame.rename(columns={"C": "foo"}, copy=False)

Expand Down Expand Up @@ -221,6 +231,20 @@ def test_rename_errors_raises(self):
with pytest.raises(KeyError, match="'E'] not found in axis"):
df.rename(columns={"A": "a", "E": "e"}, errors="raise")

def test_rename_error_raised_for_label_across_multiindex_levels(self):
df = DataFrame([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
df = df.groupby("a").agg({"b": ("count", "sum")})
with pytest.raises(
KeyError,
match=(
"\\[\\('b', 'count'\\)\\] not found "
"in axis\\. Please provide individual "
"labels for replacement, and not "
"tuples across MultiIndex levels"
),
):
df.rename(columns={("b", "count"): "new"}, errors="raise")

@pytest.mark.parametrize(
"mapper, errors, expected_columns",
[
Expand Down

0 comments on commit 26aa99c

Please sign in to comment.