Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
default_dtype,
kron,
nunique,
searchsorted,
)
from ._lib._lazy import lazy_apply

Expand All @@ -48,6 +49,7 @@
"one_hot",
"pad",
"partition",
"searchsorted",
"setdiff1d",
"sinc",
]
97 changes: 96 additions & 1 deletion src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
from ._utils._compat import (
array_namespace,
is_dask_namespace,
is_jax_array,
is_torch_namespace,
)
from ._utils._helpers import (
asarrays,
capabilities,
Expand All @@ -28,6 +33,7 @@
"kron",
"nunique",
"pad",
"searchsorted",
"setdiff1d",
"sinc",
]
Expand Down Expand Up @@ -665,6 +671,95 @@
return at(padded, tuple(slices)).set(x)


def searchsorted(
x1: Array,
x2: Array,
/,
*,
side: Literal["left", "right"] = "left",
xp: ModuleType,
) -> Array:
"""
Find indices where elements should be inserted to maintain order.

Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
were inserted before the indices, the resulting array would remain sorted.

Parameters
----------
x1 : Array
Input array. Should have a real-valued data type. Must be sorted in ascending
order along the last axis.
x2 : Array
Array containing search values. Should have a real-valued data type. Must have
the same shape as ``x1`` except along the last axis.
side : {'left', 'right'}, optional
Argument controlling which index is returned if an element of ``x2`` is equal to
one or more elements of ``x1``: ``'left'`` returns the index of the first of
these elements; ``'right'`` returns the next index after the last of these
elements. Default: ``'left'``.
xp : array_namespace, optional
The standard-compatible namespace for the array arguments. Default: infer.

Returns
-------
Array: integer array
An array of indices with the same shape as ``x2``.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([11, 12, 13, 13, 14, 15])
>>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp)
Array([0, 1, 5, 6], dtype=array_api_strict.int64)
>>> xpx.searchsorted(x, xp.asarray(13), xp=xp)
Array(2, dtype=array_api_strict.int64)
>>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp)
Array(4, dtype=array_api_strict.int64)

`searchsorted` is vectorized along the last axis.

>>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]])
>>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]])
>>> xpx.searchsorted(x1, x2, xp=xp)
Array([[1, 3],
[2, 4]], dtype=array_api_strict.int64)
"""
xp = array_namespace(x1, x2) if xp is None else xp

Check failure on line 729 in src/array_api_extra/_lib/_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Condition will always evaluate to False since the types "ModuleType" and "None" have no overlap (reportUnnecessaryComparison)
xp_default_int = xp.asarray(1).dtype
y_0d = xp.asarray(x2).ndim == 0
x_1d = x1.ndim <= 1

if x_1d or is_torch_namespace(xp):
x2 = xp.reshape(x2, ()) if (y_0d and x_1d) else x2
out = xp.searchsorted(x1, x2, side=side)
return xp.astype(out, xp_default_int, copy=False)

a = xp.full(x2.shape, 0, device=_compat.device(x1))

if x1.shape[-1] == 0:
return a

n = xp.count_nonzero(~xp.isnan(x1), axis=-1, keepdims=True)
b = xp.broadcast_to(n, x2.shape)

compare = xp.less_equal if side == "left" else xp.less

# while xp.any(b - a > 1):
# refactored to for loop with ~log2(n) iterations for JAX JIT
for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type]

Check failure on line 751 in src/array_api_extra/_lib/_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Argument of type "int | None" cannot be assigned to parameter "x" of type "_SupportsFloatOrIndex" in function "log2"   Type "int | None" is not assignable to type "_SupportsFloatOrIndex"     Type "None" is not assignable to type "_SupportsFloatOrIndex"       "None" is incompatible with protocol "SupportsFloat"         "__float__" is not present       "None" is incompatible with protocol "SupportsIndex"         "__index__" is not present (reportArgumentType)
c = (a + b) // 2
x0 = xp.take_along_axis(x1, c, axis=-1)
j = compare(x2, x0)
b = xp.where(j, c, b)
a = xp.where(j, a, c)

