diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index 73f1a25a17b..e3f56e74fda 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -144,6 +144,14 @@ def construct( da._data = docs if isinstance(docs, list) else list(docs) return da + def __eq__(self, other: Any) -> bool: + if self.__len__() != other.__len__(): + return False + for doc_self, doc_other in zip(self, other): + if doc_self != doc_other: + return False + return True + def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]: """ Validate if an Iterable of Document are compatible with this DocArray diff --git a/docarray/base_doc/doc.py b/docarray/base_doc/doc.py index 50abc6722a7..48afbe6eddd 100644 --- a/docarray/base_doc/doc.py +++ b/docarray/base_doc/doc.py @@ -102,6 +102,39 @@ def __setattr__(self, field, value) -> None: dict_ref[key] = val object.__setattr__(self, '__dict__', dict_ref) + def __eq__(self, other) -> bool: + if self.dict().keys() != other.dict().keys(): + return False + + for field_name in self.__fields__: + value1 = getattr(self, field_name) + value2 = getattr(other, field_name) + + if field_name == 'id': + continue + + if isinstance(value1, AbstractTensor) and isinstance( + value2, AbstractTensor + ): + comp_be1 = value1.get_comp_backend() + comp_be2 = value2.get_comp_backend() + + if comp_be1.shape(value1) != comp_be2.shape(value2): + return False + if ( + not (comp_be1.to_numpy(value1) == comp_be2.to_numpy(value2)) + .all() + .item() + ): + return False + else: + if value1 != value2: + return False + return True + + def __ne__(self, other) -> bool: + return not (self == other) + def _docarray_to_json_compatible(self) -> Dict: """ Convert itself into a json compatible object diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index be4fa6fa505..d47089176bb 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -1,5 +1,4 @@ from typing import Optional, TypeVar, Union - import numpy as np import pytest import torch @@ -10,6 +9,8 @@ tf_available = is_tf_available() if tf_available: + import tensorflow as tf + from docarray.typing import TensorFlowTensor @@ -80,6 +81,74 @@ class Text(BaseDoc): assert len(da) == 10 +def test_ndarray_equality(): + class Text(BaseDoc): + tensor: NdArray + + arr1 = Text(tensor=np.zeros(5)) + arr2 = Text(tensor=np.zeros(5)) + arr3 = Text(tensor=np.ones(5)) + arr4 = Text(tensor=np.zeros(4)) + + assert arr1 == arr2 + assert arr1 != arr3 + assert arr1 != arr4 + + +def test_tensor_equality(): + class Text(BaseDoc): + tensor: TorchTensor + + torch1 = Text(tensor=torch.zeros(128)) + torch2 = Text(tensor=torch.zeros(128)) + torch3 = Text(tensor=torch.zeros(126)) + torch4 = Text(tensor=torch.ones(128)) + + assert torch1 == torch2 + assert torch1 != torch3 + assert torch1 != torch4 + + +def test_documentarray(): + class Text(BaseDoc): + text: str + + da1 = DocArray([Text(text='hello')]) + da2 = DocArray([Text(text='hello')]) + + assert da1 == da2 + assert da1 == [Text(text='hello') for _ in range(len(da1))] + assert da2 == [Text(text='hello') for _ in range(len(da2))] + + +@pytest.mark.tensorflow +def test_tensorflowtensor_equality(): + class Text(BaseDoc): + tensor: TensorFlowTensor + + tensor1 = Text(tensor=tf.constant([1, 2, 3, 4, 5, 6])) + tensor2 = Text(tensor=tf.constant([1, 2, 3, 4, 5, 6])) + tensor3 = Text(tensor=tf.constant([[1.0, 2.0], [3.0, 5.0]])) + tensor4 = Text(tensor=tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + + assert tensor1 == tensor2 + assert tensor1 != tensor3 + assert tensor1 != tensor4 + + +def test_text_tensor(): + class Text1(BaseDoc): + tensor: NdArray + + class Text2(BaseDoc): + tensor: TorchTensor + + arr_tensor1 = Text1(tensor=np.zeros(2)) + arr_tensor2 = Text2(tensor=torch.zeros(2)) + + assert arr_tensor1 == arr_tensor2 + + def test_get_bulk_attributes_function(): class Mmdoc(BaseDoc): text: str diff --git a/tests/units/document/test_docs_operators.py b/tests/units/document/test_docs_operators.py index 53e7bb71e56..3e0e48f1a05 100644 --- a/tests/units/document/test_docs_operators.py +++ b/tests/units/document/test_docs_operators.py @@ -11,7 +11,7 @@ def test_text_document_operators(): assert doc == doc2 doc3 = TextDoc(id='other-id', text='text', url='http://url.com') - assert doc != doc3 + assert doc == doc3 assert 't' in doc assert 'a' not in doc