diff --git a/docarray/base_doc/any_doc.py b/docarray/base_doc/any_doc.py index f29b6f6e01e..e04c256f8bb 100644 --- a/docarray/base_doc/any_doc.py +++ b/docarray/base_doc/any_doc.py @@ -8,6 +8,10 @@ class AnyDoc(BaseDoc): AnyDoc is a Document that is not tied to any schema """ + class Config: + _load_extra_fields_from_protobuf = True # I introduce this variable to allow to load more that the fields defined in the schema + # will documented this behavior later if this fix our problem + def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) @@ -22,3 +26,9 @@ def _get_field_type(cls, field: str) -> Type['BaseDoc']: :return: """ return AnyDoc + + @classmethod + def _get_field_type_array(cls, field: str) -> Type: + from docarray import DocList + + return DocList diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 0ed39bd0d49..1fa9bc3e376 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -36,7 +36,28 @@ class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode): """ - The base class for Documents + BaseDoc is the base class for all Documents. This class should be subclassed + to create new Document types with a specific schema. + + The schema of a Document is defined by the fields of the class. + + Example: + ```python + from docarray import BaseDoc + from docarray.typing import NdArray, ImageUrl + import numpy as np + + + class MyDoc(BaseDoc): + embedding: NdArray[512] + image: ImageUrl + + + doc = MyDoc(embedding=np.zeros(512), image='https://example.com/image.jpg') + ``` + + + BaseDoc is a subclass of [pydantic.BaseModel](https://docs.pydantic.dev/usage/models/) and can be used in a similar way. """ id: Optional[ID] = Field(default_factory=lambda: ID(os.urandom(16).hex())) @@ -50,6 +71,7 @@ class Config: json_encoders = {AbstractTensor: lambda x: x} validate_assignment = True + _load_extra_fields_from_protobuf = False @classmethod def from_view(cls: Type[T], storage_view: 'ColumnStorageView') -> T: diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 2afd8c46579..9d19e4337bc 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -39,7 +39,6 @@ if torch is not None: from docarray.typing import TorchTensor - T = TypeVar('T', bound='IOMixin') @@ -122,11 +121,18 @@ class IOMixin(Iterable[Tuple[str, Any]]): __fields__: Dict[str, 'ModelField'] + class Config: + _load_extra_fields_from_protobuf: bool + @classmethod @abstractmethod def _get_field_type(cls, field: str) -> Type: ... + @classmethod + def _get_field_type_array(cls, field: str) -> Type: + return cls._get_field_type(field) + def __bytes__(self) -> bytes: return self.to_bytes() @@ -149,7 +155,8 @@ def to_bytes( bstr = self.to_protobuf().SerializePartialToString() else: raise ValueError( - f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' + f'protocol={protocol} is not supported. Can be only `protobuf` or ' + f'pickle protocols 0-5.' ) return _compress_bytes(bstr, algorithm=compress) @@ -178,7 +185,8 @@ def from_bytes( return cls.from_protobuf(pb_msg) else: raise ValueError( - f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' + f'protocol={protocol} is not supported. Can be only `protobuf` or ' + f'pickle protocols 0-5.' ) def to_base64( @@ -219,7 +227,10 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T: fields: Dict[str, Any] = {} for field_name in pb_msg.data: - if field_name not in cls.__fields__.keys(): + if ( + not (cls.Config._load_extra_fields_from_protobuf) + and field_name not in cls.__fields__.keys() + ): continue # optimization we don't even load the data if the key does not # match any field in the cls or in the mapping @@ -253,14 +264,22 @@ def _get_content_from_node_proto( return_field = content_type_dict[docarray_type].from_protobuf( getattr(value, content_key) ) - elif content_key in ['doc', 'doc_array']: + elif content_key == 'doc': if field_name is None: raise ValueError( - 'field_name cannot be None when trying to deseriliaze a Document or a DocList' + 'field_name cannot be None when trying to deseriliaze a BaseDoc' ) return_field = cls._get_field_type(field_name).from_protobuf( getattr(value, content_key) ) # we get to the parent class + elif content_key == 'doc_array': + if field_name is None: + raise ValueError( + 'field_name cannot be None when trying to deseriliaze a BaseDoc' + ) + return_field = cls._get_field_type_array(field_name).from_protobuf( + getattr(value, content_key) + ) # we get to the parent class elif content_key is None: return_field = None elif docarray_type is None: diff --git a/tests/units/array/test_array_proto.py b/tests/units/array/test_array_proto.py index ebdf0d9a3f9..e57cc3313f5 100644 --- a/tests/units/array/test_array_proto.py +++ b/tests/units/array/test_array_proto.py @@ -2,6 +2,7 @@ import pytest from docarray import BaseDoc, DocList +from docarray.base_doc import AnyDoc from docarray.documents import ImageDoc, TextDoc from docarray.typing import NdArray @@ -59,3 +60,34 @@ class CustomDocument(BaseDoc): ) DocList.from_protobuf(da.to_protobuf()) + + +@pytest.mark.proto +def test_any_doc_list_proto(): + doc = AnyDoc(hello='world') + pt = DocList([doc]).to_protobuf() + docs = DocList.from_protobuf(pt) + assert docs[0].dict()['hello'] == 'world' + + +@pytest.mark.proto +def test_any_nested_doc_list_proto(): + from docarray import BaseDoc, DocList + + class TextDocWithId(BaseDoc): + id: str + text: str + + class ResultTestDoc(BaseDoc): + matches: DocList[TextDocWithId] + + index_da = DocList[TextDocWithId]( + [TextDocWithId(id=f'{i}', text=f'ID {i}') for i in range(10)] + ) + + out_da = DocList[ResultTestDoc]([ResultTestDoc(matches=index_da[0:2])]) + pb = out_da.to_protobuf() + docs = DocList.from_protobuf(pb) + assert docs[0].matches[0].id == '0' + assert len(docs[0].matches) == 2 + assert len(docs) == 1 diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index cb5442f7700..a95a0edec62 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -5,7 +5,7 @@ import torch from docarray import DocList -from docarray.base_doc import BaseDoc +from docarray.base_doc import AnyDoc, BaseDoc from docarray.typing import NdArray, TorchTensor from docarray.utils._internal.misc import is_tf_available @@ -296,3 +296,11 @@ class MyDoc(BaseDoc): doc = MyDoc(data=data) MyDoc.from_protobuf(doc.to_protobuf()) + + +@pytest.mark.proto +def test_any_doc_proto(): + doc = AnyDoc(hello='world') + pt = doc.to_protobuf() + doc2 = AnyDoc.from_protobuf(pt) + assert doc2.dict()['hello'] == 'world'