Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 33 additions & 49 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import io
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
MutableSequence,
Expand All @@ -15,15 +13,13 @@
overload,
)

from typing_extensions import SupportsIndex
from typing_inspect import is_union_type

from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.io import IOMixinArray
from docarray.array.doc_list.pushpull import PushPullMixin
from docarray.array.doc_list.sequence_indexing_mixin import (
IndexingSequenceMixin,
IndexIterType,
)
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.typing import NdArray

Expand All @@ -40,25 +36,11 @@
T_doc = TypeVar('T_doc', bound=BaseDoc)


def _delegate_meth_to_data(meth_name: str) -> Callable:
"""
create a function that mimic a function call to the data attribute of the
DocList

:param meth_name: name of the method
:return: a method that mimic the meth_name
"""
func = getattr(list, meth_name)

@wraps(func)
def _delegate_meth(self, *args, **kwargs):
return getattr(self._data, meth_name)(*args, **kwargs)

return _delegate_meth


class DocList(
IndexingSequenceMixin[T_doc], PushPullMixin, IOMixinArray, AnyDocArray[T_doc]
ListAdvancedIndexing[T_doc],
PushPullMixin,
IOMixinArray,
AnyDocArray[T_doc],
):
"""
DocList is a container of Documents.
Expand Down Expand Up @@ -129,8 +111,13 @@ class Image(BaseDoc):
def __init__(
self,
docs: Optional[Iterable[T_doc]] = None,
validate_input_docs: bool = True,
):
self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else []
if validate_input_docs:
docs = self._validate_docs(docs) if docs else []
else:
docs = docs if docs else []
super().__init__(docs)

@classmethod
def construct(
Expand All @@ -143,9 +130,7 @@ def construct(
:param docs: a Sequence (list) of Document with the same schema
:return: a `DocList` object
"""
new_docs = cls.__new__(cls)
new_docs._data = docs if isinstance(docs, list) else list(docs)
return new_docs
return cls(docs, False)

def __eq__(self, other: Any) -> bool:
if self.__len__() != other.__len__():
Expand All @@ -168,12 +153,6 @@ def _validate_one_doc(self, doc: T_doc) -> T_doc:
raise ValueError(f'{doc} is not a {self.doc_type}')
return doc

def __len__(self):
return len(self._data)

def __iter__(self):
return iter(self._data)

def __bytes__(self) -> bytes:
with io.BytesIO() as bf:
self._write_bytes(bf=bf)
Expand All @@ -185,7 +164,7 @@ def append(self, doc: T_doc):
as the `.doc_type` of this `DocList` otherwise it will fail.
:param doc: A Document
"""
self._data.append(self._validate_one_doc(doc))
super().append(self._validate_one_doc(doc))

def extend(self, docs: Iterable[T_doc]):
"""
Expand All @@ -194,31 +173,28 @@ def extend(self, docs: Iterable[T_doc]):
fail.
:param docs: Iterable of Documents
"""
self._data.extend(self._validate_docs(docs))
super().extend(self._validate_docs(docs))

def insert(self, i: int, doc: T_doc):
def insert(self, i: SupportsIndex, doc: T_doc):
"""
Insert a Document to the `DocList`. The Document must be from the same
class as the doc_type of this `DocList` otherwise it will fail.
:param i: index to insert
:param doc: A Document
"""
self._data.insert(i, self._validate_one_doc(doc))

pop = _delegate_meth_to_data('pop')
remove = _delegate_meth_to_data('remove')
reverse = _delegate_meth_to_data('reverse')
sort = _delegate_meth_to_data('sort')
super().insert(i, self._validate_one_doc(doc))

def _get_data_column(
self: T,
field: str,
) -> Union[MutableSequence, T, 'TorchTensor', 'NdArray']:
"""Return all values of the fields from all docs this doc_list contains

:param field: name of the fields to extract
:return: Returns a list of the field value for each document
in the doc_list like container
"""Return all v @classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):alues of the fields from all docs this doc_list contains
@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
:param field: name of the fields to extract
:return: Returns a list of the field value for each document
in the doc_list like container
Comment on lines +191 to +197
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for myself to correct this docstring in my fastapi PR

"""
field_type = self.__class__.doc_type._get_field_type(field)

Expand Down Expand Up @@ -299,7 +275,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocListProto') -> T:
return super().from_protobuf(pb_msg)

@overload
def __getitem__(self, item: int) -> T_doc:
def __getitem__(self, item: SupportsIndex) -> T_doc:
...

@overload
Expand All @@ -308,3 +284,11 @@ def __getitem__(self: T, item: IndexIterType) -> T:

def __getitem__(self, item):
return super().__getitem__(item)

@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have this redefined at different inheritance levels

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure to understand how it will look like

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not quite understand this need? what super has another _class_getitem?


if isinstance(item, type) and issubclass(item, BaseDoc):
return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore
else:
return super().__class_getitem__(item)
10 changes: 1 addition & 9 deletions docarray/array/doc_list/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __getitem__(self, item: slice):

class IOMixinArray(Iterable[T_doc]):
doc_type: Type[T_doc]
_data: List[T_doc]

@abstractmethod
def __len__(self):
Expand Down Expand Up @@ -327,14 +326,7 @@ def to_json(self) -> bytes:
"""Convert the object into JSON bytes. Can be loaded via `.from_json`.
:return: JSON serialization of `DocList`
"""
return orjson_dumps(self._data)

def _docarray_to_json_compatible(self) -> List[T_doc]:
"""
Convert itself into a json compatible object
:return: A list of documents
"""
return self._data
return orjson_dumps(self)

@classmethod
def from_csv(
Expand Down
2 changes: 1 addition & 1 deletion docarray/array/doc_vec/column_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Union,
)

from docarray.array.doc_vec.list_advance_indexing import ListAdvancedIndexing
from docarray.array.list_advance_indexing import ListAdvancedIndexing
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor

Expand Down
6 changes: 3 additions & 3 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.doc_list import DocList
from docarray.array.doc_vec.column_storage import ColumnStorage, ColumnStorageView
from docarray.array.doc_vec.list_advance_indexing import ListAdvancedIndexing
from docarray.array.list_advance_indexing import ListAdvancedIndexing
from docarray.base_doc import BaseDoc
from docarray.base_doc.mixins.io import _type_to_protobuf
from docarray.typing import NdArray
Expand Down Expand Up @@ -271,9 +271,9 @@ def _get_data_column(
in the array like container
"""
if field in self._storage.any_columns.keys():
return self._storage.any_columns[field].data
return self._storage.any_columns[field]
elif field in self._storage.docs_vec_columns.keys():
return self._storage.docs_vec_columns[field].data
return self._storage.docs_vec_columns[field]
elif field in self._storage.columns.keys():
return self._storage.columns[field]
else:
Expand Down
41 changes: 0 additions & 41 deletions docarray/array/doc_vec/list_advance_indexing.py

This file was deleted.

Loading