diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index a877861aa4e..1d57110b1b8 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -15,9 +15,48 @@ ShapeT = TypeVar('ShapeT') +class _ParametrizedMeta(type): + """ + This metaclass ensures that instance and subclass checks on parametrized Tensors + are handled as expected: + + assert issubclass(TorchTensor[128], TorchTensor[128]) + t = parse_obj_as(TorchTensor[128], torch.zeros(128)) + assert isinstance(t, TorchTensor[128]) + etc. + + This special handling is needed because every call to `AbstractTensor.__getitem__` + creates a new class on the fly. + We want technically distinct but identical classes to be considered equal. + """ + + def __subclasscheck__(cls, subclass): + is_tensor = AbstractTensor in subclass.mro() + same_parents = is_tensor and cls.mro()[1:] == subclass.mro()[1:] + + subclass_target_shape = getattr(subclass, '__docarray_target_shape__', False) + self_target_shape = getattr(cls, '__docarray_target_shape__', False) + same_shape = ( + same_parents + and subclass_target_shape + and self_target_shape + and subclass_target_shape == self_target_shape + ) + + if same_shape: + return True + return super().__subclasscheck__(subclass) + + def __instancecheck__(cls, instance): + is_tensor = isinstance(instance, AbstractTensor) + if is_tensor: # custom handling + return any(issubclass(candidate, cls) for candidate in type(instance).mro()) + return super().__instancecheck__(instance) + + class AbstractTensor(Generic[ShapeT], AbstractType, ABC): - __parametrized_meta__ = type + __parametrized_meta__: type = _ParametrizedMeta _PROTO_FIELD_NAME: str @classmethod @@ -76,7 +115,7 @@ class _ParametrizedTensor( cls, # type: ignore metaclass=cls.__parametrized_meta__, # type: ignore ): - _docarray_target_shape = shape + __docarray_target_shape__ = shape @classmethod def validate( @@ -86,7 +125,9 @@ def validate( config: 'BaseConfig', ): t = super().validate(value, field, config) - return _cls.__docarray_validate_shape__(t, _cls._docarray_target_shape) + return _cls.__docarray_validate_shape__( + t, _cls.__docarray_target_shape__ + ) _ParametrizedTensor.__name__ = f'{cls.__name__}[{shape_str}]' _ParametrizedTensor.__qualname__ = f'{cls.__qualname__}[{shape_str}]' diff --git a/docarray/typing/tensor/ndarray.py b/docarray/typing/tensor/ndarray.py index 00f7109f966..991da6fad9b 100644 --- a/docarray/typing/tensor/ndarray.py +++ b/docarray/typing/tensor/ndarray.py @@ -23,9 +23,19 @@ from docarray.computation.numpy_backend import NumpyCompBackend from docarray.proto import NdArrayProto, NodeProto +from docarray.base_document.base_node import BaseNode + T = TypeVar('T', bound='NdArray') ShapeT = TypeVar('ShapeT') +tensor_base: type = type(BaseNode) + + +# the mypy error suppression below should not be necessary anymore once the following +# is released in mypy: https://github.com/python/mypy/pull/14135 +class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ignore + pass + class NdArray(np.ndarray, AbstractTensor, Generic[ShapeT]): """ @@ -71,6 +81,7 @@ class MyDoc(BaseDocument): """ _PROTO_FIELD_NAME = 'ndarray' + __parametrized_meta__ = metaNumpy @classmethod def __get_validators__(cls): diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 07cefc25947..7feba31858c 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -19,11 +19,17 @@ T = TypeVar('T', bound='TorchTensor') ShapeT = TypeVar('ShapeT') -torch_base = type(torch.Tensor) # type: Any -node_base = type(BaseNode) # type: Any +torch_base: type = type(torch.Tensor) +node_base: type = type(BaseNode) -class metaTorchAndNode(torch_base, node_base): +# the mypy error suppression below should not be necessary anymore once the following +# is released in mypy: https://github.com/python/mypy/pull/14135 +class metaTorchAndNode( + AbstractTensor.__parametrized_meta__, # type: ignore + torch_base, # type: ignore + node_base, # type: ignore +): # type: ignore pass diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index e46b8485231..76050d1b643 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -90,3 +90,39 @@ def test_np_embedding(): # illegal shape at class creation time with pytest.raises(ValueError): parse_obj_as(NdArrayEmbedding[128, 128], np.zeros((128, 128))) + + +def test_parametrized_subclass(): + c1 = NdArray[128] + c2 = NdArray[128] + assert issubclass(c1, c2) + assert issubclass(c1, NdArray) + assert issubclass(c1, np.ndarray) + + assert not issubclass(c1, NdArray[256]) + + +def test_parametrized_instance(): + t = parse_obj_as(NdArray[128], np.zeros(128)) + assert isinstance(t, NdArray[128]) + assert isinstance(t, NdArray) + assert isinstance(t, np.ndarray) + + assert not isinstance(t, NdArray[256]) + + +def test_parametrized_equality(): + t1 = parse_obj_as(NdArray[128], np.zeros(128)) + t2 = parse_obj_as(NdArray[128], np.zeros(128)) + t3 = parse_obj_as(NdArray[256], np.zeros(256)) + assert (t1 == t2).all() + assert not t1 == t3 + + +def test_parametrized_operations(): + t1 = parse_obj_as(NdArray[128], np.zeros(128)) + t2 = parse_obj_as(NdArray[128], np.zeros(128)) + t_result = t1 + t2 + assert isinstance(t_result, np.ndarray) + assert isinstance(t_result, NdArray) + assert isinstance(t_result, NdArray[128]) diff --git a/tests/units/typing/tensor/test_torch_tensor.py b/tests/units/typing/tensor/test_torch_tensor.py index 0d1ad043b50..b859cbd28bf 100644 --- a/tests/units/typing/tensor/test_torch_tensor.py +++ b/tests/units/typing/tensor/test_torch_tensor.py @@ -82,3 +82,39 @@ def test_torch_embedding(): # illegal shape at class creation time with pytest.raises(ValueError): parse_obj_as(TorchEmbedding[128, 128], torch.zeros(128, 128)) + + +def test_parametrized_subclass(): + c1 = TorchTensor[128] + c2 = TorchTensor[128] + assert issubclass(c1, c2) + assert issubclass(c1, TorchTensor) + assert issubclass(c1, torch.Tensor) + + assert not issubclass(c1, TorchTensor[256]) + + +def test_parametrized_instance(): + t = parse_obj_as(TorchTensor[128], torch.zeros(128)) + assert isinstance(t, TorchTensor[128]) + assert isinstance(t, TorchTensor) + assert isinstance(t, torch.Tensor) + + assert not isinstance(t, TorchTensor[256]) + + +def test_parametrized_equality(): + t1 = parse_obj_as(TorchTensor[128], torch.zeros(128)) + t2 = parse_obj_as(TorchTensor[128], torch.zeros(128)) + t3 = parse_obj_as(TorchTensor[256], torch.zeros(256)) + assert (t1 == t2).all() + assert not t1 == t3 + + +def test_parametrized_operations(): + t1 = parse_obj_as(TorchTensor[128], torch.zeros(128)) + t2 = parse_obj_as(TorchTensor[128], torch.zeros(128)) + t_result = t1 + t2 + assert isinstance(t_result, torch.Tensor) + assert isinstance(t_result, TorchTensor) + assert isinstance(t_result, TorchTensor[128])