diff --git a/changes/3547.misc.md b/changes/3547.misc.md new file mode 100644 index 0000000000..771bfe8861 --- /dev/null +++ b/changes/3547.misc.md @@ -0,0 +1 @@ +Moved concurrency limits to a global per-event loop setting instead of per-array call. \ No newline at end of file diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..69d6c3082e 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar @@ -8,8 +9,7 @@ from zarr.abc.metadata import Metadata from zarr.core.buffer import Buffer, NDBuffer -from zarr.core.common import NamedConfig, concurrent_map -from zarr.core.config import config +from zarr.core.common import NamedConfig if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable @@ -225,11 +225,8 @@ async def decode_partial( ------- Iterable[NDBuffer | None] """ - return await concurrent_map( - list(batch_info), - self._decode_partial_single, - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[self._decode_partial_single(*info) for info in batch_info]) class ArrayBytesCodecPartialEncodeMixin: @@ -262,11 +259,8 @@ async def encode_partial( The ByteSetter is used to write the necessary bytes and fetch bytes for existing chunk data. The chunk spec contains information about the chunk. """ - await concurrent_map( - list(batch_info), - self._encode_partial_single, - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + await asyncio.gather(*[self._encode_partial_single(*info) for info in batch_info]) class CodecPipeline: @@ -464,11 +458,8 @@ async def _batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], batch_info: Iterable[tuple[CodecInput | None, ArraySpec]], ) -> list[CodecOutput | None]: - return await concurrent_map( - list(batch_info), - _noop_for_none(func), - config.get("async.concurrency"), - ) + # Store handles concurrency limiting internally + return await asyncio.gather(*[_noop_for_none(func)(chunk, spec) for chunk, spec in batch_info]) def _noop_for_none( diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 4b3edf78d1..4ccab1877f 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod from asyncio import gather from dataclasses import dataclass @@ -462,13 +463,8 @@ async def getsize_prefix(self, prefix: str) -> int: # improve tail latency and might reduce memory pressure (since not all keys # would be in memory at once). - # avoid circular import - from zarr.core.common import concurrent_map - from zarr.core.config import config - - keys = [(x,) async for x in self.list_prefix(prefix)] - limit = config.get("async.concurrency") - sizes = await concurrent_map(keys, self.getsize, limit=limit) + keys = [x async for x in self.list_prefix(prefix)] + sizes = await asyncio.gather(*[self.getsize(key) for key in keys]) return sum(sizes) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 6b20ee950d..01ff74f38f 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from asyncio import gather @@ -22,7 +23,6 @@ import numpy as np from typing_extensions import deprecated -import zarr from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.abc.numcodec import Numcodec, _is_numcodec from zarr.codecs._v2 import V2Codec @@ -60,7 +60,6 @@ _default_zarr_format, _warn_order_kwarg, ceildiv, - concurrent_map, parse_shapelike, product, ) @@ -1848,13 +1847,12 @@ async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True) async def _delete_key(key: str) -> None: await (self.store_path / key).delete() - await concurrent_map( - [ - (self.metadata.encode_chunk_key(chunk_coords),) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _delete_key(self.metadata.encode_chunk_key(chunk_coords)) for chunk_coords in old_chunk_coords.difference(new_chunk_coords) - ], - _delete_key, - zarr_config.get("async.concurrency"), + ] ) # Write new metadata @@ -4535,10 +4533,9 @@ async def _copy_array_region( await result.setitem(chunk_coords, arr) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_array_region, - zarr.core.config.config.get("async.concurrency"), + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_array_region(region, data) for region in result._iter_shard_regions()] ) else: @@ -4546,10 +4543,9 @@ async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> Non await result.setitem(chunk_coords, _data[chunk_coords]) # Stream data from the source array to the new array - await concurrent_map( - [(region, data) for region in result._iter_shard_regions()], - _copy_arraylike_region, - zarr.core.config.config.get("async.concurrency"), + # Store handles concurrency limiting internally + await asyncio.gather( + *[_copy_arraylike_region(region, data) for region in result._iter_shard_regions()] ) return result diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..0f8350f7ea 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from itertools import islice, pairwise from typing import TYPE_CHECKING, Any, TypeVar @@ -14,7 +15,6 @@ Codec, CodecPipeline, ) -from zarr.core.common import concurrent_map from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning @@ -267,10 +267,12 @@ async def read_batch( else: out[out_selection] = fill_value_or_default(chunk_spec) else: - chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], - lambda byte_getter, prototype: byte_getter.get(prototype), - config.get("async.concurrency"), + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + byte_getter.get(array_spec.prototype) + for byte_getter, array_spec, *_ in batch_info + ] ) chunk_array_batch = await self.decode_batch( [ @@ -368,16 +370,15 @@ async def _read_key( return await byte_setter.get(prototype=prototype) chunk_bytes_batch: Iterable[Buffer | None] - chunk_bytes_batch = await concurrent_map( - [ - ( + # Store handles concurrency limiting internally + chunk_bytes_batch = await asyncio.gather( + *[ + _read_key( None if is_complete_chunk else byte_setter, chunk_spec.prototype, ) for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info - ], - _read_key, - config.get("async.concurrency"), + ] ) chunk_array_decoded = await self.decode_batch( [ @@ -435,15 +436,14 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non else: await byte_setter.set(chunk_bytes) - await concurrent_map( - [ - (byte_setter, chunk_bytes) + # Store handles concurrency limiting internally + await asyncio.gather( + *[ + _write_key(byte_setter, chunk_bytes) for chunk_bytes, (byte_setter, *_) in zip( chunk_bytes_batch, batch_info, strict=False ) - ], - _write_key, - config.get("async.concurrency"), + ] ) async def decode( @@ -470,13 +470,12 @@ async def read( out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, out, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.read_batch(single_batch_info, out, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.read_batch, - config.get("async.concurrency"), + ] ) async def write( @@ -485,13 +484,12 @@ async def write( value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: - await concurrent_map( - [ - (single_batch_info, value, drop_axes) + # Process mini-batches concurrently - stores handle I/O concurrency internally + await asyncio.gather( + *[ + self.write_batch(single_batch_info, value, drop_axes) for single_batch_info in batched(batch_info, self.batch_size) - ], - self.write_batch, - config.get("async.concurrency"), + ] ) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 9b3d297298..9682bfd60e 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -4,7 +4,9 @@ import functools import math import operator +import threading import warnings +import weakref from collections.abc import Iterable, Mapping, Sequence from enum import Enum from itertools import starmap @@ -98,15 +100,126 @@ def ceildiv(a: float, b: float) -> int: V = TypeVar("V") +# Global semaphore management for per-process concurrency limiting +# Use WeakKeyDictionary to automatically clean up semaphores when event loops are garbage collected +_global_semaphores: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Semaphore] = ( + weakref.WeakKeyDictionary() +) +# Use threading.Lock instead of asyncio.Lock to coordinate across event loops +_global_semaphore_lock = threading.Lock() + + +def get_global_semaphore() -> asyncio.Semaphore: + """ + Get the global semaphore for the current event loop. + + This ensures that all concurrent operations across the process share the same + concurrency limit, preventing excessive concurrent task creation when multiple + arrays or operations are running simultaneously. + + The semaphore is lazily created per event loop and uses the configured + `async.concurrency` value from zarr config. The semaphore is cached per event + loop, so subsequent calls return the same semaphore instance. + + Note: Config changes after the first call will not affect the semaphore limit. + To apply new config values, use :func:`reset_global_semaphores` to clear the cache. + + Returns + ------- + asyncio.Semaphore + The global semaphore for this event loop. + + Raises + ------ + RuntimeError + If called outside of an async context (no running event loop). + + See Also + -------- + reset_global_semaphores : Clear the global semaphore cache + """ + loop = asyncio.get_running_loop() + + # Acquire lock FIRST to prevent TOCTOU race condition + with _global_semaphore_lock: + if loop not in _global_semaphores: + limit = zarr_config.get("async.concurrency") + _global_semaphores[loop] = asyncio.Semaphore(limit) + return _global_semaphores[loop] + + +def reset_global_semaphores() -> None: + """ + Clear all cached global semaphores. + + This is useful when you want config changes to take effect, or for testing. + The next call to :func:`get_global_semaphore` will create a new semaphore + using the current configuration. + + Warning: This should only be called when no async operations are in progress, + as it will invalidate all existing semaphore references. + + Examples + -------- + >>> import zarr + >>> zarr.config.set({"async.concurrency": 50}) + >>> reset_global_semaphores() # Apply new config + """ + with _global_semaphore_lock: + _global_semaphores.clear() + + async def concurrent_map( items: Iterable[T], func: Callable[..., Awaitable[V]], limit: int | None = None, + *, + use_global_semaphore: bool = True, ) -> list[V]: - if limit is None: + """ + Execute an async function concurrently over multiple items with concurrency limiting. + + Parameters + ---------- + items : Iterable[T] + Items to process, where each item is a tuple of arguments to pass to func. + func : Callable[..., Awaitable[V]] + Async function to execute for each item. + limit : int | None, optional + If provided and use_global_semaphore is False, creates a local semaphore + with this limit. If None, no concurrency limiting is applied. + use_global_semaphore : bool, default True + If True, uses the global per-process semaphore for concurrency limiting, + ensuring all concurrent operations share the same limit. If False, uses + the `limit` parameter for local limiting (legacy behavior). + + Returns + ------- + list[V] + Results from executing func on all items. + """ + if use_global_semaphore: + if limit is not None: + raise ValueError( + "Cannot specify both use_global_semaphore=True and a limit value. " + "Either use the global semaphore (use_global_semaphore=True, limit=None) " + "or specify a local limit (use_global_semaphore=False, limit=)." + ) + # Use the global semaphore for process-wide concurrency limiting + sem = get_global_semaphore() + + async def run(item: tuple[Any]) -> V: + async with sem: + return await func(*item) + + return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items]) + + elif limit is None: + # No concurrency limiting return await asyncio.gather(*list(starmap(func, items))) else: + # Legacy mode: create local semaphore with specified limit sem = asyncio.Semaphore(limit) async def run(item: tuple[Any]) -> V: diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 9b5fee275b..50b57a569f 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -44,6 +44,7 @@ NodeType, ShapeLike, ZarrFormat, + get_global_semaphore, parse_shapelike, ) from zarr.core.config import config @@ -1440,8 +1441,8 @@ async def _members( ) raise ValueError(msg) - # enforce a concurrency limit by passing a semaphore to all the recursive functions - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() async for member in _iter_members_deep( self, max_depth=max_depth, @@ -3323,9 +3324,8 @@ async def create_nodes( The created nodes in the order they are created. """ - # Note: the only way to alter this value is via the config. If that's undesirable for some reason, - # then we should consider adding a keyword argument this this function - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 7945fba467..e1ca718784 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import warnings from contextlib import suppress @@ -17,6 +18,7 @@ from zarr.core.buffer import Buffer from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable @@ -82,6 +84,9 @@ class FsspecStore(Store): filesystem scheme. allowed_exceptions : tuple[type[Exception], ...] When fetching data, these cases will be deemed to correspond to missing keys. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Attributes ---------- @@ -117,18 +122,24 @@ class FsspecStore(Store): fs: AsyncFileSystem allowed_exceptions: tuple[type[Exception], ...] path: str + _semaphore: asyncio.Semaphore | None def __init__( self, fs: AsyncFileSystem, + *, read_only: bool = False, path: str = "/", allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS, + concurrency_limit: int | None = 50, ) -> None: super().__init__(read_only=read_only) self.fs = fs self.path = path self.allowed_exceptions = allowed_exceptions + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) if not self.fs.async_impl: raise TypeError("Filesystem needs to support async operations.") @@ -273,6 +284,7 @@ def __eq__(self, other: object) -> bool: and self.fs == other.fs ) + @with_concurrency_limit() async def get( self, key: str, @@ -315,6 +327,7 @@ async def get( else: return value + @with_concurrency_limit() async def set( self, key: str, @@ -335,6 +348,27 @@ async def set( raise NotImplementedError await self.fs._pipe_file(path, value.to_bytes()) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + if not self._is_open: + await self._open() + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + if not isinstance(value, Buffer): + raise TypeError( + f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + path = _dereference_path(self.path, key) + if self._semaphore: + async with self._semaphore: + await self.fs._pipe_file(path, value.to_bytes()) + else: + await self.fs._pipe_file(path, value.to_bytes()) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + + @with_concurrency_limit() async def delete(self, key: str) -> None: # docstring inherited self._check_writable() diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f64da71bb4..ea48c756d3 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -19,12 +19,13 @@ ) from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import AccessModeLiteral, concurrent_map +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator from zarr.core.buffer import BufferPrototype + from zarr.core.common import AccessModeLiteral def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: @@ -95,6 +96,9 @@ class LocalStore(Store): Directory to use as root of store. read_only : bool Whether the store is read-only + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 100. + Set to None for unlimited concurrency. Attributes ---------- @@ -109,8 +113,15 @@ class LocalStore(Store): supports_listing: bool = True root: Path + _semaphore: asyncio.Semaphore | None - def __init__(self, root: Path | str, *, read_only: bool = False) -> None: + def __init__( + self, + root: Path | str, + *, + read_only: bool = False, + concurrency_limit: int | None = 100, + ) -> None: super().__init__(read_only=read_only) if isinstance(root, str): root = Path(root) @@ -119,12 +130,17 @@ def __init__(self, root: Path | str, *, read_only: bool = False) -> None: f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) self.root = root + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited + concurrency_limit = self._semaphore._value if self._semaphore else None return type(self)( root=self.root, read_only=read_only, + concurrency_limit=concurrency_limit, ) @classmethod @@ -187,6 +203,7 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + @with_concurrency_limit() async def get( self, key: str, @@ -212,12 +229,23 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - args = [] - for key, byte_range in key_ranges: - assert isinstance(key, str) + # Note: We directly call the I/O functions here, wrapped with semaphore + # to avoid deadlock from calling the decorated get() method + + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: path = self.root / key - args.append((_get, path, prototype, byte_range)) - return await concurrent_map(args, asyncio.to_thread, limit=None) # TODO: fix limit + try: + if self._semaphore: + async with self._semaphore: + return await asyncio.to_thread(_get, path, prototype, byte_range) + else: + return await asyncio.to_thread(_get, path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + return await asyncio.gather( + *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] + ) async def set(self, key: str, value: Buffer) -> None: # docstring inherited @@ -230,6 +258,7 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: except FileExistsError: pass + @with_concurrency_limit() async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() @@ -242,6 +271,7 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: path = self.root / key await asyncio.to_thread(_put, path, value, exclusive=exclusive) + @with_concurrency_limit() async def delete(self, key: str) -> None: """ Remove a key from the store. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 904be922d7..be222c96b7 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,12 +1,12 @@ from __future__ import annotations +import asyncio from logging import getLogger from typing import TYPE_CHECKING, Self from zarr.abc.store import ByteRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: @@ -102,12 +102,10 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - - # All the key-ranges arguments goes with the same prototype - async def _get(key: str, byte_range: ByteRequest | None) -> Buffer | None: - return await self.get(key, prototype=prototype, byte_range=byte_range) - - return await concurrent_map(key_ranges, _get, limit=None) + # In-memory operations are fast and don't need concurrency limiting + return await asyncio.gather( + *[self.get(key, prototype, byte_range) for key, byte_range in key_ranges] + ) async def exists(self, key: str) -> bool: # docstring inherited diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 5c2197ecf6..223142d371 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -13,8 +13,8 @@ Store, SuffixByteRequest, ) -from zarr.core.common import concurrent_map -from zarr.core.config import config +from zarr.core.common import get_global_semaphore +from zarr.storage._utils import with_concurrency_limit if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Iterable, Sequence @@ -47,6 +47,9 @@ class ObjectStore(Store, Generic[T_Store]): An obstore store instance that is set up with the proper credentials. read_only : bool Whether to open the store in read-only mode. + concurrency_limit : int, optional + Maximum number of concurrent I/O operations. Default is 50. + Set to None for unlimited concurrency. Warnings -------- @@ -56,6 +59,7 @@ class ObjectStore(Store, Generic[T_Store]): store: T_Store """The underlying obstore instance.""" + _semaphore: asyncio.Semaphore | None def __eq__(self, value: object) -> bool: if not isinstance(value, ObjectStore): @@ -66,17 +70,28 @@ def __eq__(self, value: object) -> bool: return self.store == value.store # type: ignore[no-any-return] - def __init__(self, store: T_Store, *, read_only: bool = False) -> None: + def __init__( + self, + store: T_Store, + *, + read_only: bool = False, + concurrency_limit: int | None = 50, + ) -> None: if not store.__class__.__module__.startswith("obstore"): raise TypeError(f"expected ObjectStore class, got {store!r}") super().__init__(read_only=read_only) self.store = store + self._semaphore = ( + asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None + ) def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited + concurrency_limit = self._semaphore._value if self._semaphore else None return type(self)( store=self.store, read_only=read_only, + concurrency_limit=concurrency_limit, ) def __str__(self) -> str: @@ -94,6 +109,7 @@ def __setstate__(self, state: dict[Any, Any]) -> None: state["store"] = pickle.loads(state["store"]) self.__dict__.update(state) + @with_concurrency_limit() async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: @@ -101,41 +117,7 @@ async def get( import obstore as obs try: - if byte_range is None: - resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, RangeByteRequest): - bytes = await obs.get_range_async( - self.store, key, start=byte_range.start, end=byte_range.end - ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] - elif isinstance(byte_range, OffsetByteRequest): - resp = await obs.get_async( - self.store, key, options={"range": {"offset": byte_range.offset}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - elif isinstance(byte_range, SuffixByteRequest): - # some object stores (Azure) don't support suffix requests. In this - # case, our workaround is to first get the length of the object and then - # manually request the byte range at the end. - try: - resp = await obs.get_async( - self.store, key, options={"range": {"suffix": byte_range.suffix}} - ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] - except obs.exceptions.NotSupportedError: - head_resp = await obs.head_async(self.store, key) - file_size = head_resp["size"] - suffix_len = byte_range.suffix - buffer = await obs.get_range_async( - self.store, - key, - start=file_size - suffix_len, - length=suffix_len, - ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] - else: - raise ValueError(f"Unexpected byte_range, got {byte_range}") + return await self._get_impl(key, prototype, byte_range, obs) except _ALLOWED_EXCEPTIONS: return None @@ -145,7 +127,60 @@ async def get_partial_values( key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited - return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) + # Note: We directly call obs operations here, wrapped with semaphore + # to avoid deadlock from calling the decorated get() method + import obstore as obs + + async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None: + try: + if self._semaphore: + async with self._semaphore: + return await self._get_impl(key, prototype, byte_range, obs) + else: + return await self._get_impl(key, prototype, byte_range, obs) + except _ALLOWED_EXCEPTIONS: + return None + + return await asyncio.gather( + *[_get_with_limit(key, byte_range) for key, byte_range in key_ranges] + ) + + async def _get_impl( + self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None, obs: Any + ) -> Buffer: + """Implementation of get without semaphore decoration.""" + if byte_range is None: + resp = await obs.get_async(self.store, key) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + elif isinstance(byte_range, RangeByteRequest): + bytes = await obs.get_range_async( + self.store, key, start=byte_range.start, end=byte_range.end + ) + return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + elif isinstance(byte_range, OffsetByteRequest): + resp = await obs.get_async( + self.store, key, options={"range": {"offset": byte_range.offset}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + elif isinstance(byte_range, SuffixByteRequest): + try: + resp = await obs.get_async( + self.store, key, options={"range": {"suffix": byte_range.suffix}} + ) + return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + except obs.exceptions.NotSupportedError: + head_resp = await obs.head_async(self.store, key) + file_size = head_resp["size"] + suffix_len = byte_range.suffix + buffer = await obs.get_range_async( + self.store, + key, + start=file_size - suffix_len, + length=suffix_len, + ) + return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + else: + raise ValueError(f"Unexpected byte_range, got {byte_range}") async def exists(self, key: str) -> bool: # docstring inherited @@ -163,6 +198,7 @@ def supports_writes(self) -> bool: # docstring inherited return True + @with_concurrency_limit() async def set(self, key: str, value: Buffer) -> None: # docstring inherited import obstore as obs @@ -172,20 +208,43 @@ async def set(self, key: str, value: Buffer) -> None: buf = value.as_buffer_like() await obs.put_async(self.store, key, buf) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + # Override to avoid deadlock from calling decorated set() method + import obstore as obs + + self._check_writable() + + async def _set_with_limit(key: str, value: Buffer) -> None: + buf = value.as_buffer_like() + if self._semaphore: + async with self._semaphore: + await obs.put_async(self.store, key, buf) + else: + await obs.put_async(self.store, key, buf) + + await asyncio.gather(*[_set_with_limit(key, value) for key, value in values]) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited + # Note: Not decorated to avoid deadlock when called in batch via gather() import obstore as obs self._check_writable() buf = value.as_buffer_like() - with contextlib.suppress(obs.exceptions.AlreadyExistsError): - await obs.put_async(self.store, key, buf, mode="create") + if self._semaphore: + async with self._semaphore: + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") + else: + with contextlib.suppress(obs.exceptions.AlreadyExistsError): + await obs.put_async(self.store, key, buf, mode="create") @property def supports_deletes(self) -> bool: # docstring inherited return True + @with_concurrency_limit() async def delete(self, key: str) -> None: # docstring inherited import obstore as obs @@ -208,8 +267,18 @@ async def delete_dir(self, prefix: str) -> None: prefix += "/" metas = await obs.list(self.store, prefix).collect_async() - keys = [(m["path"],) for m in metas] - await concurrent_map(keys, self.delete, limit=config.get("async.concurrency")) + + # Delete with semaphore limiting to avoid deadlock + async def _delete_with_limit(path: str) -> None: + if self._semaphore: + async with self._semaphore: + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + else: + with contextlib.suppress(FileNotFoundError): + await obs.delete_async(self.store, path) + + await asyncio.gather(*[_delete_with_limit(m["path"]) for m in metas]) @property def supports_listing(self) -> bool: @@ -485,7 +554,8 @@ async def _get_partial_values( else: raise ValueError(f"Unsupported range input: {byte_range}") - semaphore = asyncio.Semaphore(config.get("async.concurrency")) + # Use global semaphore for process-wide concurrency limiting + semaphore = get_global_semaphore() futs: list[Coroutine[Any, Any, list[_Response]]] = [] for path, bounded_ranges in per_file_bounded_requests.items(): diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 39c28d44c3..d156a06891 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -1,17 +1,85 @@ from __future__ import annotations +import functools import re from pathlib import Path -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + import asyncio + from collections.abc import Callable, Coroutine, Iterable, Mapping from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer +P = ParamSpec("P") +T_co = TypeVar("T_co", covariant=True) + + +def with_concurrency_limit( + semaphore_attr: str = "_semaphore", +) -> Callable[[Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]]]: + """ + Decorator that applies a semaphore-based concurrency limit to an async method. + + This decorator is designed for Store methods that need to limit concurrent operations. + The store instance should have a `_semaphore` attribute (or custom attribute name) + that is either an asyncio.Semaphore or None (for unlimited concurrency). + + Parameters + ---------- + semaphore_attr : str, optional + Name of the semaphore attribute on the class instance. Default is "_semaphore". + + Returns + ------- + Callable + The decorated async function with concurrency limiting applied. + + Examples + -------- + ```python + class MyStore(Store): + def __init__(self, concurrency_limit: int = 100): + self._semaphore = asyncio.Semaphore(concurrency_limit) if concurrency_limit else None + + @with_concurrency_limit() + async def get(self, key: str) -> Buffer | None: + # This will only run when semaphore permits + return await expensive_io_operation(key) + ``` + """ + + def decorator( + func: Callable[P, Coroutine[Any, Any, T_co]], + ) -> Callable[P, Coroutine[Any, Any, T_co]]: + """ + This decorator wraps the invocation of `func` in an `async with semaphore` context manager. + The semaphore object is resolved by getting the `semaphor_attr` attribute from the first + argument to func. When this decorator is used on a method of a class, that first argument + is a reference to the class instance (`self`). + """ + + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co: + # First arg should be 'self' + if not args: + raise TypeError(f"{func.__name__} requires at least one argument (self)") + + self = args[0] + + semaphore: asyncio.Semaphore = getattr(self, semaphore_attr) + + # Apply concurrency limit + async with semaphore: + return await func(*args, **kwargs) + + return wrapper + + return decorator + def normalize_path(path: str | bytes | Path | None) -> str: if path is None: diff --git a/src/zarr/testing/store_concurrency.py b/src/zarr/testing/store_concurrency.py new file mode 100644 index 0000000000..06cf23857d --- /dev/null +++ b/src/zarr/testing/store_concurrency.py @@ -0,0 +1,247 @@ +"""Base test class for store concurrency limiting behavior.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Generic, TypeVar + +import pytest + +from zarr.core.buffer import Buffer, default_buffer_prototype + +if TYPE_CHECKING: + from zarr.abc.store import Store + +__all__ = ["StoreConcurrencyTests"] + + +S = TypeVar("S", bound="Store") +B = TypeVar("B", bound="Buffer") + + +class StoreConcurrencyTests(Generic[S, B]): + """Base class for testing store concurrency limiting behavior. + + This mixin provides tests for verifying that stores correctly implement + concurrency limiting. + + Subclasses should set: + - store_cls: The store class being tested + - buffer_cls: The buffer class to use (e.g., cpu.Buffer) + - expected_concurrency_limit: Expected default concurrency limit (or None for unlimited) + """ + + store_cls: type[S] + buffer_cls: type[B] + expected_concurrency_limit: int | None + + @pytest.fixture + async def store(self, store_kwargs: dict) -> S: + """Create and open a store instance.""" + return await self.store_cls.open(**store_kwargs) + + def test_concurrency_limit_default(self, store: S) -> None: + """Test that store has the expected default concurrency limit.""" + if hasattr(store, "_semaphore"): + if self.expected_concurrency_limit is None: + assert store._semaphore is None, "Expected no concurrency limit" + else: + assert store._semaphore is not None, "Expected concurrency limit to be set" + assert store._semaphore._value == self.expected_concurrency_limit, ( + f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}" + ) + + def test_concurrency_limit_custom(self, store_kwargs: dict) -> None: + """Test that custom concurrency limits can be set.""" + if "concurrency_limit" not in self.store_cls.__init__.__code__.co_varnames: + pytest.skip("Store does not support custom concurrency limits") + + # Test with custom limit + store = self.store_cls(**store_kwargs, concurrency_limit=42) + if hasattr(store, "_semaphore"): + assert store._semaphore is not None + assert store._semaphore._value == 42 + + # Test with None (unlimited) + store = self.store_cls(**store_kwargs, concurrency_limit=None) + if hasattr(store, "_semaphore"): + assert store._semaphore is None + + async def test_concurrency_limit_enforced(self, store: S) -> None: + """Test that the concurrency limit is actually enforced during execution. + + This test verifies that when many operations are submitted concurrently, + only up to the concurrency limit are actually executing at once. + """ + if not hasattr(store, "_semaphore") or store._semaphore is None: + pytest.skip("Store has no concurrency limit") + + limit = store._semaphore._value + + # We'll monitor the semaphore's available count + # When it reaches 0, that means `limit` operations are running + min_available = limit + + async def monitored_operation(key: str, value: B) -> None: + nonlocal min_available + # Check semaphore state right after we're scheduled + await asyncio.sleep(0) # Yield to ensure we're in the queue + available = store._semaphore._value + min_available = min(min_available, available) + + # Now do the actual operation (which will acquire the semaphore) + await store.set(key, value) + + # Launch more operations than the limit to ensure contention + num_ops = limit * 2 + items = [ + (f"limit_test_key_{i}", self.buffer_cls.from_bytes(f"value_{i}".encode())) + for i in range(num_ops) + ] + + await asyncio.gather(*[monitored_operation(k, v) for k, v in items]) + + # The semaphore should have been fully utilized (reached 0 or close to it) + # This indicates that `limit` operations were running concurrently + assert min_available < limit, ( + f"Semaphore was never fully utilized. " + f"Min available: {min_available}, Limit: {limit}. " + f"This suggests operations aren't running concurrently." + ) + + # Ideally it should reach 0, but allow some slack for timing + assert min_available <= 5, ( + f"Semaphore only reached {min_available} available slots. " + f"Expected close to 0 with limit {limit}." + ) + + async def test_batch_write_no_deadlock(self, store: S) -> None: + """Test that batch writes don't deadlock when exceeding concurrency limit.""" + # Create more items than any reasonable concurrency limit + num_items = 200 + items = [ + (f"test_key_{i}", self.buffer_cls.from_bytes(f"test_value_{i}".encode())) + for i in range(num_items) + ] + + # This should complete without deadlock, even if num_items > concurrency_limit + await asyncio.wait_for(store._set_many(items), timeout=30.0) + + # Verify all items were written correctly + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_read_no_deadlock(self, store: S) -> None: + """Test that batch reads don't deadlock when exceeding concurrency limit.""" + # Write test data + num_items = 200 + test_data = { + f"test_key_{i}": self.buffer_cls.from_bytes(f"test_value_{i}".encode()) + for i in range(num_items) + } + + for key, value in test_data.items(): + await store.set(key, value) + + # Read all items concurrently - should not deadlock + keys_and_ranges = [(key, None) for key in test_data] + results = await asyncio.wait_for( + store.get_partial_values(default_buffer_prototype(), keys_and_ranges), + timeout=30.0, + ) + + # Verify results + assert len(results) == num_items + for result, (key, expected_value) in zip(results, test_data.items()): + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_batch_delete_no_deadlock(self, store: S) -> None: + """Test that batch deletes don't deadlock when exceeding concurrency limit.""" + if not store.supports_deletes: + pytest.skip("Store does not support deletes") + + # Write test data + num_items = 200 + keys = [f"test_key_{i}" for i in range(num_items)] + for key in keys: + await store.set(key, self.buffer_cls.from_bytes(b"test_value")) + + # Delete all items concurrently - should not deadlock + await asyncio.wait_for(asyncio.gather(*[store.delete(key) for key in keys]), timeout=30.0) + + # Verify all items were deleted + for key in keys: + result = await store.get(key, default_buffer_prototype()) + assert result is None + + async def test_concurrent_operations_correctness(self, store: S) -> None: + """Test that concurrent operations produce correct results.""" + num_operations = 100 + + # Mix of reads and writes + write_keys = [f"write_key_{i}" for i in range(num_operations)] + write_values = [ + self.buffer_cls.from_bytes(f"value_{i}".encode()) for i in range(num_operations) + ] + + # Write all concurrently + await asyncio.gather(*[store.set(k, v) for k, v in zip(write_keys, write_values)]) + + # Read all concurrently + results = await asyncio.gather( + *[store.get(k, default_buffer_prototype()) for k in write_keys] + ) + + # Verify correctness + for result, expected in zip(results, write_values): + assert result is not None + assert result.to_bytes() == expected.to_bytes() + + @pytest.mark.parametrize("batch_size", [1, 10, 50, 100]) + async def test_various_batch_sizes(self, store: S, batch_size: int) -> None: + """Test that various batch sizes work correctly.""" + items = [ + (f"batch_key_{i}", self.buffer_cls.from_bytes(f"batch_value_{i}".encode())) + for i in range(batch_size) + ] + + # Should complete without issues for any batch size + await asyncio.wait_for(store._set_many(items), timeout=10.0) + + # Verify + for key, expected_value in items: + result = await store.get(key, default_buffer_prototype()) + assert result is not None + assert result.to_bytes() == expected_value.to_bytes() + + async def test_empty_batch_operations(self, store: S) -> None: + """Test that empty batch operations don't cause issues.""" + # Empty batch should not raise + await store._set_many([]) + + # Empty read batch + results = await store.get_partial_values(default_buffer_prototype(), []) + assert results == [] + + async def test_mixed_success_failure_batch(self, store: S) -> None: + """Test batch operations with mix of successful and failing items.""" + # Write some initial data + await store.set("existing_key", self.buffer_cls.from_bytes(b"existing_value")) + + # Try to read mix of existing and non-existing keys + key_ranges = [ + ("existing_key", None), + ("non_existing_key_1", None), + ("non_existing_key_2", None), + ] + + results = await store.get_partial_values(default_buffer_prototype(), key_ranges) + + # First should exist, others should be None + assert results[0] is not None + assert results[0].to_bytes() == b"existing_value" + assert results[1] is None + assert results[2] is None diff --git a/tests/test_global_concurrency.py b/tests/test_global_concurrency.py new file mode 100644 index 0000000000..f6366e3c53 --- /dev/null +++ b/tests/test_global_concurrency.py @@ -0,0 +1,327 @@ +""" +Tests for global per-process concurrency limiting. +""" + +import asyncio +from typing import Any + +import numpy as np +import pytest + +import zarr +from zarr.core.common import get_global_semaphore, reset_global_semaphores +from zarr.core.config import config + + +class TestGlobalSemaphore: + """Tests for the global semaphore management.""" + + async def test_get_global_semaphore_creates_per_loop(self) -> None: + """Test that each event loop gets its own semaphore.""" + sem1 = get_global_semaphore() + assert sem1 is not None + assert isinstance(sem1, asyncio.Semaphore) + + # Getting it again should return the same instance + sem2 = get_global_semaphore() + assert sem1 is sem2 + + async def test_global_semaphore_uses_config_limit(self) -> None: + """Test that the global semaphore respects the configured limit.""" + # Set a custom concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores to force recreation + reset_global_semaphores() + + sem = get_global_semaphore() + + # The semaphore should have the configured limit + # We can verify this by acquiring all tokens and checking the semaphore is locked + for i in range(5): + await sem.acquire() + if i < 4: + assert not sem.locked() # Should still have capacity + else: + assert sem.locked() # All tokens acquired, semaphore is now locked + + # Release all tokens + for _ in range(5): + sem.release() + + finally: + # Restore original config + config.set({"async.concurrency": original_limit}) + # Clear semaphores again to reset state + reset_global_semaphores() + + async def test_global_semaphore_shared_across_operations(self) -> None: + """Test that multiple concurrent operations share the same semaphore.""" + # Track the maximum number of concurrent tasks + max_concurrent = 0 + current_concurrent = 0 + lock = asyncio.Lock() + + async def tracked_operation() -> None: + """An operation that tracks concurrency.""" + nonlocal max_concurrent, current_concurrent + + async with lock: + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + + # Small delay to ensure overlap + await asyncio.sleep(0.01) + + async with lock: + current_concurrent -= 1 + + # Set a low concurrency limit to make the test observable + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores + reset_global_semaphores() + + # Get the global semaphore + sem = get_global_semaphore() + + # Create many tasks that use the semaphore + async def task_with_semaphore() -> None: + async with sem: + await tracked_operation() + + # Launch 20 tasks (4x the limit) + tasks = [task_with_semaphore() for _ in range(20)] + await asyncio.gather(*tasks) + + # Maximum concurrent should respect the limit + assert max_concurrent <= 5, f"Max concurrent was {max_concurrent}, expected <= 5" + assert max_concurrent >= 3, ( + f"Max concurrent was {max_concurrent}, expected some concurrency" + ) + + finally: + config.set({"async.concurrency": original_limit}) + reset_global_semaphores() + + async def test_semaphore_reuse_across_calls(self) -> None: + """Test that repeated calls to get_global_semaphore return the same instance.""" + reset_global_semaphores() + + # Call multiple times and verify we get the same instance + sem1 = get_global_semaphore() + sem2 = get_global_semaphore() + sem3 = get_global_semaphore() + + assert sem1 is sem2 is sem3, "Should return same semaphore instance on repeated calls" + + # Verify it's still the same after using it + async with sem1: + sem4 = get_global_semaphore() + assert sem1 is sem4 + + def test_config_change_after_creation(self) -> None: + """Test and document that config changes don't affect existing semaphores.""" + original_limit: Any = config.get("async.concurrency") + try: + # Set initial config + config.set({"async.concurrency": 5}) + + async def check_limit() -> None: + reset_global_semaphores() + + # Create semaphore with limit=5 + sem1 = get_global_semaphore() + initial_capacity: int = sem1._value + + # Change config + config.set({"async.concurrency": 50}) + + # Get semaphore again - should be same instance with old limit + sem2 = get_global_semaphore() + assert sem1 is sem2, "Should return same semaphore instance" + assert sem2._value == initial_capacity, ( + f"Semaphore limit changed from {initial_capacity} to {sem2._value}. " + "Config changes should not affect existing semaphores." + ) + + # Clean up + reset_global_semaphores() + + asyncio.run(check_limit()) + + finally: + config.set({"async.concurrency": original_limit}) + + +class TestArrayConcurrency: + """Tests that array operations use global concurrency limiting.""" + + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") + async def test_multiple_arrays_share_concurrency_limit(self) -> None: + """Test that reading from multiple arrays shares the global concurrency limit.""" + from zarr.core.common import concurrent_map + + # Track concurrent task executions + max_concurrent_tasks = 0 + current_concurrent_tasks = 0 + task_lock = asyncio.Lock() + + async def tracked_chunk_operation(chunk_id: int) -> int: + """Simulate a chunk operation with tracking.""" + nonlocal max_concurrent_tasks, current_concurrent_tasks + + async with task_lock: + current_concurrent_tasks += 1 + max_concurrent_tasks = max(max_concurrent_tasks, current_concurrent_tasks) + + # Small delay to simulate I/O + await asyncio.sleep(0.001) + + async with task_lock: + current_concurrent_tasks -= 1 + + return chunk_id + + # Set a low concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 10}) + + # Clear existing semaphores + reset_global_semaphores() + + # Simulate reading many chunks using concurrent_map (which uses the global semaphore) + # This simulates what happens when reading from multiple arrays + chunk_ids = [(i,) for i in range(100)] + await concurrent_map(chunk_ids, tracked_chunk_operation) + + # The maximum concurrent tasks should respect the global limit + assert max_concurrent_tasks <= 10, ( + f"Max concurrent tasks was {max_concurrent_tasks}, expected <= 10" + ) + + assert max_concurrent_tasks >= 5, ( + f"Max concurrent tasks was {max_concurrent_tasks}, " + f"expected at least some concurrency" + ) + + finally: + config.set({"async.concurrency": original_limit}) + # Note: We don't reset_global_semaphores() here because doing so while + # many tasks are still cleaning up can trigger ResourceWarnings from + # asyncio internals. The semaphore will be reused by subsequent tests. + + def test_sync_api_uses_global_concurrency(self) -> None: + """Test that synchronous API also benefits from global concurrency limiting.""" + # This test verifies that the sync API (which wraps async) uses global limiting + + # Set a low concurrency limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 8}) + + # Create a small array - the key is that zarr internally uses + # concurrent_map which now uses the global semaphore + store = zarr.storage.MemoryStore() + arr = zarr.create( + shape=(20, 20), + chunks=(10, 10), + dtype="i4", + store=store, + zarr_format=3, + ) + arr[:] = 42 + + # Read data (synchronously) + data = arr[:] + + # Verify we got the right data + assert np.all(data == 42) + + # The test passes if no errors occurred + # The concurrency limiting is happening under the hood + + finally: + config.set({"async.concurrency": original_limit}) + + +class TestConcurrentMapGlobal: + """Tests for concurrent_map using global semaphore.""" + + async def test_concurrent_map_uses_global_by_default(self) -> None: + """Test that concurrent_map uses global semaphore by default.""" + from zarr.core.common import concurrent_map + + # Track concurrent executions + max_concurrent = 0 + current_concurrent = 0 + lock = asyncio.Lock() + + async def tracked_task(x: int) -> int: + nonlocal max_concurrent, current_concurrent + + async with lock: + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + + await asyncio.sleep(0.01) + + async with lock: + current_concurrent -= 1 + + return x * 2 + + # Set a low limit + original_limit: Any = config.get("async.concurrency") + try: + config.set({"async.concurrency": 5}) + + # Clear existing semaphores + reset_global_semaphores() + + # Use concurrent_map with default settings (use_global_semaphore=True) + items = [(i,) for i in range(20)] + results = await concurrent_map(items, tracked_task) + + assert len(results) == 20 + assert max_concurrent <= 5 + assert max_concurrent >= 3 # Should have some concurrency + + finally: + config.set({"async.concurrency": original_limit}) + reset_global_semaphores() + + async def test_concurrent_map_legacy_mode(self) -> None: + """Test that concurrent_map legacy mode still works.""" + from zarr.core.common import concurrent_map + + async def simple_task(x: int) -> int: + await asyncio.sleep(0.001) + return x * 2 + + # Use legacy mode with local limit + items = [(i,) for i in range(10)] + results = await concurrent_map(items, simple_task, limit=3, use_global_semaphore=False) + + assert len(results) == 10 + assert results == [i * 2 for i in range(10)] + + async def test_concurrent_map_parameter_validation(self) -> None: + """Test that concurrent_map validates conflicting parameters.""" + from zarr.core.common import concurrent_map + + async def simple_task(x: int) -> int: + return x * 2 + + items = [(i,) for i in range(10)] + + # Should raise ValueError when both limit and use_global_semaphore=True + with pytest.raises( + ValueError, match="Cannot specify both use_global_semaphore=True and a limit" + ): + await concurrent_map(items, simple_task, limit=5, use_global_semaphore=True) diff --git a/tests/test_group.py b/tests/test_group.py index 6f1f4e68fa..9f25036298 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -23,7 +23,6 @@ from zarr.core import sync_group from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.config import config as zarr_config from zarr.core.dtype.common import unpack_dtype_json from zarr.core.dtype.npy.int import UInt8 from zarr.core.group import ( @@ -1738,29 +1737,6 @@ async def test_create_nodes( assert node_spec == {k: v.metadata for k, v in observed_nodes.items()} -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of create_nodes can be constrained by the async concurrency - configuration setting. - """ - set_latency = 0.02 - num_groups = 10 - groups = {str(idx): GroupMetadata() for idx in range(num_groups)} - - latency_store = LatencyStore(store, set_latency=set_latency) - - # check how long it takes to iterate over the groups - # if create_nodes is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = tuple(sync_group.create_nodes(store=latency_store, nodes=groups)) - elapsed = time.time() - start - assert elapsed > num_groups * set_latency - - @pytest.mark.parametrize( ("a_func", "b_func"), [ @@ -2250,38 +2226,6 @@ def test_group_members_performance(store: Store) -> None: assert elapsed < (num_groups * get_latency) -@pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_group_members_concurrency_limit(store: MemoryStore) -> None: - """ - Test that the execution time of Group.members can be constrained by the async concurrency - configuration setting. - """ - get_latency = 0.02 - - # use the input store to create some groups - group_create = zarr.group(store=store) - num_groups = 10 - - # Create some groups - for i in range(num_groups): - group_create.create_group(f"group{i}") - - latency_store = LatencyStore(store, get_latency=get_latency) - # create a group with some latency on get operations - group_read = zarr.group(store=latency_store) - - # check how long it takes to iterate over the groups - # if .members is sensitive to IO latency, - # this should take (num_groups * get_latency) seconds - # otherwise, it should take only marginally more than get_latency seconds - with zarr_config.set({"async.concurrency": 1}): - start = time.time() - _ = group_read.members() - elapsed = time.time() - start - - assert elapsed > num_groups * get_latency - - @pytest.mark.parametrize("option", ["array", "group", "invalid"]) def test_build_metadata_v3(option: Literal["array", "group", "invalid"]) -> None: """ diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 6756bc83d9..73eec991f8 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -12,6 +12,7 @@ from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import assert_bytes_equal @@ -150,3 +151,15 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None: f.write(b"abc") assert path.read_bytes() == b"xyz" assert list(path.parent.iterdir()) == [path] # no temp files + + +class TestLocalStoreConcurrency(StoreConcurrencyTests[LocalStore, cpu.Buffer]): + """Test LocalStore concurrency limiting behavior.""" + + store_cls = LocalStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = 100 # LocalStore default + + @pytest.fixture + def store_kwargs(self, tmpdir: str) -> dict[str, str]: + return {"root": str(tmpdir)} diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 29fa9b2964..2222905745 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -12,6 +12,7 @@ from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests +from zarr.testing.store_concurrency import StoreConcurrencyTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: @@ -130,3 +131,15 @@ def test_from_dict(self) -> None: result = GpuMemoryStore.from_dict(d) for v in result._store_dict.values(): assert type(v) is gpu.Buffer + + +class TestMemoryStoreConcurrency(StoreConcurrencyTests[MemoryStore, cpu.Buffer]): + """Test MemoryStore concurrency limiting behavior.""" + + store_cls = MemoryStore + buffer_cls = cpu.Buffer + expected_concurrency_limit = None # MemoryStore has no limit (fast in-memory ops) + + @pytest.fixture + def store_kwargs(self) -> dict[str, Any]: + return {"store_dict": None}