out = xp.where(compare(x2, xp.min(x1, axis=-1, keepdims=True)), 0, b)
out = xp.where(xp.isnan(x2), x1.shape[-1], out) if side == "right" else out
return xp.astype(out, xp_default_int, copy=False)


def setdiff1d(
x1: Array | complex,
x2: Array | complex,
Expand Down
162 changes: 161 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@
one_hot,
pad,
partition,
searchsorted,
setdiff1d,
sinc,
)
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import (
array_namespace,
is_jax_namespace,
is_torch_namespace,
)
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._compat import is_jax_namespace
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function
Expand All @@ -52,6 +57,7 @@
lazy_xp_function(pad)
# FIXME calls in1d which calls xp.unique_values without size
lazy_xp_function(setdiff1d, jax_jit=False)
lazy_xp_function(searchsorted)
lazy_xp_function(sinc)

NestedFloatList = list[float] | list["NestedFloatList"]
Expand Down Expand Up @@ -1637,3 +1643,157 @@
expected = xp.asarray([False, True, False, True])
res = isin(a, b, kind="sort")
xp_assert_equal(res, expected)


def _apply_over_batch(*argdefs: tuple[str, int]):

Check failure on line 1648 in tests/test_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Return type, "(f: Unknown) -> ((...) -> (Unknown | tuple[Any, ...] | list[tuple[Any, ...]]))", is partially unknown (reportUnknownParameterType)
"""
Factory for decorator that applies a function over batched arguments.
Array arguments may have any number of core dimensions (typically 0,
1, or 2) and any broadcastable batch shapes. There may be any
number of array outputs of any number of dimensions. Assumptions
right now - which are satisfied by all functions of interest in `linalg` -
are that all array inputs are consecutive keyword or positional arguments,
and that the wrapped function returns either a single array or a tuple of
arrays. It's only as general as it needs to be right now - it can be extended.
Parameters
----------
*argdefs : tuple of (str, int)
Definitions of array arguments: the keyword name of the argument, and
the number of core dimensions.
Example:
--------
`linalg.eig` accepts two matrices as the first two arguments `a` and `b`, where
`b` is optional, and returns one array or a tuple of arrays, depending on the
values of other positional or keyword arguments. To generate a wrapper that applies
the function over batches of `a` and optionally `b` :
>>> _apply_over_batch(('a', 2), ('b', 2))
"""
names, ndims = list(zip(*argdefs, strict=True))
n_arrays = len(names)

def decorator(f):

Check failure on line 1678 in tests/test_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Type annotation is missing for parameter "f" (reportMissingParameterType)

Check failure on line 1678 in tests/test_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Type of parameter "f" is unknown (reportUnknownParameterType)

Check failure on line 1678 in tests/test_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Return type, "(...) -> (Unknown | tuple[Any, ...] | list[tuple[Any, ...]])", is partially unknown (reportUnknownParameterType)
def wrapper(*args_tuple, **kwargs):

Check failure on line 1679 in tests/test_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Type of parameter "kwargs" is unknown (reportUnknownParameterType)

Check failure on line 1679 in tests/test_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Type annotation is missing for parameter "args_tuple" (reportMissingParameterType)

Check failure on line 1679 in tests/test_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Type of parameter "args_tuple" is unknown (reportUnknownParameterType)

Check failure on line 1679 in tests/test_funcs.py

View workflow job for this annotation

GitHub Actions / Lint

Return type, "Unknown | tuple[Any, ...] | list[tuple[Any, ...]]", is partially unknown (reportUnknownParameterType)
args = list(args_tuple)

# Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs`
arrays, other_args = args[:n_arrays], args[n_arrays:]
for i, name in enumerate(names):
if name in kwargs:
if i + 1 <= len(args):
message = (
f"{f.__name__}() got multiple values for argument `{name}`."
)
raise ValueError(message)
arrays.append(kwargs.pop(name))

xp = array_namespace(*arrays)

# Determine core and batch shapes
batch_shapes = []
core_shapes = []
for i, (array, ndim) in enumerate(zip(arrays, ndims, strict=True)):
array = None if array is None else xp.asarray(array) # noqa: PLW2901
shape = () if array is None else array.shape
arrays[i] = array
batch_shapes.append(shape[:-ndim] if ndim > 0 else shape)
core_shapes.append(shape[-ndim:] if ndim > 0 else ())

