Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: implement BaseOpsUtil._cast_pointwise_result #54366

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import final

import numpy as np
import pytest

Expand All @@ -10,6 +12,15 @@


class BaseOpsUtil(BaseExtensionTests):
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
# In _check_op we check that the result of a pointwise operation
# (found via _combine) matches the result of the vectorized
# operation obj.__op_name__(other).
# In some cases pandas dtype inference on the scalar result may not
# give a matching dtype even if both operations are behaving "correctly".
# In these cases, do extra required casting here.
return pointwise_result

def get_op_from_name(self, op_name: str):
return tm.get_op_from_name(op_name)

Expand All @@ -18,6 +29,12 @@ def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception):

self._check_op(ser, op, other, op_name, exc)

# Subclasses are not expected to need to override _check_op or _combine.
# Ideally any relevant overriding can be done in _cast_pointwise_result,
# get_op_from_name, and the specification of `exc`. If you find a use
# case that still requires overriding _check_op or _combine, please let
# us know at github.com/pandas-dev/pandas/issues
@final
def _combine(self, obj, other, op):
if isinstance(obj, pd.DataFrame):
if len(obj.columns) != 1:
Expand All @@ -27,12 +44,15 @@ def _combine(self, obj, other, op):
expected = obj.combine(other, op)
return expected

# see comment on _combine
@final
def _check_op(
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
):
if exc is None:
result = op(ser, other)
expected = self._combine(ser, other, op)
expected = self._cast_pointwise_result(op_name, ser, other, expected)
assert isinstance(result, type(ser))
tm.assert_equal(result, expected)
else:
Expand Down
14 changes: 7 additions & 7 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,11 +873,11 @@ def rtruediv(x, y):

return tm.get_op_from_name(op_name)

def _combine(self, obj, other, op):
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
# BaseOpsUtil._combine can upcast expected dtype
# (because it generates expected on python scalars)
# while ArrowExtensionArray maintains original type
expected = base.BaseArithmeticOpsTests._combine(self, obj, other, op)
expected = pointwise_result

was_frame = False
if isinstance(expected, pd.DataFrame):
Expand All @@ -895,7 +895,7 @@ def _combine(self, obj, other, op):
pa.types.is_floating(orig_pa_type)
or (
pa.types.is_integer(orig_pa_type)
and op.__name__ not in ["truediv", "rtruediv"]
and op_name not in ["__truediv__", "__rtruediv__"]
)
or pa.types.is_duration(orig_pa_type)
or pa.types.is_timestamp(orig_pa_type)
Expand All @@ -906,7 +906,7 @@ def _combine(self, obj, other, op):
# ArrowExtensionArray does not upcast
return expected
elif not (
(op is operator.floordiv and pa.types.is_integer(orig_pa_type))
(op_name == "__floordiv__" and pa.types.is_integer(orig_pa_type))
or pa.types.is_duration(orig_pa_type)
or pa.types.is_timestamp(orig_pa_type)
or pa.types.is_date(orig_pa_type)
Expand Down Expand Up @@ -943,14 +943,14 @@ def _combine(self, obj, other, op):
):
# decimal precision can resize in the result type depending on data
# just compare the float values
alt = op(obj, other)
alt = getattr(obj, op_name)(other)
alt_dtype = tm.get_dtype(alt)
assert isinstance(alt_dtype, ArrowDtype)
if op is operator.pow and isinstance(other, Decimal):
if op_name == "__pow__" and isinstance(other, Decimal):
# TODO: would it make more sense to retain Decimal here?
alt_dtype = ArrowDtype(pa.float64())
elif (
op is operator.pow
op_name == "__pow__"
and isinstance(other, pd.Series)
and other.dtype == original_dtype
):
Expand Down
68 changes: 35 additions & 33 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
be added to the array-specific tests in `pandas/tests/arrays/`.

