diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7c567f7..6e37bb0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,12 +11,13 @@ on: jobs: tests: - name: "python ${{ matrix.python-version }} tests" + name: "python=${{ matrix.python-version }} zarr=${{ matrix.zarr-version }} tests" runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13"] + zarr-version: [">=2,<3", ">=3"] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.7.0 @@ -40,7 +41,7 @@ jobs: key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} - name: Install Xarray-Tensorstore run: | - pip install -e .[tests] + pip install -e .[tests] "zarr${{ matrix.zarr-version }}" - name: Run unit tests run: | pytest . diff --git a/setup.py b/setup.py index 0130c3f..4956772 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setuptools.setup( name='xarray-tensorstore', - version='0.2.0', # keep in sync with xarray_tensorstore.py + version='0.3.0', # keep in sync with xarray_tensorstore.py license='Apache-2.0', author='Google LLC', author_email='noreply@google.com', diff --git a/xarray_tensorstore.py b/xarray_tensorstore.py index e9e5001..92fbe9c 100644 --- a/xarray_tensorstore.py +++ b/xarray_tensorstore.py @@ -28,7 +28,7 @@ import zarr -__version__ = '0.2.0' # keep in sync with setup.py +__version__ = '0.3.0' # keep in sync with setup.py Index = TypeVar('Index', int, slice, np.ndarray, None) @@ -217,12 +217,49 @@ def _get_zarr_format(path: str) -> int: return 2 +def _open_tensorstore_arrays( + path: str, + names: list[str], + group: zarr.Group | None, + zarr_format: int, + write: bool, + context: tensorstore.Context | None = None, +) -> dict[str, tensorstore.Future]: + """Open all arrays in a Zarr group using TensorStore.""" + specs = { + k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in names + } + + assume_metadata = False + if packaging.version.parse(zarr.__version__).major >= 3 and group is not None: + consolidated_metadata = group.metadata.consolidated_metadata + if consolidated_metadata is not None: + assume_metadata = True + for name in names: + metadata = consolidated_metadata.metadata[name].to_dict() + metadata.pop('attributes', None) # not supported by TensorStore + specs[name]['metadata'] = metadata + + array_futures = {} + for k, spec in specs.items(): + array_futures[k] = tensorstore.open( + spec, + read=True, + write=write, + open=True, + context=context, + assume_metadata=assume_metadata, + ) + return array_futures + + def open_zarr( path: str, *, context: tensorstore.Context | None = None, mask_and_scale: bool = True, write: bool = False, + consolidated: bool | None = None, ) -> xarray.Dataset: """Open an xarray.Dataset from Zarr using TensorStore. @@ -252,6 +289,9 @@ def open_zarr( xarray.open_zarr(). This is only supported for coordinate variables and otherwise will raise an error. write: Allow write access. Defaults to False. + consolidated: If True, read consolidated metadata. By default, an attempt to + use consolidated metadata is made with a fallback to non-consolidated + metadata, like in Xarray. Returns: Dataset with all data variables opened via TensorStore. @@ -272,8 +312,19 @@ def open_zarr( if context is None: context = tensorstore.Context() - # chunks=None means avoid using dask - ds = xarray.open_zarr(path, chunks=None, mask_and_scale=mask_and_scale) + # Open Xarray's backends.ZarrStore directly so we can get access to the + # underlying Zarr group's consolidated metadata. + store = xarray.backends.ZarrStore.open_group( + path, consolidated=consolidated + ) + group = store.zarr_group + ds = xarray.open_dataset( + filename_or_obj='', # ignored in favor of store= + chunks=None, # avoid using dask + mask_and_scale=mask_and_scale, + store=store, + engine='zarr', + ) if mask_and_scale: # Data variables get replaced below with _TensorStoreAdapter arrays, which @@ -282,13 +333,9 @@ def open_zarr( _raise_if_mask_and_scale_used_for_data_vars(ds) zarr_format = _get_zarr_format(path) - specs = { - k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in ds - } - array_futures = { - k: tensorstore.open(spec, read=True, write=write, context=context) - for k, spec in specs.items() - } + array_futures = _open_tensorstore_arrays( + path, list(ds), group, zarr_format, write=write, context=context + ) arrays = {k: v.result() for k, v in array_futures.items()} new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()} @@ -304,20 +351,26 @@ def _tensorstore_open_concatenated_zarrs( """Open multiple zarrs with TensorStore. Args: - paths: List of paths to zarr stores. - data_vars: List of data variable names to open. - concat_axes: List of axes along which to concatenate the data variables. - context: TensorStore context. + paths: List of paths to zarr stores. + data_vars: List of data variable names to open. + concat_axes: List of axes along which to concatenate the data variables. + context: TensorStore context. + + Returns: + Dictionary of data variable names to concatenated TensorStore arrays. """ # Open all arrays in all datasets using tensorstore arrays_list = [] for path in paths: zarr_format = _get_zarr_format(path) - specs = {k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in data_vars} - array_futures = { - k: tensorstore.open(spec, read=True, write=False, context=context) - for k, spec in specs.items() - } + # TODO(shoyer): Figure out how to support opening concatenated Zarrs with + # consolidated metadata. xarray.open_mfdataset() doesn't support opening + # from an existing store, so we'd have to replicate that functionality for + # figuring out the structure of the concatenated dataset. + group = None + array_futures = _open_tensorstore_arrays( + path, data_vars, group, zarr_format, write=False, context=context + ) arrays_list.append(array_futures) # Concatenate the tensorstore arrays @@ -354,11 +407,11 @@ def open_concatenated_zarrs( context = tensorstore.Context() ds = xarray.open_mfdataset( - paths, - concat_dim=concat_dim, - combine="nested", - mask_and_scale=mask_and_scale, - engine="zarr" + paths, + concat_dim=concat_dim, + combine='nested', + mask_and_scale=mask_and_scale, + engine='zarr', ) if mask_and_scale: @@ -369,7 +422,9 @@ def open_concatenated_zarrs( data_vars = list(ds.data_vars) concat_axes = [ds[v].dims.index(concat_dim) for v in data_vars] - arrays = _tensorstore_open_concatenated_zarrs(paths, data_vars, concat_axes, context) + arrays = _tensorstore_open_concatenated_zarrs( + paths, data_vars, concat_axes, context + ) new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()} return ds.copy(data=new_data) diff --git a/xarray_tensorstore_test.py b/xarray_tensorstore_test.py index 9102ab6..84b1c55 100644 --- a/xarray_tensorstore_test.py +++ b/xarray_tensorstore_test.py @@ -26,63 +26,63 @@ _USING_ZARR_PYTHON_3 = packaging.version.parse(zarr.__version__).major >= 3 test_cases = [ - { - 'testcase_name': 'base', - 'transform': lambda ds: ds, - }, - { - 'testcase_name': 'transposed', - 'transform': lambda ds: ds.transpose('z', 'x', 'y'), - }, - { - 'testcase_name': 'basic_int', - 'transform': lambda ds: ds.isel(y=1), - }, - { - 'testcase_name': 'negative_int', - 'transform': lambda ds: ds.isel(y=-1), - }, - { - 'testcase_name': 'basic_slice', - 'transform': lambda ds: ds.isel(z=slice(2)), - }, - { - 'testcase_name': 'full_slice', - 'transform': lambda ds: ds.isel(z=slice(0, 4)), - }, - { - 'testcase_name': 'out_of_bounds_slice', - 'transform': lambda ds: ds.isel(z=slice(0, 10)), - }, - { - 'testcase_name': 'strided_slice', - 'transform': lambda ds: ds.isel(z=slice(0, None, 2)), - }, - { - 'testcase_name': 'negative_stride_slice', - 'transform': lambda ds: ds.isel(z=slice(None, None, -1)), - }, - { - 'testcase_name': 'repeated_indexing', - 'transform': lambda ds: ds.isel(z=slice(1, None)).isel(z=0), - }, - { - 'testcase_name': 'oindex', - # includes repeated, negative and out of order indices - 'transform': lambda ds: ds.isel(x=[0], y=[1, 1], z=[1, -1, 0]), - }, - { - 'testcase_name': 'vindex', - 'transform': lambda ds: ds.isel(x=('w', [0, 1]), y=('w', [1, 2])), - }, - { - 'testcase_name': 'mixed_indexing_types', - 'transform': lambda ds: ds.isel(x=0, y=slice(2), z=[-1]), - }, - { - 'testcase_name': 'select_a_variable', - 'transform': lambda ds: ds['foo'], - }, + { + 'testcase_name': 'base', + 'transform': lambda ds: ds, + }, + { + 'testcase_name': 'transposed', + 'transform': lambda ds: ds.transpose('z', 'x', 'y'), + }, + { + 'testcase_name': 'basic_int', + 'transform': lambda ds: ds.isel(y=1), + }, + { + 'testcase_name': 'negative_int', + 'transform': lambda ds: ds.isel(y=-1), + }, + { + 'testcase_name': 'basic_slice', + 'transform': lambda ds: ds.isel(z=slice(2)), + }, + { + 'testcase_name': 'full_slice', + 'transform': lambda ds: ds.isel(z=slice(0, 4)), + }, + { + 'testcase_name': 'out_of_bounds_slice', + 'transform': lambda ds: ds.isel(z=slice(0, 10)), + }, + { + 'testcase_name': 'strided_slice', + 'transform': lambda ds: ds.isel(z=slice(0, None, 2)), + }, + { + 'testcase_name': 'negative_stride_slice', + 'transform': lambda ds: ds.isel(z=slice(None, None, -1)), + }, + { + 'testcase_name': 'repeated_indexing', + 'transform': lambda ds: ds.isel(z=slice(1, None)).isel(z=0), + }, + { + 'testcase_name': 'oindex', + # includes repeated, negative and out of order indices + 'transform': lambda ds: ds.isel(x=[0], y=[1, 1], z=[1, -1, 0]), + }, + { + 'testcase_name': 'vindex', + 'transform': lambda ds: ds.isel(x=('w', [0, 1]), y=('w', [1, 2])), + }, + { + 'testcase_name': 'mixed_indexing_types', + 'transform': lambda ds: ds.isel(x=0, y=slice(2), z=[-1]), + }, + { + 'testcase_name': 'select_a_variable', + 'transform': lambda ds: ds['foo'], + }, ] @@ -128,16 +128,18 @@ def test_open_concatenated_zarrs(self, transform): }, attrs={'global': 'global metadata'}, ) - for x in [range(0,2), range(3, 5)] + for x in [range(0, 2), range(3, 5)] ] zarr_dir = self.create_tempdir().full_path - paths = [f"{zarr_dir}/{i}" for i in range(len(sources))] + paths = [f'{zarr_dir}/{i}' for i in range(len(sources))] for source, path in zip(sources, paths, strict=True): source.chunk().to_zarr(path) - expected = transform(xarray.concat(sources, dim="x")) - actual = transform(xarray_tensorstore.open_concatenated_zarrs(paths, concat_dim="x")).compute() + expected = transform(xarray.concat(sources, dim='x')) + actual = transform( + xarray_tensorstore.open_concatenated_zarrs(paths, concat_dim='x') + ).compute() xarray.testing.assert_identical(actual, expected) @parameterized.parameters( @@ -172,7 +174,9 @@ def test_compute(self): self.assertNotIsInstance(computed_data, tensorstore.TensorStore) def test_open_zarr_from_uri(self): - source = xarray.Dataset({'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))}) + source = xarray.Dataset( + {'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))} + ) path = self.create_tempdir().full_path source.chunk().to_zarr(path) @@ -180,10 +184,12 @@ def test_open_zarr_from_uri(self): xarray.testing.assert_identical(source, opened) @parameterized.parameters( - {'zarr_format': 2}, - {'zarr_format': 3}, + {'zarr_format': 2, 'consolidated': True}, + {'zarr_format': 3, 'consolidated': True}, + {'zarr_format': 2, 'consolidated': False}, + {'zarr_format': 3, 'consolidated': False}, ) - def test_read_dataset(self, zarr_format): + def test_read_dataset(self, zarr_format: int, consolidated: bool): if not _USING_ZARR_PYTHON_3 and zarr_format == 3: self.skipTest('zarr format 3 is not supported in zarr < 3.0.0') source = xarray.Dataset( @@ -191,7 +197,9 @@ def test_read_dataset(self, zarr_format): coords={'x': np.arange(2)}, ) path = self.create_tempdir().full_path - source.chunk().to_zarr(path, zarr_format=zarr_format) + source.chunk().to_zarr( + path, zarr_format=zarr_format, consolidated=consolidated + ) opened = xarray_tensorstore.open_zarr(path) read = xarray_tensorstore.read(opened) @@ -204,8 +212,8 @@ def test_read_dataset(self, zarr_format): {'zarr_format': 2}, {'zarr_format': 3}, ) - def test_read_dataarray(self, zarr_format): - if not _USING_ZARR_PYTHON_3 and zarr_format == 3: + def test_read_dataarray(self, zarr_format: int): + if not _USING_ZARR_PYTHON_3 and zarr_format == 3: self.skipTest('zarr format 3 is not supported in zarr < 3.0.0') source = xarray.DataArray( np.arange(24).reshape(2, 3, 4),