From 7073137f38985e24c4291f6844775f664ef202db Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 16 Jan 2023 11:29:07 +0100 Subject: [PATCH 1/4] fix: check attr type to decide if chunkarray of len one return as list Signed-off-by: anna-charlotte --- docarray/document/mixins/multimodal.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docarray/document/mixins/multimodal.py b/docarray/document/mixins/multimodal.py index ff1b040c7ce..f3d20596e68 100644 --- a/docarray/document/mixins/multimodal.py +++ b/docarray/document/mixins/multimodal.py @@ -2,7 +2,6 @@ import typing from docarray.dataclasses.types import ( - Field, is_multimodal, _is_field, AttributeTypeError, @@ -104,7 +103,6 @@ def _from_dataclass(cls, obj) -> 'Document': # TODO: may have to modify this? root.tags = tags root._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA] = multi_modal_schema - return root def _get_mm_attr_postion(self, attr): @@ -208,7 +206,13 @@ def _has_multimodal_attr(self, attr): def __getattr__(self, attr): if self._has_multimodal_attr(attr): mm_attr_da = self.get_multi_modal_attribute(attr) - return mm_attr_da if len(mm_attr_da) > 1 else mm_attr_da[0] + attr_type = self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA][attr][ + 'attribute_type' + ] + if attr_type == AttributeType.ITERABLE_DOCUMENT: + return mm_attr_da + else: + return mm_attr_da if len(mm_attr_da) > 1 else mm_attr_da[0] else: raise AttributeError(f'{self.__class__.__name__} has no attribute `{attr}`') From 7aff87cadab4b2d47107d7bd0fbddb60c7fc5e36 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 16 Jan 2023 11:31:41 +0100 Subject: [PATCH 2/4] test: add test for chunk retrieval in doc from dataclass Signed-off-by: anna-charlotte --- tests/unit/document/test_multi_modal.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/unit/document/test_multi_modal.py b/tests/unit/document/test_multi_modal.py index 7f9443a2c6d..da2896a8273 100644 --- a/tests/unit/document/test_multi_modal.py +++ b/tests/unit/document/test_multi_modal.py @@ -7,6 +7,7 @@ import pytest from docarray import Document, DocumentArray +from docarray.array.chunk import ChunkArray from docarray.dataclasses import dataclass, field from docarray.typing import Image, Text, Audio, Video, Mesh, Tabular, Blob, JSON from docarray.dataclasses.getter import image_getter @@ -911,3 +912,24 @@ class A: text: List[Text] doc = Document(A(text=[])) + + +def test_doc_with_dataclass_with_list_of_length_one(): + @dataclass + class MyDoc: + title: Text + images: List[Image] + + doc = Document(MyDoc(title='doc 1', images=[IMAGE_URI])) + assert type(doc.images) == ChunkArray + assert len(doc.images) == 1 + + +def test_doc_with_dataclass_without_list(): + @dataclass + class MyDoc: + title: Text + image: Image + + doc = Document(MyDoc(title='doc 1', image=IMAGE_URI)) + assert type(doc.image) == Document From 1ae420ead5eb340beea808aba727303e77f135c6 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 16 Jan 2023 12:11:33 +0100 Subject: [PATCH 3/4] fix: remove redundant if statement Signed-off-by: anna-charlotte --- docarray/document/mixins/multimodal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/document/mixins/multimodal.py b/docarray/document/mixins/multimodal.py index f3d20596e68..206fb481408 100644 --- a/docarray/document/mixins/multimodal.py +++ b/docarray/document/mixins/multimodal.py @@ -212,7 +212,7 @@ def __getattr__(self, attr): if attr_type == AttributeType.ITERABLE_DOCUMENT: return mm_attr_da else: - return mm_attr_da if len(mm_attr_da) > 1 else mm_attr_da[0] + return mm_attr_da[0] else: raise AttributeError(f'{self.__class__.__name__} has no attribute `{attr}`') From c27b5fac2f6bb6ccd9847cc169781b00287d54f8 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Mon, 16 Jan 2023 12:47:38 +0100 Subject: [PATCH 4/4] fix: check for all iterable attr types Signed-off-by: anna-charlotte --- docarray/document/mixins/multimodal.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docarray/document/mixins/multimodal.py b/docarray/document/mixins/multimodal.py index 206fb481408..4cebec66d17 100644 --- a/docarray/document/mixins/multimodal.py +++ b/docarray/document/mixins/multimodal.py @@ -209,7 +209,11 @@ def __getattr__(self, attr): attr_type = self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA][attr][ 'attribute_type' ] - if attr_type == AttributeType.ITERABLE_DOCUMENT: + if attr_type in [ + AttributeType.ITERABLE_DOCUMENT, + AttributeType.ITERABLE_NESTED, + AttributeType.ITERABLE_PRIMITIVE, + ]: return mm_attr_da else: return mm_attr_da[0]