diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index a9deba3bec6..132cdcb9f66 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -292,3 +292,6 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]): return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore else: return super().__class_getitem__(item) + + def __repr__(self): + return AnyDocArray.__repr__(self) # type: ignore diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 1fa9bc3e376..13e2c6ea254 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -140,7 +140,7 @@ def __setattr__(self, field, value) -> None: object.__setattr__(self, '__dict__', dict_ref) def __eq__(self, other) -> bool: - if self.dict().keys() != other.dict().keys(): + if self.__fields__.keys() != other.__fields__.keys(): return False for field_name in self.__fields__: diff --git a/tests/units/document/test_base_document.py b/tests/units/document/test_base_document.py index e986ff0f1bb..0d62c069dd7 100644 --- a/tests/units/document/test_base_document.py +++ b/tests/units/document/test_base_document.py @@ -24,3 +24,22 @@ class MyDocument(BaseDoc): assert doc1.content == 'Core content updated' assert doc1.title == 'Title' assert doc1.tags_ == ['python', 'AI', 'docarray'] + + +def test_equal_nested_docs(): + import numpy as np + + from docarray import BaseDoc, DocList + from docarray.typing import NdArray + + class SimpleDoc(BaseDoc): + simple_tens: NdArray[10] + + class NestedDoc(BaseDoc): + docs: DocList[SimpleDoc] + + nested_docs = NestedDoc( + docs=DocList[SimpleDoc]([SimpleDoc(simple_tens=np.ones(10)) for j in range(2)]), + ) + + assert nested_docs == nested_docs