Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ on:

jobs:
tests:
name: "python ${{ matrix.python-version }} tests"
name: "python=${{ matrix.python-version }} zarr=${{ matrix.zarr-version }} tests"
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
python-version: ["3.11", "3.12", "3.13"]
zarr-version: [">=2,<3", ">=3"]
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
Expand All @@ -40,7 +41,7 @@ jobs:
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
- name: Install Xarray-Tensorstore
run: |
pip install -e .[tests]
pip install -e .[tests] "zarr${{ matrix.zarr-version }}"
- name: Run unit tests
run: |
pytest .
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

setuptools.setup(
name='xarray-tensorstore',
version='0.2.0', # keep in sync with xarray_tensorstore.py
version='0.3.0', # keep in sync with xarray_tensorstore.py
license='Apache-2.0',
author='Google LLC',
author_email='noreply@google.com',
Expand Down
105 changes: 80 additions & 25 deletions xarray_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import zarr


__version__ = '0.2.0' # keep in sync with setup.py
__version__ = '0.3.0' # keep in sync with setup.py


Index = TypeVar('Index', int, slice, np.ndarray, None)
Expand Down Expand Up @@ -217,12 +217,49 @@ def _get_zarr_format(path: str) -> int:
return 2


def _open_tensorstore_arrays(
path: str,
names: list[str],
group: zarr.Group | None,
zarr_format: int,
write: bool,
context: tensorstore.Context | None = None,
) -> dict[str, tensorstore.Future]:
"""Open all arrays in a Zarr group using TensorStore."""
specs = {
k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in names
}

assume_metadata = False
if packaging.version.parse(zarr.__version__).major >= 3 and group is not None:
consolidated_metadata = group.metadata.consolidated_metadata
if consolidated_metadata is not None:
assume_metadata = True
for name in names:
metadata = consolidated_metadata.metadata[name].to_dict()
metadata.pop('attributes', None) # not supported by TensorStore
specs[name]['metadata'] = metadata

array_futures = {}
for k, spec in specs.items():
array_futures[k] = tensorstore.open(
spec,
read=True,
write=write,
open=True,
context=context,
assume_metadata=assume_metadata,
)
return array_futures


def open_zarr(
path: str,
*,
context: tensorstore.Context | None = None,
mask_and_scale: bool = True,
write: bool = False,
consolidated: bool | None = None,
) -> xarray.Dataset:
"""Open an xarray.Dataset from Zarr using TensorStore.

Expand Down Expand Up @@ -252,6 +289,9 @@ def open_zarr(
xarray.open_zarr(). This is only supported for coordinate variables and
otherwise will raise an error.
write: Allow write access. Defaults to False.
consolidated: If True, read consolidated metadata. By default, an attempt to
use consolidated metadata is made with a fallback to non-consolidated
metadata, like in Xarray.

Returns:
Dataset with all data variables opened via TensorStore.
Expand All @@ -272,8 +312,19 @@ def open_zarr(
if context is None:
context = tensorstore.Context()

# chunks=None means avoid using dask
ds = xarray.open_zarr(path, chunks=None, mask_and_scale=mask_and_scale)
# Open Xarray's backends.ZarrStore directly so we can get access to the
# underlying Zarr group's consolidated metadata.
store = xarray.backends.ZarrStore.open_group(
path, consolidated=consolidated
)
group = store.zarr_group
ds = xarray.open_dataset(
filename_or_obj='', # ignored in favor of store=
chunks=None, # avoid using dask
mask_and_scale=mask_and_scale,
store=store,
engine='zarr',
)

if mask_and_scale:
# Data variables get replaced below with _TensorStoreAdapter arrays, which
Expand All @@ -282,13 +333,9 @@ def open_zarr(
_raise_if_mask_and_scale_used_for_data_vars(ds)

zarr_format = _get_zarr_format(path)
specs = {
k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in ds
}
array_futures = {
k: tensorstore.open(spec, read=True, write=write, context=context)
for k, spec in specs.items()
}
array_futures = _open_tensorstore_arrays(
path, list(ds), group, zarr_format, write=write, context=context
)
arrays = {k: v.result() for k, v in array_futures.items()}
new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}

Expand All @@ -304,20 +351,26 @@ def _tensorstore_open_concatenated_zarrs(
"""Open multiple zarrs with TensorStore.

Args:
paths: List of paths to zarr stores.
data_vars: List of data variable names to open.
concat_axes: List of axes along which to concatenate the data variables.
context: TensorStore context.
paths: List of paths to zarr stores.
data_vars: List of data variable names to open.
concat_axes: List of axes along which to concatenate the data variables.
context: TensorStore context.

Returns:
Dictionary of data variable names to concatenated TensorStore arrays.
"""
# Open all arrays in all datasets using tensorstore
arrays_list = []
for path in paths:
zarr_format = _get_zarr_format(path)
specs = {k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in data_vars}
array_futures = {
k: tensorstore.open(spec, read=True, write=False, context=context)
for k, spec in specs.items()
}
# TODO(shoyer): Figure out how to support opening concatenated Zarrs with
# consolidated metadata. xarray.open_mfdataset() doesn't support opening
# from an existing store, so we'd have to replicate that functionality for
# figuring out the structure of the concatenated dataset.
group = None
array_futures = _open_tensorstore_arrays(
path, data_vars, group, zarr_format, write=False, context=context
)
arrays_list.append(array_futures)

# Concatenate the tensorstore arrays
Expand Down Expand Up @@ -354,11 +407,11 @@ def open_concatenated_zarrs(
context = tensorstore.Context()

ds = xarray.open_mfdataset(
paths,
concat_dim=concat_dim,
combine="nested",
mask_and_scale=mask_and_scale,
engine="zarr"
paths,
concat_dim=concat_dim,
combine='nested',
mask_and_scale=mask_and_scale,
engine='zarr',
)

if mask_and_scale:
Expand All @@ -369,7 +422,9 @@ def open_concatenated_zarrs(

data_vars = list(ds.data_vars)
concat_axes = [ds[v].dims.index(concat_dim) for v in data_vars]
arrays = _tensorstore_open_concatenated_zarrs(paths, data_vars, concat_axes, context)
arrays = _tensorstore_open_concatenated_zarrs(
paths, data_vars, concat_axes, context
)
new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}

return ds.copy(data=new_data)
144 changes: 76 additions & 68 deletions xarray_tensorstore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,63 +26,63 @@
_USING_ZARR_PYTHON_3 = packaging.version.parse(zarr.__version__).major >= 3

test_cases = [
{
'testcase_name': 'base',
'transform': lambda ds: ds,
},
{
'testcase_name': 'transposed',
'transform': lambda ds: ds.transpose('z', 'x', 'y'),
},
{
'testcase_name': 'basic_int',
'transform': lambda ds: ds.isel(y=1),
},
{
'testcase_name': 'negative_int',
'transform': lambda ds: ds.isel(y=-1),
},
{
'testcase_name': 'basic_slice',
'transform': lambda ds: ds.isel(z=slice(2)),
},
{
'testcase_name': 'full_slice',
'transform': lambda ds: ds.isel(z=slice(0, 4)),
},
{
'testcase_name': 'out_of_bounds_slice',
'transform': lambda ds: ds.isel(z=slice(0, 10)),
},
{
'testcase_name': 'strided_slice',
'transform': lambda ds: ds.isel(z=slice(0, None, 2)),
},
{
'testcase_name': 'negative_stride_slice',
'transform': lambda ds: ds.isel(z=slice(None, None, -1)),
},
{
'testcase_name': 'repeated_indexing',
'transform': lambda ds: ds.isel(z=slice(1, None)).isel(z=0),
},
{
'testcase_name': 'oindex',
# includes repeated, negative and out of order indices
'transform': lambda ds: ds.isel(x=[0], y=[1, 1], z=[1, -1, 0]),
},
{
'testcase_name': 'vindex',
'transform': lambda ds: ds.isel(x=('w', [0, 1]), y=('w', [1, 2])),
},
{
'testcase_name': 'mixed_indexing_types',
'transform': lambda ds: ds.isel(x=0, y=slice(2), z=[-1]),
},
{
'testcase_name': 'select_a_variable',
'transform': lambda ds: ds['foo'],
},
{
'testcase_name': 'base',
'transform': lambda ds: ds,
},
{
'testcase_name': 'transposed',
'transform': lambda ds: ds.transpose('z', 'x', 'y'),
},
{
'testcase_name': 'basic_int',
'transform': lambda ds: ds.isel(y=1),
},
{
'testcase_name': 'negative_int',
'transform': lambda ds: ds.isel(y=-1),
},
{
'testcase_name': 'basic_slice',
'transform': lambda ds: ds.isel(z=slice(2)),
},
{
'testcase_name': 'full_slice',
'transform': lambda ds: ds.isel(z=slice(0, 4)),
},
{
'testcase_name': 'out_of_bounds_slice',
'transform': lambda ds: ds.isel(z=slice(0, 10)),
},
{
'testcase_name': 'strided_slice',
'transform': lambda ds: ds.isel(z=slice(0, None, 2)),
},
{
'testcase_name': 'negative_stride_slice',
'transform': lambda ds: ds.isel(z=slice(None, None, -1)),
},
{
'testcase_name': 'repeated_indexing',
'transform': lambda ds: ds.isel(z=slice(1, None)).isel(z=0),
},
{
'testcase_name': 'oindex',
# includes repeated, negative and out of order indices
'transform': lambda ds: ds.isel(x=[0], y=[1, 1], z=[1, -1, 0]),
},
{
'testcase_name': 'vindex',
'transform': lambda ds: ds.isel(x=('w', [0, 1]), y=('w', [1, 2])),
},
{
'testcase_name': 'mixed_indexing_types',
'transform': lambda ds: ds.isel(x=0, y=slice(2), z=[-1]),
},
{
'testcase_name': 'select_a_variable',
'transform': lambda ds: ds['foo'],
},
]


Expand Down Expand Up @@ -128,16 +128,18 @@ def test_open_concatenated_zarrs(self, transform):
},
attrs={'global': 'global metadata'},
)
for x in [range(0,2), range(3, 5)]
for x in [range(0, 2), range(3, 5)]
]

zarr_dir = self.create_tempdir().full_path
paths = [f"{zarr_dir}/{i}" for i in range(len(sources))]
paths = [f'{zarr_dir}/{i}' for i in range(len(sources))]
for source, path in zip(sources, paths, strict=True):
source.chunk().to_zarr(path)

expected = transform(xarray.concat(sources, dim="x"))
actual = transform(xarray_tensorstore.open_concatenated_zarrs(paths, concat_dim="x")).compute()
expected = transform(xarray.concat(sources, dim='x'))
actual = transform(
xarray_tensorstore.open_concatenated_zarrs(paths, concat_dim='x')
).compute()
xarray.testing.assert_identical(actual, expected)

@parameterized.parameters(
Expand Down Expand Up @@ -172,26 +174,32 @@ def test_compute(self):
self.assertNotIsInstance(computed_data, tensorstore.TensorStore)

def test_open_zarr_from_uri(self):
source = xarray.Dataset({'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))})
source = xarray.Dataset(
{'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))}
)
path = self.create_tempdir().full_path
source.chunk().to_zarr(path)

opened = xarray_tensorstore.open_zarr('file://' + path)
xarray.testing.assert_identical(source, opened)

@parameterized.parameters(
{'zarr_format': 2},
{'zarr_format': 3},
{'zarr_format': 2, 'consolidated': True},
{'zarr_format': 3, 'consolidated': True},
{'zarr_format': 2, 'consolidated': False},
{'zarr_format': 3, 'consolidated': False},
)
def test_read_dataset(self, zarr_format):
def test_read_dataset(self, zarr_format: int, consolidated: bool):
if not _USING_ZARR_PYTHON_3 and zarr_format == 3:
self.skipTest('zarr format 3 is not supported in zarr < 3.0.0')
source = xarray.Dataset(
{'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))},
coords={'x': np.arange(2)},
)
path = self.create_tempdir().full_path
source.chunk().to_zarr(path, zarr_format=zarr_format)
source.chunk().to_zarr(
path, zarr_format=zarr_format, consolidated=consolidated
)

opened = xarray_tensorstore.open_zarr(path)
read = xarray_tensorstore.read(opened)
Expand All @@ -204,8 +212,8 @@ def test_read_dataset(self, zarr_format):
{'zarr_format': 2},
{'zarr_format': 3},
)
def test_read_dataarray(self, zarr_format):
if not _USING_ZARR_PYTHON_3 and zarr_format == 3:
def test_read_dataarray(self, zarr_format: int):
if not _USING_ZARR_PYTHON_3 and zarr_format == 3:
self.skipTest('zarr format 3 is not supported in zarr < 3.0.0')
source = xarray.DataArray(
np.arange(24).reshape(2, 3, 4),
Expand Down
Loading