diff --git a/README.md b/README.md index db1e977b63c..1749be81838 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,14 @@ class MyDoc(BaseDocument): doc = MyDoc(tensor=torch.zeros(3, 224, 224)) # works doc = MyDoc(tensor=torch.zeros(224, 224, 3)) # works by reshaping doc = MyDoc(tensor=torch.zeros(224)) # fails validation + +class Image(BaseDocument): + tensor: TorchTensor[3, 'x', 'x'] + +Image(tensor = torch.zeros(3, 224, 224)) # works +Image(tensor = torch.zeros(3, 64, 128)) # fails validation because second dimension does not match third +Image(tensor = torch.zeros(4, 224 ,224 )) # fails validation because of the first dimension +Image(tensor = torch.zeros(3 ,64)) # fails validation because it does not have enough dimensions ``` ## Coming from a vector database diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index b2db595cc7f..ca37e58459b 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -1,14 +1,15 @@ import typing from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, TypeVar, Union +from typing import List, Optional, Tuple, TypeVar, Union, overload # In practice all of the below will be the same type TTensor = TypeVar('TTensor') +TAbstractTensor = TypeVar('TAbstractTensor') TTensorRetrieval = TypeVar('TTensorRetrieval') TTensorMetrics = TypeVar('TTensorMetrics') -class AbstractComputationalBackend(ABC, typing.Generic[TTensor]): +class AbstractComputationalBackend(ABC, typing.Generic[TTensor, TAbstractTensor]): """ Abstract base class for computational backends. Every supported tensor/ML framework (numpy, torch etc.) should define its own @@ -48,6 +49,53 @@ def to_device(tensor: 'TTensor', device: str) -> 'TTensor': """Move the tensor to the specified device.""" ... + @staticmethod + @abstractmethod + def shape(tensor: 'TTensor') -> Tuple[int, ...]: + """Get shape of tensor""" + ... + + @overload + @staticmethod + def reshape(tensor: 'TAbstractTensor', shape: Tuple[int, ...]) -> 'TAbstractTensor': + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + + @overload + @staticmethod + def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor': + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + + @staticmethod + @abstractmethod + def reshape( + tensor: Union['TTensor', 'TAbstractTensor'], shape: Tuple[int, ...] + ) -> Union['TTensor', 'TAbstractTensor']: + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + class Retrieval(ABC, typing.Generic[TTensorRetrieval]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index 15549a01e6c..2bd0b900d92 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -30,7 +30,7 @@ def _expand_if_scalar(arr: np.ndarray) -> np.ndarray: return arr -class NumpyCompBackend(AbstractComputationalBackend[np.ndarray]): +class NumpyCompBackend(AbstractComputationalBackend[np.ndarray, NdArray]): """ Computational backend for Numpy. """ @@ -73,6 +73,51 @@ def none_value() -> Any: """Provide a compatible value that represents None in numpy.""" return None + @staticmethod + def shape(array: 'np.ndarray') -> Tuple[int, ...]: + """Get shape of array""" + return array.shape + + @overload + @staticmethod + def reshape(array: 'NdArray', shape: Tuple[int, ...]) -> 'NdArray': + """ + Gives a new shape to array without changing its data. + + :param array: array to be reshaped + :param shape: the new shape + :return: a array with the same data and number of elements as array + but with the specified shape. + """ + ... + + @overload + @staticmethod + def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray': + """ + Gives a new shape to array without changing its data. + + :param array: array to be reshaped + :param shape: the new shape + :return: a array with the same data and number of elements as array + but with the specified shape. + """ + ... + + @staticmethod + def reshape( + array: Union['np.ndarray', 'NdArray'], shape: Tuple[int, ...] + ) -> Union['np.ndarray', 'NdArray']: + """ + Gives a new shape to array without changing its data. + + :param array: array to be reshaped + :param shape: the new shape + :return: a array with the same data and number of elements as array + but with the specified shape. + """ + return array.reshape(shape) + class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index 1bfb2ae6157..adadbd64cc2 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -31,7 +31,7 @@ def _usqueeze_if_scalar(t: torch.Tensor): return t -class TorchCompBackend(AbstractComputationalBackend[torch.Tensor]): +class TorchCompBackend(AbstractComputationalBackend[torch.Tensor, 'TorchTensor']): """ Computational backend for PyTorch. """ @@ -73,6 +73,50 @@ def none_value() -> Any: """Provide a compatible value that represents None in torch.""" return torch.tensor(float('nan')) + @staticmethod + def shape(tensor: 'torch.Tensor') -> Tuple[int, ...]: + return tuple(tensor.shape) + + @overload + @staticmethod + def reshape(tensor: 'TorchTensor', shape: Tuple[int, ...]) -> 'TorchTensor': + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + + @overload + @staticmethod + def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor': + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + + @staticmethod + def reshape( + tensor: Union['torch.Tensor', 'TorchTensor'], shape: Tuple[int, ...] + ) -> Union['torch.Tensor', 'TorchTensor']: + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + return tensor.reshape(shape) + class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 609b9cff2d2..ba8c185fb98 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -1,6 +1,18 @@ import abc +import warnings from abc import ABC -from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Tuple, + Type, + TypeVar, + Union, + cast, +) from docarray.computation import AbstractComputationalBackend from docarray.typing.abstract_type import AbstractType @@ -12,6 +24,7 @@ from docarray.proto import NdArrayProto T = TypeVar('T', bound='AbstractTensor') +TTensor = TypeVar('TTensor') ShapeT = TypeVar('ShapeT') @@ -54,32 +67,64 @@ def __instancecheck__(cls, instance): return super().__instancecheck__(instance) -class AbstractTensor(Generic[ShapeT], AbstractType, ABC): +class AbstractTensor(Generic[TTensor, T], AbstractType, ABC): __parametrized_meta__: type = _ParametrizedMeta _PROTO_FIELD_NAME: str @classmethod - @abc.abstractmethod - def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: + def __docarray_validate_shape__(cls, t: T, shape: Tuple[Union[int, str]]) -> T: """Every tensor has to implement this method in order to enable syntax of the form AnyTensor[shape]. - It is called when a tensor is assigned to a field of this type. i.e. when a tensor is passed to a Document field of type AnyTensor[shape]. - The intended behaviour is as follows: - If the shape of `t` is equal to `shape`, return `t`. - If the shape of `t` is not equal to `shape`, but can be reshaped to `shape`, return `t` reshaped to `shape`. - If the shape of `t` is not equal to `shape` and cannot be reshaped to `shape`, raise a ValueError. - :param t: The tensor to validate. :param shape: The shape to validate against. :return: The validated tensor. """ - ... + comp_be = t.get_comp_backend()() # mypy Generics require instantiation + tshape = comp_be.shape(t) + if tshape == shape: + return t + elif any(isinstance(dim, str) for dim in shape): + if len(tshape) != len(shape): + raise ValueError( + f'Tensor shape mismatch. Expected {shape}, got {tshape}' + ) + known_dims: Dict[str, int] = {} + for tdim, dim in zip(tshape, shape): + if isinstance(dim, int) and tdim != dim: + raise ValueError( + f'Tensor shape mismatch. Expected {shape}, got {tshape}' + ) + elif isinstance(dim, str): + if dim in known_dims and known_dims[dim] != tdim: + raise ValueError( + f'Tensor shape mismatch. Expected {shape}, got {tshape}' + ) + else: + known_dims[dim] = tdim + else: + return t + else: + shape = cast(Tuple[int], shape) + warnings.warn( + f'Tensor shape mismatch. Reshaping tensor ' + f'of shape {tshape} to shape {shape}' + ) + try: + value = cls._docarray_from_native(comp_be.reshape(t, shape)) + return cast(T, value) + except RuntimeError: + raise ValueError( + f'Cannot reshape tensor of shape {tshape} to shape {shape}' + ) @classmethod def __docarray_validate_getitem__(cls, item: Any) -> Tuple[int]: @@ -157,7 +202,7 @@ def _docarray_from_native(cls: Type[T], value: Any) -> T: @staticmethod @abc.abstractmethod - def get_comp_backend() -> Type[AbstractComputationalBackend]: + def get_comp_backend() -> Type[AbstractComputationalBackend[TTensor, T]]: """The computational backend compatible with this tensor type.""" ... diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index 83ff0cefed5..3e96c50a706 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -1,4 +1,3 @@ -import warnings from typing import ( TYPE_CHECKING, Any, @@ -57,12 +56,14 @@ class NdArray(np.ndarray, AbstractTensor, Generic[ShapeT]): class MyDoc(BaseDocument): arr: NdArray image_arr: NdArray[3, 224, 224] + square_crop: NdArray[3, 'x', 'x'] # create a document with tensors doc = MyDoc( arr=np.zeros((128,)), image_arr=np.zeros((3, 224, 224)), + square_crop=np.zeros((3, 64, 64)), ) assert doc.image_arr.shape == (3, 224, 224) @@ -70,6 +71,7 @@ class MyDoc(BaseDocument): doc = MyDoc( arr=np.zeros((128,)), image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) + square_crop=np.zeros((3, 128, 128)), ) assert doc.image_arr.shape == (3, 224, 224) @@ -77,6 +79,7 @@ class MyDoc(BaseDocument): doc = MyDoc( arr=np.zeros((128,)), image_arr=np.zeros((224, 224)), # this will fail validation + square_crop=np.zeros((3, 128, 64)), # this will also fail validation ) """ @@ -90,23 +93,6 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate - @classmethod - def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore - if t.shape == shape: - return t - else: - warnings.warn( - f'Tensor shape mismatch. Reshaping array ' - f'of shape {t.shape} to shape {shape}' - ) - try: - value = cls._docarray_from_native(np.reshape(t, shape)) - return cast(T, value) - except RuntimeError: - raise ValueError( - f'Cannot reshape array of shape {t.shape} to shape {shape}' - ) - @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 946e7dfd5a2..d24a6fee84b 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -1,6 +1,5 @@ -import warnings from copy import copy -from typing import TYPE_CHECKING, Any, Dict, Generic, Tuple, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Generic, Type, TypeVar, Union, cast import numpy as np import torch # type: ignore @@ -58,26 +57,29 @@ class TorchTensor( class MyDoc(BaseDocument): tensor: TorchTensor image_tensor: TorchTensor[3, 224, 224] + square_crop: TorchTensor[3, 'x', 'x'] # create a document with tensors doc = MyDoc( tensor=torch.zeros(128), image_tensor=torch.zeros(3, 224, 224), + square_crop=torch.zeros(3, 64, 64), ) # automatic shape conversion doc = MyDoc( tensor=torch.zeros(128), image_tensor=torch.zeros(224, 224, 3), # will reshape to (3, 224, 224) + square_crop=torch.zeros(3, 128, 128), ) # !! The following will raise an error due to shape mismatch !! doc = MyDoc( tensor=torch.zeros(128), image_tensor=torch.zeros(224, 224), # this will fail validation + square_crop=torch.zeros(3, 128, 64), # this will also fail validation ) - """ __parametrized_meta__ = metaTorchAndNode @@ -90,23 +92,6 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate - @classmethod - def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore - if t.shape == shape: - return t - else: - warnings.warn( - f'Tensor shape mismatch. Reshaping tensor ' - f'of shape {t.shape} to shape {shape}' - ) - try: - value = cls._docarray_from_native(t.view(shape)) - return cast(T, value) - except RuntimeError: - raise ValueError( - f'Cannot reshape tensor of shape {t.shape} to shape {shape}' - ) - @classmethod def validate( cls: Type[T], diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index 76050d1b643..e8f1bab1b95 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -74,6 +74,35 @@ def test_parametrized(): with pytest.raises(ValueError): parse_obj_as(NdArray[3, 224, 224], np.zeros((224, 224))) + # test independent variable dimensions + tensor = parse_obj_as(NdArray[3, 'x', 'y'], np.zeros((3, 224, 224))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 224, 224) + + tensor = parse_obj_as(NdArray[3, 'x', 'y'], np.zeros((3, 60, 128))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 60, 128) + + with pytest.raises(ValueError): + parse_obj_as(NdArray[3, 'x', 'y'], np.zeros((4, 224, 224))) + + with pytest.raises(ValueError): + parse_obj_as(NdArray[3, 'x', 'y'], np.zeros((100, 1))) + + # test dependent variable dimensions + tensor = parse_obj_as(NdArray[3, 'x', 'x'], np.zeros((3, 224, 224))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 224, 224) + + with pytest.raises(ValueError): + tensor = parse_obj_as(NdArray[3, 'x', 'x'], np.zeros((3, 60, 128))) + + with pytest.raises(ValueError): + tensor = parse_obj_as(NdArray[3, 'x', 'x'], np.zeros((3, 60))) + def test_np_embedding(): # correct shape diff --git a/tests/units/typing/tensor/test_torch_tensor.py b/tests/units/typing/tensor/test_torch_tensor.py index b859cbd28bf..3ea1bc720b4 100644 --- a/tests/units/typing/tensor/test_torch_tensor.py +++ b/tests/units/typing/tensor/test_torch_tensor.py @@ -58,6 +58,35 @@ def test_parametrized(): with pytest.raises(ValueError): parse_obj_as(TorchTensor[3, 224, 224], torch.zeros(224, 224)) + # test independent variable dimensions + tensor = parse_obj_as(TorchTensor[3, 'x', 'y'], torch.zeros(3, 224, 224)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 224, 224) + + tensor = parse_obj_as(TorchTensor[3, 'x', 'y'], torch.zeros(3, 60, 128)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 60, 128) + + with pytest.raises(ValueError): + parse_obj_as(TorchTensor[3, 'x', 'y'], torch.zeros(4, 224, 224)) + + with pytest.raises(ValueError): + parse_obj_as(TorchTensor[3, 'x', 'y'], torch.zeros(100, 1)) + + # test dependent variable dimensions + tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 224, 224)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 224, 224) + + with pytest.raises(ValueError): + tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60, 128)) + + with pytest.raises(ValueError): + tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60)) + @pytest.mark.parametrize('shape', [(3, 224, 224), (224, 224, 3)]) def test_parameterized_tensor_class_name(shape):