Skip to content

Commit 8ec0617

Browse files
dcherianIllviljan
andauthored
Consolidate validation of expected_groups (#193)
* Consolidate validation of expected_groups * Add tests * Switch by to nby Co-authored-by: Illviljan <[email protected]> * Type expected_groups properly Co-authored-by: Illviljan <[email protected]>
1 parent b58aa5f commit 8ec0617

File tree

4 files changed

+71
-23
lines changed

4 files changed

+71
-23
lines changed

flox/core.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
if TYPE_CHECKING:
2929
import dask.array.Array as DaskArray
3030

31-
T_ExpectedGroups = Union[Sequence, np.ndarray, pd.Index]
31+
T_Expect = Union[Sequence, np.ndarray, pd.Index, None]
32+
T_ExpectTuple = tuple[T_Expect, ...]
33+
T_ExpectedGroups = Union[T_Expect, T_ExpectTuple]
3234
T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
3335
T_Func = Union[str, Callable]
3436
T_Funcs = Union[T_Func, Sequence[T_Func]]
@@ -1476,7 +1478,7 @@ def _assert_by_is_aligned(shape, by):
14761478

14771479

14781480
def _convert_expected_groups_to_index(
1479-
expected_groups: T_ExpectedGroups, isbin: Sequence[bool], sort: bool
1481+
expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool
14801482
) -> tuple[pd.Index | None, ...]:
14811483
out: list[pd.Index | None] = []
14821484
for ex, isbin_ in zip(expected_groups, isbin):
@@ -1543,6 +1545,36 @@ def _factorize_multiple(by, expected_groups, any_by_dask, reindex):
15431545
return (group_idx,), final_groups, grp_shape
15441546

15451547

1548+
def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple:
1549+
1550+
if expected_groups is None:
1551+
return (None,) * nby
1552+
1553+
if nby == 1 and not isinstance(expected_groups, tuple):
1554+
return (np.asarray(expected_groups),)
1555+
1556+
if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list
1557+
raise ValueError(
1558+
"When grouping by multiple variables, expected_groups must be a tuple "
1559+
"of either arrays or objects convertible to an array (like lists). "
1560+
"For example `expected_groups=(np.array([1, 2, 3]), ['a', 'b', 'c'])`."
1561+
f"Received a {type(expected_groups).__name__} instead. "
1562+
"When grouping by a single variable, you can pass an array or something "
1563+
"convertible to an array for convenience: `expected_groups=['a', 'b', 'c']`."
1564+
)
1565+
1566+
if TYPE_CHECKING:
1567+
assert isinstance(expected_groups, tuple)
1568+
1569+
if len(expected_groups) != nby:
1570+
raise ValueError(
1571+
f"Must have same number of `expected_groups` (received {len(expected_groups)}) "
1572+
f" and variables to group by (received {nby})."
1573+
)
1574+
1575+
return expected_groups
1576+
1577+
15461578
def groupby_reduce(
15471579
array: np.ndarray | DaskArray,
15481580
*by: np.ndarray | DaskArray,
@@ -1679,24 +1711,17 @@ def groupby_reduce(
16791711
isbins = isbin
16801712
else:
16811713
isbins = (isbin,) * nby
1682-
if expected_groups is None:
1683-
expected_groups = (None,) * nby
16841714

16851715
_assert_by_is_aligned(array.shape, bys)
1716+
1717+
expected_groups = _validate_expected_groups(nby, expected_groups)
1718+
16861719
for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)):
16871720
if is_dask and (reindex or nby > 1) and expect is None:
16881721
raise ValueError(
16891722
f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array."
16901723
)
16911724

1692-
if nby == 1 and not isinstance(expected_groups, tuple):
1693-
expected_groups = (np.asarray(expected_groups),)
1694-
elif len(expected_groups) != nby:
1695-
raise ValueError(
1696-
f"Must have same number of `expected_groups` (received {len(expected_groups)}) "
1697-
f" and variables to group by (received {nby})."
1698-
)
1699-
17001725
# We convert to pd.Index since that lets us know if we are binning or not
17011726
# (pd.IntervalIndex or not)
17021727
expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)

flox/xarray.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .core import (
1414
_convert_expected_groups_to_index,
1515
_get_expected_groups,
16+
_validate_expected_groups,
1617
groupby_reduce,
1718
rechunk_for_blockwise as rechunk_array_for_blockwise,
1819
rechunk_for_cohorts as rechunk_array_for_cohorts,
@@ -216,16 +217,10 @@ def xarray_reduce(
216217
else:
217218
isbins = (isbin,) * nby
218219

219-
if expected_groups is None:
220-
expected_groups = (None,) * nby
221-
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list
222-
if nby == 1:
223-
expected_groups = (expected_groups,)
224-
else:
225-
raise ValueError("Needs better message.")
220+
expected_groups = _validate_expected_groups(nby, expected_groups)
226221

227222
if not sort:
228-
raise NotImplementedError
223+
raise NotImplementedError("sort must be True for xarray_reduce")
229224

230225
# eventually drop the variables we are grouping by
231226
maybe_drop = [b for b in by if isinstance(b, Hashable)]

tests/test_core.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,8 +1019,26 @@ def test_multiple_groupers(chunk, by1, by2, expected_groups) -> None:
10191019
assert_equal(expected, actual)
10201020

10211021

1022+
@pytest.mark.parametrize(
1023+
"expected_groups",
1024+
(
1025+
[None, None, None],
1026+
(None,),
1027+
),
1028+
)
1029+
def test_validate_expected_groups(expected_groups):
1030+
with pytest.raises(ValueError):
1031+
groupby_reduce(
1032+
np.ones((10,)),
1033+
np.ones((10,)),
1034+
np.ones((10,)),
1035+
expected_groups=expected_groups,
1036+
func="mean",
1037+
)
1038+
1039+
10221040
@requires_dask
1023-
def test_multiple_groupers_errors() -> None:
1041+
def test_validate_expected_groups_not_none_dask() -> None:
10241042
with pytest.raises(ValueError):
10251043
groupby_reduce(
10261044
dask.array.ones((5, 2)),

tests/test_xarray.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,22 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine):
168168

169169

170170
@requires_dask
171-
def test_dask_groupers_error():
171+
@pytest.mark.parametrize(
172+
"expected_groups",
173+
(None, (None, None), [[1, 2], [1, 2]]),
174+
)
175+
def test_validate_expected_groups(expected_groups):
172176
da = xr.DataArray(
173177
[1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])}
174178
)
175179
with pytest.raises(ValueError):
176-
xarray_reduce(da.chunk({"x": 2, "z": 1}), "labels", "labels2", func="count")
180+
xarray_reduce(
181+
da.chunk({"x": 1}),
182+
"labels",
183+
"labels2",
184+
func="count",
185+
expected_groups=expected_groups,
186+
)
177187

178188

179189
@requires_dask

0 commit comments

Comments
 (0)