24
24
TypedDict ,
25
25
TypeVar ,
26
26
Union ,
27
+ cast ,
27
28
overload ,
28
29
)
29
30
@@ -843,7 +844,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
843
844
return offset , size
844
845
845
846
846
- def _factorize_single (by , expect , * , sort : bool , reindex : bool ):
847
+ def _factorize_single (by , expect , * , sort : bool , reindex : bool ) -> tuple [ pd . Index , np . ndarray ] :
847
848
flat = by .reshape (- 1 )
848
849
if isinstance (expect , pd .RangeIndex ):
849
850
# idx is a view of the original `by` array
@@ -852,7 +853,7 @@ def _factorize_single(by, expect, *, sort: bool, reindex: bool):
852
853
# this is important in shared-memory parallelism with dask
853
854
# TODO: figure out how to avoid this
854
855
idx = flat .copy ()
855
- found_groups = np . array ( expect )
856
+ found_groups = cast ( pd . Index , expect )
856
857
# TODO: fix by using masked integers
857
858
idx [idx > expect [- 1 ]] = - 1
858
859
@@ -875,7 +876,7 @@ def _factorize_single(by, expect, *, sort: bool, reindex: bool):
875
876
idx [~ within_bins ] = - 1
876
877
else :
877
878
idx = np .zeros_like (flat , dtype = np .intp ) - 1
878
- found_groups = np . array ( expect )
879
+ found_groups = cast ( pd . Index , expect )
879
880
else :
880
881
if expect is not None and reindex :
881
882
sorter = np .argsort (expect )
@@ -890,7 +891,7 @@ def _factorize_single(by, expect, *, sort: bool, reindex: bool):
890
891
idx [mask ] = - 1
891
892
else :
892
893
idx , groups = pd .factorize (flat , sort = sort )
893
- found_groups = np . array ( groups )
894
+ found_groups = cast ( pd . Index , groups )
894
895
895
896
return (found_groups , idx .reshape (by .shape ))
896
897
@@ -913,7 +914,7 @@ def factorize_(
913
914
expected_groups : T_ExpectIndexOptTuple | None = None ,
914
915
reindex : bool = False ,
915
916
sort : bool = True ,
916
- ) -> tuple [np .ndarray , tuple [np . ndarray , ...], tuple [int , ...], int , int , None ]: ...
917
+ ) -> tuple [np .ndarray , tuple [pd . Index , ...], tuple [int , ...], int , int , None ]: ...
917
918
918
919
919
920
@overload
@@ -925,7 +926,7 @@ def factorize_(
925
926
reindex : bool = False ,
926
927
sort : bool = True ,
927
928
fastpath : Literal [False ] = False ,
928
- ) -> tuple [np .ndarray , tuple [np . ndarray , ...], tuple [int , ...], int , int , FactorProps ]: ...
929
+ ) -> tuple [np .ndarray , tuple [pd . Index , ...], tuple [int , ...], int , int , FactorProps ]: ...
929
930
930
931
931
932
@overload
@@ -937,7 +938,7 @@ def factorize_(
937
938
reindex : bool = False ,
938
939
sort : bool = True ,
939
940
fastpath : bool = False ,
940
- ) -> tuple [np .ndarray , tuple [np . ndarray , ...], tuple [int , ...], int , int , FactorProps | None ]: ...
941
+ ) -> tuple [np .ndarray , tuple [pd . Index , ...], tuple [int , ...], int , int , FactorProps | None ]: ...
941
942
942
943
943
944
def factorize_ (
@@ -948,7 +949,7 @@ def factorize_(
948
949
reindex : bool = False ,
949
950
sort : bool = True ,
950
951
fastpath : bool = False ,
951
- ) -> tuple [np .ndarray , tuple [np . ndarray , ...], tuple [int , ...], int , int , FactorProps | None ]:
952
+ ) -> tuple [np .ndarray , tuple [pd . Index , ...], tuple [int , ...], int , int , FactorProps | None ]:
952
953
"""
953
954
Returns an array of integer codes for groups (and associated data)
954
955
by wrapping pd.cut and pd.factorize (depending on isbin).
@@ -971,7 +972,7 @@ def factorize_(
971
972
_factorize_single (groupvar , expect , sort = sort , reindex = reindex )
972
973
for groupvar , expect in zip (by , expected_groups )
973
974
)
974
- found_groups = [ r [0 ] for r in results ]
975
+ found_groups = tuple ( r [0 ] for r in results )
975
976
factorized = [r [1 ] for r in results ]
976
977
977
978
grp_shape = tuple (len (grp ) for grp in found_groups )
@@ -982,7 +983,7 @@ def factorize_(
982
983
(group_idx ,) = factorized
983
984
984
985
if fastpath :
985
- return group_idx , tuple ( found_groups ) , grp_shape , ngroups , ngroups , None
986
+ return group_idx , found_groups , grp_shape , ngroups , ngroups , None
986
987
987
988
if len (axes ) == 1 and by [0 ].ndim > 1 :
988
989
# Not reducing along all dimensions of by
@@ -1178,7 +1179,7 @@ def chunk_reduce(
1178
1179
results : IntermediateDict = {"groups" : [], "intermediates" : []}
1179
1180
if reindex and expected_groups is not None :
1180
1181
# TODO: what happens with binning here?
1181
- results ["groups" ] = expected_groups . to_numpy ()
1182
+ results ["groups" ] = expected_groups
1182
1183
else :
1183
1184
if empty :
1184
1185
results ["groups" ] = np .array ([np .nan ])
@@ -1307,7 +1308,7 @@ def _finalize_results(
1307
1308
fill_value = fill_value ,
1308
1309
array_type = reindex .array_type ,
1309
1310
)
1310
- finalized ["groups" ] = expected_groups . to_numpy ()
1311
+ finalized ["groups" ] = expected_groups
1311
1312
else :
1312
1313
finalized ["groups" ] = squeezed ["groups" ]
1313
1314
@@ -2272,7 +2273,7 @@ def _factorize_multiple(
2272
2273
expected_groups : T_ExpectIndexOptTuple ,
2273
2274
any_by_dask : bool ,
2274
2275
sort : bool = True ,
2275
- ) -> tuple [tuple [np .ndarray ], tuple [np . ndarray , ...], tuple [int , ...]]:
2276
+ ) -> tuple [tuple [np .ndarray ], tuple [pd . Index , ...], tuple [int , ...]]:
2276
2277
kwargs : FactorizeKwargs = dict (
2277
2278
axes = (), # always (), we offset later if necessary.
2278
2279
fastpath = True ,
@@ -2293,7 +2294,7 @@ def _factorize_multiple(
2293
2294
raise ValueError ("Please provide expected_groups when grouping by a dask array." )
2294
2295
2295
2296
found_groups = tuple (
2296
- pd .unique (by_ .reshape (- 1 )) if expect is None else expect . to_numpy ()
2297
+ pd .Index ( pd . unique (by_ .reshape (- 1 ))) if expect is None else expect
2297
2298
for by_ , expect in zip (by , expected_groups )
2298
2299
)
2299
2300
grp_shape = tuple (map (len , found_groups ))
@@ -2883,6 +2884,9 @@ def groupby_reduce(
2883
2884
result = asdelta + offset
2884
2885
result [nanmask ] = np .timedelta64 ("NaT" )
2885
2886
2887
+ groups = map (
2888
+ lambda g : g .to_numpy () if isinstance (g , pd .Index ) and not isinstance (g , pd .RangeIndex ) else g , groups
2889
+ )
2886
2890
return (result , * groups )
2887
2891
2888
2892
0 commit comments