From ac0800f94e3bb4b3fff7c8323d5d5b59886ee011 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 13 Jan 2023 17:03:49 +0100 Subject: [PATCH 1/6] refactor: better column creation Signed-off-by: Sami Jaghouar --- docarray/array/array_stacked.py | 105 +++++++----------- docarray/computation/abstract_comp_backend.py | 5 + docarray/computation/numpy_backend.py | 4 + docarray/computation/torch_backend.py | 4 + tests/units/array/test_array_stacked.py | 2 +- .../numpy_backend/test_basics.py | 5 + .../torch_backend/test_basics.py | 5 + tests/units/util/test_typing.py | 1 + 8 files changed, 68 insertions(+), 63 deletions(-) diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 71d5b3547a8..4ed81b5f4c4 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -1,17 +1,5 @@ -from collections import defaultdict from contextlib import contextmanager -from typing import ( - TYPE_CHECKING, - Any, - DefaultDict, - Dict, - Iterable, - List, - Type, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Type, TypeVar, Union, cast from docarray.array.abstract_array import AnyDocumentArray from docarray.array.array import DocumentArray @@ -28,10 +16,6 @@ from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor -try: - import torch -except ImportError: - torch_imported = False else: from docarray.typing import TorchTensor @@ -112,64 +96,61 @@ def to(self: T, device: str): col_docarray.to(device) @classmethod - def _create_columns( - cls: Type[T], docs: DocumentArray, tensor_type: Type['AbstractTensor'] - ) -> Dict[str, Union[T, AbstractTensor]]: - columns_fields = list() + def _get_columns_schema( + cls: Type[T], + tensor_type: Type[AbstractTensor], + ) -> Dict[str, Union[Type[AbstractTensor], Type[BaseDocument]]]: + """ + Return the list of fields that are tensors and the list of fields that are + documents + :param tensor_type: the default tensor type fallback in case of union of tensor + :return: a tuple of two lists, the first one is the list of fields that are + tensors, the second one is the list of fields that are documents + """ + + column_schema: Dict[str, Union[Type[AbstractTensor], Type[BaseDocument]]] = {} + for field_name, field in cls.document_type.__fields__.items(): field_type = field.outer_type_ if is_tensor_union(field_type): - columns_fields.append(field_name) + column_schema[field_name] = tensor_type elif isinstance(field_type, type): - is_torch_subclass = ( - issubclass(field_type, torch.Tensor) if torch_imported else False - ) + if issubclass(field_type, (BaseDocument, AbstractTensor)): + column_schema[field_name] = field_type - if ( - is_torch_subclass - or issubclass(field_type, BaseDocument) - or issubclass(field_type, NdArray) - ): - columns_fields.append(field_name) + return column_schema - if not columns_fields: - # nothing to stack - return {} + @classmethod + def _create_columns( + cls: Type[T], docs: DocumentArray, tensor_type: Type[AbstractTensor] + ) -> Dict[str, Union[T, AbstractTensor]]: - columns: Dict[str, Union[T, AbstractTensor]] = dict() + if len(docs) == 0: + return {} - columns_to_stack: DefaultDict[ - str, Union[List[AbstractTensor], List[BaseDocument]] - ] = defaultdict( # type: ignore - list # type: ignore - ) # type: ignore + column_schema = cls._get_columns_schema(tensor_type) - for doc in docs: - for field_to_stack in columns_fields: - val = getattr(doc, field_to_stack) - if val is None: - type_ = cls.document_type._get_field_type(field_to_stack) - if is_tensor_union(type_): - val = tensor_type.get_comp_backend().none_value() - columns_to_stack[field_to_stack].append(val) + columns: Dict[str, Union[DocumentArrayStacked, AbstractTensor]] = dict() - for field_to_stack, to_stack in columns_to_stack.items(): + for field, type_ in column_schema.items(): + if issubclass(type_, AbstractTensor): + tensor = getattr(docs[0], field) + column_shape = ( + (len(docs), *tensor.shape) if tensor is not None else (len(docs),) + ) + columns[field] = type_.get_comp_backend().empty(column_shape) - type_ = cls.document_type._get_field_type(field_to_stack) - if is_tensor_union(type_): - columns[field_to_stack] = tensor_type.__docarray_stack__(to_stack) # type: ignore # noqa: E501 - elif isinstance(type_, type): - if issubclass(type_, BaseDocument): - columns[field_to_stack] = DocumentArray.__class_getitem__(type_)( - to_stack, tensor_type=tensor_type - ).stack() + for i, doc in enumerate(docs): + val = getattr(doc, field) + if val is None: + val = tensor_type.get_comp_backend().none_value() - elif issubclass(type_, AbstractTensor): - columns[field_to_stack] = type_.__docarray_stack__(to_stack) # type: ignore # noqa: E501 + columns[field][i] = val + setattr(doc, field, columns[field][i]) + del val - for field_name, column in columns.items(): - for doc, val in zip(docs, column): - setattr(doc, field_name, val) + elif issubclass(type_, BaseDocument): + columns[field] = getattr(docs, field).stack() return columns diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 691f437d6d9..b2db595cc7f 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -31,6 +31,11 @@ def stack( def n_dim(array: 'TTensor') -> int: ... + @staticmethod + @abstractmethod + def empty(shape: Tuple[int, ...]) -> 'TTensor': + ... + @staticmethod @abstractmethod def none_value() -> typing.Any: diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index f1eff8551a6..15549a01e6c 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -64,6 +64,10 @@ def to_device( def n_dim(array: 'np.ndarray') -> int: return array.ndim + @staticmethod + def empty(shape: Tuple[int, ...]) -> 'np.ndarray': + return np.empty(shape) + @staticmethod def none_value() -> Any: """Provide a compatible value that represents None in numpy.""" diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index 4bf02295f3d..1bfb2ae6157 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -60,6 +60,10 @@ def to_device( ) -> Union['torch.Tensor', 'TorchTensor']: return tensor.to(device) + @staticmethod + def empty(shape: Tuple[int, ...]) -> torch.Tensor: + return torch.empty(shape) + @staticmethod def n_dim(array: 'torch.Tensor') -> int: return array.ndim diff --git a/tests/units/array/test_array_stacked.py b/tests/units/array/test_array_stacked.py index 474bd7809a3..8fdfbe97abc 100644 --- a/tests/units/array/test_array_stacked.py +++ b/tests/units/array/test_array_stacked.py @@ -267,7 +267,7 @@ class Doc(BaseDocument): N = 10 da = DocumentArray[Doc]( - (Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)) + [Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)] ).stack() da_sliced = da[0:10:2] diff --git a/tests/units/computation_backends/numpy_backend/test_basics.py b/tests/units/computation_backends/numpy_backend/test_basics.py index 1873889f3a5..ea70539b3dc 100644 --- a/tests/units/computation_backends/numpy_backend/test_basics.py +++ b/tests/units/computation_backends/numpy_backend/test_basics.py @@ -7,3 +7,8 @@ def test_to_device(): with pytest.raises(NotImplementedError): NumpyCompBackend.to_device(np.random.rand(10, 3), 'meta') + + +def test_empty(): + array = NumpyCompBackend.empty((10, 3)) + assert array.shape == (10, 3) diff --git a/tests/units/computation_backends/torch_backend/test_basics.py b/tests/units/computation_backends/torch_backend/test_basics.py index 14f337df429..0005135f99b 100644 --- a/tests/units/computation_backends/torch_backend/test_basics.py +++ b/tests/units/computation_backends/torch_backend/test_basics.py @@ -8,3 +8,8 @@ def test_to_device(): assert t.device == torch.device('cpu') t = TorchCompBackend.to_device(t, 'meta') assert t.device == torch.device('meta') + + +def test_empty(): + tensor = TorchCompBackend.empty((10, 3)) + assert tensor.shape == (10, 3) diff --git a/tests/units/util/test_typing.py b/tests/units/util/test_typing.py index fb263c3aaed..bcbed3fd9b1 100644 --- a/tests/units/util/test_typing.py +++ b/tests/units/util/test_typing.py @@ -29,6 +29,7 @@ def test_is_type_tensor(type_, is_tensor): [ (int, False), (TorchTensor, False), + (NdArray, False), (Optional[TorchTensor], True), (Optional[NdArray], True), (Union[NdArray, TorchTensor], True), From ad6d4a0b9baa1b4936ce573d72b7599038310cdb Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 13 Jan 2023 17:19:06 +0100 Subject: [PATCH 2/6] refactor: better column creation fix tests Signed-off-by: Sami Jaghouar --- docarray/array/array.py | 2 +- docarray/array/array_stacked.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index 4a6b632c5a5..3bd485ad7d3 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -123,7 +123,7 @@ def _get_array_attribute( # most likely a bug in mypy though # bug reported here https://github.com/python/mypy/issues/14111 return self.__class__.__class_getitem__(field_type)( - (getattr(doc, field) for doc in self) + (getattr(doc, field) for doc in self), tensor_type=self.tensor_type ) else: return [getattr(doc, field) for doc in self] diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 4ed81b5f4c4..0eee1f65356 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -138,7 +138,9 @@ def _create_columns( column_shape = ( (len(docs), *tensor.shape) if tensor is not None else (len(docs),) ) - columns[field] = type_.get_comp_backend().empty(column_shape) + columns[field] = type_.__docarray_from_native__( + type_.get_comp_backend().empty(column_shape) + ) for i, doc in enumerate(docs): val = getattr(doc, field) From d54842b207fb55960dab299b9e5360f35f4dbd8b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 11:22:11 +0100 Subject: [PATCH 3/6] fix: fix mypy Signed-off-by: Sami Jaghouar --- docarray/array/array_stacked.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 0eee1f65356..426b1b52dd5 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -1,5 +1,16 @@ from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Type, TypeVar, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Type, + TypeVar, + Union, + cast, +) from docarray.array.abstract_array import AnyDocumentArray from docarray.array.array import DocumentArray @@ -49,7 +60,7 @@ def __init__( self: T, docs: DocumentArray, ): - self._columns: Dict[str, Union[T, AbstractTensor]] = {} + self._columns: Dict[str, Union['DocumentArrayStacked', AbstractTensor]] = {} self.from_document_array(docs) @@ -62,7 +73,7 @@ def from_document_array(self: T, docs: DocumentArray): def _from_columns( cls: Type[T], docs: DocumentArray, - columns: Dict[str, Union[T, AbstractTensor]], + columns: Mapping[str, Union['DocumentArrayStacked', AbstractTensor]], ) -> T: # below __class_getitem__ is called explicitly instead # of doing DocumentArrayStacked[docs.document_type] @@ -99,7 +110,7 @@ def to(self: T, device: str): def _get_columns_schema( cls: Type[T], tensor_type: Type[AbstractTensor], - ) -> Dict[str, Union[Type[AbstractTensor], Type[BaseDocument]]]: + ) -> Mapping[str, Union[Type[AbstractTensor], Type[BaseDocument]]]: """ Return the list of fields that are tensors and the list of fields that are documents @@ -123,7 +134,7 @@ def _get_columns_schema( @classmethod def _create_columns( cls: Type[T], docs: DocumentArray, tensor_type: Type[AbstractTensor] - ) -> Dict[str, Union[T, AbstractTensor]]: + ) -> Dict[str, Union['DocumentArrayStacked', AbstractTensor]]: if len(docs) == 0: return {} @@ -159,7 +170,7 @@ def _create_columns( def _get_array_attribute( self: T, field: str, - ) -> Union[List, T, AbstractTensor]: + ) -> Union[List, 'DocumentArrayStacked', AbstractTensor]: """Return all values of the fields from all docs this array contains :param field: name of the fields to extract From 86f9ea3ce7881cd303725120c417908bcdc3dee4 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 11:39:16 +0100 Subject: [PATCH 4/6] fix: fix mypy Signed-off-by: Sami Jaghouar --- docarray/array/array_stacked.py | 2 +- docarray/typing/tensor/abstract_tensor.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 426b1b52dd5..fc9bd940bad 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -158,7 +158,7 @@ def _create_columns( if val is None: val = tensor_type.get_comp_backend().none_value() - columns[field][i] = val + cast(AbstractTensor, columns[field])[i] = val setattr(doc, field, columns[field][i]) del val diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index a877861aa4e..44385a850de 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -124,6 +124,10 @@ def __getitem__(self, item): """Get a slice of this tensor.""" ... + def __setitem__(self, index, value): + """Set a slice of this tensor.""" + ... + def __iter__(self): """Iterate over the elements of this tensor.""" ... From 9d4c0b2f6d13c38dfb68012e149c926a2a155bd3 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 11:43:54 +0100 Subject: [PATCH 5/6] fix: fix import Signed-off-by: Sami Jaghouar --- docarray/array/array_stacked.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index fc9bd940bad..75235d840ac 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -27,10 +27,11 @@ from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor -else: - from docarray.typing import TorchTensor - torch_imported = True +try: + from docarray.typing import TorchTensor +except ImportError: + TorchTensor = None T = TypeVar('T', bound='DocumentArrayStacked') @@ -321,7 +322,9 @@ def traverse_flat( nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten_one_level(nodes) - if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)): + cls_to_check = (NdArray, TorchTensor) if TorchTensor else (NdArray,) + + if len(flattened) == 1 and isinstance(flattened[0], cls_to_check): return flattened[0] else: return flattened From baede3e330826a9382e657d056edde4c28c92b66 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 16 Jan 2023 11:56:56 +0100 Subject: [PATCH 6/6] fix: fix import Signed-off-by: Sami Jaghouar --- docarray/array/array_stacked.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 75235d840ac..9758fc6e271 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -31,7 +31,7 @@ try: from docarray.typing import TorchTensor except ImportError: - TorchTensor = None + TorchTensor = None # type: ignore T = TypeVar('T', bound='DocumentArrayStacked') @@ -322,7 +322,7 @@ def traverse_flat( nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten_one_level(nodes) - cls_to_check = (NdArray, TorchTensor) if TorchTensor else (NdArray,) + cls_to_check = (NdArray, TorchTensor) if TorchTensor is not None else (NdArray,) if len(flattened) == 1 and isinstance(flattened[0], cls_to_check): return flattened[0]