@@ -243,7 +243,6 @@ def test_where(shapes, dtypes, data):
243243@pytest .mark .min_version ("2023.12" )
244244@given (data = st .data ())
245245def test_searchsorted (data ):
246- # TODO: test side="right"
247246 # TODO: Allow different dtypes for x1 and x2
248247 _x1 = data .draw (
249248 st .lists (xps .from_dtype (dh .default_float ), min_size = 1 , unique = True ),
@@ -262,18 +261,20 @@ def test_searchsorted(data):
262261 ),
263262 label = "x2" ,
264263 )
264+ kw = data .draw (hh .kwargs (side = st .sampled_from (["left" , "right" ])))
265265
266266 repro_snippet = ph .format_snippet (f"xp.searchsorted({ x1 !r} , { x2 !r} , sorter={ sorter !r} )" )
267267 try :
268- out = xp .searchsorted (x1 , x2 , sorter = sorter )
268+ out = xp .searchsorted (x1 , x2 , sorter = sorter , ** kw )
269269
270270 ph .assert_dtype (
271271 "searchsorted" ,
272272 in_dtype = [x1 .dtype , x2 .dtype ],
273273 out_dtype = out .dtype ,
274274 expected = xp .__array_namespace_info__ ().default_dtypes ()["indexing" ],
275275 )
276- # TODO: shapes and values testing
276+ # TODO: x2.ndim > 1, values testing
277+ ph .assert_shape ("searchsorted" , out_shape = out .shape , expected = x2 .shape )
277278 except Exception as exc :
278279 exc .add_note (repro_snippet )
279280 raise
0 commit comments