From df64bcd57e014ea0bda5be18278ad887c8763f2e Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 15 Dec 2020 18:33:23 -0800 Subject: [PATCH 1/2] API: CategoricalDtype.__eq__ with categories=None stricter --- pandas/core/dtypes/dtypes.py | 10 ++++------ pandas/tests/dtypes/cast/test_infer_dtype.py | 5 +++-- pandas/tests/dtypes/test_common.py | 1 - pandas/tests/dtypes/test_dtypes.py | 21 ++++++++++++++++++-- pandas/tests/reshape/merge/test_merge.py | 14 ++++++++++--- 5 files changed, 37 insertions(+), 14 deletions(-) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 0de8a07abbec3..1a4a836e96daa 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -354,12 +354,10 @@ def __eq__(self, other: Any) -> bool: elif not (hasattr(other, "ordered") and hasattr(other, "categories")): return False elif self.categories is None or other.categories is None: - # We're forced into a suboptimal corner thanks to math and - # backwards compatibility. We require that `CDT(...) == 'category'` - # for all CDTs **including** `CDT(None, ...)`. Therefore, *all* - # CDT(., .) = CDT(None, False) and *all* - # CDT(., .) = CDT(None, True). - return True + # For non-fully-initialized dtypes, these are only equal to + # - the string "categorical" (handled above) + # - other CategoricalDtype with categories=None + return self.categories is other.categories elif self.ordered or other.ordered: # At least one has ordered=True; equal if both have ordered=True # and the same values for categories in the same order. diff --git a/pandas/tests/dtypes/cast/test_infer_dtype.py b/pandas/tests/dtypes/cast/test_infer_dtype.py index 65da8985843f9..c21dd90f7c72b 100644 --- a/pandas/tests/dtypes/cast/test_infer_dtype.py +++ b/pandas/tests/dtypes/cast/test_infer_dtype.py @@ -8,6 +8,7 @@ from pandas import ( Categorical, + CategoricalDtype, Interval, Period, Series, @@ -149,8 +150,8 @@ def test_infer_dtype_from_scalar_errors(): (np.array([[1.0, 2.0]]), np.float_, False), (Categorical(list("aabc")), np.object_, False), (Categorical([1, 2, 3]), np.int64, False), - (Categorical(list("aabc")), "category", True), - (Categorical([1, 2, 3]), "category", True), + (Categorical(list("aabc")), CategoricalDtype(categories=["a", "b", "c"]), True), + (Categorical([1, 2, 3]), CategoricalDtype(categories=[1, 2, 3]), True), (Timestamp("20160101"), np.object_, False), (np.datetime64("2016-01-01"), np.dtype("=M8[D]"), False), (date_range("20160101", periods=3), np.dtype("=M8[ns]"), False), diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py index 0d0601aa542b4..9e75ba0864e76 100644 --- a/pandas/tests/dtypes/test_common.py +++ b/pandas/tests/dtypes/test_common.py @@ -641,7 +641,6 @@ def test_is_complex_dtype(): (pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])), (pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])), (CategoricalDtype(), CategoricalDtype()), - (CategoricalDtype(["a", "b"]), CategoricalDtype()), (pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")), (pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")), (" Date: Wed, 16 Dec 2020 17:24:55 -0800 Subject: [PATCH 2/2] Test for is_dtype_equal matching dtype.__eq__ --- pandas/core/arrays/categorical.py | 2 ++ pandas/core/dtypes/common.py | 13 +++++++++++++ pandas/core/dtypes/dtypes.py | 2 +- pandas/core/indexes/category.py | 7 ++++++- pandas/tests/arrays/categorical/test_dtypes.py | 2 +- pandas/tests/dtypes/test_dtypes.py | 11 +++++++++++ 6 files changed, 34 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 27110fe1f8439..5a418ec908d12 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -36,6 +36,7 @@ is_scalar, is_timedelta64_dtype, needs_i8_conversion, + pandas_dtype, ) from pandas.core.dtypes.dtypes import CategoricalDtype from pandas.core.dtypes.generic import ABCIndex, ABCSeries @@ -403,6 +404,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike: If copy is set to False and dtype is categorical, the original object is returned. """ + dtype = pandas_dtype(dtype) if self.dtype is dtype: result = self.copy() if copy else self diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 081339583e3fd..5869b2cf22516 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -639,6 +639,19 @@ def is_dtype_equal(source, target) -> bool: >>> is_dtype_equal(DatetimeTZDtype(tz="UTC"), "datetime64") False """ + if isinstance(target, str): + if not isinstance(source, str): + # GH#38516 ensure we get the same behavior from + # is_dtype_equal(CDT, "category") and CDT == "category" + try: + src = get_dtype(source) + if isinstance(src, ExtensionDtype): + return src == target + except (TypeError, AttributeError): + return False + elif isinstance(source, str): + return is_dtype_equal(target, source) + try: source = get_dtype(source) target = get_dtype(target) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 1a4a836e96daa..75f3b511bc57d 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -355,7 +355,7 @@ def __eq__(self, other: Any) -> bool: return False elif self.categories is None or other.categories is None: # For non-fully-initialized dtypes, these are only equal to - # - the string "categorical" (handled above) + # - the string "category" (handled above) # - other CategoricalDtype with categories=None return self.categories is other.categories elif self.ordered or other.ordered: diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index e2a7752cf3f0d..35fd8af9cd36e 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -201,8 +201,13 @@ def __new__( if not isinstance(data, Categorical): data = Categorical(data, dtype=dtype) - elif isinstance(dtype, CategoricalDtype) and dtype != data.dtype: + elif ( + isinstance(dtype, CategoricalDtype) + and dtype != data.dtype + and dtype.categories is not None + ): # we want to silently ignore dtype='category' + # TODO: what if dtype.ordered is not None but dtype.categories is? data = data._set_dtype(dtype) data = data.copy() if copy else data diff --git a/pandas/tests/arrays/categorical/test_dtypes.py b/pandas/tests/arrays/categorical/test_dtypes.py index 12654388de904..a2192b2810596 100644 --- a/pandas/tests/arrays/categorical/test_dtypes.py +++ b/pandas/tests/arrays/categorical/test_dtypes.py @@ -127,7 +127,7 @@ def test_astype(self, ordered): expected = np.array(cat) tm.assert_numpy_array_equal(result, expected) - msg = r"Cannot cast object dtype to " + msg = r"Cannot cast object dtype to float64" with pytest.raises(ValueError, match=msg): cat.astype(float) diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index 1afe389e86668..8ba8562affb67 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -90,9 +90,20 @@ def test_hash_vs_equality(self, dtype): assert hash(dtype) == hash(dtype2) def test_equality(self, dtype): + assert dtype == "category" assert is_dtype_equal(dtype, "category") + assert "category" == dtype + assert is_dtype_equal("category", dtype) + + assert dtype == CategoricalDtype() assert is_dtype_equal(dtype, CategoricalDtype()) + assert CategoricalDtype() == dtype + assert is_dtype_equal(CategoricalDtype(), dtype) + + assert dtype != "foo" assert not is_dtype_equal(dtype, "foo") + assert "foo" != dtype + assert not is_dtype_equal("foo", dtype) def test_construction_from_string(self, dtype): result = CategoricalDtype.construct_from_string("category")