Skip to content

Commit 0eff0a7

Browse files
committed
Begin testing
1 parent 9d35bb2 commit 0eff0a7

File tree

6 files changed

+64
-12
lines changed

6 files changed

+64
-12
lines changed

flox/aggregate_flox.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def _prepare_for_flox(group_idx, array):
1414
if issorted:
1515
ordered_array = array
1616
else:
17-
perm = group_idx.argsort(kind="stable")
17+
kind = "stable" if isinstance(group_idx, np.ndarray) else None
18+
19+
perm = np.argsort(group_idx, kind=kind)
1820
group_idx = group_idx[..., perm]
1921
ordered_array = array[..., perm]
2022
return group_idx, ordered_array

flox/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ def factorize_(
570570
else:
571571
assert sort
572572
groups, idx = np.unique(flat, return_inverse=True)
573+
idx[np.isnan(flat)] = -1
574+
groups = groups[~np.isnan(groups)]
573575

574576
found_groups.append(groups)
575577
factorized.append(idx.reshape(groupvar.shape))

flox/xrutils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ def _select_along_axis(values, idx, axis):
294294
def nanfirst(values, axis, keepdims=False):
295295
if isinstance(axis, tuple):
296296
(axis,) = axis
297-
values = np.asarray(values)
298297
axis = normalize_axis_index(axis, values.ndim)
299298
idx_first = np.argmax(~pd.isnull(values), axis=axis)
300299
result = _select_along_axis(values, idx_first, axis)
@@ -307,7 +306,6 @@ def nanfirst(values, axis, keepdims=False):
307306
def nanlast(values, axis, keepdims=False):
308307
if isinstance(axis, tuple):
309308
(axis,) = axis
310-
values = np.asarray(values)
311309
axis = normalize_axis_index(axis, values.ndim)
312310
rev = (slice(None),) * axis + (slice(None, None, -1),)
313311
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)

tests/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
except ImportError:
2525
xr_types = () # type: ignore
2626

27+
try:
28+
import cupy as cp
29+
30+
cp_types = (cp.ndarray,)
31+
except ImportError:
32+
cp_types = () # type: ignore
33+
2734

2835
def _importorskip(modname, minversion=None):
2936
try:
@@ -88,6 +95,12 @@ def assert_equal(a, b, tolerance=None):
8895
if isinstance(b, list):
8996
b = np.array(b)
9097

98+
if isinstance(a, cp_types):
99+
a = a.get()
100+
101+
if isinstance(b, cp_types):
102+
b = b.get()
103+
91104
if isinstance(a, pd_types) or isinstance(b, pd_types):
92105
pd.testing.assert_index_equal(a, b)
93106
return

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,18 @@ def engine(request):
99
except ImportError:
1010
pytest.xfail()
1111
return request.param
12+
13+
14+
@pytest.fixture(scope="module", params=["numpy", "cupy"])
15+
def array_module(request):
16+
if request.param == "cupy":
17+
try:
18+
import cupy # noqa
19+
20+
return cupy
21+
except ImportError:
22+
pytest.xfail()
23+
elif request.param == "numpy":
24+
import numpy
25+
26+
return numpy

tests/test_core.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,31 +178,53 @@ def test_groupby_reduce(
178178
assert_equal(expected_result, result)
179179

180180

181-
def gen_array_by(size, func):
182-
by = np.ones(size[-1])
183-
rng = np.random.default_rng(12345)
181+
def maybe_skip_cupy(array_module, func, engine):
182+
if array_module is np:
183+
return
184+
185+
import cupy
186+
187+
assert array_module is cupy
188+
189+
if engine == "numba":
190+
pytest.skip()
191+
192+
if engine == "numpy" and ("prod" in func or "first" in func or "last" in func):
193+
pytest.xfail()
194+
elif engine == "flox" and not (
195+
"sum" in func or "mean" in func or "std" in func or "var" in func
196+
):
197+
pytest.xfail()
198+
199+
200+
def gen_array_by(size, func, array_module):
201+
xp = array_module
202+
by = xp.ones(size[-1])
203+
rng = xp.random.default_rng(12345)
184204
array = rng.random(size)
185205
if "nan" in func and "nanarg" not in func:
186-
array[[1, 4, 5], ...] = np.nan
206+
array[[1, 4, 5], ...] = xp.nan
187207
elif "nanarg" in func and len(size) > 1:
188-
array[[1, 4, 5], 1] = np.nan
208+
array[[1, 4, 5], 1] = xp.nan
189209
if func in ["any", "all"]:
190210
array = array > 0.5
191211
return array, by
192212

193213

194-
@pytest.mark.parametrize("chunks", [None, -1, 3, 4])
195214
@pytest.mark.parametrize("nby", [1, 2, 3])
196215
@pytest.mark.parametrize("size", ((12,), (12, 9)))
197-
@pytest.mark.parametrize("add_nan_by", [True, False])
216+
@pytest.mark.parametrize("chunks", [None, -1, 3, 4])
198217
@pytest.mark.parametrize("func", ALL_FUNCS)
199-
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
218+
@pytest.mark.parametrize("add_nan_by", [True, False])
219+
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine, array_module):
200220
if chunks is not None and not has_dask:
201221
pytest.skip()
202222
if "arg" in func and engine == "flox":
203223
pytest.skip()
204224

205-
array, by = gen_array_by(size, func)
225+
maybe_skip_cupy(array_module, func, engine)
226+
227+
array, by = gen_array_by(size, func, array_module)
206228
if chunks:
207229
array = dask.array.from_array(array, chunks=chunks)
208230
by = (by,) * nby

0 commit comments

Comments
 (0)