Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docarray/array/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ def construct(
da._data = docs if isinstance(docs, list) else list(docs)
return da

def __eq__(self, other: Any) -> bool:
if self.__len__() != other.__len__():
return False
for doc_self, doc_other in zip(self, other):
if doc_self != doc_other:
return False
return True

def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:
"""
Validate if an Iterable of Document are compatible with this DocArray
Expand Down
33 changes: 33 additions & 0 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,39 @@ def __setattr__(self, field, value) -> None:
dict_ref[key] = val
object.__setattr__(self, '__dict__', dict_ref)

def __eq__(self, other) -> bool:
if self.dict().keys() != other.dict().keys():
return False

for field_name in self.__fields__:
value1 = getattr(self, field_name)
value2 = getattr(other, field_name)

if field_name == 'id':
continue

if isinstance(value1, AbstractTensor) and isinstance(
value2, AbstractTensor
):
comp_be1 = value1.get_comp_backend()
comp_be2 = value2.get_comp_backend()

if comp_be1.shape(value1) != comp_be2.shape(value2):
return False
if (
not (comp_be1.to_numpy(value1) == comp_be2.to_numpy(value2))
.all()
.item()
):
return False
else:
if value1 != value2:
return False
return True

def __ne__(self, other) -> bool:
return not (self == other)

def _docarray_to_json_compatible(self) -> Dict:
"""
Convert itself into a json compatible object
Expand Down
71 changes: 70 additions & 1 deletion tests/units/array/test_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, TypeVar, Union

import numpy as np
import pytest
import torch
Expand All @@ -10,6 +9,8 @@

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf

from docarray.typing import TensorFlowTensor


Expand Down Expand Up @@ -80,6 +81,74 @@ class Text(BaseDoc):
assert len(da) == 10


def test_ndarray_equality():
class Text(BaseDoc):
tensor: NdArray

arr1 = Text(tensor=np.zeros(5))
arr2 = Text(tensor=np.zeros(5))
arr3 = Text(tensor=np.ones(5))
arr4 = Text(tensor=np.zeros(4))

assert arr1 == arr2
assert arr1 != arr3
assert arr1 != arr4


def test_tensor_equality():
class Text(BaseDoc):
tensor: TorchTensor

torch1 = Text(tensor=torch.zeros(128))
torch2 = Text(tensor=torch.zeros(128))
torch3 = Text(tensor=torch.zeros(126))
torch4 = Text(tensor=torch.ones(128))

assert torch1 == torch2
assert torch1 != torch3
assert torch1 != torch4


def test_documentarray():
class Text(BaseDoc):
text: str

da1 = DocArray([Text(text='hello')])
da2 = DocArray([Text(text='hello')])

assert da1 == da2
assert da1 == [Text(text='hello') for _ in range(len(da1))]
assert da2 == [Text(text='hello') for _ in range(len(da2))]


@pytest.mark.tensorflow
def test_tensorflowtensor_equality():
class Text(BaseDoc):
tensor: TensorFlowTensor

tensor1 = Text(tensor=tf.constant([1, 2, 3, 4, 5, 6]))
tensor2 = Text(tensor=tf.constant([1, 2, 3, 4, 5, 6]))
tensor3 = Text(tensor=tf.constant([[1.0, 2.0], [3.0, 5.0]]))
tensor4 = Text(tensor=tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]))

assert tensor1 == tensor2
assert tensor1 != tensor3
assert tensor1 != tensor4


def test_text_tensor():
class Text1(BaseDoc):
tensor: NdArray

class Text2(BaseDoc):
tensor: TorchTensor

arr_tensor1 = Text1(tensor=np.zeros(2))
arr_tensor2 = Text2(tensor=torch.zeros(2))

assert arr_tensor1 == arr_tensor2


def test_get_bulk_attributes_function():
class Mmdoc(BaseDoc):
text: str
Expand Down
2 changes: 1 addition & 1 deletion tests/units/document/test_docs_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_text_document_operators():
assert doc == doc2

doc3 = TextDoc(id='other-id', text='text', url='http://url.com')
assert doc != doc3
assert doc == doc3

assert 't' in doc
assert 'a' not in doc
Expand Down