From 1a2475de3f16dc046cba430c1368f345e9637a20 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 16:31:30 +0100 Subject: [PATCH 01/16] wip Signed-off-by: Sami Jaghouar --- docarray/base_document/mixins/proto.py | 2 +- docarray/proto/docarray.proto | 37 ++----------------- docarray/proto/pb2/docarray_pb2.py | 28 +++++++------- .../document/proto/test_proto_based_object.py | 2 +- 4 files changed, 20 insertions(+), 49 deletions(-) diff --git a/docarray/base_document/mixins/proto.py b/docarray/base_document/mixins/proto.py index 57c38be4141..e3cde6e7815 100644 --- a/docarray/base_document/mixins/proto.py +++ b/docarray/base_document/mixins/proto.py @@ -133,4 +133,4 @@ def _to_node_protobuf(self) -> 'NodeProto': :return: the nested item protobuf message """ - return NodeProto(nested=self.to_protobuf()) + return NodeProto(document=self.to_protobuf()) diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto index 0646453294e..f63d345480c 100644 --- a/docarray/proto/docarray.proto +++ b/docarray/proto/docarray.proto @@ -30,50 +30,21 @@ message NdArrayProto { // message NodeProto { - - oneof content { bytes blob = 1; - // the ndarray of the image/audio/video document NdArrayProto ndarray = 2; - // a text string text = 3; - // a sub Document - DocumentProto nested = 4; - + DocumentProto document = 4; // a sub DocumentArray - DocumentArrayProto chunks = 5; - - NdArrayProto embedding = 6; - - string any_url = 7; - - string image_url = 8; - - string text_url = 9; - - string id = 10; - - NdArrayProto torch_tensor = 11; - - string mesh_url = 12; - - string point_cloud_url = 13; - - string audio_url = 14; - - NdArrayProto audio_ndarray = 15; - - NdArrayProto audio_torch_tensor = 16; - + DocumentArrayProto document_array = 5; } -} - + string docarray_type = 6; +} /** * Represents a Document diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py index 1d5fb2d954b..fe3b3a693bb 100644 --- a/docarray/proto/pb2/docarray_pb2.py +++ b/docarray/proto/pb2/docarray_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\x8e\x04\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12)\n\x06nested\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12.\n\x06\x63hunks\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12+\n\tembedding\x18\x06 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x11\n\x07\x61ny_url\x18\x07 \x01(\tH\x00\x12\x13\n\timage_url\x18\x08 \x01(\tH\x00\x12\x12\n\x08text_url\x18\t \x01(\tH\x00\x12\x0c\n\x02id\x18\n \x01(\tH\x00\x12.\n\x0ctorch_tensor\x18\x0b \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x12\n\x08mesh_url\x18\x0c \x01(\tH\x00\x12\x19\n\x0fpoint_cloud_url\x18\r \x01(\tH\x00\x12\x13\n\taudio_url\x18\x0e \x01(\tH\x00\x12/\n\raudio_ndarray\x18\x0f \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x34\n\x12\x61udio_torch_tensor\x18\x10 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"\x86\x01\n\x0fUnionArrayProto\x12=\n\x0e\x64ocument_array\x18\x01 \x01(\x0b\x32#.docarray.DocumentArrayStackedProtoH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\xd6\x01\n\x19\x44ocumentArrayStackedProto\x12+\n\x05list_\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProto\x12\x41\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x30.docarray.DocumentArrayStackedProto.ColumnsEntry\x1aI\n\x0c\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.docarray.UnionArrayProto:\x02\x38\x01\x62\x06proto3' + b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xdd\x01\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12+\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12\x36\n\x0e\x64ocument_array\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12\x15\n\rdocarray_type\x18\x06 \x01(\tB\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"\x86\x01\n\x0fUnionArrayProto\x12=\n\x0e\x64ocument_array\x18\x01 \x01(\x0b\x32#.docarray.DocumentArrayStackedProtoH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\xd6\x01\n\x19\x44ocumentArrayStackedProto\x12+\n\x05list_\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProto\x12\x41\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x30.docarray.DocumentArrayStackedProto.ColumnsEntry\x1aI\n\x0c\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.docarray.UnionArrayProto:\x02\x38\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -32,17 +32,17 @@ _NDARRAYPROTO._serialized_start = 125 _NDARRAYPROTO._serialized_end = 228 _NODEPROTO._serialized_start = 231 - _NODEPROTO._serialized_end = 757 - _DOCUMENTPROTO._serialized_start = 760 - _DOCUMENTPROTO._serialized_end = 890 - _DOCUMENTPROTO_DATAENTRY._serialized_start = 826 - _DOCUMENTPROTO_DATAENTRY._serialized_end = 890 - _DOCUMENTARRAYPROTO._serialized_start = 892 - _DOCUMENTARRAYPROTO._serialized_end = 951 - _UNIONARRAYPROTO._serialized_start = 954 - _UNIONARRAYPROTO._serialized_end = 1088 - _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 1091 - _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1305 - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_start = 1232 - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_end = 1305 + _NODEPROTO._serialized_end = 452 + _DOCUMENTPROTO._serialized_start = 455 + _DOCUMENTPROTO._serialized_end = 585 + _DOCUMENTPROTO_DATAENTRY._serialized_start = 521 + _DOCUMENTPROTO_DATAENTRY._serialized_end = 585 + _DOCUMENTARRAYPROTO._serialized_start = 587 + _DOCUMENTARRAYPROTO._serialized_end = 646 + _UNIONARRAYPROTO._serialized_start = 649 + _UNIONARRAYPROTO._serialized_end = 783 + _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 786 + _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1000 + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_start = 927 + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_end = 1000 # @@protoc_insertion_point(module_scope) diff --git a/tests/units/document/proto/test_proto_based_object.py b/tests/units/document/proto/test_proto_based_object.py index 051485f91e1..f5af227c44d 100644 --- a/tests/units/document/proto/test_proto_based_object.py +++ b/tests/units/document/proto/test_proto_based_object.py @@ -6,7 +6,7 @@ def test_nested_item_proto(): NodeProto(text='hello') - NodeProto(nested=DocumentProto()) + NodeProto(document=DocumentProto()) def test_nested_optional_item_proto(): From 6970630c6772eebe7891bc4d99ec3f214ad75c8b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 16:44:50 +0100 Subject: [PATCH 02/16] wip2 Signed-off-by: Sami Jaghouar --- docarray/proto/docarray.proto | 4 ++- docarray/proto/pb2/docarray_pb2.py | 54 ++++++++++++++---------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docarray/proto/docarray.proto b/docarray/proto/docarray.proto index f63d345480c..64770bd6794 100644 --- a/docarray/proto/docarray.proto +++ b/docarray/proto/docarray.proto @@ -42,7 +42,9 @@ message NodeProto { DocumentArrayProto document_array = 5; } - string docarray_type = 6; + oneof docarray_type { + string type = 6; + } } diff --git a/docarray/proto/pb2/docarray_pb2.py b/docarray/proto/pb2/docarray_pb2.py index fe3b3a693bb..88b74779147 100644 --- a/docarray/proto/pb2/docarray_pb2.py +++ b/docarray/proto/pb2/docarray_pb2.py @@ -2,11 +2,10 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: docarray.proto """Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -14,35 +13,34 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xdd\x01\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12+\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12\x36\n\x0e\x64ocument_array\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12\x15\n\rdocarray_type\x18\x06 \x01(\tB\t\n\x07\x63ontent\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"\x86\x01\n\x0fUnionArrayProto\x12=\n\x0e\x64ocument_array\x18\x01 \x01(\x0b\x32#.docarray.DocumentArrayStackedProtoH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\xd6\x01\n\x19\x44ocumentArrayStackedProto\x12+\n\x05list_\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProto\x12\x41\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x30.docarray.DocumentArrayStackedProto.ColumnsEntry\x1aI\n\x0c\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.docarray.UnionArrayProto:\x02\x38\x01\x62\x06proto3' -) + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x64ocarray.proto\x12\x08\x64ocarray\x1a\x1cgoogle/protobuf/struct.proto\"A\n\x11\x44\x65nseNdArrayProto\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\r\x12\r\n\x05\x64type\x18\x03 \x01(\t\"g\n\x0cNdArrayProto\x12*\n\x05\x64\x65nse\x18\x01 \x01(\x0b\x32\x1b.docarray.DenseNdArrayProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xe7\x01\n\tNodeProto\x12\x0e\n\x04\x62lob\x18\x01 \x01(\x0cH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x12\x0e\n\x04text\x18\x03 \x01(\tH\x00\x12+\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProtoH\x00\x12\x36\n\x0e\x64ocument_array\x18\x05 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12\x0e\n\x04type\x18\x06 \x01(\tH\x01\x42\t\n\x07\x63ontentB\x0f\n\rdocarray_type\"\x82\x01\n\rDocumentProto\x12/\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32!.docarray.DocumentProto.DataEntry\x1a@\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.docarray.NodeProto:\x02\x38\x01\";\n\x12\x44ocumentArrayProto\x12%\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x17.docarray.DocumentProto\"\x86\x01\n\x0fUnionArrayProto\x12=\n\x0e\x64ocument_array\x18\x01 \x01(\x0b\x32#.docarray.DocumentArrayStackedProtoH\x00\x12)\n\x07ndarray\x18\x02 \x01(\x0b\x32\x16.docarray.NdArrayProtoH\x00\x42\t\n\x07\x63ontent\"\xd6\x01\n\x19\x44ocumentArrayStackedProto\x12+\n\x05list_\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProto\x12\x41\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x30.docarray.DocumentArrayStackedProto.ColumnsEntry\x1aI\n\x0c\x43olumnsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.docarray.UnionArrayProto:\x02\x38\x01\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'docarray_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _DOCUMENTPROTO_DATAENTRY._options = None - _DOCUMENTPROTO_DATAENTRY._serialized_options = b'8\001' - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._options = None - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_options = b'8\001' - _DENSENDARRAYPROTO._serialized_start = 58 - _DENSENDARRAYPROTO._serialized_end = 123 - _NDARRAYPROTO._serialized_start = 125 - _NDARRAYPROTO._serialized_end = 228 - _NODEPROTO._serialized_start = 231 - _NODEPROTO._serialized_end = 452 - _DOCUMENTPROTO._serialized_start = 455 - _DOCUMENTPROTO._serialized_end = 585 - _DOCUMENTPROTO_DATAENTRY._serialized_start = 521 - _DOCUMENTPROTO_DATAENTRY._serialized_end = 585 - _DOCUMENTARRAYPROTO._serialized_start = 587 - _DOCUMENTARRAYPROTO._serialized_end = 646 - _UNIONARRAYPROTO._serialized_start = 649 - _UNIONARRAYPROTO._serialized_end = 783 - _DOCUMENTARRAYSTACKEDPROTO._serialized_start = 786 - _DOCUMENTARRAYSTACKEDPROTO._serialized_end = 1000 - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_start = 927 - _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_end = 1000 + DESCRIPTOR._options = None + _DOCUMENTPROTO_DATAENTRY._options = None + _DOCUMENTPROTO_DATAENTRY._serialized_options = b'8\001' + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._options = None + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_options = b'8\001' + _DENSENDARRAYPROTO._serialized_start=58 + _DENSENDARRAYPROTO._serialized_end=123 + _NDARRAYPROTO._serialized_start=125 + _NDARRAYPROTO._serialized_end=228 + _NODEPROTO._serialized_start=231 + _NODEPROTO._serialized_end=462 + _DOCUMENTPROTO._serialized_start=465 + _DOCUMENTPROTO._serialized_end=595 + _DOCUMENTPROTO_DATAENTRY._serialized_start=531 + _DOCUMENTPROTO_DATAENTRY._serialized_end=595 + _DOCUMENTARRAYPROTO._serialized_start=597 + _DOCUMENTARRAYPROTO._serialized_end=656 + _UNIONARRAYPROTO._serialized_start=659 + _UNIONARRAYPROTO._serialized_end=793 + _DOCUMENTARRAYSTACKEDPROTO._serialized_start=796 + _DOCUMENTARRAYSTACKEDPROTO._serialized_end=1010 + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_start=937 + _DOCUMENTARRAYSTACKEDPROTO_COLUMNSENTRY._serialized_end=1010 # @@protoc_insertion_point(module_scope) From 41e007b4047dad49d7fea10c582775cc210ac9e0 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 16:44:57 +0100 Subject: [PATCH 03/16] wip2 Signed-off-by: Sami Jaghouar --- docarray/array/abstract_array.py | 2 +- docarray/typing/abstract_type.py | 4 +++- docarray/typing/id.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 150e1ef89d8..5302d11ab13 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -92,7 +92,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(chunks=self.to_protobuf()) + return NodeProto(documentchunks=self.to_protobuf()) @abstractmethod def traverse_flat( diff --git a/docarray/typing/abstract_type.py b/docarray/typing/abstract_type.py index c03d4336f6c..0065c700719 100644 --- a/docarray/typing/abstract_type.py +++ b/docarray/typing/abstract_type.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Type, TypeVar +from typing import TYPE_CHECKING, Any, Type, TypeVar, Optional from pydantic import BaseConfig from pydantic.fields import ModelField @@ -13,6 +13,8 @@ class AbstractType(BaseNode): + _proto_type_name : str + @classmethod def __get_validators__(cls): yield cls.validate diff --git a/docarray/typing/id.py b/docarray/typing/id.py index bf39d771746..7ff7be04d4c 100644 --- a/docarray/typing/id.py +++ b/docarray/typing/id.py @@ -16,6 +16,7 @@ class ID(str, AbstractType): """ Represent an unique ID """ + _proto_type_name = 'id' @classmethod def __get_validators__(cls): @@ -44,7 +45,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(id=self) + return NodeProto(text=self, type = self._proto_type_name) @classmethod def from_protobuf(cls: Type[T], pb_msg: 'str') -> T: From 9b35927e6f103cb079e483d2aafcceff9676bdfa Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 16:58:34 +0100 Subject: [PATCH 04/16] wip3 Signed-off-by: Sami Jaghouar --- docarray/array/abstract_array.py | 2 +- docarray/typing/id.py | 3 ++- docarray/typing/tensor/abstract_tensor.py | 15 +++++++++++++-- docarray/typing/tensor/audio/audio_ndarray.py | 3 +-- .../typing/tensor/audio/audio_torch_tensor.py | 2 +- docarray/typing/tensor/embedding/ndarray.py | 2 ++ docarray/typing/tensor/embedding/torch.py | 2 ++ docarray/typing/tensor/ndarray.py | 14 +------------- docarray/typing/tensor/torch_tensor.py | 14 +------------- docarray/typing/url/any_url.py | 3 ++- docarray/typing/url/audio_url.py | 11 +---------- docarray/typing/url/image_url.py | 11 +---------- docarray/typing/url/text_url.py | 11 +---------- docarray/typing/url/url_3d/mesh_url.py | 12 +----------- docarray/typing/url/url_3d/point_cloud_url.py | 12 +----------- docarray/typing/url/url_3d/url_3d.py | 1 + tests/units/typing/tensor/test_audio_tensor.py | 4 ++-- 17 files changed, 34 insertions(+), 88 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 5302d11ab13..1e09b7e4417 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -92,7 +92,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(documentchunks=self.to_protobuf()) + return NodeProto(document=self.to_protobuf()) @abstractmethod def traverse_flat( diff --git a/docarray/typing/id.py b/docarray/typing/id.py index 7ff7be04d4c..e64f5a60234 100644 --- a/docarray/typing/id.py +++ b/docarray/typing/id.py @@ -16,6 +16,7 @@ class ID(str, AbstractType): """ Represent an unique ID """ + _proto_type_name = 'id' @classmethod @@ -45,7 +46,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(text=self, type = self._proto_type_name) + return NodeProto(text=self, type=self._proto_type_name) @classmethod def from_protobuf(cls: Type[T], pb_msg: 'str') -> T: diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 609b9cff2d2..d0d70b601b7 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -9,7 +9,7 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.proto import NdArrayProto + from docarray.proto import NdArrayProto, NodeProto T = TypeVar('T', bound='AbstractTensor') ShapeT = TypeVar('ShapeT') @@ -57,7 +57,18 @@ def __instancecheck__(cls, instance): class AbstractTensor(Generic[ShapeT], AbstractType, ABC): __parametrized_meta__: type = _ParametrizedMeta - _PROTO_FIELD_NAME: str + _proto_type_name: str + def _to_node_protobuf(self: T) -> 'NodeProto': + """Convert itself into a NodeProto protobuf message. This function should + be called when the Document is nested into another Document that need to be + converted into a protobuf + :param field: field in which to store the content in the node proto + :return: the nested item protobuf message + """ + from docarray.proto import NodeProto + + nd_proto = self.to_protobuf() + return NodeProto(ndarray=nd_proto, type=self._proto_type_name) @classmethod @abc.abstractmethod diff --git a/docarray/typing/tensor/audio/audio_ndarray.py b/docarray/typing/tensor/audio/audio_ndarray.py index 1d619f0cdf8..ac399a6866e 100644 --- a/docarray/typing/tensor/audio/audio_ndarray.py +++ b/docarray/typing/tensor/audio/audio_ndarray.py @@ -51,8 +51,7 @@ class MyAudioDoc(Document): doc_2.audio_tensor.save_to_wav_file(file_path='path/to/file_2.wav') """ - - _PROTO_FIELD_NAME = 'audio_ndarray' + _proto_type_name = 'audio_ndarray' def to_audio_bytes(self): tensor = (self * MAX_INT_16).astype(' np.ndarray: """ return self.view(np.ndarray) - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert itself into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that need to be - converted into a protobuf - :param field: field in which to store the content in the node proto - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - nd_proto = self.to_protobuf() - return NodeProto(**{self._PROTO_FIELD_NAME: nd_proto}) - @classmethod def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': """ diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 946e7dfd5a2..04bfc96be36 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -81,7 +81,7 @@ class MyDoc(BaseDocument): """ __parametrized_meta__ = metaTorchAndNode - _PROTO_FIELD_NAME = 'torch_tensor' + _proto_type_name = 'torch' @classmethod def __get_validators__(cls): @@ -186,18 +186,6 @@ def from_ndarray(cls: Type[T], value: np.ndarray) -> T: """ return cls._docarray_from_native(torch.from_numpy(value)) - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that need to be - converted into a protobuf - :param field: field in which to store the content in the node proto - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - nd_proto = self.to_protobuf() - return NodeProto(**{self._PROTO_FIELD_NAME: nd_proto}) - @classmethod def from_protobuf(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T': """ diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index adea38e7b67..9eba9b9537c 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -17,6 +17,7 @@ class AnyUrl(BaseAnyUrl, AbstractType): host_required = ( False # turn off host requirement to allow passing of local paths as URL ) + _proto_type_name = 'url' def _to_node_protobuf(self) -> 'NodeProto': """Convert Document into a NodeProto protobuf message. This function should @@ -27,7 +28,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(any_url=str(self)) + return NodeProto(text=str(self), type=self._proto_type_name) @classmethod def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index 6e9e25a7e7e..378a7a2bc3d 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -24,16 +24,7 @@ class AudioUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(audio_url=str(self)) + _proto_type_name = 'audio_url' @classmethod def validate( diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 29063c46dcc..bb9dc4c91ed 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -25,16 +25,7 @@ class ImageUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(image_url=str(self)) + _proto_type_name = 'image_url' @classmethod def validate( diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 179022a4a03..6f0c5ebebda 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -13,16 +13,7 @@ class TextUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that need to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(text_url=str(self)) + _proto_type_name = 'texturl' def load_to_bytes(self, timeout: Optional[float] = None) -> bytes: """ diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index a772c345f0b..040e7f3c28e 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -15,17 +15,7 @@ class Mesh3DUrl(Url3D): URL to a .obj, .glb, or .ply file containing 3D mesh information. Can be remote (web) URL, or a local file path. """ - - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(mesh_url=str(self)) + _proto_type_name = 'mesh3durl' def load(self: T) -> Tuple[np.ndarray, np.ndarray]: """ diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index e4ebae45ff0..f7ddf4e60d7 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -15,17 +15,7 @@ class PointCloud3DUrl(Url3D): URL to a .obj, .glb, or .ply file containing point cloud information. Can be remote (web) URL, or a local file path. """ - - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(point_cloud_url=str(self)) + _proto_type_name = 'point_cloud_url' def load(self: T, samples: int, multiple_geometries: bool = False) -> np.ndarray: """ diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py index 68191efecfa..4a65b568b85 100644 --- a/docarray/typing/url/url_3d/url_3d.py +++ b/docarray/typing/url/url_3d/url_3d.py @@ -20,6 +20,7 @@ class Url3D(AnyUrl, ABC): URL to a .obj, .glb, or .ply file containing 3D mesh or point cloud information. Can be remote (web) URL, or a local file path. """ + _proto_type_name = 'url3d' @classmethod def validate( diff --git a/tests/units/typing/tensor/test_audio_tensor.py b/tests/units/typing/tensor/test_audio_tensor.py index caa016dbb50..e3dbf4579af 100644 --- a/tests/units/typing/tensor/test_audio_tensor.py +++ b/tests/units/typing/tensor/test_audio_tensor.py @@ -57,8 +57,8 @@ def test_illegal_validation(cls_tensor, tensor): @pytest.mark.parametrize( 'cls_tensor,tensor,proto_key', [ - (AudioTorchTensor, torch.zeros(1000, 2), AudioTorchTensor._PROTO_FIELD_NAME), - (AudioNdArray, np.zeros((1000, 2)), AudioNdArray._PROTO_FIELD_NAME), + (AudioTorchTensor, torch.zeros(1000, 2), AudioTorchTensor._proto_type_name), + (AudioNdArray, np.zeros((1000, 2)), AudioNdArray._proto_type_name), ], ) def test_proto_tensor(cls_tensor, tensor, proto_key): From f05da7720e006647ab5b59159114761854544628 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 17:21:13 +0100 Subject: [PATCH 05/16] wip4 Signed-off-by: Sami Jaghouar --- docarray/base_document/mixins/proto.py | 33 +++++---- docarray/typing/url/any_url.py | 2 +- docarray/typing/url/text_url.py | 2 +- docarray/typing/url/url_3d/mesh_url.py | 2 +- tests/integrations/document/test_proto.py | 82 +++++++++++------------ 5 files changed, 63 insertions(+), 58 deletions(-) diff --git a/docarray/base_document/mixins/proto.py b/docarray/base_document/mixins/proto.py index e3cde6e7815..081d770319a 100644 --- a/docarray/base_document/mixins/proto.py +++ b/docarray/base_document/mixins/proto.py @@ -24,7 +24,7 @@ class ProtoMixin(AbstractDocument, BaseNode): @classmethod def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: """create a Document from a protobuf message""" - from docarray.typing import ( # TorchTensor, + from docarray.typing import ( ID, AnyEmbedding, AnyUrl, @@ -35,15 +35,12 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: TextUrl, ) + fields: Dict[str, Any] = {} for field in pb_msg.data: value = pb_msg.data[field] - content_type = value.WhichOneof('content') - - # this if else statement need to be refactored it is too long - # the check should be delegated to the type level content_type_dict = dict( ndarray=NdArray, embedding=AnyEmbedding, @@ -55,27 +52,35 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: id=ID, ) + content_key = value.WhichOneof('content') + content_type = ( + value.type if value.WhichOneof('docarray_type') is not None else None + ) + if torch_imported: - content_type_dict['torch_tensor'] = TorchTensor + content_type_dict['torch'] = TorchTensor if content_type in content_type_dict: fields[field] = content_type_dict[content_type].from_protobuf( - getattr(value, content_type) + getattr(value, content_key) ) - elif content_type == 'text': - fields[field] = value.text - elif content_type == 'nested': + elif content_key == 'document': fields[field] = cls._get_field_type(field).from_protobuf( - value.nested + value.document ) # we get to the parent class - elif content_type == 'chunks': + elif content_key == 'document_array': from docarray import DocumentArray fields[field] = DocumentArray.from_protobuf( - value.chunks + value.document_array ) # we get to the parent class - elif content_type is None: + elif content_key is None: fields[field] = None + elif content_type is None: + if content_key == 'text': + fields[field] = value.text + elif content_key == 'blob': + fields[field] = value.blob else: raise ValueError( f'type {content_type} is not supported for deserialization' diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index 9eba9b9537c..485797b165e 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -17,7 +17,7 @@ class AnyUrl(BaseAnyUrl, AbstractType): host_required = ( False # turn off host requirement to allow passing of local paths as URL ) - _proto_type_name = 'url' + _proto_type_name = 'any_url' def _to_node_protobuf(self) -> 'NodeProto': """Convert Document into a NodeProto protobuf message. This function should diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 6f0c5ebebda..58498682ee2 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -13,7 +13,7 @@ class TextUrl(AnyUrl): Can be remote (web) URL, or a local file path. """ - _proto_type_name = 'texturl' + _proto_type_name = 'text_url' def load_to_bytes(self, timeout: Optional[float] = None) -> bytes: """ diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index 040e7f3c28e..157edd8a6ac 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -15,7 +15,7 @@ class Mesh3DUrl(Url3D): URL to a .obj, .glb, or .ply file containing 3D mesh information. Can be remote (web) URL, or a local file path. """ - _proto_type_name = 'mesh3durl' + _proto_type_name = 'mesh_url' def load(self: T) -> Tuple[np.ndarray, np.ndarray]: """ diff --git a/tests/integrations/document/test_proto.py b/tests/integrations/document/test_proto.py index e405ca4b1b2..7e1fda5624d 100644 --- a/tests/integrations/document/test_proto.py +++ b/tests/integrations/document/test_proto.py @@ -50,10 +50,10 @@ class MyDoc(BaseDocument): np_array_param: NdArray[224, 224, 3] generic_nd_array: AnyTensor generic_torch_tensor: AnyTensor - embedding: AnyEmbedding - torch_embedding: TorchEmbedding[128] - np_embedding: NdArrayEmbedding[128] - nested_docs: DocumentArray[NestedDoc] + # embedding: AnyEmbedding + # torch_embedding: TorchEmbedding[128] + # np_embedding: NdArrayEmbedding[128] + # nested_docs: DocumentArray[NestedDoc] doc = MyDoc( img_url='test.png', @@ -67,42 +67,42 @@ class MyDoc(BaseDocument): np_array_param=np.zeros((3, 224, 224)), generic_nd_array=np.zeros((3, 224, 224)), generic_torch_tensor=torch.zeros((3, 224, 224)), - embedding=np.zeros((3, 224, 224)), - torch_embedding=torch.zeros((128,)), - np_embedding=np.zeros((128,)), - nested_docs=DocumentArray[NestedDoc]([NestedDoc(tensor=np.zeros((128,)))]), + # embedding=np.zeros((3, 224, 224)), + # torch_embedding=torch.zeros((128,)), + # np_embedding=np.zeros((128,)), + # nested_docs=DocumentArray[NestedDoc]([NestedDoc(tensor=np.zeros((128,)))]), ) doc = MyDoc.from_protobuf(doc.to_protobuf()) - - assert doc.img_url == 'test.png' - assert doc.txt_url == 'test.txt' - assert doc.mesh_url == 'test.obj' - assert doc.point_cloud_url == 'test.obj' - assert doc.any_url == 'www.jina.ai' - - assert (doc.torch_tensor == torch.zeros((3, 224, 224))).all() - assert isinstance(doc.torch_tensor, torch.Tensor) - - assert (doc.torch_tensor_param == torch.zeros((224, 224, 3))).all() - assert isinstance(doc.torch_tensor_param, torch.Tensor) - - assert (doc.np_array == np.zeros((3, 224, 224))).all() - assert isinstance(doc.np_array, np.ndarray) - assert doc.np_array.flags.writeable - - assert (doc.np_array_param == np.zeros((224, 224, 3))).all() - assert isinstance(doc.np_array_param, np.ndarray) - - assert (doc.generic_nd_array == np.zeros((3, 224, 224))).all() - assert isinstance(doc.generic_nd_array, np.ndarray) - - assert (doc.generic_torch_tensor == torch.zeros((3, 224, 224))).all() - assert isinstance(doc.generic_torch_tensor, torch.Tensor) - - assert (doc.torch_embedding == torch.zeros((128,))).all() - assert isinstance(doc.torch_embedding, torch.Tensor) - - assert (doc.np_embedding == np.zeros((128,))).all() - assert isinstance(doc.np_embedding, np.ndarray) - - assert (doc.embedding == np.zeros((3, 224, 224))).all() + # + # assert doc.img_url == 'test.png' + # assert doc.txt_url == 'test.txt' + # assert doc.mesh_url == 'test.obj' + # assert doc.point_cloud_url == 'test.obj' + # assert doc.any_url == 'www.jina.ai' + # + # assert (doc.torch_tensor == torch.zeros((3, 224, 224))).all() + # assert isinstance(doc.torch_tensor, torch.Tensor) + # + # assert (doc.torch_tensor_param == torch.zeros((224, 224, 3))).all() + # assert isinstance(doc.torch_tensor_param, torch.Tensor) + # + # assert (doc.np_array == np.zeros((3, 224, 224))).all() + # assert isinstance(doc.np_array, np.ndarray) + # assert doc.np_array.flags.writeable + # + # assert (doc.np_array_param == np.zeros((224, 224, 3))).all() + # assert isinstance(doc.np_array_param, np.ndarray) + # + # assert (doc.generic_nd_array == np.zeros((3, 224, 224))).all() + # assert isinstance(doc.generic_nd_array, np.ndarray) + # + # assert (doc.generic_torch_tensor == torch.zeros((3, 224, 224))).all() + # assert isinstance(doc.generic_torch_tensor, torch.Tensor) + # + # assert (doc.torch_embedding == torch.zeros((128,))).all() + # assert isinstance(doc.torch_embedding, torch.Tensor) + # + # assert (doc.np_embedding == np.zeros((128,))).all() + # assert isinstance(doc.np_embedding, np.ndarray) + # + # assert (doc.embedding == np.zeros((3, 224, 224))).all() From 9754c7d5b9784520d475610fe58a4ec4f8f7ea87 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 17:31:22 +0100 Subject: [PATCH 06/16] wip5 Signed-off-by: Sami Jaghouar --- docarray/base_document/mixins/proto.py | 9 ++- docarray/typing/__init__.py | 3 +- tests/integrations/document/test_proto.py | 78 +++++++++++------------ 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/docarray/base_document/mixins/proto.py b/docarray/base_document/mixins/proto.py index 081d770319a..4ece7c38e49 100644 --- a/docarray/base_document/mixins/proto.py +++ b/docarray/base_document/mixins/proto.py @@ -12,8 +12,6 @@ except ImportError: torch_imported = False else: - from docarray.typing.tensor.torch_tensor import TorchTensor - torch_imported = True @@ -33,9 +31,9 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: NdArray, PointCloud3DUrl, TextUrl, + NdArrayEmbedding, ) - fields: Dict[str, Any] = {} for field in pb_msg.data: @@ -44,6 +42,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: content_type_dict = dict( ndarray=NdArray, embedding=AnyEmbedding, + ndarray_embedding=NdArrayEmbedding, any_url=AnyUrl, text_url=TextUrl, image_url=ImageUrl, @@ -58,7 +57,11 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: ) if torch_imported: + from docarray.typing.tensor.torch_tensor import TorchTensor + from docarray.typing import TorchEmbedding + content_type_dict['torch'] = TorchTensor + content_type_dict['torch_embedding'] = TorchEmbedding if content_type in content_type_dict: fields[field] = content_type_dict[content_type].from_protobuf( diff --git a/docarray/typing/__init__.py b/docarray/typing/__init__.py index 61315b082b5..b22ec80b1e8 100644 --- a/docarray/typing/__init__.py +++ b/docarray/typing/__init__.py @@ -1,6 +1,6 @@ from docarray.typing.id import ID from docarray.typing.tensor.audio import AudioNdArray -from docarray.typing.tensor.embedding.embedding import AnyEmbedding +from docarray.typing.tensor.embedding.embedding import AnyEmbedding, NdArrayEmbedding from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.tensor import AnyTensor from docarray.typing.url import ( @@ -15,6 +15,7 @@ __all__ = [ 'AudioNdArray', 'NdArray', + 'NdArrayEmbedding', 'AnyEmbedding', 'ImageUrl', 'AudioUrl', diff --git a/tests/integrations/document/test_proto.py b/tests/integrations/document/test_proto.py index 7e1fda5624d..97b3b87c20e 100644 --- a/tests/integrations/document/test_proto.py +++ b/tests/integrations/document/test_proto.py @@ -50,9 +50,9 @@ class MyDoc(BaseDocument): np_array_param: NdArray[224, 224, 3] generic_nd_array: AnyTensor generic_torch_tensor: AnyTensor - # embedding: AnyEmbedding - # torch_embedding: TorchEmbedding[128] - # np_embedding: NdArrayEmbedding[128] + embedding: AnyEmbedding + torch_embedding: TorchEmbedding[128] + np_embedding: NdArrayEmbedding[128] # nested_docs: DocumentArray[NestedDoc] doc = MyDoc( @@ -67,42 +67,42 @@ class MyDoc(BaseDocument): np_array_param=np.zeros((3, 224, 224)), generic_nd_array=np.zeros((3, 224, 224)), generic_torch_tensor=torch.zeros((3, 224, 224)), - # embedding=np.zeros((3, 224, 224)), - # torch_embedding=torch.zeros((128,)), - # np_embedding=np.zeros((128,)), + embedding=np.zeros((3, 224, 224)), + torch_embedding=torch.zeros((128,)), + np_embedding=np.zeros((128,)), # nested_docs=DocumentArray[NestedDoc]([NestedDoc(tensor=np.zeros((128,)))]), ) doc = MyDoc.from_protobuf(doc.to_protobuf()) - # - # assert doc.img_url == 'test.png' - # assert doc.txt_url == 'test.txt' - # assert doc.mesh_url == 'test.obj' - # assert doc.point_cloud_url == 'test.obj' - # assert doc.any_url == 'www.jina.ai' - # - # assert (doc.torch_tensor == torch.zeros((3, 224, 224))).all() - # assert isinstance(doc.torch_tensor, torch.Tensor) - # - # assert (doc.torch_tensor_param == torch.zeros((224, 224, 3))).all() - # assert isinstance(doc.torch_tensor_param, torch.Tensor) - # - # assert (doc.np_array == np.zeros((3, 224, 224))).all() - # assert isinstance(doc.np_array, np.ndarray) - # assert doc.np_array.flags.writeable - # - # assert (doc.np_array_param == np.zeros((224, 224, 3))).all() - # assert isinstance(doc.np_array_param, np.ndarray) - # - # assert (doc.generic_nd_array == np.zeros((3, 224, 224))).all() - # assert isinstance(doc.generic_nd_array, np.ndarray) - # - # assert (doc.generic_torch_tensor == torch.zeros((3, 224, 224))).all() - # assert isinstance(doc.generic_torch_tensor, torch.Tensor) - # - # assert (doc.torch_embedding == torch.zeros((128,))).all() - # assert isinstance(doc.torch_embedding, torch.Tensor) - # - # assert (doc.np_embedding == np.zeros((128,))).all() - # assert isinstance(doc.np_embedding, np.ndarray) - # - # assert (doc.embedding == np.zeros((3, 224, 224))).all() + + assert doc.img_url == 'test.png' + assert doc.txt_url == 'test.txt' + assert doc.mesh_url == 'test.obj' + assert doc.point_cloud_url == 'test.obj' + assert doc.any_url == 'www.jina.ai' + + assert (doc.torch_tensor == torch.zeros((3, 224, 224))).all() + assert isinstance(doc.torch_tensor, torch.Tensor) + + assert (doc.torch_tensor_param == torch.zeros((224, 224, 3))).all() + assert isinstance(doc.torch_tensor_param, torch.Tensor) + + assert (doc.np_array == np.zeros((3, 224, 224))).all() + assert isinstance(doc.np_array, np.ndarray) + assert doc.np_array.flags.writeable + + assert (doc.np_array_param == np.zeros((224, 224, 3))).all() + assert isinstance(doc.np_array_param, np.ndarray) + + assert (doc.generic_nd_array == np.zeros((3, 224, 224))).all() + assert isinstance(doc.generic_nd_array, np.ndarray) + + assert (doc.generic_torch_tensor == torch.zeros((3, 224, 224))).all() + assert isinstance(doc.generic_torch_tensor, torch.Tensor) + + assert (doc.torch_embedding == torch.zeros((128,))).all() + assert isinstance(doc.torch_embedding, torch.Tensor) + + assert (doc.np_embedding == np.zeros((128,))).all() + assert isinstance(doc.np_embedding, np.ndarray) + + assert (doc.embedding == np.zeros((3, 224, 224))).all() From 4db211b00191ffdc84592c856c20ac12fd995ea3 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 17:39:01 +0100 Subject: [PATCH 07/16] wip5 Signed-off-by: Sami Jaghouar --- docarray/array/abstract_array.py | 2 +- tests/integrations/document/test_proto.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 1e09b7e4417..4482a553989 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -92,7 +92,7 @@ def _to_node_protobuf(self) -> 'NodeProto': """ from docarray.proto import NodeProto - return NodeProto(document=self.to_protobuf()) + return NodeProto(document_array=self.to_protobuf()) @abstractmethod def traverse_flat( diff --git a/tests/integrations/document/test_proto.py b/tests/integrations/document/test_proto.py index 97b3b87c20e..adabe6deb84 100644 --- a/tests/integrations/document/test_proto.py +++ b/tests/integrations/document/test_proto.py @@ -53,7 +53,7 @@ class MyDoc(BaseDocument): embedding: AnyEmbedding torch_embedding: TorchEmbedding[128] np_embedding: NdArrayEmbedding[128] - # nested_docs: DocumentArray[NestedDoc] + nested_docs: DocumentArray[NestedDoc] doc = MyDoc( img_url='test.png', @@ -70,9 +70,10 @@ class MyDoc(BaseDocument): embedding=np.zeros((3, 224, 224)), torch_embedding=torch.zeros((128,)), np_embedding=np.zeros((128,)), - # nested_docs=DocumentArray[NestedDoc]([NestedDoc(tensor=np.zeros((128,)))]), + nested_docs=DocumentArray[NestedDoc]([NestedDoc(tensor=np.zeros((128,)))]), ) - doc = MyDoc.from_protobuf(doc.to_protobuf()) + doc = doc.to_protobuf() + doc = MyDoc.from_protobuf(doc) assert doc.img_url == 'test.png' assert doc.txt_url == 'test.txt' From f06e0f85a57e4e7309eb6c45f8bdd26aca9f33d5 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 17 Jan 2023 11:04:15 +0100 Subject: [PATCH 08/16] refactor: use register proto decorator Signed-off-by: Sami Jaghouar --- docarray/base_document/mixins/proto.py | 29 ++++--------------- docarray/typing/id.py | 5 ++-- docarray/typing/proto_register.py | 22 ++++++++++++++ docarray/typing/tensor/audio/audio_ndarray.py | 4 +-- .../typing/tensor/audio/audio_torch_tensor.py | 5 ++-- docarray/typing/tensor/embedding/ndarray.py | 4 +-- docarray/typing/tensor/embedding/torch.py | 4 +-- docarray/typing/tensor/ndarray.py | 4 +-- docarray/typing/tensor/torch_tensor.py | 4 +-- docarray/typing/url/any_url.py | 4 +-- docarray/typing/url/audio_url.py | 5 ++-- docarray/typing/url/image_url.py | 5 ++-- docarray/typing/url/text_url.py | 6 ++-- docarray/typing/url/url_3d/mesh_url.py | 10 ++----- docarray/typing/url/url_3d/point_cloud_url.py | 8 ++--- docarray/typing/url/url_3d/url_3d.py | 5 ++-- 16 files changed, 59 insertions(+), 65 deletions(-) create mode 100644 docarray/typing/proto_register.py diff --git a/docarray/base_document/mixins/proto.py b/docarray/base_document/mixins/proto.py index 4ece7c38e49..cb71d964c1f 100644 --- a/docarray/base_document/mixins/proto.py +++ b/docarray/base_document/mixins/proto.py @@ -2,6 +2,7 @@ from docarray.base_document.abstract_document import AbstractDocument from docarray.base_document.base_node import BaseNode +from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS if TYPE_CHECKING: from docarray.proto import DocumentProto, NodeProto @@ -22,34 +23,13 @@ class ProtoMixin(AbstractDocument, BaseNode): @classmethod def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: """create a Document from a protobuf message""" - from docarray.typing import ( - ID, - AnyEmbedding, - AnyUrl, - ImageUrl, - Mesh3DUrl, - NdArray, - PointCloud3DUrl, - TextUrl, - NdArrayEmbedding, - ) fields: Dict[str, Any] = {} for field in pb_msg.data: value = pb_msg.data[field] - content_type_dict = dict( - ndarray=NdArray, - embedding=AnyEmbedding, - ndarray_embedding=NdArrayEmbedding, - any_url=AnyUrl, - text_url=TextUrl, - image_url=ImageUrl, - mesh_url=Mesh3DUrl, - point_cloud_url=PointCloud3DUrl, - id=ID, - ) + content_type_dict = _PROTO_TYPE_NAME_TO_CLASS content_key = value.WhichOneof('content') content_type = ( @@ -57,8 +37,8 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: ) if torch_imported: - from docarray.typing.tensor.torch_tensor import TorchTensor from docarray.typing import TorchEmbedding + from docarray.typing.tensor.torch_tensor import TorchTensor content_type_dict['torch'] = TorchTensor content_type_dict['torch_embedding'] = TorchEmbedding @@ -86,7 +66,8 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T: fields[field] = value.blob else: raise ValueError( - f'type {content_type} is not supported for deserialization' + f'type {content_type}, with key {content_key} is not supported for' + f' deserialization' ) return cls.construct(**fields) diff --git a/docarray/typing/id.py b/docarray/typing/id.py index e64f5a60234..a50b4bcf245 100644 --- a/docarray/typing/id.py +++ b/docarray/typing/id.py @@ -4,6 +4,8 @@ from pydantic import BaseConfig, parse_obj_as from pydantic.fields import ModelField +from docarray.typing.proto_register import register_proto + if TYPE_CHECKING: from docarray.proto import NodeProto @@ -12,13 +14,12 @@ T = TypeVar('T', bound='ID') +@register_proto(proto_type_name='id') class ID(str, AbstractType): """ Represent an unique ID """ - _proto_type_name = 'id' - @classmethod def __get_validators__(cls): yield cls.validate diff --git a/docarray/typing/proto_register.py b/docarray/typing/proto_register.py new file mode 100644 index 00000000000..a8e655809f5 --- /dev/null +++ b/docarray/typing/proto_register.py @@ -0,0 +1,22 @@ +from typing import Callable +from docarray.typing.abstract_type import AbstractType +from typing import Type + + +_PROTO_TYPE_NAME_TO_CLASS = {} + + +def register_proto( + proto_type_name: str, +) -> Callable[[Type[AbstractType]], Type[AbstractType]]: + """Register a new type to be used in the protobuf serialization. + :param cls: the class to register + :return: the class + """ + def _register(cls: Type['AbstractType']) -> Type['AbstractType']: + cls._proto_type_name = proto_type_name + + _PROTO_TYPE_NAME_TO_CLASS[proto_type_name] = cls + return cls + + return _register diff --git a/docarray/typing/tensor/audio/audio_ndarray.py b/docarray/typing/tensor/audio/audio_ndarray.py index ac399a6866e..43e1457b330 100644 --- a/docarray/typing/tensor/audio/audio_ndarray.py +++ b/docarray/typing/tensor/audio/audio_ndarray.py @@ -1,5 +1,6 @@ from typing import TypeVar +from docarray.typing.proto_register import register_proto from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor from docarray.typing.tensor.ndarray import NdArray @@ -7,7 +8,7 @@ T = TypeVar('T', bound='AudioNdArray') - +@register_proto(proto_type_name='audio_ndarray') class AudioNdArray(AbstractAudioTensor, NdArray): """ Subclass of NdArray, to represent an audio tensor. @@ -51,7 +52,6 @@ class MyAudioDoc(Document): doc_2.audio_tensor.save_to_wav_file(file_path='path/to/file_2.wav') """ - _proto_type_name = 'audio_ndarray' def to_audio_bytes(self): tensor = (self * MAX_INT_16).astype(' 'NodeProto': """Convert Document into a NodeProto protobuf message. This function should diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index 378a7a2bc3d..3dbb7d9985a 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -4,6 +4,7 @@ import numpy as np from pydantic import parse_obj_as +from docarray.typing.proto_register import register_proto from docarray.typing.tensor.audio.audio_ndarray import MAX_INT_16, AudioNdArray from docarray.typing.url.any_url import AnyUrl @@ -17,15 +18,13 @@ AUDIO_FILE_FORMATS = ['wav'] - +@register_proto(proto_type_name='audio_url') class AudioUrl(AnyUrl): """ URL to a .wav file. Can be remote (web) URL, or a local file path. """ - _proto_type_name = 'audio_url' - @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index bb9dc4c91ed..0a8afa59a8c 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -4,6 +4,7 @@ import numpy as np +from docarray.typing.proto_register import register_proto from docarray.typing.url.any_url import AnyUrl from docarray.typing.url.helper import _uri_to_blob @@ -18,15 +19,13 @@ IMAGE_FILE_FORMATS = ('png', 'jpeg', 'jpg') - +@register_proto(proto_type_name='image_url') class ImageUrl(AnyUrl): """ URL to a .png, .jpeg, or .jpg file. Can be remote (web) URL, or a local file path. """ - _proto_type_name = 'image_url' - @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 58498682ee2..81ebfdaa1dc 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -1,20 +1,20 @@ from typing import TYPE_CHECKING, Optional +from docarray.typing.proto_register import register_proto + if TYPE_CHECKING: from docarray.proto import NodeProto from docarray.typing.url.any_url import AnyUrl from docarray.typing.url.helper import _uri_to_blob - +@register_proto(proto_type_name='text_url') class TextUrl(AnyUrl): """ URL to a text file. Can be remote (web) URL, or a local file path. """ - _proto_type_name = 'text_url' - def load_to_bytes(self, timeout: Optional[float] = None) -> bytes: """ Load the text file into a bytes object. diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index 157edd8a6ac..d62db2f3026 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -1,22 +1,18 @@ -from typing import TYPE_CHECKING, Tuple, TypeVar +from typing import Tuple, TypeVar import numpy as np +from docarray.typing.proto_register import register_proto from docarray.typing.url.url_3d.url_3d import Url3D -if TYPE_CHECKING: - from docarray.proto import NodeProto - T = TypeVar('T', bound='Mesh3DUrl') - +@register_proto(proto_type_name='mesh_url') class Mesh3DUrl(Url3D): """ URL to a .obj, .glb, or .ply file containing 3D mesh information. Can be remote (web) URL, or a local file path. """ - _proto_type_name = 'mesh_url' - def load(self: T) -> Tuple[np.ndarray, np.ndarray]: """ Load the data from the url into a tuple of two numpy.ndarrays containing diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index f7ddf4e60d7..ddf3dd8b3ce 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -1,21 +1,19 @@ -from typing import TYPE_CHECKING, TypeVar +from typing import TypeVar import numpy as np +from docarray.typing.proto_register import register_proto from docarray.typing.url.url_3d.url_3d import Url3D -if TYPE_CHECKING: - from docarray.proto import NodeProto - T = TypeVar('T', bound='PointCloud3DUrl') +@register_proto(proto_type_name='point_cloud_url') class PointCloud3DUrl(Url3D): """ URL to a .obj, .glb, or .ply file containing point cloud information. Can be remote (web) URL, or a local file path. """ - _proto_type_name = 'point_cloud_url' def load(self: T, samples: int, multiple_geometries: bool = False) -> np.ndarray: """ diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py index 4a65b568b85..5a8100a363d 100644 --- a/docarray/typing/url/url_3d/url_3d.py +++ b/docarray/typing/url/url_3d/url_3d.py @@ -3,6 +3,7 @@ import numpy as np +from docarray.typing.proto_register import register_proto from docarray.typing.url.any_url import AnyUrl if TYPE_CHECKING: @@ -14,14 +15,12 @@ T = TypeVar('T', bound='Url3D') - +@register_proto(proto_type_name='url3d') class Url3D(AnyUrl, ABC): """ URL to a .obj, .glb, or .ply file containing 3D mesh or point cloud information. Can be remote (web) URL, or a local file path. """ - _proto_type_name = 'url3d' - @classmethod def validate( cls: Type[T], From 0f94ee09717d6c2435631a8b7bea08af54867040 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 18 Jan 2023 13:37:19 +0100 Subject: [PATCH 09/16] fix: fix test Signed-off-by: Sami Jaghouar --- tests/units/typing/tensor/test_audio_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/units/typing/tensor/test_audio_tensor.py b/tests/units/typing/tensor/test_audio_tensor.py index e3dbf4579af..9a2bf64fa69 100644 --- a/tests/units/typing/tensor/test_audio_tensor.py +++ b/tests/units/typing/tensor/test_audio_tensor.py @@ -64,7 +64,7 @@ def test_illegal_validation(cls_tensor, tensor): def test_proto_tensor(cls_tensor, tensor, proto_key): tensor = parse_obj_as(cls_tensor, tensor) proto = tensor._to_node_protobuf() - assert str(proto).startswith(proto_key) + assert proto_key in str(proto) @pytest.mark.parametrize( From 496c6dc20d0568e63ed8e10444a6b0180cf93821 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 18 Jan 2023 13:56:35 +0100 Subject: [PATCH 10/16] fix: ruff and black f Signed-off-by: Sami Jaghouar --- docarray/typing/abstract_type.py | 4 ++-- docarray/typing/proto_register.py | 6 +++--- docarray/typing/tensor/audio/audio_ndarray.py | 1 + docarray/typing/tensor/audio/audio_torch_tensor.py | 1 + docarray/typing/tensor/embedding/ndarray.py | 2 +- docarray/typing/tensor/embedding/torch.py | 2 +- docarray/typing/url/any_url.py | 1 + docarray/typing/url/image_url.py | 3 +-- docarray/typing/url/text_url.py | 7 ++----- docarray/typing/url/url_3d/mesh_url.py | 2 ++ docarray/typing/url/url_3d/url_3d.py | 2 ++ 11 files changed, 17 insertions(+), 14 deletions(-) diff --git a/docarray/typing/abstract_type.py b/docarray/typing/abstract_type.py index 0065c700719..97bf9a2c1a2 100644 --- a/docarray/typing/abstract_type.py +++ b/docarray/typing/abstract_type.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Type, TypeVar, Optional +from typing import TYPE_CHECKING, Any, Type, TypeVar from pydantic import BaseConfig from pydantic.fields import ModelField @@ -13,7 +13,7 @@ class AbstractType(BaseNode): - _proto_type_name : str + _proto_type_name: str @classmethod def __get_validators__(cls): diff --git a/docarray/typing/proto_register.py b/docarray/typing/proto_register.py index a8e655809f5..a4e91c2796b 100644 --- a/docarray/typing/proto_register.py +++ b/docarray/typing/proto_register.py @@ -1,7 +1,6 @@ -from typing import Callable -from docarray.typing.abstract_type import AbstractType -from typing import Type +from typing import Callable, Type +from docarray.typing.abstract_type import AbstractType _PROTO_TYPE_NAME_TO_CLASS = {} @@ -13,6 +12,7 @@ def register_proto( :param cls: the class to register :return: the class """ + def _register(cls: Type['AbstractType']) -> Type['AbstractType']: cls._proto_type_name = proto_type_name diff --git a/docarray/typing/tensor/audio/audio_ndarray.py b/docarray/typing/tensor/audio/audio_ndarray.py index 43e1457b330..65026080da5 100644 --- a/docarray/typing/tensor/audio/audio_ndarray.py +++ b/docarray/typing/tensor/audio/audio_ndarray.py @@ -8,6 +8,7 @@ T = TypeVar('T', bound='AudioNdArray') + @register_proto(proto_type_name='audio_ndarray') class AudioNdArray(AbstractAudioTensor, NdArray): """ diff --git a/docarray/typing/tensor/audio/audio_torch_tensor.py b/docarray/typing/tensor/audio/audio_torch_tensor.py index f58a41c9758..b3494c5dee0 100644 --- a/docarray/typing/tensor/audio/audio_torch_tensor.py +++ b/docarray/typing/tensor/audio/audio_torch_tensor.py @@ -7,6 +7,7 @@ T = TypeVar('T', bound='AudioTorchTensor') + @register_proto(proto_type_name='audio_torch_tensor') class AudioTorchTensor(AbstractAudioTensor, TorchTensor, metaclass=metaTorchAndNode): """ diff --git a/docarray/typing/tensor/embedding/ndarray.py b/docarray/typing/tensor/embedding/ndarray.py index 34a6958ce85..24e59579a39 100644 --- a/docarray/typing/tensor/embedding/ndarray.py +++ b/docarray/typing/tensor/embedding/ndarray.py @@ -2,7 +2,7 @@ from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin from docarray.typing.tensor.ndarray import NdArray + @register_proto(proto_type_name='ndarray_embedding') class NdArrayEmbedding(NdArray, EmbeddingMixin): alternative_type = NdArray - diff --git a/docarray/typing/tensor/embedding/torch.py b/docarray/typing/tensor/embedding/torch.py index 9cdba801827..3ff6ffb0559 100644 --- a/docarray/typing/tensor/embedding/torch.py +++ b/docarray/typing/tensor/embedding/torch.py @@ -11,7 +11,7 @@ class metaTorchAndEmbedding(torch_base, embedding_base): pass + @register_proto(proto_type_name='torch_embedding') class TorchEmbedding(TorchTensor, EmbeddingMixin, metaclass=metaTorchAndEmbedding): alternative_type = TorchTensor - diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index 00191dfd90e..11e416f2f95 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -13,6 +13,7 @@ T = TypeVar('T', bound='AnyUrl') + @register_proto(proto_type_name='any_url') class AnyUrl(BaseAnyUrl, AbstractType): host_required = ( diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 0a8afa59a8c..36183908527 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -13,12 +13,11 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.proto import NodeProto - T = TypeVar('T', bound='ImageUrl') IMAGE_FILE_FORMATS = ('png', 'jpeg', 'jpg') + @register_proto(proto_type_name='image_url') class ImageUrl(AnyUrl): """ diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 81ebfdaa1dc..3bad471618f 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -1,13 +1,10 @@ -from typing import TYPE_CHECKING, Optional +from typing import Optional from docarray.typing.proto_register import register_proto - -if TYPE_CHECKING: - from docarray.proto import NodeProto - from docarray.typing.url.any_url import AnyUrl from docarray.typing.url.helper import _uri_to_blob + @register_proto(proto_type_name='text_url') class TextUrl(AnyUrl): """ diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index d62db2f3026..a94d9808a70 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -7,12 +7,14 @@ T = TypeVar('T', bound='Mesh3DUrl') + @register_proto(proto_type_name='mesh_url') class Mesh3DUrl(Url3D): """ URL to a .obj, .glb, or .ply file containing 3D mesh information. Can be remote (web) URL, or a local file path. """ + def load(self: T) -> Tuple[np.ndarray, np.ndarray]: """ Load the data from the url into a tuple of two numpy.ndarrays containing diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py index 5a8100a363d..9219a017750 100644 --- a/docarray/typing/url/url_3d/url_3d.py +++ b/docarray/typing/url/url_3d/url_3d.py @@ -15,12 +15,14 @@ T = TypeVar('T', bound='Url3D') + @register_proto(proto_type_name='url3d') class Url3D(AnyUrl, ABC): """ URL to a .obj, .glb, or .ply file containing 3D mesh or point cloud information. Can be remote (web) URL, or a local file path. """ + @classmethod def validate( cls: Type[T], From a1b3f6203af331718fb98e26fb8b7d257ec0a99e Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Jan 2023 09:49:34 +0100 Subject: [PATCH 11/16] fix: fix tests --- docarray/typing/url/video_url.py | 14 ++------------ tests/units/typing/url/test_audio_url.py | 2 +- tests/units/typing/url/test_video_url.py | 2 +- 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py index fff2dda5d18..ad56a805c7d 100644 --- a/docarray/typing/url/video_url.py +++ b/docarray/typing/url/video_url.py @@ -4,6 +4,7 @@ from pydantic.tools import parse_obj_as from docarray.typing import AudioNdArray, NdArray +from docarray.typing.proto_register import register_proto from docarray.typing.tensor.video import VideoNdArray from docarray.typing.url.any_url import AnyUrl @@ -11,29 +12,18 @@ from pydantic import BaseConfig from pydantic.fields import ModelField - from docarray.proto import NodeProto - T = TypeVar('T', bound='VideoUrl') VIDEO_FILE_FORMATS = ['mp4'] +@register_proto(proto_type_name='video_url') class VideoUrl(AnyUrl): """ URL to a .wav file. Can be remote (web) URL, or a local file path. """ - def _to_node_protobuf(self: T) -> 'NodeProto': - """Convert Document into a NodeProto protobuf message. This function should - be called when the Document is nested into another Document that needs to - be converted into a protobuf - :return: the nested item protobuf message - """ - from docarray.proto import NodeProto - - return NodeProto(video_url=str(self)) - @classmethod def validate( cls: Type[T], diff --git a/tests/units/typing/url/test_audio_url.py b/tests/units/typing/url/test_audio_url.py index 5d0f042fd71..21280aa8eb9 100644 --- a/tests/units/typing/url/test_audio_url.py +++ b/tests/units/typing/url/test_audio_url.py @@ -105,4 +105,4 @@ def test_illegal_validation(path_to_file): def test_proto_audio_url(file_url): uri = parse_obj_as(AudioUrl, file_url) proto = uri._to_node_protobuf() - assert str(proto).startswith('audio_url') + assert 'audio_url' in str(proto) diff --git a/tests/units/typing/url/test_video_url.py b/tests/units/typing/url/test_video_url.py index 02ae5119a59..bfdd59b8de6 100644 --- a/tests/units/typing/url/test_video_url.py +++ b/tests/units/typing/url/test_video_url.py @@ -115,4 +115,4 @@ def test_illegal_validation(path_to_file): def test_proto_video_url(file_url): uri = parse_obj_as(VideoUrl, file_url) proto = uri._to_node_protobuf() - assert str(proto).startswith('video_url') + assert 'video_url' in str(proto) From 17a0008a17872f6d137efa30c7db38efe836b606 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Jan 2023 09:59:39 +0100 Subject: [PATCH 12/16] docs: update docstring --- docarray/typing/proto_register.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docarray/typing/proto_register.py b/docarray/typing/proto_register.py index a4e91c2796b..9945b64f39b 100644 --- a/docarray/typing/proto_register.py +++ b/docarray/typing/proto_register.py @@ -9,6 +9,22 @@ def register_proto( proto_type_name: str, ) -> Callable[[Type[AbstractType]], Type[AbstractType]]: """Register a new type to be used in the protobuf serialization. + + This will add the type key to the global registry of types key used in the proto + serialization and deserialization. This is for internal usage only. + + EXAMPLE USAGE + + .. code-block:: python + + from docarray.typing.proto_register import register_proto + from docarray.typing.abstract_type import AbstractType + + + @register_proto(proto_type_name='my_type') + class MyType(AbstractType): + ... + :param cls: the class to register :return: the class """ From 721ffc2a51e7877aeebcaae4dd4bac84d726152a Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Jan 2023 15:54:24 +0100 Subject: [PATCH 13/16] feat: make key in proto unique and raise error if duplication --- docarray/typing/proto_register.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docarray/typing/proto_register.py b/docarray/typing/proto_register.py index 9945b64f39b..0e12669239d 100644 --- a/docarray/typing/proto_register.py +++ b/docarray/typing/proto_register.py @@ -29,6 +29,11 @@ class MyType(AbstractType): :return: the class """ + if proto_type_name in _PROTO_TYPE_NAME_TO_CLASS.keys(): + raise ValueError( + f'the key {proto_type_name} is already registered in the global registry' + ) + def _register(cls: Type['AbstractType']) -> Type['AbstractType']: cls._proto_type_name = proto_type_name From f8135dfa476a77b7e69cce5b6c711888e0f02cea Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Jan 2023 15:58:56 +0100 Subject: [PATCH 14/16] refactor: make register function private --- docarray/typing/id.py | 4 ++-- docarray/typing/proto_register.py | 6 +++--- docarray/typing/tensor/audio/audio_ndarray.py | 4 ++-- docarray/typing/tensor/audio/audio_torch_tensor.py | 4 ++-- docarray/typing/tensor/embedding/ndarray.py | 4 ++-- docarray/typing/tensor/embedding/torch.py | 4 ++-- docarray/typing/tensor/ndarray.py | 4 ++-- docarray/typing/tensor/torch_tensor.py | 4 ++-- docarray/typing/tensor/video/video_ndarray.py | 4 ++-- docarray/typing/tensor/video/video_torch_tensor.py | 4 ++-- docarray/typing/url/any_url.py | 4 ++-- docarray/typing/url/audio_url.py | 4 ++-- docarray/typing/url/image_url.py | 4 ++-- docarray/typing/url/text_url.py | 4 ++-- docarray/typing/url/url_3d/mesh_url.py | 4 ++-- docarray/typing/url/url_3d/point_cloud_url.py | 4 ++-- docarray/typing/url/url_3d/url_3d.py | 4 ++-- docarray/typing/url/video_url.py | 4 ++-- 18 files changed, 37 insertions(+), 37 deletions(-) diff --git a/docarray/typing/id.py b/docarray/typing/id.py index a50b4bcf245..02ee39245b6 100644 --- a/docarray/typing/id.py +++ b/docarray/typing/id.py @@ -4,7 +4,7 @@ from pydantic import BaseConfig, parse_obj_as from pydantic.fields import ModelField -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto if TYPE_CHECKING: from docarray.proto import NodeProto @@ -14,7 +14,7 @@ T = TypeVar('T', bound='ID') -@register_proto(proto_type_name='id') +@_register_proto(proto_type_name='id') class ID(str, AbstractType): """ Represent an unique ID diff --git a/docarray/typing/proto_register.py b/docarray/typing/proto_register.py index 0e12669239d..4a1fe77dad9 100644 --- a/docarray/typing/proto_register.py +++ b/docarray/typing/proto_register.py @@ -1,11 +1,11 @@ -from typing import Callable, Type +from typing import Callable, Dict, Type from docarray.typing.abstract_type import AbstractType -_PROTO_TYPE_NAME_TO_CLASS = {} +_PROTO_TYPE_NAME_TO_CLASS: Dict[str, Type[AbstractType]] = {} -def register_proto( +def _register_proto( proto_type_name: str, ) -> Callable[[Type[AbstractType]], Type[AbstractType]]: """Register a new type to be used in the protobuf serialization. diff --git a/docarray/typing/tensor/audio/audio_ndarray.py b/docarray/typing/tensor/audio/audio_ndarray.py index 65026080da5..807cbd55fbb 100644 --- a/docarray/typing/tensor/audio/audio_ndarray.py +++ b/docarray/typing/tensor/audio/audio_ndarray.py @@ -1,6 +1,6 @@ from typing import TypeVar -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor from docarray.typing.tensor.ndarray import NdArray @@ -9,7 +9,7 @@ T = TypeVar('T', bound='AudioNdArray') -@register_proto(proto_type_name='audio_ndarray') +@_register_proto(proto_type_name='audio_ndarray') class AudioNdArray(AbstractAudioTensor, NdArray): """ Subclass of NdArray, to represent an audio tensor. diff --git a/docarray/typing/tensor/audio/audio_torch_tensor.py b/docarray/typing/tensor/audio/audio_torch_tensor.py index b3494c5dee0..d74ca407afd 100644 --- a/docarray/typing/tensor/audio/audio_torch_tensor.py +++ b/docarray/typing/tensor/audio/audio_torch_tensor.py @@ -1,6 +1,6 @@ from typing import TypeVar -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor from docarray.typing.tensor.audio.audio_ndarray import MAX_INT_16 from docarray.typing.tensor.torch_tensor import TorchTensor, metaTorchAndNode @@ -8,7 +8,7 @@ T = TypeVar('T', bound='AudioTorchTensor') -@register_proto(proto_type_name='audio_torch_tensor') +@_register_proto(proto_type_name='audio_torch_tensor') class AudioTorchTensor(AbstractAudioTensor, TorchTensor, metaclass=metaTorchAndNode): """ Subclass of TorchTensor, to represent an audio tensor. diff --git a/docarray/typing/tensor/embedding/ndarray.py b/docarray/typing/tensor/embedding/ndarray.py index 24e59579a39..631268e7c26 100644 --- a/docarray/typing/tensor/embedding/ndarray.py +++ b/docarray/typing/tensor/embedding/ndarray.py @@ -1,8 +1,8 @@ -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin from docarray.typing.tensor.ndarray import NdArray -@register_proto(proto_type_name='ndarray_embedding') +@_register_proto(proto_type_name='ndarray_embedding') class NdArrayEmbedding(NdArray, EmbeddingMixin): alternative_type = NdArray diff --git a/docarray/typing/tensor/embedding/torch.py b/docarray/typing/tensor/embedding/torch.py index 3ff6ffb0559..6178144045f 100644 --- a/docarray/typing/tensor/embedding/torch.py +++ b/docarray/typing/tensor/embedding/torch.py @@ -1,6 +1,6 @@ from typing import Any # noqa: F401 -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin from docarray.typing.tensor.torch_tensor import TorchTensor @@ -12,6 +12,6 @@ class metaTorchAndEmbedding(torch_base, embedding_base): pass -@register_proto(proto_type_name='torch_embedding') +@_register_proto(proto_type_name='torch_embedding') class TorchEmbedding(TorchTensor, EmbeddingMixin, metaclass=metaTorchAndEmbedding): alternative_type = TorchTensor diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index eea7e32c4cc..d5d3edf6ab3 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -13,7 +13,7 @@ import numpy as np -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: @@ -37,7 +37,7 @@ class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ign pass -@register_proto(proto_type_name='ndarray') +@_register_proto(proto_type_name='ndarray') class NdArray(np.ndarray, AbstractTensor, Generic[ShapeT]): """ Subclass of np.ndarray, intended for use in a Document. diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 5d3a3a0dcef..3b26ac60b3e 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -4,7 +4,7 @@ import numpy as np import torch # type: ignore -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.abstract_tensor import AbstractTensor if TYPE_CHECKING: @@ -33,7 +33,7 @@ class metaTorchAndNode( pass -@register_proto(proto_type_name='torch') +@_register_proto(proto_type_name='torch') class TorchTensor( torch.Tensor, AbstractTensor, Generic[ShapeT], metaclass=metaTorchAndNode ): diff --git a/docarray/typing/tensor/video/video_ndarray.py b/docarray/typing/tensor/video/video_ndarray.py index 42df8265a52..345499634e5 100644 --- a/docarray/typing/tensor/video/video_ndarray.py +++ b/docarray/typing/tensor/video/video_ndarray.py @@ -2,7 +2,7 @@ import numpy as np -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.ndarray import NdArray from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin @@ -13,7 +13,7 @@ from pydantic.fields import ModelField -@register_proto(proto_type_name='video_ndarray') +@_register_proto(proto_type_name='video_ndarray') class VideoNdArray(NdArray, VideoTensorMixin): """ Subclass of NdArray, to represent a video tensor. diff --git a/docarray/typing/tensor/video/video_torch_tensor.py b/docarray/typing/tensor/video/video_torch_tensor.py index 0f81b33e531..92a477915e9 100644 --- a/docarray/typing/tensor/video/video_torch_tensor.py +++ b/docarray/typing/tensor/video/video_torch_tensor.py @@ -2,7 +2,7 @@ import numpy as np -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.torch_tensor import TorchTensor, metaTorchAndNode from docarray.typing.tensor.video.video_tensor_mixin import VideoTensorMixin @@ -13,7 +13,7 @@ from pydantic.fields import ModelField -@register_proto(proto_type_name='video_torch_tensor') +@_register_proto(proto_type_name='video_torch_tensor') class VideoTorchTensor(TorchTensor, VideoTensorMixin, metaclass=metaTorchAndNode): """ Subclass of TorchTensor, to represent a video tensor. diff --git a/docarray/typing/url/any_url.py b/docarray/typing/url/any_url.py index 11e416f2f95..700c1af2e3a 100644 --- a/docarray/typing/url/any_url.py +++ b/docarray/typing/url/any_url.py @@ -4,7 +4,7 @@ from pydantic import errors, parse_obj_as from docarray.typing.abstract_type import AbstractType -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto if TYPE_CHECKING: from pydantic.networks import Parts @@ -14,7 +14,7 @@ T = TypeVar('T', bound='AnyUrl') -@register_proto(proto_type_name='any_url') +@_register_proto(proto_type_name='any_url') class AnyUrl(BaseAnyUrl, AbstractType): host_required = ( False # turn off host requirement to allow passing of local paths as URL diff --git a/docarray/typing/url/audio_url.py b/docarray/typing/url/audio_url.py index 824f9818a55..f04fdb1c7fa 100644 --- a/docarray/typing/url/audio_url.py +++ b/docarray/typing/url/audio_url.py @@ -4,7 +4,7 @@ import numpy as np from pydantic import parse_obj_as -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.audio.audio_ndarray import MAX_INT_16, AudioNdArray from docarray.typing.url.any_url import AnyUrl @@ -17,7 +17,7 @@ AUDIO_FILE_FORMATS = ['wav'] -@register_proto(proto_type_name='audio_url') +@_register_proto(proto_type_name='audio_url') class AudioUrl(AnyUrl): """ URL to a .wav file. diff --git a/docarray/typing/url/image_url.py b/docarray/typing/url/image_url.py index 36183908527..4e7300a33d9 100644 --- a/docarray/typing/url/image_url.py +++ b/docarray/typing/url/image_url.py @@ -4,7 +4,7 @@ import numpy as np -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl from docarray.typing.url.helper import _uri_to_blob @@ -18,7 +18,7 @@ IMAGE_FILE_FORMATS = ('png', 'jpeg', 'jpg') -@register_proto(proto_type_name='image_url') +@_register_proto(proto_type_name='image_url') class ImageUrl(AnyUrl): """ URL to a .png, .jpeg, or .jpg file. diff --git a/docarray/typing/url/text_url.py b/docarray/typing/url/text_url.py index 3bad471618f..0b6ddaacdad 100644 --- a/docarray/typing/url/text_url.py +++ b/docarray/typing/url/text_url.py @@ -1,11 +1,11 @@ from typing import Optional -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl from docarray.typing.url.helper import _uri_to_blob -@register_proto(proto_type_name='text_url') +@_register_proto(proto_type_name='text_url') class TextUrl(AnyUrl): """ URL to a text file. diff --git a/docarray/typing/url/url_3d/mesh_url.py b/docarray/typing/url/url_3d/mesh_url.py index c2e39f31f3c..bd5a1de408e 100644 --- a/docarray/typing/url/url_3d/mesh_url.py +++ b/docarray/typing/url/url_3d/mesh_url.py @@ -4,7 +4,7 @@ from pydantic import parse_obj_as from docarray.typing import NdArray -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.url.url_3d.url_3d import Url3D T = TypeVar('T', bound='Mesh3DUrl') @@ -15,7 +15,7 @@ class Mesh3DLoadResult(NamedTuple): faces: NdArray -@register_proto(proto_type_name='mesh_url') +@_register_proto(proto_type_name='mesh_url') class Mesh3DUrl(Url3D): """ URL to a .obj, .glb, or .ply file containing 3D mesh information. diff --git a/docarray/typing/url/url_3d/point_cloud_url.py b/docarray/typing/url/url_3d/point_cloud_url.py index 0b42fbfa7b5..ad57ff85ace 100644 --- a/docarray/typing/url/url_3d/point_cloud_url.py +++ b/docarray/typing/url/url_3d/point_cloud_url.py @@ -4,13 +4,13 @@ from pydantic import parse_obj_as from docarray.typing import NdArray -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.url.url_3d.url_3d import Url3D T = TypeVar('T', bound='PointCloud3DUrl') -@register_proto(proto_type_name='point_cloud_url') +@_register_proto(proto_type_name='point_cloud_url') class PointCloud3DUrl(Url3D): """ URL to a .obj, .glb, or .ply file containing point cloud information. diff --git a/docarray/typing/url/url_3d/url_3d.py b/docarray/typing/url/url_3d/url_3d.py index 9219a017750..cebd7e94080 100644 --- a/docarray/typing/url/url_3d/url_3d.py +++ b/docarray/typing/url/url_3d/url_3d.py @@ -3,7 +3,7 @@ import numpy as np -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.url.any_url import AnyUrl if TYPE_CHECKING: @@ -16,7 +16,7 @@ T = TypeVar('T', bound='Url3D') -@register_proto(proto_type_name='url3d') +@_register_proto(proto_type_name='url3d') class Url3D(AnyUrl, ABC): """ URL to a .obj, .glb, or .ply file containing 3D mesh or point cloud information. diff --git a/docarray/typing/url/video_url.py b/docarray/typing/url/video_url.py index 074fcfd294b..7e171ec926a 100644 --- a/docarray/typing/url/video_url.py +++ b/docarray/typing/url/video_url.py @@ -4,7 +4,7 @@ from pydantic.tools import parse_obj_as from docarray.typing import AudioNdArray, NdArray -from docarray.typing.proto_register import register_proto +from docarray.typing.proto_register import _register_proto from docarray.typing.tensor.video import VideoNdArray from docarray.typing.url.any_url import AnyUrl @@ -23,7 +23,7 @@ class VideoLoadResult(NamedTuple): key_frame_indices: NdArray -@register_proto(proto_type_name='video_url') +@_register_proto(proto_type_name='video_url') class VideoUrl(AnyUrl): """ URL to a .wav file. From 3cae8a1862927a2d5ae10a84fe3d1242d0b29a84 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Thu, 19 Jan 2023 16:31:27 +0100 Subject: [PATCH 15/16] feat: apply charlotte suggestion Co-authored-by: Charlotte Gerhaher Signed-off-by: samsja <55492238+samsja@users.noreply.github.com> --- docarray/typing/tensor/torch_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 3b26ac60b3e..cdbd75f2a71 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -33,7 +33,7 @@ class metaTorchAndNode( pass -@_register_proto(proto_type_name='torch') +@_register_proto(proto_type_name='torch_tensor') class TorchTensor( torch.Tensor, AbstractTensor, Generic[ShapeT], metaclass=metaTorchAndNode ): From 63988a50f968df75590b5be955ebf0834c68bc2c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Jan 2023 16:33:24 +0100 Subject: [PATCH 16/16] fix: remove useless test --- tests/units/document/proto/test_proto_based_object.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/units/document/proto/test_proto_based_object.py b/tests/units/document/proto/test_proto_based_object.py index f5af227c44d..6d2b7a79b7b 100644 --- a/tests/units/document/proto/test_proto_based_object.py +++ b/tests/units/document/proto/test_proto_based_object.py @@ -4,15 +4,6 @@ from docarray.typing import NdArray -def test_nested_item_proto(): - NodeProto(text='hello') - NodeProto(document=DocumentProto()) - - -def test_nested_optional_item_proto(): - NodeProto() - - def test_ndarray(): original_ndarray = np.zeros((3, 224, 224))