-
Notifications
You must be signed in to change notification settings - Fork 238
feat: make DocList an actual Python List #1457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0f12649
c346709
ba6f5f9
ff5eabe
65b89d5
58dfb04
f42b660
226bfd6
4075136
638d06c
5b13d06
b2f302c
cc2460f
44020b2
1180bf7
1b060e7
c9efc92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,7 @@ | ||
| import io | ||
| from functools import wraps | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| Callable, | ||
| Iterable, | ||
| List, | ||
| MutableSequence, | ||
|
|
@@ -15,15 +13,13 @@ | |
| overload, | ||
| ) | ||
|
|
||
| from typing_extensions import SupportsIndex | ||
| from typing_inspect import is_union_type | ||
|
|
||
| from docarray.array.any_array import AnyDocArray | ||
| from docarray.array.doc_list.io import IOMixinArray | ||
| from docarray.array.doc_list.pushpull import PushPullMixin | ||
| from docarray.array.doc_list.sequence_indexing_mixin import ( | ||
| IndexingSequenceMixin, | ||
| IndexIterType, | ||
| ) | ||
| from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing | ||
| from docarray.base_doc import AnyDoc, BaseDoc | ||
| from docarray.typing import NdArray | ||
|
|
||
|
|
@@ -40,25 +36,11 @@ | |
| T_doc = TypeVar('T_doc', bound=BaseDoc) | ||
|
|
||
|
|
||
| def _delegate_meth_to_data(meth_name: str) -> Callable: | ||
| """ | ||
| create a function that mimic a function call to the data attribute of the | ||
| DocList | ||
|
|
||
| :param meth_name: name of the method | ||
| :return: a method that mimic the meth_name | ||
| """ | ||
| func = getattr(list, meth_name) | ||
|
|
||
| @wraps(func) | ||
| def _delegate_meth(self, *args, **kwargs): | ||
| return getattr(self._data, meth_name)(*args, **kwargs) | ||
|
|
||
| return _delegate_meth | ||
|
|
||
|
|
||
| class DocList( | ||
| IndexingSequenceMixin[T_doc], PushPullMixin, IOMixinArray, AnyDocArray[T_doc] | ||
| ListAdvancedIndexing[T_doc], | ||
| PushPullMixin, | ||
| IOMixinArray, | ||
| AnyDocArray[T_doc], | ||
| ): | ||
| """ | ||
| DocList is a container of Documents. | ||
|
|
@@ -129,8 +111,13 @@ class Image(BaseDoc): | |
| def __init__( | ||
| self, | ||
| docs: Optional[Iterable[T_doc]] = None, | ||
| validate_input_docs: bool = True, | ||
| ): | ||
| self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else [] | ||
| if validate_input_docs: | ||
| docs = self._validate_docs(docs) if docs else [] | ||
| else: | ||
| docs = docs if docs else [] | ||
| super().__init__(docs) | ||
|
|
||
| @classmethod | ||
| def construct( | ||
|
|
@@ -143,9 +130,7 @@ def construct( | |
| :param docs: a Sequence (list) of Document with the same schema | ||
| :return: a `DocList` object | ||
| """ | ||
| new_docs = cls.__new__(cls) | ||
| new_docs._data = docs if isinstance(docs, list) else list(docs) | ||
| return new_docs | ||
| return cls(docs, False) | ||
|
|
||
| def __eq__(self, other: Any) -> bool: | ||
| if self.__len__() != other.__len__(): | ||
|
|
@@ -168,12 +153,6 @@ def _validate_one_doc(self, doc: T_doc) -> T_doc: | |
| raise ValueError(f'{doc} is not a {self.doc_type}') | ||
| return doc | ||
|
|
||
| def __len__(self): | ||
| return len(self._data) | ||
|
|
||
| def __iter__(self): | ||
| return iter(self._data) | ||
|
|
||
| def __bytes__(self) -> bytes: | ||
| with io.BytesIO() as bf: | ||
| self._write_bytes(bf=bf) | ||
|
|
@@ -185,7 +164,7 @@ def append(self, doc: T_doc): | |
| as the `.doc_type` of this `DocList` otherwise it will fail. | ||
| :param doc: A Document | ||
| """ | ||
| self._data.append(self._validate_one_doc(doc)) | ||
| super().append(self._validate_one_doc(doc)) | ||
|
|
||
| def extend(self, docs: Iterable[T_doc]): | ||
| """ | ||
|
|
@@ -194,31 +173,28 @@ def extend(self, docs: Iterable[T_doc]): | |
| fail. | ||
| :param docs: Iterable of Documents | ||
| """ | ||
| self._data.extend(self._validate_docs(docs)) | ||
| super().extend(self._validate_docs(docs)) | ||
|
|
||
| def insert(self, i: int, doc: T_doc): | ||
| def insert(self, i: SupportsIndex, doc: T_doc): | ||
| """ | ||
| Insert a Document to the `DocList`. The Document must be from the same | ||
| class as the doc_type of this `DocList` otherwise it will fail. | ||
| :param i: index to insert | ||
| :param doc: A Document | ||
| """ | ||
| self._data.insert(i, self._validate_one_doc(doc)) | ||
|
|
||
| pop = _delegate_meth_to_data('pop') | ||
| remove = _delegate_meth_to_data('remove') | ||
| reverse = _delegate_meth_to_data('reverse') | ||
| sort = _delegate_meth_to_data('sort') | ||
| super().insert(i, self._validate_one_doc(doc)) | ||
|
|
||
| def _get_data_column( | ||
| self: T, | ||
| field: str, | ||
| ) -> Union[MutableSequence, T, 'TorchTensor', 'NdArray']: | ||
| """Return all values of the fields from all docs this doc_list contains | ||
|
|
||
| :param field: name of the fields to extract | ||
| :return: Returns a list of the field value for each document | ||
| in the doc_list like container | ||
| """Return all v @classmethod | ||
| def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):alues of the fields from all docs this doc_list contains | ||
| @classmethod | ||
| def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]): | ||
| :param field: name of the fields to extract | ||
| :return: Returns a list of the field value for each document | ||
| in the doc_list like container | ||
| """ | ||
| field_type = self.__class__.doc_type._get_field_type(field) | ||
|
|
||
|
|
@@ -299,7 +275,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocListProto') -> T: | |
| return super().from_protobuf(pb_msg) | ||
|
|
||
| @overload | ||
| def __getitem__(self, item: int) -> T_doc: | ||
| def __getitem__(self, item: SupportsIndex) -> T_doc: | ||
| ... | ||
|
|
||
| @overload | ||
|
|
@@ -308,3 +284,11 @@ def __getitem__(self: T, item: IndexIterType) -> T: | |
|
|
||
| def __getitem__(self, item): | ||
| return super().__getitem__(item) | ||
|
|
||
| @classmethod | ||
| def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather have this redefined at different inheritance levels
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure to understand how it will look like
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not quite understand this need? what |
||
|
|
||
| if isinstance(item, type) and issubclass(item, BaseDoc): | ||
| return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore | ||
| else: | ||
| return super().__class_getitem__(item) | ||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note for myself to correct this docstring in my fastapi PR