diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 6f7f42eca3794..f919a299eda4d 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -637,6 +637,9 @@ def __invert__(self) -> Self: # This is a bit wise op for integer types if pa.types.is_integer(self._pa_array.type): return type(self)(pc.bit_wise_not(self._pa_array)) + elif pa.types.is_string(self._pa_array.type): + # Raise TypeError instead of pa.ArrowNotImplementedError + raise TypeError("__invert__ is not supported for string dtypes") else: return type(self)(pc.invert(self._pa_array)) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index de9f872aca01d..4835cb11db042 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -753,7 +753,11 @@ def test_EA_types(self, engine, data, dtype_backend, request): def test_invert(self, data, request): pa_dtype = data.dtype.pyarrow_dtype - if not (pa.types.is_boolean(pa_dtype) or pa.types.is_integer(pa_dtype)): + if not ( + pa.types.is_boolean(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_string(pa_dtype) + ): request.applymarker( pytest.mark.xfail( raises=pa.ArrowNotImplementedError, diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index bf6db6e9f16ec..2d5a134f8560a 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -13,7 +13,10 @@ be added to the array-specific tests in `pandas/tests/arrays/`. """ +from __future__ import annotations + import string +from typing import cast import numpy as np import pytest @@ -90,7 +93,7 @@ def data_for_grouping(dtype, chunked): return maybe_split_array(arr, chunked) -class TestDtype(base.BaseDtypeTests): +class TestStringArray(base.ExtensionTests): def test_eq_with_str(self, dtype): assert dtype == f"string[{dtype.storage}]" super().test_eq_with_str(dtype) @@ -100,43 +103,25 @@ def test_is_not_string_type(self, dtype): # because StringDtype is a string type assert is_string_dtype(dtype) - -class TestInterface(base.BaseInterfaceTests): def test_view(self, data, request, arrow_string_storage): if data.dtype.storage in arrow_string_storage: pytest.skip(reason="2D support not implemented for ArrowStringArray") super().test_view(data) - -class TestConstructors(base.BaseConstructorsTests): def test_from_dtype(self, data): # base test uses string representation of dtype pass - -class TestReshaping(base.BaseReshapingTests): def test_transpose(self, data, request, arrow_string_storage): if data.dtype.storage in arrow_string_storage: pytest.skip(reason="2D support not implemented for ArrowStringArray") super().test_transpose(data) - -class TestGetitem(base.BaseGetitemTests): - pass - - -class TestSetitem(base.BaseSetitemTests): def test_setitem_preserves_views(self, data, request, arrow_string_storage): if data.dtype.storage in arrow_string_storage: pytest.skip(reason="2D support not implemented for ArrowStringArray") super().test_setitem_preserves_views(data) - -class TestIndex(base.BaseIndexTests): - pass - - -class TestMissing(base.BaseMissingTests): def test_dropna_array(self, data_missing): result = data_missing.dropna() expected = data_missing[[1]] @@ -154,8 +139,57 @@ def test_fillna_no_op_returns_copy(self, data): assert result is not data tm.assert_extension_array_equal(result, data) + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | None: + if op_name in ["__divmod__", "__rdivmod__"]: + if isinstance(obj, pd.Series) and cast( + StringDtype, tm.get_dtype(obj) + ).storage in [ + "pyarrow", + "pyarrow_numpy", + ]: + # TODO: re-raise as TypeError? + return NotImplementedError + elif isinstance(other, pd.Series) and cast( + StringDtype, tm.get_dtype(other) + ).storage in [ + "pyarrow", + "pyarrow_numpy", + ]: + # TODO: re-raise as TypeError? + return NotImplementedError + return TypeError + elif op_name in ["__mod__", "__rmod__", "__pow__", "__rpow__"]: + if cast(StringDtype, tm.get_dtype(obj)).storage in [ + "pyarrow", + "pyarrow_numpy", + ]: + return NotImplementedError + return TypeError + elif op_name in ["__mul__", "__rmul__"]: + # Can only multiply strings by integers + return TypeError + elif op_name in [ + "__truediv__", + "__rtruediv__", + "__floordiv__", + "__rfloordiv__", + "__sub__", + "__rsub__", + ]: + if cast(StringDtype, tm.get_dtype(obj)).storage in [ + "pyarrow", + "pyarrow_numpy", + ]: + import pyarrow as pa + + # TODO: better to re-raise as TypeError? + return pa.ArrowNotImplementedError + return TypeError + + return None -class TestReduce(base.BaseReduceTests): def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: return ( op_name in ["min", "max"] @@ -163,42 +197,22 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: and op_name in ("any", "all") ) - -class TestMethods(base.BaseMethodsTests): - pass - - -class TestCasting(base.BaseCastingTests): - pass - - -class TestComparisonOps(base.BaseComparisonOpsTests): def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): - dtype = tm.get_dtype(obj) - # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no - # attribute "storage" - if dtype.storage == "pyarrow": # type: ignore[union-attr] - cast_to = "boolean[pyarrow]" - elif dtype.storage == "pyarrow_numpy": # type: ignore[union-attr] + dtype = cast(StringDtype, tm.get_dtype(obj)) + if op_name in ["__add__", "__radd__"]: + cast_to = dtype + elif dtype.storage == "pyarrow": + cast_to = "boolean[pyarrow]" # type: ignore[assignment] + elif dtype.storage == "pyarrow_numpy": cast_to = np.bool_ # type: ignore[assignment] else: - cast_to = "boolean" + cast_to = "boolean" # type: ignore[assignment] return pointwise_result.astype(cast_to) def test_compare_scalar(self, data, comparison_op): ser = pd.Series(data) self._compare_other(ser, data, comparison_op, "abc") - -class TestParsing(base.BaseParsingTests): - pass - - -class TestPrinting(base.BasePrintingTests): - pass - - -class TestGroupBy(base.BaseGroupbyTests): @pytest.mark.filterwarnings("ignore:Falling back:pandas.errors.PerformanceWarning") def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op): super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)