Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion src/zarr/core/codec_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from itertools import islice, pairwise
from typing import TYPE_CHECKING, Any, TypeVar
from warnings import warn
Expand All @@ -14,6 +14,7 @@
Codec,
CodecPipeline,
GetResult,
SupportsSyncCodec,
)
from zarr.core.common import concurrent_map
from zarr.core.config import config
Expand Down Expand Up @@ -69,6 +70,115 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
return fill_value


@dataclass(slots=True, kw_only=True)
class ChunkTransform:
"""A synchronous codec chain bound to an ArraySpec.

Provides ``encode_chunk`` and ``decode_chunk`` for pure-compute
codec operations (no IO, no threading, no batching).

``shape`` and ``dtype`` reflect the representation **after** all
ArrayArrayCodec transforms — i.e. the spec that feeds the
ArrayBytesCodec.

All codecs must implement ``SupportsSyncCodec``. Construction will
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idle thought without having yet looked at the implementation: make Codec generic over SupportsSyncCodec (via a protocol) so that this can be caught before runtime?

raise ``TypeError`` if any codec does not.
"""

codecs: tuple[Codec, ...]
array_spec: ArraySpec

# (ArrayArrayCodec, input_spec) pairs in pipeline order.
_aa_codecs: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = field(
init=False, repr=False, compare=False
)
_ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False)
_ab_spec: ArraySpec = field(init=False, repr=False, compare=False)
_bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False)

def __post_init__(self) -> None:
non_sync = [c for c in self.codecs if not isinstance(c, SupportsSyncCodec)]
if non_sync:
names = ", ".join(type(c).__name__ for c in non_sync)
raise TypeError(
f"All codecs must implement SupportsSyncCodec. The following do not: {names}"
)

aa, ab, bb = codecs_from_list(list(self.codecs))

aa_codecs: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = ()
spec = self.array_spec
for aa_codec in aa:
aa_codecs = (*aa_codecs, (aa_codec, spec))
spec = aa_codec.resolve_metadata(spec)

self._aa_codecs = aa_codecs
self._ab_codec = ab
self._ab_spec = spec
self._bb_codecs = bb

@property
def shape(self) -> tuple[int, ...]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are shape and dtype ultimately used? They're a bit complicated to understand. Presumably you need them for your full perf PR, but I wanted to confirm that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on the fence about these attributes actually. I wanted to model the fact that the output of a ChunkTransform is an array, with a fixed shape and dtype, and the fact that the array -> array codecs can be thought of as layers of ChunkTransform objects. But I don't know if we actually need this. Maybe a richer return type annotation is a better way of conveying this information.

"""Shape after all ArrayArrayCodec transforms (input to the ArrayBytesCodec)."""
return self._ab_spec.shape

@property
def dtype(self) -> ZDType[TBaseDType, TBaseScalar]:
"""Dtype after all ArrayArrayCodec transforms (input to the ArrayBytesCodec)."""
return self._ab_spec.dtype

def decode(
self,
chunk_bytes: Buffer,
) -> NDBuffer:
"""Decode a single chunk through the full codec chain, synchronously.

Pure compute -- no IO.
"""
bb_out: Any = chunk_bytes
for bb_codec in reversed(self._bb_codecs):
bb_out = bb_codec._decode_sync(bb_out, self._ab_spec) # type: ignore[attr-defined]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment about why the type: ignore is needed? Presumably it relates to that isinstance(c, SupportsSyncCodec) check above, which mypy can't see here in decode?


ab_out: Any = self._ab_codec._decode_sync(bb_out, self._ab_spec) # type: ignore[attr-defined]

for aa_codec, spec in reversed(self._aa_codecs):
ab_out = aa_codec._decode_sync(ab_out, spec) # type: ignore[attr-defined]

return ab_out # type: ignore[no-any-return]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type: ignore I don't understand. Is the Any up above accurate, or should this be NDBuffer | Buffer like I see here? And if it supposed to be NDBuffer | Buffer why do we declare we return -> NDBuffer


def encode(
self,
chunk_array: NDBuffer,
) -> Buffer | None:
"""Encode a single chunk through the full codec chain, synchronously.

Pure compute -- no IO.
"""
aa_out: Any = chunk_array

for aa_codec, spec in self._aa_codecs:
if aa_out is None:
return None
Comment on lines +160 to +161
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be moved out of the loop? Can _encode_sync return None?

aa_out = aa_codec._encode_sync(aa_out, spec) # type: ignore[attr-defined]

if aa_out is None:
return None
bb_out: Any = self._ab_codec._encode_sync(aa_out, self._ab_spec) # type: ignore[attr-defined]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite hard to read/review: "is the output of aa after application of ab really bb?!" Doesn't help that output is Any despite all the fancy typing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you suggest

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aa_out -> asarray; bb_out -> asbytes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so it's an issue with the variable names, I will see if I can make them more clear

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the properties are an issue too but I don't have a suggestion for that.


for bb_codec in self._bb_codecs:
if bb_out is None:
return None
Comment on lines +169 to +170
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be hoisted out of the loop

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe not? I'm confused regardless.

