diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index 8a5557e4858..af95e05bfb1 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -5,7 +5,7 @@ import torch from docarray import BaseDocument, DocumentArray -from docarray.typing import NdArray, TorchTensor +from docarray.typing import ImageUrl, NdArray, TorchTensor from docarray.utils.misc import is_tf_available tf_available = is_tf_available() @@ -319,3 +319,50 @@ class Text(BaseDocument): da = DocumentArray[Text].construct(docs) assert da._data is docs + + +def test_reverse(): + class Text(BaseDocument): + text: str + + docs = [Text(text=f'hello {i}') for i in range(10)] + + da = DocumentArray[Text](docs) + da.reverse() + assert da[-1].text == 'hello 0' + assert da[0].text == 'hello 9' + + +class Image(BaseDocument): + tensor: Optional[NdArray] + url: ImageUrl + + +def test_remove(): + images = [Image(url=f'http://url.com/foo_{i}.png') for i in range(3)] + da = DocumentArray[Image](images) + da.remove(images[1]) + assert len(da) == 2 + assert da[0] == images[0] + assert da[1] == images[2] + + +def test_pop(): + images = [Image(url=f'http://url.com/foo_{i}.png') for i in range(3)] + da = DocumentArray[Image](images) + popped = da.pop(1) + assert len(da) == 2 + assert popped == images[1] + assert da[0] == images[0] + assert da[1] == images[2] + + +def test_sort(): + images = [ + Image(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1] + ] + da = DocumentArray[Image](images) + da.sort(key=lambda img: len(img.tensor)) + assert len(da) == 3 + assert da[0].url == 'http://url.com/foo_0.png' + assert da[1].url == 'http://url.com/foo_1.png'