Skip to content

Commit 64f206f

Browse files
Backport PR #33089 on branch 1.0.x (BUG: Don't cast nullable Boolean to float in groupby) (#34023)
Co-authored-by: Daniel Saxton <[email protected]>
1 parent 84f5b4e commit 64f206f

File tree

4 files changed

+39
-12
lines changed

4 files changed

+39
-12
lines changed

doc/source/whatsnew/v1.0.4.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ Fixed regressions
2121
- Fix performance regression in ``memory_usage(deep=True)`` for object dtype (:issue:`33012`)
2222
- Bug where :meth:`Categorical.replace` would replace with ``NaN`` whenever the new value and replacement value were equal (:issue:`33288`)
2323
- Bug where an ordered :class:`Categorical` containing only ``NaN`` values would raise rather than returning ``NaN`` when taking the minimum or maximum (:issue:`33450`)
24+
- Bug in :meth:`DataFrameGroupBy.agg` with dictionary input losing ``ExtensionArray`` dtypes (:issue:`32194`)
2425
- Fix to preserve the ability to index with the "nearest" method with xarray's CFTimeIndex, an :class:`Index` subclass (`pydata/xarray#3751 <https://github.com/pydata/xarray/issues/3751>`_, :issue:`32905`).
2526
-
2627

2728
.. _whatsnew_104.bug_fixes:
2829

2930
Bug fixes
3031
~~~~~~~~~
32+
- Bug in :meth:`SeriesGroupBy.first`, :meth:`SeriesGroupBy.last`, :meth:`SeriesGroupBy.min`, and :meth:`SeriesGroupBy.max` returning floats when applied to nullable Booleans (:issue:`33071`)
3133
- Bug in :meth:`Rolling.min` and :meth:`Rolling.max`: Growing memory usage after multiple calls when using a fixed window (:issue:`30726`)
3234
-
3335

pandas/core/groupby/groupby.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class providing the base-class of operations.
4242
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
4343
from pandas.core.dtypes.common import (
4444
ensure_float,
45+
is_categorical_dtype,
4546
is_datetime64_dtype,
4647
is_extension_array_dtype,
4748
is_integer_dtype,
@@ -807,15 +808,15 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
807808
dtype = obj.dtype
808809

809810
if not is_scalar(result):
810-
if is_extension_array_dtype(dtype) and dtype.kind != "M":
811-
# The function can return something of any type, so check
812-
# if the type is compatible with the calling EA.
813-
# datetime64tz is handled correctly in agg_series,
814-
# so is excluded here.
815-
816-
if len(result) and isinstance(result[0], dtype.type):
817-
cls = dtype.construct_array_type()
818-
result = try_cast_to_ea(cls, result, dtype=dtype)
811+
if (
812+
is_extension_array_dtype(dtype)
813+
and not is_categorical_dtype(dtype)
814+
and dtype.kind != "M"
815+
):
816+
# We have to special case categorical so as not to upcast
817+
# things like counts back to categorical
818+
cls = dtype.construct_array_type()
819+
result = try_cast_to_ea(cls, result, dtype=dtype)
819820

820821
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
821822
result = maybe_downcast_to_dtype(result, dtype)

pandas/tests/groupby/test_nth.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,32 @@ def test_first_last_tz_multi_column(method, ts, alpha):
395395
tm.assert_frame_equal(result, expected)
396396

397397

398+
@pytest.mark.parametrize(
399+
"values",
400+
[
401+
pd.array([True, False], dtype="boolean"),
402+
pd.array([1, 2], dtype="Int64"),
403+
pd.to_datetime(["2020-01-01", "2020-02-01"]),
404+
pd.to_timedelta([1, 2], unit="D"),
405+
],
406+
)
407+
@pytest.mark.parametrize("function", ["first", "last", "min", "max"])
408+
def test_first_last_extension_array_keeps_dtype(values, function):
409+
# https://github.com/pandas-dev/pandas/issues/33071
410+
# https://github.com/pandas-dev/pandas/issues/32194
411+
df = DataFrame({"a": [1, 2], "b": values})
412+
grouped = df.groupby("a")
413+
idx = Index([1, 2], name="a")
414+
expected_series = Series(values, name="b", index=idx)
415+
expected_frame = DataFrame({"b": values}, index=idx)
416+
417+
result_series = getattr(grouped["b"], function)()
418+
tm.assert_series_equal(result_series, expected_series)
419+
420+
result_frame = grouped.agg({"b": function})
421+
tm.assert_frame_equal(result_frame, expected_frame)
422+
423+
398424
def test_nth_multi_index_as_expected():
399425
# PR 9090, related to issue 8979
400426
# test nth on MultiIndex

pandas/tests/resample/test_datetime_index.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ def test_resample_integerarray():
122122

123123
result = ts.resample("3T").mean()
124124
expected = Series(
125-
[1, 4, 7],
126-
index=pd.date_range("1/1/2000", periods=3, freq="3T"),
127-
dtype="float64",
125+
[1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64",
128126
)
129127
tm.assert_series_equal(result, expected)
130128

0 commit comments

Comments
 (0)