|
28 | 28 | if TYPE_CHECKING:
|
29 | 29 | import dask.array.Array as DaskArray
|
30 | 30 |
|
31 |
| - T_ExpectedGroups = Union[Sequence, np.ndarray, pd.Index] |
| 31 | + T_Expect = Union[Sequence, np.ndarray, pd.Index, None] |
| 32 | + T_ExpectTuple = tuple[T_Expect, ...] |
| 33 | + T_ExpectedGroups = Union[T_Expect, T_ExpectTuple] |
32 | 34 | T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
|
33 | 35 | T_Func = Union[str, Callable]
|
34 | 36 | T_Funcs = Union[T_Func, Sequence[T_Func]]
|
@@ -1476,7 +1478,7 @@ def _assert_by_is_aligned(shape, by):
|
1476 | 1478 |
|
1477 | 1479 |
|
1478 | 1480 | def _convert_expected_groups_to_index(
|
1479 |
| - expected_groups: T_ExpectedGroups, isbin: Sequence[bool], sort: bool |
| 1481 | + expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool |
1480 | 1482 | ) -> tuple[pd.Index | None, ...]:
|
1481 | 1483 | out: list[pd.Index | None] = []
|
1482 | 1484 | for ex, isbin_ in zip(expected_groups, isbin):
|
@@ -1543,6 +1545,36 @@ def _factorize_multiple(by, expected_groups, any_by_dask, reindex):
|
1543 | 1545 | return (group_idx,), final_groups, grp_shape
|
1544 | 1546 |
|
1545 | 1547 |
|
| 1548 | +def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple: |
| 1549 | + |
| 1550 | + if expected_groups is None: |
| 1551 | + return (None,) * nby |
| 1552 | + |
| 1553 | + if nby == 1 and not isinstance(expected_groups, tuple): |
| 1554 | + return (np.asarray(expected_groups),) |
| 1555 | + |
| 1556 | + if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list |
| 1557 | + raise ValueError( |
| 1558 | + "When grouping by multiple variables, expected_groups must be a tuple " |
| 1559 | + "of either arrays or objects convertible to an array (like lists). " |
| 1560 | + "For example `expected_groups=(np.array([1, 2, 3]), ['a', 'b', 'c'])`." |
| 1561 | + f"Received a {type(expected_groups).__name__} instead. " |
| 1562 | + "When grouping by a single variable, you can pass an array or something " |
| 1563 | + "convertible to an array for convenience: `expected_groups=['a', 'b', 'c']`." |
| 1564 | + ) |
| 1565 | + |
| 1566 | + if TYPE_CHECKING: |
| 1567 | + assert isinstance(expected_groups, tuple) |
| 1568 | + |
| 1569 | + if len(expected_groups) != nby: |
| 1570 | + raise ValueError( |
| 1571 | + f"Must have same number of `expected_groups` (received {len(expected_groups)}) " |
| 1572 | + f" and variables to group by (received {nby})." |
| 1573 | + ) |
| 1574 | + |
| 1575 | + return expected_groups |
| 1576 | + |
| 1577 | + |
1546 | 1578 | def groupby_reduce(
|
1547 | 1579 | array: np.ndarray | DaskArray,
|
1548 | 1580 | *by: np.ndarray | DaskArray,
|
@@ -1679,24 +1711,17 @@ def groupby_reduce(
|
1679 | 1711 | isbins = isbin
|
1680 | 1712 | else:
|
1681 | 1713 | isbins = (isbin,) * nby
|
1682 |
| - if expected_groups is None: |
1683 |
| - expected_groups = (None,) * nby |
1684 | 1714 |
|
1685 | 1715 | _assert_by_is_aligned(array.shape, bys)
|
| 1716 | + |
| 1717 | + expected_groups = _validate_expected_groups(nby, expected_groups) |
| 1718 | + |
1686 | 1719 | for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)):
|
1687 | 1720 | if is_dask and (reindex or nby > 1) and expect is None:
|
1688 | 1721 | raise ValueError(
|
1689 | 1722 | f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array."
|
1690 | 1723 | )
|
1691 | 1724 |
|
1692 |
| - if nby == 1 and not isinstance(expected_groups, tuple): |
1693 |
| - expected_groups = (np.asarray(expected_groups),) |
1694 |
| - elif len(expected_groups) != nby: |
1695 |
| - raise ValueError( |
1696 |
| - f"Must have same number of `expected_groups` (received {len(expected_groups)}) " |
1697 |
| - f" and variables to group by (received {nby})." |
1698 |
| - ) |
1699 |
| - |
1700 | 1725 | # We convert to pd.Index since that lets us know if we are binning or not
|
1701 | 1726 | # (pd.IntervalIndex or not)
|
1702 | 1727 | expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)
|
|
0 commit comments