Skip to content

Commit 9b8e27a

Browse files
committed
Better alignment check + error
xref #191
1 parent 27a4e9a commit 9b8e27a

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

flox/core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,9 +1465,11 @@ def _assert_by_is_aligned(shape, by):
14651465
for idx, b in enumerate(by):
14661466
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
14671467
raise ValueError(
1468-
"`array` and `by` arrays must be aligned "
1469-
"i.e. array.shape[-by.ndim :] == by.shape. "
1470-
"for every array in `by`."
1468+
"`array` and `by` arrays must be 'aligned' "
1469+
"so that such that by_ is broadcastable to array.shape[-by.ndim:] "
1470+
"for every array `by_` in `by`. "
1471+
"Either array.shape[-by_.ndim :] == by_.shape or the only differences "
1472+
"should be size-1 dimensions in by_."
14711473
f"Received array of shape {shape} but "
14721474
f"array {idx} in `by` has shape {b.shape}."
14731475
)

flox/xarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,14 @@ def xarray_reduce(
265265

266266
# broadcast to make sure grouper dimensions are present in the array.
267267
exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
268+
269+
try:
270+
xr.align(ds, *by_da, join="exact")
271+
except ValueError as e:
272+
raise ValueError(
273+
"Object being grouped must be exactly aligned with every array in `by`."
274+
) from e
275+
268276
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]
269277

270278
if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):

tests/test_xarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,12 @@ def test_mixed_grouping(chunk):
499499
assert (r.sel(v1=[3, 4, 5]) == 0).all().data
500500

501501

502+
def test_alignment_error():
503+
da = xr.DataArray(np.arange(10), dims="x", coords={"x": np.arange(10)})
504+
with pytest.raises(ValueError):
505+
xarray_reduce(da, da.x.sel(x=slice(5)), func="count")
506+
507+
502508
@pytest.mark.parametrize("add_nan", [True, False])
503509
@pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")])
504510
@pytest.mark.parametrize("dtype", [np.float32, np.float64])

0 commit comments

Comments
 (0)