Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion tests/units/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from docarray import BaseDocument, DocumentArray
from docarray.typing import NdArray, TorchTensor
from docarray.typing import ImageUrl, NdArray, TorchTensor
from docarray.utils.misc import is_tf_available

tf_available = is_tf_available()
Expand Down Expand Up @@ -319,3 +319,50 @@ class Text(BaseDocument):
da = DocumentArray[Text].construct(docs)

assert da._data is docs


def test_reverse():
class Text(BaseDocument):
text: str

docs = [Text(text=f'hello {i}') for i in range(10)]

da = DocumentArray[Text](docs)
da.reverse()
assert da[-1].text == 'hello 0'
assert da[0].text == 'hello 9'


class Image(BaseDocument):
tensor: Optional[NdArray]
url: ImageUrl


def test_remove():
images = [Image(url=f'http://url.com/foo_{i}.png') for i in range(3)]
da = DocumentArray[Image](images)
da.remove(images[1])
assert len(da) == 2
assert da[0] == images[0]
assert da[1] == images[2]


def test_pop():
images = [Image(url=f'http://url.com/foo_{i}.png') for i in range(3)]
da = DocumentArray[Image](images)
popped = da.pop(1)
assert len(da) == 2
assert popped == images[1]
assert da[0] == images[0]
assert da[1] == images[2]


def test_sort():
images = [
Image(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1]
]
da = DocumentArray[Image](images)
da.sort(key=lambda img: len(img.tensor))
assert len(da) == 3
assert da[0].url == 'http://url.com/foo_0.png'
assert da[1].url == 'http://url.com/foo_1.png'