Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docarray/array/mixins/io/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,18 @@ def to_bytes(
if not _file_ctx:
return bf.getvalue()

def to_protobuf(self) -> 'DocumentArrayProto':
def to_protobuf(self, ndarray_type: Optional[str] = None) -> 'DocumentArrayProto':
"""Convert DocumentArray into a Protobuf message.

:param ndarray_type: can be ``list`` or ``numpy``, if set it will force all ndarray-like object from all
Documents to ``List`` or ``numpy.ndarray``.
:return: the protobuf message
"""
from ....proto.docarray_pb2 import DocumentArrayProto

dap = DocumentArrayProto()
for d in self:
dap.docs.append(d.to_protobuf())
dap.docs.append(d.to_protobuf(ndarray_type))
return dap

@classmethod
Expand Down
11 changes: 8 additions & 3 deletions docarray/document/mixins/protobuf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Type
from typing import TYPE_CHECKING, Type, Optional

if TYPE_CHECKING:
from ...types import T
Expand All @@ -12,7 +12,12 @@ def from_protobuf(cls: Type['T'], pb_msg: 'DocumentProto') -> 'T':

return parse_proto(pb_msg)

def to_protobuf(self) -> 'DocumentProto':
def to_protobuf(self, ndarray_type: Optional[str] = None) -> 'DocumentProto':
"""Convert Document into a Protobuf message.

:param ndarray_type: can be ``list`` or ``numpy``, if set it will force all ndarray-like object to be ``List`` or ``numpy.ndarray``.
:return: the protobuf message
"""
from ...proto.io import flush_proto

return flush_proto(self)
return flush_proto(self, ndarray_type)
6 changes: 3 additions & 3 deletions docarray/proto/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct
Expand Down Expand Up @@ -37,13 +37,13 @@ def parse_proto(pb_msg: 'DocumentProto') -> 'Document':
return Document(**fields)


def flush_proto(doc: 'Document') -> 'DocumentProto':
def flush_proto(doc: 'Document', ndarray_type: Optional[str] = None) -> 'DocumentProto':
pb_msg = DocumentProto()
for key in doc.non_empty_fields:
try:
value = getattr(doc, key)
if key in ('tensor', 'embedding'):
flush_ndarray(getattr(pb_msg, key), value)
flush_ndarray(getattr(pb_msg, key), value, ndarray_type=ndarray_type)
elif key in ('chunks', 'matches'):
for d in value:
d: Document
Expand Down
13 changes: 10 additions & 3 deletions docarray/proto/io/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import numpy as np

from ...math.ndarray import get_array_type
from ...math.ndarray import get_array_type, to_numpy_array

if TYPE_CHECKING:
from ...types import ArrayType
Expand Down Expand Up @@ -44,7 +44,14 @@ def read_ndarray(pb_msg: 'NdArrayProto') -> 'ArrayType':
return _to_framework_array(x, framework)


def flush_ndarray(pb_msg: 'NdArrayProto', value: 'ArrayType'):
def flush_ndarray(
pb_msg: 'NdArrayProto', value: 'ArrayType', ndarray_type: Optional[str] = None
):
if ndarray_type == 'list':
value = to_numpy_array(value).tolist()
elif ndarray_type == 'numpy':
value = to_numpy_array(value)

framework, is_sparse = get_array_type(value)

if framework == 'docarray':
Expand Down
1 change: 1 addition & 0 deletions docs/fundamentals/document/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ mime_type: "image/jpeg"

One can refer to the [Protobuf specification of `Document`](../../proto/index.md) for details.

When `.tensor` or `.embedding` contains frameworks-specific ndarray-like object, you can use `.to_protobuf(..., ndarray_type='numpy')` or `.to_protobuf(..., ndarray_type='list')` to cast them into `list` or `numpy.ndarray` automatically. This will help to ensure the maximum compatability between different microservices.

## What's next?

Expand Down
14 changes: 13 additions & 1 deletion tests/unit/math/test_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import numpy as np
import paddle
import pytest
import tensorflow as tf
import torch
from scipy.sparse import csr_matrix, coo_matrix, bsr_matrix, csc_matrix, issparse

from docarray.math.ndarray import get_array_rows
from docarray.proto.docarray_pb2 import NdArrayProto
from docarray.proto.io import flush_ndarray, read_ndarray


@pytest.mark.parametrize(
Expand All @@ -30,7 +33,8 @@
csc_matrix,
],
)
def test_get_array_rows(data, expected_result, arraytype):
@pytest.mark.parametrize('ndarray_type', ['list', 'numpy'])
def test_get_array_rows(data, expected_result, arraytype, ndarray_type):
data_array = arraytype(data)

num_rows, ndim = get_array_rows(data_array)
Expand All @@ -39,3 +43,11 @@ def test_get_array_rows(data, expected_result, arraytype):
assert expected_result[0] == num_rows
else:
assert expected_result == (num_rows, ndim)

na_proto = NdArrayProto()
flush_ndarray(na_proto, value=data_array, ndarray_type=ndarray_type)
r_data_array = read_ndarray(na_proto)
if ndarray_type == 'list':
assert isinstance(r_data_array, list)
elif ndarray_type == 'numpy':
assert isinstance(r_data_array, np.ndarray)