diff --git a/docarray/array/array.py b/docarray/array/array.py index 4a6b632c5a5..3bd485ad7d3 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -123,7 +123,7 @@ def _get_array_attribute( # most likely a bug in mypy though # bug reported here https://github.com/python/mypy/issues/14111 return self.__class__.__class_getitem__(field_type)( - (getattr(doc, field) for doc in self) + (getattr(doc, field) for doc in self), tensor_type=self.tensor_type ) else: return [getattr(doc, field) for doc in self] diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 71d5b3547a8..9758fc6e271 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -1,12 +1,11 @@ -from collections import defaultdict from contextlib import contextmanager from typing import ( TYPE_CHECKING, Any, - DefaultDict, Dict, Iterable, List, + Mapping, Type, TypeVar, Union, @@ -28,14 +27,11 @@ from docarray.typing import TorchTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor + try: - import torch -except ImportError: - torch_imported = False -else: from docarray.typing import TorchTensor - - torch_imported = True +except ImportError: + TorchTensor = None # type: ignore T = TypeVar('T', bound='DocumentArrayStacked') @@ -65,7 +61,7 @@ def __init__( self: T, docs: DocumentArray, ): - self._columns: Dict[str, Union[T, AbstractTensor]] = {} + self._columns: Dict[str, Union['DocumentArrayStacked', AbstractTensor]] = {} self.from_document_array(docs) @@ -78,7 +74,7 @@ def from_document_array(self: T, docs: DocumentArray): def _from_columns( cls: Type[T], docs: DocumentArray, - columns: Dict[str, Union[T, AbstractTensor]], + columns: Mapping[str, Union['DocumentArrayStacked', AbstractTensor]], ) -> T: # below __class_getitem__ is called explicitly instead # of doing DocumentArrayStacked[docs.document_type] @@ -112,71 +108,70 @@ def to(self: T, device: str): col_docarray.to(device) @classmethod - def _create_columns( - cls: Type[T], docs: DocumentArray, tensor_type: Type['AbstractTensor'] - ) -> Dict[str, Union[T, AbstractTensor]]: - columns_fields = list() + def _get_columns_schema( + cls: Type[T], + tensor_type: Type[AbstractTensor], + ) -> Mapping[str, Union[Type[AbstractTensor], Type[BaseDocument]]]: + """ + Return the list of fields that are tensors and the list of fields that are + documents + :param tensor_type: the default tensor type fallback in case of union of tensor + :return: a tuple of two lists, the first one is the list of fields that are + tensors, the second one is the list of fields that are documents + """ + + column_schema: Dict[str, Union[Type[AbstractTensor], Type[BaseDocument]]] = {} + for field_name, field in cls.document_type.__fields__.items(): field_type = field.outer_type_ if is_tensor_union(field_type): - columns_fields.append(field_name) + column_schema[field_name] = tensor_type elif isinstance(field_type, type): - is_torch_subclass = ( - issubclass(field_type, torch.Tensor) if torch_imported else False - ) + if issubclass(field_type, (BaseDocument, AbstractTensor)): + column_schema[field_name] = field_type - if ( - is_torch_subclass - or issubclass(field_type, BaseDocument) - or issubclass(field_type, NdArray) - ): - columns_fields.append(field_name) + return column_schema - if not columns_fields: - # nothing to stack - return {} + @classmethod + def _create_columns( + cls: Type[T], docs: DocumentArray, tensor_type: Type[AbstractTensor] + ) -> Dict[str, Union['DocumentArrayStacked', AbstractTensor]]: - columns: Dict[str, Union[T, AbstractTensor]] = dict() + if len(docs) == 0: + return {} - columns_to_stack: DefaultDict[ - str, Union[List[AbstractTensor], List[BaseDocument]] - ] = defaultdict( # type: ignore - list # type: ignore - ) # type: ignore + column_schema = cls._get_columns_schema(tensor_type) - for doc in docs: - for field_to_stack in columns_fields: - val = getattr(doc, field_to_stack) - if val is None: - type_ = cls.document_type._get_field_type(field_to_stack) - if is_tensor_union(type_): - val = tensor_type.get_comp_backend().none_value() - columns_to_stack[field_to_stack].append(val) + columns: Dict[str, Union[DocumentArrayStacked, AbstractTensor]] = dict() - for field_to_stack, to_stack in columns_to_stack.items(): + for field, type_ in column_schema.items(): + if issubclass(type_, AbstractTensor): + tensor = getattr(docs[0], field) + column_shape = ( + (len(docs), *tensor.shape) if tensor is not None else (len(docs),) + ) + columns[field] = type_.__docarray_from_native__( + type_.get_comp_backend().empty(column_shape) + ) - type_ = cls.document_type._get_field_type(field_to_stack) - if is_tensor_union(type_): - columns[field_to_stack] = tensor_type.__docarray_stack__(to_stack) # type: ignore # noqa: E501 - elif isinstance(type_, type): - if issubclass(type_, BaseDocument): - columns[field_to_stack] = DocumentArray.__class_getitem__(type_)( - to_stack, tensor_type=tensor_type - ).stack() + for i, doc in enumerate(docs): + val = getattr(doc, field) + if val is None: + val = tensor_type.get_comp_backend().none_value() - elif issubclass(type_, AbstractTensor): - columns[field_to_stack] = type_.__docarray_stack__(to_stack) # type: ignore # noqa: E501 + cast(AbstractTensor, columns[field])[i] = val + setattr(doc, field, columns[field][i]) + del val - for field_name, column in columns.items(): - for doc, val in zip(docs, column): - setattr(doc, field_name, val) + elif issubclass(type_, BaseDocument): + columns[field] = getattr(docs, field).stack() return columns def _get_array_attribute( self: T, field: str, - ) -> Union[List, T, AbstractTensor]: + ) -> Union[List, 'DocumentArrayStacked', AbstractTensor]: """Return all values of the fields from all docs this array contains :param field: name of the fields to extract @@ -327,7 +322,9 @@ def traverse_flat( nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path)) flattened = AnyDocumentArray._flatten_one_level(nodes) - if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)): + cls_to_check = (NdArray, TorchTensor) if TorchTensor is not None else (NdArray,) + + if len(flattened) == 1 and isinstance(flattened[0], cls_to_check): return flattened[0] else: return flattened diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 691f437d6d9..b2db595cc7f 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -31,6 +31,11 @@ def stack( def n_dim(array: 'TTensor') -> int: ... + @staticmethod + @abstractmethod + def empty(shape: Tuple[int, ...]) -> 'TTensor': + ... + @staticmethod @abstractmethod def none_value() -> typing.Any: diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index f1eff8551a6..15549a01e6c 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -64,6 +64,10 @@ def to_device( def n_dim(array: 'np.ndarray') -> int: return array.ndim + @staticmethod + def empty(shape: Tuple[int, ...]) -> 'np.ndarray': + return np.empty(shape) + @staticmethod def none_value() -> Any: """Provide a compatible value that represents None in numpy.""" diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index 4bf02295f3d..1bfb2ae6157 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -60,6 +60,10 @@ def to_device( ) -> Union['torch.Tensor', 'TorchTensor']: return tensor.to(device) + @staticmethod + def empty(shape: Tuple[int, ...]) -> torch.Tensor: + return torch.empty(shape) + @staticmethod def n_dim(array: 'torch.Tensor') -> int: return array.ndim diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index a877861aa4e..44385a850de 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -124,6 +124,10 @@ def __getitem__(self, item): """Get a slice of this tensor.""" ... + def __setitem__(self, index, value): + """Set a slice of this tensor.""" + ... + def __iter__(self): """Iterate over the elements of this tensor.""" ... diff --git a/tests/units/array/test_array_stacked.py b/tests/units/array/test_array_stacked.py index 474bd7809a3..8fdfbe97abc 100644 --- a/tests/units/array/test_array_stacked.py +++ b/tests/units/array/test_array_stacked.py @@ -267,7 +267,7 @@ class Doc(BaseDocument): N = 10 da = DocumentArray[Doc]( - (Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)) + [Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)] ).stack() da_sliced = da[0:10:2] diff --git a/tests/units/computation_backends/numpy_backend/test_basics.py b/tests/units/computation_backends/numpy_backend/test_basics.py index 1873889f3a5..ea70539b3dc 100644 --- a/tests/units/computation_backends/numpy_backend/test_basics.py +++ b/tests/units/computation_backends/numpy_backend/test_basics.py @@ -7,3 +7,8 @@ def test_to_device(): with pytest.raises(NotImplementedError): NumpyCompBackend.to_device(np.random.rand(10, 3), 'meta') + + +def test_empty(): + array = NumpyCompBackend.empty((10, 3)) + assert array.shape == (10, 3) diff --git a/tests/units/computation_backends/torch_backend/test_basics.py b/tests/units/computation_backends/torch_backend/test_basics.py index 14f337df429..0005135f99b 100644 --- a/tests/units/computation_backends/torch_backend/test_basics.py +++ b/tests/units/computation_backends/torch_backend/test_basics.py @@ -8,3 +8,8 @@ def test_to_device(): assert t.device == torch.device('cpu') t = TorchCompBackend.to_device(t, 'meta') assert t.device == torch.device('meta') + + +def test_empty(): + tensor = TorchCompBackend.empty((10, 3)) + assert tensor.shape == (10, 3) diff --git a/tests/units/util/test_typing.py b/tests/units/util/test_typing.py index fb263c3aaed..bcbed3fd9b1 100644 --- a/tests/units/util/test_typing.py +++ b/tests/units/util/test_typing.py @@ -29,6 +29,7 @@ def test_is_type_tensor(type_, is_tensor): [ (int, False), (TorchTensor, False), + (NdArray, False), (Optional[TorchTensor], True), (Optional[NdArray], True), (Union[NdArray, TorchTensor], True),