Skip to content

Commit 9b5e7fc

Browse files
committed
Optimized codec pipeline that avoids a memory copy in common cases when reading
Improve test Slight generalization Add test for blosc Typing Only support particular configurations Don't test array with single chunk Remove unused checksum test for blosc
1 parent cf879eb commit 9b5e7fc

3 files changed

Lines changed: 170 additions & 0 deletions

File tree

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, Any
6+
7+
import numpy.typing as npt
8+
9+
from zarr.core.buffer import core
10+
11+
from zarr.abc.codec import BytesBytesCodec
12+
from zarr.codecs.bytes import BytesCodec
13+
from zarr.codecs.blosc import BloscCodec
14+
from zarr.codecs.zstd import ZstdCodec
15+
from zarr.core.codec_pipeline import BatchedCodecPipeline
16+
from zarr.core.indexing import SelectorTuple, is_contiguous_selection
17+
from zarr.registry import register_pipeline
18+
19+
if TYPE_CHECKING:
20+
from collections.abc import Callable, Iterable
21+
22+
from zarr.abc.store import ByteGetter
23+
from zarr.core.array_spec import ArraySpec
24+
from zarr.core.buffer import Buffer, NDBuffer
25+
from zarr.core.buffer.core import NDArrayLike
26+
from zarr.core.common import ChunkCoords
27+
from zarr.core.indexing import Selection
28+
29+
30+
@dataclass(frozen=True)
31+
class OptimizedCodecPipeline(BatchedCodecPipeline):
32+
33+
async def read(
34+
self,
35+
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
36+
out: NDBuffer,
37+
drop_axes: tuple[int, ...] = (),
38+
) -> None:
39+
if (
40+
len(self.array_array_codecs) == 0 and
41+
isinstance(self.array_bytes_codec, BytesCodec) and
42+
len(self.bytes_bytes_codecs) == 1 and
43+
isinstance(self.bytes_bytes_codecs[0], (BloscCodec, ZstdCodec))
44+
):
45+
# read compressed bytes using ByteGetter, decompress into out
46+
for byte_getter, chunk_spec, chunk_selection, out_selection, is_complete_chunk in batch_info:
47+
if not is_contiguous_selection(out_selection) or not is_total_slice(out_selection, chunk_spec.shape):
48+
# fall back to non-optimized path
49+
# TODO: what if we are part-way through a batch?
50+
await super().read(batch_info, out, drop_axes)
51+
return
52+
53+
buffer = await byte_getter.get(chunk_spec.prototype)
54+
await _decode_single_out(self.bytes_bytes_codecs[0], buffer, out[out_selection])
55+
56+
else:
57+
await super().read(batch_info, out, drop_axes)
58+
59+
# Note that this was removed in https://github.com/zarr-developers/zarr-python/pull/2784
60+
# and replaced with `is_complete_chunk`, which returns True for end chunks.
61+
# However, for the purposes of decoding into an out buffer, we need end chunks to be
62+
# treated differently since the out selection is smaller than the buffer being read from.
63+
def is_total_slice(item: Selection, shape: ChunkCoords) -> bool:
64+
"""Determine whether `item` specifies a complete slice of array with the
65+
given `shape`. Used to optimize __setitem__ operations on the Chunk
66+
class."""
67+
68+
# N.B., assume shape is normalized
69+
if item == slice(None):
70+
return True
71+
if isinstance(item, slice):
72+
item = (item,)
73+
if isinstance(item, tuple):
74+
return all(
75+
(isinstance(dim_sel, int) and dim_len == 1)
76+
or (
77+
isinstance(dim_sel, slice)
78+
and (
79+
(dim_sel == slice(None))
80+
or ((dim_sel.stop - dim_sel.start == dim_len) and (dim_sel.step in [1, None]))
81+
)
82+
)
83+
for dim_sel, dim_len in zip(item, shape, strict=False)
84+
)
85+
else:
86+
raise TypeError(f"expected slice or tuple of slices, found {item!r}")
87+
88+
89+
async def _decode_single_out(
90+
codec: BytesBytesCodec,
91+
chunk_bytes: Buffer,
92+
out: NDBuffer,
93+
) -> None:
94+
if isinstance(codec, BloscCodec):
95+
decode_method = codec._blosc_codec.decode
96+
elif isinstance(codec, ZstdCodec):
97+
decode_method = codec._zstd_codec.decode
98+
else:
99+
raise ValueError(f"Unsupported codec: {codec}")
100+
await asyncio.to_thread(
101+
as_numpy_array_wrapper_out, decode_method, chunk_bytes, out
102+
)
103+
104+
def as_numpy_array_wrapper_out(
105+
func: Callable[[npt.NDArray[Any], NDArrayLike], None], buf: core.Buffer, out: core.NDBuffer
106+
) -> None:
107+
func(buf.as_numpy_array(), out.as_ndarray_like())
108+
109+
register_pipeline(OptimizedCodecPipeline)

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ async def parse_store(
4242
) -> LocalStore | MemoryStore | FsspecStore | ZipStore:
4343
if store == "local":
4444
return await LocalStore.open(path)
45+
if store == "obstore":
46+
import obstore
47+
from zarr.storage import ObjectStore
48+
49+
local_store = obstore.store.LocalStore(prefix=path, mkdir=True)
50+
return ObjectStore(store=local_store)
51+
return await LocalStore.open(path)
4552
if store == "memory":
4653
return await MemoryStore.open()
4754
if store == "fsspec":
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
import pytest
3+
4+
import zarr
5+
from zarr.abc.store import Store
6+
from zarr.codecs import BloscCodec, ZstdCodec
7+
from zarr.core.config import config
8+
from zarr.storage import StorePath
9+
10+
# TODO: this is to register the pipeline
11+
import zarr.core.optimized_codec_pipeline
12+
13+
@pytest.mark.parametrize("store", ["local", "memory", "obstore"], indirect=["store"])
14+
@pytest.mark.parametrize("checksum", [True, False])
15+
def test_optimized_codec_pipeline_zstd(store: Store, checksum: bool) -> None:
16+
data = np.arange(0, 256, dtype="uint16").reshape((16, 16))
17+
18+
with config.set({"codec_pipeline.path": "zarr.core.optimized_codec_pipeline.OptimizedCodecPipeline"}):
19+
a = zarr.create_array(
20+
StorePath(store, path="zstd"),
21+
shape=data.shape,
22+
chunks=(10, 10),
23+
dtype=data.dtype,
24+
fill_value=0,
25+
compressors=ZstdCodec(level=0, checksum=checksum),
26+
)
27+
28+
a[:, :] = data
29+
30+
a = zarr.open(StorePath(store, path="zstd"))
31+
assert np.array_equal(a[0:10, 0:10], data[0:10, 0:10])
32+
assert np.array_equal(a[0:10, 10:16], data[0:10, 10:16]) # end chunk
33+
34+
35+
@pytest.mark.parametrize("store", ["local", "memory", "obstore"], indirect=["store"])
36+
def test_optimized_codec_pipeline_blosc(store: Store) -> None:
37+
data = np.arange(0, 256, dtype="uint16").reshape((16, 16))
38+
39+
with config.set({"codec_pipeline.path": "zarr.core.optimized_codec_pipeline.OptimizedCodecPipeline"}):
40+
a = zarr.create_array(
41+
StorePath(store, path="zstd"),
42+
shape=data.shape,
43+
chunks=(10, 10),
44+
dtype=data.dtype,
45+
fill_value=0,
46+
compressors=BloscCodec(),
47+
)
48+
49+
a[:, :] = data
50+
51+
a = zarr.open(StorePath(store, path="zstd"))
52+
assert np.array_equal(a[0:10, 0:10], data[0:10, 0:10])
53+
assert np.array_equal(a[0:10, 10:16], data[0:10, 10:16]) # end chunk
54+

0 commit comments

Comments
 (0)