From 9c811a21682d5df9917301dafcb4e5cbb0dbe043 Mon Sep 17 00:00:00 2001 From: Matt Haberland Date: Sat, 22 Nov 2025 23:56:06 -0800 Subject: [PATCH 1/2] Add vectorized searchsorted --- src/array_api_extra/__init__.py | 2 + src/array_api_extra/_lib/_funcs.py | 97 ++++++++++++++++- tests/test_funcs.py | 162 ++++++++++++++++++++++++++++- 3 files changed, 259 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 14a3803b..bb165602 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -22,6 +22,7 @@ default_dtype, kron, nunique, + searchsorted, ) from ._lib._lazy import lazy_apply @@ -48,6 +49,7 @@ "one_hot", "pad", "partition", + "searchsorted", "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 6e50ce95..f4ac0df6 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -8,7 +8,12 @@ from ._at import at from ._utils import _compat, _helpers -from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array +from ._utils._compat import ( + array_namespace, + is_dask_namespace, + is_jax_array, + is_torch_namespace, +) from ._utils._helpers import ( asarrays, capabilities, @@ -28,6 +33,7 @@ "kron", "nunique", "pad", + "searchsorted", "setdiff1d", "sinc", ] @@ -665,6 +671,95 @@ def pad( return at(padded, tuple(slices)).set(x) +def searchsorted( + x1: Array, + x2: Array, + /, + *, + side: Literal["left", "right"] = "left", + xp: ModuleType, +) -> Array: + """ + Find indices where elements should be inserted to maintain order. + + Find the indices into a sorted array ``x1`` such that if the elements in ``x2`` + were inserted before the indices, the resulting array would remain sorted. + + Parameters + ---------- + x1 : Array + Input array. Should have a real-valued data type. Must be sorted in ascending + order along the last axis. + x2 : Array + Array containing search values. Should have a real-valued data type. Must have + the same shape as ``x1`` except along the last axis. + side : {'left', 'right'}, optional + Argument controlling which index is returned if an element of ``x2`` is equal to + one or more elements of ``x1``: ``'left'`` returns the index of the first of + these elements; ``'right'`` returns the next index after the last of these + elements. Default: ``'left'``. + xp : array_namespace, optional + The standard-compatible namespace for the array arguments. Default: infer. + + Returns + ------- + Array: integer array + An array of indices with the same shape as ``x2``. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([11, 12, 13, 13, 14, 15]) + >>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp) + Array([0, 1, 5, 6], dtype=array_api_strict.int64) + >>> xpx.searchsorted(x, xp.asarray(13), xp=xp) + Array(2, dtype=array_api_strict.int64) + >>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp) + Array(4, dtype=array_api_strict.int64) + + `searchsorted` is vectorized along the last axis. + + >>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]]) + >>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]]) + >>> xpx.searchsorted(x1, x2, xp=xp) + Array([[1, 3], + [2, 4]], dtype=array_api_strict.int64) + """ + xp = array_namespace(x1, x2) if xp is None else xp + xp_default_int = xp.asarray(1).dtype + y_0d = xp.asarray(x2).ndim == 0 + x_1d = x1.ndim <= 1 + + if x_1d or is_torch_namespace(xp): + x2 = xp.reshape(x2, ()) if (y_0d and x_1d) else x2 + out = xp.searchsorted(x1, x2, side=side) + return xp.astype(out, xp_default_int, copy=False) + + a = xp.full(x2.shape, 0, device=_compat.device(x1)) + + if x1.shape[-1] == 0: + return a + + n = xp.count_nonzero(~xp.isnan(x1), axis=-1, keepdims=True) + b = xp.broadcast_to(n, x2.shape) + + compare = xp.less_equal if side == "left" else xp.less + + # while xp.any(b - a > 1): + # refactored to for loop with ~log2(n) iterations for JAX JIT + for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type] + c = (a + b) // 2 + x0 = xp.take_along_axis(x1, c, axis=-1) + j = compare(x2, x0) + b = xp.where(j, c, b) + a = xp.where(j, a, c) + + out = xp.where(compare(x2, xp.min(x1, axis=-1, keepdims=True)), 0, b) + out = xp.where(xp.isnan(x2), x1.shape[-1], out) if side == "right" else out + return xp.astype(out, xp_default_int, copy=False) + + def setdiff1d( x1: Array | complex, x2: Array | complex, diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6b10757f..12a71872 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -29,13 +29,18 @@ one_hot, pad, partition, + searchsorted, setdiff1d, sinc, ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal +from array_api_extra._lib._utils._compat import ( + array_namespace, + is_jax_namespace, + is_torch_namespace, +) from array_api_extra._lib._utils._compat import device as get_device -from array_api_extra._lib._utils._compat import is_jax_namespace from array_api_extra._lib._utils._helpers import eager_shape, ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -52,6 +57,7 @@ lazy_xp_function(pad) # FIXME calls in1d which calls xp.unique_values without size lazy_xp_function(setdiff1d, jax_jit=False) +lazy_xp_function(searchsorted) lazy_xp_function(sinc) NestedFloatList = list[float] | list["NestedFloatList"] @@ -1637,3 +1643,157 @@ def test_kind(self, xp: ModuleType, library: Backend): expected = xp.asarray([False, True, False, True]) res = isin(a, b, kind="sort") xp_assert_equal(res, expected) + + +def _apply_over_batch(*argdefs: tuple[str, int]): + """ + Factory for decorator that applies a function over batched arguments. + + Array arguments may have any number of core dimensions (typically 0, + 1, or 2) and any broadcastable batch shapes. There may be any + number of array outputs of any number of dimensions. Assumptions + right now - which are satisfied by all functions of interest in `linalg` - + are that all array inputs are consecutive keyword or positional arguments, + and that the wrapped function returns either a single array or a tuple of + arrays. It's only as general as it needs to be right now - it can be extended. + + Parameters + ---------- + *argdefs : tuple of (str, int) + Definitions of array arguments: the keyword name of the argument, and + the number of core dimensions. + + Example: + -------- + `linalg.eig` accepts two matrices as the first two arguments `a` and `b`, where + `b` is optional, and returns one array or a tuple of arrays, depending on the + values of other positional or keyword arguments. To generate a wrapper that applies + the function over batches of `a` and optionally `b` : + + >>> _apply_over_batch(('a', 2), ('b', 2)) + """ + names, ndims = list(zip(*argdefs, strict=True)) + n_arrays = len(names) + + def decorator(f): + def wrapper(*args_tuple, **kwargs): + args = list(args_tuple) + + # Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs` + arrays, other_args = args[:n_arrays], args[n_arrays:] + for i, name in enumerate(names): + if name in kwargs: + if i + 1 <= len(args): + message = ( + f"{f.__name__}() got multiple values for argument `{name}`." + ) + raise ValueError(message) + arrays.append(kwargs.pop(name)) + + xp = array_namespace(*arrays) + + # Determine core and batch shapes + batch_shapes = [] + core_shapes = [] + for i, (array, ndim) in enumerate(zip(arrays, ndims, strict=True)): + array = None if array is None else xp.asarray(array) # noqa: PLW2901 + shape = () if array is None else array.shape + arrays[i] = array + batch_shapes.append(shape[:-ndim] if ndim > 0 else shape) + core_shapes.append(shape[-ndim:] if ndim > 0 else ()) + + # Early exit if call is not batched + if not any(batch_shapes): + return f(*arrays, *other_args, **kwargs) + + # Determine broadcasted batch shape + batch_shape = np.broadcast_shapes(*batch_shapes) # Gives OK error message + + # Broadcast arrays to appropriate shape + for i, (array, core_shape) in enumerate( + zip(arrays, core_shapes, strict=True) + ): + if array is None: + continue + arrays[i] = xp.broadcast_to(array, batch_shape + core_shape) + + # Main loop + results = [] + for index in np.ndindex(batch_shape): + result = f( + *( + (array[index] if array is not None else None) + for array in arrays + ), + *other_args, + **kwargs, + ) + # Assume `result` is either a tuple or single array. This is easily + # generalized by allowing the contributor to pass an `unpack_result` + # callable to the decorator factory. + result = (result,) if not isinstance(result, tuple) else result + results.append(result) + results = list(zip(*results, strict=True)) + + # Reshape results + for i, result in enumerate(results): + result = xp.stack(result) # noqa: PLW2901 + core_shape = result.shape[1:] + results[i] = xp.reshape(result, batch_shape + core_shape) + + # Assume `result` should be a single array if there is only one element or + # a `tuple` otherwise. This is easily generalized by allowing the + # contributor to pass an `pack_result` callable to the decorator factory. + return results[0] if len(results) == 1 else results + + return wrapper + + return decorator + + +@_apply_over_batch(("a", 1), ("v", 1)) +def xp_searchsorted(a, v, side, xp): + return xp.searchsorted(xp.asarray(a), xp.asarray(v), side=side) + + +@pytest.mark.skip_xp_backend(Backend.DASK, reason="no take_along_axis") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no searchsorted") +class TestSearchsorted: + @pytest.mark.parametrize("side", ["left", "right"]) + @pytest.mark.parametrize("ties", [False, True]) + @pytest.mark.parametrize( + "shape", [0, 1, 2, 10, 11, 1000, 10001, (2, 0), (0, 2), (2, 10), (2, 3, 11)] + ) + @pytest.mark.parametrize("nans_x", [False, True]) + @pytest.mark.parametrize("infs_x", [False, True]) + def test_nd(self, side, ties, shape, nans_x, infs_x, xp): + if nans_x and is_torch_namespace(xp): + pytest.skip("torch sorts NaNs differently") + rng = np.random.default_rng(945298725498274853) + x = rng.integers(5, size=shape) if ties else rng.random(shape) + # float32 is to accommodate JAX - nextafter with `float64` is too small? + x = np.asarray(x, dtype=np.float32) + xr = np.nextafter(x, np.inf) + xl = np.nextafter(x, -np.inf) + x_ = np.asarray([-np.inf, np.inf, np.nan]) + x_ = np.broadcast_to(x_, (*x.shape[:-1], 3)) + y = rng.permuted(np.concatenate((xl, x, xr, x_), axis=-1), axis=-1) + if nans_x: + mask = rng.random(shape) < 0.1 + x[mask] = np.nan + if infs_x: + mask = rng.random(shape) < 0.1 + x[mask] = -np.inf + mask = rng.random(shape) > 0.9 + x[mask] = np.inf + x = np.sort(x, stable=True, axis=-1) + x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64) + xp_default_int = xp.asarray(1).dtype + if x.size == 0 and x.ndim > 0 and x.shape[-1] != 0: + ref = xp.empty((*x.shape[:-1], y.shape[-1]), dtype=xp_default_int) + else: + ref = xp_searchsorted(x, y, side=side, xp=np) + ref = xp.asarray(ref, dtype=xp_default_int) + x, y = xp.asarray(x.copy()), xp.asarray(y.copy()) + res = searchsorted(x, y, side=side, xp=xp) + xp_assert_equal(res, ref) From a094ffe62857a14588acc69119cd2cc7d5a713d0 Mon Sep 17 00:00:00 2001 From: Matt Haberland Date: Sun, 23 Nov 2025 14:40:12 -0800 Subject: [PATCH 2/2] STY: searchsorted: fix typing issues --- src/array_api_extra/_lib/_funcs.py | 4 +-- tests/test_funcs.py | 40 ++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index f4ac0df6..3895014d 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -677,7 +677,7 @@ def searchsorted( /, *, side: Literal["left", "right"] = "left", - xp: ModuleType, + xp: ModuleType | None = None, ) -> Array: """ Find indices where elements should be inserted to maintain order. @@ -748,7 +748,7 @@ def searchsorted( # while xp.any(b - a > 1): # refactored to for loop with ~log2(n) iterations for JAX JIT - for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type] + for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] c = (a + b) // 2 x0 = xp.take_along_axis(x1, c, axis=-1) j = compare(x2, x0) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 12a71872..f455541c 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,7 +1,7 @@ import math import warnings from types import ModuleType -from typing import Any, cast +from typing import Any, Literal, cast import hypothesis import hypothesis.extra.numpy as npst @@ -1645,10 +1645,12 @@ def test_kind(self, xp: ModuleType, library: Backend): xp_assert_equal(res, expected) -def _apply_over_batch(*argdefs: tuple[str, int]): +def _apply_over_batch(*argdefs: tuple[str, int]) -> Any: """ Factory for decorator that applies a function over batched arguments. + Copied (with light simplifications) from `scipy._lib._util`. + Array arguments may have any number of core dimensions (typically 0, 1, or 2) and any broadcastable batch shapes. There may be any number of array outputs of any number of dimensions. Assumptions @@ -1675,8 +1677,11 @@ def _apply_over_batch(*argdefs: tuple[str, int]): names, ndims = list(zip(*argdefs, strict=True)) n_arrays = len(names) - def decorator(f): - def wrapper(*args_tuple, **kwargs): + def decorator(f: Any) -> Any: + def wrapper( + *args_tuple: tuple[Any] | None, + **kwargs: dict[str, Any] | None, + ) -> Any: args = list(args_tuple) # Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs` @@ -1688,9 +1693,9 @@ def wrapper(*args_tuple, **kwargs): f"{f.__name__}() got multiple values for argument `{name}`." ) raise ValueError(message) - arrays.append(kwargs.pop(name)) + arrays.append(kwargs.pop(name)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - xp = array_namespace(*arrays) + xp = array_namespace(*arrays) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # Determine core and batch shapes batch_shapes = [] @@ -1751,8 +1756,13 @@ def wrapper(*args_tuple, **kwargs): return decorator -@_apply_over_batch(("a", 1), ("v", 1)) -def xp_searchsorted(a, v, side, xp): +@_apply_over_batch(("a", 1), ("v", 1)) # type: ignore[misc] +def xp_searchsorted( + a: Array, + v: Array, + side: Literal["left", "right"], + xp: ModuleType, +) -> Array: return xp.searchsorted(xp.asarray(a), xp.asarray(v), side=side) @@ -1766,13 +1776,21 @@ class TestSearchsorted: ) @pytest.mark.parametrize("nans_x", [False, True]) @pytest.mark.parametrize("infs_x", [False, True]) - def test_nd(self, side, ties, shape, nans_x, infs_x, xp): + def test_nd( + self, + side: Literal["left", "right"], + ties: bool, + shape: int | tuple[int], + nans_x: bool, + infs_x: bool, + xp: ModuleType, + ): if nans_x and is_torch_namespace(xp): pytest.skip("torch sorts NaNs differently") rng = np.random.default_rng(945298725498274853) x = rng.integers(5, size=shape) if ties else rng.random(shape) # float32 is to accommodate JAX - nextafter with `float64` is too small? - x = np.asarray(x, dtype=np.float32) + x = np.asarray(x, dtype=np.float32) # type:ignore[assignment] xr = np.nextafter(x, np.inf) xl = np.nextafter(x, -np.inf) x_ = np.asarray([-np.inf, np.inf, np.nan]) @@ -1786,7 +1804,7 @@ def test_nd(self, side, ties, shape, nans_x, infs_x, xp): x[mask] = -np.inf mask = rng.random(shape) > 0.9 x[mask] = np.inf - x = np.sort(x, stable=True, axis=-1) + x = np.sort(x, axis=-1) # type:ignore[assignment] x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64) xp_default_int = xp.asarray(1).dtype if x.size == 0 and x.ndim > 0 and x.shape[-1] != 0: