diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index 995b2fc4ec5..b65a17b15e5 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -17,6 +17,7 @@ overload, ) +import numpy as np from pydantic import BaseConfig, parse_obj_as from typing_inspect import typingGenericAlias @@ -34,7 +35,12 @@ if TYPE_CHECKING: from pydantic.fields import ModelField - from docarray.proto import DocVecProto + from docarray.proto import ( + DocVecProto, + ListOfDocArrayProto, + ListOfDocVecProto, + NdArrayProto, + ) torch_available = is_torch_available() if torch_available: @@ -54,6 +60,56 @@ T = TypeVar('T', bound='DocVec') IndexIterType = Union[slice, Iterable[int], Iterable[bool], None] +NONE_NDARRAY_PROTO_SHAPE = (0,) +NONE_NDARRAY_PROTO_DTYPE = 'None' + + +def _none_ndarray_proto() -> 'NdArrayProto': + from docarray.proto import NdArrayProto + + zeros_arr = parse_obj_as(NdArray, np.zeros(NONE_NDARRAY_PROTO_SHAPE)) + nd_proto = NdArrayProto() + nd_proto.dense.buffer = zeros_arr.tobytes() + nd_proto.dense.ClearField('shape') + nd_proto.dense.shape.extend(list(zeros_arr.shape)) + nd_proto.dense.dtype = NONE_NDARRAY_PROTO_DTYPE + + return nd_proto + + +def _none_docvec_proto() -> 'DocVecProto': + from docarray.proto import DocVecProto + + return DocVecProto() + + +def _none_list_of_docvec_proto() -> 'ListOfDocArrayProto': + from docarray.proto import ListOfDocVecProto + + return ListOfDocVecProto() + + +def _is_none_ndarray_proto(proto: 'NdArrayProto') -> bool: + return ( + proto.dense.shape == list(NONE_NDARRAY_PROTO_SHAPE) + and proto.dense.dtype == NONE_NDARRAY_PROTO_DTYPE + ) + + +def _is_none_docvec_proto(proto: 'DocVecProto') -> bool: + return ( + proto.tensor_columns == {} + and proto.doc_columns == {} + and proto.docs_vec_columns == {} + and proto.any_columns == {} + ) + + +def _is_none_list_of_docvec_proto(proto: 'ListOfDocVecProto') -> bool: + from docarray.proto import ListOfDocVecProto + + return isinstance(proto, ListOfDocVecProto) and len(proto.data) == 0 + class DocVec(AnyDocArray[T_doc]): """ @@ -243,7 +299,7 @@ def _check_doc_field_not_none(field_name, doc): elif issubclass(field_type, AnyDocArray): if first_doc_is_none: _verify_optional_field_of_docs(docs) - doc_columns[field_name] = None + docs_vec_columns[field_name] = None else: docs_list = list() for doc in docs: @@ -534,12 +590,63 @@ def __len__(self): @classmethod def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T: - """create a Document from a protobuf message""" + """create a DocVec from a protobuf message""" + + tensor_columns: Dict[str, Optional[AbstractTensor]] = {} + doc_columns: Dict[str, Optional['DocVec']] = {} + docs_vec_columns: Dict[str, Optional[ListAdvancedIndexing['DocVec']]] = {} + any_columns: Dict[str, ListAdvancedIndexing] = {} + + for tens_col_name, tens_col_proto in pb_msg.tensor_columns.items(): + if _is_none_ndarray_proto(tens_col_proto): + # handle values that were None before serialization + tensor_columns[tens_col_name] = None + else: + # TODO(johannes): handle torch, tf, numpy + tensor_columns[tens_col_name] = NdArray.from_protobuf(tens_col_proto) + + for doc_col_name, doc_col_proto in pb_msg.doc_columns.items(): + if _is_none_docvec_proto(doc_col_proto): + # handle values that were None before serialization + doc_columns[doc_col_name] = None + else: + col_doc_type: Type = cls.doc_type._get_field_type(doc_col_name) + doc_columns[doc_col_name] = DocVec.__class_getitem__( + col_doc_type + ).from_protobuf(doc_col_proto) + + for docs_vec_col_name, docs_vec_col_proto in pb_msg.docs_vec_columns.items(): + vec_list: Optional[ListAdvancedIndexing] + if _is_none_list_of_docvec_proto(docs_vec_col_proto): + # handle values that were None before serialization + vec_list = None + else: + vec_list = ListAdvancedIndexing() + for doc_list_proto in docs_vec_col_proto.data: + col_doc_type = cls.doc_type._get_field_type( + docs_vec_col_name + ).doc_type + vec_list.append( + DocVec.__class_getitem__(col_doc_type).from_protobuf( + doc_list_proto + ) + ) + docs_vec_columns[docs_vec_col_name] = vec_list + + for any_col_name, any_col_proto in pb_msg.any_columns.items(): + any_column: ListAdvancedIndexing = ListAdvancedIndexing() + for node_proto in any_col_proto.data: + content = cls.doc_type._get_content_from_node_proto( + node_proto, any_col_name + ) + any_column.append(content) + any_columns[any_col_name] = any_column + storage = ColumnStorage( - pb_msg.tensor_columns, - pb_msg.doc_columns, - pb_msg.docs_vec_columns, - pb_msg.any_columns, + tensor_columns=tensor_columns, + doc_columns=doc_columns, + docs_vec_columns=docs_vec_columns, + any_columns=any_columns, ) return cls.from_columns_storage(storage) @@ -547,35 +654,40 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T: def to_protobuf(self) -> 'DocVecProto': """Convert DocVec into a Protobuf message""" from docarray.proto import ( - DocListProto, DocVecProto, ListOfAnyProto, ListOfDocArrayProto, + ListOfDocVecProto, NdArrayProto, ) - da_proto = DocListProto() - for doc in self: - da_proto.docs.append(doc.to_protobuf()) - doc_columns_proto: Dict[str, DocVecProto] = dict() tensor_columns_proto: Dict[str, NdArrayProto] = dict() da_columns_proto: Dict[str, ListOfDocArrayProto] = dict() any_columns_proto: Dict[str, ListOfAnyProto] = dict() for field, col_doc in self._storage.doc_columns.items(): - doc_columns_proto[field] = ( - col_doc.to_protobuf() if col_doc is not None else None - ) + if col_doc is None: + # put dummy empty DocVecProto for serialization + doc_columns_proto[field] = _none_docvec_proto() + else: + doc_columns_proto[field] = col_doc.to_protobuf() for field, col_tens in self._storage.tensor_columns.items(): - tensor_columns_proto[field] = ( - col_tens.to_protobuf() if col_tens is not None else None - ) + if col_tens is None: + # put dummy empty NdArrayProto for serialization + tensor_columns_proto[field] = _none_ndarray_proto() + else: + tensor_columns_proto[field] = ( + col_tens.to_protobuf() if col_tens is not None else None + ) for field, col_da in self._storage.docs_vec_columns.items(): - list_proto = ListOfDocArrayProto() + list_proto = ListOfDocVecProto() if col_da: for docs in col_da: list_proto.data.append(docs.to_protobuf()) + else: + # put dummy empty ListOfDocVecProto for serialization + list_proto = _none_list_of_docvec_proto() da_columns_proto[field] = list_proto for field, col_any in self._storage.any_columns.items(): list_proto = ListOfAnyProto() diff --git a/docarray/proto/__init__.py b/docarray/proto/__init__.py index b1a201b6e2f..b7cff253d8b 100644 --- a/docarray/proto/__init__.py +++ b/docarray/proto/__init__.py @@ -17,6 +17,7 @@ DocVecProto, ListOfAnyProto, ListOfDocArrayProto, + ListOfDocVecProto, NdArrayProto, NodeProto, ) @@ -28,6 +29,7 @@ DocVecProto, ListOfAnyProto, ListOfDocArrayProto, + ListOfDocVecProto, NdArrayProto, NodeProto, ) @@ -40,6 +42,7 @@ 'DocVecProto', 'DocListProto', 'ListOfDocArrayProto', + 'ListOfDocVecProto', 'ListOfAnyProto', 'DictOfAnyProto', ] diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto index 19a33ccbc22..a73451bac1b 100644 --- a/docarray/proto/docarray.proto +++ b/docarray/proto/docarray.proto @@ -100,9 +100,13 @@ message ListOfDocArrayProto { repeated DocListProto data = 1; } +message ListOfDocVecProto { + repeated DocVecProto data = 1; +} + message DocVecProto{ map tensor_columns = 1; // a dict of document columns map doc_columns = 2; // a dict of tensor columns - map docs_vec_columns = 3; // a dict of document array columns + map docs_vec_columns = 3; // a dict of document array columns map any_columns = 4; // a dict of any columns. Used for the rest of the data } \ No newline at end of file diff --git a/docarray/proto/pb/docarray_pb2.py b/docarray/proto/pb/docarray_pb2.py index 8ff91a9f5e8..0cd5b334a18 100644 --- a/docarray/proto/pb/docarray_pb2.py +++ b/docarray/proto/pb/docarray_pb2.py @@ -14,7 +14,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xb1\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12!\n\x03\x64oc\x18\x07 \x01(\x0b\x32\x12.docarray.DocProtoH\x00\x12+\n\tdoc_array\x18\x08 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"x\n\x08\x44ocProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"0\n\x0c\x44ocListProto\x12 \n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x12.docarray.DocProto\";\n\x13ListOfDocArrayProto\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.docarray.DocListProto\"\xc7\x04\n\x0b\x44ocVecProto\x12@\n\x0etensor_columns\x18\x01 \x03(\x0b\x32(.docarray.DocVecProto.TensorColumnsEntry\x12:\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32%.docarray.DocVecProto.DocColumnsEntry\x12\x43\n\x10\x64ocs_vec_columns\x18\x03 \x03(\x0b\x32).docarray.DocVecProto.DocsVecColumnsEntry\x12:\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32%.docarray.DocVecProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aH\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.docarray.DocVecProto:\x02\x38\x01\x1aT\n\x13\x44ocsVecColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.docarray.ListOfDocArrayProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xb1\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12!\n\x03\x64oc\x18\x07 \x01(\x0b\x32\x12.docarray.DocProtoH\x00\x12+\n\tdoc_array\x18\x08 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"x\n\x08\x44ocProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"0\n\x0c\x44ocListProto\x12 \n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x12.docarray.DocProto\";\n\x13ListOfDocArrayProto\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.docarray.DocListProto\"8\n\x11ListOfDocVecProto\x12#\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x15.docarray.DocVecProto\"\xc5\x04\n\x0b\x44ocVecProto\x12@\n\x0etensor_columns\x18\x01 \x03(\x0b\x32(.docarray.DocVecProto.TensorColumnsEntry\x12:\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32%.docarray.DocVecProto.DocColumnsEntry\x12\x43\n\x10\x64ocs_vec_columns\x18\x03 \x03(\x0b\x32).docarray.DocVecProto.DocsVecColumnsEntry\x12:\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32%.docarray.DocVecProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aH\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.docarray.DocVecProto:\x02\x38\x01\x1aR\n\x13\x44ocsVecColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.docarray.ListOfDocVecProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'docarray_pb2', globals()) @@ -57,14 +57,16 @@ _DOCLISTPROTO._serialized_end=1177 _LISTOFDOCARRAYPROTO._serialized_start=1179 _LISTOFDOCARRAYPROTO._serialized_end=1238 - _DOCVECPROTO._serialized_start=1241 - _DOCVECPROTO._serialized_end=1824 - _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_start=1511 - _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_end=1587 - _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_start=1589 - _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_end=1661 - _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_start=1663 - _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_end=1747 - _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_start=1749 - _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_end=1824 + _LISTOFDOCVECPROTO._serialized_start=1240 + _LISTOFDOCVECPROTO._serialized_end=1296 + _DOCVECPROTO._serialized_start=1299 + _DOCVECPROTO._serialized_end=1880 + _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_start=1569 + _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_end=1645 + _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_start=1647 + _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_end=1719 + _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_start=1721 + _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_end=1803 + _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_start=1805 + _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_end=1880 # @@protoc_insertion_point(module_scope) diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py index 9fbbbadf342..e178c8c3f9d 100644 --- a/docarray/proto/pb2/docarray_pb2.py +++ b/docarray/proto/pb2/docarray_pb2.py @@ -16,7 +16,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xb1\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12!\n\x03\x64oc\x18\x07 \x01(\x0b\x32\x12.docarray.DocProtoH\x00\x12+\n\tdoc_array\x18\x08 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"x\n\x08\x44ocProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"0\n\x0c\x44ocListProto\x12 \n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x12.docarray.DocProto\";\n\x13ListOfDocArrayProto\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.docarray.DocListProto\"\xc7\x04\n\x0b\x44ocVecProto\x12@\n\x0etensor_columns\x18\x01 \x03(\x0b\x32(.docarray.DocVecProto.TensorColumnsEntry\x12:\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32%.docarray.DocVecProto.DocColumnsEntry\x12\x43\n\x10\x64ocs_vec_columns\x18\x03 \x03(\x0b\x32).docarray.DocVecProto.DocsVecColumnsEntry\x12:\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32%.docarray.DocVecProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aH\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.docarray.DocVecProto:\x02\x38\x01\x1aT\n\x13\x44ocsVecColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.docarray.ListOfDocArrayProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3' + b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"Z\n\x0cKeyValuePair\x12#\n\x03key\x18\x01 \x01(\x0b\x32\x16.google.protobuf.Value\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.google.protobuf.Value\";\n\x10GenericDictValue\x12\'\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x16.docarray.KeyValuePair\"\xb1\x03\n\tNodeProto\x12\x0e\n\x04text\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x05H\x00\x12\x0f\n\x05\x66loat\x18\x03 \x01(\x01H\x00\x12\x11\n\x07\x62oolean\x18\x04 \x01(\x08H\x00\x12\x0e\n\x04\x62lob\x18\x05 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12!\n\x03\x64oc\x18\x07 \x01(\x0b\x32\x12.docarray.DocProtoH\x00\x12+\n\tdoc_array\x18\x08 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12(\n\x04list\x18\t \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12\'\n\x03set\x18\n \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12)\n\x05tuple\x18\x0b \x01(\x0b\x32\x18.docarray.ListOfAnyProtoH\x00\x12(\n\x04\x64ict\x18\x0c \x01(\x0b\x32\x18.docarray.DictOfAnyProtoH\x00\x12\x0e\n\x04type\x18\r \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"x\n\x08\x44ocProto\x12*\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x1c.docarray.DocProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"\x84\x01\n\x0e\x44ictOfAnyProto\x12\x30\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\".docarray.DictOfAnyProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\"3\n\x0eListOfAnyProto\x12!\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x13.docarray.NodeProto\"0\n\x0c\x44ocListProto\x12 \n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x12.docarray.DocProto\";\n\x13ListOfDocArrayProto\x12$\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x16.docarray.DocListProto\"8\n\x11ListOfDocVecProto\x12#\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x15.docarray.DocVecProto\"\xc5\x04\n\x0b\x44ocVecProto\x12@\n\x0etensor_columns\x18\x01 \x03(\x0b\x32(.docarray.DocVecProto.TensorColumnsEntry\x12:\n\x0b\x64oc_columns\x18\x02 \x03(\x0b\x32%.docarray.DocVecProto.DocColumnsEntry\x12\x43\n\x10\x64ocs_vec_columns\x18\x03 \x03(\x0b\x32).docarray.DocVecProto.DocsVecColumnsEntry\x12:\n\x0b\x61ny_columns\x18\x04 \x03(\x0b\x32%.docarray.DocVecProto.AnyColumnsEntry\x1aL\n\x12TensorColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProto:\x02\x38\x01\x1aH\n\x0f\x44ocColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.docarray.DocVecProto:\x02\x38\x01\x1aR\n\x13\x44ocsVecColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.docarray.ListOfDocVecProto:\x02\x38\x01\x1aK\n\x0f\x41nyColumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.docarray.ListOfAnyProto:\x02\x38\x01\x62\x06proto3' ) @@ -32,6 +32,7 @@ _LISTOFANYPROTO = DESCRIPTOR.message_types_by_name['ListOfAnyProto'] _DOCLISTPROTO = DESCRIPTOR.message_types_by_name['DocListProto'] _LISTOFDOCARRAYPROTO = DESCRIPTOR.message_types_by_name['ListOfDocArrayProto'] +_LISTOFDOCVECPROTO = DESCRIPTOR.message_types_by_name['ListOfDocVecProto'] _DOCVECPROTO = DESCRIPTOR.message_types_by_name['DocVecProto'] _DOCVECPROTO_TENSORCOLUMNSENTRY = _DOCVECPROTO.nested_types_by_name[ 'TensorColumnsEntry' @@ -171,6 +172,17 @@ ) _sym_db.RegisterMessage(ListOfDocArrayProto) +ListOfDocVecProto = _reflection.GeneratedProtocolMessageType( + 'ListOfDocVecProto', + (_message.Message,), + { + 'DESCRIPTOR': _LISTOFDOCVECPROTO, + '__module__': 'docarray_pb2' + # @@protoc_insertion_point(class_scope:docarray.ListOfDocVecProto) + }, +) +_sym_db.RegisterMessage(ListOfDocVecProto) + DocVecProto = _reflection.GeneratedProtocolMessageType( 'DocVecProto', (_message.Message,), @@ -261,14 +273,16 @@ _DOCLISTPROTO._serialized_end = 1177 _LISTOFDOCARRAYPROTO._serialized_start = 1179 _LISTOFDOCARRAYPROTO._serialized_end = 1238 - _DOCVECPROTO._serialized_start = 1241 - _DOCVECPROTO._serialized_end = 1824 - _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_start = 1511 - _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_end = 1587 - _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_start = 1589 - _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_end = 1661 - _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_start = 1663 - _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_end = 1747 - _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_start = 1749 - _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_end = 1824 + _LISTOFDOCVECPROTO._serialized_start = 1240 + _LISTOFDOCVECPROTO._serialized_end = 1296 + _DOCVECPROTO._serialized_start = 1299 + _DOCVECPROTO._serialized_end = 1880 + _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_start = 1569 + _DOCVECPROTO_TENSORCOLUMNSENTRY._serialized_end = 1645 + _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_start = 1647 + _DOCVECPROTO_DOCCOLUMNSENTRY._serialized_end = 1719 + _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_start = 1721 + _DOCVECPROTO_DOCSVECCOLUMNSENTRY._serialized_end = 1803 + _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_start = 1805 + _DOCVECPROTO_ANYCOLUMNSENTRY._serialized_end = 1880 # @@protoc_insertion_point(module_scope) diff --git a/tests/units/array/stack/test_proto.py b/tests/units/array/stack/test_proto.py index 0cda39db730..15be0d496d5 100644 --- a/tests/units/array/stack/test_proto.py +++ b/tests/units/array/stack/test_proto.py @@ -1,3 +1,5 @@ +from typing import Dict, Optional, Union + import numpy as np import pytest import torch @@ -43,6 +45,186 @@ class CustomDocument(BaseDoc): [CustomDocument(image=np.zeros((3, 224, 224))) for _ in range(10)] ).to_doc_vec() - da2 = DocVec.from_protobuf(da.to_protobuf()) + da2 = DocVec[CustomDocument].from_protobuf(da.to_protobuf()) assert isinstance(da2, DocVec) + assert da.doc_type == da2.doc_type + assert (da2.image == da.image).all() + + +@pytest.mark.proto +def test_proto_none_tensor_column(): + class MyOtherDoc(BaseDoc): + embedding: Union[NdArray, None] + other_embedding: NdArray + third_embedding: Union[NdArray, None] + + da = DocVec[MyOtherDoc]( + [ + MyOtherDoc( + other_embedding=np.random.random(512), + ), + MyOtherDoc(other_embedding=np.random.random(512)), + ] + ) + assert da._storage.tensor_columns['embedding'] is None + assert da._storage.tensor_columns['other_embedding'] is not None + assert da._storage.tensor_columns['third_embedding'] is None + + proto = da.to_protobuf() + da_after = DocVec[MyOtherDoc].from_protobuf(proto) + + assert da_after._storage.tensor_columns['embedding'] is None + assert da_after._storage.tensor_columns['other_embedding'] is not None + assert ( + da_after._storage.tensor_columns['other_embedding'] + == da._storage.tensor_columns['other_embedding'] + ).all() + assert da_after._storage.tensor_columns['third_embedding'] is None + + +@pytest.mark.proto +def test_proto_none_doc_column(): + class InnerDoc(BaseDoc): + embedding: NdArray + + class MyDoc(BaseDoc): + inner: Union[InnerDoc, None] + other_inner: Union[InnerDoc, None] + + da = DocVec[MyDoc]( + [ + MyDoc(other_inner=InnerDoc(embedding=np.random.random(512))), + MyDoc(other_inner=InnerDoc(embedding=np.random.random(512))), + ] + ) + assert da._storage.doc_columns['inner'] is None + assert len(da._storage.doc_columns['other_inner']) == 2 + + proto = da.to_protobuf() + da_after = DocVec[MyDoc].from_protobuf(proto) + + assert da_after._storage.doc_columns['inner'] is None + assert len(da._storage.doc_columns['other_inner']) == 2 + assert (da.other_inner.embedding == da_after.other_inner.embedding).all() + + +@pytest.mark.proto +def test_proto_none_docvec_column(): + class InnerDoc(BaseDoc): + embedding: NdArray + + class MyDoc(BaseDoc): + inner_l: Union[DocList[InnerDoc], None] + inner_v: Union[DocVec[InnerDoc], None] + inner_exists_v: Union[DocVec[InnerDoc], None] + inner_exists_l: Union[DocList[InnerDoc], None] + + def _make_inner_list(): + return DocList[InnerDoc]( + [ + InnerDoc(embedding=np.random.random(512)), + InnerDoc(embedding=np.random.random(512)), + ] + ) + + da = DocVec[MyDoc]( + [ + MyDoc( + inner_exists_l=_make_inner_list(), + inner_exists_v=_make_inner_list().to_doc_vec(), + ), + MyDoc( + inner_exists_l=_make_inner_list(), + inner_exists_v=_make_inner_list().to_doc_vec(), + ), + ] + ) + assert da._storage.docs_vec_columns['inner_l'] is None + assert da._storage.docs_vec_columns['inner_v'] is None + assert len(da._storage.docs_vec_columns['inner_exists_l']) == 2 + assert len(da._storage.docs_vec_columns['inner_exists_v']) == 2 + assert da.inner_exists_l[0].embedding.shape == (2, 512) + assert da.inner_exists_l[1].embedding.shape == (2, 512) + assert da.inner_exists_v[0].embedding.shape == (2, 512) + assert da.inner_exists_v[1].embedding.shape == (2, 512) + + proto = da.to_protobuf() + da_after = DocVec[MyDoc].from_protobuf(proto) + + assert da_after._storage.docs_vec_columns['inner_l'] is None + assert da_after._storage.docs_vec_columns['inner_v'] is None + assert len(da._storage.docs_vec_columns['inner_exists_l']) == 2 + assert len(da._storage.docs_vec_columns['inner_exists_v']) == 2 + assert ( + da.inner_exists_l[0].embedding == da_after.inner_exists_l[0].embedding + ).all() + assert ( + da.inner_exists_l[1].embedding == da_after.inner_exists_l[1].embedding + ).all() + assert ( + da.inner_exists_v[0].embedding == da_after.inner_exists_v[0].embedding + ).all() + assert ( + da.inner_exists_v[1].embedding == da_after.inner_exists_v[1].embedding + ).all() + + +@pytest.mark.proto +def test_proto_any_column(): + class MyDoc(BaseDoc): + embedding: NdArray + text: str + d: Dict + + da = DocVec[MyDoc]( + [ + MyDoc( + embedding=np.random.random(512), + text='hi', + d={'a': 1}, + ), + MyDoc(embedding=np.random.random(512), text='there', d={'b': 2}), + ] + ) + assert da._storage.tensor_columns['embedding'].shape == (2, 512) + assert da._storage.any_columns['text'] == ['hi', 'there'] + assert da._storage.any_columns['d'] == [{'a': 1}, {'b': 2}] + + proto = da.to_protobuf() + da_after = DocVec[MyDoc].from_protobuf(proto) + + assert da_after.doc_type == da.doc_type + assert da._storage.tensor_columns['embedding'].shape == (2, 512) + assert ( + da_after._storage.tensor_columns['embedding'] + == da._storage.tensor_columns['embedding'] + ).all() + assert da._storage.any_columns['text'] == ['hi', 'there'] + assert da._storage.any_columns['d'] == [{'a': 1}, {'b': 2}] + + assert (da_after.embedding == da.embedding).all() + assert da_after.text == da.text + assert da_after.d == da.d + + +@pytest.mark.proto +def test_proto_none_any_column(): + class MyDoc(BaseDoc): + text: Optional[str] + d: Optional[Dict] + + da = DocVec[MyDoc]( + [ + MyDoc(), + MyDoc(), + ] + ) + assert da._storage.any_columns['text'] == [None, None] + assert da._storage.any_columns['d'] == [None, None] + + proto = da.to_protobuf() + da_after = DocVec[MyDoc].from_protobuf(proto) + + assert da_after._storage.any_columns['text'] == [None, None] + assert da_after._storage.any_columns['d'] == [None, None]