Skip to content

Commit cc88278

Browse files
committed
Add drop-in array/proxy API test
1 parent 71b199b commit cc88278

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

tests/test_numpy_array.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ template<typename... Ix> arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(
6868
sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
6969
sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
7070

71+
template <typename T, typename T2> py::handle auxiliaries(T &&r, T2 &&r2) {
72+
if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
73+
py::list l;
74+
l.append(*r.data(0, 0));
75+
l.append(*r2.mutable_data(0, 0));
76+
l.append(r.data(0, 1) == r2.mutable_data(0, 1));
77+
l.append(r.ndim());
78+
l.append(r.itemsize());
79+
l.append(r.shape(0));
80+
l.append(r.shape(1));
81+
l.append(r.size());
82+
l.append(r.nbytes());
83+
return l.release();
84+
}
85+
7186
test_initializer numpy_array([](py::module &m) {
7287
auto sm = m.def_submodule("array");
7388

@@ -221,17 +236,7 @@ test_initializer numpy_array([](py::module &m) {
221236
sm.def("proxy_auxiliaries2", [](py::array_t<double> a) {
222237
auto r = a.unchecked<2>();
223238
auto r2 = a.mutable_unchecked<2>();
224-
py::list l;
225-
l.append(*r.data(0, 0));
226-
l.append(*r2.mutable_data(0, 0));
227-
l.append(r.data(0, 1) == r2.mutable_data(0, 1));
228-
l.append(r.ndim());
229-
l.append(r.itemsize());
230-
l.append(r.shape(0));
231-
l.append(r.shape(1));
232-
l.append(r.size());
233-
l.append(r.nbytes());
234-
return l.release();
239+
return auxiliaries(r, r2);
235240
});
236241

237242
// Same as the above, but without a compile-time dimensions specification:
@@ -253,19 +258,10 @@ test_initializer numpy_array([](py::module &m) {
253258
return a;
254259
});
255260
sm.def("proxy_auxiliaries2_dyn", [](py::array_t<double> a) {
256-
auto r = a.unchecked();
257-
if (r.ndim() != 2) throw std::domain_error("error: ndim != 2");
258-
auto r2 = a.mutable_unchecked();
259-
py::list l;
260-
l.append(*r.data(0, 0));
261-
l.append(*r2.mutable_data(0, 0));
262-
l.append(r.data(0, 1) == r2.mutable_data(0, 1));
263-
l.append(r.ndim());
264-
l.append(r.itemsize());
265-
l.append(r.shape(0));
266-
l.append(r.shape(1));
267-
l.append(r.size());
268-
l.append(r.nbytes());
269-
return l.release();
261+
return auxiliaries(a.unchecked(), a.mutable_unchecked());
262+
});
263+
264+
sm.def("array_auxiliaries2", [](py::array_t<double> a) {
265+
return auxiliaries(a, a);
270266
});
271267
});

tests/test_numpy_array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def test_greedy_string_overload(): # issue 685
343343

344344
def test_array_unchecked_fixed_dims(msg):
345345
from pybind11_tests.array import (proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm,
346-
proxy_auxiliaries2)
346+
proxy_auxiliaries2, array_auxiliaries2)
347347

348348
z1 = np.array([[1, 2], [3, 4]], dtype='float64')
349349
proxy_add2(z1, 10)
@@ -362,10 +362,12 @@ def test_array_unchecked_fixed_dims(msg):
362362
assert proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55
363363

364364
assert proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
365+
assert proxy_auxiliaries2(z1) == array_auxiliaries2(z1)
365366

366367

367368
def test_array_unchecked_dyn_dims(msg):
368-
from pybind11_tests.array import proxy_add2_dyn, proxy_init3_dyn, proxy_auxiliaries2_dyn
369+
from pybind11_tests.array import (proxy_add2_dyn, proxy_init3_dyn, proxy_auxiliaries2_dyn,
370+
array_auxiliaries2)
369371
z1 = np.array([[1, 2], [3, 4]], dtype='float64')
370372
proxy_add2_dyn(z1, 10)
371373
assert np.all(z1 == [[11, 12], [13, 14]])
@@ -374,3 +376,4 @@ def test_array_unchecked_dyn_dims(msg):
374376
assert np.all(proxy_init3_dyn(3.0) == expect_c)
375377

376378
assert proxy_auxiliaries2_dyn(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
379+
assert proxy_auxiliaries2_dyn(z1) == array_auxiliaries2(z1)

0 commit comments

Comments
 (0)