diff --git a/docarray/array/mixins/io/binary.py b/docarray/array/mixins/io/binary.py index f4451808801..3b7b772c1c6 100644 --- a/docarray/array/mixins/io/binary.py +++ b/docarray/array/mixins/io/binary.py @@ -1,15 +1,16 @@ +import base64 import io import os.path -import base64 import pickle from contextlib import nullcontext -from typing import Union, BinaryIO, TYPE_CHECKING, Type, Optional +from typing import Union, BinaryIO, TYPE_CHECKING, Type, Optional, Generator from ....helper import random_uuid, __windows__, get_compress_ctx, decompress_bytes if TYPE_CHECKING: from ....types import T from ....proto.docarray_pb2 import DocumentArrayProto + from .... import Document, DocumentArray class BinaryIOMixin: @@ -22,17 +23,18 @@ def load_binary( protocol: str = 'pickle-array', compress: Optional[str] = None, _show_progress: bool = False, - ) -> 'T': - """Load array elements from a LZ4-compressed binary file. + streaming: bool = False, + ) -> Union['DocumentArray', Generator['Document', None, None]]: + """Load array elements from a compressed binary file. :param file: File or filename or serialized bytes where the data is stored. :param protocol: protocol to use :param compress: compress algorithm to use :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` - + :param streaming: if `True` returns a generator over `Document` objects. + In case protocol is pickle the `Documents` are streamed from disk to save memory usage :return: a DocumentArray object """ - if isinstance(file, io.BufferedReader): file_ctx = nullcontext(file) elif isinstance(file, bytes): @@ -41,37 +43,121 @@ def load_binary( file_ctx = open(file, 'rb') else: raise ValueError(f'unsupported input {file!r}') + if streaming: + return cls._load_binary_stream( + file_ctx, + protocol=protocol, + compress=compress, + _show_progress=_show_progress, + ) + else: + return cls._load_binary_all(file_ctx, protocol, compress, _show_progress) + + @classmethod + def _load_binary_stream( + cls: Type['T'], + file_ctx: str, + protocol=None, + compress=None, + _show_progress=False, + ) -> Generator['Document', None, None]: + """Yield `Document` objects from a binary file + + :param protocol: protocol to use + :param compress: compress algorithm to use + :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: a generator of `Document` objects + """ from .... import Document + if _show_progress: + from rich.progress import track as _track + + track = lambda x: _track(x, description='Deserializing') + else: + track = lambda x: x + + with file_ctx as f: + version_numdocs_lendoc0 = f.read(9) + # 1 byte (uint8) + version = int.from_bytes(version_numdocs_lendoc0[0:1], 'big', signed=False) + # 8 bytes (uint64) + num_docs = int.from_bytes(version_numdocs_lendoc0[1:9], 'big', signed=False) + + for _ in track(range(num_docs)): + # 4 bytes (uint32) + len_current_doc_in_bytes = int.from_bytes( + f.read(4), 'big', signed=False + ) + yield Document.from_bytes( + f.read(len_current_doc_in_bytes), + protocol=protocol, + compress=compress, + ) + + @classmethod + def _load_binary_all(cls, file_ctx, protocol, compress, show_progress): + """Read a `DocumentArray` object from a binary file + + :param protocol: protocol to use + :param compress: compress algorithm to use + :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` + :return: a `DocumentArray` + """ + from .... import Document + with file_ctx as fp: d = fp.read() if hasattr(fp, 'read') else fp + + if protocol == 'pickle-array' or protocol == 'protobuf-array': if get_compress_ctx(algorithm=compress) is not None: d = decompress_bytes(d, algorithm=compress) compress = None - if protocol == 'protobuf-array': - from ....proto.docarray_pb2 import DocumentArrayProto + if protocol == 'protobuf-array': + from ....proto.docarray_pb2 import DocumentArrayProto - dap = DocumentArrayProto() - dap.ParseFromString(d) + dap = DocumentArrayProto() + dap.ParseFromString(d) - return cls.from_protobuf(dap) - elif protocol == 'pickle-array': - return pickle.loads(d) + return cls.from_protobuf(dap) + elif protocol == 'pickle-array': + return pickle.loads(d) + + # Binary format for streaming case + else: + # 1 byte (uint8) + version = int.from_bytes(d[0:1], 'big', signed=False) + # 8 bytes (uint64) + num_docs = int.from_bytes(d[1:9], 'big', signed=False) + if show_progress: + from rich.progress import track as _track + + track = lambda x: _track(x, description='Deserializing') else: - _len = len(random_uuid().bytes) - _binary_delimiter = d[:_len] # first get delimiter - if _show_progress: - from rich.progress import track as _track + track = lambda x: x - track = lambda x: _track(x, description='Deserializing') - else: - track = lambda x: x - return cls( - Document.from_bytes(od, protocol=protocol, compress=compress) - for od in track(d[_len:].split(_binary_delimiter)) + # this 9 is version + num_docs bytes used + start_pos = 9 + docs = [] + + for _ in track(range(num_docs)): + # 4 bytes (uint32) + len_current_doc_in_bytes = int.from_bytes( + d[start_pos : start_pos + 4], 'big', signed=False ) + start_doc_pos = start_pos + 4 + end_doc_pos = start_doc_pos + len_current_doc_in_bytes + start_pos = end_doc_pos + + # variable length bytes doc + doc = Document.from_bytes( + d[start_doc_pos:end_doc_pos], protocol=protocol, compress=compress + ) + docs.append(doc) + + return cls(docs) @classmethod def from_bytes( @@ -130,8 +216,11 @@ def to_bytes( :return: the binary serialization in bytes """ - _binary_delimiter = random_uuid().bytes - compress_ctx = get_compress_ctx(compress, mode='wb') + if protocol == 'protobuf-array' or protocol == 'pickle-array': + compress_ctx = get_compress_ctx(compress, mode='wb') + else: + compress_ctx = None + with (_file_ctx or io.BytesIO()) as bf: if compress_ctx is None: # if compress do not support streaming then postpone the compress @@ -141,12 +230,14 @@ def to_bytes( f = compress_ctx(bf) fc = f compress = None + with fc: if protocol == 'protobuf-array': f.write(self.to_protobuf().SerializePartialToString()) elif protocol == 'pickle-array': f.write(pickle.dumps(self)) - else: + elif protocol in ('pickle', 'protobuf'): + # Binary format for streaming case if _show_progress: from rich.progress import track as _track @@ -154,9 +245,29 @@ def to_bytes( else: track = lambda x: x + # V1 DocArray streaming serialization format + # | 1 byte | 8 bytes | 4 bytes | variable | 4 bytes | variable ... + + # 1 byte (uint8) + version_byte = b'\x01' + # 8 bytes (uint64) + num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False) + f.write(version_byte + num_docs_as_bytes) + for d in track(self): - f.write(_binary_delimiter) - f.write(d.to_bytes(protocol=protocol, compress=compress)) + # 4 bytes (uint32) + doc_as_bytes = d.to_bytes(protocol=protocol, compress=compress) + + # variable size bytes + len_doc_as_bytes = len(doc_as_bytes).to_bytes( + 4, 'big', signed=False + ) + f.write(len_doc_as_bytes + doc_as_bytes) + else: + raise ValueError( + f'protocol={protocol} is not supported. Can be only `protobuf`,`pickle`,`protobuf-array`,`pickle-array`.' + ) + if not _file_ctx: return bf.getvalue() diff --git a/docs/fundamentals/documentarray/serialization.md b/docs/fundamentals/documentarray/serialization.md index 87ceb1b10af..be95fd452a6 100644 --- a/docs/fundamentals/documentarray/serialization.md +++ b/docs/fundamentals/documentarray/serialization.md @@ -143,21 +143,36 @@ Depending on how you want to interpret the results, the figures above can be an ### Wire format of `pickle` and `protobuf` -When set `protocol=pickle` or `protobuf`, the result binary string looks like the following: +When set `protocol=pickle` or `protobuf`, the resulting bytes look like the following: ```text ------------------------------------------------------------------------------------ -| Delimiter | doc1.to_bytes() | Delimiter | doc2.to_bytes() | Delimiter | ... ------------------------------------------------------------------------------------ - | | - | | - | | - Fixed-length | - | - Variable-length +-------------------------------------------------------------------------------------------------------- +| version | len(docs) | doc1_bytes | doc1.to_bytes() | doc2_bytes | doc2.to_bytes() ... +--------------------------------------------------------------------------------------------------------- +| Fixed-length | Fixed-length | Fixed-length | Variable-length | Fixed-length | Variable-length ... +-------------------------------------------------------------------------------------------------------- + | | | | | | + uint8 uint64 uint32 Variable-length ... ... + ``` -Here `Delimiter` is a 16-bytes separator such as `b'g\x81\xcc\x1c\x0f\x93L\xed\xa2\xb0s)\x9c\xf9\xf6\xf2'` used for setting the boundary of each Document's serialization. Given a `to_bytes(protocol='pickle/protobuf')` binary string, once we know the first 16 bytes, the boundary is clear. Consequently, one can leverage this format to stream Documents, drop, skip, or early-stop, etc. +Here `version` is a `uint8` that specifies the serialization version of the `DocumentArray` serialization format, followed by `len(docs)` which is a `uint64` that specifies the amount of serialized documents. +Afterwards, `doc1_bytes` describes how many bytes are used to serialize `doc1`, followed by `doc1.to_bytes()` which is the bytes data of the document itself. +The pattern `dock_bytes` and `dock.to_bytes` is repeated `len(docs)` times. + + +### Streaming + +A `DocumentArray` can be streamed from a serialized file as shown in the following example + +```python +da_generator = DocumentArray.load_binary('documentarray.bin', protocol='pickle', compress='gzip', streaming=True) +for d in da_generator: + # work here with `d` as a Document object + print(d.text) +``` + + ## From/to base64 @@ -305,4 +320,6 @@ da = DocumentArray.pull(token='myda123') Now you can continue the work at local, analyzing `da` or visualizing it. Your friends & colleagues who know the token `myda123` can also pull that DocumentArray. It's useful when you want to quickly share the results with your colleagues & friends. -The maximum size of an upload is 4GB under the `protocol='protobuf'` and `compress='gzip'` setting. The lifetime of an upload is one week after its creation. \ No newline at end of file +The maximum size of an upload is 4GB under the `protocol='protobuf'` and `compress='gzip'` setting. The lifetime of an upload is one week after its creation. + + diff --git a/tests/unit/array/test_from_to_bytes.py b/tests/unit/array/test_from_to_bytes.py index ffb0f5f3550..c2f5a8a0665 100644 --- a/tests/unit/array/test_from_to_bytes.py +++ b/tests/unit/array/test_from_to_bytes.py @@ -1,10 +1,11 @@ +import types import numpy as np import pytest import tensorflow as tf import torch from scipy.sparse import csr_matrix, coo_matrix, bsr_matrix, csc_matrix -from docarray import DocumentArray +from docarray import DocumentArray, Document from docarray.math.ndarray import to_numpy_array from tests import random_docs @@ -92,3 +93,21 @@ def test_push_pull_show_progress(show_progress, protocol): r = da.to_bytes(_show_progress=show_progress, protocol=protocol) da_r = DocumentArray.from_bytes(r, _show_progress=show_progress, protocol=protocol) assert da == da_r + + +# Note protocol = ['protobuf-array', 'pickle-array'] not supported with Document.from_bytes +@pytest.mark.parametrize('protocol', ['protobuf', 'pickle']) +@pytest.mark.parametrize( + 'compress', ['lz4', 'bz2', 'lzma', 'gzip', 'zlib', 'gzib', None] +) +def test_save_bytes_stream(tmpfile, protocol, compress): + da = DocumentArray( + [Document(text='aaa'), Document(text='bbb'), Document(text='ccc')] + ) + da.save_binary(tmpfile, protocol=protocol, compress=compress) + da_reconstructed = DocumentArray.load_binary( + tmpfile, protocol=protocol, compress=compress, streaming=True + ) + assert isinstance(da_reconstructed, types.GeneratorType) + for d, d_rec in zip(da, da_reconstructed): + assert d == d_rec