diff --git a/docarray/document/mixins/multimodal.py b/docarray/document/mixins/multimodal.py index ff1b040c7ce..4cebec66d17 100644 --- a/docarray/document/mixins/multimodal.py +++ b/docarray/document/mixins/multimodal.py @@ -2,7 +2,6 @@ import typing from docarray.dataclasses.types import ( - Field, is_multimodal, _is_field, AttributeTypeError, @@ -104,7 +103,6 @@ def _from_dataclass(cls, obj) -> 'Document': # TODO: may have to modify this? root.tags = tags root._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA] = multi_modal_schema - return root def _get_mm_attr_postion(self, attr): @@ -208,7 +206,17 @@ def _has_multimodal_attr(self, attr): def __getattr__(self, attr): if self._has_multimodal_attr(attr): mm_attr_da = self.get_multi_modal_attribute(attr) - return mm_attr_da if len(mm_attr_da) > 1 else mm_attr_da[0] + attr_type = self._metadata[DocumentMetadata.MULTI_MODAL_SCHEMA][attr][ + 'attribute_type' + ] + if attr_type in [ + AttributeType.ITERABLE_DOCUMENT, + AttributeType.ITERABLE_NESTED, + AttributeType.ITERABLE_PRIMITIVE, + ]: + return mm_attr_da + else: + return mm_attr_da[0] else: raise AttributeError(f'{self.__class__.__name__} has no attribute `{attr}`') diff --git a/tests/unit/document/test_multi_modal.py b/tests/unit/document/test_multi_modal.py index 7f9443a2c6d..da2896a8273 100644 --- a/tests/unit/document/test_multi_modal.py +++ b/tests/unit/document/test_multi_modal.py @@ -7,6 +7,7 @@ import pytest from docarray import Document, DocumentArray +from docarray.array.chunk import ChunkArray from docarray.dataclasses import dataclass, field from docarray.typing import Image, Text, Audio, Video, Mesh, Tabular, Blob, JSON from docarray.dataclasses.getter import image_getter @@ -911,3 +912,24 @@ class A: text: List[Text] doc = Document(A(text=[])) + + +def test_doc_with_dataclass_with_list_of_length_one(): + @dataclass + class MyDoc: + title: Text + images: List[Image] + + doc = Document(MyDoc(title='doc 1', images=[IMAGE_URI])) + assert type(doc.images) == ChunkArray + assert len(doc.images) == 1 + + +def test_doc_with_dataclass_without_list(): + @dataclass + class MyDoc: + title: Text + image: Image + + doc = Document(MyDoc(title='doc 1', image=IMAGE_URI)) + assert type(doc.image) == Document