"""
import operator

import numpy as np
import pytest

Expand All @@ -23,6 +25,7 @@

import pandas as pd
import pandas._testing as tm
from pandas.core import roperator
from pandas.core.arrays.boolean import BooleanDtype
from pandas.tests.extension import base

Expand Down Expand Up @@ -125,41 +128,40 @@ def check_opname(self, s, op_name, other, exc=None):
if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]:
# match behavior with non-masked bool dtype
exc = NotImplementedError
elif op_name in self.implements:
# exception message would include "numpy boolean subtract""
exc = TypeError

super().check_opname(s, op_name, other, exc=exc)

def _check_op(self, obj, op, other, op_name, exc=NotImplementedError):
if exc is None:
if op_name in self.implements:
msg = r"numpy boolean subtract"
with pytest.raises(TypeError, match=msg):
op(obj, other)
return

result = op(obj, other)
expected = self._combine(obj, other, op)

if op_name in (
"__floordiv__",
"__rfloordiv__",
"__pow__",
"__rpow__",
"__mod__",
"__rmod__",
):
# combine keeps boolean type
expected = expected.astype("Int8")
elif op_name in ("__truediv__", "__rtruediv__"):
# combine with bools does not generate the correct result
# (numpy behaviour for div is to regard the bools as numeric)
expected = self._combine(obj.astype(float), other, op)
expected = expected.astype("Float64")
if op_name == "__rpow__":
# for rpow, combine does not propagate NaN
expected[result.isna()] = np.nan
tm.assert_equal(result, expected)
else:
with pytest.raises(exc):
op(obj, other)
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
if op_name in (
"__floordiv__",
"__rfloordiv__",
"__pow__",
"__rpow__",
"__mod__",
"__rmod__",
):
# combine keeps boolean type
pointwise_result = pointwise_result.astype("Int8")

elif op_name in ("__truediv__", "__rtruediv__"):
# combine with bools does not generate the correct result
# (numpy behaviour for div is to regard the bools as numeric)
if op_name == "__truediv__":
op = operator.truediv
else:
op = roperator.rtruediv
pointwise_result = self._combine(obj.astype(float), other, op)
pointwise_result = pointwise_result.astype("Float64")

if op_name == "__rpow__":
# for rpow, combine does not propagate NaN
result = getattr(obj, op_name)(other)
pointwise_result[result.isna()] = np.nan

return pointwise_result

@pytest.mark.xfail(
reason="Inconsistency between floordiv and divmod; we raise for floordiv "
Expand Down
60 changes: 13 additions & 47 deletions pandas/tests/extension/test_masked_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,45 +146,20 @@ class TestDtype(base.BaseDtypeTests):


class TestArithmeticOps(base.BaseArithmeticOpsTests):
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
if exc is None:
sdtype = tm.get_dtype(s)

if hasattr(other, "dtype") and isinstance(other.dtype, np.dtype):
if sdtype.kind == "f":
if other.dtype.kind == "f":
# other is np.float64 and would therefore always result
# in upcasting, so keeping other as same numpy_dtype
other = other.astype(sdtype.numpy_dtype)

else:
# i.e. sdtype.kind in "iu""
if other.dtype.kind in "iu" and sdtype.is_unsigned_integer:
# TODO: comment below is inaccurate; other can be int8
# int16, ...
# and the trouble is that e.g. if s is UInt8 and other
# is int8, then result is UInt16
# other is np.int64 and would therefore always result in
# upcasting, so keeping other as same numpy_dtype
other = other.astype(sdtype.numpy_dtype)

result = op(s, other)
expected = self._combine(s, other, op)

if sdtype.kind in "iu":
if op_name in ("__rtruediv__", "__truediv__", "__div__"):
expected = expected.fillna(np.nan).astype("Float64")
else:
# combine method result in 'biggest' (int64) dtype
expected = expected.astype(sdtype)
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
sdtype = tm.get_dtype(obj)
expected = pointwise_result

if sdtype.kind in "iu":
if op_name in ("__rtruediv__", "__truediv__", "__div__"):
expected = expected.fillna(np.nan).astype("Float64")
else:
# combine method result in 'biggest' (float64) dtype
# combine method result in 'biggest' (int64) dtype
expected = expected.astype(sdtype)

tm.assert_equal(result, expected)
else:
with pytest.raises(exc):
op(s, other)
# combine method result in 'biggest' (float64) dtype
expected = expected.astype(sdtype)
return expected

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
# overwriting to indicate ops don't raise an error
Expand All @@ -195,17 +170,8 @@ def _check_divmod_op(self, ser: pd.Series, op, other, exc=None):


class TestComparisonOps(base.BaseComparisonOpsTests):
def _check_op(
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
):
if exc is None:
result = op(ser, other)
# Override to do the astype to boolean
expected = ser.combine(other, op).astype("boolean")
tm.assert_series_equal(result, expected)
else:
with pytest.raises(exc):
op(ser, other)
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
return pointwise_result.astype("boolean")

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
super().check_opname(ser, op_name, other, exc=None)
Expand Down