Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ class MyDoc(BaseDocument):
doc = MyDoc(tensor=torch.zeros(3, 224, 224)) # works
doc = MyDoc(tensor=torch.zeros(224, 224, 3)) # works by reshaping
doc = MyDoc(tensor=torch.zeros(224)) # fails validation

class Image(BaseDocument):
tensor: TorchTensor[3, 'x', 'x']

Image(tensor = torch.zeros(3, 224, 224)) # works
Image(tensor = torch.zeros(3, 64, 128)) # fails validation because second dimension does not match third
Image(tensor = torch.zeros(4, 224 ,224 )) # fails validation because of the first dimension
Image(tensor = torch.zeros(3 ,64)) # fails validation because it does not have enough dimensions
```

## Coming from a vector database
Expand Down
52 changes: 50 additions & 2 deletions docarray/computation/abstract_comp_backend.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import typing
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, TypeVar, Union
from typing import List, Optional, Tuple, TypeVar, Union, overload

# In practice all of the below will be the same type
TTensor = TypeVar('TTensor')
TAbstractTensor = TypeVar('TAbstractTensor')
TTensorRetrieval = TypeVar('TTensorRetrieval')
TTensorMetrics = TypeVar('TTensorMetrics')


class AbstractComputationalBackend(ABC, typing.Generic[TTensor]):
class AbstractComputationalBackend(ABC, typing.Generic[TTensor, TAbstractTensor]):
"""
Abstract base class for computational backends.
Every supported tensor/ML framework (numpy, torch etc.) should define its own
Expand Down Expand Up @@ -48,6 +49,53 @@ def to_device(tensor: 'TTensor', device: str) -> 'TTensor':
"""Move the tensor to the specified device."""
...

@staticmethod
@abstractmethod
def shape(tensor: 'TTensor') -> Tuple[int, ...]:
"""Get shape of tensor"""
...

@overload
@staticmethod
def reshape(tensor: 'TAbstractTensor', shape: Tuple[int, ...]) -> 'TAbstractTensor':
"""
Gives a new shape to tensor without changing its data.

:param tensor: tensor to be reshaped
:param shape: the new shape
:return: a tensor with the same data and number of elements as tensor
but with the specified shape.
"""
...

@overload
@staticmethod
def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor':
"""
Gives a new shape to tensor without changing its data.

:param tensor: tensor to be reshaped
:param shape: the new shape
:return: a tensor with the same data and number of elements as tensor
but with the specified shape.
"""
...

@staticmethod
@abstractmethod
def reshape(
tensor: Union['TTensor', 'TAbstractTensor'], shape: Tuple[int, ...]
) -> Union['TTensor', 'TAbstractTensor']:
"""
Gives a new shape to tensor without changing its data.

:param tensor: tensor to be reshaped
:param shape: the new shape
:return: a tensor with the same data and number of elements as tensor
but with the specified shape.
"""
...

class Retrieval(ABC, typing.Generic[TTensorRetrieval]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
47 changes: 46 additions & 1 deletion docarray/computation/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _expand_if_scalar(arr: np.ndarray) -> np.ndarray:
return arr


class NumpyCompBackend(AbstractComputationalBackend[np.ndarray]):
class NumpyCompBackend(AbstractComputationalBackend[np.ndarray, NdArray]):
"""
Computational backend for Numpy.
"""
Expand Down Expand Up @@ -73,6 +73,51 @@ def none_value() -> Any:
"""Provide a compatible value that represents None in numpy."""
return None

@staticmethod
def shape(array: 'np.ndarray') -> Tuple[int, ...]:
"""Get shape of array"""
return array.shape

@overload
@staticmethod
def reshape(array: 'NdArray', shape: Tuple[int, ...]) -> 'NdArray':
"""
Gives a new shape to array without changing its data.

:param array: array to be reshaped
:param shape: the new shape
:return: a array with the same data and number of elements as array
but with the specified shape.
"""
...

@overload
@staticmethod
def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray':
"""
Gives a new shape to array without changing its data.

:param array: array to be reshaped
:param shape: the new shape
:return: a array with the same data and number of elements as array
but with the specified shape.
"""
...

@staticmethod
def reshape(
array: Union['np.ndarray', 'NdArray'], shape: Tuple[int, ...]
) -> Union['np.ndarray', 'NdArray']:
"""
Gives a new shape to array without changing its data.

:param array: array to be reshaped
:param shape: the new shape
:return: a array with the same data and number of elements as array
but with the specified shape.
"""
return array.reshape(shape)

class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
46 changes: 45 additions & 1 deletion docarray/computation/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _usqueeze_if_scalar(t: torch.Tensor):
return t


class TorchCompBackend(AbstractComputationalBackend[torch.Tensor]):
class TorchCompBackend(AbstractComputationalBackend[torch.Tensor, 'TorchTensor']):
"""
Computational backend for PyTorch.
"""
Expand Down Expand Up @@ -73,6 +73,50 @@ def none_value() -> Any:
"""Provide a compatible value that represents None in torch."""
return torch.tensor(float('nan'))

@staticmethod
def shape(tensor: 'torch.Tensor') -> Tuple[int, ...]:
return tuple(tensor.shape)

@overload
@staticmethod
def reshape(tensor: 'TorchTensor', shape: Tuple[int, ...]) -> 'TorchTensor':
"""
Gives a new shape to tensor without changing its data.

:param tensor: tensor to be reshaped
:param shape: the new shape
:return: a tensor with the same data and number of elements as tensor
but with the specified shape.
"""
...

@overload
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to automatically generate the overload. But lets not do it in this PR unless you found a brillant solution haha

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think it works :( python/mypy#11488

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we could generate the code though

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but yeah lets see later

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah not code generation via our friend the bot.

But we could have a helper function that generate 3 methods with the correct overload

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could it be done via decorator somehow?

@staticmethod
def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor':
"""
Gives a new shape to tensor without changing its data.

:param tensor: tensor to be reshaped
:param shape: the new shape
:return: a tensor with the same data and number of elements as tensor
but with the specified shape.
"""
...

@staticmethod
def reshape(
tensor: Union['torch.Tensor', 'TorchTensor'], shape: Tuple[int, ...]
) -> Union['torch.Tensor', 'TorchTensor']:
"""
Gives a new shape to tensor without changing its data.

:param tensor: tensor to be reshaped
:param shape: the new shape
:return: a tensor with the same data and number of elements as tensor
but with the specified shape.
"""
return tensor.reshape(shape)

class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
63 changes: 54 additions & 9 deletions docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import abc
import warnings
from abc import ABC
from typing import TYPE_CHECKING, Any, Generic, List, Tuple, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Tuple,
Type,
TypeVar,
Union,
cast,
)

from docarray.computation import AbstractComputationalBackend
from docarray.typing.abstract_type import AbstractType
Expand All @@ -12,6 +24,7 @@
from docarray.proto import NdArrayProto

T = TypeVar('T', bound='AbstractTensor')
TTensor = TypeVar('TTensor')
ShapeT = TypeVar('ShapeT')


Expand Down Expand Up @@ -54,32 +67,64 @@ def __instancecheck__(cls, instance):
return super().__instancecheck__(instance)


class AbstractTensor(Generic[ShapeT], AbstractType, ABC):
class AbstractTensor(Generic[TTensor, T], AbstractType, ABC):

__parametrized_meta__: type = _ParametrizedMeta
_PROTO_FIELD_NAME: str

@classmethod
@abc.abstractmethod
def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T:
def __docarray_validate_shape__(cls, t: T, shape: Tuple[Union[int, str]]) -> T:
"""Every tensor has to implement this method in order to
enable syntax of the form AnyTensor[shape].

It is called when a tensor is assigned to a field of this type.
i.e. when a tensor is passed to a Document field of type AnyTensor[shape].

The intended behaviour is as follows:
- If the shape of `t` is equal to `shape`, return `t`.
- If the shape of `t` is not equal to `shape`,
but can be reshaped to `shape`, return `t` reshaped to `shape`.
- If the shape of `t` is not equal to `shape`
and cannot be reshaped to `shape`, raise a ValueError.

:param t: The tensor to validate.
:param shape: The shape to validate against.
:return: The validated tensor.
"""
...
comp_be = t.get_comp_backend()() # mypy Generics require instantiation
tshape = comp_be.shape(t)
if tshape == shape:
return t
elif any(isinstance(dim, str) for dim in shape):
if len(tshape) != len(shape):
raise ValueError(
f'Tensor shape mismatch. Expected {shape}, got {tshape}'
)
known_dims: Dict[str, int] = {}
for tdim, dim in zip(tshape, shape):
if isinstance(dim, int) and tdim != dim:
raise ValueError(
f'Tensor shape mismatch. Expected {shape}, got {tshape}'
)
elif isinstance(dim, str):
if dim in known_dims and known_dims[dim] != tdim:
raise ValueError(
f'Tensor shape mismatch. Expected {shape}, got {tshape}'
)
else:
known_dims[dim] = tdim
else:
return t
else:
shape = cast(Tuple[int], shape)
warnings.warn(
f'Tensor shape mismatch. Reshaping tensor '
f'of shape {tshape} to shape {shape}'
)
try:
value = cls._docarray_from_native(comp_be.reshape(t, shape))
return cast(T, value)
except RuntimeError:
raise ValueError(
f'Cannot reshape tensor of shape {tshape} to shape {shape}'
)

@classmethod
def __docarray_validate_getitem__(cls, item: Any) -> Tuple[int]:
Expand Down Expand Up @@ -157,7 +202,7 @@ def _docarray_from_native(cls: Type[T], value: Any) -> T:

@staticmethod
@abc.abstractmethod
def get_comp_backend() -> Type[AbstractComputationalBackend]:
def get_comp_backend() -> Type[AbstractComputationalBackend[TTensor, T]]:
"""The computational backend compatible with this tensor type."""
...

Expand Down
22 changes: 4 additions & 18 deletions docarray/typing/tensor/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -57,26 +56,30 @@ class NdArray(np.ndarray, AbstractTensor, Generic[ShapeT]):
class MyDoc(BaseDocument):
arr: NdArray
image_arr: NdArray[3, 224, 224]
square_crop: NdArray[3, 'x', 'x']


# create a document with tensors
doc = MyDoc(
arr=np.zeros((128,)),
image_arr=np.zeros((3, 224, 224)),
square_crop=np.zeros((3, 64, 64)),
)
assert doc.image_arr.shape == (3, 224, 224)

# automatic shape conversion
doc = MyDoc(
arr=np.zeros((128,)),
image_arr=np.zeros((224, 224, 3)), # will reshape to (3, 224, 224)
square_crop=np.zeros((3, 128, 128)),
)
assert doc.image_arr.shape == (3, 224, 224)

# !! The following will raise an error due to shape mismatch !!
doc = MyDoc(
arr=np.zeros((128,)),
image_arr=np.zeros((224, 224)), # this will fail validation
square_crop=np.zeros((3, 128, 64)), # this will also fail validation
)
"""

Expand All @@ -90,23 +93,6 @@ def __get_validators__(cls):
# the value returned from the previous validator
yield cls.validate

@classmethod
def __docarray_validate_shape__(cls, t: T, shape: Tuple[int]) -> T: # type: ignore
if t.shape == shape:
return t
else:
warnings.warn(
f'Tensor shape mismatch. Reshaping array '
f'of shape {t.shape} to shape {shape}'
)
try:
value = cls._docarray_from_native(np.reshape(t, shape))
return cast(T, value)
except RuntimeError:
raise ValueError(
f'Cannot reshape array of shape {t.shape} to shape {shape}'
)

@classmethod
def validate(
cls: Type[T],
Expand Down
Loading