Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,48 @@
ShapeT = TypeVar('ShapeT')


class _ParametrizedMeta(type):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure that type is needed here

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it is strictly needed bc in the end the "final" metaclass will only inherit from this one, but I think if we wanted to use _ParametrizedMeta directly as a metaclass we would need it. So I think it is good practice for all metclasses to do it

"""
This metaclass ensures that instance and subclass checks on parametrized Tensors
are handled as expected:

assert issubclass(TorchTensor[128], TorchTensor[128])
t = parse_obj_as(TorchTensor[128], torch.zeros(128))
assert isinstance(t, TorchTensor[128])
etc.

This special handling is needed because every call to `AbstractTensor.__getitem__`
creates a new class on the fly.
We want technically distinct but identical classes to be considered equal.
"""

def __subclasscheck__(cls, subclass):
is_tensor = AbstractTensor in subclass.mro()
same_parents = is_tensor and cls.mro()[1:] == subclass.mro()[1:]

subclass_target_shape = getattr(subclass, '__docarray_target_shape__', False)
self_target_shape = getattr(cls, '__docarray_target_shape__', False)
same_shape = (
same_parents
and subclass_target_shape
and self_target_shape
and subclass_target_shape == self_target_shape
)

if same_shape:
return True
return super().__subclasscheck__(subclass)

def __instancecheck__(cls, instance):
is_tensor = isinstance(instance, AbstractTensor)
if is_tensor: # custom handling
return any(issubclass(candidate, cls) for candidate in type(instance).mro())
return super().__instancecheck__(instance)


class AbstractTensor(Generic[ShapeT], AbstractType, ABC):

__parametrized_meta__ = type
__parametrized_meta__: type = _ParametrizedMeta
_PROTO_FIELD_NAME: str

@classmethod
Expand Down Expand Up @@ -76,7 +115,7 @@ class _ParametrizedTensor(
cls, # type: ignore
metaclass=cls.__parametrized_meta__, # type: ignore
):
_docarray_target_shape = shape
__docarray_target_shape__ = shape

@classmethod
def validate(
Expand All @@ -86,7 +125,9 @@ def validate(
config: 'BaseConfig',
):
t = super().validate(value, field, config)
return _cls.__docarray_validate_shape__(t, _cls._docarray_target_shape)
return _cls.__docarray_validate_shape__(
t, _cls.__docarray_target_shape__
)

_ParametrizedTensor.__name__ = f'{cls.__name__}[{shape_str}]'
_ParametrizedTensor.__qualname__ = f'{cls.__qualname__}[{shape_str}]'
Expand Down
11 changes: 11 additions & 0 deletions docarray/typing/tensor/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,19 @@
from docarray.computation.numpy_backend import NumpyCompBackend
from docarray.proto import NdArrayProto, NodeProto

from docarray.base_document.base_node import BaseNode

T = TypeVar('T', bound='NdArray')
ShapeT = TypeVar('ShapeT')

tensor_base: type = type(BaseNode)


# the mypy error suppression below should not be necessary anymore once the following
# is released in mypy: https://github.com/python/mypy/pull/14135
class metaNumpy(AbstractTensor.__parametrized_meta__, tensor_base): # type: ignore
pass


class NdArray(np.ndarray, AbstractTensor, Generic[ShapeT]):
"""
Expand Down Expand Up @@ -71,6 +81,7 @@ class MyDoc(BaseDocument):
"""

_PROTO_FIELD_NAME = 'ndarray'
__parametrized_meta__ = metaNumpy

@classmethod
def __get_validators__(cls):
Expand Down
12 changes: 9 additions & 3 deletions docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@
T = TypeVar('T', bound='TorchTensor')
ShapeT = TypeVar('ShapeT')

torch_base = type(torch.Tensor) # type: Any
node_base = type(BaseNode) # type: Any
torch_base: type = type(torch.Tensor)
node_base: type = type(BaseNode)


class metaTorchAndNode(torch_base, node_base):
# the mypy error suppression below should not be necessary anymore once the following
# is released in mypy: https://github.com/python/mypy/pull/14135
class metaTorchAndNode(
AbstractTensor.__parametrized_meta__, # type: ignore
torch_base, # type: ignore
node_base, # type: ignore
): # type: ignore
pass


Expand Down
36 changes: 36 additions & 0 deletions tests/units/typing/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,39 @@ def test_np_embedding():
# illegal shape at class creation time
with pytest.raises(ValueError):
parse_obj_as(NdArrayEmbedding[128, 128], np.zeros((128, 128)))


def test_parametrized_subclass():
c1 = NdArray[128]
c2 = NdArray[128]
assert issubclass(c1, c2)
assert issubclass(c1, NdArray)
assert issubclass(c1, np.ndarray)

assert not issubclass(c1, NdArray[256])


def test_parametrized_instance():
t = parse_obj_as(NdArray[128], np.zeros(128))
assert isinstance(t, NdArray[128])
assert isinstance(t, NdArray)
assert isinstance(t, np.ndarray)

assert not isinstance(t, NdArray[256])


def test_parametrized_equality():
t1 = parse_obj_as(NdArray[128], np.zeros(128))
t2 = parse_obj_as(NdArray[128], np.zeros(128))
t3 = parse_obj_as(NdArray[256], np.zeros(256))
assert (t1 == t2).all()
assert not t1 == t3


def test_parametrized_operations():
t1 = parse_obj_as(NdArray[128], np.zeros(128))
t2 = parse_obj_as(NdArray[128], np.zeros(128))
t_result = t1 + t2
assert isinstance(t_result, np.ndarray)
assert isinstance(t_result, NdArray)
assert isinstance(t_result, NdArray[128])
36 changes: 36 additions & 0 deletions tests/units/typing/tensor/test_torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,39 @@ def test_torch_embedding():
# illegal shape at class creation time
with pytest.raises(ValueError):
parse_obj_as(TorchEmbedding[128, 128], torch.zeros(128, 128))


def test_parametrized_subclass():
c1 = TorchTensor[128]
c2 = TorchTensor[128]
assert issubclass(c1, c2)
assert issubclass(c1, TorchTensor)
assert issubclass(c1, torch.Tensor)

assert not issubclass(c1, TorchTensor[256])


def test_parametrized_instance():
t = parse_obj_as(TorchTensor[128], torch.zeros(128))
assert isinstance(t, TorchTensor[128])
assert isinstance(t, TorchTensor)
assert isinstance(t, torch.Tensor)

assert not isinstance(t, TorchTensor[256])


def test_parametrized_equality():
t1 = parse_obj_as(TorchTensor[128], torch.zeros(128))
t2 = parse_obj_as(TorchTensor[128], torch.zeros(128))
t3 = parse_obj_as(TorchTensor[256], torch.zeros(256))
assert (t1 == t2).all()
assert not t1 == t3


def test_parametrized_operations():
t1 = parse_obj_as(TorchTensor[128], torch.zeros(128))
t2 = parse_obj_as(TorchTensor[128], torch.zeros(128))
t_result = t1 + t2
assert isinstance(t_result, torch.Tensor)
assert isinstance(t_result, TorchTensor)
assert isinstance(t_result, TorchTensor[128])