Skip to content

Commit 734444c

Browse files
dcherianfjetter
andcommitted
Add quantile_tdigest
Co-authored-by: Florian Jetter <[email protected]>
1 parent 19db5b3 commit 734444c

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

flox/aggregations.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
from numpy.typing import DTypeLike
1010

11-
from . import aggregate_flox, aggregate_npg, xrutils
11+
from . import aggregate_flox, aggregate_npg, sketches, xrutils
1212
from . import xrdtypes as dtypes
1313

1414
if TYPE_CHECKING:
@@ -495,6 +495,16 @@ def _pick_second(*x):
495495
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
496496
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
497497

498+
499+
quantile_tdigest = Aggregation(
500+
"quantile_tdigest",
501+
numpy=(sketches.tdigest_aggregate,),
502+
chunk=(sketches.tdigest_chunk,),
503+
combine=(sketches.tdigest_combine,),
504+
finalize=sketches.tdigest_aggregate,
505+
)
506+
507+
498508
aggregations = {
499509
"any": any_,
500510
"all": all_,
@@ -527,6 +537,7 @@ def _pick_second(*x):
527537
"nanquantile": nanquantile,
528538
"mode": mode,
529539
"nanmode": nanmode,
540+
"quantile_tdigest": quantile_tdigest,
530541
}
531542

532543

flox/sketches.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
import numpy_groupies as npg
3+
4+
5+
def tdigest_chunk(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, **kwargs):
6+
from crick import TDigest
7+
8+
def _(arr):
9+
digest = TDigest()
10+
# we receive object arrays from numpy_groupies
11+
digest.update(arr.astype(array.dtype, copy=False))
12+
return digest
13+
14+
result = npg.aggregate_numpy.aggregate(group_idx, array, func=_, axis=axis, dtype=object)
15+
return result
16+
17+
18+
def tdigest_combine(digests, axis=-1, keepdims=True):
19+
from crick import TDigest
20+
21+
def _(arr):
22+
t = TDigest()
23+
t.merge(*arr)
24+
return np.array([t], dtype=object)
25+
26+
(axis,) = axis
27+
result = np.apply_along_axis(_, axis, digests)
28+
29+
return result
30+
31+
32+
def tdigest_aggregate(digests, q, axis=-1, keepdims=True):
33+
for idx in np.ndindex(digests.shape):
34+
digests[idx] = digests[idx].quantile(q)
35+
return digests

0 commit comments

Comments
 (0)