1- use crate :: array:: PyArray ;
2- use crate :: npyffi;
3- use crate :: npyffi:: array:: PY_ARRAY_API ;
4- use crate :: npyffi:: objects;
5- use crate :: npyffi:: types:: { npy_uint32, NPY_CASTING , NPY_ORDER } ;
6- use pyo3:: prelude:: * ;
1+ use crate :: array:: { PyArray , PyArrayDyn } ;
2+ use crate :: npyffi:: {
3+ array:: PY_ARRAY_API ,
4+ types:: { NPY_CASTING , NPY_ORDER } ,
5+ * ,
6+ } ;
7+ use crate :: types:: TypeNum ;
8+ use pyo3:: { prelude:: * , PyNativeType } ;
79
810use std:: marker:: PhantomData ;
911use std:: os:: raw:: * ;
1012use std:: ptr;
1113
12- pub enum NPyIterFlag {
14+ #[ derive( Clone , Copy , Debug , Eq , PartialEq ) ]
15+ pub enum NpyIterFlag {
1316 CIndex ,
1417 FIndex ,
1518 MultiIndex ,
@@ -24,105 +27,71 @@ pub enum NPyIterFlag {
2427 DelayBufAlloc ,
2528 DontNegateStrides ,
2629 CopyIfOverlap ,
30+ ReadWrite ,
31+ ReadOnly ,
32+ WriteOnly ,
2733}
2834
29- /*
30-
31- #define NPY_ITER_C_INDEX 0x00000001
32- #define NPY_ITER_F_INDEX 0x00000002
33- #define NPY_ITER_MULTI_INDEX 0x00000004
34- #define NPY_ITER_EXTERNAL_LOOP 0x00000008
35- #define NPY_ITER_COMMON_DTYPE 0x00000010
36- #define NPY_ITER_REFS_OK 0x00000020
37- #define NPY_ITER_ZEROSIZE_OK 0x00000040
38- #define NPY_ITER_REDUCE_OK 0x00000080
39- #define NPY_ITER_RANGED 0x00000100
40- #define NPY_ITER_BUFFERED 0x00000200
41- #define NPY_ITER_GROWINNER 0x00000400
42- #define NPY_ITER_DELAY_BUFALLOC 0x00000800
43- #define NPY_ITER_DONT_NEGATE_STRIDES 0x00001000
44- #define NPY_ITER_COPY_IF_OVERLAP 0x00002000
45- #define NPY_ITER_READWRITE 0x00010000
46- #define NPY_ITER_READONLY 0x00020000
47- #define NPY_ITER_WRITEONLY 0x00040000
48- #define NPY_ITER_NBO 0x00080000
49- #define NPY_ITER_ALIGNED 0x00100000
50- #define NPY_ITER_CONTIG 0x00200000
51- #define NPY_ITER_COPY 0x00400000
52- #define NPY_ITER_UPDATEIFCOPY 0x00800000
53- #define NPY_ITER_ALLOCATE 0x01000000
54- #define NPY_ITER_NO_SUBTYPE 0x02000000
55- #define NPY_ITER_VIRTUAL 0x04000000
56- #define NPY_ITER_NO_BROADCAST 0x08000000
57- #define NPY_ITER_WRITEMASKED 0x10000000
58- #define NPY_ITER_ARRAYMASK 0x20000000
59- #define NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE 0x40000000
60-
61- #define NPY_ITER_GLOBAL_FLAGS 0x0000ffff
62- #define NPY_ITER_PER_OP_FLAGS 0xffff0000
63-
64- */
65-
66- impl NPyIterFlag {
35+ impl NpyIterFlag {
6736 fn to_c_enum ( & self ) -> npy_uint32 {
68- use NPyIterFlag :: * ;
37+ use NpyIterFlag :: * ;
6938 match self {
70- CIndex => 0x00000001 ,
71- FIndex => 0x00000002 ,
72- MultiIndex => 0x00000004 ,
73- ExternalLoop => 0x00000008 ,
74- CommonDtype => 0x00000010 ,
75- RefsOk => 0x00000020 ,
76- ZerosizeOk => 0x00000040 ,
77- ReduceOk => 0x00000080 ,
78- Ranged => 0x00000100 ,
79- Buffered => 0x00000200 ,
80- GrowInner => 0x00000400 ,
81- DelayBufAlloc => 0x00000800 ,
82- DontNegateStrides => 0x00001000 ,
83- CopyIfOverlap => 0x00002000 ,
39+ CIndex => NPY_ITER_C_INDEX ,
40+ FIndex => NPY_ITER_C_INDEX ,
41+ MultiIndex => NPY_ITER_MULTI_INDEX ,
42+ ExternalLoop => NPY_ITER_EXTERNAL_LOOP ,
43+ CommonDtype => NPY_ITER_COMMON_DTYPE ,
44+ RefsOk => NPY_ITER_REFS_OK ,
45+ ZerosizeOk => NPY_ITER_ZEROSIZE_OK ,
46+ ReduceOk => NPY_ITER_REDUCE_OK ,
47+ Ranged => NPY_ITER_RANGED ,
48+ Buffered => NPY_ITER_BUFFERED ,
49+ GrowInner => NPY_ITER_GROWINNER ,
50+ DelayBufAlloc => NPY_ITER_DELAY_BUFALLOC ,
51+ DontNegateStrides => NPY_ITER_DONT_NEGATE_STRIDES ,
52+ CopyIfOverlap => NPY_ITER_COPY_IF_OVERLAP ,
53+ ReadWrite => NPY_ITER_READWRITE ,
54+ ReadOnly => NPY_ITER_READONLY ,
55+ WriteOnly => NPY_ITER_WRITEONLY ,
8456 }
8557 }
8658}
8759
8860pub struct NpyIterBuilder < ' py , T > {
8961 flags : npy_uint32 ,
90- array : * mut npyffi:: PyArrayObject ,
91- py : Python < ' py > ,
92- return_type : PhantomData < T > ,
62+ array : & ' py PyArrayDyn < T > ,
9363}
9464
95- impl < ' py , T > NpyIterBuilder < ' py , T > {
96- pub fn new < D > ( array : PyArray < T , D > , py : Python < ' py > ) -> NpyIterBuilder < ' py , T > {
65+ impl < ' py , T : TypeNum > NpyIterBuilder < ' py , T > {
66+ pub fn new < D : ndarray :: Dimension > ( array : & ' py PyArray < T , D > ) -> NpyIterBuilder < ' py , T > {
9767 NpyIterBuilder {
98- array : array. as_array_ptr ( ) ,
99- py,
10068 flags : 0 ,
101- return_type : PhantomData ,
69+ array : array . into_dyn ( ) ,
10270 }
10371 }
10472
105- pub fn set_iter_flags ( & mut self , flag : NPyIterFlag , value : bool ) -> & mut Self {
106- if value {
107- self . flags |= flag. to_c_enum ( ) ;
108- } else {
109- self . flags &= !flag. to_c_enum ( ) ;
110- }
73+ pub fn add ( mut self , flag : NpyIterFlag ) -> Self {
74+ self . flags |= flag. to_c_enum ( ) ;
11175 self
11276 }
11377
114- pub fn finish ( self ) -> Option < NpyIterSingleArray < ' py , T > > {
78+ pub fn remove ( mut self , flag : NpyIterFlag ) -> Self {
79+ self . flags &= !flag. to_c_enum ( ) ;
80+ self
81+ }
82+
83+ pub fn build ( self ) -> PyResult < NpyIterSingleArray < ' py , T > > {
11584 let iter_ptr = unsafe {
11685 PY_ARRAY_API . NpyIter_New (
117- self . array ,
86+ self . array . as_array_ptr ( ) ,
11887 self . flags ,
11988 NPY_ORDER :: NPY_ANYORDER ,
12089 NPY_CASTING :: NPY_SAFE_CASTING ,
12190 ptr:: null_mut ( ) ,
12291 )
12392 } ;
124-
125- NpyIterSingleArray :: new ( iter_ptr, self . py )
93+ let py = self . array . py ( ) ;
94+ NpyIterSingleArray :: new ( iter_ptr, py ) . ok_or_else ( || PyErr :: fetch ( py ) )
12695 }
12796}
12897
0 commit comments