diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 13e2c6ea254..4457859360d 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -1,9 +1,11 @@ import os from typing import ( TYPE_CHECKING, + AbstractSet, Any, Callable, Dict, + Mapping, Optional, Type, TypeVar, @@ -24,7 +26,7 @@ if TYPE_CHECKING: from pydantic import Protocol from pydantic.types import StrBytes - from pydantic.typing import AbstractSetIntStr, MappingIntStrAny + from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny from docarray.array.doc_vec.column_storage import ColumnStorageView @@ -57,7 +59,8 @@ class MyDoc(BaseDoc): ``` - BaseDoc is a subclass of [pydantic.BaseModel](https://docs.pydantic.dev/usage/models/) and can be used in a similar way. + BaseDoc is a subclass of [pydantic.BaseModel]( + https://docs.pydantic.dev/usage/models/) and can be used in a similar way. """ id: Optional[ID] = Field(default_factory=lambda: ID(os.urandom(16).hex())) @@ -180,7 +183,8 @@ def _docarray_to_json_compatible(self) -> Dict: return self.dict() ######################################################################################################################################################## - ### this section is just for documentation purposes will be removed later once https://github.com/mkdocstrings/griffe/issues/138 is fixed ############## + ### this section is just for documentation purposes will be removed later once + # https://github.com/mkdocstrings/griffe/issues/138 is fixed ############## ######################################################################################################################################################## def json( @@ -198,9 +202,11 @@ def json( **dumps_kwargs: Any, ) -> str: """ - Generate a JSON representation of the model, `include` and `exclude` arguments as per `dict()`. + Generate a JSON representation of the model, `include` and `exclude` + arguments as per `dict()`. - `encoder` is an optional function to supply as `default` to json.dumps(), other arguments as per `json.dumps()`. + `encoder` is an optional function to supply as `default` to json.dumps(), + other arguments as per `json.dumps()`. """ return super().json( include=include, @@ -242,3 +248,55 @@ def parse_raw( proto=proto, allow_pickle=allow_pickle, ) + + def dict( + self, + *, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> 'DictStrAny': + """ + Generate a dictionary representation of the model, optionally specifying + which fields to include or exclude. + + """ + + doclist_exclude_fields = [] + for field in self.__fields__.keys(): + from docarray import DocList + + type_ = self._get_field_type(field) + if isinstance(type_, type) and issubclass(type_, DocList): + doclist_exclude_fields.append(field) + + original_exclude = exclude + if exclude is None: + exclude = set(doclist_exclude_fields) + elif isinstance(exclude, AbstractSet): + exclude = set([*exclude, *doclist_exclude_fields]) + elif isinstance(exclude, Mapping): + exclude = dict(**exclude) + exclude.update({field: ... for field in doclist_exclude_fields}) + + data = super().dict( + include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + for field in doclist_exclude_fields: + # we need to do this because pydantic will not recognize DocList correctly + original_exclude = original_exclude or {} + if field not in original_exclude: + data[field] = [doc.dict() for doc in getattr(self, field)] + + return data diff --git a/tests/units/document/test_base_document.py b/tests/units/document/test_base_document.py index 0d62c069dd7..74b9e2d5332 100644 --- a/tests/units/document/test_base_document.py +++ b/tests/units/document/test_base_document.py @@ -1,6 +1,10 @@ from typing import List, Optional -from docarray.base_doc.doc import BaseDoc +import numpy as np +import pytest + +from docarray import BaseDoc, DocList +from docarray.typing import NdArray def test_base_document_init(): @@ -43,3 +47,39 @@ class NestedDoc(BaseDoc): ) assert nested_docs == nested_docs + + +@pytest.fixture +def nested_docs(): + class SimpleDoc(BaseDoc): + simple_tens: NdArray[10] + + class NestedDoc(BaseDoc): + docs: DocList[SimpleDoc] + hello: str = 'world' + + nested_docs = NestedDoc( + docs=DocList[SimpleDoc]([SimpleDoc(simple_tens=np.ones(10)) for j in range(2)]), + ) + + return nested_docs + + +def test_nested_to_dict(nested_docs): + d = nested_docs.dict() + assert (d['docs'][0]['simple_tens'] == np.ones(10)).all() + + +def test_nested_to_dict_exclude(nested_docs): + d = nested_docs.dict(exclude={'docs'}) + assert 'docs' not in d.keys() + + +def test_nested_to_dict_exclude_set(nested_docs): + d = nested_docs.dict(exclude={'hello'}) + assert 'hello' not in d.keys() + + +def test_nested_to_dict_exclude_dict(nested_docs): # doto change + d = nested_docs.dict(exclude={'hello': True}) + assert 'hello' not in d.keys()