-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Cythonized GroupBy Quantile #20405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cythonized GroupBy Quantile #20405
Changes from all commits
618ec99
74871d8
7b6ca68
31aff03
4a43815
eb18823
813da81
e152dd5
7a8fefb
b4938ba
3f7d0a9
d7aec3f
e712946
72cd30e
a3c4b11
ac96526
7d439d8
3047eed
70bf89a
02eb336
7c3c349
3b9c7c4
ad8b184
b846bc2
93b122c
09308d4
1a718f2
ff062bd
bdb5089
9b55fb5
31e66fc
41a734f
67e0f00
07b0c00
86aeb4a
86b9d8d
cfa1b45
00085d0
1f02532
3c64c1f
09695f5
68cfed9
4ce1448
5e840da
7969fb6
f9a8317
464a831
4b3f9be
b996e1d
cdd8985
64f46a3
4d88e8a
1cd93dd
9ae23c1
eb99f07
94d4892
0512f37
2370129
a018570
f41cd05
21691bb
082aea3
dc5877a
7496a9b
ec013bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
cdef enum InterpolationEnumType: | ||
INTERPOLATION_LINEAR, | ||
INTERPOLATION_LOWER, | ||
INTERPOLATION_HIGHER, | ||
INTERPOLATION_NEAREST, | ||
INTERPOLATION_MIDPOINT |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,8 @@ class providing the base-class of operations. | |
ensure_float, is_extension_array_dtype, is_numeric_dtype, is_scalar) | ||
from pandas.core.dtypes.missing import isna, notna | ||
|
||
from pandas.api.types import ( | ||
is_datetime64_dtype, is_integer_dtype, is_object_dtype) | ||
import pandas.core.algorithms as algorithms | ||
from pandas.core.base import ( | ||
DataError, GroupByError, PandasObject, SelectionMixin, SpecificationError) | ||
|
@@ -1024,15 +1026,17 @@ def _bool_agg(self, val_test, skipna): | |
""" | ||
|
||
def objs_to_bool(vals): | ||
try: | ||
vals = vals.astype(np.bool) | ||
except ValueError: # for objects | ||
# type: np.ndarray -> (np.ndarray, typing.Type) | ||
if is_object_dtype(vals): | ||
vals = np.array([bool(x) for x in vals]) | ||
else: | ||
vals = vals.astype(np.bool) | ||
|
||
return vals.view(np.uint8) | ||
return vals.view(np.uint8), np.bool | ||
|
||
def result_to_bool(result): | ||
return result.astype(np.bool, copy=False) | ||
def result_to_bool(result, inference): | ||
# type: (np.ndarray, typing.Type) -> np.ndarray | ||
return result.astype(inference, copy=False) | ||
|
||
return self._get_cythonized_result('group_any_all', self.grouper, | ||
aggregate=True, | ||
|
@@ -1688,6 +1692,75 @@ def nth(self, n, dropna=None): | |
|
||
return result | ||
|
||
def quantile(self, q=0.5, interpolation='linear'): | ||
""" | ||
Return group values at the given quantile, a la numpy.percentile. | ||
|
||
Parameters | ||
---------- | ||
q : float or array-like, default 0.5 (50% quantile) | ||
Value(s) between 0 and 1 providing the quantile(s) to compute. | ||
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} | ||
Method to use when the desired quantile falls between two points. | ||
|
||
Returns | ||
------- | ||
Series or DataFrame | ||
Return type determined by caller of GroupBy object. | ||
|
||
See Also | ||
-------- | ||
Series.quantile : Similar method for Series. | ||
DataFrame.quantile : Similar method for DataFrame. | ||
numpy.percentile : NumPy method to compute qth percentile. | ||
|
||
Examples | ||
-------- | ||
>>> df = pd.DataFrame([ | ||
... ['a', 1], ['a', 2], ['a', 3], | ||
... ['b', 1], ['b', 3], ['b', 5] | ||
... ], columns=['key', 'val']) | ||
>>> df.groupby('key').quantile() | ||
val | ||
key | ||
a 2.0 | ||
b 3.0 | ||
""" | ||
|
||
def pre_processor(vals): | ||
# type: np.ndarray -> (np.ndarray, Optional[typing.Type]) | ||
if is_object_dtype(vals): | ||
raise TypeError("'quantile' cannot be performed against " | ||
"'object' dtypes!") | ||
|
||
inference = None | ||
if is_integer_dtype(vals): | ||
inference = np.int64 | ||
elif is_datetime64_dtype(vals): | ||
inference = 'datetime64[ns]' | ||
vals = vals.astype(np.float) | ||
|
||
return vals, inference | ||
|
||
def post_processor(vals, inference): | ||
# type: (np.ndarray, Optional[typing.Type]) -> np.ndarray | ||
if inference: | ||
# Check for edge case | ||
if not (is_integer_dtype(inference) and | ||
interpolation in {'linear', 'midpoint'}): | ||
vals = vals.astype(inference) | ||
|
||
return vals | ||
|
||
return self._get_cythonized_result('group_quantile', self.grouper, | ||
aggregate=True, | ||
needs_values=True, | ||
needs_mask=True, | ||
cython_dtype=np.float64, | ||
pre_processing=pre_processor, | ||
post_processing=post_processor, | ||
q=q, interpolation=interpolation) | ||
|
||
@Substitution(name='groupby') | ||
def ngroup(self, ascending=True): | ||
""" | ||
|
@@ -1924,10 +1997,16 @@ def _get_cythonized_result(self, how, grouper, aggregate=False, | |
Whether the result of the Cython operation is an index of | ||
values to be retrieved, instead of the actual values themselves | ||
pre_processing : function, default None | ||
Function to be applied to `values` prior to passing to Cython | ||
Raises if `needs_values` is False | ||
Function to be applied to `values` prior to passing to Cython. | ||
Function should return a tuple where the first element is the | ||
values to be passed to Cython and the second element is an optional | ||
type which the values should be converted to after being returned | ||
by the Cython operation. Raises if `needs_values` is False. | ||
post_processing : function, default None | ||
Function to be applied to result of Cython function | ||
Function to be applied to result of Cython function. Should accept | ||
WillAyd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
an array of values as the first argument and type inferences as its | ||
second argument, i.e. the signature should be | ||
(ndarray, typing.Type). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really what I meant, this is a (ndarray, dict) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well on a second pass the dict was unnecessary so I ended up just returning the type. I suppose it could either be an actual type object or a string so this is an approximation of that for simplicity, though could update to include a string as well |
||
**kwargs : dict | ||
Extra arguments to be passed back to Cython funcs | ||
|
||
|
@@ -1963,10 +2042,12 @@ def _get_cythonized_result(self, how, grouper, aggregate=False, | |
|
||
result = np.zeros(result_sz, dtype=cython_dtype) | ||
func = partial(base_func, result, labels) | ||
inferences = None | ||
|
||
if needs_values: | ||
WillAyd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
vals = obj.values | ||
if pre_processing: | ||
vals = pre_processing(vals) | ||
vals, inferences = pre_processing(vals) | ||
func = partial(func, vals) | ||
|
||
if needs_mask: | ||
|
@@ -1982,7 +2063,7 @@ def _get_cythonized_result(self, how, grouper, aggregate=False, | |
result = algorithms.take_nd(obj.values, result) | ||
|
||
if post_processing: | ||
result = post_processing(result) | ||
result = post_processing(result, inferences) | ||
WillAyd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
output[name] = result | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1060,6 +1060,55 @@ def test_size(df): | |
tm.assert_series_equal(df.groupby('A').size(), out) | ||
|
||
|
||
# quantile | ||
# -------------------------------- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. side note this file is getting pretty big, maybe should split it up a bit (later) |
||
@pytest.mark.parametrize("interpolation", [ | ||
"linear", "lower", "higher", "nearest", "midpoint"]) | ||
@pytest.mark.parametrize("a_vals,b_vals", [ | ||
# Ints | ||
([1, 2, 3, 4, 5], [5, 4, 3, 2, 1]), | ||
([1, 2, 3, 4], [4, 3, 2, 1]), | ||
([1, 2, 3, 4, 5], [4, 3, 2, 1]), | ||
# Floats | ||
([1., 2., 3., 4., 5.], [5., 4., 3., 2., 1.]), | ||
# Missing data | ||
([1., np.nan, 3., np.nan, 5.], [5., np.nan, 3., np.nan, 1.]), | ||
([np.nan, 4., np.nan, 2., np.nan], [np.nan, 4., np.nan, 2., np.nan]), | ||
# Timestamps | ||
([x for x in pd.date_range('1/1/18', freq='D', periods=5)], | ||
[x for x in pd.date_range('1/1/18', freq='D', periods=5)][::-1]), | ||
# All NA | ||
([np.nan] * 5, [np.nan] * 5), | ||
]) | ||
@pytest.mark.parametrize('q', [0, .25, .5, .75, 1]) | ||
def test_quantile(interpolation, a_vals, b_vals, q): | ||
if interpolation == 'nearest' and q == 0.5 and b_vals == [4, 3, 2, 1]: | ||
pytest.skip("Unclear numpy expectation for nearest result with " | ||
"equidistant data") | ||
|
||
a_expected = pd.Series(a_vals).quantile(q, interpolation=interpolation) | ||
b_expected = pd.Series(b_vals).quantile(q, interpolation=interpolation) | ||
|
||
df = DataFrame({ | ||
'key': ['a'] * len(a_vals) + ['b'] * len(b_vals), | ||
'val': a_vals + b_vals}) | ||
|
||
expected = DataFrame([a_expected, b_expected], columns=['val'], | ||
index=Index(['a', 'b'], name='key')) | ||
result = df.groupby('key').quantile(q, interpolation=interpolation) | ||
|
||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
def test_quantile_raises(): | ||
df = pd.DataFrame([ | ||
['foo', 'a'], ['foo', 'b'], ['foo', 'c']], columns=['key', 'val']) | ||
|
||
with pytest.raises(TypeError, match="cannot be performed against " | ||
"'object' dtypes"): | ||
df.groupby('key').quantile() | ||
|
||
|
||
# pipe | ||
# -------------------------------- | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.