diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2f1d76d7abc..075c10f5975 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,6 +14,10 @@ v2026.03.0 (unreleased) New Features ~~~~~~~~~~~~ +- Adds a new option ``chunks="preserve"`` when opening a dataset. This option + guarantees that chunks in xarray match on-disk chunks or multiples of them. + No chunk splitting allowed. (:pull:`11060`). + By `Julia Signell `_ Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/properties/test_parallelcompat.py b/properties/test_parallelcompat.py new file mode 100644 index 00000000000..91325578619 --- /dev/null +++ b/properties/test_parallelcompat.py @@ -0,0 +1,71 @@ +import numpy as np +import pytest + +pytest.importorskip("hypothesis") +# isort: split + +from hypothesis import given + +import xarray.testing.strategies as xrst +from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint + + +class TestPreserveChunks: + @given(xrst.shape_and_chunks()) + def test_preserve_all_chunks( + self, shape_and_chunks: tuple[tuple[int, ...], tuple[int, ...]] + ) -> None: + shape, previous_chunks = shape_and_chunks + typesize = 8 + target = 1024 * 1024 + + actual = ChunkManagerEntrypoint.preserve_chunks( + chunks=("preserve",) * len(shape), + shape=shape, + target=target, + typesize=typesize, + previous_chunks=previous_chunks, + ) + for i, chunk in enumerate(actual): + if chunk != shape[i]: + assert chunk >= previous_chunks[i] + assert chunk % previous_chunks[i] == 0 + assert chunk <= shape[i] + + if actual != shape: + assert np.prod(actual) * typesize >= 0.5 * target + + @pytest.mark.parametrize("first_chunk", [-1, (), 1]) + @given(xrst.shape_and_chunks(min_dims=2)) + def test_preserve_some_chunks( + self, + first_chunk: int | tuple[int, ...], + shape_and_chunks: tuple[tuple[int, ...], tuple[int, ...]], + ) -> None: + shape, previous_chunks = shape_and_chunks + typesize = 4 + target = 2 * 1024 * 1024 + + actual = ChunkManagerEntrypoint.preserve_chunks( + chunks=(first_chunk, *["preserve" for _ in range(len(shape) - 1)]), + shape=shape, + target=target, + typesize=typesize, + previous_chunks=previous_chunks, + ) + for i, chunk in enumerate(actual): + if i == 0: + if first_chunk == 1: + assert chunk == 1 + elif first_chunk == -1: + assert chunk == shape[i] + elif first_chunk == (): + assert chunk == previous_chunks[i] + elif chunk != shape[i]: + assert chunk >= previous_chunks[i] + assert chunk % previous_chunks[i] == 0 + assert chunk <= shape[i] + + # if we have more than one chunk, make sure the chunks are big enough + if actual[1:] != shape[1:]: + assert np.prod(actual) * typesize >= 0.5 * target diff --git a/xarray/backends/api.py b/xarray/backends/api.py index fd992f3e5d8..89c1a6efc5d 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -292,9 +292,9 @@ def _dataset_from_backend_dataset( create_default_indexes, **extra_tokens, ): - if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}: + if not isinstance(chunks, int | dict) and chunks not in {None, "auto", "preserve"}: raise ValueError( - f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}." + f"chunks must be an int, dict, 'auto', 'preserve', or None. Instead found {chunks}." ) _protect_dataset_variables_inplace(backend_ds, cache) @@ -430,11 +430,14 @@ def open_dataset( "netcdf4" over "h5netcdf" over "scipy" (customizable via ``netcdf_engine_order`` in ``xarray.set_options()``). A custom backend class (a subclass of ``BackendEntrypoint``) can also be used. - chunks : int, dict, 'auto' or None, default: None + chunks : int, dict, 'auto', 'preserve' or None, default: None If provided, used to load the data into dask arrays. - ``chunks="auto"`` will use dask ``auto`` chunking taking into account the engine preferred chunks. + - ``chunks="preserve"`` will use a chunking scheme that never splits encoded + chunks. If encoded chunks are small then "preserve" takes multiples of them + over the largest dimension. - ``chunks=None`` skips using dask. This uses xarray's internally private :ref:`lazy indexing classes `, but data is eagerly loaded into memory as numpy arrays when accessed. @@ -674,11 +677,14 @@ def open_dataarray( "netcdf4" over "h5netcdf" over "scipy" (customizable via ``netcdf_engine_order`` in ``xarray.set_options()``). A custom backend class (a subclass of ``BackendEntrypoint``) can also be used. - chunks : int, dict, 'auto' or None, default: None + chunks : int, dict, 'auto', 'preserve', or None, default: None If provided, used to load the data into dask arrays. - ``chunks='auto'`` will use dask ``auto`` chunking taking into account the engine preferred chunks. + - ``chunks="preserve"`` will use a chunking scheme that never splits encoded + chunks. If encoded chunks are small then "preserve" takes multiples of them + over the largest dimension. - ``chunks=None`` skips using dask. This uses xarray's internally private :ref:`lazy indexing classes `, but data is eagerly loaded into memory as numpy arrays when accessed. @@ -900,11 +906,14 @@ def open_datatree( "h5netcdf" over "netcdf4" (customizable via ``netcdf_engine_order`` in ``xarray.set_options()``). A custom backend class (a subclass of ``BackendEntrypoint``) can also be used. - chunks : int, dict, 'auto' or None, default: None + chunks : int, dict, 'auto', preserve, or None, default: None If provided, used to load the data into dask arrays. - ``chunks="auto"`` will use dask ``auto`` chunking taking into account the engine preferred chunks. + - ``chunks="preserve"`` will use a chunking scheme that never splits encoded + chunks. If encoded chunks are small then "preserve" takes multiples of them + over the largest dimension. - ``chunks=None`` skips using dask. This uses xarray's internally private :ref:`lazy indexing classes `, but data is eagerly loaded into memory as numpy arrays when accessed. @@ -1146,11 +1155,14 @@ def open_groups( ``xarray.set_options()``). A custom backend class (a subclass of ``BackendEntrypoint``) can also be used. can also be used. - chunks : int, dict, 'auto' or None, default: None + chunks : int, dict, 'auto', 'preserve', or None, default: None If provided, used to load the data into dask arrays. - ``chunks="auto"`` will use dask ``auto`` chunking taking into account the engine preferred chunks. + - ``chunks="preserve"`` will use a chunking scheme that never splits encoded + chunks. If encoded chunks are small then "preserve" takes multiples of them + over the largest dimension. - ``chunks=None`` skips using dask. This uses xarray's internally private :ref:`lazy indexing classes `, but data is eagerly loaded into memory as numpy arrays when accessed. @@ -1418,7 +1430,7 @@ def open_mfdataset( concatenation along more than one dimension is desired, then ``paths`` must be a nested list-of-lists (see ``combine_nested`` for details). (A string glob will be expanded to a 1-dimensional list.) - chunks : int, dict, 'auto' or None, optional + chunks : int, dict, 'auto', 'preserve', or None, optional Dictionary with keys given by dimension names and values given by chunk sizes. In general, these should divide the dimensions of each dataset. If int, chunk each dimension by ``chunks``. By default, chunks will be chosen to match the diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 6681673025c..9851ae6671c 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1450,12 +1450,15 @@ def open_zarr( Array synchronizer provided to zarr group : str, optional Group path. (a.k.a. `path` in zarr terminology.) - chunks : int, dict, "auto" or None, optional + chunks : int, dict, "auto", "preserve", or None, optional Used to load the data into dask arrays. Default behavior is to use ``chunks={}`` if dask is available, otherwise ``chunks=None``. - ``chunks='auto'`` will use dask ``auto`` chunking taking into account the engine preferred chunks. + - ``chunks="preserve"`` will use a chunking scheme that never splits encoded + chunks. If encoded chunks are small then "preserve" takes multiples of them + over the largest dimension. - ``chunks=None`` skips using dask. This uses xarray's internally private :ref:`lazy indexing classes `, but data is eagerly loaded into memory as numpy arrays when accessed. diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 9610b96d4f9..f9227f7796a 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -74,7 +74,7 @@ def dtype(self) -> _DType_co: ... _NormalizedChunks = tuple[tuple[int, ...], ...] # FYI in some cases we don't allow `None`, which this doesn't take account of. # # FYI the `str` is for a size string, e.g. "16MB", supported by dask. -T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051 +T_ChunkDim: TypeAlias = str | int | Literal["auto", "preserve"] | tuple[int, ...] | None # noqa: PYI051 # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim] diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index eb01a150c18..323beb6a37e 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from xarray.namedarray._typing import ( + T_ChunkDim, T_Chunks, _DType_co, _NormalizedChunks, @@ -45,11 +46,11 @@ def chunks(self, data: Any) -> _NormalizedChunks: def normalize_chunks( self, - chunks: T_Chunks | _NormalizedChunks, + chunks: tuple[T_ChunkDim, ...] | _NormalizedChunks, shape: tuple[int, ...] | None = None, limit: int | None = None, dtype: _DType_co | None = None, - previous_chunks: _NormalizedChunks | None = None, + previous_chunks: tuple[int, ...] | _NormalizedChunks | None = None, ) -> Any: """Called by open_dataset""" from dask.array.core import normalize_chunks diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index 8a68f5e9562..5c86e8c579d 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from xarray.namedarray._typing import ( + T_ChunkDim, T_Chunks, _Chunks, _DType, @@ -784,3 +785,120 @@ def get_auto_chunk_size( raise NotImplementedError( "For 'auto' rechunking of cftime arrays, get_auto_chunk_size must be implemented by the chunk manager" ) + + @staticmethod + def preserve_chunks( + chunks: tuple[T_ChunkDim, ...], + shape: tuple[int, ...], + target: int, + typesize: int, + previous_chunks: tuple[int, ...] | _NormalizedChunks, + ) -> tuple[T_ChunkDim, ...]: + """Quickly determine optimal chunks close to target size but never splitting + previous_chunks. + + This takes in a chunks argument potentially containing ``"preserve"`` for several + dimensions. This function replaces ``"preserve"`` with concrete dimension sizes that + try to get chunks to be close to certain size in bytes, provided by the ``target=`` + keyword. Any dimensions marked as ``"preserve"`` will potentially be multiplied + by some factor to get close to the byte target, while never splitting + ``previous_chunks``. If chunks are non-uniform along a particular dimension + then that dimension will always use exactly ``previous_chunks``. + + Examples + -------- + >>> ChunkManagerEntrypoint.preserve_chunks( + ... chunks=("preserve", "preserve", "preserve"), + ... shape=(1280, 1280, 20), + ... target=500 * 1024, + ... typesize=8, + ... previous_chunks=(128, 128, 1), + ... ) + (128, 128, 2) + + >>> ChunkManagerEntrypoint.preserve_chunks( + ... chunks=("preserve", "preserve", 1), + ... shape=(1280, 1280, 20), + ... target=1 * 1024 * 1024, + ... typesize=8, + ... previous_chunks=(128, 128, 1), + ... ) + (128, 1024, 1) + + >>> ChunkManagerEntrypoint.preserve_chunks( + ... chunks=("preserve", "preserve", 1), + ... shape=(1280, 1280, 20), + ... target=1 * 1024 * 1024, + ... typesize=8, + ... previous_chunks=((128,) * 10, (128, 256, 256, 512), (1,) * 20), + ... ) + (256, (128, 256, 256, 512), 1) + + Parameters + ---------- + chunks: tuple[int | str | tuple[int], ...] + A tuple of either dimensions or tuples of explicit chunk dimensions + Some entries should be "preserve". + shape: tuple[int] + The shape of the array + target: int + The target size of the chunk in bytes. + typesize: int + The size, in bytes, of each element of the chunk. + previous_chunks: tuple[int | tuple[int], ...] + Size of chunks being preserved. Expressed as a tuple of ints or tuple + of tuple of ints. + """ + new_chunks = [*previous_chunks] + auto_dims = [c == "preserve" for c in chunks] + max_chunks = np.array(shape) + for i, previous_chunk in enumerate(previous_chunks): + chunk = chunks[i] + if chunk == -1: + # -1 means whole dim is in one chunk + new_chunks[i] = shape[i] + else: + if isinstance(previous_chunk, tuple): + # For uniform chunks just take the first item + if previous_chunk[1:-1] == previous_chunk[:-2]: + new_chunks[i] = previous_chunk[0] + previous_chunk = previous_chunk[0] + # For non-uniform chunks, leave them alone + else: + auto_dims[i] = False + max_chunks[i] = max(previous_chunk) + + if isinstance(previous_chunk, int): + # preserve, None or () means we want to track previous chunk + if chunk == "preserve" or not chunk: + max_chunks[i] = previous_chunk + # otherwise use the explicitly provided chunk + else: + new_chunks[i] = chunk + max_chunks[i] = chunk if isinstance(chunk, int) else max(chunk) + + if not any(auto_dims): + return chunks + + while True: + # Repeatedly look for the last dim with more than one chunk and multiply it by 2. + # Stop when: + # 1a. we are larger than the target chunk size OR + # 1b. we are within 50% of the target chunk size OR + # 2. the chunk covers the entire array + + num_chunks = np.array(shape) / max_chunks * auto_dims + chunk_bytes = np.prod(max_chunks) * typesize + + if chunk_bytes > target or abs(chunk_bytes - target) / target < 0.5: + break + + if (num_chunks <= 1).all(): + break + + idx = int(np.nonzero(num_chunks > 1)[0][-1]) + + new_chunks[idx] = min(new_chunks[idx] * 2, shape[idx]) + max_chunks[idx] = new_chunks[idx] + + return tuple(new_chunks) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index 3490a76aa8d..2a997b4a831 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -222,7 +222,7 @@ def _get_chunk( # type: ignore[no-untyped-def] preferred_chunk_shape = tuple( itertools.starmap(preferred_chunks.get, zip(dims, shape, strict=True)) ) - if isinstance(chunks, Number) or (chunks == "auto"): + if isinstance(chunks, (Number, str)): chunks = dict.fromkeys(dims, chunks) chunk_shape = tuple( chunks.get(dim, None) or preferred_chunk_sizes @@ -236,6 +236,20 @@ def _get_chunk( # type: ignore[no-untyped-def] limit = None dtype = data.dtype + if any(c == "preserve" for c in chunk_shape) and any( + c == "auto" for c in chunk_shape + ): + raise ValueError('chunks cannot use a combination of "auto" and "preserve"') + + if shape and preferred_chunk_shape and any(c == "preserve" for c in chunk_shape): + chunk_shape = chunkmanager.preserve_chunks( + chunk_shape, + shape=shape, + target=chunkmanager.get_auto_chunk_size(), + typesize=getattr(dtype, "itemsize", 8), + previous_chunks=preferred_chunk_shape, + ) + chunk_shape = chunkmanager.normalize_chunks( chunk_shape, shape=shape, diff --git a/xarray/testing/strategies.py b/xarray/testing/strategies.py index 9f6bb8110e8..9f0a7080936 100644 --- a/xarray/testing/strategies.py +++ b/xarray/testing/strategies.py @@ -31,6 +31,7 @@ "names", "outer_array_indexers", "pandas_index_dtypes", + "shape_and_chunks", "supported_dtypes", "unique_subset_of", "variables", @@ -210,6 +211,68 @@ def dimension_sizes( ) +@st.composite +def shape_and_chunks( + draw: st.DrawFn, + *, + min_dims: int = 1, + max_dims: int = 4, + min_size: int = 1, + max_size: int = 900, +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Generate a shape tuple and corresponding chunks tuple. + + Each element in the chunks tuple is smaller than or equal to the + corresponding element in the shape tuple. + + Requires the hypothesis package to be installed. + + Parameters + ---------- + min_dims : int, optional + Minimum number of dimensions. Default is 1. + max_dims : int, optional + Maximum number of dimensions. Default is 4. + min_size : int, optional + Minimum size for each dimension. Default is 1. + max_size : int, optional + Maximum size for each dimension. Default is 100. + + Returns + ------- + tuple[tuple[int, ...], tuple[int, ...]] + A tuple containing (shape, chunks) where: + - shape is a tuple of positive integers + - chunks is a tuple where each element is an integer <= corresponding shape element + + Examples + -------- + >>> shape_and_chunks().example() # doctest: +SKIP + ((5, 3, 8), (2, 3, 4)) + >>> shape_and_chunks().example() # doctest: +SKIP + ((10, 7), (10, 3)) + >>> shape_and_chunks(min_dims=2, max_dims=3).example() # doctest: +SKIP + ((4, 6, 2), (2, 3, 1)) + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + # Generate the shape tuple + ndim = draw(st.integers(min_value=min_dims, max_value=max_dims)) + shape = draw( + st.tuples( + *[st.integers(min_value=min_size, max_value=max_size) for _ in range(ndim)] + ) + ) + # Generate chunks tuple with each element <= corresponding shape element + chunks = draw( + st.tuples(*[st.integers(min_value=1, max_value=size) for size in shape]) + ) + return shape, chunks + + _readable_strings = st.text( _readable_characters, max_size=5, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index b11c2b6f4f1..c74a361c70a 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -7336,6 +7336,44 @@ def test_chunking_consistency(chunks, tmp_path: Path) -> None: xr.testing.assert_chunks_equal(actual, expected) +@requires_zarr +@requires_dask +@pytest.mark.parametrize( + "chunks,expected", + [ + ("preserve", (160, 500)), + (-1, (500, 500)), + ({}, (10, 10)), + ({"x": "preserve"}, (500, 10)), + ({"x": -1}, (500, 10)), + ({"x": "preserve", "y": -1}, (160, 500)), + ], +) +def test_open_dataset_chunking_zarr_with_preserve( + chunks, expected, tmp_path: Path +) -> None: + encoded_chunks = 10 + dask_arr = da.from_array( + np.ones((500, 500), dtype="float64"), chunks=encoded_chunks + ) + ds = xr.Dataset( + { + "test": xr.DataArray( + dask_arr, + dims=("x", "y"), + ) + } + ) + ds["test"].encoding["chunks"] = encoded_chunks + ds.to_zarr(tmp_path / "test.zarr") + + with dask.config.set({"array.chunk-size": "1MiB"}): + with open_dataset( + tmp_path / "test.zarr", engine="zarr", chunks=chunks + ) as actual: + assert (actual.chunks["x"][0], actual.chunks["y"][0]) == expected + + def _check_guess_can_open_and_open(entrypoint, obj, engine, expected): assert entrypoint.guess_can_open(obj) with open_dataset(obj, engine=engine) as actual: