Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docarray/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _get_array_attribute(
# most likely a bug in mypy though
# bug reported here https://github.com/python/mypy/issues/14111
return self.__class__.__class_getitem__(field_type)(
(getattr(doc, field) for doc in self)
(getattr(doc, field) for doc in self), tensor_type=self.tensor_type
)
else:
return [getattr(doc, field) for doc in self]
Expand Down
111 changes: 54 additions & 57 deletions docarray/array/array_stacked.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from collections import defaultdict
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
Iterable,
List,
Mapping,
Type,
TypeVar,
Union,
Expand All @@ -28,14 +27,11 @@
from docarray.typing import TorchTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor


try:
import torch
except ImportError:
torch_imported = False
else:
from docarray.typing import TorchTensor

torch_imported = True
except ImportError:
TorchTensor = None # type: ignore

T = TypeVar('T', bound='DocumentArrayStacked')

Expand Down Expand Up @@ -65,7 +61,7 @@ def __init__(
self: T,
docs: DocumentArray,
):
self._columns: Dict[str, Union[T, AbstractTensor]] = {}
self._columns: Dict[str, Union['DocumentArrayStacked', AbstractTensor]] = {}

self.from_document_array(docs)

Expand All @@ -78,7 +74,7 @@ def from_document_array(self: T, docs: DocumentArray):
def _from_columns(
cls: Type[T],
docs: DocumentArray,
columns: Dict[str, Union[T, AbstractTensor]],
columns: Mapping[str, Union['DocumentArrayStacked', AbstractTensor]],
) -> T:
# below __class_getitem__ is called explicitly instead
# of doing DocumentArrayStacked[docs.document_type]
Expand Down Expand Up @@ -112,71 +108,70 @@ def to(self: T, device: str):
col_docarray.to(device)

@classmethod
def _create_columns(
cls: Type[T], docs: DocumentArray, tensor_type: Type['AbstractTensor']
) -> Dict[str, Union[T, AbstractTensor]]:
columns_fields = list()
def _get_columns_schema(
cls: Type[T],
tensor_type: Type[AbstractTensor],
) -> Mapping[str, Union[Type[AbstractTensor], Type[BaseDocument]]]:
"""
Return the list of fields that are tensors and the list of fields that are
documents
:param tensor_type: the default tensor type fallback in case of union of tensor
:return: a tuple of two lists, the first one is the list of fields that are
tensors, the second one is the list of fields that are documents
"""

column_schema: Dict[str, Union[Type[AbstractTensor], Type[BaseDocument]]] = {}

for field_name, field in cls.document_type.__fields__.items():
field_type = field.outer_type_
if is_tensor_union(field_type):
columns_fields.append(field_name)
column_schema[field_name] = tensor_type
elif isinstance(field_type, type):
is_torch_subclass = (
issubclass(field_type, torch.Tensor) if torch_imported else False
)
if issubclass(field_type, (BaseDocument, AbstractTensor)):
column_schema[field_name] = field_type

if (
is_torch_subclass
or issubclass(field_type, BaseDocument)
or issubclass(field_type, NdArray)
):
columns_fields.append(field_name)
return column_schema

if not columns_fields:
# nothing to stack
return {}
@classmethod
def _create_columns(
cls: Type[T], docs: DocumentArray, tensor_type: Type[AbstractTensor]
) -> Dict[str, Union['DocumentArrayStacked', AbstractTensor]]:

columns: Dict[str, Union[T, AbstractTensor]] = dict()
if len(docs) == 0:
return {}

columns_to_stack: DefaultDict[
str, Union[List[AbstractTensor], List[BaseDocument]]
] = defaultdict( # type: ignore
list # type: ignore
) # type: ignore
column_schema = cls._get_columns_schema(tensor_type)

for doc in docs:
for field_to_stack in columns_fields:
val = getattr(doc, field_to_stack)
if val is None:
type_ = cls.document_type._get_field_type(field_to_stack)
if is_tensor_union(type_):
val = tensor_type.get_comp_backend().none_value()
columns_to_stack[field_to_stack].append(val)
columns: Dict[str, Union[DocumentArrayStacked, AbstractTensor]] = dict()

for field_to_stack, to_stack in columns_to_stack.items():
for field, type_ in column_schema.items():
if issubclass(type_, AbstractTensor):
tensor = getattr(docs[0], field)
column_shape = (
(len(docs), *tensor.shape) if tensor is not None else (len(docs),)
)
columns[field] = type_.__docarray_from_native__(
type_.get_comp_backend().empty(column_shape)
)

type_ = cls.document_type._get_field_type(field_to_stack)
if is_tensor_union(type_):
columns[field_to_stack] = tensor_type.__docarray_stack__(to_stack) # type: ignore # noqa: E501
elif isinstance(type_, type):
if issubclass(type_, BaseDocument):
columns[field_to_stack] = DocumentArray.__class_getitem__(type_)(
to_stack, tensor_type=tensor_type
).stack()
for i, doc in enumerate(docs):
val = getattr(doc, field)
if val is None:
val = tensor_type.get_comp_backend().none_value()

elif issubclass(type_, AbstractTensor):
columns[field_to_stack] = type_.__docarray_stack__(to_stack) # type: ignore # noqa: E501
cast(AbstractTensor, columns[field])[i] = val
setattr(doc, field, columns[field][i])
del val

for field_name, column in columns.items():
for doc, val in zip(docs, column):
setattr(doc, field_name, val)
elif issubclass(type_, BaseDocument):
columns[field] = getattr(docs, field).stack()

return columns

def _get_array_attribute(
self: T,
field: str,
) -> Union[List, T, AbstractTensor]:
) -> Union[List, 'DocumentArrayStacked', AbstractTensor]:
"""Return all values of the fields from all docs this array contains

:param field: name of the fields to extract
Expand Down Expand Up @@ -327,7 +322,9 @@ def traverse_flat(
nodes = list(AnyDocumentArray._traverse(node=self, access_path=access_path))
flattened = AnyDocumentArray._flatten_one_level(nodes)

if len(flattened) == 1 and isinstance(flattened[0], (NdArray, TorchTensor)):
cls_to_check = (NdArray, TorchTensor) if TorchTensor is not None else (NdArray,)

if len(flattened) == 1 and isinstance(flattened[0], cls_to_check):
return flattened[0]
else:
return flattened
5 changes: 5 additions & 0 deletions docarray/computation/abstract_comp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def stack(
def n_dim(array: 'TTensor') -> int:
...

@staticmethod
@abstractmethod
def empty(shape: Tuple[int, ...]) -> 'TTensor':
...

@staticmethod
@abstractmethod
def none_value() -> typing.Any:
Expand Down
4 changes: 4 additions & 0 deletions docarray/computation/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def to_device(
def n_dim(array: 'np.ndarray') -> int:
return array.ndim

@staticmethod
def empty(shape: Tuple[int, ...]) -> 'np.ndarray':
return np.empty(shape)

@staticmethod
def none_value() -> Any:
"""Provide a compatible value that represents None in numpy."""
Expand Down
4 changes: 4 additions & 0 deletions docarray/computation/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def to_device(
) -> Union['torch.Tensor', 'TorchTensor']:
return tensor.to(device)

@staticmethod
def empty(shape: Tuple[int, ...]) -> torch.Tensor:
return torch.empty(shape)

@staticmethod
def n_dim(array: 'torch.Tensor') -> int:
return array.ndim
Expand Down
4 changes: 4 additions & 0 deletions docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def __getitem__(self, item):
"""Get a slice of this tensor."""
...

def __setitem__(self, index, value):
"""Set a slice of this tensor."""
...

def __iter__(self):
"""Iterate over the elements of this tensor."""
...
Expand Down
2 changes: 1 addition & 1 deletion tests/units/array/test_array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class Doc(BaseDocument):
N = 10

da = DocumentArray[Doc](
(Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N))
[Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)]
).stack()

da_sliced = da[0:10:2]
Expand Down
5 changes: 5 additions & 0 deletions tests/units/computation_backends/numpy_backend/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
def test_to_device():
with pytest.raises(NotImplementedError):
NumpyCompBackend.to_device(np.random.rand(10, 3), 'meta')


def test_empty():
array = NumpyCompBackend.empty((10, 3))
assert array.shape == (10, 3)
5 changes: 5 additions & 0 deletions tests/units/computation_backends/torch_backend/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ def test_to_device():
assert t.device == torch.device('cpu')
t = TorchCompBackend.to_device(t, 'meta')
assert t.device == torch.device('meta')


def test_empty():
tensor = TorchCompBackend.empty((10, 3))
assert tensor.shape == (10, 3)
1 change: 1 addition & 0 deletions tests/units/util/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_is_type_tensor(type_, is_tensor):
[
(int, False),
(TorchTensor, False),
(NdArray, False),
(Optional[TorchTensor], True),
(Optional[NdArray], True),
(Union[NdArray, TorchTensor], True),
Expand Down