@@ -158,3 +158,254 @@ impl<'py, T: 'py> std::iter::Iterator for NpyIterSingleArray<'py, T> {
158158 }
159159 }
160160}
161+
162+ pub trait MultiIterMode { }
163+
164+ impl MultiIterMode for ( ) { }
165+
166+ pub struct RO < S > {
167+ structure : PhantomData < S > ,
168+ }
169+
170+ impl < S : MultiIterMode > MultiIterMode for RO < S > { }
171+
172+ pub struct RW < S > {
173+ structure : PhantomData < S > ,
174+ }
175+
176+ impl < S : MultiIterMode > MultiIterMode for RW < S > { }
177+
178+ pub trait MultiIterModeHasManyArrays : MultiIterMode { }
179+ impl MultiIterModeHasManyArrays for RO < RO < ( ) > > { }
180+ impl MultiIterModeHasManyArrays for RO < RW < ( ) > > { }
181+ impl MultiIterModeHasManyArrays for RW < RO < ( ) > > { }
182+ impl MultiIterModeHasManyArrays for RW < RW < ( ) > > { }
183+
184+ impl < S : MultiIterModeHasManyArrays > MultiIterModeHasManyArrays for RO < S > { }
185+ impl < S : MultiIterModeHasManyArrays > MultiIterModeHasManyArrays for RW < S > { }
186+
187+ pub struct NpyMultiIterBuilder < ' py , T , S : MultiIterMode > {
188+ flags : npy_uint32 ,
189+ opflags : Vec < npy_uint32 > ,
190+ arrays : Vec < & ' py PyArrayDyn < T > > ,
191+ structure : PhantomData < S > ,
192+ }
193+
194+ impl < ' py , T : TypeNum > NpyMultiIterBuilder < ' py , T , ( ) > {
195+ pub fn new ( ) -> Self {
196+ Self {
197+ flags : 0 ,
198+ opflags : Vec :: new ( ) ,
199+ arrays : Vec :: new ( ) ,
200+ structure : PhantomData ,
201+ }
202+ }
203+
204+ pub fn set ( mut self , flag : NpyIterFlag ) -> Self {
205+ if flag == NpyIterFlag :: ExternalLoop {
206+ // TODO: I don't want to make set fallible, but also we don't want to
207+ // support ExternalLoop yet (maybe ever?).
208+ panic ! ( "rust-numpy does not currently support ExternalLoop access" ) ;
209+ }
210+ self . flags |= flag. to_c_enum ( ) ;
211+ self
212+ }
213+
214+ pub fn unset ( mut self , flag : NpyIterFlag ) -> Self {
215+ self . flags &= !flag. to_c_enum ( ) ;
216+ self
217+ }
218+ }
219+
220+ impl < ' py , T : TypeNum , S : MultiIterMode > NpyMultiIterBuilder < ' py , T , S > {
221+ pub fn add_readonly_array < D : ndarray:: Dimension > (
222+ mut self ,
223+ array : & ' py PyArray < T , D > ,
224+ ) -> NpyMultiIterBuilder < ' py , T , RO < S > > {
225+ self . arrays . push ( array. into_dyn ( ) ) ;
226+ self . opflags . push ( NPY_ITER_READONLY ) ;
227+
228+ NpyMultiIterBuilder {
229+ flags : self . flags ,
230+ opflags : self . opflags ,
231+ arrays : self . arrays ,
232+ structure : PhantomData ,
233+ }
234+ }
235+
236+ pub fn add_readwrite_array < D : ndarray:: Dimension > (
237+ mut self ,
238+ array : & ' py PyArray < T , D > ,
239+ ) -> NpyMultiIterBuilder < ' py , T , RW < S > > {
240+ self . arrays . push ( array. into_dyn ( ) ) ;
241+ self . opflags . push ( NPY_ITER_READWRITE ) ;
242+
243+ NpyMultiIterBuilder {
244+ flags : self . flags ,
245+ opflags : self . opflags ,
246+ arrays : self . arrays ,
247+ structure : PhantomData ,
248+ }
249+ }
250+ }
251+
252+ impl < ' py , T : TypeNum , S : MultiIterModeHasManyArrays > NpyMultiIterBuilder < ' py , T , S > {
253+ pub fn build ( mut self ) -> PyResult < NpyMultiIterArray < ' py , T , S > > {
254+ assert ! ( self . arrays. len( ) == self . opflags. len( ) ) ;
255+ assert ! ( self . arrays. len( ) <= i32 :: MAX as usize ) ;
256+ assert ! ( 2 <= self . arrays. len( ) ) ;
257+
258+ let iter_ptr = unsafe {
259+ PY_ARRAY_API . NpyIter_MultiNew (
260+ self . arrays . len ( ) as i32 ,
261+ self . arrays
262+ . iter_mut ( )
263+ . map ( |x| x. as_array_ptr ( ) )
264+ . collect :: < Vec < _ > > ( )
265+ . as_mut_ptr ( ) ,
266+ self . flags ,
267+ NPY_ORDER :: NPY_ANYORDER ,
268+ NPY_CASTING :: NPY_SAFE_CASTING ,
269+ self . opflags . as_mut_ptr ( ) ,
270+ ptr:: null_mut ( ) ,
271+ )
272+ } ;
273+ let py = self . arrays [ 0 ] . py ( ) ;
274+ NpyMultiIterArray :: new ( iter_ptr, py) . ok_or_else ( || PyErr :: fetch ( py) )
275+ }
276+ }
277+
278+ pub struct NpyMultiIterArray < ' py , T , S : MultiIterModeHasManyArrays > {
279+ iterator : ptr:: NonNull < objects:: NpyIter > ,
280+ iternext : unsafe extern "C" fn ( * mut objects:: NpyIter ) -> c_int ,
281+ empty : bool ,
282+ dataptr : * mut * mut c_char ,
283+
284+ return_type : PhantomData < T > ,
285+ structure : PhantomData < S > ,
286+ _py : Python < ' py > ,
287+ }
288+
289+ impl < ' py , T , S : MultiIterModeHasManyArrays > NpyMultiIterArray < ' py , T , S > {
290+ fn new ( iterator : * mut objects:: NpyIter , py : Python < ' py > ) -> Option < Self > {
291+ let mut iterator = ptr:: NonNull :: new ( iterator) ?;
292+
293+ // TODO replace the null second arg with something correct.
294+ let iternext =
295+ unsafe { PY_ARRAY_API . NpyIter_GetIterNext ( iterator. as_mut ( ) , ptr:: null_mut ( ) ) ? } ;
296+ let dataptr = unsafe { PY_ARRAY_API . NpyIter_GetDataPtrArray ( iterator. as_mut ( ) ) } ;
297+
298+ if dataptr. is_null ( ) {
299+ unsafe { PY_ARRAY_API . NpyIter_Deallocate ( iterator. as_mut ( ) ) } ;
300+ }
301+
302+ Some ( Self {
303+ iterator,
304+ iternext,
305+ empty : false , // TODO: Handle empty iterators
306+ dataptr,
307+ return_type : PhantomData ,
308+ structure : PhantomData ,
309+ _py : py,
310+ } )
311+ }
312+ }
313+
314+ impl < ' py , T , S : MultiIterModeHasManyArrays > Drop for NpyMultiIterArray < ' py , T , S > {
315+ fn drop ( & mut self ) {
316+ let _success = unsafe { PY_ARRAY_API . NpyIter_Deallocate ( self . iterator . as_mut ( ) ) } ;
317+ // TODO: Handle _success somehow?
318+ }
319+ }
320+
321+ impl < ' py , T : ' py > std:: iter:: Iterator for NpyMultiIterArray < ' py , T , RO < RO < ( ) > > > {
322+ type Item = ( & ' py T , & ' py T ) ;
323+
324+ fn next ( & mut self ) -> Option < Self :: Item > {
325+ if self . empty {
326+ None
327+ } else {
328+ // Note: This pointer is correct and doesn't need to be updated,
329+ // note that we're derefencing a **char into a *char casting to a *T
330+ // and then transforming that into a reference, the value that dataptr
331+ // points to is being updated by iternext to point to the next value.
332+ let retval = Some ( unsafe {
333+ (
334+ & * ( * self . dataptr as * mut T ) ,
335+ & * ( * self . dataptr . offset ( 1 ) as * mut T ) ,
336+ )
337+ } ) ;
338+ self . empty = unsafe { ( self . iternext ) ( self . iterator . as_mut ( ) ) } == 0 ;
339+ retval
340+ }
341+ }
342+ }
343+
344+ impl < ' py , T : ' py > std:: iter:: Iterator for NpyMultiIterArray < ' py , T , RO < RW < ( ) > > > {
345+ type Item = ( & ' py mut T , & ' py T ) ;
346+
347+ fn next ( & mut self ) -> Option < Self :: Item > {
348+ if self . empty {
349+ None
350+ } else {
351+ // Note: This pointer is correct and doesn't need to be updated,
352+ // note that we're derefencing a **char into a *char casting to a *T
353+ // and then transforming that into a reference, the value that dataptr
354+ // points to is being updated by iternext to point to the next value.
355+ let retval = Some ( unsafe {
356+ (
357+ & mut * ( * self . dataptr as * mut T ) ,
358+ & * ( * self . dataptr . offset ( 1 ) as * mut T ) ,
359+ )
360+ } ) ;
361+ self . empty = unsafe { ( self . iternext ) ( self . iterator . as_mut ( ) ) } == 0 ;
362+ retval
363+ }
364+ }
365+ }
366+
367+ impl < ' py , T : ' py > std:: iter:: Iterator for NpyMultiIterArray < ' py , T , RW < RO < ( ) > > > {
368+ type Item = ( & ' py T , & ' py mut T ) ;
369+
370+ fn next ( & mut self ) -> Option < Self :: Item > {
371+ if self . empty {
372+ None
373+ } else {
374+ // Note: This pointer is correct and doesn't need to be updated,
375+ // note that we're derefencing a **char into a *char casting to a *T
376+ // and then transforming that into a reference, the value that dataptr
377+ // points to is being updated by iternext to point to the next value.
378+ let retval = Some ( unsafe {
379+ (
380+ & * ( * self . dataptr as * mut T ) ,
381+ & mut * ( * self . dataptr . offset ( 1 ) as * mut T ) ,
382+ )
383+ } ) ;
384+ self . empty = unsafe { ( self . iternext ) ( self . iterator . as_mut ( ) ) } == 0 ;
385+ retval
386+ }
387+ }
388+ }
389+
390+ impl < ' py , T : ' py > std:: iter:: Iterator for NpyMultiIterArray < ' py , T , RW < RW < ( ) > > > {
391+ type Item = ( & ' py mut T , & ' py mut T ) ;
392+
393+ fn next ( & mut self ) -> Option < Self :: Item > {
394+ if self . empty {
395+ None
396+ } else {
397+ // Note: This pointer is correct and doesn't need to be updated,
398+ // note that we're derefencing a **char into a *char casting to a *T
399+ // and then transforming that into a reference, the value that dataptr
400+ // points to is being updated by iternext to point to the next value.
401+ let retval = Some ( unsafe {
402+ (
403+ & mut * ( * self . dataptr as * mut T ) ,
404+ & mut * ( * self . dataptr . offset ( 1 ) as * mut T ) ,
405+ )
406+ } ) ;
407+ self . empty = unsafe { ( self . iternext ) ( self . iterator . as_mut ( ) ) } == 0 ;
408+ retval
409+ }
410+ }
411+ }
0 commit comments