@@ -60,6 +60,7 @@ def assert_equal(x, y, msg_extra=None):
6060
6161def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
6262 matrix_axes = (- 2 , - 1 ),
63+ res_axes = None ,
6364 assert_equal = assert_equal , ** kw ):
6465 """
6566 Test that f(*args, **kw) maps across stacks of matrices
@@ -84,7 +85,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
8485
8586 # Assume the result is stacked along the last 'dims' axes of matrix_axes.
8687 # This holds for all the functions tested in this file
87- res_axes = matrix_axes [::- 1 ][:dims ]
88+ if res_axes is None :
89+ if not isinstance (matrix_axes , tuple ) and all (isinstance (x , int ) for x in matrix_axes ):
90+ raise ValueError ("res_axes must be specified if matrix_axes is not a tuple of integers" )
91+ res_axes = matrix_axes [::- 1 ][:dims ]
8892
8993 for (x_idxes , (res_idx ,)) in zip (
9094 iter_indices (* shapes , skip_axes = matrix_axes ),
@@ -330,10 +334,12 @@ def test_matmul(x1, x2):
330334 assert res .shape == ()
331335 elif len (x1 .shape ) == 1 :
332336 assert res .shape == x2 .shape [:- 2 ] + x2 .shape [- 1 :]
333- _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 )
337+ _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 ,
338+ matrix_axes = [(0 ,), (- 2 , - 1 )], res_axes = [- 1 ])
334339 elif len (x2 .shape ) == 1 :
335340 assert res .shape == x1 .shape [:- 1 ]
336- _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 )
341+ _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 ,
342+ matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
337343 else :
338344 stack_shape = sh .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
339345 assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
@@ -546,10 +552,11 @@ def test_solve(x1, x2):
546552 # TODO: This requires an upstream fix to ndindex
547553 # (https://github.com/Quansight-Labs/ndindex/pull/131)
548554
549- # if x2.ndim == 1:
550- # _test_stacks(linalg.solve, x1, x2, res=res, dims=1)
551- # else:
552- # _test_stacks(linalg.solve, x1, x2, res=res, dims=2)
555+ if x2 .ndim == 1 :
556+ _test_stacks (linalg .solve , x1 , x2 , res = res , dims = 1 ,
557+ matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
558+ else :
559+ _test_stacks (linalg .solve , x1 , x2 , res = res , dims = 2 )
553560
554561@pytest .mark .xp_extension ('linalg' )
555562@given (
0 commit comments