diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index a443597347283..a5d7aced24c75 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -83,6 +83,7 @@ class OutputKey: groupby_other_methods = frozenset( [ "agg", + "agg_index", "aggregate", "apply", "boxplot", diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index d86062e35a20c..74f934cb89330 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1799,6 +1799,16 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs): res.index = default_index(len(res)) return res + @final + @property + def agg_index(self) -> Index: + """Index of an aggregation result. + + Produces the index that will be on the result of an aggregation. Always + returns the index as if ``as_index=True``. + """ + return self._grouper.result_index + # ----------------------------------------------------------------- # apply/agg/transform diff --git a/pandas/tests/groupby/test_api.py b/pandas/tests/groupby/test_api.py index 5c5982954de2f..db9f8fd0e646c 100644 --- a/pandas/tests/groupby/test_api.py +++ b/pandas/tests/groupby/test_api.py @@ -32,6 +32,7 @@ def test_tab_completion(multiindex_dataframe_random_data): "B", "C", "agg", + "agg_index", "aggregate", "apply", "boxplot", diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 14866ce2065ed..4281158f09110 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -3055,3 +3055,61 @@ def test_decimal_na_sort(test_series): result = gb._grouper.result_index expected = Index([Decimal(1), None], name="key") tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["float32", "Int64", "int16[pyarrow]"]) +@pytest.mark.parametrize("keys", [["a"], ["a", "b"]]) +def test_agg_index(observed, sort, dropna, dtype, keys): + # GH#??? + na_value = np.nan if dtype == "float32" else pd.NA + df = DataFrame( + { + "a": Series([2, na_value, 2, 1, na_value], dtype=dtype), + "b": Series([3, na_value, 3, na_value, 5], dtype=dtype), + } + ) + gb = df.groupby(keys, observed=observed, sort=sort, dropna=dropna) + result = gb.agg_index + data = df[keys].drop_duplicates() + if sort: + data = data.sort_values(keys) + if dropna: + data = data[~data.isna().any(axis=1)] + expected = data.set_index(keys).index + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize( + "keys", + [ + ["a"], + pytest.param( + ["a", "b"], marks=pytest.mark.xfail(reason="Does not include unobserved") + ), + ], +) +def test_agg_index_categorical(sort, dropna, keys): + # GH#??? + df = DataFrame( + { + "a": Categorical([2, 0, 2, 1, 0], categories=[1, 2, 3]), + "b": Categorical([3, 0, 3, 0, 2], categories=[1, 2, 3]), + } + ) + gb = df.groupby(keys, observed=False, sort=sort, dropna=dropna) + result = gb.agg_index + if keys == ["a"]: + data = DataFrame({"a": Categorical([2, 0, 1, 3], categories=[1, 2, 3])}) + else: + data = DataFrame( + { + "a": Categorical(np.repeat([2, 0, 1, 3], 4), categories=[1, 2, 3]), + "b": Categorical(4 * [3, 0, 2, 1], categories=[1, 2, 3]), + } + ) + if sort: + data = data.sort_values(keys) + if dropna: + data = data[~data.isna().any(axis=1)] + expected = data.set_index(keys).index + tm.assert_index_equal(result, expected)