Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docarray/base_doc/any_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ class AnyDoc(BaseDoc):
AnyDoc is a Document that is not tied to any schema
"""

class Config:
_load_extra_fields_from_protobuf = True # I introduce this variable to allow to load more that the fields defined in the schema
# will documented this behavior later if this fix our problem

def __init__(self, **kwargs):
super().__init__()
self.__dict__.update(kwargs)
Expand All @@ -22,3 +26,9 @@ def _get_field_type(cls, field: str) -> Type['BaseDoc']:
:return:
"""
return AnyDoc

@classmethod
def _get_field_type_array(cls, field: str) -> Type:
from docarray import DocList

return DocList
24 changes: 23 additions & 1 deletion docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,28 @@

class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode):
"""
The base class for Documents
BaseDoc is the base class for all Documents. This class should be subclassed
to create new Document types with a specific schema.

The schema of a Document is defined by the fields of the class.

Example:
```python
from docarray import BaseDoc
from docarray.typing import NdArray, ImageUrl
import numpy as np


class MyDoc(BaseDoc):
embedding: NdArray[512]
image: ImageUrl


doc = MyDoc(embedding=np.zeros(512), image='https://example.com/image.jpg')
```


BaseDoc is a subclass of [pydantic.BaseModel](https://docs.pydantic.dev/usage/models/) and can be used in a similar way.
"""

id: Optional[ID] = Field(default_factory=lambda: ID(os.urandom(16).hex()))
Expand All @@ -50,6 +71,7 @@ class Config:
json_encoders = {AbstractTensor: lambda x: x}

validate_assignment = True
_load_extra_fields_from_protobuf = False

@classmethod
def from_view(cls: Type[T], storage_view: 'ColumnStorageView') -> T:
Expand Down
31 changes: 25 additions & 6 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
if torch is not None:
from docarray.typing import TorchTensor


T = TypeVar('T', bound='IOMixin')


Expand Down Expand Up @@ -122,11 +121,18 @@ class IOMixin(Iterable[Tuple[str, Any]]):

__fields__: Dict[str, 'ModelField']

class Config:
_load_extra_fields_from_protobuf: bool

@classmethod
@abstractmethod
def _get_field_type(cls, field: str) -> Type:
...

@classmethod
def _get_field_type_array(cls, field: str) -> Type:
return cls._get_field_type(field)

def __bytes__(self) -> bytes:
return self.to_bytes()

Expand All @@ -149,7 +155,8 @@ def to_bytes(
bstr = self.to_protobuf().SerializePartialToString()
else:
raise ValueError(
f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.'
f'protocol={protocol} is not supported. Can be only `protobuf` or '
f'pickle protocols 0-5.'
)
return _compress_bytes(bstr, algorithm=compress)

Expand Down Expand Up @@ -178,7 +185,8 @@ def from_bytes(
return cls.from_protobuf(pb_msg)
else:
raise ValueError(
f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.'
f'protocol={protocol} is not supported. Can be only `protobuf` or '
f'pickle protocols 0-5.'
)

def to_base64(
Expand Down Expand Up @@ -219,7 +227,10 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T:
fields: Dict[str, Any] = {}

for field_name in pb_msg.data:
if field_name not in cls.__fields__.keys():
if (
not (cls.Config._load_extra_fields_from_protobuf)
and field_name not in cls.__fields__.keys()
):
continue # optimization we don't even load the data if the key does not
# match any field in the cls or in the mapping

Expand Down Expand Up @@ -253,14 +264,22 @@ def _get_content_from_node_proto(
return_field = content_type_dict[docarray_type].from_protobuf(
getattr(value, content_key)
)
elif content_key in ['doc', 'doc_array']:
elif content_key == 'doc':
if field_name is None:
raise ValueError(
'field_name cannot be None when trying to deseriliaze a Document or a DocList'
'field_name cannot be None when trying to deseriliaze a BaseDoc'
)
return_field = cls._get_field_type(field_name).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
elif content_key == 'doc_array':
if field_name is None:
raise ValueError(
'field_name cannot be None when trying to deseriliaze a BaseDoc'
)
return_field = cls._get_field_type_array(field_name).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
elif content_key is None:
return_field = None
elif docarray_type is None:
Expand Down
32 changes: 32 additions & 0 deletions tests/units/array/test_array_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from docarray import BaseDoc, DocList
from docarray.base_doc import AnyDoc
from docarray.documents import ImageDoc, TextDoc
from docarray.typing import NdArray

Expand Down Expand Up @@ -59,3 +60,34 @@ class CustomDocument(BaseDoc):
)

DocList.from_protobuf(da.to_protobuf())


@pytest.mark.proto
def test_any_doc_list_proto():
doc = AnyDoc(hello='world')
pt = DocList([doc]).to_protobuf()
docs = DocList.from_protobuf(pt)
assert docs[0].dict()['hello'] == 'world'


@pytest.mark.proto
def test_any_nested_doc_list_proto():
from docarray import BaseDoc, DocList

class TextDocWithId(BaseDoc):
id: str
text: str

class ResultTestDoc(BaseDoc):
matches: DocList[TextDocWithId]

index_da = DocList[TextDocWithId](
[TextDocWithId(id=f'{i}', text=f'ID {i}') for i in range(10)]
)

out_da = DocList[ResultTestDoc]([ResultTestDoc(matches=index_da[0:2])])
pb = out_da.to_protobuf()
docs = DocList.from_protobuf(pb)
assert docs[0].matches[0].id == '0'
assert len(docs[0].matches) == 2
assert len(docs) == 1
10 changes: 9 additions & 1 deletion tests/units/document/proto/test_document_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from docarray import DocList
from docarray.base_doc import BaseDoc
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.typing import NdArray, TorchTensor
from docarray.utils._internal.misc import is_tf_available

Expand Down Expand Up @@ -296,3 +296,11 @@ class MyDoc(BaseDoc):
doc = MyDoc(data=data)

MyDoc.from_protobuf(doc.to_protobuf())


@pytest.mark.proto
def test_any_doc_proto():
doc = AnyDoc(hello='world')
pt = doc.to_protobuf()
doc2 = AnyDoc.from_protobuf(pt)
assert doc2.dict()['hello'] == 'world'