From b5f98f45738b424ef07314593e52eded18dca658 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 18:27:20 +0800 Subject: [PATCH 01/16] test: add torch tensor tests Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- .../units/typing/tensor/test_torch_tensor.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/units/typing/tensor/test_torch_tensor.py b/tests/units/typing/tensor/test_torch_tensor.py index b859cbd28bf..a5a67af7889 100644 --- a/tests/units/typing/tensor/test_torch_tensor.py +++ b/tests/units/typing/tensor/test_torch_tensor.py @@ -58,6 +58,32 @@ def test_parametrized(): with pytest.raises(ValueError): parse_obj_as(TorchTensor[3, 224, 224], torch.zeros(224, 224)) + # test independent variable dimensions + tensor = parse_obj_as(TorchTensor[3, 'x', 'y'], torch.zeros(3, 224, 224)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 224, 224) + + tensor = parse_obj_as(TorchTensor[3, 'x', 'y'], torch.zeros(3, 60, 128)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 60, 128) + + with pytest.raises(ValueError): + parse_obj_as(TorchTensor[3, 'x', 'y'], torch.zeros(4, 224, 224)) + + with pytest.raises(ValueError): + parse_obj_as(TorchTensor[3, 'x', 'y'], torch.zeros(100, 1)) + + # test dependent variable dimensions + tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 224, 224)) + assert isinstance(tensor, TorchTensor) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 224, 224) + + with pytest.raises(ValueError): + tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60, 128)) + @pytest.mark.parametrize('shape', [(3, 224, 224), (224, 224, 3)]) def test_parameterized_tensor_class_name(shape): From f7be0c002eec43a28e8ceb5283b866f726fbafac Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 18:27:42 +0800 Subject: [PATCH 02/16] feat: implement for torch tensor Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/torch_tensor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 7feba31858c..eab98d011ac 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -94,6 +94,22 @@ def __get_validators__(cls): def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore if t.shape == shape: return t + elif any(isinstance(dim, str) for dim in shape): + known_dims: Dict[str, int] = {} + for tdim, dim in zip(t.shape, shape): + if isinstance(dim, int) and tdim != dim: + raise ValueError( + f"Tensor shape mismatch. Expected {shape}, got {t.shape}" + ) + elif isinstance(dim, str): + if dim in known_dims and known_dims[dim] != tdim: + raise ValueError( + f"Tensor shape mismatch. Expected {shape}, got {t.shape}" + ) + else: + known_dims[dim] = tdim + else: + return t else: warnings.warn( f'Tensor shape mismatch. Reshaping tensor ' From 429b6bfa3ad8e7f53168ac0ec5267b3ac4d8f4cc Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 18:45:21 +0800 Subject: [PATCH 03/16] test: add numpy array tests Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- tests/units/typing/tensor/test_tensor.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index 76050d1b643..f19f6ce337f 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -74,6 +74,32 @@ def test_parametrized(): with pytest.raises(ValueError): parse_obj_as(NdArray[3, 224, 224], np.zeros((224, 224))) + # test independent variable dimensions + tensor = parse_obj_as(NdArray[3, 'x', 'y'], np.zeros((3, 224, 224))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 224, 224) + + tensor = parse_obj_as(NdArray[3, 'x', 'y'], np.zeros((3, 60, 128))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 60, 128) + + with pytest.raises(ValueError): + parse_obj_as(NdArray[3, 'x', 'y'], np.zeros((4, 224, 224))) + + with pytest.raises(ValueError): + parse_obj_as(NdArray[3, 'x', 'y'], np.zeros((100, 1))) + + # test dependent variable dimensions + tensor = parse_obj_as(NdArray[3, 'x', 'x'], np.zeros((3, 224, 224))) + assert isinstance(tensor, NdArray) + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (3, 224, 224) + + with pytest.raises(ValueError): + tensor = parse_obj_as(NdArray[3, 'x', 'x'], np.zeros((3, 60, 128))) + def test_np_embedding(): # correct shape From 4fdf3d509246a775c6f3ad3d0d6937cbb6379935 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 18:45:39 +0800 Subject: [PATCH 04/16] feat: implement for numpy array Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/ndarray.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index 991da6fad9b..f54cfaed3dc 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -94,9 +94,25 @@ def __get_validators__(cls): def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore if t.shape == shape: return t + elif any(isinstance(dim, str) for dim in shape): + known_dims: Dict[str, int] = {} + for tdim, dim in zip(t.shape, shape): + if isinstance(dim, int) and tdim != dim: + raise ValueError( + f"Array shape mismatch. Expected {shape}, got {t.shape}" + ) + elif isinstance(dim, str): + if dim in known_dims and known_dims[dim] != tdim: + raise ValueError( + f"Array shape mismatch. Expected {shape}, got {t.shape}" + ) + else: + known_dims[dim] = tdim + else: + return t else: warnings.warn( - f'Tensor shape mismatch. Reshaping array ' + f'Array shape mismatch. Reshaping array ' f'of shape {t.shape} to shape {shape}' ) try: From c70376ed71da12a45c60fc5803534f88a3ab803a Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 19:32:36 +0800 Subject: [PATCH 05/16] test: add uncovered edge case Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- tests/units/typing/tensor/test_tensor.py | 3 +++ tests/units/typing/tensor/test_torch_tensor.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index f19f6ce337f..e8f1bab1b95 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -100,6 +100,9 @@ def test_parametrized(): with pytest.raises(ValueError): tensor = parse_obj_as(NdArray[3, 'x', 'x'], np.zeros((3, 60, 128))) + with pytest.raises(ValueError): + tensor = parse_obj_as(NdArray[3, 'x', 'x'], np.zeros((3, 60))) + def test_np_embedding(): # correct shape diff --git a/tests/units/typing/tensor/test_torch_tensor.py b/tests/units/typing/tensor/test_torch_tensor.py index a5a67af7889..3ea1bc720b4 100644 --- a/tests/units/typing/tensor/test_torch_tensor.py +++ b/tests/units/typing/tensor/test_torch_tensor.py @@ -84,6 +84,9 @@ def test_parametrized(): with pytest.raises(ValueError): tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60, 128)) + with pytest.raises(ValueError): + tensor = parse_obj_as(TorchTensor[3, 'x', 'x'], torch.zeros(3, 60)) + @pytest.mark.parametrize('shape', [(3, 224, 224), (224, 224, 3)]) def test_parameterized_tensor_class_name(shape): From f3635e5d0f9661ea9219b33c126f357301f55764 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 19:33:01 +0800 Subject: [PATCH 06/16] fix: fix uncovered edge case Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/ndarray.py | 4 ++++ docarray/typing/tensor/torch_tensor.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index f54cfaed3dc..e75eb924a63 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -95,6 +95,10 @@ def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ign if t.shape == shape: return t elif any(isinstance(dim, str) for dim in shape): + if len(t.shape) != len(shape): + raise ValueError( + f"Tensor shape mismatch. Expected {shape}, got {t.shape}" + ) known_dims: Dict[str, int] = {} for tdim, dim in zip(t.shape, shape): if isinstance(dim, int) and tdim != dim: diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index eab98d011ac..566567bd074 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -95,6 +95,10 @@ def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ign if t.shape == shape: return t elif any(isinstance(dim, str) for dim in shape): + if len(t.shape) != len(shape): + raise ValueError( + f"Tensor shape mismatch. Expected {shape}, got {t.shape}" + ) known_dims: Dict[str, int] = {} for tdim, dim in zip(t.shape, shape): if isinstance(dim, int) and tdim != dim: From 26970e6dfaace52d557deaf1e50a53698e4de10b Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 19:39:28 +0800 Subject: [PATCH 07/16] docs: showcase variable dimension usage Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index db1e977b63c..1749be81838 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,14 @@ class MyDoc(BaseDocument): doc = MyDoc(tensor=torch.zeros(3, 224, 224)) # works doc = MyDoc(tensor=torch.zeros(224, 224, 3)) # works by reshaping doc = MyDoc(tensor=torch.zeros(224)) # fails validation + +class Image(BaseDocument): + tensor: TorchTensor[3, 'x', 'x'] + +Image(tensor = torch.zeros(3, 224, 224)) # works +Image(tensor = torch.zeros(3, 64, 128)) # fails validation because second dimension does not match third +Image(tensor = torch.zeros(4, 224 ,224 )) # fails validation because of the first dimension +Image(tensor = torch.zeros(3 ,64)) # fails validation because it does not have enough dimensions ``` ## Coming from a vector database From 8b7dda3d68a31844ad45426d9a1cae7b2de661a2 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 21:01:46 +0800 Subject: [PATCH 08/16] fix: small typo Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/ndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index e75eb924a63..fdaf4ba6318 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -97,7 +97,7 @@ def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ign elif any(isinstance(dim, str) for dim in shape): if len(t.shape) != len(shape): raise ValueError( - f"Tensor shape mismatch. Expected {shape}, got {t.shape}" + f"Array shape mismatch. Expected {shape}, got {t.shape}" ) known_dims: Dict[str, int] = {} for tdim, dim in zip(t.shape, shape): From c745629b4dfd5d5d375cd55a8ae2297ed733e918 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 21:10:02 +0800 Subject: [PATCH 09/16] docs: add variable dim example to docstrings Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/ndarray.py | 4 ++++ docarray/typing/tensor/torch_tensor.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index fdaf4ba6318..b5dfcc87b14 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -57,12 +57,14 @@ class NdArray(np.ndarray, AbstractTensor, Generic[ShapeT]): class MyDoc(BaseDocument): arr: NdArray image_arr: NdArray[3, 224, 224] + square_crop: NdArray[3, 'x', 'x'] # create a document with tensors doc = MyDoc( arr=np.zeros((128,)), image_arr=np.zeros((3, 224, 224)), + square_crop=np.zeros((3, 64, 64)), ) assert doc.image_arr.shape == (3, 224, 224) @@ -70,6 +72,7 @@ class MyDoc(BaseDocument): doc = MyDoc( arr=np.zeros((128,)), image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224) + square_crop=np.zeros((3, 128, 128)), ) assert doc.image_arr.shape == (3, 224, 224) @@ -77,6 +80,7 @@ class MyDoc(BaseDocument): doc = MyDoc( arr=np.zeros((128,)), image_arr=np.zeros((224, 224)), # this will fail validation + square_crop=np.zeros((3, 128, 64)), # this will also fail validation ) """ diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 566567bd074..96bb218ab95 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -58,26 +58,29 @@ class TorchTensor( class MyDoc(BaseDocument): tensor: TorchTensor image_tensor: TorchTensor[3, 224, 224] + square_crop: TorchTensor[3, 'x', 'x'] # create a document with tensors doc = MyDoc( tensor=torch.zeros(128), image_tensor=torch.zeros(3, 224, 224), + square_crop=torch.zeros(3, 64, 64), ) # automatic shape conversion doc = MyDoc( tensor=torch.zeros(128), image_tensor=torch.zeros(224, 224, 3), # will reshape to (3, 224, 224) + square_crop=torch.zeros(3, 128, 128), ) # !! The following will raise an error due to shape mismatch !! doc = MyDoc( tensor=torch.zeros(128), image_tensor=torch.zeros(224, 224), # this will fail validation + square_crop=torch.zeros(3, 128, 64), # this will also fail validation ) - """ __parametrized_meta__ = metaTorchAndNode From 6f89780fbfca490f65e7936c65cb2c3202660cee Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Wed, 11 Jan 2023 22:51:28 +0800 Subject: [PATCH 10/16] refactor: single quote string normalization Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/ndarray.py | 6 +++--- docarray/typing/tensor/torch_tensor.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index b5dfcc87b14..4a4f4e8e8bf 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -101,18 +101,18 @@ def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ign elif any(isinstance(dim, str) for dim in shape): if len(t.shape) != len(shape): raise ValueError( - f"Array shape mismatch. Expected {shape}, got {t.shape}" + f'Array shape mismatch. Expected {shape}, got {t.shape}' ) known_dims: Dict[str, int] = {} for tdim, dim in zip(t.shape, shape): if isinstance(dim, int) and tdim != dim: raise ValueError( - f"Array shape mismatch. Expected {shape}, got {t.shape}" + f'Array shape mismatch. Expected {shape}, got {t.shape}' ) elif isinstance(dim, str): if dim in known_dims and known_dims[dim] != tdim: raise ValueError( - f"Array shape mismatch. Expected {shape}, got {t.shape}" + f'Array shape mismatch. Expected {shape}, got {t.shape}' ) else: known_dims[dim] = tdim diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 96bb218ab95..e2278de883a 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -100,18 +100,18 @@ def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ign elif any(isinstance(dim, str) for dim in shape): if len(t.shape) != len(shape): raise ValueError( - f"Tensor shape mismatch. Expected {shape}, got {t.shape}" + f'Tensor shape mismatch. Expected {shape}, got {t.shape}' ) known_dims: Dict[str, int] = {} for tdim, dim in zip(t.shape, shape): if isinstance(dim, int) and tdim != dim: raise ValueError( - f"Tensor shape mismatch. Expected {shape}, got {t.shape}" + f'Tensor shape mismatch. Expected {shape}, got {t.shape}' ) elif isinstance(dim, str): if dim in known_dims and known_dims[dim] != tdim: raise ValueError( - f"Tensor shape mismatch. Expected {shape}, got {t.shape}" + f'Tensor shape mismatch. Expected {shape}, got {t.shape}' ) else: known_dims[dim] = tdim From 587eede71ab12b27eb3697014c0fdbf9ee76648b Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Mon, 16 Jan 2023 17:15:03 +0800 Subject: [PATCH 11/16] refactor: consolidated validate shape in abstract tensor Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/abstract_tensor.py | 71 ++++++++++++++++------- docarray/typing/tensor/ndarray.py | 38 ------------ docarray/typing/tensor/torch_tensor.py | 40 +------------ 3 files changed, 51 insertions(+), 98 deletions(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 1d57110b1b8..b1fcb6aa4cb 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -1,6 +1,18 @@ import abc +import warnings from abc import ABC -from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Tuple, + Type, + TypeVar, + Union, + cast, +) from docarray.computation import AbstractComputationalBackend from docarray.typing.abstract_type import AbstractType @@ -60,26 +72,43 @@ class AbstractTensor(Generic[ShapeT], AbstractType, ABC): _PROTO_FIELD_NAME: str @classmethod - @abc.abstractmethod - def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: - """Every tensor has to implement this method in order to - enable syntax of the form AnyTensor[shape]. - - It is called when a tensor is assigned to a field of this type. - i.e. when a tensor is passed to a Document field of type AnyTensor[shape]. - - The intended behaviour is as follows: - - If the shape of `t` is equal to `shape`, return `t`. - - If the shape of `t` is not equal to `shape`, - but can be reshaped to `shape`, return `t` reshaped to `shape`. - - If the shape of `t` is not equal to `shape` - and cannot be reshaped to `shape`, raise a ValueError. - - :param t: The tensor to validate. - :param shape: The shape to validate against. - :return: The validated tensor. - """ - ... + def __docarray_validate_shape__( + cls, t: T, shape: Tuple[Union[int, str]] + ) -> T: # type: ignore + if t.shape == shape: + return t + elif any(isinstance(dim, str) for dim in shape): + if len(t.shape) != len(shape): + raise ValueError( + f'Tensor shape mismatch. Expected {shape}, got {t.shape}' + ) + known_dims: Dict[str, int] = {} + for tdim, dim in zip(t.shape, shape): + if isinstance(dim, int) and tdim != dim: + raise ValueError( + f'Tensor shape mismatch. Expected {shape}, got {t.shape}' + ) + elif isinstance(dim, str): + if dim in known_dims and known_dims[dim] != tdim: + raise ValueError( + f'Tensor shape mismatch. Expected {shape}, got {t.shape}' + ) + else: + known_dims[dim] = tdim + else: + return t + else: + warnings.warn( + f'Tensor shape mismatch. Reshaping tensor ' + f'of shape {t.shape} to shape {shape}' + ) + try: + value = cls.__docarray_from_native__(t.reshape(shape)) + return cast(T, value) + except RuntimeError: + raise ValueError( + f'Cannot reshape tensor of shape {t.shape} to shape {shape}' + ) @classmethod def __docarray_validate_getitem__(cls, item: Any) -> Tuple[int]: diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index 4a4f4e8e8bf..a2eb2327531 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -1,4 +1,3 @@ -import warnings from typing import ( TYPE_CHECKING, Any, @@ -94,43 +93,6 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate - @classmethod - def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore - if t.shape == shape: - return t - elif any(isinstance(dim, str) for dim in shape): - if len(t.shape) != len(shape): - raise ValueError( - f'Array shape mismatch. Expected {shape}, got {t.shape}' - ) - known_dims: Dict[str, int] = {} - for tdim, dim in zip(t.shape, shape): - if isinstance(dim, int) and tdim != dim: - raise ValueError( - f'Array shape mismatch. Expected {shape}, got {t.shape}' - ) - elif isinstance(dim, str): - if dim in known_dims and known_dims[dim] != tdim: - raise ValueError( - f'Array shape mismatch. Expected {shape}, got {t.shape}' - ) - else: - known_dims[dim] = tdim - else: - return t - else: - warnings.warn( - f'Array shape mismatch. Reshaping array ' - f'of shape {t.shape} to shape {shape}' - ) - try: - value = cls.__docarray_from_native__(np.reshape(t, shape)) - return cast(T, value) - except RuntimeError: - raise ValueError( - f'Cannot reshape array of shape {t.shape} to shape {shape}' - ) - @classmethod def validate( cls: Type[T], diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index e2278de883a..ae265b833c5 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -1,6 +1,5 @@ -import warnings from copy import copy -from typing import TYPE_CHECKING, Any, Dict, Generic, Tuple, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Generic, Type, TypeVar, Union, cast import numpy as np import torch # type: ignore @@ -93,43 +92,6 @@ def __get_validators__(cls): # the value returned from the previous validator yield cls.validate - @classmethod - def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore - if t.shape == shape: - return t - elif any(isinstance(dim, str) for dim in shape): - if len(t.shape) != len(shape): - raise ValueError( - f'Tensor shape mismatch. Expected {shape}, got {t.shape}' - ) - known_dims: Dict[str, int] = {} - for tdim, dim in zip(t.shape, shape): - if isinstance(dim, int) and tdim != dim: - raise ValueError( - f'Tensor shape mismatch. Expected {shape}, got {t.shape}' - ) - elif isinstance(dim, str): - if dim in known_dims and known_dims[dim] != tdim: - raise ValueError( - f'Tensor shape mismatch. Expected {shape}, got {t.shape}' - ) - else: - known_dims[dim] = tdim - else: - return t - else: - warnings.warn( - f'Tensor shape mismatch. Reshaping tensor ' - f'of shape {t.shape} to shape {shape}' - ) - try: - value = cls.__docarray_from_native__(t.view(shape)) - return cast(T, value) - except RuntimeError: - raise ValueError( - f'Cannot reshape tensor of shape {t.shape} to shape {shape}' - ) - @classmethod def validate( cls: Type[T], From 879d89e8939ab7321d1b1903144edefad241e85a Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Mon, 16 Jan 2023 17:49:38 +0800 Subject: [PATCH 12/16] fix: fix mypy being upset about reshape Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/abstract_tensor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index b1fcb6aa4cb..c755f194fe0 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -98,6 +98,7 @@ def __docarray_validate_shape__( else: return t else: + shape = cast(Tuple[int], shape) warnings.warn( f'Tensor shape mismatch. Reshaping tensor ' f'of shape {t.shape} to shape {shape}' @@ -213,3 +214,18 @@ def _docarray_to_json_compatible(self): :return: a representation of the tensor compatible with orjson """ ... + + @abc.abstractmethod + def reshape(self, shape: Tuple[int, ...]): + """ + Gives a new shape to tensor without changing its data. + :return: a tensor with the same data and number of elements as self + but with the specified shape. + """ + ... + + @property + @abc.abstractmethod + def shape(self) -> Tuple[int, ...]: + """The shape of this tensor.""" + ... From f9cea6cbc76a2af575655cc6e0466dd65c50c938 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Mon, 16 Jan 2023 20:26:20 +0800 Subject: [PATCH 13/16] refactor: add TAbstractTensor to computational backends Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/computation/abstract_comp_backend.py | 64 ++++++++++++++++++- docarray/computation/numpy_backend.py | 47 +++++++++++++- docarray/computation/torch_backend.py | 46 ++++++++++++- 3 files changed, 153 insertions(+), 4 deletions(-) diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 691f437d6d9..1905e0a34b6 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -1,14 +1,15 @@ import typing from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, TypeVar, Union +from typing import List, Optional, Tuple, TypeVar, Union, overload # In practice all of the below will be the same type TTensor = TypeVar('TTensor') +TAbstractTensor = TypeVar('TAbstractTensor') TTensorRetrieval = TypeVar('TTensorRetrieval') TTensorMetrics = TypeVar('TTensorMetrics') -class AbstractComputationalBackend(ABC, typing.Generic[TTensor]): +class AbstractComputationalBackend(ABC, typing.Generic[TTensor, TAbstractTensor]): """ Abstract base class for computational backends. Every supported tensor/ML framework (numpy, torch etc.) should define its own @@ -43,6 +44,65 @@ def to_device(tensor: 'TTensor', device: str) -> 'TTensor': """Move the tensor to the specified device.""" ... + @overload + @staticmethod + def shape(tensor: 'TAbstractTensor') -> Tuple[int, ...]: + """Get shape of tensor""" + ... + + @overload + @staticmethod + def shape(tensor: 'TTensor') -> Tuple[int, ...]: + """Get shape of tensor""" + ... + + @staticmethod + @abstractmethod + def shape(tensor: Union['TTensor', 'TAbstractTensor']) -> Tuple[int, ...]: + """Get shape of tensor""" + ... + + @overload + @staticmethod + def reshape(tensor: 'TAbstractTensor', shape: Tuple[int, ...]) -> 'TAbstractTensor': + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + + @overload + @staticmethod + def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor': + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + + @staticmethod + @abstractmethod + def reshape( + tensor: Union['TTensor', 'TAbstractTensor'], shape: Tuple[int, ...] + ) -> Union['TTensor', 'TAbstractTensor']: + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + class Retrieval(ABC, typing.Generic[TTensorRetrieval]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index f1eff8551a6..8a170e4ada1 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -30,7 +30,7 @@ def _expand_if_scalar(arr: np.ndarray) -> np.ndarray: return arr -class NumpyCompBackend(AbstractComputationalBackend[np.ndarray]): +class NumpyCompBackend(AbstractComputationalBackend[np.ndarray, NdArray]): """ Computational backend for Numpy. """ @@ -69,6 +69,51 @@ def none_value() -> Any: """Provide a compatible value that represents None in numpy.""" return None + @staticmethod + def shape(array: 'np.ndarray') -> Tuple[int, ...]: + """Get shape of array""" + return array.shape + + @overload + @staticmethod + def reshape(array: 'NdArray', shape: Tuple[int, ...]) -> 'NdArray': + """ + Gives a new shape to array without changing its data. + + :param array: array to be reshaped + :param shape: the new shape + :return: a array with the same data and number of elements as array + but with the specified shape. + """ + ... + + @overload + @staticmethod + def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray': + """ + Gives a new shape to array without changing its data. + + :param array: array to be reshaped + :param shape: the new shape + :return: a array with the same data and number of elements as array + but with the specified shape. + """ + ... + + @staticmethod + def reshape( + array: Union['np.ndarray', 'NdArray'], shape: Tuple[int, ...] + ) -> Union['np.ndarray', 'NdArray']: + """ + Gives a new shape to array without changing its data. + + :param array: array to be reshaped + :param shape: the new shape + :return: a array with the same data and number of elements as array + but with the specified shape. + """ + return array.reshape(shape) + class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index 4bf02295f3d..d56040e33c4 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -31,7 +31,7 @@ def _usqueeze_if_scalar(t: torch.Tensor): return t -class TorchCompBackend(AbstractComputationalBackend[torch.Tensor]): +class TorchCompBackend(AbstractComputationalBackend[torch.Tensor, 'TorchTensor']): """ Computational backend for PyTorch. """ @@ -69,6 +69,50 @@ def none_value() -> Any: """Provide a compatible value that represents None in torch.""" return torch.tensor(float('nan')) + @staticmethod + def shape(tensor: Union['torch.Tensor', 'TorchTensor']) -> Tuple[int, ...]: + return tuple(tensor.shape) + + @overload + @staticmethod + def reshape(tensor: 'TorchTensor', shape: Tuple[int, ...]) -> 'TorchTensor': + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + + @overload + @staticmethod + def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor': + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + ... + + @staticmethod + def reshape( + tensor: Union['torch.Tensor', 'TorchTensor'], shape: Tuple[int, ...] + ) -> Union['torch.Tensor', 'TorchTensor']: + """ + Gives a new shape to tensor without changing its data. + + :param tensor: tensor to be reshaped + :param shape: the new shape + :return: a tensor with the same data and number of elements as tensor + but with the specified shape. + """ + return tensor.reshape(shape) + class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]): """ Abstract class for retrieval and ranking functionalities From 5c87da153108ba62ba18c42c7b21b8775336167a Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Mon, 16 Jan 2023 20:34:41 +0800 Subject: [PATCH 14/16] refactor: use computational backends in validate shape Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/abstract_tensor.py | 58 +++++++++++------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index c755f194fe0..07e7874b59b 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -24,6 +24,7 @@ from docarray.proto import NdArrayProto T = TypeVar('T', bound='AbstractTensor') +TTensor = TypeVar('TTensor') ShapeT = TypeVar('ShapeT') @@ -66,32 +67,46 @@ def __instancecheck__(cls, instance): return super().__instancecheck__(instance) -class AbstractTensor(Generic[ShapeT], AbstractType, ABC): +class AbstractTensor(Generic[TTensor, T], AbstractType, ABC): __parametrized_meta__: type = _ParametrizedMeta _PROTO_FIELD_NAME: str @classmethod - def __docarray_validate_shape__( - cls, t: T, shape: Tuple[Union[int, str]] - ) -> T: # type: ignore - if t.shape == shape: + def __docarray_validate_shape__(cls, t: T, shape: Tuple[Union[int, str]]) -> T: + """Every tensor has to implement this method in order to + enable syntax of the form AnyTensor[shape]. + It is called when a tensor is assigned to a field of this type. + i.e. when a tensor is passed to a Document field of type AnyTensor[shape]. + The intended behaviour is as follows: + - If the shape of `t` is equal to `shape`, return `t`. + - If the shape of `t` is not equal to `shape`, + but can be reshaped to `shape`, return `t` reshaped to `shape`. + - If the shape of `t` is not equal to `shape` + and cannot be reshaped to `shape`, raise a ValueError. + :param t: The tensor to validate. + :param shape: The shape to validate against. + :return: The validated tensor. + """ + comp_be = t.get_comp_backend()() # mypy Generics require instantiation + tshape = comp_be.shape(t) + if tshape == shape: return t elif any(isinstance(dim, str) for dim in shape): - if len(t.shape) != len(shape): + if len(tshape) != len(shape): raise ValueError( - f'Tensor shape mismatch. Expected {shape}, got {t.shape}' + f'Tensor shape mismatch. Expected {shape}, got {tshape}' ) known_dims: Dict[str, int] = {} - for tdim, dim in zip(t.shape, shape): + for tdim, dim in zip(tshape, shape): if isinstance(dim, int) and tdim != dim: raise ValueError( - f'Tensor shape mismatch. Expected {shape}, got {t.shape}' + f'Tensor shape mismatch. Expected {shape}, got {tshape}' ) elif isinstance(dim, str): if dim in known_dims and known_dims[dim] != tdim: raise ValueError( - f'Tensor shape mismatch. Expected {shape}, got {t.shape}' + f'Tensor shape mismatch. Expected {shape}, got {tshape}' ) else: known_dims[dim] = tdim @@ -101,14 +116,14 @@ def __docarray_validate_shape__( shape = cast(Tuple[int], shape) warnings.warn( f'Tensor shape mismatch. Reshaping tensor ' - f'of shape {t.shape} to shape {shape}' + f'of shape {tshape} to shape {shape}' ) try: - value = cls.__docarray_from_native__(t.reshape(shape)) + value = cls.__docarray_from_native__(comp_be.reshape(t, shape)) return cast(T, value) except RuntimeError: raise ValueError( - f'Cannot reshape tensor of shape {t.shape} to shape {shape}' + f'Cannot reshape tensor of shape {tshape} to shape {shape}' ) @classmethod @@ -187,7 +202,7 @@ def __docarray_from_native__(cls: Type[T], value: Any) -> T: @staticmethod @abc.abstractmethod - def get_comp_backend() -> Type[AbstractComputationalBackend]: + def get_comp_backend() -> Type[AbstractComputationalBackend[TTensor, T]]: """The computational backend compatible with this tensor type.""" ... @@ -214,18 +229,3 @@ def _docarray_to_json_compatible(self): :return: a representation of the tensor compatible with orjson """ ... - - @abc.abstractmethod - def reshape(self, shape: Tuple[int, ...]): - """ - Gives a new shape to tensor without changing its data. - :return: a tensor with the same data and number of elements as self - but with the specified shape. - """ - ... - - @property - @abc.abstractmethod - def shape(self) -> Tuple[int, ...]: - """The shape of this tensor.""" - ... From 27b0241b9db5cb6fb482b7fdd12eef9da9f11374 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Mon, 16 Jan 2023 21:15:03 +0800 Subject: [PATCH 15/16] refactor: remove redundant code Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/computation/abstract_comp_backend.py | 14 +------------- docarray/computation/torch_backend.py | 2 +- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 1905e0a34b6..9ec28a5d72b 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -44,21 +44,9 @@ def to_device(tensor: 'TTensor', device: str) -> 'TTensor': """Move the tensor to the specified device.""" ... - @overload - @staticmethod - def shape(tensor: 'TAbstractTensor') -> Tuple[int, ...]: - """Get shape of tensor""" - ... - - @overload - @staticmethod - def shape(tensor: 'TTensor') -> Tuple[int, ...]: - """Get shape of tensor""" - ... - @staticmethod @abstractmethod - def shape(tensor: Union['TTensor', 'TAbstractTensor']) -> Tuple[int, ...]: + def shape(tensor: 'TTensor') -> Tuple[int, ...]: """Get shape of tensor""" ... diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index d56040e33c4..eb316c07251 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -70,7 +70,7 @@ def none_value() -> Any: return torch.tensor(float('nan')) @staticmethod - def shape(tensor: Union['torch.Tensor', 'TorchTensor']) -> Tuple[int, ...]: + def shape(tensor: 'torch.Tensor') -> Tuple[int, ...]: return tuple(tensor.shape) @overload From ecdf48a26bb07207c773ccf65c44ca6933f08cd9 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Mon, 16 Jan 2023 22:20:06 +0800 Subject: [PATCH 16/16] fix: fixed missed rename from merge Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/typing/tensor/abstract_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 301817ac082..ba8c185fb98 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -119,7 +119,7 @@ def __docarray_validate_shape__(cls, t: T, shape: Tuple[Union[int, str]]) -> T: f'of shape {tshape} to shape {shape}' ) try: - value = cls.__docarray_from_native__(comp_be.reshape(t, shape)) + value = cls._docarray_from_native(comp_be.reshape(t, shape)) return cast(T, value) except RuntimeError: raise ValueError(