Skip to content

Commit

Permalink
REF: implement TestExtension (#54432)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Aug 7, 2023
1 parent 809f371 commit bbe11b2
Show file tree
Hide file tree
Showing 19 changed files with 79 additions and 103 deletions.
56 changes: 42 additions & 14 deletions pandas/tests/extension/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,61 @@ class TestMyDtype(BaseDtypeTests):
wherever the test requires it. You're free to implement additional tests.
"""
from pandas.tests.extension.base.accumulate import BaseAccumulateTests # noqa: F401
from pandas.tests.extension.base.casting import BaseCastingTests # noqa: F401
from pandas.tests.extension.base.constructors import BaseConstructorsTests # noqa: F401
from pandas.tests.extension.base.accumulate import BaseAccumulateTests
from pandas.tests.extension.base.casting import BaseCastingTests
from pandas.tests.extension.base.constructors import BaseConstructorsTests
from pandas.tests.extension.base.dim2 import ( # noqa: F401
Dim2CompatTests,
NDArrayBacked2DTests,
)
from pandas.tests.extension.base.dtype import BaseDtypeTests # noqa: F401
from pandas.tests.extension.base.getitem import BaseGetitemTests # noqa: F401
from pandas.tests.extension.base.groupby import BaseGroupbyTests # noqa: F401
from pandas.tests.extension.base.index import BaseIndexTests # noqa: F401
from pandas.tests.extension.base.interface import BaseInterfaceTests # noqa: F401
from pandas.tests.extension.base.io import BaseParsingTests # noqa: F401
from pandas.tests.extension.base.methods import BaseMethodsTests # noqa: F401
from pandas.tests.extension.base.missing import BaseMissingTests # noqa: F401
from pandas.tests.extension.base.dtype import BaseDtypeTests
from pandas.tests.extension.base.getitem import BaseGetitemTests
from pandas.tests.extension.base.groupby import BaseGroupbyTests
from pandas.tests.extension.base.index import BaseIndexTests
from pandas.tests.extension.base.interface import BaseInterfaceTests
from pandas.tests.extension.base.io import BaseParsingTests
from pandas.tests.extension.base.methods import BaseMethodsTests
from pandas.tests.extension.base.missing import BaseMissingTests
from pandas.tests.extension.base.ops import ( # noqa: F401
BaseArithmeticOpsTests,
BaseComparisonOpsTests,
BaseOpsUtil,
BaseUnaryOpsTests,
)
from pandas.tests.extension.base.printing import BasePrintingTests # noqa: F401
from pandas.tests.extension.base.printing import BasePrintingTests
from pandas.tests.extension.base.reduce import ( # noqa: F401
BaseBooleanReduceTests,
BaseNoReduceTests,
BaseNumericReduceTests,
BaseReduceTests,
)
from pandas.tests.extension.base.reshaping import BaseReshapingTests # noqa: F401
from pandas.tests.extension.base.setitem import BaseSetitemTests # noqa: F401
from pandas.tests.extension.base.reshaping import BaseReshapingTests
from pandas.tests.extension.base.setitem import BaseSetitemTests


# One test class that you can inherit as an alternative to inheriting all the
# test classes above.
# Note 1) this excludes Dim2CompatTests and NDArrayBacked2DTests.
# Note 2) this uses BaseReduceTests and and _not_ BaseBooleanReduceTests,
# BaseNoReduceTests, or BaseNumericReduceTests
class ExtensionTests(
BaseAccumulateTests,
BaseCastingTests,
BaseConstructorsTests,
BaseDtypeTests,
BaseGetitemTests,
BaseGroupbyTests,
BaseIndexTests,
BaseInterfaceTests,
BaseParsingTests,
BaseMethodsTests,
BaseMissingTests,
BaseArithmeticOpsTests,
BaseComparisonOpsTests,
BaseUnaryOpsTests,
BasePrintingTests,
BaseReduceTests,
BaseReshapingTests,
BaseSetitemTests,
):
pass
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import pandas as pd
import pandas._testing as tm
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseAccumulateTests(BaseExtensionTests):
class BaseAccumulateTests:
"""
Accumulation specific tests. Generally these only
make sense for numeric/boolean operations.
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import pandas as pd
import pandas._testing as tm
from pandas.core.internals.blocks import NumpyBlock
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseCastingTests(BaseExtensionTests):
class BaseCastingTests:
"""Casting to and from ExtensionDtypes"""

def test_astype_object_series(self, all_data):
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import pandas._testing as tm
from pandas.api.extensions import ExtensionArray
from pandas.core.internals.blocks import EABackedBlock
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseConstructorsTests(BaseExtensionTests):
class BaseConstructorsTests:
def test_from_sequence_from_cls(self, data):
result = type(data)._from_sequence(data, dtype=data.dtype)
tm.assert_extension_array_equal(result, data)
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/dim2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
import pandas as pd
import pandas._testing as tm
from pandas.core.arrays.integer import NUMPY_INT_TO_DTYPE
from pandas.tests.extension.base.base import BaseExtensionTests


class Dim2CompatTests(BaseExtensionTests):
class Dim2CompatTests:
# Note: these are ONLY for ExtensionArray subclasses that support 2D arrays.
# i.e. not for pyarrow-backed EAs.

Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
is_object_dtype,
is_string_dtype,
)
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseDtypeTests(BaseExtensionTests):
class BaseDtypeTests:
"""Base class for ExtensionDtype classes"""

def test_name(self, dtype):
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import pandas as pd
import pandas._testing as tm
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseGetitemTests(BaseExtensionTests):
class BaseGetitemTests:
"""Tests for ExtensionArray.__getitem__."""

def test_iloc_series(self, data):
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@

