Skip to content

Commit 18e1bd2

Browse files
committed
Use py::detail::compare_buffer_info<T>::compare() to validate the format_descriptor<T>::format() strings.
1 parent d432ce7 commit 18e1bd2

File tree

2 files changed

+79
-66
lines changed

2 files changed

+79
-66
lines changed

tests/test_buffers.cpp

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,39 @@
1414
#include "pybind11_tests.h"
1515

1616
TEST_SUBMODULE(buffers, m) {
17-
m.def("format_descriptor_format", [](const std::string &cpp_name) {
18-
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
19-
static auto *table = new std::map<std::string, std::string>;
20-
if (table->empty()) {
17+
m.attr("std_is_same_double_long_double") = std::is_same<double, long double>::value;
18+
19+
m.def("format_descriptor_format_compare",
20+
[](const std::string &cpp_name, const py::buffer &buffer) {
21+
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
22+
static auto *format_table = new std::map<std::string, std::string>;
23+
static auto *compare_table
24+
= new std::map<std::string, bool (*)(const py::buffer_info &)>;
25+
if (format_table->empty()) {
2126
#define PYBIND11_ASSIGN_HELPER(...) \
22-
(*table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format();
23-
PYBIND11_ASSIGN_HELPER(PyObject *)
24-
PYBIND11_ASSIGN_HELPER(bool)
25-
PYBIND11_ASSIGN_HELPER(std::int8_t)
26-
PYBIND11_ASSIGN_HELPER(std::uint8_t)
27-
PYBIND11_ASSIGN_HELPER(std::int16_t)
28-
PYBIND11_ASSIGN_HELPER(std::uint16_t)
29-
PYBIND11_ASSIGN_HELPER(std::int32_t)
30-
PYBIND11_ASSIGN_HELPER(std::uint32_t)
31-
PYBIND11_ASSIGN_HELPER(std::int64_t)
32-
PYBIND11_ASSIGN_HELPER(std::uint64_t)
33-
PYBIND11_ASSIGN_HELPER(float)
34-
PYBIND11_ASSIGN_HELPER(double)
35-
PYBIND11_ASSIGN_HELPER(long double)
36-
PYBIND11_ASSIGN_HELPER(std::complex<float>)
37-
PYBIND11_ASSIGN_HELPER(std::complex<double>)
38-
PYBIND11_ASSIGN_HELPER(std::complex<long double>)
27+
(*format_table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format(); \
28+
(*compare_table)[#__VA_ARGS__] = py::detail::compare_buffer_info<__VA_ARGS__>::compare;
29+
PYBIND11_ASSIGN_HELPER(PyObject *)
30+
PYBIND11_ASSIGN_HELPER(bool)
31+
PYBIND11_ASSIGN_HELPER(std::int8_t)
32+
PYBIND11_ASSIGN_HELPER(std::uint8_t)
33+
PYBIND11_ASSIGN_HELPER(std::int16_t)
34+
PYBIND11_ASSIGN_HELPER(std::uint16_t)
35+
PYBIND11_ASSIGN_HELPER(std::int32_t)
36+
PYBIND11_ASSIGN_HELPER(std::uint32_t)
37+
PYBIND11_ASSIGN_HELPER(std::int64_t)
38+
PYBIND11_ASSIGN_HELPER(std::uint64_t)
39+
PYBIND11_ASSIGN_HELPER(float)
40+
PYBIND11_ASSIGN_HELPER(double)
41+
PYBIND11_ASSIGN_HELPER(long double)
42+
PYBIND11_ASSIGN_HELPER(std::complex<float>)
43+
PYBIND11_ASSIGN_HELPER(std::complex<double>)
44+
PYBIND11_ASSIGN_HELPER(std::complex<long double>)
3945
#undef PYBIND11_ASSIGN_HELPER
40-
}
41-
return (*table)[cpp_name];
42-
});
46+
}
47+
return std::pair<std::string, bool>((*format_table)[cpp_name],
48+
(*compare_table)[cpp_name](buffer.request()));
49+
});
4350

4451
// test_from_python / test_to_python:
4552
class Matrix {

tests/test_buffers.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,49 +10,55 @@
1010

1111
np = pytest.importorskip("numpy")
1212

13-
if env.WIN:
14-
# Windows does not have these (see e.g. #1908). But who knows, maybe later?
15-
np_float128_or_none = getattr(np, "float128", None)
16-
np_complex256_or_none = getattr(np, "complex256", None)
13+
if m.std_is_same_double_long_double: # Windows.
14+
np_float128 = None
15+
np_complex256 = None
1716
else:
18-
np_float128_or_none = np.float128
19-
np_complex256_or_none = np.complex256
20-
21-
22-
@pytest.mark.parametrize(
23-
("cpp_name", "expected_fmts", "np_array_dtype"),
24-
[
25-
("PyObject *", ["O"], object),
26-
("bool", ["?"], np.bool_),
27-
("std::int8_t", ["b"], np.int8),
28-
("std::uint8_t", ["B"], np.uint8),
29-
("std::int16_t", ["h"], np.int16),
30-
("std::uint16_t", ["H"], np.uint16),
31-
("std::int32_t", ["i"], np.int32),
32-
("std::uint32_t", ["I"], np.uint32),
33-
("std::int64_t", ["q"], np.int64),
34-
("std::uint64_t", ["Q"], np.uint64),
35-
("float", ["f"], np.float32),
36-
("double", ["d"], np.float64),
37-
("long double", ["g", "d"], np_float128_or_none),
38-
("std::complex<float>", ["Zf"], np.complex64),
39-
("std::complex<double>", ["Zd"], np.complex128),
40-
("std::complex<long double>", ["Zg", "Zd"], np_complex256_or_none),
41-
],
42-
)
43-
def test_format_descriptor_format(cpp_name, expected_fmts, np_array_dtype):
44-
fmt = m.format_descriptor_format(cpp_name)
45-
assert fmt in expected_fmts
46-
47-
if np_array_dtype is not None:
48-
na = np.array([], dtype=np_array_dtype)
49-
bi = m.get_buffer_info(na)
50-
bif = bi.format
51-
if bif == "l":
52-
bif = "i" if bi.itemsize == 4 else "q"
53-
elif bif == "L":
54-
bif = "I" if bi.itemsize == 4 else "Q"
55-
assert bif == fmt
17+
np_float128 = np.float128
18+
np_complex256 = np.complex256
19+
20+
CPP_NAME_FORMAT_NP_DTYPE_TABLE = [
21+
item
22+
for item in [
23+
("PyObject *", "O", object),
24+
("bool", "?", np.bool_),
25+
("std::int8_t", "b", np.int8),
26+
("std::uint8_t", "B", np.uint8),
27+
("std::int16_t", "h", np.int16),
28+
("std::uint16_t", "H", np.uint16),
29+
("std::int32_t", "i", np.int32),
30+
("std::uint32_t", "I", np.uint32),
31+
("std::int64_t", "q", np.int64),
32+
("std::uint64_t", "Q", np.uint64),
33+
("float", "f", np.float32),
34+
("double", "d", np.float64),
35+
("long double", "g", np_float128),
36+
("std::complex<float>", "Zf", np.complex64),
37+
("std::complex<double>", "Zd", np.complex128),
38+
("std::complex<long double>", "Zg", np_complex256),
39+
]
40+
if item[-1] is not None
41+
]
42+
CPP_NAME_FORMAT_TABLE = [
43+
(cpp_name, format) for cpp_name, format, _ in CPP_NAME_FORMAT_NP_DTYPE_TABLE
44+
]
45+
CPP_NAME_NP_DTYPE_TABLE = [
46+
(cpp_name, np_dtype) for cpp_name, _, np_dtype in CPP_NAME_FORMAT_NP_DTYPE_TABLE
47+
]
48+
49+
50+
@pytest.mark.parametrize(("cpp_name", "np_dtype"), CPP_NAME_NP_DTYPE_TABLE)
51+
def test_format_descriptor_format_compare(cpp_name, np_dtype):
52+
np_array = np.array([], dtype=np_dtype)
53+
for other_cpp_name, expected_format in CPP_NAME_FORMAT_TABLE:
54+
format, np_array_is_matching = m.format_descriptor_format_compare(
55+
other_cpp_name, np_array
56+
)
57+
assert format == expected_format
58+
if other_cpp_name == cpp_name:
59+
assert np_array_is_matching
60+
else:
61+
assert not np_array_is_matching
5662

5763

5864
def test_from_python():

0 commit comments

Comments
 (0)