# Early exit if call is not batched
if not any(batch_shapes):
return f(*arrays, *other_args, **kwargs)

# Determine broadcasted batch shape
batch_shape = np.broadcast_shapes(*batch_shapes) # Gives OK error message

# Broadcast arrays to appropriate shape
for i, (array, core_shape) in enumerate(
zip(arrays, core_shapes, strict=True)
):
if array is None:
continue
arrays[i] = xp.broadcast_to(array, batch_shape + core_shape)

# Main loop
results = []
for index in np.ndindex(batch_shape):
result = f(
*(
(array[index] if array is not None else None)
for array in arrays
),
*other_args,
**kwargs,
)
# Assume `result` is either a tuple or single array. This is easily
# generalized by allowing the contributor to pass an `unpack_result`
# callable to the decorator factory.
result = (result,) if not isinstance(result, tuple) else result
results.append(result)
results = list(zip(*results, strict=True))

# Reshape results
for i, result in enumerate(results):
result = xp.stack(result) # noqa: PLW2901
core_shape = result.shape[1:]
results[i] = xp.reshape(result, batch_shape + core_shape)

# Assume `result` should be a single array if there is only one element or
# a `tuple` otherwise. This is easily generalized by allowing the
# contributor to pass an `pack_result` callable to the decorator factory.
return results[0] if len(results) == 1 else results

return wrapper

return decorator


@_apply_over_batch(("a", 1), ("v", 1))
def xp_searchsorted(a, v, side, xp):
return xp.searchsorted(xp.asarray(a), xp.asarray(v), side=side)


@pytest.mark.skip_xp_backend(Backend.DASK, reason="no take_along_axis")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no searchsorted")
class TestSearchsorted:
@pytest.mark.parametrize("side", ["left", "right"])
@pytest.mark.parametrize("ties", [False, True])
@pytest.mark.parametrize(
"shape", [0, 1, 2, 10, 11, 1000, 10001, (2, 0), (0, 2), (2, 10), (2, 3, 11)]
)
@pytest.mark.parametrize("nans_x", [False, True])
@pytest.mark.parametrize("infs_x", [False, True])
def test_nd(self, side, ties, shape, nans_x, infs_x, xp):
if nans_x and is_torch_namespace(xp):
pytest.skip("torch sorts NaNs differently")
rng = np.random.default_rng(945298725498274853)
x = rng.integers(5, size=shape) if ties else rng.random(shape)
# float32 is to accommodate JAX - nextafter with `float64` is too small?
x = np.asarray(x, dtype=np.float32)
xr = np.nextafter(x, np.inf)
xl = np.nextafter(x, -np.inf)
x_ = np.asarray([-np.inf, np.inf, np.nan])
x_ = np.broadcast_to(x_, (*x.shape[:-1], 3))
y = rng.permuted(np.concatenate((xl, x, xr, x_), axis=-1), axis=-1)
if nans_x:
mask = rng.random(shape) < 0.1
x[mask] = np.nan
if infs_x:
mask = rng.random(shape) < 0.1
x[mask] = -np.inf
mask = rng.random(shape) > 0.9
x[mask] = np.inf
x = np.sort(x, stable=True, axis=-1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary and causing CI failures.

Suggested change
x = np.sort(x, stable=True, axis=-1)
x = np.sort(x, axis=-1)

x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64)
xp_default_int = xp.asarray(1).dtype
if x.size == 0 and x.ndim > 0 and x.shape[-1] != 0:
ref = xp.empty((*x.shape[:-1], y.shape[-1]), dtype=xp_default_int)
else:
ref = xp_searchsorted(x, y, side=side, xp=np)
ref = xp.asarray(ref, dtype=xp_default_int)
x, y = xp.asarray(x.copy()), xp.asarray(y.copy())
res = searchsorted(x, y, side=side, xp=xp)
xp_assert_equal(res, ref)
Loading