Open
Description
Feature request
It would be really nice if we may JIT-compile user-defined functions passed to reduce
, so we can use them more efficiently. A toy example:
from nested_pandas.datasets import generate_data
from numba import njit
@njit
def reduce(fn, offsets, t, flux):
results = np.empty(offsets.size - 1, dtype=t.dtype)
for i in range(results.size):
start, end = offsets[i], offsets[i + 1]
results[i] = fn(t[start:end], flux[start:end])
return results
@njit
def max_slope(t, flux):
# Assuming t is sorted
slope = np.diff(flux) / np.diff(t)
return np.max(slope)
nf = generate_data(10_000, 100)
ns = nf.nested
%timeit nf.reduce(max_slope, 'nested.t', 'nested.flux')
# 36.8 ms ± 286 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
reduce(
max_slope,
np.asarray(ns.array.list_offsets),
np.asarray(ns.nest['t']),
np.asarray(ns.nest['flux']),
)
# 7.38 ms ± 126 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
The similar is possible with Jax
Before submitting
Please check the following:
- I have described the purpose of the suggested change, specifying what I need the enhancement to accomplish, i.e. what problem it solves.
- I have included any relevant links, screenshots, environment information, and data relevant to implementing the requested feature, as well as pseudocode for how I want to access the new functionality.
- If I have ideas for how the new feature could be implemented, I have provided explanations and/or pseudocode and/or task lists for the steps.