diff --git a/flox/xarray.py b/flox/xarray.py index 3200d7f0a..7f0f95d0f 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -201,10 +201,10 @@ def xarray_reduce( >>> da = da = xr.ones_like(labels) >>> # Sum all values in da that matches the elements in the group index: >>> xarray_reduce(da, labels, func="sum") - + Size: 32B array([3, 2, 2, 2]) Coordinates: - * label (label) int64 0 1 2 3 + * label (label) int64 32B 0 1 2 3 """ if skipna is not None and isinstance(func, Aggregation): @@ -303,14 +303,16 @@ def xarray_reduce( # reducing along a dimension along which groups do not vary # This is really just a normal reduction. # This is not right when binning so we exclude. - if isinstance(func, str): - dsfunc = func[3:] if skipna else func - else: + if isinstance(func, str) and func.startswith("nan"): + raise ValueError(f"Specify func={func[3:]}, skipna=True instead of func={func}") + elif isinstance(func, Aggregation): raise NotImplementedError( "func must be a string when reducing along a dimension not present in `by`" ) - # TODO: skipna needs test - result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna) + # skipna is not supported for all reductions + # https://github.com/pydata/xarray/issues/8819 + kwargs = {"skipna": skipna} if skipna is not None else {} + result = getattr(ds_broad, func)(dim=dim_tuple, **kwargs) if isinstance(obj, xr.DataArray): return obj._from_temp_dataset(result) else: diff --git a/tests/__init__.py b/tests/__init__.py index fbba74c73..2ebeb5ad2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -124,3 +124,35 @@ def assert_equal_tuple(a, b): np.testing.assert_array_equal(a_, b_) else: assert a_ == b_ + + +SCIPY_STATS_FUNCS = ("mode", "nanmode") +BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS +ALL_FUNCS = ( + "sum", + "nansum", + "argmax", + "nanfirst", + "nanargmax", + "prod", + "nanprod", + "mean", + "nanmean", + "var", + "nanvar", + "std", + "nanstd", + "max", + "nanmax", + "min", + "nanmin", + "argmin", + "nanargmin", + "any", + "all", + "nanlast", + "median", + "nanmedian", + "quantile", + "nanquantile", +) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS) diff --git a/tests/test_core.py b/tests/test_core.py index 6837eb963..11cd19e3e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -31,12 +31,14 @@ ) from . import ( + ALL_FUNCS, + BLOCKWISE_FUNCS, + SCIPY_STATS_FUNCS, assert_equal, assert_equal_tuple, has_dask, raise_if_dask_computes, requires_dask, - requires_scipy, ) logger = logging.getLogger("flox") @@ -60,36 +62,6 @@ def dask_array_ones(*args): DEFAULT_QUANTILE = 0.9 -SCIPY_STATS_FUNCS = ("mode", "nanmode") -BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS -ALL_FUNCS = ( - "sum", - "nansum", - "argmax", - "nanfirst", - "nanargmax", - "prod", - "nanprod", - "mean", - "nanmean", - "var", - "nanvar", - "std", - "nanstd", - "max", - "nanmax", - "min", - "nanmin", - "argmin", - "nanargmin", - "any", - "all", - "nanlast", - "median", - "nanmedian", - "quantile", - "nanquantile", -) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS) if TYPE_CHECKING: from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 95ab2eff3..97b5674ba 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -9,6 +9,7 @@ from flox.xarray import rechunk_for_blockwise, xarray_reduce from . import ( + ALL_FUNCS, assert_equal, has_dask, raise_if_dask_computes, @@ -710,3 +711,32 @@ def test_multiple_quantiles(q, chunk, by_ndim, skipna): with xr.set_options(use_flox=False): expected = da.groupby(by).quantile(q, skipna=skipna) xr.testing.assert_allclose(expected, actual) + + +@pytest.mark.parametrize("func", ALL_FUNCS) +def test_direct_reduction(func): + if "arg" in func or "mode" in func: + pytest.skip() + # regression test for https://github.com/pydata/xarray/issues/8819 + rand = np.random.choice([True, False], size=(2, 3)) + if func not in ["any", "all"]: + rand = rand.astype(float) + + if "nan" in func: + func = func[3:] + kwargs = {"skipna": True} + else: + kwargs = {} + + if "first" not in func and "last" not in func: + kwargs["dim"] = "y" + + if "quantile" in func: + kwargs["q"] = 0.9 + + data = xr.DataArray(rand, dims=("x", "y"), coords={"x": [10, 20], "y": [0, 1, 2]}) + with xr.set_options(use_flox=True): + actual = getattr(data.groupby("x", squeeze=False), func)(**kwargs) + with xr.set_options(use_flox=False): + expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs) + xr.testing.assert_identical(expected, actual)