diff --git a/docarray/array/mixins/io/binary.py b/docarray/array/mixins/io/binary.py index f4451808801..544d0b5266c 100644 --- a/docarray/array/mixins/io/binary.py +++ b/docarray/array/mixins/io/binary.py @@ -1,6 +1,6 @@ +import base64 import io import os.path -import base64 import pickle from contextlib import nullcontext from typing import Union, BinaryIO, TYPE_CHECKING, Type, Optional @@ -22,6 +22,7 @@ def load_binary( protocol: str = 'pickle-array', compress: Optional[str] = None, _show_progress: bool = False, + return_iterator: bool = False, ) -> 'T': """Load array elements from a LZ4-compressed binary file. @@ -29,10 +30,10 @@ def load_binary( :param protocol: protocol to use :param compress: compress algorithm to use :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - + :param return_iterator: returns an iterator over the DocumentArray. + In case protocol is pickle the `Documents` are streamed from disk to save memory usage :return: a DocumentArray object """ - if isinstance(file, io.BufferedReader): file_ctx = nullcontext(file) elif isinstance(file, bytes): @@ -41,37 +42,100 @@ def load_binary( file_ctx = open(file, 'rb') else: raise ValueError(f'unsupported input {file!r}') + if return_iterator: + return cls._load_binary_stream( + file_ctx, protocol=protocol, compress=compress + ) + else: + return cls._load_binary_all(file_ctx, protocol, compress, _show_progress) + + @classmethod + def _load_binary_stream( + cls: Type['T'], file_ctx: str, protocol=None, compress=None, show_progress=False + ) -> 'T': from .... import Document + if show_progress: + from rich.progress import track as _track + + track = lambda x: _track(x, description='Deserializing') + else: + track = lambda x: x + + with file_ctx as f: + version_numdocs_lendoc0 = f.read(9) + # 1 byte (uint8) + version = int.from_bytes(version_numdocs_lendoc0[0:1], 'big', signed=False) + # 8 bytes (uint64) + num_docs = int.from_bytes(version_numdocs_lendoc0[1:9], 'big', signed=False) + + for _ in track(range(num_docs)): + # 4 bytes (uint32) + len_current_doc_in_bytes = int.from_bytes( + f.read(4), 'big', signed=False + ) + yield Document.from_bytes( + f.read(len_current_doc_in_bytes), + protocol=protocol, + compress=compress, + ) + + @classmethod + def _load_binary_all(cls, file_ctx, protocol, compress, show_progress): + from .... import Document + with file_ctx as fp: d = fp.read() if hasattr(fp, 'read') else fp + + if protocol == 'pickle-array' or protocol == 'protobuf-array': if get_compress_ctx(algorithm=compress) is not None: d = decompress_bytes(d, algorithm=compress) compress = None - if protocol == 'protobuf-array': - from ....proto.docarray_pb2 import DocumentArrayProto + if protocol == 'protobuf-array': + from ....proto.docarray_pb2 import DocumentArrayProto + + dap = DocumentArrayProto() + dap.ParseFromString(d) - dap = DocumentArrayProto() - dap.ParseFromString(d) + return cls.from_protobuf(dap) + elif protocol == 'pickle-array': + return pickle.loads(d) - return cls.from_protobuf(dap) - elif protocol == 'pickle-array': - return pickle.loads(d) + # Binary format for streaming case + else: + # 1 byte (uint8) + version = int.from_bytes(d[0:1], 'big', signed=False) + # 8 bytes (uint64) + num_docs = int.from_bytes(d[1:9], 'big', signed=False) + if show_progress: + from rich.progress import track as _track + + track = lambda x: _track(x, description='Deserializing') else: - _len = len(random_uuid().bytes) - _binary_delimiter = d[:_len] # first get delimiter - if _show_progress: - from rich.progress import track as _track + track = lambda x: x - track = lambda x: _track(x, description='Deserializing') - else: - track = lambda x: x - return cls( - Document.from_bytes(od, protocol=protocol, compress=compress) - for od in track(d[_len:].split(_binary_delimiter)) + # this 9 is version + num_docs bytes used + start_pos = 9 + docs = [] + + for _ in track(range(num_docs)): + # 4 bytes (uint32) + len_current_doc_in_bytes = int.from_bytes( + d[start_pos : start_pos + 4], 'big', signed=False + ) + start_doc_pos = start_pos + 4 + end_doc_pos = start_doc_pos + len_current_doc_in_bytes + start_pos = end_doc_pos + + # variable length bytes doc + doc = Document.from_bytes( + d[start_doc_pos:end_doc_pos], protocol=protocol, compress=compress ) + docs.append(doc) + + return cls(docs) @classmethod def from_bytes( @@ -130,8 +194,11 @@ def to_bytes( :return: the binary serialization in bytes """ - _binary_delimiter = random_uuid().bytes - compress_ctx = get_compress_ctx(compress, mode='wb') + if protocol == 'protobuf-array' or protocol == 'pickle-array': + compress_ctx = get_compress_ctx(compress, mode='wb') + else: + compress_ctx = None + with (_file_ctx or io.BytesIO()) as bf: if compress_ctx is None: # if compress do not support streaming then postpone the compress @@ -141,12 +208,14 @@ def to_bytes( f = compress_ctx(bf) fc = f compress = None + with fc: if protocol == 'protobuf-array': f.write(self.to_protobuf().SerializePartialToString()) elif protocol == 'pickle-array': f.write(pickle.dumps(self)) else: + # Binary format for streaming case if _show_progress: from rich.progress import track as _track @@ -154,9 +223,25 @@ def to_bytes( else: track = lambda x: x + # V1 DocArray streaming serialization format + # | 1 byte | 8 bytes | 4 bytes | variable | 4 bytes | variable ... + + # 1 byte (uint8) + version_byte = b'\x01' + # 8 bytes (uint64) + num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False) + f.write(version_byte + num_docs_as_bytes) + for d in track(self): - f.write(_binary_delimiter) - f.write(d.to_bytes(protocol=protocol, compress=compress)) + # 4 bytes (uint32) + doc_as_bytes = d.to_bytes(protocol=protocol, compress=compress) + + # variable size bytes + len_doc_as_bytes = len(doc_as_bytes).to_bytes( + 4, 'big', signed=False + ) + f.write(len_doc_as_bytes + doc_as_bytes) + if not _file_ctx: return bf.getvalue() diff --git a/docarray/document/mixins/porting.py b/docarray/document/mixins/porting.py index 6cbf7c1b357..f6077463c34 100644 --- a/docarray/document/mixins/porting.py +++ b/docarray/document/mixins/porting.py @@ -50,6 +50,7 @@ def to_bytes( raise ValueError( f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' ) + return compress_bytes(bstr, algorithm=compress) @classmethod diff --git a/docs/fundamentals/documentarray/serialization.md b/docs/fundamentals/documentarray/serialization.md index fc29e6de89b..b04aa3dd179 100644 --- a/docs/fundamentals/documentarray/serialization.md +++ b/docs/fundamentals/documentarray/serialization.md @@ -126,21 +126,23 @@ Depending on how you want to interpret the results, the figures above can be an ### Wire format of `pickle` and `protobuf` -When set `protocol=pickle` or `protobuf`, the result binary string looks like the following: +When set `protocol=pickle` or `protobuf`, the resulting bytes look like the following: ```text ------------------------------------------------------------------------------------ -| Delimiter | doc1.to_bytes() | Delimiter | doc2.to_bytes() | Delimiter | ... ------------------------------------------------------------------------------------ - | | - | | - | | - Fixed-length | - | - Variable-length +-------------------------------------------------------------------------------------------------------- +| version | len(docs) | doc1_bytes | doc1.to_bytes() | doc2_bytes | doc2.to_bytes() ... +--------------------------------------------------------------------------------------------------------- +| Fixed-length | Fixed-length | Fixed-length | Variable-length | Fixed-length | Variable-length ... +-------------------------------------------------------------------------------------------------------- + | | | | | | + uint8 uint64 uint32 Variable-length ... ... + ``` -Here `Delimiter` is a 16-bytes separator such as `b'g\x81\xcc\x1c\x0f\x93L\xed\xa2\xb0s)\x9c\xf9\xf6\xf2'` used for setting the boundary of each Document's serialization. Given a `to_bytes(protocol='pickle/protobuf')` binary string, once we know the first 16 bytes, the boundary is clear. Consequently, one can leverage this format to stream Documents, drop, skip, or early-stop, etc. +Here `version` is a `uint8` that specifies the serialization version of the `DocumentArray` serialization format, followed by `len(docs)` which is a `uint64` that specifies the amount of serialized documents. +Afterwards, `doc1_bytes` describes how many bytes are used to serialize `doc1`, followed by `doc1.to_bytes()` which is the bytes data of the document itself. +The pattern `dock_bytes` and `dock.to_bytes` is repeated `len(docs)` times. + ## From/to base64 diff --git a/tests/unit/array/test_from_to_bytes.py b/tests/unit/array/test_from_to_bytes.py index 2ac85ca9364..042573dafe0 100644 --- a/tests/unit/array/test_from_to_bytes.py +++ b/tests/unit/array/test_from_to_bytes.py @@ -1,10 +1,12 @@ +import types + import numpy as np import pytest import tensorflow as tf import torch from scipy.sparse import csr_matrix, coo_matrix, bsr_matrix, csc_matrix -from docarray import DocumentArray +from docarray import DocumentArray, Document from docarray.math.ndarray import to_numpy_array from tests import random_docs @@ -70,6 +72,24 @@ def test_save_bytes(target_da, protocol, compress, tmpfile): DocumentArray.load_binary(fp, protocol=protocol, compress=compress) +# Note protocol = ['protobuf-array', 'pickle-array'] not supported with Document.from_bytes +@pytest.mark.parametrize('protocol', ['protobuf', 'pickle']) +@pytest.mark.parametrize( + 'compress', ['lz4', 'bz2', 'lzma', 'gzip', 'zlib', 'gzib', None] +) +def test_save_bytes_stream(tmpfile, protocol, compress): + da = DocumentArray( + [Document(text='aaa'), Document(buffer=b'buffer'), Document(tags={'a': 'b'})] + ) + da.save_binary(tmpfile, protocol=protocol, compress=compress) + da_reconstructed = DocumentArray.load_binary( + tmpfile, protocol=protocol, compress=compress, return_iterator=True + ) + assert isinstance(da_reconstructed, types.GeneratorType) + for d, d_rec in zip(da, da_reconstructed): + assert d == d_rec + + @pytest.mark.parametrize('target_da', [DocumentArray.empty(100), random_docs(100)]) def test_from_to_protobuf(target_da): DocumentArray.from_protobuf(target_da.to_protobuf())