diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index fcfc3194764..ab84662b81a 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -21,7 +21,6 @@ from docarray.base_document import BaseDocument from docarray.display.document_array_summary import DocumentArraySummary -from docarray.typing import NdArray from docarray.typing.abstract_type import AbstractType from docarray.utils._typing import change_cls_name @@ -36,7 +35,6 @@ class AnyDocumentArray(Sequence[T_doc], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] - tensor_type: Type['AbstractTensor'] = NdArray __typed_da__: Dict[Type['AnyDocumentArray'], Dict[Type[BaseDocument], Type]] = {} def __repr__(self): diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index bf43586a1ac..4d5a720421e 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -117,7 +117,6 @@ class Image(BaseDocument): del da[0:5] # remove elements for 0 to 5 from DocumentArray :param docs: iterable of Document - :param tensor_type: Class used to wrap the tensors of the Documents when stacked """ @@ -126,27 +125,22 @@ class Image(BaseDocument): def __init__( self, docs: Optional[Iterable[T_doc]] = None, - tensor_type: Type['AbstractTensor'] = NdArray, ): self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else [] - self.tensor_type = tensor_type @classmethod def construct( cls: Type[T], docs: Sequence[T_doc], - tensor_type: Type['AbstractTensor'] = NdArray, ) -> T: """ Create a DocumentArray without validation any data. The data must come from a trusted source :param docs: a Sequence (list) of Document with the same schema - :param tensor_type: Class used to wrap the tensors of the Documents when stacked :return: """ da = cls.__new__(cls) da._data = docs if isinstance(docs, list) else list(docs) - da.tensor_type = tensor_type return da def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]: @@ -227,7 +221,7 @@ def _get_data_column( # most likely a bug in mypy though # bug reported here https://github.com/python/mypy/issues/14111 return DocumentArray.__class_getitem__(field_type)( - (getattr(doc, field) for doc in self), tensor_type=self.tensor_type + (getattr(doc, field) for doc in self), ) else: return [getattr(doc, field) for doc in self] @@ -247,15 +241,21 @@ def _set_data_column( for doc, value in zip(self, values): setattr(doc, field, value) - def stack(self) -> 'DocumentArrayStacked': + def stack( + self, + tensor_type: Type['AbstractTensor'] = NdArray, + ) -> 'DocumentArrayStacked': """ Convert the DocumentArray into a DocumentArrayStacked. `Self` cannot be used afterwards + :param tensor_type: Tensor Class used to wrap the stacked tensors. This is useful + if the BaseDocument has some undefined tensor type like AnyTensor or Union of NdArray and TorchTensor + :return: A DocumentArrayStacked of the same document type as self """ from docarray.array.stacked.array_stacked import DocumentArrayStacked return DocumentArrayStacked.__class_getitem__(self.document_type)( - self, tensor_type=self.tensor_type + self, tensor_type=tensor_type ) @classmethod diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index 318401b3ea5..c609a2acc31 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -82,8 +82,9 @@ class DocumentArrayStacked(AnyDocumentArray[T_doc]): numpy/PyTorch. :param docs: a DocumentArray - :param tensor_type: Class used to wrap the stacked tensors - + :param tensor_type: Tensor Class used to wrap the stacked tensors. This is useful + if the BaseDocument of this DocumentArrayStacked has some undefined tensor type like + AnyTensor or Union of NdArray and TorchTensor """ document_type: Type[T_doc] @@ -158,12 +159,17 @@ def __init__( cast(AbstractTensor, tensor_columns[field_name])[i] = val elif issubclass(field_type, BaseDocument): - doc_columns[field_name] = getattr(docs, field_name).stack() + doc_columns[field_name] = getattr(docs, field_name).stack( + tensor_type=self.tensor_type + ) - elif issubclass(field_type, DocumentArray): + elif issubclass(field_type, AnyDocumentArray): docs_list = list() for doc in docs: - docs_list.append(getattr(doc, field_name).stack()) + da = getattr(doc, field_name) + if isinstance(da, DocumentArray): + da = da.stack(tensor_type=self.tensor_type) + docs_list.append(da) da_columns[field_name] = ListAdvancedIndexing(docs_list) else: any_columns[field_name] = ListAdvancedIndexing( @@ -318,7 +324,9 @@ def _set_data_and_columns( f'{value} schema : {value.document_type} is not compatible with ' f'this DocumentArrayStacked schema : {self.document_type}' ) - processed_value = cast(T, value.stack()) # we need to copy data here + processed_value = cast( + T, value.stack(tensor_type=self.tensor_type) + ) # we need to copy data here elif isinstance(value, DocumentArrayStacked): if not issubclass(value.document_type, self.document_type): @@ -507,9 +515,7 @@ def unstack(self: T) -> DocumentArray[T_doc]: del self._storage - return DocumentArray.__class_getitem__(self.document_type).construct( - docs, tensor_type=self.tensor_type - ) + return DocumentArray.__class_getitem__(self.document_type).construct(docs) def traverse_flat( self, diff --git a/docarray/data/torch_dataset.py b/docarray/data/torch_dataset.py index d25de6f0732..3032826b4a4 100644 --- a/docarray/data/torch_dataset.py +++ b/docarray/data/torch_dataset.py @@ -2,7 +2,7 @@ from torch.utils.data import Dataset -from docarray import BaseDocument, DocumentArray +from docarray import BaseDocument, DocumentArray, DocumentArrayStacked from docarray.typing import TorchTensor from docarray.utils._typing import change_cls_name @@ -123,13 +123,13 @@ def __getitem__(self, item: int): def collate_fn(cls, batch: List[T_doc]): doc_type = cls.document_type if doc_type: - batch_da = DocumentArray[doc_type]( # type: ignore + batch_da = DocumentArrayStacked[doc_type]( # type: ignore batch, tensor_type=TorchTensor, ) else: - batch_da = DocumentArray(batch, tensor_type=TorchTensor) - return batch_da.stack() + batch_da = DocumentArrayStacked(batch, tensor_type=TorchTensor) + return batch_da @classmethod def __class_getitem__(cls, item: Type[BaseDocument]) -> Type['MultiModalDataset']: diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 8dddf61e47c..dcaa3e89a0a 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -263,10 +263,10 @@ def test_any_tensor_with_torch(tensor_type, tensor): class ImageDoc(BaseDocument): tensor: AnyTensor - da = DocumentArray[ImageDoc]( + da = DocumentArrayStacked[ImageDoc]( [ImageDoc(tensor=tensor) for _ in range(10)], tensor_type=tensor_type, - ).stack() + ) for i in range(len(da)): assert (da[i].tensor == tensor).all() @@ -284,10 +284,10 @@ class ImageDoc(BaseDocument): class TopDoc(BaseDocument): img: ImageDoc - da = DocumentArray[TopDoc]( + da = DocumentArrayStacked[TopDoc]( [TopDoc(img=ImageDoc(tensor=tensor)) for _ in range(10)], tensor_type=TorchTensor, - ).stack() + ) for i in range(len(da)): assert (da.img[i].tensor == tensor).all() @@ -300,9 +300,9 @@ def test_dict_stack(): class MyDoc(BaseDocument): my_dict: Dict[str, int] - da = DocumentArray[MyDoc]( + da = DocumentArrayStacked[MyDoc]( [MyDoc(my_dict={'a': 1, 'b': 2}) for _ in range(10)] - ).stack() + ) da.my_dict @@ -314,9 +314,9 @@ class Doc(BaseDocument): N = 10 - da = DocumentArray[Doc]( + da = DocumentArrayStacked[Doc]( [Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)] - ).stack() + ) da_sliced = da[0:10:2] assert isinstance(da_sliced, DocumentArrayStacked) @@ -334,9 +334,7 @@ def test_stack_embedding(): class MyDoc(BaseDocument): embedding: AnyEmbedding - da = DocumentArray[MyDoc]( - [MyDoc(embedding=np.zeros(10)) for _ in range(10)] - ).stack() + da = DocumentArrayStacked[MyDoc]([MyDoc(embedding=np.zeros(10)) for _ in range(10)]) assert 'embedding' in da._storage.tensor_columns.keys() assert (da.embedding == np.zeros((10, 10))).all() @@ -347,18 +345,17 @@ def test_stack_none(tensor_backend): class MyDoc(BaseDocument): tensor: Optional[AnyTensor] - da = DocumentArray[MyDoc]( + da = DocumentArrayStacked[MyDoc]( [MyDoc(tensor=None) for _ in range(10)], tensor_type=tensor_backend - ).stack() + ) assert 'tensor' in da._storage.tensor_columns.keys() def test_to_device(): - da = DocumentArray[ImageDoc]( + da = DocumentArrayStacked[ImageDoc]( [ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor ) - da = da.stack() assert da.tensor.device == torch.device('cpu') da.to('meta') assert da.tensor.device == torch.device('meta') @@ -368,12 +365,11 @@ def test_to_device_with_nested_da(): class Video(BaseDocument): images: DocumentArray[ImageDoc] - da_image = DocumentArray[ImageDoc]( + da_image = DocumentArrayStacked[ImageDoc]( [ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor ) - da = DocumentArray[Video]([Video(images=da_image)]) - da = da.stack() + da = DocumentArrayStacked[Video]([Video(images=da_image)]) assert da.images[0].tensor.device == torch.device('cpu') da.to('meta') assert da.images[0].tensor.device == torch.device('meta') @@ -384,11 +380,10 @@ class MyDoc(BaseDocument): tensor: TorchTensor docs: ImageDoc - da = DocumentArray[MyDoc]( + da = DocumentArrayStacked[MyDoc]( [MyDoc(tensor=torch.zeros(3, 5), docs=ImageDoc(tensor=torch.zeros(3, 5)))], tensor_type=TorchTensor, ) - da = da.stack() assert da.tensor.device == torch.device('cpu') assert da.docs.tensor.device == torch.device('cpu') da.to('meta') @@ -397,10 +392,9 @@ class MyDoc(BaseDocument): def test_to_device_numpy(): - da = DocumentArray[ImageDoc]( + da = DocumentArrayStacked[ImageDoc]( [ImageDoc(tensor=np.zeros((3, 5)))], tensor_type=NdArray ) - da = da.stack() with pytest.raises(NotImplementedError): da.to('meta') @@ -444,9 +438,7 @@ def test_np_scalar(): class MyDoc(BaseDocument): scalar: NdArray - da = DocumentArray[MyDoc]( - [MyDoc(scalar=np.array(2.0)) for _ in range(3)], tensor_type=NdArray - ) + da = DocumentArray[MyDoc]([MyDoc(scalar=np.array(2.0)) for _ in range(3)]) assert all(doc.scalar.ndim == 0 for doc in da) assert all(doc.scalar == 2.0 for doc in da) @@ -467,11 +459,11 @@ class MyDoc(BaseDocument): scalar: TorchTensor da = DocumentArray[MyDoc]( - [MyDoc(scalar=torch.tensor(2.0)) for _ in range(3)], tensor_type=TorchTensor + [MyDoc(scalar=torch.tensor(2.0)) for _ in range(3)], ) assert all(doc.scalar.ndim == 0 for doc in da) assert all(doc.scalar == 2.0 for doc in da) - stacked_da = da.stack() + stacked_da = da.stack(tensor_type=TorchTensor) assert type(stacked_da.scalar) == TorchTensor assert all(type(doc.scalar) == TorchTensor for doc in stacked_da) @@ -486,7 +478,7 @@ def test_np_nan(): class MyDoc(BaseDocument): scalar: Optional[NdArray] - da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)], tensor_type=NdArray) + da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)]) assert all(doc.scalar is None for doc in da) assert all(doc.scalar == doc.scalar for doc in da) stacked_da = da.stack() @@ -505,10 +497,10 @@ def test_torch_nan(): class MyDoc(BaseDocument): scalar: Optional[TorchTensor] - da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)], tensor_type=TorchTensor) + da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)]) assert all(doc.scalar is None for doc in da) assert all(doc.scalar == doc.scalar for doc in da) - stacked_da = da.stack() + stacked_da = da.stack(tensor_type=TorchTensor) assert type(stacked_da.scalar) == TorchTensor assert all(type(doc.scalar) == TorchTensor for doc in stacked_da) diff --git a/tests/units/array/stack/test_array_stacked_tf.py b/tests/units/array/stack/test_array_stacked_tf.py index 2603c17e42c..f18a4604fde 100644 --- a/tests/units/array/stack/test_array_stacked_tf.py +++ b/tests/units/array/stack/test_array_stacked_tf.py @@ -39,7 +39,7 @@ class MMdoc(BaseDocument): import tensorflow as tf - batch = DocumentArray[MMdoc]( + batch = DocumentArrayStacked[MMdoc]( [ MMdoc( img=DocumentArray[Image]( @@ -50,7 +50,7 @@ class MMdoc(BaseDocument): ] ) - return batch.stack() + return batch @pytest.mark.tensorflow @@ -84,11 +84,10 @@ def test_set_after_stacking(): class Image(BaseDocument): tensor: TensorFlowTensor[3, 224, 224] - batch = DocumentArray[Image]( + batch = DocumentArrayStacked[Image]( [Image(tensor=tf.zeros((3, 224, 224))) for _ in range(10)] ) - batch = batch.stack() batch.tensor = tf.ones((10, 3, 224, 224)) assert tnp.allclose(batch.tensor.tensor, tf.ones((10, 3, 224, 224))) for i, doc in enumerate(batch): @@ -114,9 +113,7 @@ class MMdoc(BaseDocument): batch = DocumentArray[MMdoc]( [MMdoc(img=Image(tensor=tf.zeros((3, 224, 224)))) for _ in range(10)] - ) - - batch = batch.stack() + ).stack() assert tnp.allclose( batch._storage.doc_columns['img']._storage.tensor_columns['tensor'].tensor, @@ -222,10 +219,10 @@ def test_any_tensor_with_tf(): class Image(BaseDocument): tensor: AnyTensor - da = DocumentArray[Image]( + da = DocumentArrayStacked[Image]( [Image(tensor=tensor) for _ in range(10)], tensor_type=TensorFlowTensor, - ).stack() + ) for i in range(len(da)): assert tnp.allclose(da[i].tensor.tensor, tensor) @@ -244,10 +241,10 @@ class Image(BaseDocument): class TopDoc(BaseDocument): img: Image - da = DocumentArray[TopDoc]( + da = DocumentArrayStacked[TopDoc]( [TopDoc(img=Image(tensor=tensor)) for _ in range(10)], tensor_type=TensorFlowTensor, - ).stack() + ) for i in range(len(da)): assert tnp.allclose(da.img[i].tensor.tensor, tensor) @@ -263,9 +260,9 @@ class Doc(BaseDocument): text: str tensor: TensorFlowTensor - da = DocumentArray[Doc]( + da = DocumentArrayStacked[Doc]( [Doc(text=f'hello{i}', tensor=tf.zeros((3, 224, 224))) for i in range(10)] - ).stack() + ) da_sliced = da[0:10:2] assert isinstance(da_sliced, DocumentArrayStacked) @@ -279,10 +276,9 @@ def test_stack_none(): class MyDoc(BaseDocument): tensor: Optional[AnyTensor] - da = DocumentArray[MyDoc]( + da = DocumentArrayStacked[MyDoc]( [MyDoc(tensor=None) for _ in range(10)], tensor_type=TensorFlowTensor - ).stack() - + ) assert 'tensor' in da._storage.tensor_columns.keys() diff --git a/tests/units/array/test_indexing.py b/tests/units/array/test_indexing.py index a16d78424b6..9d875b1b6bf 100644 --- a/tests/units/array/test_indexing.py +++ b/tests/units/array/test_indexing.py @@ -2,7 +2,7 @@ import pytest import torch -from docarray import DocumentArray +from docarray import DocumentArray, DocumentArrayStacked from docarray.documents import TextDoc from docarray.typing import TorchTensor @@ -13,7 +13,6 @@ def da(): tensors = [torch.ones((4,)) * i for i in range(10)] return DocumentArray[TextDoc]( [TextDoc(text=text, embedding=tens) for text, tens in zip(texts, tensors)], - tensor_type=TorchTensor, ) @@ -23,7 +22,6 @@ def da_to_set(): tensors = [torch.ones((4,)) * i * 2 for i in range(5)] return DocumentArray[TextDoc]( [TextDoc(text=text, embedding=tens) for text, tens in zip(texts, tensors)], - tensor_type=TorchTensor, ) @@ -35,7 +33,7 @@ def da_to_set(): @pytest.mark.parametrize('stack', [True, False]) def test_simple_getitem(stack, da): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) assert torch.all(da[0].embedding == torch.zeros((4,))) assert da[0].text == 'hello 0' @@ -44,7 +42,7 @@ def test_simple_getitem(stack, da): @pytest.mark.parametrize('stack', [True, False]) def test_get_none(stack, da): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) assert da[None] is da @@ -53,7 +51,7 @@ def test_get_none(stack, da): @pytest.mark.parametrize('index', [(1, 2, 3, 4, 6), [1, 2, 3, 4, 6]]) def test_iterable_getitem(stack, da, index): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) indexed_da = da[index] @@ -66,7 +64,7 @@ def test_iterable_getitem(stack, da, index): @pytest.mark.parametrize('index_dtype', [torch.int64]) def test_torchtensor_getitem(stack, da, index_dtype): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) index = torch.tensor([1, 2, 3, 4, 6], dtype=index_dtype) @@ -81,7 +79,7 @@ def test_torchtensor_getitem(stack, da, index_dtype): @pytest.mark.parametrize('index_dtype', [int, np.int_, np.int32, np.int64]) def test_nparray_getitem(stack, da, index_dtype): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) index = np.array([1, 2, 3, 4, 6], dtype=index_dtype) @@ -103,7 +101,7 @@ def test_nparray_getitem(stack, da, index_dtype): ) def test_boolmask_getitem(stack, da, index): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) indexed_da = da[index] @@ -122,7 +120,7 @@ def test_boolmask_getitem(stack, da, index): @pytest.mark.parametrize('stack_left', [True, False]) def test_simple_setitem(stack_left, da, da_to_set): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) da[0] = da_to_set[0] @@ -135,9 +133,9 @@ def test_simple_setitem(stack_left, da, da_to_set): @pytest.mark.parametrize('index', [(1, 2, 3, 4, 6), [1, 2, 3, 4, 6]]) def test_iterable_setitem(stack_left, stack_right, da, da_to_set, index): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) if stack_right: - da_to_set = da_to_set.stack() + da_to_set = da_to_set.stack(tensor_type=TorchTensor) da[index] = da_to_set @@ -158,9 +156,9 @@ def test_iterable_setitem(stack_left, stack_right, da, da_to_set, index): @pytest.mark.parametrize('index_dtype', [torch.int64]) def test_torchtensor_setitem(stack_left, stack_right, da, da_to_set, index_dtype): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) if stack_right: - da_to_set = da_to_set.stack() + da_to_set = da_to_set.stack(tensor_type=TorchTensor) index = torch.tensor([1, 2, 3, 4, 6], dtype=index_dtype) @@ -183,9 +181,9 @@ def test_torchtensor_setitem(stack_left, stack_right, da, da_to_set, index_dtype @pytest.mark.parametrize('index_dtype', [int, np.int_, np.int32, np.int64]) def test_nparray_setitem(stack_left, stack_right, da, da_to_set, index_dtype): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) if stack_right: - da_to_set = da_to_set.stack() + da_to_set = da_to_set.stack(tensor_type=TorchTensor) index = np.array([1, 2, 3, 4, 6], dtype=index_dtype) @@ -216,9 +214,9 @@ def test_nparray_setitem(stack_left, stack_right, da, da_to_set, index_dtype): ) def test_boolmask_setitem(stack_left, stack_right, da, da_to_set, index): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) if stack_right: - da_to_set = da_to_set.stack() + da_to_set = da_to_set.stack(tensor_type=TorchTensor) da[index] = da_to_set @@ -238,10 +236,10 @@ def test_boolmask_setitem(stack_left, stack_right, da, da_to_set, index): def test_setitem_update_column(): texts = [f'hello {i}' for i in range(10)] tensors = [torch.ones((4,)) * (i + 1) for i in range(10)] - da = DocumentArray[TextDoc]( + da = DocumentArrayStacked[TextDoc]( [TextDoc(text=text, embedding=tens) for text, tens in zip(texts, tensors)], tensor_type=TorchTensor, - ).stack() + ) da[0] = TextDoc(text='hello', embedding=torch.zeros((4,)))