Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docarray/array/abstract_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _to_node_protobuf(self) -> 'NodeProto':
"""
from docarray.proto import NodeProto

return NodeProto(chunks=self.to_protobuf())
return NodeProto(document_array=self.to_protobuf())

@abstractmethod
def traverse_flat(
Expand Down
61 changes: 25 additions & 36 deletions docarray/base_document/mixins/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from docarray.base_document.abstract_document import AbstractDocument
from docarray.base_document.base_node import BaseNode
from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS

if TYPE_CHECKING:
from docarray.proto import DocumentProto, NodeProto
Expand All @@ -12,8 +13,6 @@
except ImportError:
torch_imported = False
else:
from docarray.typing.tensor.torch_tensor import TorchTensor

torch_imported = True


Expand All @@ -24,61 +23,51 @@ class ProtoMixin(AbstractDocument, BaseNode):
@classmethod
def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T:
"""create a Document from a protobuf message"""
from docarray.typing import ( # TorchTensor,
ID,
AnyEmbedding,
AnyUrl,
ImageUrl,
Mesh3DUrl,
NdArray,
PointCloud3DUrl,
TextUrl,
)

fields: Dict[str, Any] = {}

for field in pb_msg.data:
value = pb_msg.data[field]

content_type = value.WhichOneof('content')

# this if else statement need to be refactored it is too long
# the check should be delegated to the type level
content_type_dict = dict(
ndarray=NdArray,
embedding=AnyEmbedding,
any_url=AnyUrl,
text_url=TextUrl,
image_url=ImageUrl,
mesh_url=Mesh3DUrl,
point_cloud_url=PointCloud3DUrl,
id=ID,
content_type_dict = _PROTO_TYPE_NAME_TO_CLASS

content_key = value.WhichOneof('content')
content_type = (
value.type if value.WhichOneof('docarray_type') is not None else None
)

if torch_imported:
content_type_dict['torch_tensor'] = TorchTensor
from docarray.typing import TorchEmbedding
from docarray.typing.tensor.torch_tensor import TorchTensor

content_type_dict['torch'] = TorchTensor
content_type_dict['torch_embedding'] = TorchEmbedding

if content_type in content_type_dict:
fields[field] = content_type_dict[content_type].from_protobuf(
getattr(value, content_type)
getattr(value, content_key)
)
elif content_type == 'text':
fields[field] = value.text
elif content_type == 'nested':
elif content_key == 'document':
fields[field] = cls._get_field_type(field).from_protobuf(
value.nested
value.document
) # we get to the parent class
elif content_type == 'chunks':
elif content_key == 'document_array':
from docarray import DocumentArray

fields[field] = DocumentArray.from_protobuf(
value.chunks
value.document_array
) # we get to the parent class
elif content_type is None:
elif content_key is None:
fields[field] = None
elif content_type is None:
if content_key == 'text':
fields[field] = value.text
elif content_key == 'blob':
fields[field] = value.blob
else:
raise ValueError(
f'type {content_type} is not supported for deserialization'
f'type {content_type}, with key {content_key} is not supported for'
f' deserialization'
)

return cls.construct(**fields)
Expand Down Expand Up @@ -133,4 +122,4 @@ def _to_node_protobuf(self) -> 'NodeProto':

:return: the nested item protobuf message
"""
return NodeProto(nested=self.to_protobuf())
return NodeProto(document=self.to_protobuf())
43 changes: 5 additions & 38 deletions docarray/proto/docarray.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,57 +30,24 @@ message NdArrayProto {
//
message NodeProto {



oneof content {
bytes blob = 1;

// the ndarray of the image/audio/video document
NdArrayProto ndarray = 2;

// a text
string text = 3;

// a sub Document
DocumentProto nested = 4;

DocumentProto document = 4;
// a sub DocumentArray
DocumentArrayProto chunks = 5;

NdArrayProto embedding = 6;

string any_url = 7;

string image_url = 8;

string text_url = 9;

string id = 10;

NdArrayProto torch_tensor = 11;

string mesh_url = 12;

string point_cloud_url = 13;

string audio_url = 14;

NdArrayProto audio_ndarray = 15;

NdArrayProto audio_torch_tensor = 16;

string video_url = 17;

NdArrayProto video_ndarray = 18;

NdArrayProto video_torch_tensor = 19;
DocumentArrayProto document_array = 5;
}

oneof docarray_type {
string type = 6;
}

}



/**
* Represents a Document
*/
Expand Down
28 changes: 14 additions & 14 deletions docarray/proto/pb2/docarray_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions docarray/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

__all__ = [
'NdArray',
'NdArrayEmbedding',
'AudioNdArray',
'VideoNdArray',
'AnyEmbedding',
Expand Down
2 changes: 2 additions & 0 deletions docarray/typing/abstract_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class AbstractType(BaseNode):
_proto_type_name: str

@classmethod
def __get_validators__(cls):
yield cls.validate
Expand Down
5 changes: 4 additions & 1 deletion docarray/typing/id.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pydantic import BaseConfig, parse_obj_as
from pydantic.fields import ModelField

from docarray.typing.proto_register import _register_proto

if TYPE_CHECKING:
from docarray.proto import NodeProto

Expand All @@ -12,6 +14,7 @@
T = TypeVar('T', bound='ID')


@_register_proto(proto_type_name='id')
class ID(str, AbstractType):
"""
Represent an unique ID
Expand Down Expand Up @@ -44,7 +47,7 @@ def _to_node_protobuf(self) -> 'NodeProto':
"""
from docarray.proto import NodeProto

return NodeProto(id=self)
return NodeProto(text=self, type=self._proto_type_name)

@classmethod
def from_protobuf(cls: Type[T], pb_msg: 'str') -> T:
Expand Down
43 changes: 43 additions & 0 deletions docarray/typing/proto_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Callable, Dict, Type

from docarray.typing.abstract_type import AbstractType

_PROTO_TYPE_NAME_TO_CLASS: Dict[str, Type[AbstractType]] = {}


def _register_proto(
proto_type_name: str,
) -> Callable[[Type[AbstractType]], Type[AbstractType]]:
"""Register a new type to be used in the protobuf serialization.

This will add the type key to the global registry of types key used in the proto
serialization and deserialization. This is for internal usage only.

EXAMPLE USAGE

.. code-block:: python

from docarray.typing.proto_register import register_proto
from docarray.typing.abstract_type import AbstractType


@register_proto(proto_type_name='my_type')
class MyType(AbstractType):
...

:param cls: the class to register
:return: the class
"""

if proto_type_name in _PROTO_TYPE_NAME_TO_CLASS.keys():
raise ValueError(
f'the key {proto_type_name} is already registered in the global registry'
)

def _register(cls: Type['AbstractType']) -> Type['AbstractType']:
cls._proto_type_name = proto_type_name

_PROTO_TYPE_NAME_TO_CLASS[proto_type_name] = cls
return cls

return _register
16 changes: 14 additions & 2 deletions docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import BaseConfig
from pydantic.fields import ModelField

from docarray.proto import NdArrayProto
from docarray.proto import NdArrayProto, NodeProto

T = TypeVar('T', bound='AbstractTensor')
TTensor = TypeVar('TTensor')
Expand Down Expand Up @@ -70,7 +70,19 @@ def __instancecheck__(cls, instance):
class AbstractTensor(Generic[TTensor, T], AbstractType, ABC):

__parametrized_meta__: type = _ParametrizedMeta
_PROTO_FIELD_NAME: str
_proto_type_name: str

def _to_node_protobuf(self: T) -> 'NodeProto':
"""Convert itself into a NodeProto protobuf message. This function should
be called when the Document is nested into another Document that need to be
converted into a protobuf
:param field: field in which to store the content in the node proto
:return: the nested item protobuf message
"""
from docarray.proto import NodeProto

nd_proto = self.to_protobuf()
return NodeProto(ndarray=nd_proto, type=self._proto_type_name)

@classmethod
def __docarray_validate_shape__(cls, t: T, shape: Tuple[Union[int, str]]) -> T:
Expand Down
4 changes: 2 additions & 2 deletions docarray/typing/tensor/audio/audio_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TypeVar

from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor
from docarray.typing.tensor.ndarray import NdArray

Expand All @@ -8,6 +9,7 @@
T = TypeVar('T', bound='AudioNdArray')


@_register_proto(proto_type_name='audio_ndarray')
class AudioNdArray(AbstractAudioTensor, NdArray):
"""
Subclass of NdArray, to represent an audio tensor.
Expand Down Expand Up @@ -52,8 +54,6 @@ class MyAudioDoc(Document):

"""

_PROTO_FIELD_NAME = 'audio_ndarray'

def to_audio_bytes(self):
tensor = (self * MAX_INT_16).astype('<h')
return tensor.tobytes()
4 changes: 2 additions & 2 deletions docarray/typing/tensor/audio/audio_torch_tensor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import TypeVar

from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor
from docarray.typing.tensor.audio.audio_ndarray import MAX_INT_16
from docarray.typing.tensor.torch_tensor import TorchTensor, metaTorchAndNode

T = TypeVar('T', bound='AudioTorchTensor')


@_register_proto(proto_type_name='audio_torch_tensor')
class AudioTorchTensor(AbstractAudioTensor, TorchTensor, metaclass=metaTorchAndNode):
"""
Subclass of TorchTensor, to represent an audio tensor.
Expand Down Expand Up @@ -50,8 +52,6 @@ class MyAudioDoc(Document):

"""

_PROTO_FIELD_NAME = 'audio_torch_tensor'

def to_audio_bytes(self):
import torch

Expand Down
2 changes: 2 additions & 0 deletions docarray/typing/tensor/embedding/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin
from docarray.typing.tensor.ndarray import NdArray


@_register_proto(proto_type_name='ndarray_embedding')
class NdArrayEmbedding(NdArray, EmbeddingMixin):
alternative_type = NdArray
Loading