import pandas as pd
import pandas._testing as tm
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseGroupbyTests(BaseExtensionTests):
class BaseGroupbyTests:
"""Groupby-specific tests."""

def test_grouping_grouper(self, data_for_grouping):
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
Tests for Indexes backed by arbitrary ExtensionArrays.
"""
import pandas as pd
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseIndexTests(BaseExtensionTests):
class BaseIndexTests:
"""Tests for Index object backed by an ExtensionArray"""

def test_index_from_array(self, data):
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

import pandas as pd
import pandas._testing as tm
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseInterfaceTests(BaseExtensionTests):
class BaseInterfaceTests:
"""Tests that the basic interface is satisfied."""

# ------------------------------------------------------------------------
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import pandas as pd
import pandas._testing as tm
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseParsingTests(BaseExtensionTests):
class BaseParsingTests:
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data):
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
import pandas as pd
import pandas._testing as tm
from pandas.core.sorting import nargsort
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseMethodsTests(BaseExtensionTests):
class BaseMethodsTests:
"""Various Series and DataFrame methods."""

def test_hash_pandas_object(self, data):
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import pandas as pd
import pandas._testing as tm
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseMissingTests(BaseExtensionTests):
class BaseMissingTests:
def test_isna(self, data_missing):
expected = np.array([True, False])

Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import pandas as pd
import pandas._testing as tm
from pandas.core import ops
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseOpsUtil(BaseExtensionTests):
class BaseOpsUtil:
series_scalar_exc: type[Exception] | None = TypeError
frame_scalar_exc: type[Exception] | None = TypeError
series_array_exc: type[Exception] | None = TypeError
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import pytest

import pandas as pd
from pandas.tests.extension.base.base import BaseExtensionTests


class BasePrintingTests(BaseExtensionTests):
class BasePrintingTests:
"""Tests checking the formatting of your EA when printed."""

@pytest.mark.parametrize("size", ["big", "small"])
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import pandas as pd
import pandas._testing as tm
from pandas.api.types import is_numeric_dtype
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseReduceTests(BaseExtensionTests):
class BaseReduceTests:
"""
Reduction specific tests. Generally these only
make sense for numeric/boolean operations.
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/reshaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import pandas._testing as tm
from pandas.api.extensions import ExtensionArray
from pandas.core.internals.blocks import EABackedBlock
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseReshapingTests(BaseExtensionTests):
class BaseReshapingTests:
"""Tests for reshaping and concatenation."""

@pytest.mark.parametrize("in_frame", [True, False])
Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/extension/base/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import pandas as pd
import pandas._testing as tm
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseSetitemTests(BaseExtensionTests):
class BaseSetitemTests:
@pytest.fixture(
params=[
lambda x: x.index,
Expand Down
75 changes: 20 additions & 55 deletions pandas/tests/extension/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def data_missing():
return IntervalArray.from_tuples([None, (0, 1)])


@pytest.fixture
def data_for_twos():
pytest.skip("Not a numeric dtype")


@pytest.fixture
def data_for_sorting():
return IntervalArray.from_tuples([(1, 2), (2, 3), (0, 1)])
Expand All @@ -65,74 +70,34 @@ def data_for_grouping():
return IntervalArray.from_tuples([b, b, None, None, a, a, b, c])


class BaseInterval:
pass


class TestDtype(BaseInterval, base.BaseDtypeTests):
pass


class TestCasting(BaseInterval, base.BaseCastingTests):
pass


class TestConstructors(BaseInterval, base.BaseConstructorsTests):
pass


class TestGetitem(BaseInterval, base.BaseGetitemTests):
pass


class TestIndex(base.BaseIndexTests):
pass


class TestGrouping(BaseInterval, base.BaseGroupbyTests):
pass


class TestInterface(BaseInterval, base.BaseInterfaceTests):
pass
class TestIntervalArray(base.ExtensionTests):
divmod_exc = TypeError


class TestReduce(base.BaseReduceTests):
def _supports_reduction(self, obj, op_name: str) -> bool:
return op_name in ["min", "max"]


class TestMethods(BaseInterval, base.BaseMethodsTests):
@pytest.mark.xfail(
reason="Raises with incorrect message bc it disallows *all* listlikes "
"instead of just wrong-length listlikes"
)
def test_fillna_length_mismatch(self, data_missing):
super().test_fillna_length_mismatch(data_missing)


class TestMissing(BaseInterval, base.BaseMissingTests):
def test_fillna_non_scalar_raises(self, data_missing):
msg = "can only insert Interval objects and NA into an IntervalArray"
with pytest.raises(TypeError, match=msg):
data_missing.fillna([1, 1])


class TestReshaping(BaseInterval, base.BaseReshapingTests):
pass


class TestSetitem(BaseInterval, base.BaseSetitemTests):
pass


class TestPrinting(BaseInterval, base.BasePrintingTests):
pass


class TestParsing(BaseInterval, base.BaseParsingTests):
@pytest.mark.parametrize("engine", ["c", "python"])
def test_EA_types(self, engine, data):
expected_msg = r".*must implement _from_sequence_of_strings.*"
with pytest.raises(NotImplementedError, match=expected_msg):
super().test_EA_types(engine, data)

@pytest.mark.xfail(
reason="Looks like the test (incorrectly) implicitly assumes int/bool dtype"
)
def test_invert(self, data):
super().test_invert(data)


# TODO: either belongs in tests.arrays.interval or move into base tests.
def test_fillna_non_scalar_raises(data_missing):
msg = "can only insert Interval objects and NA into an IntervalArray"
with pytest.raises(TypeError, match=msg):
data_missing.fillna([1, 1])

0 comments on commit bbe11b2

Please sign in to comment.