@@ -4,7 +4,7 @@ use crate::npyffi::{
44 types:: { NPY_CASTING , NPY_ORDER } ,
55 * ,
66} ;
7- use crate :: types:: TypeNum ;
7+ use crate :: types:: Element ;
88use pyo3:: { prelude:: * , PyNativeType } ;
99
1010use std:: marker:: PhantomData ;
@@ -13,10 +13,11 @@ use std::ptr;
1313
1414#[ derive( Clone , Copy , Debug , Eq , PartialEq ) ]
1515pub enum NpyIterFlag {
16- CIndex ,
16+ /* CIndex,
1717 FIndex,
18- MultiIndex ,
19- ExternalLoop ,
18+ MultiIndex, */
19+ // ExternalLoop, // This flag greatly modifies the behaviour of accessing the data
20+ // so we don't support it.
2021 CommonDtype ,
2122 RefsOk ,
2223 ZerosizeOk ,
@@ -27,19 +28,19 @@ pub enum NpyIterFlag {
2728 DelayBufAlloc ,
2829 DontNegateStrides ,
2930 CopyIfOverlap ,
30- ReadWrite ,
31+ /* ReadWrite,
3132 ReadOnly,
32- WriteOnly ,
33+ WriteOnly, */
3334}
3435
3536impl NpyIterFlag {
3637 fn to_c_enum ( & self ) -> npy_uint32 {
3738 use NpyIterFlag :: * ;
3839 match self {
39- CIndex => NPY_ITER_C_INDEX ,
40+ /* CIndex => NPY_ITER_C_INDEX,
4041 FIndex => NPY_ITER_C_INDEX,
41- MultiIndex => NPY_ITER_MULTI_INDEX ,
42- ExternalLoop => NPY_ITER_EXTERNAL_LOOP ,
42+ MultiIndex => NPY_ITER_MULTI_INDEX, */
43+ /* ExternalLoop => NPY_ITER_EXTERNAL_LOOP, */
4344 CommonDtype => NPY_ITER_COMMON_DTYPE ,
4445 RefsOk => NPY_ITER_REFS_OK ,
4546 ZerosizeOk => NPY_ITER_ZEROSIZE_OK ,
@@ -50,9 +51,9 @@ impl NpyIterFlag {
5051 DelayBufAlloc => NPY_ITER_DELAY_BUFALLOC ,
5152 DontNegateStrides => NPY_ITER_DONT_NEGATE_STRIDES ,
5253 CopyIfOverlap => NPY_ITER_COPY_IF_OVERLAP ,
53- ReadWrite => NPY_ITER_READWRITE ,
54+ /* ReadWrite => NPY_ITER_READWRITE,
5455 ReadOnly => NPY_ITER_READONLY,
55- WriteOnly => NPY_ITER_WRITEONLY ,
56+ WriteOnly => NPY_ITER_WRITEONLY, */
5657 }
5758 }
5859}
@@ -62,20 +63,22 @@ pub struct NpyIterBuilder<'py, T> {
6263 array : & ' py PyArrayDyn < T > ,
6364}
6465
65- impl < ' py , T : TypeNum > NpyIterBuilder < ' py , T > {
66- pub fn new < D : ndarray:: Dimension > ( array : & ' py PyArray < T , D > ) -> NpyIterBuilder < ' py , T > {
66+ impl < ' py , T : Element > NpyIterBuilder < ' py , T > {
67+ pub fn readwrite < D : ndarray:: Dimension > ( array : & ' py PyArray < T , D > ) -> NpyIterBuilder < ' py , T > {
6768 NpyIterBuilder {
68- flags : 0 ,
69- array : array. into_dyn ( ) ,
69+ flags : NPY_ITER_READWRITE ,
70+ array : array. to_dyn ( ) ,
7071 }
7172 }
7273
73- pub fn set ( mut self , flag : NpyIterFlag ) -> Self {
74- if flag == NpyIterFlag :: ExternalLoop {
75- // TODO: I don't want to make set fallible, but also we don't want to
76- // support ExternalLoop yet (maybe ever?).
77- panic ! ( "rust-numpy does not currently support ExternalLoop access" ) ;
74+ pub fn readonly < D : ndarray:: Dimension > ( array : & ' py PyArray < T , D > ) -> NpyIterBuilder < ' py , T > {
75+ NpyIterBuilder {
76+ flags : NPY_ITER_READONLY ,
77+ array : array. to_dyn ( ) ,
7878 }
79+ }
80+
81+ pub fn set ( mut self , flag : NpyIterFlag ) -> Self {
7982 self . flags |= flag. to_c_enum ( ) ;
8083 self
8184 }
@@ -191,7 +194,7 @@ pub struct NpyMultiIterBuilder<'py, T, S: MultiIterMode> {
191194 structure : PhantomData < S > ,
192195}
193196
194- impl < ' py , T : TypeNum > NpyMultiIterBuilder < ' py , T , ( ) > {
197+ impl < ' py , T : Element > NpyMultiIterBuilder < ' py , T , ( ) > {
195198 pub fn new ( ) -> Self {
196199 Self {
197200 flags : 0 ,
@@ -202,11 +205,6 @@ impl<'py, T: TypeNum> NpyMultiIterBuilder<'py, T, ()> {
202205 }
203206
204207 pub fn set ( mut self , flag : NpyIterFlag ) -> Self {
205- if flag == NpyIterFlag :: ExternalLoop {
206- // TODO: I don't want to make set fallible, but also we don't want to
207- // support ExternalLoop yet (maybe ever?).
208- panic ! ( "rust-numpy does not currently support ExternalLoop access" ) ;
209- }
210208 self . flags |= flag. to_c_enum ( ) ;
211209 self
212210 }
@@ -217,12 +215,12 @@ impl<'py, T: TypeNum> NpyMultiIterBuilder<'py, T, ()> {
217215 }
218216}
219217
220- impl < ' py , T : TypeNum , S : MultiIterMode > NpyMultiIterBuilder < ' py , T , S > {
218+ impl < ' py , T : Element , S : MultiIterMode > NpyMultiIterBuilder < ' py , T , S > {
221219 pub fn add_readonly_array < D : ndarray:: Dimension > (
222220 mut self ,
223221 array : & ' py PyArray < T , D > ,
224222 ) -> NpyMultiIterBuilder < ' py , T , RO < S > > {
225- self . arrays . push ( array. into_dyn ( ) ) ;
223+ self . arrays . push ( array. to_dyn ( ) ) ;
226224 self . opflags . push ( NPY_ITER_READONLY ) ;
227225
228226 NpyMultiIterBuilder {
@@ -237,7 +235,7 @@ impl<'py, T: TypeNum, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
237235 mut self ,
238236 array : & ' py PyArray < T , D > ,
239237 ) -> NpyMultiIterBuilder < ' py , T , RW < S > > {
240- self . arrays . push ( array. into_dyn ( ) ) ;
238+ self . arrays . push ( array. to_dyn ( ) ) ;
241239 self . opflags . push ( NPY_ITER_READWRITE ) ;
242240
243241 NpyMultiIterBuilder {
@@ -249,7 +247,7 @@ impl<'py, T: TypeNum, S: MultiIterMode> NpyMultiIterBuilder<'py, T, S> {
249247 }
250248}
251249
252- impl < ' py , T : TypeNum , S : MultiIterModeHasManyArrays > NpyMultiIterBuilder < ' py , T , S > {
250+ impl < ' py , T : Element , S : MultiIterModeHasManyArrays > NpyMultiIterBuilder < ' py , T , S > {
253251 pub fn build ( mut self ) -> PyResult < NpyMultiIterArray < ' py , T , S > > {
254252 assert ! ( self . arrays. len( ) == self . opflags. len( ) ) ;
255253 assert ! ( self . arrays. len( ) <= i32 :: MAX as usize ) ;
@@ -279,6 +277,7 @@ pub struct NpyMultiIterArray<'py, T, S: MultiIterModeHasManyArrays> {
279277 iterator : ptr:: NonNull < objects:: NpyIter > ,
280278 iternext : unsafe extern "C" fn ( * mut objects:: NpyIter ) -> c_int ,
281279 empty : bool ,
280+ iter_size : npy_intp ,
282281 dataptr : * mut * mut c_char ,
283282
284283 return_type : PhantomData < T > ,
@@ -298,11 +297,14 @@ impl<'py, T, S: MultiIterModeHasManyArrays> NpyMultiIterArray<'py, T, S> {
298297 if dataptr. is_null ( ) {
299298 unsafe { PY_ARRAY_API . NpyIter_Deallocate ( iterator. as_mut ( ) ) } ;
300299 }
300+
301+ let iter_size = unsafe { PY_ARRAY_API . NpyIter_GetIterSize ( iterator. as_mut ( ) ) } ;
301302
302303 Some ( Self {
303304 iterator,
304305 iternext,
305- empty : false , // TODO: Handle empty iterators
306+ iter_size,
307+ empty : iter_size != 0 , // TODO: Handle empty iterators
306308 dataptr,
307309 return_type : PhantomData ,
308310 structure : PhantomData ,
@@ -339,6 +341,10 @@ impl<'py, T: 'py> std::iter::Iterator for NpyMultiIterArray<'py, T, $arg> {
339341 retval
340342 }
341343 }
344+
345+ fn size_hint( & self ) -> ( usize , Option <usize >) {
346+ ( self . iter_size as usize , Some ( self . iter_size as usize ) )
347+ }
342348}
343349 }
344350}
0 commit comments