diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index 9d1ca90a916..a9deba3bec6 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -1,9 +1,7 @@ import io -from functools import wraps from typing import ( TYPE_CHECKING, Any, - Callable, Iterable, List, MutableSequence, @@ -15,15 +13,13 @@ overload, ) +from typing_extensions import SupportsIndex from typing_inspect import is_union_type from docarray.array.any_array import AnyDocArray from docarray.array.doc_list.io import IOMixinArray from docarray.array.doc_list.pushpull import PushPullMixin -from docarray.array.doc_list.sequence_indexing_mixin import ( - IndexingSequenceMixin, - IndexIterType, -) +from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing from docarray.base_doc import AnyDoc, BaseDoc from docarray.typing import NdArray @@ -40,25 +36,11 @@ T_doc = TypeVar('T_doc', bound=BaseDoc) -def _delegate_meth_to_data(meth_name: str) -> Callable: - """ - create a function that mimic a function call to the data attribute of the - DocList - - :param meth_name: name of the method - :return: a method that mimic the meth_name - """ - func = getattr(list, meth_name) - - @wraps(func) - def _delegate_meth(self, *args, **kwargs): - return getattr(self._data, meth_name)(*args, **kwargs) - - return _delegate_meth - - class DocList( - IndexingSequenceMixin[T_doc], PushPullMixin, IOMixinArray, AnyDocArray[T_doc] + ListAdvancedIndexing[T_doc], + PushPullMixin, + IOMixinArray, + AnyDocArray[T_doc], ): """ DocList is a container of Documents. @@ -129,8 +111,13 @@ class Image(BaseDoc): def __init__( self, docs: Optional[Iterable[T_doc]] = None, + validate_input_docs: bool = True, ): - self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else [] + if validate_input_docs: + docs = self._validate_docs(docs) if docs else [] + else: + docs = docs if docs else [] + super().__init__(docs) @classmethod def construct( @@ -143,9 +130,7 @@ def construct( :param docs: a Sequence (list) of Document with the same schema :return: a `DocList` object """ - new_docs = cls.__new__(cls) - new_docs._data = docs if isinstance(docs, list) else list(docs) - return new_docs + return cls(docs, False) def __eq__(self, other: Any) -> bool: if self.__len__() != other.__len__(): @@ -168,12 +153,6 @@ def _validate_one_doc(self, doc: T_doc) -> T_doc: raise ValueError(f'{doc} is not a {self.doc_type}') return doc - def __len__(self): - return len(self._data) - - def __iter__(self): - return iter(self._data) - def __bytes__(self) -> bytes: with io.BytesIO() as bf: self._write_bytes(bf=bf) @@ -185,7 +164,7 @@ def append(self, doc: T_doc): as the `.doc_type` of this `DocList` otherwise it will fail. :param doc: A Document """ - self._data.append(self._validate_one_doc(doc)) + super().append(self._validate_one_doc(doc)) def extend(self, docs: Iterable[T_doc]): """ @@ -194,31 +173,28 @@ def extend(self, docs: Iterable[T_doc]): fail. :param docs: Iterable of Documents """ - self._data.extend(self._validate_docs(docs)) + super().extend(self._validate_docs(docs)) - def insert(self, i: int, doc: T_doc): + def insert(self, i: SupportsIndex, doc: T_doc): """ Insert a Document to the `DocList`. The Document must be from the same class as the doc_type of this `DocList` otherwise it will fail. :param i: index to insert :param doc: A Document """ - self._data.insert(i, self._validate_one_doc(doc)) - - pop = _delegate_meth_to_data('pop') - remove = _delegate_meth_to_data('remove') - reverse = _delegate_meth_to_data('reverse') - sort = _delegate_meth_to_data('sort') + super().insert(i, self._validate_one_doc(doc)) def _get_data_column( self: T, field: str, ) -> Union[MutableSequence, T, 'TorchTensor', 'NdArray']: - """Return all values of the fields from all docs this doc_list contains - - :param field: name of the fields to extract - :return: Returns a list of the field value for each document - in the doc_list like container + """Return all v @classmethod + def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):alues of the fields from all docs this doc_list contains + @classmethod + def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]): + :param field: name of the fields to extract + :return: Returns a list of the field value for each document + in the doc_list like container """ field_type = self.__class__.doc_type._get_field_type(field) @@ -299,7 +275,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocListProto') -> T: return super().from_protobuf(pb_msg) @overload - def __getitem__(self, item: int) -> T_doc: + def __getitem__(self, item: SupportsIndex) -> T_doc: ... @overload @@ -308,3 +284,11 @@ def __getitem__(self: T, item: IndexIterType) -> T: def __getitem__(self, item): return super().__getitem__(item) + + @classmethod + def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]): + + if isinstance(item, type) and issubclass(item, BaseDoc): + return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore + else: + return super().__class_getitem__(item) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index fdad272b94c..5c2f1c9190a 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -99,7 +99,6 @@ def __getitem__(self, item: slice): class IOMixinArray(Iterable[T_doc]): doc_type: Type[T_doc] - _data: List[T_doc] @abstractmethod def __len__(self): @@ -327,14 +326,7 @@ def to_json(self) -> bytes: """Convert the object into JSON bytes. Can be loaded via `.from_json`. :return: JSON serialization of `DocList` """ - return orjson_dumps(self._data) - - def _docarray_to_json_compatible(self) -> List[T_doc]: - """ - Convert itself into a json compatible object - :return: A list of documents - """ - return self._data + return orjson_dumps(self) @classmethod def from_csv( diff --git a/docarray/array/doc_vec/column_storage.py b/docarray/array/doc_vec/column_storage.py index 42c67c96b3b..736b4114b16 100644 --- a/docarray/array/doc_vec/column_storage.py +++ b/docarray/array/doc_vec/column_storage.py @@ -10,7 +10,7 @@ Union, ) -from docarray.array.doc_vec.list_advance_indexing import ListAdvancedIndexing +from docarray.array.list_advance_indexing import ListAdvancedIndexing from docarray.typing import NdArray from docarray.typing.tensor.abstract_tensor import AbstractTensor diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index c7c94b393dd..101fe5b93e3 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -21,7 +21,7 @@ from docarray.array.any_array import AnyDocArray from docarray.array.doc_list.doc_list import DocList from docarray.array.doc_vec.column_storage import ColumnStorage, ColumnStorageView -from docarray.array.doc_vec.list_advance_indexing import ListAdvancedIndexing +from docarray.array.list_advance_indexing import ListAdvancedIndexing from docarray.base_doc import BaseDoc from docarray.base_doc.mixins.io import _type_to_protobuf from docarray.typing import NdArray @@ -271,9 +271,9 @@ def _get_data_column( in the array like container """ if field in self._storage.any_columns.keys(): - return self._storage.any_columns[field].data + return self._storage.any_columns[field] elif field in self._storage.docs_vec_columns.keys(): - return self._storage.docs_vec_columns[field].data + return self._storage.docs_vec_columns[field] elif field in self._storage.columns.keys(): return self._storage.columns[field] else: diff --git a/docarray/array/doc_vec/list_advance_indexing.py b/docarray/array/doc_vec/list_advance_indexing.py deleted file mode 100644 index bc5c07d9c83..00000000000 --- a/docarray/array/doc_vec/list_advance_indexing.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Iterator, MutableSequence, TypeVar - -from docarray.array.doc_list.sequence_indexing_mixin import IndexingSequenceMixin - -T_item = TypeVar('T_item') - - -class ListAdvancedIndexing(IndexingSequenceMixin[T_item]): - """ - A list wrapper that implements custom indexing - - You can index into a ListAdvanceIndex like a numpy array or torch tensor: - - --- - - ```python - docs[0] # index by position - docs[0:5:2] # index by slice - docs[[0, 2, 3]] # index by list of indices - docs[True, False, True, True, ...] # index by boolean mask - ``` - - --- - - """ - - _data: MutableSequence[T_item] - - def __init__(self, data: MutableSequence[T_item]): - self._data = data - - @property - def data(self) -> MutableSequence[T_item]: - return self._data - - def __len__(self) -> int: - return len(self._data) - - def __iter__(self) -> Iterator[T_item]: - for item in self._data: - yield item diff --git a/docarray/array/doc_list/sequence_indexing_mixin.py b/docarray/array/list_advance_indexing.py similarity index 81% rename from docarray/array/doc_list/sequence_indexing_mixin.py rename to docarray/array/list_advance_indexing.py index 8513c82bee0..bcf966e6454 100644 --- a/docarray/array/doc_list/sequence_indexing_mixin.py +++ b/docarray/array/list_advance_indexing.py @@ -1,10 +1,8 @@ -import abc from typing import ( TYPE_CHECKING, Any, Iterable, - MutableSequence, - Optional, + List, Sequence, TypeVar, Union, @@ -14,11 +12,12 @@ ) import numpy as np +from typing_extensions import SupportsIndex from docarray.utils._internal.misc import import_library T_item = TypeVar('T_item') -T = TypeVar('T', bound='IndexingSequenceMixin') +T = TypeVar('T', bound='ListAdvancedIndexing') IndexIterType = Union[slice, Iterable[int], Iterable[bool], None] @@ -34,12 +33,11 @@ def _is_np_int(item: Any) -> bool: return False # this is unreachable, but mypy wants it -class IndexingSequenceMixin(Iterable[T_item]): +class ListAdvancedIndexing(List[T_item]): """ - This mixin allow sto extend a list into an object that can be indexed - a la numpy/pytorch. + A list wrapper that implements custom indexing - You can index into, delete from, and set items in a IndexingSequenceMixin like a numpy doc_list or torch tensor: + You can index into a ListAdvanceIndex like a numpy array or torch tensor: --- @@ -54,19 +52,6 @@ class IndexingSequenceMixin(Iterable[T_item]): """ - _data: MutableSequence[T_item] - - @abc.abstractmethod - def __init__( - self, - docs: Optional[Iterable[T_item]] = None, - ): - ... - - @abc.abstractmethod - def __len__(self) -> int: - ... - @staticmethod def _normalize_index_item( item: Any, @@ -107,13 +92,13 @@ def _normalize_index_item( def _get_from_indices(self: T, item: Iterable[int]) -> T: results = [] for ix in item: - results.append(self._data[ix]) + results.append(self[ix]) return self.__class__(results) def _set_by_indices(self: T, item: Iterable[int], value: Iterable[T_item]): for ix, doc_to_set in zip(item, value): try: - self._data[ix] = doc_to_set + self[ix] = doc_to_set except KeyError: raise IndexError(f'Index {ix} is out of range') @@ -126,7 +111,7 @@ def _set_by_mask(self: T, item: Iterable[bool], value: Sequence[T_item]): i_value = 0 for i, mask_value in zip(range(len(self)), item): if mask_value: - self._data[i] = value[i_value] + self[i] = value[i_value] i_value += 1 def _del_from_mask(self: T, item: Iterable[bool]) -> None: @@ -137,15 +122,15 @@ def _del_from_indices(self: T, item: Iterable[int]) -> None: for ix in sorted(item, reverse=True): # reversed is needed here otherwise some the indices are not up to date after # each delete - del self._data[ix] + del self[ix] - def __delitem__(self, key: Union[int, IndexIterType]) -> None: + def __delitem__(self, key: Union[SupportsIndex, IndexIterType]) -> None: item = self._normalize_index_item(key) if item is None: return elif isinstance(item, (int, slice)): - del self._data[item] + super().__delitem__(item) else: head = item[0] # type: ignore if isinstance(head, bool): @@ -157,7 +142,7 @@ def __delitem__(self, key: Union[int, IndexIterType]) -> None: raise TypeError(f'Invalid type {type(head)} for indexing') @overload - def __getitem__(self: T, item: int) -> T_item: + def __getitem__(self: T, item: SupportsIndex) -> T_item: ... @overload @@ -168,10 +153,10 @@ def __getitem__(self, item): item = self._normalize_index_item(item) if type(item) == slice: - return self.__class__(self._data[item]) + return self.__class__(super().__getitem__(item)) if isinstance(item, int): - return self._data[item] + return super().__getitem__(item) if item is None: return self @@ -186,11 +171,11 @@ def __getitem__(self, item): raise TypeError(f'Invalid type {type(head)} for indexing') @overload - def __setitem__(self: T, key: IndexIterType, value: Sequence[T_item]): + def __setitem__(self: T, key: SupportsIndex, value: T_item) -> None: ... @overload - def __setitem__(self: T, key: int, value: T_item): + def __setitem__(self: T, key: IndexIterType, value: Iterable[T_item]): ... @no_type_check @@ -198,9 +183,9 @@ def __setitem__(self: T, key, value): key_norm = self._normalize_index_item(key) if isinstance(key_norm, int): - self._data[key_norm] = value + super().__setitem__(key_norm, value) elif isinstance(key_norm, slice): - self._data[key_norm] = value + super().__setitem__(key_norm, value) else: # _normalize_index_item() guarantees the line below is correct head = key_norm[0] diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 13f4837cd61..7613b393a10 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -421,7 +421,7 @@ def find( query_vec_np, search_field=search_field, limit=limit, **kwargs ) - if isinstance(docs, List): + if isinstance(docs, List) and not isinstance(docs, DocList): docs = self._dict_list_to_docarray(docs) return FindResult(documents=docs, scores=scores) diff --git a/tests/units/array/stack/test_init.py b/tests/units/array/stack/test_init.py index 663eebadf89..6e23835b560 100644 --- a/tests/units/array/stack/test_init.py +++ b/tests/units/array/stack/test_init.py @@ -15,7 +15,7 @@ class MyDoc(BaseDoc): da = DocVec[MyDoc](docs, tensor_type=NdArray) assert (da._storage.tensor_columns['tensor'] == np.zeros((4, 10))).all() - assert da._storage.any_columns['name']._data == ['hello' for _ in range(4)] + assert da._storage.any_columns['name'] == ['hello' for _ in range(4)] def test_da_iter(): diff --git a/tests/units/array/test_array.py b/tests/units/array/test_array.py index 79d50b64e82..316baa26d38 100644 --- a/tests/units/array/test_array.py +++ b/tests/units/array/test_array.py @@ -24,7 +24,7 @@ class Text(BaseDoc): def test_iterate(da): - for doc, doc2 in zip(da, da._data): + for doc, doc2 in zip(da, da): assert doc.id == doc2.id @@ -380,11 +380,11 @@ def test_construct(): class Text(BaseDoc): text: str - docs = [Text(text=f'hello {i}') for i in range(10)] + docs = [Text(text=f'hello {i}') for i in range(10)] + [BaseDoc()] da = DocList[Text].construct(docs) - assert da._data is docs + assert type(da[-1]) == BaseDoc def test_reverse():