Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Dict,
Mapping,
Optional,
Type,
TypeVar,
Expand All @@ -24,7 +26,7 @@
if TYPE_CHECKING:
from pydantic import Protocol
from pydantic.types import StrBytes
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny

from docarray.array.doc_vec.column_storage import ColumnStorageView

Expand Down Expand Up @@ -57,7 +59,8 @@ class MyDoc(BaseDoc):
```


BaseDoc is a subclass of [pydantic.BaseModel](https://docs.pydantic.dev/usage/models/) and can be used in a similar way.
BaseDoc is a subclass of [pydantic.BaseModel](
https://docs.pydantic.dev/usage/models/) and can be used in a similar way.
"""

id: Optional[ID] = Field(default_factory=lambda: ID(os.urandom(16).hex()))
Expand Down Expand Up @@ -180,7 +183,8 @@ def _docarray_to_json_compatible(self) -> Dict:
return self.dict()

########################################################################################################################################################
### this section is just for documentation purposes will be removed later once https://github.com/mkdocstrings/griffe/issues/138 is fixed ##############
### this section is just for documentation purposes will be removed later once
# https://github.com/mkdocstrings/griffe/issues/138 is fixed ##############
########################################################################################################################################################

def json(
Expand All @@ -198,9 +202,11 @@ def json(
**dumps_kwargs: Any,
) -> str:
"""
Generate a JSON representation of the model, `include` and `exclude` arguments as per `dict()`.
Generate a JSON representation of the model, `include` and `exclude`
arguments as per `dict()`.

`encoder` is an optional function to supply as `default` to json.dumps(), other arguments as per `json.dumps()`.
`encoder` is an optional function to supply as `default` to json.dumps(),
other arguments as per `json.dumps()`.
"""
return super().json(
include=include,
Expand Down Expand Up @@ -242,3 +248,55 @@ def parse_raw(
proto=proto,
allow_pickle=allow_pickle,
)

def dict(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> 'DictStrAny':
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add description of each argument?

Generate a dictionary representation of the model, optionally specifying
which fields to include or exclude.

"""

doclist_exclude_fields = []
for field in self.__fields__.keys():
from docarray import DocList

type_ = self._get_field_type(field)
if isinstance(type_, type) and issubclass(type_, DocList):
doclist_exclude_fields.append(field)

original_exclude = exclude
if exclude is None:
exclude = set(doclist_exclude_fields)
elif isinstance(exclude, AbstractSet):
exclude = set([*exclude, *doclist_exclude_fields])
elif isinstance(exclude, Mapping):
exclude = dict(**exclude)
exclude.update({field: ... for field in doclist_exclude_fields})

data = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

for field in doclist_exclude_fields:
# we need to do this because pydantic will not recognize DocList correctly
original_exclude = original_exclude or {}
if field not in original_exclude:
data[field] = [doc.dict() for doc in getattr(self, field)]

return data
42 changes: 41 additions & 1 deletion tests/units/document/test_base_document.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List, Optional

from docarray.base_doc.doc import BaseDoc
import numpy as np
import pytest

from docarray import BaseDoc, DocList
from docarray.typing import NdArray


def test_base_document_init():
Expand Down Expand Up @@ -43,3 +47,39 @@ class NestedDoc(BaseDoc):
)

assert nested_docs == nested_docs


@pytest.fixture
def nested_docs():
class SimpleDoc(BaseDoc):
simple_tens: NdArray[10]

class NestedDoc(BaseDoc):
docs: DocList[SimpleDoc]
hello: str = 'world'

nested_docs = NestedDoc(
docs=DocList[SimpleDoc]([SimpleDoc(simple_tens=np.ones(10)) for j in range(2)]),
)

return nested_docs


def test_nested_to_dict(nested_docs):
d = nested_docs.dict()
assert (d['docs'][0]['simple_tens'] == np.ones(10)).all()


def test_nested_to_dict_exclude(nested_docs):
d = nested_docs.dict(exclude={'docs'})
assert 'docs' not in d.keys()


def test_nested_to_dict_exclude_set(nested_docs):
d = nested_docs.dict(exclude={'hello'})
assert 'hello' not in d.keys()


def test_nested_to_dict_exclude_dict(nested_docs): # doto change
d = nested_docs.dict(exclude={'hello': True})
assert 'hello' not in d.keys()