bb_out = bb_codec._encode_sync(bb_out, self._ab_spec) # type: ignore[attr-defined]

return bb_out # type: ignore[no-any-return]

def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
for codec in self.codecs:
byte_length = codec.compute_encoded_size(byte_length, array_spec)
array_spec = codec.resolve_metadata(array_spec)
return byte_length


@dataclass(frozen=True)
class BatchedCodecPipeline(CodecPipeline):
"""Default codec pipeline.
Expand Down
222 changes: 222 additions & 0 deletions tests/test_sync_codec_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from __future__ import annotations

from typing import Any

import numpy as np
import pytest

from zarr.abc.codec import ArrayBytesCodec
from zarr.codecs.bytes import BytesCodec
from zarr.codecs.crc32c_ import Crc32cCodec
from zarr.codecs.gzip import GzipCodec
from zarr.codecs.transpose import TransposeCodec
from zarr.codecs.zstd import ZstdCodec
from zarr.core.array_spec import ArrayConfig, ArraySpec
from zarr.core.buffer import Buffer, NDBuffer, default_buffer_prototype
from zarr.core.codec_pipeline import ChunkTransform
from zarr.core.dtype import get_data_type_from_native_dtype


def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[np.generic]) -> ArraySpec:
zdtype = get_data_type_from_native_dtype(dtype)
return ArraySpec(
shape=shape,
dtype=zdtype,
fill_value=zdtype.cast_scalar(0),
config=ArrayConfig(order="C", write_empty_chunks=True),
prototype=default_buffer_prototype(),
)


def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer:
return default_buffer_prototype().nd_buffer.from_numpy_array(arr)


class TestChunkTransform:
def test_construction_bytes_only(self) -> None:
# Construction succeeds when all codecs implement SupportsSyncCodec.
spec = _make_array_spec((100,), np.dtype("float64"))
ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)

def test_construction_with_compression(self) -> None:
# AB + BB codec chain where both implement SupportsSyncCodec.
spec = _make_array_spec((100,), np.dtype("float64"))
ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec)

def test_construction_full_chain(self) -> None:
# All three codec types (AA + AB + BB), all implementing SupportsSyncCodec.
spec = _make_array_spec((3, 4), np.dtype("float64"))
ChunkTransform(
codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), array_spec=spec
)

def test_encode_decode_roundtrip_bytes_only(self) -> None:
# Minimal round-trip: BytesCodec serializes the array to bytes and back.
# No compression, no AA transform.
arr = np.arange(100, dtype="float64")
spec = _make_array_spec(arr.shape, arr.dtype)
chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
nd_buf = _make_nd_buffer(arr)
Comment on lines +53 to +59
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I'd be fine with consolidating these tests with the construction tests. I do like having focused construction tests when the constructors are complicated, but these seem simple enough that just seeing the traceback pointing to __post_init__ should be enough. Either works for me though.


encoded = chain.encode(nd_buf)
assert encoded is not None
decoded = chain.decode(encoded)
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
Comment on lines +61 to +64
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth refactoring this to a test helper, assert_encode_decode_equal(...).

Also, not worth worrying about currently, arr is possibly not a NumPy array, depending on what default_buffer_prototype() returns, in which case np.testing.assert_array_equal might not work. But to solve that more generally is out of scope.


def test_shape_dtype_no_aa_codecs(self) -> None:
# Without AA codecs, shape and dtype should match the input ArraySpec
# (no transforms applied before the AB codec).
spec = _make_array_spec((100,), np.dtype("float64"))
chunk = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
assert chunk.shape == (100,)
assert chunk.dtype == spec.dtype

def test_shape_dtype_with_transpose(self) -> None:
# TransposeCodec(order=(1,0)) on a (3, 4) array produces (4, 3).
# shape/dtype reflect what the AB codec sees after all AA transforms.
spec = _make_array_spec((3, 4), np.dtype("float64"))
chunk = ChunkTransform(codecs=(TransposeCodec(order=(1, 0)), BytesCodec()), array_spec=spec)
assert chunk.shape == (4, 3)
assert chunk.dtype == spec.dtype

def test_encode_decode_roundtrip_with_compression(self) -> None:
# Round-trip with a BB codec (GzipCodec) to verify that bytes-bytes
# compression/decompression is wired correctly.
arr = np.arange(100, dtype="float64")
spec = _make_array_spec(arr.shape, arr.dtype)
chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec(level=1)), array_spec=spec)
nd_buf = _make_nd_buffer(arr)

encoded = chain.encode(nd_buf)
assert encoded is not None
decoded = chain.decode(encoded)
np.testing.assert_array_equal(arr, decoded.as_numpy_array())

def test_encode_decode_roundtrip_with_transpose(self) -> None:
# Full AA + AB + BB chain round-trip. Transpose permutes axes on encode,
# then BytesCodec serializes, then ZstdCodec compresses. Decode reverses
# all three stages. Verifies the full pipeline works end to end.
arr = np.arange(12, dtype="float64").reshape(3, 4)
spec = _make_array_spec(arr.shape, arr.dtype)
chain = ChunkTransform(
codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)),
array_spec=spec,
)
nd_buf = _make_nd_buffer(arr)

encoded = chain.encode(nd_buf)
assert encoded is not None
decoded = chain.decode(encoded)
np.testing.assert_array_equal(arr, decoded.as_numpy_array())

def test_rejects_non_sync_codec(self) -> None:
# Construction must raise TypeError when a codec lacks SupportsSyncCodec.

class AsyncOnlyCodec(ArrayBytesCodec):
is_fixed_size = True

async def _decode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> NDBuffer:
raise NotImplementedError # pragma: no cover

async def _encode_single(
self, chunk_array: NDBuffer, chunk_spec: ArraySpec
) -> Buffer | None:
raise NotImplementedError # pragma: no cover

def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
return input_byte_length # pragma: no cover
Comment on lines +115 to +127
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to the top of this module, so you can reuse it in mixed_sync_and_non_sync?


spec = _make_array_spec((100,), np.dtype("float64"))
with pytest.raises(TypeError, match="AsyncOnlyCodec"):
ChunkTransform(codecs=(AsyncOnlyCodec(),), array_spec=spec)

def test_rejects_mixed_sync_and_non_sync(self) -> None:
# Even if some codecs support sync, a single non-sync codec should
# cause construction to fail.

class AsyncOnlyCodec(ArrayBytesCodec):
is_fixed_size = True

async def _decode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> NDBuffer:
raise NotImplementedError # pragma: no cover

async def _encode_single(
self, chunk_array: NDBuffer, chunk_spec: ArraySpec
) -> Buffer | None:
raise NotImplementedError # pragma: no cover

def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
return input_byte_length # pragma: no cover

spec = _make_array_spec((3, 4), np.dtype("float64"))
with pytest.raises(TypeError, match="AsyncOnlyCodec"):
ChunkTransform(
codecs=(TransposeCodec(order=(1, 0)), AsyncOnlyCodec()),
array_spec=spec,
)

def test_compute_encoded_size_bytes_only(self) -> None:
# BytesCodec is size-preserving: encoded size == input size.
spec = _make_array_spec((100,), np.dtype("float64"))
chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
assert chain.compute_encoded_size(800, spec) == 800

def test_compute_encoded_size_with_crc32c(self) -> None:
# Crc32cCodec appends a 4-byte checksum, so encoded size = input + 4.
spec = _make_array_spec((100,), np.dtype("float64"))
chain = ChunkTransform(codecs=(BytesCodec(), Crc32cCodec()), array_spec=spec)
assert chain.compute_encoded_size(800, spec) == 804

def test_compute_encoded_size_with_transpose(self) -> None:
# TransposeCodec reorders axes but doesn't change the byte count.
# Verifies that compute_encoded_size walks through AA codecs correctly.
spec = _make_array_spec((3, 4), np.dtype("float64"))
chain = ChunkTransform(codecs=(TransposeCodec(order=(1, 0)), BytesCodec()), array_spec=spec)
assert chain.compute_encoded_size(96, spec) == 96

def test_encode_returns_none_propagation(self) -> None:
# When an AA codec returns None (signaling "this chunk is the fill value,
# don't store it"), encode must short-circuit and return None
# instead of passing None into the next codec.

class NoneReturningAACodec(TransposeCodec):
"""An ArrayArrayCodec that always returns None from encode."""

def _encode_sync(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer | None:
return None

spec = _make_array_spec((3, 4), np.dtype("float64"))
chain = ChunkTransform(
codecs=(NoneReturningAACodec(order=(1, 0)), BytesCodec()),
array_spec=spec,
)
arr = np.arange(12, dtype="float64").reshape(3, 4)
nd_buf = _make_nd_buffer(arr)
assert chain.encode(nd_buf) is None

def test_encode_decode_roundtrip_with_crc32c(self) -> None:
# Round-trip through BytesCodec + Crc32cCodec. Crc32c appends a checksum
# on encode and verifies it on decode, so this tests that the BB codec
# pipeline runs correctly in both directions.
arr = np.arange(100, dtype="float64")
spec = _make_array_spec(arr.shape, arr.dtype)
chain = ChunkTransform(codecs=(BytesCodec(), Crc32cCodec()), array_spec=spec)
nd_buf = _make_nd_buffer(arr)

encoded = chain.encode(nd_buf)
assert encoded is not None
decoded = chain.decode(encoded)
np.testing.assert_array_equal(arr, decoded.as_numpy_array())

def test_encode_decode_roundtrip_int32(self) -> None:
# Round-trip with int32 data to verify that the codec chain is not
# float-specific. Exercises a different dtype path through BytesCodec.
arr = np.arange(50, dtype="int32")
spec = _make_array_spec(arr.shape, arr.dtype)
chain = ChunkTransform(codecs=(BytesCodec(), ZstdCodec(level=1)), array_spec=spec)
nd_buf = _make_nd_buffer(arr)

encoded = chain.encode(nd_buf)
assert encoded is not None
decoded = chain.decode(encoded)
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
Loading