-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add numba engine for rolling apply #30151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3b9bff8
9a302bf
0e9a600
36a77ed
dbb2a9b
f0e9a4d
1250aee
4e7fd1a
cb976cf
45420bb
17851cf
20767ca
9619f8d
66fa69c
b8908ea
135f2ad
34a5687
6da8199
123f77e
54e74d1
04d3530
4bbf587
f849bc7
0c30e48
c4c952e
8645976
987c916
b775684
2e04e60
9b20ff5
0c14033
c7106dc
1640085
2846faf
5a645c0
6bac000
6f1c73f
a890337
0a9071c
9d8d40b
84c3491
a429206
5826ad9
cf7571b
4bc9787
18eed60
f715b55
6a765bf
af3fe50
eb7b5e1
f7dfcf4
a42a960
d019830
29d145f
248149c
a3da51e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
"xlrd": "1.1.0", | ||
"xlwt": "1.2.0", | ||
"xlsxwriter": "0.9.8", | ||
"numba": "0.46.0", | ||
} | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import types | ||
from typing import Any, Callable, Dict, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
from pandas._typing import Scalar | ||
from pandas.compat._optional import import_optional_dependency | ||
|
||
|
||
def make_rolling_apply( | ||
func: Callable[..., Scalar], | ||
args: Tuple, | ||
nogil: bool, | ||
parallel: bool, | ||
nopython: bool, | ||
): | ||
""" | ||
Creates a JITted rolling apply function with a JITted version of | ||
the user's function. | ||
|
||
Parameters | ||
---------- | ||
func : function | ||
function to be applied to each window and will be JITed | ||
args : tuple | ||
*args to be passed into the function | ||
nogil : bool | ||
nogil parameter from engine_kwargs for numba.jit | ||
parallel : bool | ||
parallel parameter from engine_kwargs for numba.jit | ||
nopython : bool | ||
nopython parameter from engine_kwargs for numba.jit | ||
|
||
Returns | ||
------- | ||
Numba function | ||
""" | ||
numba = import_optional_dependency("numba") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a doc-string that says what this function does (the parameters are already documented elsewhere, maybe just mention that) |
||
|
||
if parallel: | ||
loop_range = numba.prange | ||
else: | ||
loop_range = range | ||
|
||
if isinstance(func, numba.targets.registry.CPUDispatcher): | ||
# Don't jit a user passed jitted function | ||
numba_func = func | ||
else: | ||
|
||
@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stuartarchibald sorry for the ping, but I see that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mroeschke no problem, I can try and help with this. I think it needs to look a bit like this (for reference, this is untested, I am just guessing from the context! Also, the import types
import numpy as np
from numba.extending import overload, is_jitted
from numba import njit
import numba
# this provides a local definition to overload
def overload_target(window, *_args):
# If JIT is disabled, this function will run, so write the implementation here!
pass
nopython = True
nogil = True
parallel = False
# pretend this is an arg to `make_rolling_apply`
def func(window, *args):
return window * 2 + args[0]
@overload(overload_target, jit_options={'nopython':nopython, 'nogil':nogil,
'parallel':parallel})
def ol_overload_target(window, *_args):
# This function "overloads" `overload_target`, whenever the Numba compiler
# "sees" `overload_target` it will use this function.
# Using `is_jitted` to avoid `isinstance` on
# `numba.targets.registry.CPUDispatcher` as that may be considered an
# internal Numba detail.
if is_jitted(func):
# it's already JIT compiled so just reference it
overload_target_impl = func
elif getattr(np, func.__name__, False) is func or isinstance(
func, types.BuiltinFunctionType
):
# it's a NumPy function or builtin so just reference it
overload_target_impl = func
else:
# it's a Python function, so register it as JIT compilable and reference
# that
overload_target_impl = numba.jit(func, nopython=nopython, nogil=nogil)
# This is the Numba implementation of the overload, it will just be JIT
# compiled whenever the compiler "sees" a reference to "overload_target" in
# code it is compiling.
def impl(window, *_args):
return overload_target_impl(window, *_args)
return impl
# demo
@njit
def roll_apply(window, *_args):
return overload_target(window, *_args)
print(roll_apply(np.arange(10.), 1.23))
Hope this helps? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the reply! We had a PR recently that refactored this to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem! I just took a look at the patch above, I think it'd work but think it might lose some of the dispatch ability offered by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great thanks for the context! Yeah this function should expect a custom UDF so thanks for the confirmation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Glad to get this resolved, thanks for confirming too! It sounds like the replacement above is appropriate. If there are any more issues/queries feel free to open issues on the Numba issue tracker (or ping here!). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stuartarchibald I'm running into a rolling apply issue with pandas 2.1.1 and numba 0.58 that might be related. Discussion is here: |
||
def numba_func(window, *_args): | ||
if getattr(np, func.__name__, False) is func or isinstance( | ||
func, types.BuiltinFunctionType | ||
): | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
jf = func | ||
else: | ||
jf = numba.jit(func, nopython=nopython, nogil=nogil) | ||
|
||
def impl(window, *_args): | ||
return jf(window, *_args) | ||
|
||
return impl | ||
|
||
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) | ||
def roll_apply( | ||
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int, | ||
) -> np.ndarray: | ||
result = np.empty(len(begin)) | ||
for i in loop_range(len(result)): | ||
start = begin[i] | ||
stop = end[i] | ||
window = values[start:stop] | ||
count_nan = np.sum(np.isnan(window)) | ||
if len(window) - count_nan >= minimum_periods: | ||
result[i] = numba_func(window, *args) | ||
else: | ||
result[i] = np.nan | ||
return result | ||
|
||
return roll_apply | ||
|
||
|
||
def generate_numba_apply_func( | ||
args: Tuple, | ||
kwargs: Dict[str, Any], | ||
func: Callable[..., Scalar], | ||
engine_kwargs: Optional[Dict[str, bool]], | ||
): | ||
""" | ||
Generate a numba jitted apply function specified by values from engine_kwargs. | ||
|
||
1. jit the user's function | ||
2. Return a rolling apply function with the jitted function inline | ||
|
||
Configurations specified in engine_kwargs apply to both the user's | ||
function _AND_ the rolling apply function. | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Parameters | ||
---------- | ||
args : tuple | ||
*args to be passed into the function | ||
kwargs : dict | ||
**kwargs to be passed into the function | ||
func : function | ||
function to be applied to each window and will be JITed | ||
engine_kwargs : dict | ||
dictionary of arguments to be passed into numba.jit | ||
|
||
Returns | ||
------- | ||
Numba function | ||
""" | ||
|
||
if engine_kwargs is None: | ||
engine_kwargs = {} | ||
|
||
nopython = engine_kwargs.get("nopython", True) | ||
nogil = engine_kwargs.get("nogil", False) | ||
parallel = engine_kwargs.get("parallel", False) | ||
|
||
if kwargs and nopython: | ||
raise ValueError( | ||
"numba does not support kwargs with nopython=True: " | ||
"https://github.com/numba/numba/issues/2916" | ||
) | ||
|
||
return make_rolling_apply(func, args, nogil, parallel, nopython) |
Uh oh!
There was an error while loading. Please reload this page.