Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
overload,
)

from pydantic import parse_obj_as
from typing_extensions import SupportsIndex
from typing_inspect import is_union_type

Expand Down Expand Up @@ -261,8 +262,13 @@ def validate(

if isinstance(value, (cls, DocVec)):
return value
elif isinstance(value, Iterable):
elif isinstance(value, cls):
return cls(value)
elif isinstance(value, Iterable):
docs = []
for doc in value:
docs.append(parse_obj_as(cls.doc_type, doc))
return cls(docs)
else:
raise TypeError(f'Expecting an Iterable of {cls.doc_type}')

Expand Down
106 changes: 77 additions & 29 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import os
import warnings
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
no_type_check,
)

import orjson
from pydantic import BaseModel, Field
from pydantic.main import ROOT_KEY
from rich.console import Console

from docarray.base_doc.base_node import BaseNode
Expand All @@ -36,6 +41,9 @@
T_update = TypeVar('T_update', bound='UpdateMixin')


ExcludeType = Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]


class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode):
"""
BaseDoc is the base class for all Documents. This class should be subclassed
Expand Down Expand Up @@ -191,7 +199,7 @@ def json(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: ExcludeType = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
Expand All @@ -208,19 +216,46 @@ def json(
`encoder` is an optional function to supply as `default` to json.dumps(),
other arguments as per `json.dumps()`.
"""
return super().json(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
encoder=encoder,
models_as_dict=models_as_dict,
**dumps_kwargs,
exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist(
exclude=exclude
)

# this is copy from pydantic code
if skip_defaults is not None:
warnings.warn(
f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"',
DeprecationWarning,
)
exclude_unset = skip_defaults
encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__)

# We don't directly call `self.dict()`, which does exactly this with `to_dict=True`
# because we want to be able to keep raw `BaseModel` instances and not as `dict`.
# This allows users to write custom JSON encoders for given `BaseModel` classes.
data = dict(
self._iter(
to_dict=models_as_dict,
by_alias=by_alias,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
)

# this is the custom part to deal with DocList
for field in doclist_exclude_fields:
# we need to do this because pydantic will not recognize DocList correctly
original_exclude = original_exclude or {}
if field not in original_exclude:
data[field] = [doc.dict() for doc in getattr(self, field)]

# this is copy from pydantic code
if self.__custom_root_type__:
data = data[ROOT_KEY]
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)

@no_type_check
@classmethod
def parse_raw(
Expand Down Expand Up @@ -253,7 +288,7 @@ def dict(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: ExcludeType = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
Expand All @@ -266,22 +301,9 @@ def dict(

"""

doclist_exclude_fields = []
for field in self.__fields__.keys():
from docarray import DocList

type_ = self._get_field_type(field)
if isinstance(type_, type) and issubclass(type_, DocList):
doclist_exclude_fields.append(field)

original_exclude = exclude
if exclude is None:
exclude = set(doclist_exclude_fields)
elif isinstance(exclude, AbstractSet):
exclude = set([*exclude, *doclist_exclude_fields])
elif isinstance(exclude, Mapping):
exclude = dict(**exclude)
exclude.update({field: ... for field in doclist_exclude_fields})
exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist(
exclude=exclude
)

data = super().dict(
include=include,
Expand All @@ -301,4 +323,30 @@ def dict(

return data

def _exclude_doclist(
self, exclude: ExcludeType
) -> Tuple[ExcludeType, ExcludeType, List[str]]:
doclist_exclude_fields = []
for field in self.__fields__.keys():
from docarray import DocList

type_ = self._get_field_type(field)
if isinstance(type_, type) and issubclass(type_, DocList):
doclist_exclude_fields.append(field)

original_exclude = exclude
if exclude is None:
exclude = set(doclist_exclude_fields)
elif isinstance(exclude, AbstractSet):
exclude = set([*exclude, *doclist_exclude_fields])
elif isinstance(exclude, Mapping):
exclude = dict(**exclude)
exclude.update({field: ... for field in doclist_exclude_fields})

return (
exclude,
original_exclude,
doclist_exclude_fields,
)

to_json = json
16 changes: 16 additions & 0 deletions tests/units/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
import torch
from pydantic import parse_obj_as

from docarray import BaseDoc, DocList
from docarray.typing import ImageUrl, NdArray, TorchTensor
Expand Down Expand Up @@ -452,3 +453,18 @@ class Image(BaseDoc):
assert docs.features == [None for _ in range(10)]
assert isinstance(docs.features, list)
assert not isinstance(docs.features, DocList)


def test_validate_list_dict():

images = [
dict(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1]
]

docs = parse_obj_as(DocList[Image], images)

assert docs.url == [
'http://url.com/foo_2.png',
'http://url.com/foo_0.png',
'http://url.com/foo_1.png',
]
10 changes: 8 additions & 2 deletions tests/units/document/test_base_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import pytest

from docarray import BaseDoc, DocList
from docarray import DocList
from docarray.base_doc.doc import BaseDoc
from docarray.typing import NdArray


Expand Down Expand Up @@ -80,6 +81,11 @@ def test_nested_to_dict_exclude_set(nested_docs):
assert 'hello' not in d.keys()


def test_nested_to_dict_exclude_dict(nested_docs): # doto change
def test_nested_to_dict_exclude_dict(nested_docs):
d = nested_docs.dict(exclude={'hello': True})
assert 'hello' not in d.keys()


def test_nested_to_json(nested_docs):
d = nested_docs.json()
nested_docs.__class__.parse_raw(d)