1+ from __future__ import annotations
2+
3+ import asyncio
4+ from dataclasses import dataclass
5+ from typing import TYPE_CHECKING , Any
6+
7+ import numpy .typing as npt
8+
9+ from zarr .core .buffer import core
10+
11+ from zarr .abc .codec import BytesBytesCodec
12+ from zarr .codecs .bytes import BytesCodec
13+ from zarr .codecs .blosc import BloscCodec
14+ from zarr .codecs .zstd import ZstdCodec
15+ from zarr .core .codec_pipeline import BatchedCodecPipeline
16+ from zarr .core .indexing import SelectorTuple , is_contiguous_selection
17+ from zarr .registry import register_pipeline
18+
19+ if TYPE_CHECKING :
20+ from collections .abc import Callable , Iterable
21+
22+ from zarr .abc .store import ByteGetter
23+ from zarr .core .array_spec import ArraySpec
24+ from zarr .core .buffer import Buffer , NDBuffer
25+ from zarr .core .buffer .core import NDArrayLike
26+ from zarr .core .common import ChunkCoords
27+ from zarr .core .indexing import Selection
28+
29+
30+ @dataclass (frozen = True )
31+ class OptimizedCodecPipeline (BatchedCodecPipeline ):
32+
33+ async def read (
34+ self ,
35+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
36+ out : NDBuffer ,
37+ drop_axes : tuple [int , ...] = (),
38+ ) -> None :
39+ if (
40+ len (self .array_array_codecs ) == 0 and
41+ isinstance (self .array_bytes_codec , BytesCodec ) and
42+ len (self .bytes_bytes_codecs ) == 1 and
43+ isinstance (self .bytes_bytes_codecs [0 ], (BloscCodec , ZstdCodec ))
44+ ):
45+ # read compressed bytes using ByteGetter, decompress into out
46+ for byte_getter , chunk_spec , chunk_selection , out_selection , is_complete_chunk in batch_info :
47+ if not is_contiguous_selection (out_selection ) or not is_total_slice (out_selection , chunk_spec .shape ):
48+ # fall back to non-optimized path
49+ # TODO: what if we are part-way through a batch?
50+ await super ().read (batch_info , out , drop_axes )
51+ return
52+
53+ buffer = await byte_getter .get (chunk_spec .prototype )
54+ await _decode_single_out (self .bytes_bytes_codecs [0 ], buffer , out [out_selection ])
55+
56+ else :
57+ await super ().read (batch_info , out , drop_axes )
58+
59+ # Note that this was removed in https://github.com/zarr-developers/zarr-python/pull/2784
60+ # and replaced with `is_complete_chunk`, which returns True for end chunks.
61+ # However, for the purposes of decoding into an out buffer, we need end chunks to be
62+ # treated differently since the out selection is smaller than the buffer being read from.
63+ def is_total_slice (item : Selection , shape : ChunkCoords ) -> bool :
64+ """Determine whether `item` specifies a complete slice of array with the
65+ given `shape`. Used to optimize __setitem__ operations on the Chunk
66+ class."""
67+
68+ # N.B., assume shape is normalized
69+ if item == slice (None ):
70+ return True
71+ if isinstance (item , slice ):
72+ item = (item ,)
73+ if isinstance (item , tuple ):
74+ return all (
75+ (isinstance (dim_sel , int ) and dim_len == 1 )
76+ or (
77+ isinstance (dim_sel , slice )
78+ and (
79+ (dim_sel == slice (None ))
80+ or ((dim_sel .stop - dim_sel .start == dim_len ) and (dim_sel .step in [1 , None ]))
81+ )
82+ )
83+ for dim_sel , dim_len in zip (item , shape , strict = False )
84+ )
85+ else :
86+ raise TypeError (f"expected slice or tuple of slices, found { item !r} " )
87+
88+
89+ async def _decode_single_out (
90+ codec : BytesBytesCodec ,
91+ chunk_bytes : Buffer ,
92+ out : NDBuffer ,
93+ ) -> None :
94+ if isinstance (codec , BloscCodec ):
95+ decode_method = codec ._blosc_codec .decode
96+ elif isinstance (codec , ZstdCodec ):
97+ decode_method = codec ._zstd_codec .decode
98+ else :
99+ raise ValueError (f"Unsupported codec: { codec } " )
100+ await asyncio .to_thread (
101+ as_numpy_array_wrapper_out , decode_method , chunk_bytes , out
102+ )
103+
104+ def as_numpy_array_wrapper_out (
105+ func : Callable [[npt .NDArray [Any ], NDArrayLike ], None ], buf : core .Buffer , out : core .NDBuffer
106+ ) -> None :
107+ func (buf .as_numpy_array (), out .as_ndarray_like ())
108+
109+ register_pipeline (OptimizedCodecPipeline )
0 commit comments