From 79ab3116f5a014635a568cbe7eed6b674cbaedb1 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 5 Jan 2023 14:28:58 +0100 Subject: [PATCH 1/7] fix: some hacky stuff that i don't understand Signed-off-by: Johannes Messner --- docarray/typing/tensor/torch_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index c94cee91ced..618f4ce1145 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -230,3 +230,7 @@ def get_comp_backend() -> Type['TorchCompBackend']: from docarray.computation.torch_backend import TorchCompBackend return TorchCompBackend + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, (torch.Tensor,), args, kwargs) From c37e0a64c3ba2f68064d02d15b007c585734a864 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 12 Jan 2023 13:06:35 +0100 Subject: [PATCH 2/7] fix: subclass check for parametrized tensors Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 27 ++++++++++++++++++++--- docarray/typing/tensor/torch_tensor.py | 8 +++---- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 5e098b0e1e5..06cecdb8ca9 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -13,9 +13,28 @@ ShapeT = TypeVar('ShapeT') +class _ParametrizedMeta(type): + def __subclasscheck__(cls, subclass): + is_tensor = AbstractTensor in subclass.mro() + same_parents = is_tensor and cls.mro()[1:] == subclass.mro()[1:] + + subclass_target_shape = getattr(subclass, '__docarray_target_shape__', False) + self_target_shape = getattr(cls, '__docarray_target_shape__', False) + same_shape = ( + same_parents + and subclass_target_shape + and self_target_shape + and subclass_target_shape == self_target_shape + ) + + if same_shape: + return True + return super().__subclasscheck__(subclass) + + class AbstractTensor(AbstractType, Generic[ShapeT], ABC): - __parametrized_meta__ = type + __parametrized_meta__ = _ParametrizedMeta _PROTO_FIELD_NAME: str @classmethod @@ -74,7 +93,7 @@ class _ParametrizedTensor( cls, # type: ignore metaclass=cls.__parametrized_meta__, # type: ignore ): - _docarray_target_shape = shape + __docarray_target_shape__ = shape @classmethod def validate( @@ -84,7 +103,9 @@ def validate( config: 'BaseConfig', ): t = super().validate(value, field, config) - return _cls.__docarray_validate_shape__(t, _cls._docarray_target_shape) + return _cls.__docarray_validate_shape__( + t, _cls.__docarray_target_shape__ + ) _ParametrizedTensor.__name__ = f'{cls.__name__}[{shape_str}]' _ParametrizedTensor.__qualname__ = f'{cls.__qualname__}[{shape_str}]' diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 618f4ce1145..159b497e2a4 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -23,7 +23,7 @@ node_base = type(BaseNode) # type: Any -class metaTorchAndNode(torch_base, node_base): +class metaTorchAndNode(AbstractTensor.__parametrized_meta__, torch_base, node_base): pass @@ -231,6 +231,6 @@ def get_comp_backend() -> Type['TorchCompBackend']: return TorchCompBackend - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - return super().__torch_function__(func, (torch.Tensor,), args, kwargs) + # @classmethod + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # return super().__torch_function__(func, (torch.Tensor,), args, kwargs) From ddb5c994578179a3782a29191fb0254d14f3d3b6 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 12 Jan 2023 13:19:36 +0100 Subject: [PATCH 3/7] fix: add instance check Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 06cecdb8ca9..8a52a475feb 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -14,6 +14,20 @@ class _ParametrizedMeta(type): + """ + This metaclass ensures that instance and subclass checks on parametrized Tensors + are handled as expected: + + assert issubclass(TorchTensor[128], TorchTensor[128]) + t = parse_obj_as(TorchTensor[128], torch.zeros(128)) + assert isinstance(t, TorchTensor[128]) + etc. + + This special handling is needed because every call to `AbstractTensor.__getitem__` + creates a new class on the fly. + We want technically distinct but identical classes to be considered equal. + """ + def __subclasscheck__(cls, subclass): is_tensor = AbstractTensor in subclass.mro() same_parents = is_tensor and cls.mro()[1:] == subclass.mro()[1:] @@ -31,6 +45,12 @@ def __subclasscheck__(cls, subclass): return True return super().__subclasscheck__(subclass) + def __instancecheck__(cls, instance): + is_tensor = isinstance(instance, AbstractTensor) + if is_tensor: # custom handling + return any(issubclass(type(instance), subclass) for subclass in cls.mro()) + return super().__instancecheck__(instance) + class AbstractTensor(AbstractType, Generic[ShapeT], ABC): From c6e0e83e78feeade193dd429648dc44a13b1334c Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 12 Jan 2023 13:41:55 +0100 Subject: [PATCH 4/7] fix: ndarray metaclass and instance check Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 2 +- docarray/typing/tensor/ndarray.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 8a52a475feb..9495fe92c1d 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -48,7 +48,7 @@ def __subclasscheck__(cls, subclass): def __instancecheck__(cls, instance): is_tensor = isinstance(instance, AbstractTensor) if is_tensor: # custom handling - return any(issubclass(type(instance), subclass) for subclass in cls.mro()) + return any(issubclass(candidate, cls) for candidate in type(instance).mro()) return super().__instancecheck__(instance) diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index b8149883bc7..8f33506e071 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -23,9 +23,17 @@ from docarray.computation.numpy_backend import NumpyCompBackend from docarray.proto import NdArrayProto, NodeProto +from docarray.document.base_node import BaseNode + T = TypeVar('T', bound='NdArray') ShapeT = TypeVar('ShapeT') +tensor_base = type(BaseNode) # type: Any + + +class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): + pass + class NdArray(AbstractTensor, np.ndarray, Generic[ShapeT]): """ @@ -71,6 +79,7 @@ class MyDoc(BaseDocument): """ _PROTO_FIELD_NAME = 'ndarray' + __parametrized_meta__ = metaNumpy @classmethod def __get_validators__(cls): From 6a555d0fb5bab651f79439d24dff991f42f48e85 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 12 Jan 2023 13:51:15 +0100 Subject: [PATCH 5/7] test: add tests Signed-off-by: Johannes Messner --- tests/units/typing/tensor/test_tensor.py | 36 +++++++++++++++++++ .../units/typing/tensor/test_torch_tensor.py | 36 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index dd8356ed84d..103b0f3367a 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -90,3 +90,39 @@ def test_np_embedding(): # illegal shape at class creation time with pytest.raises(ValueError): parse_obj_as(NdArrayEmbedding[128, 128], np.zeros((128, 128))) + + +def test_parametrized_subclass(): + c1 = NdArray[128] + c2 = NdArray[128] + assert issubclass(c1, c2) + assert issubclass(c1, NdArray) + assert issubclass(c1, np.ndarray) + + assert not issubclass(c1, NdArray[256]) + + +def test_parametrized_instance(): + t = parse_obj_as(NdArray[128], np.zeros(128)) + assert isinstance(t, NdArray[128]) + assert isinstance(t, NdArray) + assert isinstance(t, np.ndarray) + + assert not isinstance(t, NdArray[256]) + + +def test_parametrized_equality(): + t1 = parse_obj_as(NdArray[128], np.zeros(128)) + t2 = parse_obj_as(NdArray[128], np.zeros(128)) + t3 = parse_obj_as(NdArray[256], np.zeros(256)) + assert (t1 == t2).all() + assert not t1 == t3 + + +def test_parametrized_operations(): + t1 = parse_obj_as(NdArray[128], np.zeros(128)) + t2 = parse_obj_as(NdArray[128], np.zeros(128)) + t_result = t1 + t2 + assert isinstance(t_result, np.ndarray) + assert isinstance(t_result, NdArray) + assert isinstance(t_result, NdArray[128]) diff --git a/tests/units/typing/tensor/test_torch_tensor.py b/tests/units/typing/tensor/test_torch_tensor.py index cd2fa257c37..76ad4628da0 100644 --- a/tests/units/typing/tensor/test_torch_tensor.py +++ b/tests/units/typing/tensor/test_torch_tensor.py @@ -82,3 +82,39 @@ def test_torch_embedding(): # illegal shape at class creation time with pytest.raises(ValueError): parse_obj_as(TorchEmbedding[128, 128], torch.zeros(128, 128)) + + +def test_parametrized_subclass(): + c1 = TorchTensor[128] + c2 = TorchTensor[128] + assert issubclass(c1, c2) + assert issubclass(c1, TorchTensor) + assert issubclass(c1, torch.Tensor) + + assert not issubclass(c1, TorchTensor[256]) + + +def test_parametrized_instance(): + t = parse_obj_as(TorchTensor[128], torch.zeros(128)) + assert isinstance(t, TorchTensor[128]) + assert isinstance(t, TorchTensor) + assert isinstance(t, torch.Tensor) + + assert not isinstance(t, TorchTensor[256]) + + +def test_parametrized_equality(): + t1 = parse_obj_as(TorchTensor[128], torch.zeros(128)) + t2 = parse_obj_as(TorchTensor[128], torch.zeros(128)) + t3 = parse_obj_as(TorchTensor[256], torch.zeros(256)) + assert (t1 == t2).all() + assert not t1 == t3 + + +def test_parametrized_operations(): + t1 = parse_obj_as(TorchTensor[128], torch.zeros(128)) + t2 = parse_obj_as(TorchTensor[128], torch.zeros(128)) + t_result = t1 + t2 + assert isinstance(t_result, torch.Tensor) + assert isinstance(t_result, TorchTensor) + assert isinstance(t_result, TorchTensor[128]) From 0a3d6b5be46d54d6c8d152a43e00fa4e2fa10640 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 12 Jan 2023 13:51:55 +0100 Subject: [PATCH 6/7] refactor: remove comment Signed-off-by: Johannes Messner --- docarray/typing/tensor/torch_tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 159b497e2a4..ac92173f54a 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -230,7 +230,3 @@ def get_comp_backend() -> Type['TorchCompBackend']: from docarray.computation.torch_backend import TorchCompBackend return TorchCompBackend - - # @classmethod - # def __torch_function__(cls, func, types, args=(), kwargs=None): - # return super().__torch_function__(func, (torch.Tensor,), args, kwargs) From 22c6f6ce75039460c03d99f7b9acd61f2a9aa660 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 12 Jan 2023 14:21:42 +0100 Subject: [PATCH 7/7] fix: handle mypy Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 3 ++- docarray/typing/tensor/ndarray.py | 8 +++++--- docarray/typing/tensor/torch_tensor.py | 12 +++++++++--- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 2278d2858f8..1d57110b1b8 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -53,9 +53,10 @@ def __instancecheck__(cls, instance): return any(issubclass(candidate, cls) for candidate in type(instance).mro()) return super().__instancecheck__(instance) + class AbstractTensor(Generic[ShapeT], AbstractType, ABC): - __parametrized_meta__ = _ParametrizedMeta + __parametrized_meta__: type = _ParametrizedMeta _PROTO_FIELD_NAME: str @classmethod diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index e0d8e06c631..991da6fad9b 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -23,15 +23,17 @@ from docarray.computation.numpy_backend import NumpyCompBackend from docarray.proto import NdArrayProto, NodeProto -from docarray.document.base_node import BaseNode +from docarray.base_document.base_node import BaseNode T = TypeVar('T', bound='NdArray') ShapeT = TypeVar('ShapeT') -tensor_base = type(BaseNode) # type: Any +tensor_base: type = type(BaseNode) -class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): +# the mypy error suppression below should not be necessary anymore once the following +# is released in mypy: https://github.com/python/mypy/pull/14135 +class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ignore pass diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index e81d13ad40d..7feba31858c 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -19,11 +19,17 @@ T = TypeVar('T', bound='TorchTensor') ShapeT = TypeVar('ShapeT') -torch_base = type(torch.Tensor) # type: Any -node_base = type(BaseNode) # type: Any +torch_base: type = type(torch.Tensor) +node_base: type = type(BaseNode) -class metaTorchAndNode(AbstractTensor.__parametrized_meta__, torch_base, node_base): +# the mypy error suppression below should not be necessary anymore once the following +# is released in mypy: https://github.com/python/mypy/pull/14135 +class metaTorchAndNode( + AbstractTensor.__parametrized_meta__, # type: ignore + torch_base, # type: ignore + node_base, # type: ignore +): # type: ignore pass