diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 8be9f0ad78e..ae7369c80d1 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -6469,7 +6469,7 @@ def _get_replacement_values_for_columns( to_replace_columns = {col: [to_replace] for col in columns_dtype_map} values_columns = {col: [value] for col in columns_dtype_map} elif cudf.api.types.is_list_like(to_replace) or isinstance( - to_replace, ColumnBase + to_replace, (ColumnBase, BaseIndex) ): if is_scalar(value): to_replace_columns = {col: to_replace for col in columns_dtype_map} @@ -6483,7 +6483,9 @@ def _get_replacement_values_for_columns( ) for col in columns_dtype_map } - elif cudf.api.types.is_list_like(value): + elif cudf.api.types.is_list_like( + value + ) or cudf.utils.dtypes.is_column_like(value): if len(to_replace) != len(value): raise ValueError( f"Replacement lists must be " @@ -6495,9 +6497,6 @@ def _get_replacement_values_for_columns( col: to_replace for col in columns_dtype_map } values_columns = {col: value for col in columns_dtype_map} - elif cudf.utils.dtypes.is_column_like(value): - to_replace_columns = {col: to_replace for col in columns_dtype_map} - values_columns = {col: value for col in columns_dtype_map} else: raise TypeError( "value argument must be scalar, list-like or Series" @@ -6592,12 +6591,13 @@ def _get_replacement_values_for_columns( return all_na_columns, to_replace_columns, values_columns -def _is_series(obj): +def _is_series(obj: Any) -> bool: """ Checks if the `obj` is of type `cudf.Series` instead of checking for isinstance(obj, cudf.Series) + to avoid circular imports. """ - return isinstance(obj, Frame) and obj.ndim == 1 and obj.index is not None + return isinstance(obj, IndexedFrame) and obj.ndim == 1 @_performance_tracking diff --git a/python/cudf/cudf/tests/test_replace.py b/python/cudf/cudf/tests/test_replace.py index 1973fe6fb41..e5ee0127a74 100644 --- a/python/cudf/cudf/tests/test_replace.py +++ b/python/cudf/cudf/tests/test_replace.py @@ -1378,3 +1378,9 @@ def test_fillna_nan_and_null(): result = ser.fillna(2.2) expected = cudf.Series([2.2, 2.2, 1.1]) assert_eq(result, expected) + + +def test_replace_with_index_objects(): + result = cudf.Series([1, 2]).replace(cudf.Index([1]), cudf.Index([2])) + expected = pd.Series([1, 2]).replace(pd.Index([1]), pd.Index([2])) + assert_eq(result, expected)