1- use crate :: npyffi:: { NpyTypes , PyArray_Descr , PY_ARRAY_API } ;
1+ //! Implements conversion utitlities.
2+ use crate :: npyffi:: { NpyTypes , PyArray_Descr , NPY_TYPES , PY_ARRAY_API } ;
3+ pub use num_complex:: Complex32 as c32;
4+ pub use num_complex:: Complex64 as c64;
25use pyo3:: ffi;
36use pyo3:: prelude:: * ;
7+ use pyo3:: types:: PyType ;
8+ use pyo3:: { AsPyPointer , PyNativeType } ;
49use std:: os:: raw:: c_int;
510
611pub struct PyArrayDescr ( PyAny ) ;
@@ -21,3 +26,179 @@ unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int {
2126 PY_ARRAY_API . get_type_object ( NpyTypes :: PyArrayDescr_Type ) ,
2227 )
2328}
29+
30+ impl PyArrayDescr {
31+ pub fn as_dtype_ptr ( & self ) -> * mut PyArray_Descr {
32+ self . as_ptr ( ) as _
33+ }
34+
35+ pub fn get_type ( & self ) -> & PyType {
36+ let dtype_type_ptr = unsafe { * self . as_dtype_ptr ( ) } . typeobj ;
37+ unsafe { PyType :: from_type_ptr ( self . py ( ) , dtype_type_ptr) }
38+ }
39+
40+ pub fn get_typenum ( & self ) -> std:: os:: raw:: c_int {
41+ unsafe { * self . as_dtype_ptr ( ) } . type_num
42+ }
43+
44+ pub fn get_datatype ( & self ) -> Option < DataType > {
45+ DataType :: from_typenum ( self . get_typenum ( ) )
46+ }
47+
48+ pub fn from_npy_type ( py : Python , npy_type : NPY_TYPES ) -> & Self {
49+ unsafe {
50+ let descr = PY_ARRAY_API . PyArray_DescrFromType ( npy_type as i32 ) ;
51+ py. from_owned_ptr ( descr as _ )
52+ }
53+ }
54+ }
55+
56+ /// An enum type represents numpy data type.
57+ ///
58+ /// This type is mainly for displaying error, and user don't have to use it directly.
59+ #[ derive( Clone , Debug , Eq , PartialEq ) ]
60+ pub enum DataType {
61+ Bool ,
62+ Int8 ,
63+ Int16 ,
64+ Int32 ,
65+ Int64 ,
66+ Uint8 ,
67+ Uint16 ,
68+ Uint32 ,
69+ Uint64 ,
70+ Float32 ,
71+ Float64 ,
72+ Complex32 ,
73+ Complex64 ,
74+ Object ,
75+ }
76+
77+ impl DataType {
78+ pub fn from_typenum ( typenum : c_int ) -> Option < Self > {
79+ Some ( match typenum {
80+ x if x == NPY_TYPES :: NPY_BOOL as i32 => DataType :: Bool ,
81+ x if x == NPY_TYPES :: NPY_BYTE as i32 => DataType :: Int8 ,
82+ x if x == NPY_TYPES :: NPY_SHORT as i32 => DataType :: Int16 ,
83+ x if x == NPY_TYPES :: NPY_INT as i32 => DataType :: Int32 ,
84+ x if x == NPY_TYPES :: NPY_LONG as i32 => return DataType :: from_clong ( false ) ,
85+ x if x == NPY_TYPES :: NPY_LONGLONG as i32 => DataType :: Int64 ,
86+ x if x == NPY_TYPES :: NPY_UBYTE as i32 => DataType :: Uint8 ,
87+ x if x == NPY_TYPES :: NPY_USHORT as i32 => DataType :: Uint16 ,
88+ x if x == NPY_TYPES :: NPY_UINT as i32 => DataType :: Uint32 ,
89+ x if x == NPY_TYPES :: NPY_ULONG as i32 => return DataType :: from_clong ( true ) ,
90+ x if x == NPY_TYPES :: NPY_ULONGLONG as i32 => DataType :: Uint64 ,
91+ x if x == NPY_TYPES :: NPY_FLOAT as i32 => DataType :: Float32 ,
92+ x if x == NPY_TYPES :: NPY_DOUBLE as i32 => DataType :: Float64 ,
93+ x if x == NPY_TYPES :: NPY_CFLOAT as i32 => DataType :: Complex32 ,
94+ x if x == NPY_TYPES :: NPY_CDOUBLE as i32 => DataType :: Complex64 ,
95+ x if x == NPY_TYPES :: NPY_OBJECT as i32 => DataType :: Object ,
96+ _ => return None ,
97+ } )
98+ }
99+
100+ pub fn from_dtype ( dtype : & crate :: PyArrayDescr ) -> Option < Self > {
101+ Self :: from_typenum ( dtype. get_typenum ( ) )
102+ }
103+
104+ #[ inline]
105+ pub fn into_ctype ( self ) -> NPY_TYPES {
106+ match self {
107+ DataType :: Bool => NPY_TYPES :: NPY_BOOL ,
108+ DataType :: Int8 => NPY_TYPES :: NPY_BYTE ,
109+ DataType :: Int16 => NPY_TYPES :: NPY_SHORT ,
110+ DataType :: Int32 => NPY_TYPES :: NPY_INT ,
111+ DataType :: Int64 => NPY_TYPES :: NPY_LONGLONG ,
112+ DataType :: Uint8 => NPY_TYPES :: NPY_UBYTE ,
113+ DataType :: Uint16 => NPY_TYPES :: NPY_USHORT ,
114+ DataType :: Uint32 => NPY_TYPES :: NPY_UINT ,
115+ DataType :: Uint64 => NPY_TYPES :: NPY_ULONGLONG ,
116+ DataType :: Float32 => NPY_TYPES :: NPY_FLOAT ,
117+ DataType :: Float64 => NPY_TYPES :: NPY_DOUBLE ,
118+ DataType :: Complex32 => NPY_TYPES :: NPY_CFLOAT ,
119+ DataType :: Complex64 => NPY_TYPES :: NPY_CDOUBLE ,
120+ DataType :: Object => NPY_TYPES :: NPY_OBJECT ,
121+ }
122+ }
123+
124+ #[ inline( always) ]
125+ fn from_clong ( is_usize : bool ) -> Option < Self > {
126+ if cfg ! ( any( target_pointer_width = "32" , windows) ) {
127+ Some ( if is_usize {
128+ DataType :: Uint32
129+ } else {
130+ DataType :: Int32
131+ } )
132+ } else if cfg ! ( all( target_pointer_width = "64" , not( windows) ) ) {
133+ Some ( if is_usize {
134+ DataType :: Uint64
135+ } else {
136+ DataType :: Int64
137+ } )
138+ } else {
139+ None
140+ }
141+ }
142+ }
143+
144+ /// Represents that a type can be an element of `PyArray`.
145+ pub trait Element : Clone {
146+ const DATA_TYPE : DataType ;
147+
148+ fn is_same_type ( dtype : & PyArrayDescr ) -> bool ;
149+
150+ #[ inline]
151+ fn npy_type ( ) -> NPY_TYPES {
152+ Self :: DATA_TYPE . into_ctype ( )
153+ }
154+
155+ fn get_dtype ( py : Python ) -> & PyArrayDescr {
156+ PyArrayDescr :: from_npy_type ( py, Self :: npy_type ( ) )
157+ }
158+ }
159+
160+ macro_rules! impl_num_element {
161+ ( $t: ty, $npy_dat_t: ident $( , $npy_types: ident) +) => {
162+ impl Element for $t {
163+ const DATA_TYPE : DataType = DataType :: $npy_dat_t;
164+ fn is_same_type( dtype: & PyArrayDescr ) -> bool {
165+ $( dtype. get_typenum( ) == NPY_TYPES :: $npy_types as i32 ||) + false
166+ }
167+ }
168+ } ;
169+ }
170+
171+ impl_num_element ! ( bool , Bool , NPY_BOOL ) ;
172+ impl_num_element ! ( i8 , Int8 , NPY_BYTE ) ;
173+ impl_num_element ! ( i16 , Int16 , NPY_SHORT ) ;
174+ impl_num_element ! ( u8 , Uint8 , NPY_UBYTE ) ;
175+ impl_num_element ! ( u16 , Uint16 , NPY_USHORT ) ;
176+ impl_num_element ! ( f32 , Float32 , NPY_FLOAT ) ;
177+ impl_num_element ! ( f64 , Float64 , NPY_DOUBLE ) ;
178+ impl_num_element ! ( c32, Complex32 , NPY_CFLOAT ) ;
179+ impl_num_element ! ( c64, Complex64 , NPY_CDOUBLE ) ;
180+
181+ cfg_if ! {
182+ if #[ cfg( all( target_pointer_width = "64" , windows) ) ] {
183+ impl_num_element!( usize , Uint64 , NPY_ULONGLONG ) ;
184+ } else if #[ cfg( all( target_pointer_width = "64" , not( windows) ) ) ] {
185+ impl_num_element!( usize , Uint64 , NPY_ULONG , NPY_ULONGLONG ) ;
186+ } else if #[ cfg( all( target_pointer_width = "32" , windows) ) ] {
187+ impl_num_element!( usize , Uint32 , NPY_UINT , NPY_ULONG ) ;
188+ } else if #[ cfg( all( target_pointer_width = "32" , not( windows) ) ) ] {
189+ impl_num_element!( usize , Uint32 , NPY_UINT ) ;
190+ }
191+ }
192+ cfg_if ! {
193+ if #[ cfg( any( target_pointer_width = "32" , windows) ) ] {
194+ impl_num_element!( i32 , Int32 , NPY_INT , NPY_LONG ) ;
195+ impl_num_element!( u32 , Uint32 , NPY_UINT , NPY_ULONG ) ;
196+ impl_num_element!( i64 , Int64 , NPY_LONGLONG ) ;
197+ impl_num_element!( u64 , Uint64 , NPY_ULONGLONG ) ;
198+ } else if #[ cfg( all( target_pointer_width = "64" , not( windows) ) ) ] {
199+ impl_num_element!( i32 , Int32 , NPY_INT ) ;
200+ impl_num_element!( u32 , Uint32 , NPY_UINT ) ;
201+ impl_num_element!( i64 , Int64 , NPY_LONG , NPY_LONGLONG ) ;
202+ impl_num_element!( u64 , Uint64 , NPY_ULONG , NPY_ULONGLONG ) ;
203+ }
204+ }
0 commit comments