diff --git a/docarray/base_document/mixins/io.py b/docarray/base_document/mixins/io.py index 07f73cf2735..4c79a20fc98 100644 --- a/docarray/base_document/mixins/io.py +++ b/docarray/base_document/mixins/io.py @@ -14,11 +14,26 @@ TypeVar, ) +import numpy as np from typing_inspect import is_union_type from docarray.base_document.base_node import BaseNode +from docarray.typing import NdArray from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS from docarray.utils.compress import _compress_bytes, _decompress_bytes +from docarray.utils.misc import is_tf_available, is_torch_available + +tf_available = is_tf_available() +if tf_available: + import tensorflow as tf # type: ignore + + from docarray.typing import TensorFlowTensor + +torch_available = is_torch_available() +if torch_available: + import torch + + from docarray.typing import TorchTensor if TYPE_CHECKING: from pydantic.fields import ModelField @@ -36,60 +51,69 @@ def _type_to_protobuf(value: Any) -> 'NodeProto': """ from docarray.proto import NodeProto - nested_item: 'NodeProto' - if isinstance(value, BaseNode): - nested_item = value._to_node_protobuf() + basic_type_to_key = { + str: 'text', + bool: 'boolean', + int: 'integer', + float: 'float', + bytes: 'blob', + } - elif isinstance(value, str): - nested_item = NodeProto(text=value) + container_type_to_key = {list: 'list', set: 'set', tuple: 'tuple'} - elif isinstance(value, bool): - nested_item = NodeProto(boolean=value) + nested_item: 'NodeProto' - elif isinstance(value, int): - nested_item = NodeProto(integer=value) + if isinstance(value, BaseNode): + nested_item = value._to_node_protobuf() + return nested_item - elif isinstance(value, float): - nested_item = NodeProto(float=value) + base_node_wrap: BaseNode + if torch_available: + if isinstance(value, torch.Tensor): + base_node_wrap = TorchTensor._docarray_from_native(value) + return base_node_wrap._to_node_protobuf() - elif isinstance(value, bytes): - nested_item = NodeProto(blob=value) + if tf_available: + if isinstance(value, tf.Tensor): + base_node_wrap = TensorFlowTensor._docarray_from_native(value) + return base_node_wrap._to_node_protobuf() - elif isinstance(value, list): - from google.protobuf.struct_pb2 import ListValue + if isinstance(value, np.ndarray): + base_node_wrap = NdArray._docarray_from_native(value) + return base_node_wrap._to_node_protobuf() - lvalue = ListValue() - for item in value: - lvalue.append(item) - nested_item = NodeProto(list=lvalue) + for basic_type, key_name in basic_type_to_key.items(): + if isinstance(value, basic_type): + nested_item = NodeProto(**{key_name: value}) + return nested_item - elif isinstance(value, set): - from google.protobuf.struct_pb2 import ListValue + for container_type, key_name in container_type_to_key.items(): + if isinstance(value, container_type): + from docarray.proto import ListOfAnyProto - lvalue = ListValue() - for item in value: - lvalue.append(item) - nested_item = NodeProto(set=lvalue) + lvalue = ListOfAnyProto() + for item in value: + lvalue.data.append(_type_to_protobuf(item)) + nested_item = NodeProto(**{key_name: lvalue}) + return nested_item - elif isinstance(value, tuple): - from google.protobuf.struct_pb2 import ListValue + if isinstance(value, dict): + from docarray.proto import DictOfAnyProto - lvalue = ListValue() - for item in value: - lvalue.append(item) - nested_item = NodeProto(tuple=lvalue) + data = {} - elif isinstance(value, dict): - from google.protobuf.struct_pb2 import Struct + for key, content in value.items(): + data[key] = _type_to_protobuf(content) - struct = Struct() - struct.update(value) + struct = DictOfAnyProto(data=data) nested_item = NodeProto(dict=struct) + return nested_item + elif value is None: nested_item = NodeProto() + return nested_item else: raise ValueError(f'{type(value)} is not supported with protobuf') - return nested_item class IOMixin(Iterable[Tuple[str, Any]]): @@ -208,7 +232,9 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: return cls(**fields) @classmethod - def _get_content_from_node_proto(cls, value: 'NodeProto', field_name: str) -> Any: + def _get_content_from_node_proto( + cls, value: 'NodeProto', field_name: Optional[str] = None + ) -> Any: """ load the proto data from a node proto @@ -217,12 +243,6 @@ def _get_content_from_node_proto(cls, value: 'NodeProto', field_name: str) -> An :return: the loaded field """ content_type_dict = _PROTO_TYPE_NAME_TO_CLASS - arg_to_container: Dict[str, Callable] = { - 'list': list, - 'set': set, - 'tuple': tuple, - 'dict': dict, - } content_key = value.WhichOneof('content') docarray_type = ( @@ -236,6 +256,10 @@ def _get_content_from_node_proto(cls, value: 'NodeProto', field_name: str) -> An getattr(value, content_key) ) elif content_key in ['document', 'document_array']: + if field_name is None: + raise ValueError( + 'field_name cannot be None when trying to deseriliaze a Document or a DocumentArray' + ) return_field = cls._get_field_type(field_name).from_protobuf( getattr(value, content_key) ) # we get to the parent class @@ -243,16 +267,26 @@ def _get_content_from_node_proto(cls, value: 'NodeProto', field_name: str) -> An return_field = None elif docarray_type is None: + arg_to_container: Dict[str, Callable] = { + 'list': list, + 'set': set, + 'tuple': tuple, + } + if content_key in ['text', 'blob', 'integer', 'float', 'boolean']: return_field = getattr(value, content_key) elif content_key in arg_to_container.keys(): - from google.protobuf.json_format import MessageToDict - return_field = arg_to_container[content_key]( - MessageToDict(getattr(value, content_key)) + cls._get_content_from_node_proto(node) + for node in getattr(value, content_key).data ) + elif content_key == 'dict': + deser_dict: Dict[str, Any] = dict() + for key_name, node in value.dict.data.items(): + deser_dict[key_name] = cls._get_content_from_node_proto(node) + return_field = deser_dict else: raise ValueError( f'key {content_key} is not supported for deserialization' diff --git a/docarray/proto/__init__.py b/docarray/proto/__init__.py index 61bcec75322..5bc8c078c51 100644 --- a/docarray/proto/__init__.py +++ b/docarray/proto/__init__.py @@ -2,6 +2,7 @@ if __pb__version__.startswith('4'): from docarray.proto.pb.docarray_pb2 import ( + DictOfAnyProto, DocumentArrayProto, DocumentArrayStackedProto, DocumentProto, @@ -12,6 +13,7 @@ ) else: from docarray.proto.pb2.docarray_pb2 import ( + DictOfAnyProto, DocumentArrayProto, DocumentArrayStackedProto, DocumentProto, @@ -30,4 +32,5 @@ 'DocumentArrayProto', 'ListOfDocumentArrayProto', 'ListOfAnyProto', + 'DictOfAnyProto', ] diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto index 82a6304fff8..ae9c86a2fc1 100644 --- a/docarray/proto/docarray.proto +++ b/docarray/proto/docarray.proto @@ -35,7 +35,8 @@ message GenericDictValue { repeated KeyValuePair entries = 1; } -// + + message NodeProto { oneof content { @@ -56,13 +57,13 @@ message NodeProto { // a sub DocumentArray DocumentArrayProto document_array = 8; //any list - google.protobuf.ListValue list = 9; + ListOfAnyProto list = 9; //any set - google.protobuf.ListValue set = 10; + ListOfAnyProto set = 10; //any tuple - google.protobuf.ListValue tuple = 11; + ListOfAnyProto tuple = 11; // dictionary with string as keys - google.protobuf.Struct dict = 12; + DictOfAnyProto dict = 12; } oneof docarray_type { @@ -80,18 +81,25 @@ message DocumentProto { } +message DictOfAnyProto { + + map data = 1; + +} + +message ListOfAnyProto { + repeated NodeProto data = 1; +} + message DocumentArrayProto { repeated DocumentProto docs = 1; // a list of Documents } + message ListOfDocumentArrayProto { repeated DocumentArrayProto data = 1; } -message ListOfAnyProto { - repeated NodeProto data = 1; -} - message DocumentArrayStackedProto{ map tensor_columns = 1; // a dict of document columns map doc_columns = 2; // a dict of tensor columns diff --git a/docarray/proto/pb/docarray_pb2.py b/docarray/proto/pb/docarray_pb2.py index 49f68231896..b66c36a7e1e 100644 --- a/docarray/proto/pb/docarray_pb2.py +++ b/docarray/proto/pb/docarray_pb2.py @@ -14,7 +14,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xcb\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12+\n\x08\x64ocument\x18\x07 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12\x36\n\x0e\x64ocument_array\x18\x08 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12*\n\x04list\x18\t \x01(\x0b\x32\x1a.google.protobuf.ListValueH\x00\x12)\n\x03set\x18\n \x01(\x0b\x32\x1a.google.protobuf.ListValueH\x00\x12+\n\x05tuple\x18\x0b \x01(\x0b\x32\x1a.google.protobuf.ListValueH\x00\x12\'\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x17.google.protobuf.StructH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"F\n\x18ListOfDocumentArrayProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocumentArrayProto\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"\x90\x05\n\x19\x44ocumentArrayStackedProto\x12N\n\x0etensor_columns\x18\x01 \x03(\x0b\x32\x36.docarray.DocumentArrayStackedProto.TensorColumnsEntry\x12H\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32\x33.docarray.DocumentArrayStackedProto.DocColumnsEntry\x12\x46\n\nda_columns\x18\x03 \x03(\x0b\x32\x32.docarray.DocumentArrayStackedProto.DaColumnsEntry\x12H\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32\x33.docarray.DocumentArrayStackedProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aV\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.docarray.DocumentArrayStackedProto:\x02\x38\x01\x1aT\n\x0e\x44\x61\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x31\n\x05value\x18\x02 \x01(\x0b\x32\".docarray.ListOfDocumentArrayProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xc6\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12+\n\x08\x64ocument\x18\x07 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12\x36\n\x0e\x64ocument_array\x18\x08 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"F\n\x18ListOfDocumentArrayProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocumentArrayProto\"\x90\x05\n\x19\x44ocumentArrayStackedProto\x12N\n\x0etensor_columns\x18\x01 \x03(\x0b\x32\x36.docarray.DocumentArrayStackedProto.TensorColumnsEntry\x12H\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32\x33.docarray.DocumentArrayStackedProto.DocColumnsEntry\x12\x46\n\nda_columns\x18\x03 \x03(\x0b\x32\x32.docarray.DocumentArrayStackedProto.DaColumnsEntry\x12H\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32\x33.docarray.DocumentArrayStackedProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aV\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.docarray.DocumentArrayStackedProto:\x02\x38\x01\x1aT\n\x0e\x44\x61\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x31\n\x05value\x18\x02 \x01(\x0b\x32\".docarray.ListOfDocumentArrayProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'docarray_pb2', globals()) @@ -23,6 +23,8 @@ DESCRIPTOR._options = None _DOCUMENTPROTO_DATAENTRY._options = None _DOCUMENTPROTO_DATAENTRY._serialized_options = b'8\001' + _DICTOFANYPROTO_DATAENTRY._options = None + _DICTOFANYPROTO_DATAENTRY._serialized_options = b'8\001' _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._options = None _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_options = b'8\001' _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._options = None @@ -40,25 +42,29 @@ _GENERICDICTVALUE._serialized_start=322 _GENERICDICTVALUE._serialized_end=381 _NODEPROTO._serialized_start=384 - _NODEPROTO._serialized_end=843 - _DOCUMENTPROTO._serialized_start=846 - _DOCUMENTPROTO._serialized_end=976 - _DOCUMENTPROTO_DATAENTRY._serialized_start=912 - _DOCUMENTPROTO_DATAENTRY._serialized_end=976 - _DOCUMENTARRAYPROTO._serialized_start=978 - _DOCUMENTARRAYPROTO._serialized_end=1037 - _LISTOFDOCUMENTARRAYPROTO._serialized_start=1039 - _LISTOFDOCUMENTARRAYPROTO._serialized_end=1109 - _LISTOFANYPROTO._serialized_start=1111 - _LISTOFANYPROTO._serialized_end=1162 - _DOCUMENTARRAYSTACKEDPROTO._serialized_start=1165 - _DOCUMENTARRAYSTACKEDPROTO._serialized_end=1821 - _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_start=1494 - _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_end=1570 - _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._serialized_start=1572 - _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._serialized_end=1658 - _DOCUMENTARRAYSTACKEDPROTO_DACOLUMNSENTRY._serialized_start=1660 - _DOCUMENTARRAYSTACKEDPROTO_DACOLUMNSENTRY._serialized_end=1744 - _DOCUMENTARRAYSTACKEDPROTO_ANYCOLUMNSENTRY._serialized_start=1746 - _DOCUMENTARRAYSTACKEDPROTO_ANYCOLUMNSENTRY._serialized_end=1821 + _NODEPROTO._serialized_end=838 + _DOCUMENTPROTO._serialized_start=841 + _DOCUMENTPROTO._serialized_end=971 + _DOCUMENTPROTO_DATAENTRY._serialized_start=907 + _DOCUMENTPROTO_DATAENTRY._serialized_end=971 + _DICTOFANYPROTO._serialized_start=974 + _DICTOFANYPROTO._serialized_end=1106 + _DICTOFANYPROTO_DATAENTRY._serialized_start=907 + _DICTOFANYPROTO_DATAENTRY._serialized_end=971 + _LISTOFANYPROTO._serialized_start=1108 + _LISTOFANYPROTO._serialized_end=1159 + _DOCUMENTARRAYPROTO._serialized_start=1161 + _DOCUMENTARRAYPROTO._serialized_end=1220 + _LISTOFDOCUMENTARRAYPROTO._serialized_start=1222 + _LISTOFDOCUMENTARRAYPROTO._serialized_end=1292 + _DOCUMENTARRAYSTACKEDPROTO._serialized_start=1295 + _DOCUMENTARRAYSTACKEDPROTO._serialized_end=1951 + _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_start=1624 + _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_end=1700 + _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._serialized_start=1702 + _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._serialized_end=1788 + _DOCUMENTARRAYSTACKEDPROTO_DACOLUMNSENTRY._serialized_start=1790 + _DOCUMENTARRAYSTACKEDPROTO_DACOLUMNSENTRY._serialized_end=1874 + _DOCUMENTARRAYSTACKEDPROTO_ANYCOLUMNSENTRY._serialized_start=1876 + _DOCUMENTARRAYSTACKEDPROTO_ANYCOLUMNSENTRY._serialized_end=1951 # @@protoc_insertion_point(module_scope) diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py index cac439eec07..cc71cb81420 100644 --- a/docarray/proto/pb2/docarray_pb2.py +++ b/docarray/proto/pb2/docarray_pb2.py @@ -16,7 +16,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xcb\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12+\n\x08\x64ocument\x18\x07 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12\x36\n\x0e\x64ocument_array\x18\x08 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12*\n\x04list\x18\t \x01(\x0b\x32\x1a.google.protobuf.ListValueH\x00\x12)\n\x03set\x18\n \x01(\x0b\x32\x1a.google.protobuf.ListValueH\x00\x12+\n\x05tuple\x18\x0b \x01(\x0b\x32\x1a.google.protobuf.ListValueH\x00\x12\'\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x17.google.protobuf.StructH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"F\n\x18ListOfDocumentArrayProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocumentArrayProto\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"\x90\x05\n\x19\x44ocumentArrayStackedProto\x12N\n\x0etensor_columns\x18\x01 \x03(\x0b\x32\x36.docarray.DocumentArrayStackedProto.TensorColumnsEntry\x12H\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32\x33.docarray.DocumentArrayStackedProto.DocColumnsEntry\x12\x46\n\nda_columns\x18\x03 \x03(\x0b\x32\x32.docarray.DocumentArrayStackedProto.DaColumnsEntry\x12H\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32\x33.docarray.DocumentArrayStackedProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aV\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.docarray.DocumentArrayStackedProto:\x02\x38\x01\x1aT\n\x0e\x44\x61\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x31\n\x05value\x18\x02 \x01(\x0b\x32\".docarray.ListOfDocumentArrayProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3' + b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xc6\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12+\n\x08\x64ocument\x18\x07 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12\x36\n\x0e\x64ocument_array\x18\x08 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"F\n\x18ListOfDocumentArrayProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocumentArrayProto\"\x90\x05\n\x19\x44ocumentArrayStackedProto\x12N\n\x0etensor_columns\x18\x01 \x03(\x0b\x32\x36.docarray.DocumentArrayStackedProto.TensorColumnsEntry\x12H\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32\x33.docarray.DocumentArrayStackedProto.DocColumnsEntry\x12\x46\n\nda_columns\x18\x03 \x03(\x0b\x32\x32.docarray.DocumentArrayStackedProto.DaColumnsEntry\x12H\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32\x33.docarray.DocumentArrayStackedProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aV\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.docarray.DocumentArrayStackedProto:\x02\x38\x01\x1aT\n\x0e\x44\x61\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x31\n\x05value\x18\x02 \x01(\x0b\x32\".docarray.ListOfDocumentArrayProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3' ) @@ -27,9 +27,11 @@ _NODEPROTO = DESCRIPTOR.message_types_by_name['NodeProto'] _DOCUMENTPROTO = DESCRIPTOR.message_types_by_name['DocumentProto'] _DOCUMENTPROTO_DATAENTRY = _DOCUMENTPROTO.nested_types_by_name['DataEntry'] +_DICTOFANYPROTO = DESCRIPTOR.message_types_by_name['DictOfAnyProto'] +_DICTOFANYPROTO_DATAENTRY = _DICTOFANYPROTO.nested_types_by_name['DataEntry'] +_LISTOFANYPROTO = DESCRIPTOR.message_types_by_name['ListOfAnyProto'] _DOCUMENTARRAYPROTO = DESCRIPTOR.message_types_by_name['DocumentArrayProto'] _LISTOFDOCUMENTARRAYPROTO = DESCRIPTOR.message_types_by_name['ListOfDocumentArrayProto'] -_LISTOFANYPROTO = DESCRIPTOR.message_types_by_name['ListOfAnyProto'] _DOCUMENTARRAYSTACKEDPROTO = DESCRIPTOR.message_types_by_name[ 'DocumentArrayStackedProto' ] @@ -121,6 +123,38 @@ _sym_db.RegisterMessage(DocumentProto) _sym_db.RegisterMessage(DocumentProto.DataEntry) +DictOfAnyProto = _reflection.GeneratedProtocolMessageType( + 'DictOfAnyProto', + (_message.Message,), + { + 'DataEntry': _reflection.GeneratedProtocolMessageType( + 'DataEntry', + (_message.Message,), + { + 'DESCRIPTOR': _DICTOFANYPROTO_DATAENTRY, + '__module__': 'docarray_pb2' + # @@protoc_insertion_point(class_scope:docarray.DictOfAnyProto.DataEntry) + }, + ), + 'DESCRIPTOR': _DICTOFANYPROTO, + '__module__': 'docarray_pb2' + # @@protoc_insertion_point(class_scope:docarray.DictOfAnyProto) + }, +) +_sym_db.RegisterMessage(DictOfAnyProto) +_sym_db.RegisterMessage(DictOfAnyProto.DataEntry) + +ListOfAnyProto = _reflection.GeneratedProtocolMessageType( + 'ListOfAnyProto', + (_message.Message,), + { + 'DESCRIPTOR': _LISTOFANYPROTO, + '__module__': 'docarray_pb2' + # @@protoc_insertion_point(class_scope:docarray.ListOfAnyProto) + }, +) +_sym_db.RegisterMessage(ListOfAnyProto) + DocumentArrayProto = _reflection.GeneratedProtocolMessageType( 'DocumentArrayProto', (_message.Message,), @@ -143,17 +177,6 @@ ) _sym_db.RegisterMessage(ListOfDocumentArrayProto) -ListOfAnyProto = _reflection.GeneratedProtocolMessageType( - 'ListOfAnyProto', - (_message.Message,), - { - 'DESCRIPTOR': _LISTOFANYPROTO, - '__module__': 'docarray_pb2' - # @@protoc_insertion_point(class_scope:docarray.ListOfAnyProto) - }, -) -_sym_db.RegisterMessage(ListOfAnyProto) - DocumentArrayStackedProto = _reflection.GeneratedProtocolMessageType( 'DocumentArrayStackedProto', (_message.Message,), @@ -210,6 +233,8 @@ DESCRIPTOR._options = None _DOCUMENTPROTO_DATAENTRY._options = None _DOCUMENTPROTO_DATAENTRY._serialized_options = b'8\001' + _DICTOFANYPROTO_DATAENTRY._options = None + _DICTOFANYPROTO_DATAENTRY._serialized_options = b'8\001' _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._options = None _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_options = b'8\001' _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._options = None @@ -227,25 +252,29 @@ _GENERICDICTVALUE._serialized_start = 322 _GENERICDICTVALUE._serialized_end = 381 _NODEPROTO._serialized_start = 384 - _NODEPROTO._serialized_end = 843 - _DOCUMENTPROTO._serialized_start = 846 - _DOCUMENTPROTO._serialized_end = 976 - _DOCUMENTPROTO_DATAENTRY._serialized_start = 912 - _DOCUMENTPROTO_DATAENTRY._serialized_end = 976 - _DOCUMENTARRAYPROTO._serialized_start = 978 - _DOCUMENTARRAYPROTO._serialized_end = 1037 - _LISTOFDOCUMENTARRAYPROTO._serialized_start = 1039 - _LISTOFDOCUMENTARRAYPROTO._serialized_end = 1109 - _LISTOFANYPROTO._serialized_start = 1111 - _LISTOFANYPROTO._serialized_end = 1162 - _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 1165 - _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1821 - _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_start = 1494 - _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_end = 1570 - _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._serialized_start = 1572 - _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._serialized_end = 1658 - _DOCUMENTARRAYSTACKEDPROTO_DACOLUMNSENTRY._serialized_start = 1660 - _DOCUMENTARRAYSTACKEDPROTO_DACOLUMNSENTRY._serialized_end = 1744 - _DOCUMENTARRAYSTACKEDPROTO_ANYCOLUMNSENTRY._serialized_start = 1746 - _DOCUMENTARRAYSTACKEDPROTO_ANYCOLUMNSENTRY._serialized_end = 1821 + _NODEPROTO._serialized_end = 838 + _DOCUMENTPROTO._serialized_start = 841 + _DOCUMENTPROTO._serialized_end = 971 + _DOCUMENTPROTO_DATAENTRY._serialized_start = 907 + _DOCUMENTPROTO_DATAENTRY._serialized_end = 971 + _DICTOFANYPROTO._serialized_start = 974 + _DICTOFANYPROTO._serialized_end = 1106 + _DICTOFANYPROTO_DATAENTRY._serialized_start = 907 + _DICTOFANYPROTO_DATAENTRY._serialized_end = 971 + _LISTOFANYPROTO._serialized_start = 1108 + _LISTOFANYPROTO._serialized_end = 1159 + _DOCUMENTARRAYPROTO._serialized_start = 1161 + _DOCUMENTARRAYPROTO._serialized_end = 1220 + _LISTOFDOCUMENTARRAYPROTO._serialized_start = 1222 + _LISTOFDOCUMENTARRAYPROTO._serialized_end = 1292 + _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 1295 + _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1951 + _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_start = 1624 + _DOCUMENTARRAYSTACKEDPROTO_TENSORCOLUMNSENTRY._serialized_end = 1700 + _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._serialized_start = 1702 + _DOCUMENTARRAYSTACKEDPROTO_DOCCOLUMNSENTRY._serialized_end = 1788 + _DOCUMENTARRAYSTACKEDPROTO_DACOLUMNSENTRY._serialized_start = 1790 + _DOCUMENTARRAYSTACKEDPROTO_DACOLUMNSENTRY._serialized_end = 1874 + _DOCUMENTARRAYSTACKEDPROTO_ANYCOLUMNSENTRY._serialized_start = 1876 + _DOCUMENTARRAYSTACKEDPROTO_ANYCOLUMNSENTRY._serialized_end = 1951 # @@protoc_insertion_point(module_scope) diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index d14ece7a418..73df031be07 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -7,6 +7,10 @@ from docarray import DocumentArray from docarray.base_document import BaseDocument from docarray.typing import NdArray, TorchTensor +from docarray.utils.misc import is_tf_available + +if is_tf_available(): + import tensorflow as tf @pytest.mark.proto @@ -204,3 +208,91 @@ class MyDoc(BaseDocument): assert doc.tensor.dtype == dtype assert MyDoc.from_protobuf(doc.to_protobuf()).tensor.dtype == dtype assert MyDoc.parse_obj(doc.dict()).tensor.dtype == dtype + + +@pytest.mark.proto +def test_nested_dict(): + class MyDoc(BaseDocument): + data: Dict + + doc = MyDoc(data={'data': (1, 2)}) + + MyDoc.from_protobuf(doc.to_protobuf()) + + +@pytest.mark.proto +def test_tuple_complex(): + class MyDoc(BaseDocument): + data: Tuple + + doc = MyDoc(data=(1, 2)) + + doc2 = MyDoc.from_protobuf(doc.to_protobuf()) + + assert doc2.data == (1, 2) + + +@pytest.mark.proto +def test_list_complex(): + class MyDoc(BaseDocument): + data: List + + doc = MyDoc(data=[(1, 2)]) + + doc2 = MyDoc.from_protobuf(doc.to_protobuf()) + + assert doc2.data == [(1, 2)] + + +@pytest.mark.proto +def test_nested_tensor_list(): + class MyDoc(BaseDocument): + data: List + + doc = MyDoc(data=[np.zeros(10)]) + + doc2 = MyDoc.from_protobuf(doc.to_protobuf()) + + assert isinstance(doc2.data[0], np.ndarray) + assert isinstance(doc2.data[0], NdArray) + + assert (doc2.data[0] == np.zeros(10)).all() + + +@pytest.mark.proto +def test_nested_tensor_dict(): + class MyDoc(BaseDocument): + data: Dict + + doc = MyDoc(data={'hello': np.zeros(10)}) + + doc2 = MyDoc.from_protobuf(doc.to_protobuf()) + + assert isinstance(doc2.data['hello'], np.ndarray) + assert isinstance(doc2.data['hello'], NdArray) + + assert (doc2.data['hello'] == np.zeros(10)).all() + + +@pytest.mark.proto +def test_super_complex_nested(): + class MyDoc(BaseDocument): + data: Dict + + data = {'hello': (torch.zeros(55), 1, 'hi', [torch.ones(55), np.zeros(10), (1, 2)])} + doc = MyDoc(data=data) + + doc2 = MyDoc.from_protobuf(doc.to_protobuf()) + + (doc2.data['hello'][3][0] == torch.ones(55)).all() + + +@pytest.mark.tensorflow +def test_super_complex_nested_tensorflow(): + class MyDoc(BaseDocument): + data: Dict + + data = {'hello': (torch.zeros(55), 1, 'hi', [tf.ones(55), np.zeros(10), (1, 2)])} + doc = MyDoc(data=data) + + MyDoc.from_protobuf(doc.to_protobuf())