11import math
22from collections import deque
3- from typing import Iterable , Union
3+ from typing import Iterable , Iterator , Tuple , Union
44
55import pytest
66from hypothesis import assume , given
@@ -33,8 +33,10 @@ def assert_array_ndindex(
3333 x_indices : Iterable [Union [int , Shape ]],
3434 out : Array ,
3535 out_indices : Iterable [Union [int , Shape ]],
36+ / ,
37+ ** kw ,
3638):
37- msg_suffix = f" [{ func_name } ()]\n { x = } \n { out = } "
39+ msg_suffix = f" [{ func_name } ({ ph . fmt_kw ( kw ) } )]\n { x = } \n { out = } "
3840 for x_idx , out_idx in zip (x_indices , out_indices ):
3941 msg = f"out[{ out_idx } ]={ out [out_idx ]} , should be x[{ x_idx } ]={ x [x_idx ]} "
4042 msg += msg_suffix
@@ -266,7 +268,15 @@ def test_reshape(x, data):
266268 assert_array_ndindex ("reshape" , x , sh .ndindex (x .shape ), out , sh .ndindex (out .shape ))
267269
268270
269- @pytest .mark .skip (reason = "faulty test logic" ) # TODO
271+ def roll_ndindex (shape : Shape , shifts : Tuple [int ], axes : Tuple [int ]) -> Iterator [Shape ]:
272+ assert len (shifts ) == len (axes ) # sanity check
273+ all_shifts = [0 for _ in shape ]
274+ for s , a in zip (shifts , axes ):
275+ all_shifts [a ] = s
276+ for idx in sh .ndindex (shape ):
277+ yield tuple ((i + sh ) % si for i , sh , si in zip (idx , all_shifts , shape ))
278+
279+
270280@given (xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ()), st .data ())
271281def test_roll (x , data ):
272282 shift_strat = st .integers (- hh .MAX_ARRAY_SIZE , hh .MAX_ARRAY_SIZE )
@@ -287,6 +297,8 @@ def test_roll(x, data):
287297
288298 out = xp .roll (x , shift , ** kw )
289299
300+ kw = {"shift" : shift , ** kw } # for error messages
301+
290302 ph .assert_dtype ("roll" , x .dtype , out .dtype )
291303
292304 ph .assert_result_shape ("roll" , (x .shape ,), out .shape )
@@ -296,18 +308,12 @@ def test_roll(x, data):
296308 indices = list (sh .ndindex (x .shape ))
297309 shifted_indices = deque (indices )
298310 shifted_indices .rotate (- shift )
299- assert_array_ndindex ("roll" , x , indices , out , shifted_indices )
311+ assert_array_ndindex ("roll" , x , indices , out , shifted_indices , ** kw )
300312 else :
301- _shift = (shift ,) if isinstance (shift , int ) else shift
313+ shifts = (shift ,) if isinstance (shift , int ) else shift
302314 axes = sh .normalise_axis (kw ["axis" ], x .ndim )
303- all_indices = list (sh .ndindex (x .shape ))
304- for s , a in zip (_shift , axes ):
305- side = x .shape [a ]
306- for i in range (side ):
307- indices = [idx for idx in all_indices if idx [a ] == i ]
308- shifted_indices = deque (indices )
309- shifted_indices .rotate (- s )
310- assert_array_ndindex ("roll" , x , indices , out , shifted_indices )
315+ shifted_indices = roll_ndindex (x .shape , shifts , axes )
316+ assert_array_ndindex ("roll" , x , sh .ndindex (x .shape ), out , shifted_indices , ** kw )
311317
312318
313319@given (
0 commit comments