diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 1d947e70b1b..5a443a427e5 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -95,14 +95,9 @@ def to(self: T, device: str) -> T: for field in self._columns.keys(): col = self._columns[field] if isinstance(col, AbstractTensor): - # the casting below is arbitrary, in reality `col` could be of any - # subclass of AbstractTensor. But to make mypy happy we have to cast - # it to a concrete subclass thereof - # see mypy issue: https://github.com/python/mypy/issues/14421 - col_ = cast('TorchTensor', col) - self._columns[field] = col_.get_comp_backend().to_device(col_, device) - elif isinstance(col, NdArray): - self._columns[field] = col.get_comp_backend().to_device(col, device) + self._columns[field] = col.__class__._docarray_from_native( + col.get_comp_backend().to_device(col, device) + ) else: # recursive call col_docarray = cast(T, col) col_docarray.to(device) diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index f47c6883cde..7ea6a73e0c1 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -1,18 +1,17 @@ import typing from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypeVar, Union if TYPE_CHECKING: import numpy as np # In practice all of the below will be the same type TTensor = TypeVar('TTensor') -TAbstractTensor = TypeVar('TAbstractTensor') TTensorRetrieval = TypeVar('TTensorRetrieval') TTensorMetrics = TypeVar('TTensorMetrics') -class AbstractComputationalBackend(ABC, typing.Generic[TTensor, TAbstractTensor]): +class AbstractComputationalBackend(ABC, typing.Generic[TTensor]): """ Abstract base class for computational backends. Every supported tensor/ML framework (numpy, torch etc.) should define its own @@ -73,37 +72,9 @@ def shape(tensor: 'TTensor') -> Tuple[int, ...]: """Get shape of tensor""" ... - @overload - @staticmethod - def reshape(tensor: 'TAbstractTensor', shape: Tuple[int, ...]) -> 'TAbstractTensor': - """ - Gives a new shape to tensor without changing its data. - - :param tensor: tensor to be reshaped - :param shape: the new shape - :return: a tensor with the same data and number of elements as tensor - but with the specified shape. - """ - ... - - @overload - @staticmethod - def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor': - """ - Gives a new shape to tensor without changing its data. - - :param tensor: tensor to be reshaped - :param shape: the new shape - :return: a tensor with the same data and number of elements as tensor - but with the specified shape. - """ - ... - @staticmethod @abstractmethod - def reshape( - tensor: Union['TTensor', 'TAbstractTensor'], shape: Tuple[int, ...] - ) -> Union['TTensor', 'TAbstractTensor']: + def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor': """ Gives a new shape to tensor without changing its data. diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index 63de0736f35..b84bda79361 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -1,10 +1,9 @@ import warnings -from typing import Any, List, Optional, Tuple, Union, overload +from typing import Any, List, Optional, Tuple, Union import numpy as np from docarray.computation import AbstractComputationalBackend -from docarray.typing import NdArray def _expand_if_single_axis(*matrices: np.ndarray) -> List[np.ndarray]: @@ -30,7 +29,7 @@ def _expand_if_scalar(arr: np.ndarray) -> np.ndarray: return arr -class NumpyCompBackend(AbstractComputationalBackend[np.ndarray, NdArray]): +class NumpyCompBackend(AbstractComputationalBackend[np.ndarray]): """ Computational backend for Numpy. """ @@ -41,22 +40,8 @@ def stack( ) -> 'np.ndarray': return np.stack(tensors, axis=dim) - @overload - @staticmethod - def to_device(tensor: 'NdArray', device: str) -> 'NdArray': - """Move the tensor to the specified device.""" - ... - - @overload @staticmethod def to_device(tensor: 'np.ndarray', device: str) -> 'np.ndarray': - """Move the tensor to the specified device.""" - ... - - @staticmethod - def to_device( - tensor: Union['np.ndarray', 'NdArray'], device: str - ) -> Union['np.ndarray', 'NdArray']: """Move the tensor to the specified device.""" raise NotImplementedError('Numpy does not support devices (GPU).') @@ -88,39 +73,11 @@ def shape(array: 'np.ndarray') -> Tuple[int, ...]: """Get shape of array""" return array.shape - @overload - @staticmethod - def reshape(array: 'NdArray', shape: Tuple[int, ...]) -> 'NdArray': - """ - Gives a new shape to array without changing its data. - - :param array: array to be reshaped - :param shape: the new shape - :return: a array with the same data and number of elements as array - but with the specified shape. - """ - ... - - @overload @staticmethod def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray': """ Gives a new shape to array without changing its data. - :param array: array to be reshaped - :param shape: the new shape - :return: a array with the same data and number of elements as array - but with the specified shape. - """ - ... - - @staticmethod - def reshape( - array: Union['np.ndarray', 'NdArray'], shape: Tuple[int, ...] - ) -> Union['np.ndarray', 'NdArray']: - """ - Gives a new shape to array without changing its data. - :param array: array to be reshaped :param shape: the new shape :return: a array with the same data and number of elements as array diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index 7b144793eac..93309029898 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -1,13 +1,10 @@ -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, overload +from typing import Any, List, Optional, Tuple, Union import numpy as np import torch from docarray.computation.abstract_comp_backend import AbstractComputationalBackend -if TYPE_CHECKING: - from docarray.typing import TorchTensor - def _unsqueeze_if_single_axis(*matrices: torch.Tensor) -> List[torch.Tensor]: """Unsqueezes tensors that only have one axis, at dim 0. @@ -32,7 +29,7 @@ def _usqueeze_if_scalar(t: torch.Tensor): return t -class TorchCompBackend(AbstractComputationalBackend[torch.Tensor, 'TorchTensor']): +class TorchCompBackend(AbstractComputationalBackend[torch.Tensor]): """ Computational backend for PyTorch. """ @@ -43,22 +40,9 @@ def stack( ) -> 'torch.Tensor': return torch.stack(tensors, dim=dim) - @overload - @staticmethod - def to_device(tensor: 'TorchTensor', device: str) -> 'TorchTensor': - """Move the tensor to the specified device.""" - ... - - @overload @staticmethod def to_device(tensor: 'torch.Tensor', device: str) -> 'torch.Tensor': """Move the tensor to the specified device.""" - ... - - @staticmethod - def to_device( - tensor: Union['torch.Tensor', 'TorchTensor'], device: str - ) -> Union['torch.Tensor', 'TorchTensor']: return tensor.to(device) @staticmethod @@ -92,36 +76,9 @@ def none_value() -> Any: def shape(tensor: 'torch.Tensor') -> Tuple[int, ...]: return tuple(tensor.shape) - @overload - @staticmethod - def reshape(tensor: 'TorchTensor', shape: Tuple[int, ...]) -> 'TorchTensor': - """ - Gives a new shape to tensor without changing its data. - - :param tensor: tensor to be reshaped - :param shape: the new shape - :return: a tensor with the same data and number of elements as tensor - but with the specified shape. - """ - ... - - @overload @staticmethod def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor': - """ - Gives a new shape to tensor without changing its data. - - :param tensor: tensor to be reshaped - :param shape: the new shape - :return: a tensor with the same data and number of elements as tensor - but with the specified shape. - """ - ... - @staticmethod - def reshape( - tensor: Union['torch.Tensor', 'TorchTensor'], shape: Tuple[int, ...] - ) -> Union['torch.Tensor', 'TorchTensor']: """ Gives a new shape to tensor without changing its data. diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index ba8c185fb98..d1342bd1b1c 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -88,7 +88,7 @@ def __docarray_validate_shape__(cls, t: T, shape: Tuple[Union[int, str]]) -> T: :param shape: The shape to validate against. :return: The validated tensor. """ - comp_be = t.get_comp_backend()() # mypy Generics require instantiation + comp_be = t.get_comp_backend() tshape = comp_be.shape(t) if tshape == shape: return t @@ -202,7 +202,7 @@ def _docarray_from_native(cls: Type[T], value: Any) -> T: @staticmethod @abc.abstractmethod - def get_comp_backend() -> Type[AbstractComputationalBackend[TTensor, T]]: + def get_comp_backend() -> AbstractComputationalBackend: """The computational backend compatible with this tensor type.""" ... diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index 3e96c50a706..8286e7fba78 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -202,11 +202,11 @@ def to_protobuf(self) -> 'NdArrayProto': return nd_proto @staticmethod - def get_comp_backend() -> Type['NumpyCompBackend']: + def get_comp_backend() -> 'NumpyCompBackend': """Return the computational backend of the tensor""" from docarray.computation.numpy_backend import NumpyCompBackend - return NumpyCompBackend + return NumpyCompBackend() def __class_getitem__(cls, item: Any, *args, **kwargs): # see here for mypy bug: https://github.com/python/mypy/issues/14123 diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 20fc5a4c8ca..808fb7a23bd 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -216,11 +216,11 @@ def to_protobuf(self) -> 'NdArrayProto': return nd_proto @staticmethod - def get_comp_backend() -> Type['TorchCompBackend']: + def get_comp_backend() -> 'TorchCompBackend': """Return the computational backend of the tensor""" from docarray.computation.torch_backend import TorchCompBackend - return TorchCompBackend + return TorchCompBackend() @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None):