diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 150e1ef89d8..4482a553989 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -92,7 +92,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(chunks=self.to_protobuf()) + return NodeProto(document_array=self.to_protobuf()) @abstractmethod def traverse_flat( diff --git a/docarray/base_document/mixins/proto.py b/docarray/base_document/mixins/proto.py index 57c38be4141..cb71d964c1f 100644 --- a/docarray/base_document/mixins/proto.py +++ b/docarray/base_document/mixins/proto.py @@ -2,6 +2,7 @@ from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode +from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS if TYPE_CHECKING: from docarray.proto import DocumentProto, NodeProto @@ -12,8 +13,6 @@ except ImportError: torch_imported = False else: - from docarray.typing.tensor.torch_tensor import TorchTensor - torch_imported = True @@ -24,61 +23,51 @@ class ProtoMixin(AbstractDocument, BaseNode): @classmethod def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: """create a Document from a protobuf message""" - from docarray.typing import ( # TorchTensor, - ID, - AnyEmbedding, - AnyUrl, - ImageUrl, - Mesh3DUrl, - NdArray, - PointCloud3DUrl, - TextUrl, - ) fields: Dict[str, Any] = {} for field in pb_msg.data: value = pb_msg.data[field] - content_type = value.WhichOneof('content') - - # this if else statement need to be refactored it is too long - # the check should be delegated to the type level - content_type_dict = dict( - ndarray=NdArray, - embedding=AnyEmbedding, - any_url=AnyUrl, - text_url=TextUrl, - image_url=ImageUrl, - mesh_url=Mesh3DUrl, - point_cloud_url=PointCloud3DUrl, - id=ID, + content_type_dict = _PROTO_TYPE_NAME_TO_CLASS + + content_key = value.WhichOneof('content') + content_type = ( + value.type if value.WhichOneof('docarray_type') is not None else None ) if torch_imported: - content_type_dict['torch_tensor'] = TorchTensor + from docarray.typing import TorchEmbedding + from docarray.typing.tensor.torch_tensor import TorchTensor + + content_type_dict['torch'] = TorchTensor + content_type_dict['torch_embedding'] = TorchEmbedding if content_type in content_type_dict: fields[field] = content_type_dict[content_type].from_protobuf( - getattr(value, content_type) + getattr(value, content_key) ) - elif content_type == 'text': - fields[field] = value.text - elif content_type == 'nested': + elif content_key == 'document': fields[field] = cls._get_field_type(field).from_protobuf( - value.nested + value.document ) # we get to the parent class - elif content_type == 'chunks': + elif content_key == 'document_array': from docarray import DocumentArray fields[field] = DocumentArray.from_protobuf( - value.chunks + value.document_array ) # we get to the parent class - elif content_type is None: + elif content_key is None: fields[field] = None + elif content_type is None: + if content_key == 'text': + fields[field] = value.text + elif content_key == 'blob': + fields[field] = value.blob else: raise ValueError( - f'type {content_type} is not supported for deserialization' + f'type {content_type}, with key {content_key} is not supported for' + f' deserialization' ) return cls.construct(**fields) @@ -133,4 +122,4 @@ def _to_node_protobuf(self) -> 'NodeProto': :return: the nested item protobuf message """ - return NodeProto(nested=self.to_protobuf()) + return NodeProto(document=self.to_protobuf()) diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto index 39f8354b223..64770bd6794 100644 --- a/docarray/proto/docarray.proto +++ b/docarray/proto/docarray.proto @@ -30,57 +30,24 @@ message NdArrayProto { // message NodeProto { - - oneof content { bytes blob = 1; - // the ndarray of the image/audio/video document NdArrayProto ndarray = 2; - // a text string text = 3; - // a sub Document - DocumentProto nested = 4; - + DocumentProto document = 4; // a sub DocumentArray - DocumentArrayProto chunks = 5; - - NdArrayProto embedding = 6; - - string any_url = 7; - - string image_url = 8; - - string text_url = 9; - - string id = 10; - - NdArrayProto torch_tensor = 11; - - string mesh_url = 12; - - string point_cloud_url = 13; - - string audio_url = 14; - - NdArrayProto audio_ndarray = 15; - - NdArrayProto audio_torch_tensor = 16; - - string video_url = 17; - - NdArrayProto video_ndarray = 18; - - NdArrayProto video_torch_tensor = 19; + DocumentArrayProto document_array = 5; + } + oneof docarray_type { + string type = 6; } } - - /** * Represents a Document */ diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py index da5d3df5a46..d0729658ff7 100644 --- a/docarray/proto/pb2/docarray_pb2.py +++ b/docarray/proto/pb2/docarray_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\x8a\x05\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12)\n\x06nested\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12.\n\x06\x63hunks\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12+\n\tembedding\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x11\n\x07\x61ny_url\x18\x07 \x01(\tH\x00\x12\x13\n\timage_url\x18\x08 \x01(\tH\x00\x12\x12\n\x08text_url\x18\t \x01(\tH\x00\x12\x0c\n\x02id\x18\n \x01(\tH\x00\x12.\n\x0ctorch_tensor\x18\x0b \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x12\n\x08mesh_url\x18\x0c \x01(\tH\x00\x12\x19\n\x0fpoint_cloud_url\x18\r \x01(\tH\x00\x12\x13\n\taudio_url\x18\x0e \x01(\tH\x00\x12/\n\raudio_ndarray\x18\x0f \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x34\n\x12\x61udio_torch_tensor\x18\x10 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x13\n\tvideo_url\x18\x11 \x01(\tH\x00\x12/\n\rvideo_ndarray\x18\x12 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x34\n\x12video_torch_tensor\x18\x13 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"\x86\x01\n\x0fUnionArrayProto\x12=\n\x0e\x64ocument_array\x18\x01 \x01(\x0b\x32#.docarray.DocumentArrayStackedProtoH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\xd6\x01\n\x19\x44ocumentArrayStackedProto\x12+\n\x05list_\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProto\x12\x41\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x30.docarray.DocumentArrayStackedProto.ColumnsEntry\x1aI\n\x0c\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.docarray.UnionArrayProto:\x02\x38\x01\x62\x06proto3' + b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xe7\x01\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12+\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12\x36\n\x0e\x64ocument_array\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12\x0e\n\x04type\x18\x06 \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"\x86\x01\n\x0fUnionArrayProto\x12=\n\x0e\x64ocument_array\x18\x01 \x01(\x0b\x32#.docarray.DocumentArrayStackedProtoH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\xd6\x01\n\x19\x44ocumentArrayStackedProto\x12+\n\x05list_\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProto\x12\x41\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x30.docarray.DocumentArrayStackedProto.ColumnsEntry\x1aI\n\x0c\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.docarray.UnionArrayProto:\x02\x38\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -32,17 +32,17 @@ _NDARRAYPROTO._serialized_start = 125 _NDARRAYPROTO._serialized_end = 228 _NODEPROTO._serialized_start = 231 - _NODEPROTO._serialized_end = 881 - _DOCUMENTPROTO._serialized_start = 884 - _DOCUMENTPROTO._serialized_end = 1014 - _DOCUMENTPROTO_DATAENTRY._serialized_start = 950 - _DOCUMENTPROTO_DATAENTRY._serialized_end = 1014 - _DOCUMENTARRAYPROTO._serialized_start = 1016 - _DOCUMENTARRAYPROTO._serialized_end = 1075 - _UNIONARRAYPROTO._serialized_start = 1078 - _UNIONARRAYPROTO._serialized_end = 1212 - _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 1215 - _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1429 - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_start = 1356 - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_end = 1429 + _NODEPROTO._serialized_end = 462 + _DOCUMENTPROTO._serialized_start = 465 + _DOCUMENTPROTO._serialized_end = 595 + _DOCUMENTPROTO_DATAENTRY._serialized_start = 531 + _DOCUMENTPROTO_DATAENTRY._serialized_end = 595 + _DOCUMENTARRAYPROTO._serialized_start = 597 + _DOCUMENTARRAYPROTO._serialized_end = 656 + _UNIONARRAYPROTO._serialized_start = 659 + _UNIONARRAYPROTO._serialized_end = 793 + _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 796 + _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1010 + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_start = 937 + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_end = 1010 # @@protoc_insertion_point(module_scope) diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py index 864ad7b2f67..9d23f0b7730 100644 --- a/docarray/typing/__init__.py +++ b/docarray/typing/__init__.py @@ -16,6 +16,7 @@ __all__ = [ 'NdArray', + 'NdArrayEmbedding', 'AudioNdArray', 'VideoNdArray', 'AnyEmbedding', diff --git a/docarray/typing/abstract_type.py b/docarray/typing/abstract_type.py index c03d4336f6c..97bf9a2c1a2 100644 --- a/docarray/typing/abstract_type.py +++ b/docarray/typing/abstract_type.py @@ -13,6 +13,8 @@ class AbstractType(BaseNode): + _proto_type_name: str + @classmethod def __get_validators__(cls): yield cls.validate diff --git a/docarray/typing/id.py b/docarray/typing/id.py index bf39d771746..02ee39245b6 100644 --- a/docarray/typing/id.py +++ b/docarray/typing/id.py @@ -4,6 +4,8 @@ from pydantic import BaseConfig, parse_obj_as from pydantic.fields import ModelField +from docarray.typing.proto_register import _register_proto + if TYPE_CHECKING: from docarray.proto import NodeProto @@ -12,6 +14,7 @@ T = TypeVar('T', bound='ID') +@_register_proto(proto_type_name='id') class ID(str, AbstractType): """ Represent an unique ID @@ -44,7 +47,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(id=self) + return NodeProto(text=self, type=self._proto_type_name) @classmethod def from_protobuf(cls: Type[T], pb_msg: 'str') -> T: diff --git a/docarray/typing/proto_register.py b/docarray/typing/proto_register.py new file mode 100644 index 00000000000..4a1fe77dad9 --- /dev/null +++ b/docarray/typing/proto_register.py @@ -0,0 +1,43 @@ +from typing import Callable, Dict, Type + +from docarray.typing.abstract_type import AbstractType + +_PROTO_TYPE_NAME_TO_CLASS: Dict[str, Type[AbstractType]] = {} + + +def _register_proto( + proto_type_name: str, +) -> Callable[[Type[AbstractType]], Type[AbstractType]]: + """Register a new type to be used in the protobuf serialization. + + This will add the type key to the global registry of types key used in the proto + serialization and deserialization. This is for internal usage only. + + EXAMPLE USAGE + + .. code-block:: python + + from docarray.typing.proto_register import register_proto + from docarray.typing.abstract_type import AbstractType + + + @register_proto(proto_type_name='my_type') + class MyType(AbstractType): + ... + + :param cls: the class to register + :return: the class + """ + + if proto_type_name in _PROTO_TYPE_NAME_TO_CLASS.keys(): + raise ValueError( + f'the key {proto_type_name} is already registered in the global registry' + ) + + def _register(cls: Type['AbstractType']) -> Type['AbstractType']: + cls._proto_type_name = proto_type_name + + _PROTO_TYPE_NAME_TO_CLASS[proto_type_name] = cls + return cls + + return _register diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index d1342bd1b1c..b179665f9ff 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -21,7 +21,7 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.proto import NdArrayProto + from docarray.proto import NdArrayProto, NodeProto T = TypeVar('T', bound='AbstractTensor') TTensor = TypeVar('TTensor') @@ -70,7 +70,19 @@ def __instancecheck__(cls, instance): class AbstractTensor(Generic[TTensor, T], AbstractType, ABC): __parametrized_meta__: type = _ParametrizedMeta - _PROTO_FIELD_NAME: str + _proto_type_name: str + + def _to_node_protobuf(self: T) -> 'NodeProto': + """Convert itself into a NodeProto protobuf message. This function should + be called when the Document is nested into another Document that need to be + converted into a protobuf + :param field: field in which to store the content in the node proto + :return: the nested item protobuf message + """ + from docarray.proto import NodeProto + + nd_proto = self.to_protobuf() + return NodeProto(ndarray=nd_proto, type=self._proto_type_name) @classmethod def __docarray_validate_shape__(cls, t: T, shape: Tuple[Union[int, str]]) -> T: diff --git a/docarray/typing/tensor/audio/audio_ndarray.py b/docarray/typing/tensor/audio/audio_ndarray.py index 1d619f0cdf8..807cbd55fbb 100644 --- a/docarray/typing/tensor/audio/audio_ndarray.py +++ b/docarray/typing/tensor/audio/audio_ndarray.py @@ -1,5 +1,6 @@ from typing import TypeVar +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor from docarray.typing.tensor.ndarray import NdArray @@ -8,6 +9,7 @@ T = TypeVar('T', bound='AudioNdArray') +@_register_proto(proto_type_name='audio_ndarray') class AudioNdArray(AbstractAudioTensor, NdArray): """ Subclass of NdArray, to represent an audio tensor. @@ -52,8 +54,6 @@ class MyAudioDoc(Document): """ - _PROTO_FIELD_NAME = 'audio_ndarray' - def to_audio_bytes(self): tensor = (self * MAX_INT_16).astype(' np.ndarray: """ return self.view(np.ndarray) - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert itself into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that need to be - converted into a protobuf - :param field: field in which to store the content in the node proto - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - nd_proto = self.to_protobuf() - return NodeProto(**{self._PROTO_FIELD_NAME: nd_proto}) - @classmethod def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': """ diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 808fb7a23bd..cdbd75f2a71 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -4,13 +4,14 @@ import numpy as np import torch # type: ignore +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: from pydantic.fields import ModelField from pydantic import BaseConfig import numpy as np - from docarray.proto import NdArrayProto, NodeProto + from docarray.proto import NdArrayProto from docarray.computation.torch_backend import TorchCompBackend from docarray.base_document.base_node import BaseNode @@ -32,6 +33,7 @@ class metaTorchAndNode( pass +@_register_proto(proto_type_name='torch_tensor') class TorchTensor( torch.Tensor, AbstractTensor, Generic[ShapeT], metaclass=metaTorchAndNode ): @@ -83,7 +85,6 @@ class MyDoc(BaseDocument): """ __parametrized_meta__ = metaTorchAndNode - _PROTO_FIELD_NAME = 'torch_tensor' @classmethod def __get_validators__(cls): @@ -171,18 +172,6 @@ def from_ndarray(cls: Type[T], value: np.ndarray) -> T: """ return cls._docarray_from_native(torch.from_numpy(value)) - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that need to be - converted into a protobuf - :param field: field in which to store the content in the node proto - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - nd_proto = self.to_protobuf() - return NodeProto(**{self._PROTO_FIELD_NAME: nd_proto}) - @classmethod def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': """ diff --git a/docarray/typing/tensor/video/video_ndarray.py b/docarray/typing/tensor/video/video_ndarray.py index 5cf6efc0057..345499634e5 100644 --- a/docarray/typing/tensor/video/video_ndarray.py +++ b/docarray/typing/tensor/video/video_ndarray.py @@ -2,6 +2,7 @@ import numpy as np +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin @@ -12,6 +13,7 @@ from pydantic.fields import ModelField +@_register_proto(proto_type_name='video_ndarray') class VideoNdArray(NdArray, VideoTensorMixin): """ Subclass of NdArray, to represent a video tensor. @@ -21,8 +23,6 @@ class VideoNdArray(NdArray, VideoTensorMixin): """ - _PROTO_FIELD_NAME = 'video_ndarray' - @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/tensor/video/video_torch_tensor.py b/docarray/typing/tensor/video/video_torch_tensor.py index 60dce18da3f..92a477915e9 100644 --- a/docarray/typing/tensor/video/video_torch_tensor.py +++ b/docarray/typing/tensor/video/video_torch_tensor.py @@ -2,6 +2,7 @@ import numpy as np +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.torch_tensor import TorchTensor, metaTorchAndNode from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin @@ -12,6 +13,7 @@ from pydantic.fields import ModelField +@_register_proto(proto_type_name='video_torch_tensor') class VideoTorchTensor(TorchTensor, VideoTensorMixin, metaclass=metaTorchAndNode): """ Subclass of TorchTensor, to represent a video tensor. @@ -21,8 +23,6 @@ class VideoTorchTensor(TorchTensor, VideoTensorMixin, metaclass=metaTorchAndNode """ - _PROTO_FIELD_NAME = 'video_torch_tensor' - @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index adea38e7b67..700c1af2e3a 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -4,6 +4,7 @@ from pydantic import errors, parse_obj_as from docarray.typing.abstract_type import AbstractType +from docarray.typing.proto_register import _register_proto if TYPE_CHECKING: from pydantic.networks import Parts @@ -13,6 +14,7 @@ T = TypeVar('T', bound='AnyUrl') +@_register_proto(proto_type_name='any_url') class AnyUrl(BaseAnyUrl, AbstractType): host_required = ( False # turn off host requirement to allow passing of local paths as URL @@ -27,7 +29,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(any_url=str(self)) + return NodeProto(text=str(self), type=self._proto_type_name) @classmethod def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index 1646b4eb0e0..f04fdb1c7fa 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -4,6 +4,7 @@ import numpy as np from pydantic import parse_obj_as +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.audio.audio_ndarray import MAX_INT_16, AudioNdArray from docarray.typing.url.any_url import AnyUrl @@ -11,30 +12,18 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.proto import NodeProto - T = TypeVar('T', bound='AudioUrl') AUDIO_FILE_FORMATS = ['wav'] +@_register_proto(proto_type_name='audio_url') class AudioUrl(AnyUrl): """ URL to a .wav file. Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(audio_url=str(self)) - @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 29063c46dcc..4e7300a33d9 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -4,6 +4,7 @@ import numpy as np +from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl from docarray.typing.url.helper import _uri_to_blob @@ -12,30 +13,18 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.proto import NodeProto - T = TypeVar('T', bound='ImageUrl') IMAGE_FILE_FORMATS = ('png', 'jpeg', 'jpg') +@_register_proto(proto_type_name='image_url') class ImageUrl(AnyUrl): """ URL to a .png, .jpeg, or .jpg file. Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(image_url=str(self)) - @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 179022a4a03..0b6ddaacdad 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -1,29 +1,17 @@ -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from docarray.proto import NodeProto +from typing import Optional +from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl from docarray.typing.url.helper import _uri_to_blob +@_register_proto(proto_type_name='text_url') class TextUrl(AnyUrl): """ URL to a text file. Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that need to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(text_url=str(self)) - def load_to_bytes(self, timeout: Optional[float] = None) -> bytes: """ Load the text file into a bytes object. diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index 5b18f5d847b..bd5a1de408e 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -1,14 +1,12 @@ -from typing import TYPE_CHECKING, NamedTuple, TypeVar +from typing import NamedTuple, TypeVar import numpy as np from pydantic import parse_obj_as from docarray.typing import NdArray +from docarray.typing.proto_register import _register_proto from docarray.typing.url.url_3d.url_3d import Url3D -if TYPE_CHECKING: - from docarray.proto import NodeProto - T = TypeVar('T', bound='Mesh3DUrl') @@ -17,23 +15,13 @@ class Mesh3DLoadResult(NamedTuple): faces: NdArray +@_register_proto(proto_type_name='mesh_url') class Mesh3DUrl(Url3D): """ URL to a .obj, .glb, or .ply file containing 3D mesh information. Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(mesh_url=str(self)) - def load(self: T) -> Mesh3DLoadResult: """ Load the data from the url into a named tuple of two NdArrays containing diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index 256b80e5ee3..ad57ff85ace 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -1,34 +1,22 @@ -from typing import TYPE_CHECKING, TypeVar +from typing import TypeVar import numpy as np from pydantic import parse_obj_as from docarray.typing import NdArray +from docarray.typing.proto_register import _register_proto from docarray.typing.url.url_3d.url_3d import Url3D -if TYPE_CHECKING: - from docarray.proto import NodeProto - T = TypeVar('T', bound='PointCloud3DUrl') +@_register_proto(proto_type_name='point_cloud_url') class PointCloud3DUrl(Url3D): """ URL to a .obj, .glb, or .ply file containing point cloud information. Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(point_cloud_url=str(self)) - def load(self: T, samples: int, multiple_geometries: bool = False) -> NdArray: """ Load the data from the url into an NdArray containing point cloud information. diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py index 68191efecfa..cebd7e94080 100644 --- a/docarray/typing/url/url_3d/url_3d.py +++ b/docarray/typing/url/url_3d/url_3d.py @@ -3,6 +3,7 @@ import numpy as np +from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl if TYPE_CHECKING: @@ -15,6 +16,7 @@ T = TypeVar('T', bound='Url3D') +@_register_proto(proto_type_name='url3d') class Url3D(AnyUrl, ABC): """ URL to a .obj, .glb, or .ply file containing 3D mesh or point cloud information. diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py index ca967a90c2a..7e171ec926a 100644 --- a/docarray/typing/url/video_url.py +++ b/docarray/typing/url/video_url.py @@ -4,6 +4,7 @@ from pydantic.tools import parse_obj_as from docarray.typing import AudioNdArray, NdArray +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.video import VideoNdArray from docarray.typing.url.any_url import AnyUrl @@ -11,8 +12,6 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.proto import NodeProto - T = TypeVar('T', bound='VideoUrl') VIDEO_FILE_FORMATS = ['mp4'] @@ -24,22 +23,13 @@ class VideoLoadResult(NamedTuple): key_frame_indices: NdArray +@_register_proto(proto_type_name='video_url') class VideoUrl(AnyUrl): """ URL to a .wav file. Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(video_url=str(self)) - @classmethod def validate( cls: Type[T], diff --git a/tests/integrations/document/test_proto.py b/tests/integrations/document/test_proto.py index e405ca4b1b2..adabe6deb84 100644 --- a/tests/integrations/document/test_proto.py +++ b/tests/integrations/document/test_proto.py @@ -72,7 +72,8 @@ class MyDoc(BaseDocument): np_embedding=np.zeros((128,)), nested_docs=DocumentArray[NestedDoc]([NestedDoc(tensor=np.zeros((128,)))]), ) - doc = MyDoc.from_protobuf(doc.to_protobuf()) + doc = doc.to_protobuf() + doc = MyDoc.from_protobuf(doc) assert doc.img_url == 'test.png' assert doc.txt_url == 'test.txt' diff --git a/tests/units/document/proto/test_proto_based_object.py b/tests/units/document/proto/test_proto_based_object.py index 051485f91e1..6d2b7a79b7b 100644 --- a/tests/units/document/proto/test_proto_based_object.py +++ b/tests/units/document/proto/test_proto_based_object.py @@ -4,15 +4,6 @@ from docarray.typing import NdArray -def test_nested_item_proto(): - NodeProto(text='hello') - NodeProto(nested=DocumentProto()) - - -def test_nested_optional_item_proto(): - NodeProto() - - def test_ndarray(): original_ndarray = np.zeros((3, 224, 224)) diff --git a/tests/units/typing/tensor/test_audio_tensor.py b/tests/units/typing/tensor/test_audio_tensor.py index caa016dbb50..9a2bf64fa69 100644 --- a/tests/units/typing/tensor/test_audio_tensor.py +++ b/tests/units/typing/tensor/test_audio_tensor.py @@ -57,14 +57,14 @@ def test_illegal_validation(cls_tensor, tensor): @pytest.mark.parametrize( 'cls_tensor,tensor,proto_key', [ - (AudioTorchTensor, torch.zeros(1000, 2), AudioTorchTensor._PROTO_FIELD_NAME), - (AudioNdArray, np.zeros((1000, 2)), AudioNdArray._PROTO_FIELD_NAME), + (AudioTorchTensor, torch.zeros(1000, 2), AudioTorchTensor._proto_type_name), + (AudioNdArray, np.zeros((1000, 2)), AudioNdArray._proto_type_name), ], ) def test_proto_tensor(cls_tensor, tensor, proto_key): tensor = parse_obj_as(cls_tensor, tensor) proto = tensor._to_node_protobuf() - assert str(proto).startswith(proto_key) + assert proto_key in str(proto) @pytest.mark.parametrize( diff --git a/tests/units/typing/tensor/test_video_tensor.py b/tests/units/typing/tensor/test_video_tensor.py index 214fcdf6e12..d7baf131a39 100644 --- a/tests/units/typing/tensor/test_video_tensor.py +++ b/tests/units/typing/tensor/test_video_tensor.py @@ -67,15 +67,15 @@ def test_illegal_validation(cls_tensor, tensor): ( VideoTorchTensor, torch.zeros(1, 224, 224, 3), - VideoTorchTensor._PROTO_FIELD_NAME, + VideoTorchTensor._proto_type_name, ), - (VideoNdArray, np.zeros((1, 224, 224, 3)), VideoNdArray._PROTO_FIELD_NAME), + (VideoNdArray, np.zeros((1, 224, 224, 3)), VideoNdArray._proto_type_name), ], ) def test_proto_tensor(cls_tensor, tensor, proto_key): tensor = parse_obj_as(cls_tensor, tensor) proto = tensor._to_node_protobuf() - assert str(proto).startswith(proto_key) + assert proto_key in str(proto) @pytest.mark.parametrize( diff --git a/tests/units/typing/url/test_audio_url.py b/tests/units/typing/url/test_audio_url.py index 5d0f042fd71..21280aa8eb9 100644 --- a/tests/units/typing/url/test_audio_url.py +++ b/tests/units/typing/url/test_audio_url.py @@ -105,4 +105,4 @@ def test_illegal_validation(path_to_file): def test_proto_audio_url(file_url): uri = parse_obj_as(AudioUrl, file_url) proto = uri._to_node_protobuf() - assert str(proto).startswith('audio_url') + assert 'audio_url' in str(proto) diff --git a/tests/units/typing/url/test_video_url.py b/tests/units/typing/url/test_video_url.py index 46bd60ea648..ef395092a93 100644 --- a/tests/units/typing/url/test_video_url.py +++ b/tests/units/typing/url/test_video_url.py @@ -123,4 +123,4 @@ def test_illegal_validation(path_to_file): def test_proto_video_url(file_url): uri = parse_obj_as(VideoUrl, file_url) proto = uri._to_node_protobuf() - assert str(proto).startswith('video_url') + assert 'video_url' in str(proto)