From 084c521b7eacfea833591699df84df13f1c2f1ab Mon Sep 17 00:00:00 2001 From: samsja Date: Wed, 22 Mar 2023 10:11:08 +0100 Subject: [PATCH 1/6] fix: remove tensor type from DocumentArray Signed-off-by: samsja --- docarray/array/abstract_array.py | 2 - docarray/array/array/array.py | 18 +++---- docarray/array/stacked/array_stacked.py | 15 +++--- docarray/data/torch_dataset.py | 8 +-- tests/units/array/stack/test_array_stacked.py | 52 ++++++++----------- tests/units/array/test_indexing.py | 38 +++++++------- 6 files changed, 62 insertions(+), 71 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index fcfc3194764..ab84662b81a 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -21,7 +21,6 @@ from docarray.base_document import BaseDocument from docarray.display.document_array_summary import DocumentArraySummary -from docarray.typing import NdArray from docarray.typing.abstract_type import AbstractType from docarray.utils._typing import change_cls_name @@ -36,7 +35,6 @@ class AnyDocumentArray(Sequence[T_doc], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] - tensor_type: Type['AbstractTensor'] = NdArray __typed_da__: Dict[Type['AnyDocumentArray'], Dict[Type[BaseDocument], Type]] = {} def __repr__(self): diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index bf43586a1ac..700d97980ae 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -117,7 +117,6 @@ class Image(BaseDocument): del da[0:5] # remove elements for 0 to 5 from DocumentArray :param docs: iterable of Document - :param tensor_type: Class used to wrap the tensors of the Documents when stacked """ @@ -126,27 +125,22 @@ class Image(BaseDocument): def __init__( self, docs: Optional[Iterable[T_doc]] = None, - tensor_type: Type['AbstractTensor'] = NdArray, ): self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else [] - self.tensor_type = tensor_type @classmethod def construct( cls: Type[T], docs: Sequence[T_doc], - tensor_type: Type['AbstractTensor'] = NdArray, ) -> T: """ Create a DocumentArray without validation any data. The data must come from a trusted source :param docs: a Sequence (list) of Document with the same schema - :param tensor_type: Class used to wrap the tensors of the Documents when stacked :return: """ da = cls.__new__(cls) da._data = docs if isinstance(docs, list) else list(docs) - da.tensor_type = tensor_type return da def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]: @@ -227,7 +221,7 @@ def _get_data_column( # most likely a bug in mypy though # bug reported here https://github.com/python/mypy/issues/14111 return DocumentArray.__class_getitem__(field_type)( - (getattr(doc, field) for doc in self), tensor_type=self.tensor_type + (getattr(doc, field) for doc in self), ) else: return [getattr(doc, field) for doc in self] @@ -247,15 +241,21 @@ def _set_data_column( for doc, value in zip(self, values): setattr(doc, field, value) - def stack(self) -> 'DocumentArrayStacked': + def stack( + self, + tensor_type: Type['AbstractTensor'] = NdArray, + ) -> 'DocumentArrayStacked': """ Convert the DocumentArray into a DocumentArrayStacked. `Self` cannot be used afterwards + :param tensor_type: TensorClass used to wrap the tensors of the Documents when + stacked + :return: A DocumentArrayStacked of the same document type as self """ from docarray.array.stacked.array_stacked import DocumentArrayStacked return DocumentArrayStacked.__class_getitem__(self.document_type)( - self, tensor_type=self.tensor_type + self, tensor_type=tensor_type ) @classmethod diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index 318401b3ea5..5a8764ef8a9 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -158,12 +158,17 @@ def __init__( cast(AbstractTensor, tensor_columns[field_name])[i] = val elif issubclass(field_type, BaseDocument): - doc_columns[field_name] = getattr(docs, field_name).stack() + doc_columns[field_name] = getattr(docs, field_name).stack( + tensor_type=self.tensor_type + ) - elif issubclass(field_type, DocumentArray): + elif issubclass(field_type, AnyDocumentArray): docs_list = list() for doc in docs: - docs_list.append(getattr(doc, field_name).stack()) + da = getattr(doc, field_name) + if isinstance(da, DocumentArray): + da = da.stack(tensor_type=self.tensor_type) + docs_list.append(da) da_columns[field_name] = ListAdvancedIndexing(docs_list) else: any_columns[field_name] = ListAdvancedIndexing( @@ -507,9 +512,7 @@ def unstack(self: T) -> DocumentArray[T_doc]: del self._storage - return DocumentArray.__class_getitem__(self.document_type).construct( - docs, tensor_type=self.tensor_type - ) + return DocumentArray.__class_getitem__(self.document_type).construct(docs) def traverse_flat( self, diff --git a/docarray/data/torch_dataset.py b/docarray/data/torch_dataset.py index d25de6f0732..3032826b4a4 100644 --- a/docarray/data/torch_dataset.py +++ b/docarray/data/torch_dataset.py @@ -2,7 +2,7 @@ from torch.utils.data import Dataset -from docarray import BaseDocument, DocumentArray +from docarray import BaseDocument, DocumentArray, DocumentArrayStacked from docarray.typing import TorchTensor from docarray.utils._typing import change_cls_name @@ -123,13 +123,13 @@ def __getitem__(self, item: int): def collate_fn(cls, batch: List[T_doc]): doc_type = cls.document_type if doc_type: - batch_da = DocumentArray[doc_type]( # type: ignore + batch_da = DocumentArrayStacked[doc_type]( # type: ignore batch, tensor_type=TorchTensor, ) else: - batch_da = DocumentArray(batch, tensor_type=TorchTensor) - return batch_da.stack() + batch_da = DocumentArrayStacked(batch, tensor_type=TorchTensor) + return batch_da @classmethod def __class_getitem__(cls, item: Type[BaseDocument]) -> Type['MultiModalDataset']: diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 8dddf61e47c..dcaa3e89a0a 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -263,10 +263,10 @@ def test_any_tensor_with_torch(tensor_type, tensor): class ImageDoc(BaseDocument): tensor: AnyTensor - da = DocumentArray[ImageDoc]( + da = DocumentArrayStacked[ImageDoc]( [ImageDoc(tensor=tensor) for _ in range(10)], tensor_type=tensor_type, - ).stack() + ) for i in range(len(da)): assert (da[i].tensor == tensor).all() @@ -284,10 +284,10 @@ class ImageDoc(BaseDocument): class TopDoc(BaseDocument): img: ImageDoc - da = DocumentArray[TopDoc]( + da = DocumentArrayStacked[TopDoc]( [TopDoc(img=ImageDoc(tensor=tensor)) for _ in range(10)], tensor_type=TorchTensor, - ).stack() + ) for i in range(len(da)): assert (da.img[i].tensor == tensor).all() @@ -300,9 +300,9 @@ def test_dict_stack(): class MyDoc(BaseDocument): my_dict: Dict[str, int] - da = DocumentArray[MyDoc]( + da = DocumentArrayStacked[MyDoc]( [MyDoc(my_dict={'a': 1, 'b': 2}) for _ in range(10)] - ).stack() + ) da.my_dict @@ -314,9 +314,9 @@ class Doc(BaseDocument): N = 10 - da = DocumentArray[Doc]( + da = DocumentArrayStacked[Doc]( [Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)] - ).stack() + ) da_sliced = da[0:10:2] assert isinstance(da_sliced, DocumentArrayStacked) @@ -334,9 +334,7 @@ def test_stack_embedding(): class MyDoc(BaseDocument): embedding: AnyEmbedding - da = DocumentArray[MyDoc]( - [MyDoc(embedding=np.zeros(10)) for _ in range(10)] - ).stack() + da = DocumentArrayStacked[MyDoc]([MyDoc(embedding=np.zeros(10)) for _ in range(10)]) assert 'embedding' in da._storage.tensor_columns.keys() assert (da.embedding == np.zeros((10, 10))).all() @@ -347,18 +345,17 @@ def test_stack_none(tensor_backend): class MyDoc(BaseDocument): tensor: Optional[AnyTensor] - da = DocumentArray[MyDoc]( + da = DocumentArrayStacked[MyDoc]( [MyDoc(tensor=None) for _ in range(10)], tensor_type=tensor_backend - ).stack() + ) assert 'tensor' in da._storage.tensor_columns.keys() def test_to_device(): - da = DocumentArray[ImageDoc]( + da = DocumentArrayStacked[ImageDoc]( [ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor ) - da = da.stack() assert da.tensor.device == torch.device('cpu') da.to('meta') assert da.tensor.device == torch.device('meta') @@ -368,12 +365,11 @@ def test_to_device_with_nested_da(): class Video(BaseDocument): images: DocumentArray[ImageDoc] - da_image = DocumentArray[ImageDoc]( + da_image = DocumentArrayStacked[ImageDoc]( [ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor ) - da = DocumentArray[Video]([Video(images=da_image)]) - da = da.stack() + da = DocumentArrayStacked[Video]([Video(images=da_image)]) assert da.images[0].tensor.device == torch.device('cpu') da.to('meta') assert da.images[0].tensor.device == torch.device('meta') @@ -384,11 +380,10 @@ class MyDoc(BaseDocument): tensor: TorchTensor docs: ImageDoc - da = DocumentArray[MyDoc]( + da = DocumentArrayStacked[MyDoc]( [MyDoc(tensor=torch.zeros(3, 5), docs=ImageDoc(tensor=torch.zeros(3, 5)))], tensor_type=TorchTensor, ) - da = da.stack() assert da.tensor.device == torch.device('cpu') assert da.docs.tensor.device == torch.device('cpu') da.to('meta') @@ -397,10 +392,9 @@ class MyDoc(BaseDocument): def test_to_device_numpy(): - da = DocumentArray[ImageDoc]( + da = DocumentArrayStacked[ImageDoc]( [ImageDoc(tensor=np.zeros((3, 5)))], tensor_type=NdArray ) - da = da.stack() with pytest.raises(NotImplementedError): da.to('meta') @@ -444,9 +438,7 @@ def test_np_scalar(): class MyDoc(BaseDocument): scalar: NdArray - da = DocumentArray[MyDoc]( - [MyDoc(scalar=np.array(2.0)) for _ in range(3)], tensor_type=NdArray - ) + da = DocumentArray[MyDoc]([MyDoc(scalar=np.array(2.0)) for _ in range(3)]) assert all(doc.scalar.ndim == 0 for doc in da) assert all(doc.scalar == 2.0 for doc in da) @@ -467,11 +459,11 @@ class MyDoc(BaseDocument): scalar: TorchTensor da = DocumentArray[MyDoc]( - [MyDoc(scalar=torch.tensor(2.0)) for _ in range(3)], tensor_type=TorchTensor + [MyDoc(scalar=torch.tensor(2.0)) for _ in range(3)], ) assert all(doc.scalar.ndim == 0 for doc in da) assert all(doc.scalar == 2.0 for doc in da) - stacked_da = da.stack() + stacked_da = da.stack(tensor_type=TorchTensor) assert type(stacked_da.scalar) == TorchTensor assert all(type(doc.scalar) == TorchTensor for doc in stacked_da) @@ -486,7 +478,7 @@ def test_np_nan(): class MyDoc(BaseDocument): scalar: Optional[NdArray] - da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)], tensor_type=NdArray) + da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)]) assert all(doc.scalar is None for doc in da) assert all(doc.scalar == doc.scalar for doc in da) stacked_da = da.stack() @@ -505,10 +497,10 @@ def test_torch_nan(): class MyDoc(BaseDocument): scalar: Optional[TorchTensor] - da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)], tensor_type=TorchTensor) + da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)]) assert all(doc.scalar is None for doc in da) assert all(doc.scalar == doc.scalar for doc in da) - stacked_da = da.stack() + stacked_da = da.stack(tensor_type=TorchTensor) assert type(stacked_da.scalar) == TorchTensor assert all(type(doc.scalar) == TorchTensor for doc in stacked_da) diff --git a/tests/units/array/test_indexing.py b/tests/units/array/test_indexing.py index a16d78424b6..9d875b1b6bf 100644 --- a/tests/units/array/test_indexing.py +++ b/tests/units/array/test_indexing.py @@ -2,7 +2,7 @@ import pytest import torch -from docarray import DocumentArray +from docarray import DocumentArray, DocumentArrayStacked from docarray.documents import TextDoc from docarray.typing import TorchTensor @@ -13,7 +13,6 @@ def da(): tensors = [torch.ones((4,)) * i for i in range(10)] return DocumentArray[TextDoc]( [TextDoc(text=text, embedding=tens) for text, tens in zip(texts, tensors)], - tensor_type=TorchTensor, ) @@ -23,7 +22,6 @@ def da_to_set(): tensors = [torch.ones((4,)) * i * 2 for i in range(5)] return DocumentArray[TextDoc]( [TextDoc(text=text, embedding=tens) for text, tens in zip(texts, tensors)], - tensor_type=TorchTensor, ) @@ -35,7 +33,7 @@ def da_to_set(): @pytest.mark.parametrize('stack', [True, False]) def test_simple_getitem(stack, da): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) assert torch.all(da[0].embedding == torch.zeros((4,))) assert da[0].text == 'hello 0' @@ -44,7 +42,7 @@ def test_simple_getitem(stack, da): @pytest.mark.parametrize('stack', [True, False]) def test_get_none(stack, da): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) assert da[None] is da @@ -53,7 +51,7 @@ def test_get_none(stack, da): @pytest.mark.parametrize('index', [(1, 2, 3, 4, 6), [1, 2, 3, 4, 6]]) def test_iterable_getitem(stack, da, index): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) indexed_da = da[index] @@ -66,7 +64,7 @@ def test_iterable_getitem(stack, da, index): @pytest.mark.parametrize('index_dtype', [torch.int64]) def test_torchtensor_getitem(stack, da, index_dtype): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) index = torch.tensor([1, 2, 3, 4, 6], dtype=index_dtype) @@ -81,7 +79,7 @@ def test_torchtensor_getitem(stack, da, index_dtype): @pytest.mark.parametrize('index_dtype', [int, np.int_, np.int32, np.int64]) def test_nparray_getitem(stack, da, index_dtype): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) index = np.array([1, 2, 3, 4, 6], dtype=index_dtype) @@ -103,7 +101,7 @@ def test_nparray_getitem(stack, da, index_dtype): ) def test_boolmask_getitem(stack, da, index): if stack: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) indexed_da = da[index] @@ -122,7 +120,7 @@ def test_boolmask_getitem(stack, da, index): @pytest.mark.parametrize('stack_left', [True, False]) def test_simple_setitem(stack_left, da, da_to_set): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) da[0] = da_to_set[0] @@ -135,9 +133,9 @@ def test_simple_setitem(stack_left, da, da_to_set): @pytest.mark.parametrize('index', [(1, 2, 3, 4, 6), [1, 2, 3, 4, 6]]) def test_iterable_setitem(stack_left, stack_right, da, da_to_set, index): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) if stack_right: - da_to_set = da_to_set.stack() + da_to_set = da_to_set.stack(tensor_type=TorchTensor) da[index] = da_to_set @@ -158,9 +156,9 @@ def test_iterable_setitem(stack_left, stack_right, da, da_to_set, index): @pytest.mark.parametrize('index_dtype', [torch.int64]) def test_torchtensor_setitem(stack_left, stack_right, da, da_to_set, index_dtype): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) if stack_right: - da_to_set = da_to_set.stack() + da_to_set = da_to_set.stack(tensor_type=TorchTensor) index = torch.tensor([1, 2, 3, 4, 6], dtype=index_dtype) @@ -183,9 +181,9 @@ def test_torchtensor_setitem(stack_left, stack_right, da, da_to_set, index_dtype @pytest.mark.parametrize('index_dtype', [int, np.int_, np.int32, np.int64]) def test_nparray_setitem(stack_left, stack_right, da, da_to_set, index_dtype): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) if stack_right: - da_to_set = da_to_set.stack() + da_to_set = da_to_set.stack(tensor_type=TorchTensor) index = np.array([1, 2, 3, 4, 6], dtype=index_dtype) @@ -216,9 +214,9 @@ def test_nparray_setitem(stack_left, stack_right, da, da_to_set, index_dtype): ) def test_boolmask_setitem(stack_left, stack_right, da, da_to_set, index): if stack_left: - da = da.stack() + da = da.stack(tensor_type=TorchTensor) if stack_right: - da_to_set = da_to_set.stack() + da_to_set = da_to_set.stack(tensor_type=TorchTensor) da[index] = da_to_set @@ -238,10 +236,10 @@ def test_boolmask_setitem(stack_left, stack_right, da, da_to_set, index): def test_setitem_update_column(): texts = [f'hello {i}' for i in range(10)] tensors = [torch.ones((4,)) * (i + 1) for i in range(10)] - da = DocumentArray[TextDoc]( + da = DocumentArrayStacked[TextDoc]( [TextDoc(text=text, embedding=tens) for text, tens in zip(texts, tensors)], tensor_type=TorchTensor, - ).stack() + ) da[0] = TextDoc(text='hello', embedding=torch.zeros((4,))) From 4f71a8451df81abb62a63830c48757ae88dd2131 Mon Sep 17 00:00:00 2001 From: samsja Date: Wed, 22 Mar 2023 10:23:32 +0100 Subject: [PATCH 2/6] fix: fix test Signed-off-by: samsja --- docarray/array/stacked/array_stacked.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index 5a8764ef8a9..1780e755685 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -323,7 +323,9 @@ def _set_data_and_columns( f'{value} schema : {value.document_type} is not compatible with ' f'this DocumentArrayStacked schema : {self.document_type}' ) - processed_value = cast(T, value.stack()) # we need to copy data here + processed_value = cast( + T, value.stack(tensor_type=self.tensor_type) + ) # we need to copy data here elif isinstance(value, DocumentArrayStacked): if not issubclass(value.document_type, self.document_type): From 6f58315b43f4265061eeb79072a9cdea6c72137b Mon Sep 17 00:00:00 2001 From: samsja Date: Wed, 22 Mar 2023 10:26:57 +0100 Subject: [PATCH 3/6] fix: fix tensorflow test Signed-off-by: samsja --- .../array/stack/test_array_stacked_tf.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/units/array/stack/test_array_stacked_tf.py b/tests/units/array/stack/test_array_stacked_tf.py index 2603c17e42c..f18a4604fde 100644 --- a/tests/units/array/stack/test_array_stacked_tf.py +++ b/tests/units/array/stack/test_array_stacked_tf.py @@ -39,7 +39,7 @@ class MMdoc(BaseDocument): import tensorflow as tf - batch = DocumentArray[MMdoc]( + batch = DocumentArrayStacked[MMdoc]( [ MMdoc( img=DocumentArray[Image]( @@ -50,7 +50,7 @@ class MMdoc(BaseDocument): ] ) - return batch.stack() + return batch @pytest.mark.tensorflow @@ -84,11 +84,10 @@ def test_set_after_stacking(): class Image(BaseDocument): tensor: TensorFlowTensor[3, 224, 224] - batch = DocumentArray[Image]( + batch = DocumentArrayStacked[Image]( [Image(tensor=tf.zeros((3, 224, 224))) for _ in range(10)] ) - batch = batch.stack() batch.tensor = tf.ones((10, 3, 224, 224)) assert tnp.allclose(batch.tensor.tensor, tf.ones((10, 3, 224, 224))) for i, doc in enumerate(batch): @@ -114,9 +113,7 @@ class MMdoc(BaseDocument): batch = DocumentArray[MMdoc]( [MMdoc(img=Image(tensor=tf.zeros((3, 224, 224)))) for _ in range(10)] - ) - - batch = batch.stack() + ).stack() assert tnp.allclose( batch._storage.doc_columns['img']._storage.tensor_columns['tensor'].tensor, @@ -222,10 +219,10 @@ def test_any_tensor_with_tf(): class Image(BaseDocument): tensor: AnyTensor - da = DocumentArray[Image]( + da = DocumentArrayStacked[Image]( [Image(tensor=tensor) for _ in range(10)], tensor_type=TensorFlowTensor, - ).stack() + ) for i in range(len(da)): assert tnp.allclose(da[i].tensor.tensor, tensor) @@ -244,10 +241,10 @@ class Image(BaseDocument): class TopDoc(BaseDocument): img: Image - da = DocumentArray[TopDoc]( + da = DocumentArrayStacked[TopDoc]( [TopDoc(img=Image(tensor=tensor)) for _ in range(10)], tensor_type=TensorFlowTensor, - ).stack() + ) for i in range(len(da)): assert tnp.allclose(da.img[i].tensor.tensor, tensor) @@ -263,9 +260,9 @@ class Doc(BaseDocument): text: str tensor: TensorFlowTensor - da = DocumentArray[Doc]( + da = DocumentArrayStacked[Doc]( [Doc(text=f'hello{i}', tensor=tf.zeros((3, 224, 224))) for i in range(10)] - ).stack() + ) da_sliced = da[0:10:2] assert isinstance(da_sliced, DocumentArrayStacked) @@ -279,10 +276,9 @@ def test_stack_none(): class MyDoc(BaseDocument): tensor: Optional[AnyTensor] - da = DocumentArray[MyDoc]( + da = DocumentArrayStacked[MyDoc]( [MyDoc(tensor=None) for _ in range(10)], tensor_type=TensorFlowTensor - ).stack() - + ) assert 'tensor' in da._storage.tensor_columns.keys() From ab9bfc4ed5a08003d3a921d7d16869f95d353859 Mon Sep 17 00:00:00 2001 From: samsja Date: Wed, 22 Mar 2023 10:31:30 +0100 Subject: [PATCH 4/6] fix: docstrng Signed-off-by: samsja --- docarray/array/array/array.py | 4 ++-- docarray/array/stacked/array_stacked.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index 700d97980ae..d2d8fad19c1 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -248,8 +248,8 @@ def stack( """ Convert the DocumentArray into a DocumentArrayStacked. `Self` cannot be used afterwards - :param tensor_type: TensorClass used to wrap the tensors of the Documents when - stacked + :param tensor_type: Tensor Class used to wrap the stacked tensors. This is usefull + if the BaseDocument has some undefined tensor type like AnyTensor or Union of NdArray and TorchTensor :return: A DocumentArrayStacked of the same document type as self """ from docarray.array.stacked.array_stacked import DocumentArrayStacked diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index 1780e755685..3b786502319 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -82,8 +82,9 @@ class DocumentArrayStacked(AnyDocumentArray[T_doc]): numpy/PyTorch. :param docs: a DocumentArray - :param tensor_type: Class used to wrap the stacked tensors - + :param tensor_type: Tensor Class used to wrap the stacked tensors. This is usefull + if the BaseDocument of this DocumentArrayStacked has some undefined tensor type like + AnyTensor or Union of NdArray and TorchTensor """ document_type: Type[T_doc] From 475c14c494738e488c5feaa2fcbd890c6469d167 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Wed, 22 Mar 2023 10:53:22 +0100 Subject: [PATCH 5/6] feat: apply charllote suggestion Co-authored-by: Charlotte Gerhaher Signed-off-by: samsja <55492238+samsja@users.noreply.github.com> --- docarray/array/array/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/array/array.py b/docarray/array/array/array.py index d2d8fad19c1..4d5a720421e 100644 --- a/docarray/array/array/array.py +++ b/docarray/array/array/array.py @@ -248,7 +248,7 @@ def stack( """ Convert the DocumentArray into a DocumentArrayStacked. `Self` cannot be used afterwards - :param tensor_type: Tensor Class used to wrap the stacked tensors. This is usefull + :param tensor_type: Tensor Class used to wrap the stacked tensors. This is useful if the BaseDocument has some undefined tensor type like AnyTensor or Union of NdArray and TorchTensor :return: A DocumentArrayStacked of the same document type as self """ From d2d4434f890f4c2ef1c18ea15cec87a07695ec9a Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Wed, 22 Mar 2023 13:19:29 +0100 Subject: [PATCH 6/6] feat: apply saba suggestion Co-authored-by: Saba Sturua <45267439+jupyterjazz@users.noreply.github.com> Signed-off-by: samsja <55492238+samsja@users.noreply.github.com> --- docarray/array/stacked/array_stacked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/stacked/array_stacked.py b/docarray/array/stacked/array_stacked.py index 3b786502319..c609a2acc31 100644 --- a/docarray/array/stacked/array_stacked.py +++ b/docarray/array/stacked/array_stacked.py @@ -82,7 +82,7 @@ class DocumentArrayStacked(AnyDocumentArray[T_doc]): numpy/PyTorch. :param docs: a DocumentArray - :param tensor_type: Tensor Class used to wrap the stacked tensors. This is usefull + :param tensor_type: Tensor Class used to wrap the stacked tensors. This is useful if the BaseDocument of this DocumentArrayStacked has some undefined tensor type like AnyTensor or Union of NdArray and TorchTensor """