Skip to content

Commit

Permalink
BUG: dataframe.update coercing dtype (#57637)
Browse files Browse the repository at this point in the history
  • Loading branch information
aureliobarbosa authored Mar 2, 2024
1 parent ddc3144 commit 8fde168
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 4 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ Bug fixes
~~~~~~~~~
- Fixed bug in :meth:`DataFrame.join` inconsistently setting result index name (:issue:`55815`)
- Fixed bug in :meth:`DataFrame.to_string` that raised ``StopIteration`` with nested DataFrames. (:issue:`16098`)
- Fixed bug in :meth:`DataFrame.update` bool dtype being converted to object (:issue:`55509`)
- Fixed bug in :meth:`Series.diff` allowing non-integer values for the ``periods`` argument. (:issue:`56607`)

Categorical
Expand Down
23 changes: 19 additions & 4 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8706,6 +8706,10 @@ def update(
dict.update : Similar method for dictionaries.
DataFrame.merge : For column(s)-on-column(s) operations.
Notes
-----
1. Duplicate indices on `other` are not supported and raises `ValueError`.
Examples
--------
>>> df = pd.DataFrame({"A": [1, 2, 3], "B": [400, 500, 600]})
Expand Down Expand Up @@ -8778,11 +8782,22 @@ def update(
if not isinstance(other, DataFrame):
other = DataFrame(other)

other = other.reindex(self.index)
if other.index.has_duplicates:
raise ValueError("Update not allowed with duplicate indexes on other.")

index_intersection = other.index.intersection(self.index)
if index_intersection.empty:
raise ValueError(
"Update not allowed when the index on `other` has no intersection "
"with this dataframe."
)

other = other.reindex(index_intersection)
this_data = self.loc[index_intersection]

for col in self.columns.intersection(other.columns):
this = self[col]._values
that = other[col]._values
this = this_data[col]
that = other[col]

if filter_func is not None:
mask = ~filter_func(this) | isna(that)
Expand All @@ -8802,7 +8817,7 @@ def update(
if mask.all():
continue

self.loc[:, col] = self[col].where(mask, that)
self.loc[index_intersection, col] = this.where(mask, that)

# ----------------------------------------------------------------------
# Data reshaping
Expand Down
52 changes: 52 additions & 0 deletions pandas/tests/frame/methods/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,55 @@ def test_update_dt_column_with_NaT_create_column(self):
{"A": [1.0, 3.0], "B": [pd.NaT, pd.to_datetime("2016-01-01")]}
)
tm.assert_frame_equal(df, expected)

@pytest.mark.parametrize(
"value_df, value_other, dtype",
[
(True, False, bool),
(1, 2, int),
(1.0, 2.0, float),
(1.0 + 1j, 2.0 + 2j, complex),
(np.uint64(1), np.uint(2), np.dtype("ubyte")),
(np.uint64(1), np.uint(2), np.dtype("intc")),
("a", "b", pd.StringDtype()),
(
pd.to_timedelta("1 ms"),
pd.to_timedelta("2 ms"),
np.dtype("timedelta64[ns]"),
),
(
np.datetime64("2000-01-01T00:00:00"),
np.datetime64("2000-01-02T00:00:00"),
np.dtype("datetime64[ns]"),
),
],
)
def test_update_preserve_dtype(self, value_df, value_other, dtype):
# GH#55509
df = DataFrame({"a": [value_df] * 2}, index=[1, 2], dtype=dtype)
other = DataFrame({"a": [value_other]}, index=[1], dtype=dtype)
expected = DataFrame({"a": [value_other, value_df]}, index=[1, 2], dtype=dtype)
df.update(other)
tm.assert_frame_equal(df, expected)

def test_update_raises_on_duplicate_argument_index(self):
# GH#55509
df = DataFrame({"a": [1, 1]}, index=[1, 2])
other = DataFrame({"a": [2, 3]}, index=[1, 1])
with pytest.raises(ValueError, match="duplicate index"):
df.update(other)

def test_update_raises_without_intersection(self):
# GH#55509
df = DataFrame({"a": [1]}, index=[1])
other = DataFrame({"a": [2]}, index=[2])
with pytest.raises(ValueError, match="no intersection"):
df.update(other)

def test_update_on_duplicate_frame_unique_argument_index(self):
# GH#55509
df = DataFrame({"a": [1, 1, 1]}, index=[1, 1, 2], dtype=np.dtype("intc"))
other = DataFrame({"a": [2, 3]}, index=[1, 2], dtype=np.dtype("intc"))
expected = DataFrame({"a": [2, 2, 3]}, index=[1, 1, 2], dtype=np.dtype("intc"))
df.update(other)
tm.assert_frame_equal(df, expected)

0 comments on commit 8fde168

Please sign in to comment.