Skip to content

Commit ba7a0fa

Browse files
oleksandr-pavlykpre-commit-ci[bot]Skylion007
authored
Expand dtype accessors (#3868)
* Added constructor based on typenum, based on PyArray_DescrFromType Added accessors for typenum, alignment, byteorder and flags fields of PyArray_Descr struct. * Added tests for new py::dtype constructor, and for accessors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed the comment for alignment method * Update include/pybind11/numpy.h Co-authored-by: Aaron Gokaslan <skylion.aaron@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aaron Gokaslan <skylion.aaron@gmail.com>
1 parent fa98804 commit ba7a0fa

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

include/pybind11/numpy.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,13 @@ class dtype : public object {
562562
m_ptr = from_args(std::move(args)).release().ptr();
563563
}
564564

565+
explicit dtype(int typenum)
566+
: object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) {
567+
if (m_ptr == nullptr) {
568+
throw error_already_set();
569+
}
570+
}
571+
565572
/// This is essentially the same as calling numpy.dtype(args) in Python.
566573
static dtype from_args(object args) {
567574
PyObject *ptr = nullptr;
@@ -596,6 +603,23 @@ class dtype : public object {
596603
return detail::array_descriptor_proxy(m_ptr)->type;
597604
}
598605

606+
/// type number of dtype.
607+
ssize_t num() const {
608+
// Note: The signature, `dtype::num` follows the naming of NumPy's public
609+
// Python API (i.e., ``dtype.num``), rather than its internal
610+
// C API (``PyArray_Descr::type_num``).
611+
return detail::array_descriptor_proxy(m_ptr)->type_num;
612+
}
613+
614+
/// Single character for byteorder
615+
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
616+
617+
/// Alignment of the data type
618+
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
619+
620+
/// Flags for the array descriptor
621+
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
622+
599623
private:
600624
static object _dtype_from_pep3118() {
601625
static PyObject *obj = module_::import("numpy.core._internal")

tests/test_numpy_dtypes.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ py::list test_dtype_ctors() {
291291
list.append(py::dtype(names, formats, offsets, 20));
292292
list.append(py::dtype(py::buffer_info((void *) 0, sizeof(unsigned int), "I", 1)));
293293
list.append(py::dtype(py::buffer_info((void *) 0, 0, "T{i:a:f:b:}", 1)));
294+
list.append(py::dtype(py::detail::npy_api::NPY_DOUBLE_));
294295
return list;
295296
}
296297

@@ -440,6 +441,34 @@ TEST_SUBMODULE(numpy_dtypes, m) {
440441
}
441442
return list;
442443
});
444+
m.def("test_dtype_num", [dtype_names]() {
445+
py::list list;
446+
for (const auto &dt_name : dtype_names) {
447+
list.append(py::dtype(dt_name).num());
448+
}
449+
return list;
450+
});
451+
m.def("test_dtype_byteorder", [dtype_names]() {
452+
py::list list;
453+
for (const auto &dt_name : dtype_names) {
454+
list.append(py::dtype(dt_name).byteorder());
455+
}
456+
return list;
457+
});
458+
m.def("test_dtype_alignment", [dtype_names]() {
459+
py::list list;
460+
for (const auto &dt_name : dtype_names) {
461+
list.append(py::dtype(dt_name).alignment());
462+
}
463+
return list;
464+
});
465+
m.def("test_dtype_flags", [dtype_names]() {
466+
py::list list;
467+
for (const auto &dt_name : dtype_names) {
468+
list.append(py::dtype(dt_name).flags());
469+
}
470+
return list;
471+
});
443472
m.def("test_dtype_methods", []() {
444473
py::list list;
445474
auto dt1 = py::dtype::of<int32_t>();

tests/test_numpy_dtypes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def test_dtype(simple_dtype):
160160
d1,
161161
np.dtype("uint32"),
162162
d2,
163+
np.dtype("d"),
163164
]
164165

165166
assert m.test_dtype_methods() == [
@@ -175,8 +176,13 @@ def test_dtype(simple_dtype):
175176
np.zeros(1, m.trailing_padding_dtype())
176177
)
177178

179+
expected_chars = "bhilqBHILQefdgFDG?MmO"
178180
assert m.test_dtype_kind() == list("iiiiiuuuuuffffcccbMmO")
179-
assert m.test_dtype_char_() == list("bhilqBHILQefdgFDG?MmO")
181+
assert m.test_dtype_char_() == list(expected_chars)
182+
assert m.test_dtype_num() == [np.dtype(ch).num for ch in expected_chars]
183+
assert m.test_dtype_byteorder() == [np.dtype(ch).byteorder for ch in expected_chars]
184+
assert m.test_dtype_alignment() == [np.dtype(ch).alignment for ch in expected_chars]
185+
assert m.test_dtype_flags() == [chr(np.dtype(ch).flags) for ch in expected_chars]
180186

181187

182188
def test_recarray(simple_dtype, packed_dtype):

0 commit comments

Comments
 (0)