From d0622f3473b9f11973e7ffea6098833b69c7ecb6 Mon Sep 17 00:00:00 2001 From: samsja Date: Mon, 24 Apr 2023 12:14:45 +0200 Subject: [PATCH 1/7] feat: add config to load more field that the schema Signed-off-by: samsja --- docarray/base_doc/any_doc.py | 4 ++++ docarray/base_doc/doc.py | 1 + docarray/base_doc/mixins/io.py | 5 ++++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/docarray/base_doc/any_doc.py b/docarray/base_doc/any_doc.py index f29b6f6e01e..dee118468de 100644 --- a/docarray/base_doc/any_doc.py +++ b/docarray/base_doc/any_doc.py @@ -8,6 +8,10 @@ class AnyDoc(BaseDoc): AnyDoc is a Document that is not tied to any schema """ + class Config: + load_extra_fields_from_protobuf = True # I introduce this variable to allow to load more that the fields defined in the schema + # will documented this behavior later if this fix our problem + def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 0ed39bd0d49..ee0bdc54edc 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -50,6 +50,7 @@ class Config: json_encoders = {AbstractTensor: lambda x: x} validate_assignment = True + load_extra_fields_from_protobuf = False @classmethod def from_view(cls: Type[T], storage_view: 'ColumnStorageView') -> T: diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 2afd8c46579..7333084f7ba 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -219,7 +219,10 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T: fields: Dict[str, Any] = {} for field_name in pb_msg.data: - if field_name not in cls.__fields__.keys(): + if ( + not (cls.Config.load_extra_fields_from_protobuf) + and field_name not in cls.__fields__.keys() + ): continue # optimization we don't even load the data if the key does not # match any field in the cls or in the mapping From 49f20ca1632135d8e1587807594d569b61b18df7 Mon Sep 17 00:00:00 2001 From: samsja Date: Mon, 24 Apr 2023 18:07:15 +0200 Subject: [PATCH 2/7] feat: add test Signed-off-by: samsja --- tests/units/array/test_array_proto.py | 9 +++++++++ tests/units/document/proto/test_document_proto.py | 11 ++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/units/array/test_array_proto.py b/tests/units/array/test_array_proto.py index ebdf0d9a3f9..d8c408bf1f4 100644 --- a/tests/units/array/test_array_proto.py +++ b/tests/units/array/test_array_proto.py @@ -2,6 +2,7 @@ import pytest from docarray import BaseDoc, DocList +from docarray.base_doc import AnyDoc from docarray.documents import ImageDoc, TextDoc from docarray.typing import NdArray @@ -59,3 +60,11 @@ class CustomDocument(BaseDoc): ) DocList.from_protobuf(da.to_protobuf()) + + +@pytest.mark.proto +def test_any_doc_list_proto(): + doc = AnyDoc(hello='world') + pt = DocList([doc]).to_protobuf() + docs = DocList.from_protobuf(pt) + assert docs[0].dict()['hello'] == 'world' diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index cb5442f7700..8820d28ab34 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -5,7 +5,7 @@ import torch from docarray import DocList -from docarray.base_doc import BaseDoc +from docarray.base_doc import AnyDoc, BaseDoc from docarray.typing import NdArray, TorchTensor from docarray.utils._internal.misc import is_tf_available @@ -287,6 +287,7 @@ class MyDoc(BaseDoc): (doc2.data['hello'][3][0] == torch.ones(55)).all() +@pytest.mark.proto @pytest.mark.tensorflow def test_super_complex_nested_tensorflow(): class MyDoc(BaseDoc): @@ -296,3 +297,11 @@ class MyDoc(BaseDoc): doc = MyDoc(data=data) MyDoc.from_protobuf(doc.to_protobuf()) + + +@pytest.mark.proto +def test_any_doc_proto(): + doc = AnyDoc(hello='world') + pt = doc.to_protobuf() + doc2 = AnyDoc.from_protobuf(pt) + assert doc2.dict()['hello'] == 'world' From 1e3515b3775a392275a5658505a17cfdb945fc61 Mon Sep 17 00:00:00 2001 From: samsja Date: Mon, 24 Apr 2023 18:23:05 +0200 Subject: [PATCH 3/7] fix: fix nested doclist Signed-off-by: samsja --- docarray/base_doc/any_doc.py | 6 ++++++ docarray/base_doc/mixins/io.py | 16 ++++++++++++++-- tests/units/array/test_array_proto.py | 23 +++++++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/docarray/base_doc/any_doc.py b/docarray/base_doc/any_doc.py index dee118468de..8a45e2a36d0 100644 --- a/docarray/base_doc/any_doc.py +++ b/docarray/base_doc/any_doc.py @@ -26,3 +26,9 @@ def _get_field_type(cls, field: str) -> Type['BaseDoc']: :return: """ return AnyDoc + + @classmethod + def _get_field_type_array(cls, field: str) -> Type: + from docarray import DocList + + return DocList diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 7333084f7ba..ed78cb4a7d6 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -127,6 +127,10 @@ class IOMixin(Iterable[Tuple[str, Any]]): def _get_field_type(cls, field: str) -> Type: ... + @classmethod + def _get_field_type_array(cls, field: str) -> Type: + return cls._get_field_type(field) + def __bytes__(self) -> bytes: return self.to_bytes() @@ -256,14 +260,22 @@ def _get_content_from_node_proto( return_field = content_type_dict[docarray_type].from_protobuf( getattr(value, content_key) ) - elif content_key in ['doc', 'doc_array']: + elif content_key == 'doc': if field_name is None: raise ValueError( - 'field_name cannot be None when trying to deseriliaze a Document or a DocList' + 'field_name cannot be None when trying to deseriliaze a BaseDoc' ) return_field = cls._get_field_type(field_name).from_protobuf( getattr(value, content_key) ) # we get to the parent class + elif content_key == 'doc_array': + if field_name is None: + raise ValueError( + 'field_name cannot be None when trying to deseriliaze a BaseDoc' + ) + return_field = cls._get_field_type_array(field_name).from_protobuf( + getattr(value, content_key) + ) # we get to the parent class elif content_key is None: return_field = None elif docarray_type is None: diff --git a/tests/units/array/test_array_proto.py b/tests/units/array/test_array_proto.py index d8c408bf1f4..e57cc3313f5 100644 --- a/tests/units/array/test_array_proto.py +++ b/tests/units/array/test_array_proto.py @@ -68,3 +68,26 @@ def test_any_doc_list_proto(): pt = DocList([doc]).to_protobuf() docs = DocList.from_protobuf(pt) assert docs[0].dict()['hello'] == 'world' + + +@pytest.mark.proto +def test_any_nested_doc_list_proto(): + from docarray import BaseDoc, DocList + + class TextDocWithId(BaseDoc): + id: str + text: str + + class ResultTestDoc(BaseDoc): + matches: DocList[TextDocWithId] + + index_da = DocList[TextDocWithId]( + [TextDocWithId(id=f'{i}', text=f'ID {i}') for i in range(10)] + ) + + out_da = DocList[ResultTestDoc]([ResultTestDoc(matches=index_da[0:2])]) + pb = out_da.to_protobuf() + docs = DocList.from_protobuf(pb) + assert docs[0].matches[0].id == '0' + assert len(docs[0].matches) == 2 + assert len(docs) == 1 From ce04d431af3efbd3c510a583fc3c8881cef5d69b Mon Sep 17 00:00:00 2001 From: samsja Date: Wed, 26 Apr 2023 09:10:14 +0200 Subject: [PATCH 4/7] fix: fix mypy Signed-off-by: samsja --- docarray/base_doc/mixins/io.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index ed78cb4a7d6..f462922dbad 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -39,7 +39,6 @@ if torch is not None: from docarray.typing import TorchTensor - T = TypeVar('T', bound='IOMixin') @@ -122,6 +121,9 @@ class IOMixin(Iterable[Tuple[str, Any]]): __fields__: Dict[str, 'ModelField'] + class Config: + load_extra_fields_from_protobuf: bool + @classmethod @abstractmethod def _get_field_type(cls, field: str) -> Type: @@ -153,7 +155,8 @@ def to_bytes( bstr = self.to_protobuf().SerializePartialToString() else: raise ValueError( - f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' + f'protocol={protocol} is not supported. Can be only `protobuf` or ' + f'pickle protocols 0-5.' ) return _compress_bytes(bstr, algorithm=compress) @@ -182,7 +185,8 @@ def from_bytes( return cls.from_protobuf(pb_msg) else: raise ValueError( - f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.' + f'protocol={protocol} is not supported. Can be only `protobuf` or ' + f'pickle protocols 0-5.' ) def to_base64( From 76caae686733f5e2cde637e8e481639b026c18cc Mon Sep 17 00:00:00 2001 From: samsja Date: Wed, 26 Apr 2023 09:37:49 +0200 Subject: [PATCH 5/7] docs: add BaseDoc docstring Signed-off-by: samsja --- docarray/base_doc/doc.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index ee0bdc54edc..9748c3341bf 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -36,7 +36,28 @@ class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode): """ - The base class for Documents + BaseDoc is the base class for all Documents. This class should be subclassed + to create new Document types with a specific schema. + + The schema of a Document is defined by the fields of the class. + + Example: + ```python + from docarray import BaseDoc + from docarray.typing import NdArray, ImageUrl + import numpy as np + + + class MyDoc(BaseDoc): + embedding: NdArray[512] + image: ImageUrl + + + doc = MyDoc(embedding=np.zeros(512), image='https://example.com/image.jpg') + ``` + + + BaseDoc is a subclass of [pydantic.BaseModel](https://docs.pydantic.dev/usage/models/) and can be used in a similar way. """ id: Optional[ID] = Field(default_factory=lambda: ID(os.urandom(16).hex())) From 969ed1decbb351a1a8564dae5ea3e009041f691c Mon Sep 17 00:00:00 2001 From: samsja Date: Wed, 26 Apr 2023 09:38:42 +0200 Subject: [PATCH 6/7] refactor: make config private Signed-off-by: samsja --- docarray/base_doc/any_doc.py | 2 +- docarray/base_doc/doc.py | 2 +- docarray/base_doc/mixins/io.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docarray/base_doc/any_doc.py b/docarray/base_doc/any_doc.py index 8a45e2a36d0..e04c256f8bb 100644 --- a/docarray/base_doc/any_doc.py +++ b/docarray/base_doc/any_doc.py @@ -9,7 +9,7 @@ class AnyDoc(BaseDoc): """ class Config: - load_extra_fields_from_protobuf = True # I introduce this variable to allow to load more that the fields defined in the schema + _load_extra_fields_from_protobuf = True # I introduce this variable to allow to load more that the fields defined in the schema # will documented this behavior later if this fix our problem def __init__(self, **kwargs): diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 9748c3341bf..1fa9bc3e376 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -71,7 +71,7 @@ class Config: json_encoders = {AbstractTensor: lambda x: x} validate_assignment = True - load_extra_fields_from_protobuf = False + _load_extra_fields_from_protobuf = False @classmethod def from_view(cls: Type[T], storage_view: 'ColumnStorageView') -> T: diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index f462922dbad..9d19e4337bc 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -122,7 +122,7 @@ class IOMixin(Iterable[Tuple[str, Any]]): __fields__: Dict[str, 'ModelField'] class Config: - load_extra_fields_from_protobuf: bool + _load_extra_fields_from_protobuf: bool @classmethod @abstractmethod @@ -228,7 +228,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T: for field_name in pb_msg.data: if ( - not (cls.Config.load_extra_fields_from_protobuf) + not (cls.Config._load_extra_fields_from_protobuf) and field_name not in cls.__fields__.keys() ): continue # optimization we don't even load the data if the key does not From 1b35ae38e78dc55723f5c2415ebafb4c93f108b9 Mon Sep 17 00:00:00 2001 From: samsja Date: Wed, 26 Apr 2023 10:00:18 +0200 Subject: [PATCH 7/7] fix: fix test Signed-off-by: samsja --- tests/units/document/proto/test_document_proto.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/units/document/proto/test_document_proto.py b/tests/units/document/proto/test_document_proto.py index 8820d28ab34..a95a0edec62 100644 --- a/tests/units/document/proto/test_document_proto.py +++ b/tests/units/document/proto/test_document_proto.py @@ -287,7 +287,6 @@ class MyDoc(BaseDoc): (doc2.data['hello'][3][0] == torch.ones(55)).all() -@pytest.mark.proto @pytest.mark.tensorflow def test_super_complex_nested_tensorflow(): class MyDoc(BaseDoc):