@@ -109,21 +109,21 @@ pub struct NpySingleIter<'py, T> {
109109 iterator : ptr:: NonNull < objects:: NpyIter > ,
110110 iternext : unsafe extern "C" fn ( * mut objects:: NpyIter ) -> c_int ,
111111 empty : bool ,
112+ iter_size : npy_intp ,
112113 dataptr : * mut * mut c_char ,
113114 return_type : PhantomData < T > ,
114115 _py : Python < ' py > ,
115116}
116117
117118impl < ' py , T > NpySingleIter < ' py , T > {
118- fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> PyResult < NpySingleIter < ' py , T > > {
119+ fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> PyResult < Self > {
119120 let mut iterator = match ptr:: NonNull :: new ( iterator) {
120121 Some ( iter) => iter,
121122 None => {
122123 return Err ( NpyIterInstantiationError . into ( ) ) ;
123124 }
124125 } ;
125126
126- // TODO replace the null second arg with something correct.
127127 let iternext = match unsafe { PY_ARRAY_API . NpyIter_GetIterNext ( iterator. as_mut ( ) , ptr:: null_mut ( ) ) } {
128128 Some ( ptr) => ptr,
129129 None => {
@@ -137,10 +137,13 @@ impl<'py, T> NpySingleIter<'py, T> {
137137 return Err ( NpyIterInstantiationError . into ( ) ) ;
138138 }
139139
140- Ok ( NpySingleIter {
140+ let iter_size = unsafe { PY_ARRAY_API . NpyIter_GetIterSize ( iterator. as_mut ( ) ) } ;
141+
142+ Ok ( Self {
141143 iterator,
142144 iternext,
143- empty : false , // TODO: Handle empty iterators
145+ iter_size,
146+ empty : iter_size == 0 ,
144147 dataptr,
145148 return_type : PhantomData ,
146149 _py : py,
@@ -171,6 +174,10 @@ impl<'py, T: 'py> std::iter::Iterator for NpySingleIter<'py, T> {
171174 retval
172175 }
173176 }
177+
178+ fn size_hint ( & self ) -> ( usize , Option < usize > ) {
179+ ( self . iter_size as usize , Some ( self . iter_size as usize ) )
180+ }
174181}
175182
176183mod private {
@@ -189,7 +196,7 @@ macro_rules! private_impl {
189196 } ;
190197}
191198
192- /// A combinator type that represents an terator mode (e.g., ReadOnly + ReadWrite + ReadOnly ).
199+ /// A combinator type that represents an terator mode (e.g., ReadOnly + ReadWrite).
193200pub trait MultiIterMode {
194201 private_decl ! ( ) ;
195202 type Pre : MultiIterMode ;
@@ -316,7 +323,7 @@ impl<'py, T: Element, S: MultiIterModeWithManyArrays> NpyMultiIterBuilder<'py, T
316323 )
317324 } ;
318325 let py = self . arrays [ 0 ] . py ( ) ;
319- NpyMultiIter :: new ( iter_ptr, py) . ok_or_else ( || PyErr :: fetch ( py ) )
326+ NpyMultiIter :: new ( iter_ptr, py)
320327 }
321328}
322329
@@ -332,25 +339,34 @@ pub struct NpyMultiIter<'py, T, S: MultiIterModeWithManyArrays> {
332339}
333340
334341impl < ' py , T , S : MultiIterModeWithManyArrays > NpyMultiIter < ' py , T , S > {
335- fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> Option < Self > {
336- let mut iterator = ptr:: NonNull :: new ( iterator) ?;
342+ fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> PyResult < Self > {
343+ let mut iterator = match ptr:: NonNull :: new ( iterator) {
344+ Some ( ptr) => ptr,
345+ None => {
346+ return Err ( NpyIterInstantiationError . into ( ) ) ;
347+ }
348+ } ;
337349
338- // TODO replace the null second arg with something correct.
339- let iternext =
340- unsafe { PY_ARRAY_API . NpyIter_GetIterNext ( iterator. as_mut ( ) , ptr:: null_mut ( ) ) ? } ;
350+ let iternext = match unsafe { PY_ARRAY_API . NpyIter_GetIterNext ( iterator. as_mut ( ) , ptr:: null_mut ( ) ) } {
351+ Some ( ptr) => ptr,
352+ None => {
353+ return Err ( PyErr :: fetch ( py) ) ;
354+ }
355+ } ;
341356 let dataptr = unsafe { PY_ARRAY_API . NpyIter_GetDataPtrArray ( iterator. as_mut ( ) ) } ;
342357
343358 if dataptr. is_null ( ) {
344359 unsafe { PY_ARRAY_API . NpyIter_Deallocate ( iterator. as_mut ( ) ) } ;
360+ return Err ( NpyIterInstantiationError . into ( ) ) ;
345361 }
346362
347363 let iter_size = unsafe { PY_ARRAY_API . NpyIter_GetIterSize ( iterator. as_mut ( ) ) } ;
348364
349- Some ( Self {
365+ Ok ( Self {
350366 iterator,
351367 iternext,
352368 iter_size,
353- empty : iter_size == 0 , // TODO: Handle empty iterators
369+ empty : iter_size == 0 ,
354370 dataptr,
355371 marker : PhantomData ,
356372 _py : py,
0 commit comments