From 77084f5e3286b2106818815b7627f7f752ea0cc5 Mon Sep 17 00:00:00 2001 From: samsja Date: Tue, 2 May 2023 13:26:50 +0200 Subject: [PATCH 1/6] fix: fix to dict exclude Signed-off-by: samsja --- docarray/base_doc/doc.py | 65 ++++++++++++++++++++-- tests/units/document/test_base_document.py | 43 +++++++++++++- 2 files changed, 102 insertions(+), 6 deletions(-) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 13e2c6ea254..33835254b93 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -1,9 +1,11 @@ import os from typing import ( TYPE_CHECKING, + AbstractSet, Any, Callable, Dict, + Mapping, Optional, Type, TypeVar, @@ -24,7 +26,7 @@ if TYPE_CHECKING: from pydantic import Protocol from pydantic.types import StrBytes - from pydantic.typing import AbstractSetIntStr, MappingIntStrAny + from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny from docarray.array.doc_vec.column_storage import ColumnStorageView @@ -57,7 +59,8 @@ class MyDoc(BaseDoc): ``` - BaseDoc is a subclass of [pydantic.BaseModel](https://docs.pydantic.dev/usage/models/) and can be used in a similar way. + BaseDoc is a subclass of [pydantic.BaseModel]( + https://docs.pydantic.dev/usage/models/) and can be used in a similar way. """ id: Optional[ID] = Field(default_factory=lambda: ID(os.urandom(16).hex())) @@ -180,7 +183,8 @@ def _docarray_to_json_compatible(self) -> Dict: return self.dict() ######################################################################################################################################################## - ### this section is just for documentation purposes will be removed later once https://github.com/mkdocstrings/griffe/issues/138 is fixed ############## + ### this section is just for documentation purposes will be removed later once + # https://github.com/mkdocstrings/griffe/issues/138 is fixed ############## ######################################################################################################################################################## def json( @@ -198,9 +202,11 @@ def json( **dumps_kwargs: Any, ) -> str: """ - Generate a JSON representation of the model, `include` and `exclude` arguments as per `dict()`. + Generate a JSON representation of the model, `include` and `exclude` + arguments as per `dict()`. - `encoder` is an optional function to supply as `default` to json.dumps(), other arguments as per `json.dumps()`. + `encoder` is an optional function to supply as `default` to json.dumps(), + other arguments as per `json.dumps()`. """ return super().json( include=include, @@ -242,3 +248,52 @@ def parse_raw( proto=proto, allow_pickle=allow_pickle, ) + + def dict( + self, + *, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> 'DictStrAny': + """ + Generate a dictionary representation of the model, optionally specifying + which fields to include or exclude. + + """ + + doclist_exclude_fields = [] + for field in self.__fields__.keys(): + from docarray import DocList + + if issubclass(self._get_field_type(field), DocList): + doclist_exclude_fields.append(field) + + original_exclude = exclude + if exclude is None: + exclude = set(doclist_exclude_fields) + elif isinstance(exclude, AbstractSet): + exclude.update(doclist_exclude_fields) + elif isinstance(exclude, Mapping): + exclude.update({field: ... for field in doclist_exclude_fields}) + + data = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + for field in doclist_exclude_fields: + # we need to do this because pydantic will not recognize DocList correctly + if field not in original_exclude: + data[field] = [doc.dict() for doc in getattr(self, field)] + + return data diff --git a/tests/units/document/test_base_document.py b/tests/units/document/test_base_document.py index 0d62c069dd7..8c0f594a4cb 100644 --- a/tests/units/document/test_base_document.py +++ b/tests/units/document/test_base_document.py @@ -1,6 +1,10 @@ from typing import List, Optional -from docarray.base_doc.doc import BaseDoc +import numpy as np +import pytest + +from docarray import BaseDoc, DocList +from docarray.typing import NdArray def test_base_document_init(): @@ -43,3 +47,40 @@ class NestedDoc(BaseDoc): ) assert nested_docs == nested_docs + + +@pytest.fixture +def nested_docs(): + class SimpleDoc(BaseDoc): + simple_tens: NdArray[10] + + class NestedDoc(BaseDoc): + docs: DocList[SimpleDoc] + hello: str = 'world' + + nested_docs = NestedDoc( + docs=DocList[SimpleDoc]([SimpleDoc(simple_tens=np.ones(10)) for j in range(2)]), + ) + + return nested_docs + + +def test_nested_to_dict(nested_docs): + d = nested_docs.dict() + assert (d['docs'][0]['simple_tens'] == np.ones(10)).all() + + +def test_nested_to_dict_exclude_1(nested_docs): + d = nested_docs.dict(exclude={'docs'}) + assert 'docs' not in d.keys() + + +def test_nested_to_dict_exclude_2(nested_docs): + d = nested_docs.dict(exclude={'hello'}) + assert 'hello' not in d.keys() + + +def test_nested_to_dict_exclude_3(nested_docs): # doto change + d = nested_docs.dict(exclude={'hello': True}) + assert 'docs' not in d.keys() + assert 'hello' not in d.keys() From 4932157029eafb2fb89da9f717b556d349ce63fb Mon Sep 17 00:00:00 2001 From: samsja Date: Tue, 2 May 2023 13:38:25 +0200 Subject: [PATCH 2/6] fix: fix mypy Signed-off-by: samsja --- docarray/base_doc/doc.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 33835254b93..d3a666a53c6 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -277,8 +277,9 @@ def dict( if exclude is None: exclude = set(doclist_exclude_fields) elif isinstance(exclude, AbstractSet): - exclude.update(doclist_exclude_fields) + exclude = set(*exclude, *doclist_exclude_fields) elif isinstance(exclude, Mapping): + exclude = dict(**exclude) exclude.update({field: ... for field in doclist_exclude_fields}) data = super().dict( @@ -293,7 +294,8 @@ def dict( for field in doclist_exclude_fields: # we need to do this because pydantic will not recognize DocList correctly - if field not in original_exclude: - data[field] = [doc.dict() for doc in getattr(self, field)] + if original_exclude: + if field not in original_exclude: + data[field] = [doc.dict() for doc in getattr(self, field)] return data From d51a253a8b04c24678e5cd4b3d84d92561386ec0 Mon Sep 17 00:00:00 2001 From: samsja Date: Tue, 2 May 2023 15:48:04 +0200 Subject: [PATCH 3/6] fix: rename test Signed-off-by: samsja --- tests/units/document/test_base_document.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/units/document/test_base_document.py b/tests/units/document/test_base_document.py index 8c0f594a4cb..256d7a60131 100644 --- a/tests/units/document/test_base_document.py +++ b/tests/units/document/test_base_document.py @@ -70,17 +70,17 @@ def test_nested_to_dict(nested_docs): assert (d['docs'][0]['simple_tens'] == np.ones(10)).all() -def test_nested_to_dict_exclude_1(nested_docs): +def test_nested_to_dict_exclude(nested_docs): d = nested_docs.dict(exclude={'docs'}) assert 'docs' not in d.keys() -def test_nested_to_dict_exclude_2(nested_docs): +def test_nested_to_dict_exclude_set(nested_docs): d = nested_docs.dict(exclude={'hello'}) assert 'hello' not in d.keys() -def test_nested_to_dict_exclude_3(nested_docs): # doto change +def test_nested_to_dict_exclude_dict(nested_docs): # doto change d = nested_docs.dict(exclude={'hello': True}) assert 'docs' not in d.keys() assert 'hello' not in d.keys() From df2489f002df897e1c0225405447c9aec050257f Mon Sep 17 00:00:00 2001 From: samsja Date: Tue, 2 May 2023 15:57:50 +0200 Subject: [PATCH 4/6] fix: fixx isinstance check Signed-off-by: samsja --- docarray/base_doc/doc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index d3a666a53c6..8f150257ad6 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -109,7 +109,7 @@ def summary(self) -> None: @classmethod def schema_summary(cls) -> None: - """Print a summary of the Documents schema.""" + """Print a summary Fof the Documents schema.""" from docarray.display.document_summary import DocumentSummary DocumentSummary.schema_summary(cls) @@ -270,7 +270,8 @@ def dict( for field in self.__fields__.keys(): from docarray import DocList - if issubclass(self._get_field_type(field), DocList): + type_ = self._get_field_type(field) + if isinstance(type_, type) and issubclass(type_, DocList): doclist_exclude_fields.append(field) original_exclude = exclude From 29f469dcf843ddb3bb9ff611213a09bffd1d8ecd Mon Sep 17 00:00:00 2001 From: samsja Date: Tue, 2 May 2023 16:01:58 +0200 Subject: [PATCH 5/6] fix: fix exclude bug Signed-off-by: samsja --- docarray/base_doc/doc.py | 8 ++++---- tests/units/document/test_base_document.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 8f150257ad6..6f55bbdc5db 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -278,7 +278,7 @@ def dict( if exclude is None: exclude = set(doclist_exclude_fields) elif isinstance(exclude, AbstractSet): - exclude = set(*exclude, *doclist_exclude_fields) + exclude = set([*exclude, *doclist_exclude_fields]) elif isinstance(exclude, Mapping): exclude = dict(**exclude) exclude.update({field: ... for field in doclist_exclude_fields}) @@ -295,8 +295,8 @@ def dict( for field in doclist_exclude_fields: # we need to do this because pydantic will not recognize DocList correctly - if original_exclude: - if field not in original_exclude: - data[field] = [doc.dict() for doc in getattr(self, field)] + original_exclude = original_exclude or {} + if field not in original_exclude: + data[field] = [doc.dict() for doc in getattr(self, field)] return data diff --git a/tests/units/document/test_base_document.py b/tests/units/document/test_base_document.py index 256d7a60131..74b9e2d5332 100644 --- a/tests/units/document/test_base_document.py +++ b/tests/units/document/test_base_document.py @@ -82,5 +82,4 @@ def test_nested_to_dict_exclude_set(nested_docs): def test_nested_to_dict_exclude_dict(nested_docs): # doto change d = nested_docs.dict(exclude={'hello': True}) - assert 'docs' not in d.keys() assert 'hello' not in d.keys() From d09e33db99403f8592cb75e6abd6ee3c7fcca305 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Tue, 2 May 2023 18:49:41 +0200 Subject: [PATCH 6/6] docs: update docarray/base_doc/doc.py Signed-off-by: Joan Fontanals --- docarray/base_doc/doc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 6f55bbdc5db..4457859360d 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -109,7 +109,7 @@ def summary(self) -> None: @classmethod def schema_summary(cls) -> None: - """Print a summary Fof the Documents schema.""" + """Print a summary of the Documents schema.""" from docarray.display.document_summary import DocumentSummary DocumentSummary.schema_summary(cls)