diff --git a/asv_bench/benchmarks/cohorts.py b/asv_bench/benchmarks/cohorts.py index dbfbe8cd5..21707d448 100644 --- a/asv_bench/benchmarks/cohorts.py +++ b/asv_bench/benchmarks/cohorts.py @@ -1,8 +1,10 @@ import dask import numpy as np import pandas as pd +import xarray as xr import flox +from flox.xarray import xarray_reduce class Cohorts: @@ -12,7 +14,7 @@ def setup(self, *args, **kwargs): raise NotImplementedError def time_find_group_cohorts(self): - flox.core.find_group_cohorts(self.by, self.array.chunks) + flox.core.find_group_cohorts(self.by, [self.array.chunks[ax] for ax in self.axis]) # The cache clear fails dependably in CI # Not sure why try: @@ -125,3 +127,13 @@ class PerfectMonthlyRechunked(PerfectMonthly): def setup(self, *args, **kwargs): super().setup() super().rechunk() + + +def time_cohorts_era5_single(): + TIME = 900 # 92044 in Google ARCO ERA5 + da = xr.DataArray( + dask.array.ones((TIME, 721, 1440), chunks=(1, -1, -1)), + dims=("time", "lat", "lon"), + coords=dict(time=pd.date_range("1959-01-01", freq="6H", periods=TIME)), + ) + xarray_reduce(da, da.time.dt.day, method="cohorts", func="any") diff --git a/flox/core.py b/flox/core.py index e5518b551..e2784ae99 100644 --- a/flox/core.py +++ b/flox/core.py @@ -208,15 +208,20 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: # 1. First subset the array appropriately axis = range(-labels.ndim, 0) # Easier to create a dask array and use the .blocks property - array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks) + array = dask.array.empty(tuple(sum(c) for c in chunks), chunks=chunks) labels = np.broadcast_to(labels, array.shape[-labels.ndim :]) # Iterate over each block and create a new block of same shape with "chunk number" shape = tuple(array.blocks.shape[ax] for ax in axis) - blocks = np.empty(math.prod(shape), dtype=object) - for idx, block in enumerate(array.blocks.ravel()): - blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx) - which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1) + # Use a numpy object array to enable assignment in the loop + # TODO: is it possible to just use a nested list? + # That is what we need for `np.block` + blocks = np.empty(shape, dtype=object) + array_chunks = tuple(np.array(c) for c in array.chunks) + for idx, blockindex in enumerate(np.ndindex(array.numblocks)): + chunkshape = tuple(c[i] for c, i in zip(array_chunks, blockindex)) + blocks[blockindex] = np.full(chunkshape, idx) + which_chunk = np.block(blocks.tolist()).reshape(-1) raveled = labels.reshape(-1) # these are chunks where a label is present @@ -229,7 +234,11 @@ def invert(x) -> tuple[np.ndarray, ...]: chunks_cohorts = tlz.groupby(invert, label_chunks.keys()) - if merge: + # If our dataset has chunksize one along the axis, + # then no merging is possible. + single_chunks = all((ac == 1).all() for ac in array_chunks) + + if merge and not single_chunks: # First sort by number of chunks occupied by cohort sorted_chunks_cohorts = dict( sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)