From a7330d7e481f63dca7e24711c4407813c14120d9 Mon Sep 17 00:00:00 2001 From: samsja Date: Mon, 8 May 2023 10:07:20 +0200 Subject: [PATCH 1/2] fix: fix nested doc to json Signed-off-by: samsja --- docarray/base_doc/doc.py | 106 +++++++++++++++------ tests/units/document/test_base_document.py | 5 + 2 files changed, 82 insertions(+), 29 deletions(-) diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 86705302108..04a74c8c312 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -1,20 +1,25 @@ import os +import warnings from typing import ( TYPE_CHECKING, AbstractSet, Any, Callable, Dict, + List, Mapping, Optional, + Tuple, Type, TypeVar, Union, + cast, no_type_check, ) import orjson from pydantic import BaseModel, Field +from pydantic.main import ROOT_KEY from rich.console import Console from docarray.base_doc.base_node import BaseNode @@ -36,6 +41,9 @@ T_update = TypeVar('T_update', bound='UpdateMixin') +ExcludeType = Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] + + class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode): """ BaseDoc is the base class for all Documents. This class should be subclassed @@ -191,7 +199,7 @@ def json( self, *, include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, - exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: ExcludeType = None, by_alias: bool = False, skip_defaults: Optional[bool] = None, exclude_unset: bool = False, @@ -208,19 +216,46 @@ def json( `encoder` is an optional function to supply as `default` to json.dumps(), other arguments as per `json.dumps()`. """ - return super().json( - include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - encoder=encoder, - models_as_dict=models_as_dict, - **dumps_kwargs, + exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist( + exclude=exclude + ) + + # this is copy from pydantic code + if skip_defaults is not None: + warnings.warn( + f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"', + DeprecationWarning, + ) + exclude_unset = skip_defaults + encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__) + + # We don't directly call `self.dict()`, which does exactly this with `to_dict=True` + # because we want to be able to keep raw `BaseModel` instances and not as `dict`. + # This allows users to write custom JSON encoders for given `BaseModel` classes. + data = dict( + self._iter( + to_dict=models_as_dict, + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) ) + # this is the custom part to deal with DocList + for field in doclist_exclude_fields: + # we need to do this because pydantic will not recognize DocList correctly + original_exclude = original_exclude or {} + if field not in original_exclude: + data[field] = [doc.dict() for doc in getattr(self, field)] + + # this is copy from pydantic code + if self.__custom_root_type__: + data = data[ROOT_KEY] + return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) + @no_type_check @classmethod def parse_raw( @@ -253,7 +288,7 @@ def dict( self, *, include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, - exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: ExcludeType = None, by_alias: bool = False, skip_defaults: Optional[bool] = None, exclude_unset: bool = False, @@ -266,22 +301,9 @@ def dict( """ - doclist_exclude_fields = [] - for field in self.__fields__.keys(): - from docarray import DocList - - type_ = self._get_field_type(field) - if isinstance(type_, type) and issubclass(type_, DocList): - doclist_exclude_fields.append(field) - - original_exclude = exclude - if exclude is None: - exclude = set(doclist_exclude_fields) - elif isinstance(exclude, AbstractSet): - exclude = set([*exclude, *doclist_exclude_fields]) - elif isinstance(exclude, Mapping): - exclude = dict(**exclude) - exclude.update({field: ... for field in doclist_exclude_fields}) + exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist( + exclude=exclude + ) data = super().dict( include=include, @@ -301,4 +323,30 @@ def dict( return data + def _exclude_doclist( + self, exclude: ExcludeType + ) -> Tuple[ExcludeType, ExcludeType, List[str]]: + doclist_exclude_fields = [] + for field in self.__fields__.keys(): + from docarray import DocList + + type_ = self._get_field_type(field) + if isinstance(type_, type) and issubclass(type_, DocList): + doclist_exclude_fields.append(field) + + original_exclude = exclude + if exclude is None: + exclude = set(doclist_exclude_fields) + elif isinstance(exclude, AbstractSet): + exclude = set([*exclude, *doclist_exclude_fields]) + elif isinstance(exclude, Mapping): + exclude = dict(**exclude) + exclude.update({field: ... for field in doclist_exclude_fields}) + + return ( + exclude, + original_exclude, + doclist_exclude_fields, + ) + to_json = json diff --git a/tests/units/document/test_base_document.py b/tests/units/document/test_base_document.py index 74b9e2d5332..2636d1b3d94 100644 --- a/tests/units/document/test_base_document.py +++ b/tests/units/document/test_base_document.py @@ -83,3 +83,8 @@ def test_nested_to_dict_exclude_set(nested_docs): def test_nested_to_dict_exclude_dict(nested_docs): # doto change d = nested_docs.dict(exclude={'hello': True}) assert 'hello' not in d.keys() + + +def test_nested_to_json(nested_docs): + nested_docs.json() + # nested_docs.__class__.parse_raw(d) From f3c7f1a720a60fc6be3c4265951034c4ce2ed9c7 Mon Sep 17 00:00:00 2001 From: samsja Date: Mon, 8 May 2023 10:49:53 +0200 Subject: [PATCH 2/2] fix: to json Signed-off-by: samsja --- docarray/array/doc_list/doc_list.py | 8 +++++++- tests/units/array/test_array.py | 16 ++++++++++++++++ tests/units/document/test_base_document.py | 9 +++++---- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index 07d3d37fedd..c683feca637 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -13,6 +13,7 @@ overload, ) +from pydantic import parse_obj_as from typing_extensions import SupportsIndex from typing_inspect import is_union_type @@ -261,8 +262,13 @@ def validate( if isinstance(value, (cls, DocVec)): return value - elif isinstance(value, Iterable): + elif isinstance(value, cls): return cls(value) + elif isinstance(value, Iterable): + docs = [] + for doc in value: + docs.append(parse_obj_as(cls.doc_type, doc)) + return cls(docs) else: raise TypeError(f'Expecting an Iterable of {cls.doc_type}') diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index 33344bd2aaa..020701d4f2e 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch +from pydantic import parse_obj_as from docarray import BaseDoc, DocList from docarray.typing import ImageUrl, NdArray, TorchTensor @@ -452,3 +453,18 @@ class Image(BaseDoc): assert docs.features == [None for _ in range(10)] assert isinstance(docs.features, list) assert not isinstance(docs.features, DocList) + + +def test_validate_list_dict(): + + images = [ + dict(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1] + ] + + docs = parse_obj_as(DocList[Image], images) + + assert docs.url == [ + 'http://url.com/foo_2.png', + 'http://url.com/foo_0.png', + 'http://url.com/foo_1.png', + ] diff --git a/tests/units/document/test_base_document.py b/tests/units/document/test_base_document.py index 2636d1b3d94..1f68d751eaf 100644 --- a/tests/units/document/test_base_document.py +++ b/tests/units/document/test_base_document.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from docarray import BaseDoc, DocList +from docarray import DocList +from docarray.base_doc.doc import BaseDoc from docarray.typing import NdArray @@ -80,11 +81,11 @@ def test_nested_to_dict_exclude_set(nested_docs): assert 'hello' not in d.keys() -def test_nested_to_dict_exclude_dict(nested_docs): # doto change +def test_nested_to_dict_exclude_dict(nested_docs): d = nested_docs.dict(exclude={'hello': True}) assert 'hello' not in d.keys() def test_nested_to_json(nested_docs): - nested_docs.json() - # nested_docs.__class__.parse_raw(d) + d = nested_docs.json() + nested_docs.__class__.parse_raw(d)