diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20c8fcdc750..e7c5cd3687c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,4 @@ repos: -- repo: https://github.com/terrencepreilly/darglint - rev: v1.5.8 - hooks: - - id: darglint - files: docarray/ - exclude: ^(docarray/proto/docarray_pb2.py|docs/|docarray/resources/) - args: - - --message-template={path}:{line} {msg_id} {msg} - - -s=sphinx - - -z=full - - -v=2 -- repo: https://github.com/pycqa/pydocstyle - rev: 5.1.1 # pick a git hash / tag to point to - hooks: - - id: pydocstyle - files: docarray/ - exclude: ^(docarray/proto/docarray_pb2.py|docs/|docarray/resources/) - args: - - --select=D101,D102,D103 - repo: https://github.com/ambv/black rev: 20.8b1 hooks: diff --git a/docarray/__init__.py b/docarray/__init__.py index 48de1627b3b..ba5f8fa1b7b 100644 --- a/docarray/__init__.py +++ b/docarray/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.4.5' +__version__ = '0.5.0' from .document import Document from .array import DocumentArray diff --git a/docarray/array/base.py b/docarray/array/base.py new file mode 100644 index 00000000000..dd1721c362b --- /dev/null +++ b/docarray/array/base.py @@ -0,0 +1,9 @@ +from typing import MutableSequence + +from .. import Document + + +class BaseDocumentArray(MutableSequence[Document]): + def __init__(self, *args, storage: str = 'memory', **kwargs): + super().__init__() + self._init_storage(*args, **kwargs) diff --git a/docarray/array/chunk.py b/docarray/array/chunk.py index ea445702f62..246b162152d 100644 --- a/docarray/array/chunk.py +++ b/docarray/array/chunk.py @@ -7,12 +7,13 @@ ) from .document import DocumentArray +from .memory import DocumentArrayInMemory if TYPE_CHECKING: from ..document import Document -class ChunkArray(DocumentArray): +class ChunkArray(DocumentArrayInMemory): """ :class:`ChunkArray` inherits from :class:`DocumentArray`. It's a subset of Documents. diff --git a/docarray/array/document.py b/docarray/array/document.py index 951a86c9c37..f21bab69a33 100644 --- a/docarray/array/document.py +++ b/docarray/array/document.py @@ -1,399 +1,48 @@ -import itertools -from typing import ( - Optional, - TYPE_CHECKING, - Generator, - Iterator, - Dict, - Union, - MutableSequence, - Sequence, - Iterable, - overload, - Any, - List, -) - -import numpy as np +from typing import Optional, overload, TYPE_CHECKING, Dict, Union +from .base import BaseDocumentArray from .mixins import AllMixins -from .. import Document -from ..helper import typename if TYPE_CHECKING: from ..types import ( DocumentArraySourceType, - DocumentArrayIndexType, - DocumentArraySingletonIndexType, - DocumentArrayMultipleIndexType, - DocumentArrayMultipleAttributeType, - DocumentArraySingleAttributeType, + DocumentArrayLike, + DocumentArraySqlite, + DocumentArrayInMemory, ) + from .storage.sqlite import SqliteConfig -class DocumentArray(AllMixins, MutableSequence[Document]): - def __init__( - self, docs: Optional['DocumentArraySourceType'] = None, copy: bool = False - ): - super().__init__() - self._data = [] - if docs is None: - return - elif isinstance( - docs, (DocumentArray, Sequence, Generator, Iterator, itertools.chain) - ): - if copy: - self._data = [Document(d, copy=True) for d in docs] - self._rebuild_id2offset() - elif isinstance(docs, DocumentArray): - self._data = docs._data - self._id_to_index = docs._id2offset - else: - self._data = list(docs) - self._rebuild_id2offset() - else: - if isinstance(docs, Document): - if copy: - self.append(Document(docs, copy=True)) - else: - self.append(docs) - - @property - def _id2offset(self) -> Dict[str, int]: - """Return the `_id_to_index` map - - :return: a Python dict. - """ - if not hasattr(self, '_id_to_index'): - self._rebuild_id2offset() - return self._id_to_index - - def _rebuild_id2offset(self) -> None: - """Update the id_to_index map by enumerating all Documents in self._data. - - Very costy! Only use this function when self._data is dramtically changed. - """ - - self._id_to_index = { - d.id: i for i, d in enumerate(self._data) - } # type: Dict[str, int] - - def insert(self, index: int, value: 'Document'): - """Insert `doc` at `index`. - - :param index: Position of the insertion. - :param value: The doc needs to be inserted. - """ - self._data.insert(index, value) - self._id2offset[value.id] = index - - def __eq__(self, other): - return ( - type(self) is type(other) - and type(self._data) is type(other._data) - and self._data == other._data - ) - - def __len__(self): - return len(self._data) - - def __iter__(self) -> Iterator['Document']: - yield from self._data - - def __contains__(self, x: Union[str, 'Document']): - if isinstance(x, str): - return x in self._id2offset - elif isinstance(x, Document): - return x.id in self._id2offset - else: - return False - +class DocumentArray(AllMixins, BaseDocumentArray): @overload - def __getitem__(self, index: 'DocumentArraySingletonIndexType') -> 'Document': + def __new__( + cls, _docs: Optional['DocumentArraySourceType'] = None, copy: bool = False + ) -> 'DocumentArrayInMemory': + """Create an in-memory DocumentArray object.""" ... @overload - def __getitem__(self, index: 'DocumentArrayMultipleIndexType') -> 'DocumentArray': + def __new__( + cls, + _docs: Optional['DocumentArraySourceType'] = None, + storage: str = 'sqlite', + config: Optional[Union['SqliteConfig', Dict]] = None, + ) -> 'DocumentArraySqlite': + """Create a SQLite-powered DocumentArray object.""" ... - @overload - def __getitem__(self, index: 'DocumentArraySingleAttributeType') -> List[Any]: - ... + def __new__(cls, *args, storage: str = 'memory', **kwargs) -> 'DocumentArrayLike': + if cls is DocumentArray: + if storage == 'memory': + from .memory import DocumentArrayInMemory - @overload - def __getitem__( - self, index: 'DocumentArrayMultipleAttributeType' - ) -> List[List[Any]]: - ... + instance = super().__new__(DocumentArrayInMemory) + elif storage == 'sqlite': + from .sqlite import DocumentArraySqlite - def __getitem__( - self, index: 'DocumentArrayIndexType' - ) -> Union['Document', 'DocumentArray']: - if isinstance(index, (int, np.generic)) and not isinstance(index, bool): - return self._data[int(index)] - elif isinstance(index, str): - if index.startswith('@'): - return self.traverse_flat(index[1:]) - else: - return self._data[self._id2offset[index]] - elif isinstance(index, slice): - return DocumentArray(self._data[index]) - elif index is Ellipsis: - return self.flatten() - elif isinstance(index, Sequence): - if ( - isinstance(index, tuple) - and len(index) == 2 - and isinstance(index[0], (slice, Sequence)) - ): - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self._id2offset: - return DocumentArray([self[index[0]], self[index[1]]]) - else: - return getattr(self[index[0]], index[1]) - elif isinstance(index[0], (slice, Sequence)): - _docs = self[index[0]] - _attrs = index[1] - if isinstance(_attrs, str): - _attrs = (index[1],) - return _docs._get_attributes(*_attrs) - elif isinstance(index[0], bool): - return DocumentArray(itertools.compress(self._data, index)) - elif isinstance(index[0], int): - return DocumentArray(self._data[t] for t in index) - elif isinstance(index[0], str): - return DocumentArray(self._data[self._id2offset[t]] for t in index) - elif isinstance(index, np.ndarray): - index = index.squeeze() - if index.ndim == 1: - return self[index.tolist()] + instance = super().__new__(DocumentArraySqlite) else: - raise IndexError( - f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' - ) - raise IndexError(f'Unsupported index type {typename(index)}: {index}') - - @overload - def __setitem__( - self, - index: 'DocumentArrayMultipleAttributeType', - value: List[List['Any']], - ): - ... - - @overload - def __setitem__( - self, - index: 'DocumentArraySingleAttributeType', - value: List['Any'], - ): - ... - - @overload - def __setitem__( - self, - index: 'DocumentArraySingletonIndexType', - value: 'Document', - ): - ... - - @overload - def __setitem__( - self, - index: 'DocumentArrayMultipleIndexType', - value: Sequence['Document'], - ): - ... - - def __setitem__( - self, - index: 'DocumentArrayIndexType', - value: Union['Document', Sequence['Document']], - ): - - if isinstance(index, (int, np.generic)) and not isinstance(index, bool): - index = int(index) - self._data[index] = value - self._id2offset[value.id] = index - elif isinstance(index, str): - if index.startswith('@'): - for _d, _v in zip(self.traverse_flat(index[1:]), value): - _d._data = _v._data - self._rebuild_id2offset() - else: - old_idx = self._id2offset.pop(index) - self._data[old_idx] = value - self._id2offset[value.id] = old_idx - elif isinstance(index, slice): - self._data[index] = value - self._rebuild_id2offset() - elif index is Ellipsis: - for _d, _v in zip(self.flatten(), value): - _d._data = _v._data - self._rebuild_id2offset() - elif isinstance(index, Sequence): - if ( - isinstance(index, tuple) - and len(index) == 2 - and isinstance(index[0], (slice, Sequence)) - ): - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self._id2offset: - for _d, _v in zip((self[index[0]], self[index[1]]), value): - _d._data = _v._data - self._rebuild_id2offset() - elif hasattr(self[index[0]], index[1]): - setattr(self[index[0]], index[1], value) - else: - # to avoid accidentally add new unsupport attribute - raise ValueError( - f'`{index[1]}` is neither a valid id nor attribute name' - ) - elif isinstance(index[0], (slice, Sequence)): - _docs = self[index[0]] - _attrs = index[1] - - if isinstance(_attrs, str): - # a -> [a] - # [a, a] -> [a, a] - _attrs = (index[1],) - if isinstance(value, (list, tuple)) and not any( - isinstance(el, (tuple, list)) for el in value - ): - # [x] -> [[x]] - # [[x], [y]] -> [[x], [y]] - value = (value,) - if not isinstance(value, (list, tuple)): - # x -> [x] - value = (value,) - - for _a, _v in zip(_attrs, value): - if _a == 'tensor': - _docs.tensors = _v - elif _a == 'embedding': - _docs.embeddings = _v - else: - if len(_docs) == 1: - setattr(_docs[0], _a, _v) - else: - for _d, _vv in zip(_docs, _v): - setattr(_d, _a, _vv) - elif isinstance(index[0], bool): - if len(index) != len(self._data): - raise IndexError( - f'Boolean mask index is required to have the same length as {len(self._data)}, ' - f'but receiving {len(index)}' - ) - _selected = itertools.compress(self._data, index) - for _idx, _val in zip(_selected, value): - self[_idx.id] = _val - elif isinstance(index[0], (int, str)): - if not isinstance(value, Sequence) or len(index) != len(value): - raise ValueError( - f'Number of elements for assigning must be ' - f'the same as the index length: {len(index)}' - ) - if isinstance(value, Document): - for si in index: - self[si] = value - else: - for si, _val in zip(index, value): - self[si] = _val - elif isinstance(index, np.ndarray): - index = index.squeeze() - if index.ndim == 1: - self[index.tolist()] = value - else: - raise IndexError( - f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' - ) - else: - raise IndexError(f'Unsupported index type {typename(index)}: {index}') - - def __delitem__(self, index: 'DocumentArrayIndexType'): - if isinstance(index, (int, np.generic)) and not isinstance(index, bool): - index = int(index) - self._id2offset.pop(self._data[index].id) - del self._data[index] - elif isinstance(index, str): - if index.startswith('@'): - raise NotImplementedError( - 'Delete elements along traversal paths is not implemented' - ) - else: - del self._data[self._id2offset[index]] - self._id2offset.pop(index) - elif isinstance(index, slice): - del self._data[index] - self._rebuild_id2offset() - elif index is Ellipsis: - self._data.clear() - self._id2offset.clear() - elif isinstance(index, Sequence): - if ( - isinstance(index, tuple) - and len(index) == 2 - and isinstance(index[0], (slice, Sequence)) - ): - if isinstance(index[0], str) and isinstance(index[1], str): - # ambiguity only comes from the second string - if index[1] in self._id2offset: - del self[index[0]] - del self[index[1]] - else: - self[index[0]].pop(index[1]) - elif isinstance(index[0], (slice, Sequence)): - _docs = self[index[0]] - _attrs = index[1] - if isinstance(_attrs, str): - _attrs = (index[1],) - for _d in _docs: - _d.pop(*_attrs) - elif isinstance(index[0], bool): - self._data = list( - itertools.compress(self._data, (not _i for _i in index)) - ) - self._rebuild_id2offset() - elif isinstance(index[0], int): - for t in sorted(index, reverse=True): - del self[t] - elif isinstance(index[0], str): - for t in index: - del self[t] - elif isinstance(index, np.ndarray): - index = index.squeeze() - if index.ndim == 1: - del self[index.tolist()] - else: - raise IndexError( - f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' - ) + raise ValueError(f'storage=`{storage}` is not supported.') else: - raise IndexError(f'Unsupported index type {typename(index)}: {index}') - - def clear(self): - """Clear the data of :class:`DocumentArray`""" - self._data.clear() - self._id2offset.clear() - - def __bool__(self): - """To simulate ```l = []; if l: ...``` - - :return: returns true if the length of the array is larger than 0 - """ - return len(self) > 0 - - def __repr__(self): - return f'<{self.__class__.__name__} (length={len(self)}) at {id(self)}>' - - def __add__(self, other: Union['Document', Sequence['Document']]): - v = type(self)() - v.extend(self) - v.extend(other) - return v - - def extend(self, values: Iterable['Document']) -> None: - self._data.extend(values) - self._rebuild_id2offset() + instance = super().__new__(cls) + return instance diff --git a/docarray/array/match.py b/docarray/array/match.py index 2b33828f3fb..07d2e8b2ced 100644 --- a/docarray/array/match.py +++ b/docarray/array/match.py @@ -6,13 +6,14 @@ Sequence, ) -from .. import DocumentArray +from .document import DocumentArray +from .memory import DocumentArrayInMemory if TYPE_CHECKING: from ..document import Document -class MatchArray(DocumentArray): +class MatchArray(DocumentArrayInMemory): """ :class:`MatchArray` inherits from :class:`DocumentArray`. It's a subset of Documents that represents the matches diff --git a/docarray/array/memory.py b/docarray/array/memory.py new file mode 100644 index 00000000000..d97ff54656f --- /dev/null +++ b/docarray/array/memory.py @@ -0,0 +1,7 @@ +from .document import DocumentArray +from .storage.memory import StorageMixins + + +class DocumentArrayInMemory(StorageMixins, DocumentArray): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) diff --git a/docarray/array/mixins/__init__.py b/docarray/array/mixins/__init__.py index 61b24a546a0..f5dbd8df4d6 100644 --- a/docarray/array/mixins/__init__.py +++ b/docarray/array/mixins/__init__.py @@ -1,10 +1,12 @@ from abc import ABC from .content import ContentPropertyMixin +from .delitem import DelItemMixin from .embed import EmbedMixin from .empty import EmptyMixin from .evaluation import EvaluationMixin from .getattr import GetAttributeMixin +from .getitem import GetItemMixin from .group import GroupMixin from .io.binary import BinaryIOMixin from .io.common import CommonIOMixin @@ -16,15 +18,19 @@ from .match import MatchMixin from .parallel import ParallelMixin from .plot import PlotMixin +from .pydantic import PydanticMixin from .reduce import ReduceMixin from .sample import SampleMixin +from .setitem import SetItemMixin from .text import TextToolsMixin from .traverse import TraverseMixin -from .pydantic import PydanticMixin class AllMixins( GetAttributeMixin, + GetItemMixin, + SetItemMixin, + DelItemMixin, ContentPropertyMixin, PydanticMixin, GroupMixin, @@ -47,6 +53,6 @@ class AllMixins( DataframeIOMixin, ABC, ): - """All plugins that can be used in :class:`DocumentArray`. """ + """All plugins that can be used in :class:`DocumentArray`.""" ... diff --git a/docarray/array/mixins/delitem.py b/docarray/array/mixins/delitem.py new file mode 100644 index 00000000000..b7f796aff98 --- /dev/null +++ b/docarray/array/mixins/delitem.py @@ -0,0 +1,71 @@ +from typing import ( + TYPE_CHECKING, + Sequence, +) + +import numpy as np + +from ...helper import typename + +if TYPE_CHECKING: + from ...types import ( + DocumentArrayIndexType, + ) + + +class DelItemMixin: + """Provide help function to enable advanced indexing in `__delitem__`""" + + def __delitem__(self, index: 'DocumentArrayIndexType'): + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): + self._del_doc_by_offset(int(index)) + + elif isinstance(index, str): + if index.startswith('@'): + raise NotImplementedError( + 'Delete elements along traversal paths is not implemented' + ) + else: + self._del_doc_by_id(index) + elif isinstance(index, slice): + self._del_docs_by_slice(index) + elif index is Ellipsis: + self._del_all_docs() + elif isinstance(index, Sequence): + if ( + isinstance(index, tuple) + and len(index) == 2 + and isinstance(index[0], (slice, Sequence)) + ): + if isinstance(index[0], str) and isinstance(index[1], str): + # ambiguity only comes from the second string + if index[1] in self: + del self[index[0]] + del self[index[1]] + else: + self._set_doc_attr_by_id(index[0], index[1], None) + elif isinstance(index[0], (slice, Sequence)): + _attrs = index[1] + if isinstance(_attrs, str): + _attrs = (index[1],) + for _d in self[index[0]]: + for _aa in _attrs: + self._set_doc_attr_by_id(_d.id, _aa, None) + elif isinstance(index[0], bool): + self._del_docs_by_mask(index) + elif isinstance(index[0], int): + for t in sorted(index, reverse=True): + del self[t] + elif isinstance(index[0], str): + for t in index: + del self[t] + elif isinstance(index, np.ndarray): + index = index.squeeze() + if index.ndim == 1: + del self[index.tolist()] + else: + raise IndexError( + f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' + ) + else: + raise IndexError(f'Unsupported index type {typename(index)}: {index}') diff --git a/docarray/array/mixins/getitem.py b/docarray/array/mixins/getitem.py new file mode 100644 index 00000000000..308d1731f48 --- /dev/null +++ b/docarray/array/mixins/getitem.py @@ -0,0 +1,98 @@ +import itertools +from typing import ( + TYPE_CHECKING, + Union, + Sequence, + overload, + Any, + List, +) + +import numpy as np + +from ... import Document +from ...helper import typename + +if TYPE_CHECKING: + from ...types import ( + DocumentArrayIndexType, + DocumentArraySingletonIndexType, + DocumentArrayMultipleIndexType, + DocumentArrayMultipleAttributeType, + DocumentArraySingleAttributeType, + ) + from ... import DocumentArray + + +class GetItemMixin: + """Provide helper functions to enable advance indexing in `__getitem__`""" + + @overload + def __getitem__(self, index: 'DocumentArraySingletonIndexType') -> 'Document': + ... + + @overload + def __getitem__(self, index: 'DocumentArrayMultipleIndexType') -> 'DocumentArray': + ... + + @overload + def __getitem__(self, index: 'DocumentArraySingleAttributeType') -> List[Any]: + ... + + @overload + def __getitem__( + self, index: 'DocumentArrayMultipleAttributeType' + ) -> List[List[Any]]: + ... + + def __getitem__( + self, index: 'DocumentArrayIndexType' + ) -> Union['Document', 'DocumentArray']: + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): + return self._get_doc_by_offset(int(index)) + elif isinstance(index, str): + if index.startswith('@'): + return self.traverse_flat(index[1:]) + else: + return self._get_doc_by_id(index) + elif isinstance(index, slice): + from ... import DocumentArray + + return DocumentArray(self._get_docs_by_slice(index)) + elif index is Ellipsis: + return self.flatten() + elif isinstance(index, Sequence): + from ... import DocumentArray + + if ( + isinstance(index, tuple) + and len(index) == 2 + and isinstance(index[0], (slice, Sequence)) + ): + if isinstance(index[0], str) and isinstance(index[1], str): + # ambiguity only comes from the second string + if index[1] in self: + return DocumentArray([self[index[0]], self[index[1]]]) + else: + return getattr(self[index[0]], index[1]) + elif isinstance(index[0], (slice, Sequence)): + _docs = self[index[0]] + _attrs = index[1] + if isinstance(_attrs, str): + _attrs = (index[1],) + return _docs._get_attributes(*_attrs) + elif isinstance(index[0], bool): + return DocumentArray(itertools.compress(self, index)) + elif isinstance(index[0], int): + return DocumentArray(self._get_docs_by_offsets(index)) + elif isinstance(index[0], str): + return DocumentArray(self._get_docs_by_ids(index)) + elif isinstance(index, np.ndarray): + index = index.squeeze() + if index.ndim == 1: + return self[index.tolist()] + else: + raise IndexError( + f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' + ) + raise IndexError(f'Unsupported index type {typename(index)}: {index}') diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py new file mode 100644 index 00000000000..245631a1ca3 --- /dev/null +++ b/docarray/array/mixins/setitem.py @@ -0,0 +1,158 @@ +import itertools +from typing import ( + TYPE_CHECKING, + Union, + Sequence, + overload, + Any, + List, +) + +import numpy as np + +from ... import Document +from ...helper import typename + +if TYPE_CHECKING: + from ...types import ( + DocumentArrayIndexType, + DocumentArraySingletonIndexType, + DocumentArrayMultipleIndexType, + DocumentArrayMultipleAttributeType, + DocumentArraySingleAttributeType, + ) + + +class SetItemMixin: + """Provides helper function to allow advanced indexing for `__setitem__`""" + + @overload + def __setitem__( + self, + index: 'DocumentArrayMultipleAttributeType', + value: List[List['Any']], + ): + ... + + @overload + def __setitem__( + self, + index: 'DocumentArraySingleAttributeType', + value: List['Any'], + ): + ... + + @overload + def __setitem__( + self, + index: 'DocumentArraySingletonIndexType', + value: 'Document', + ): + ... + + @overload + def __setitem__( + self, + index: 'DocumentArrayMultipleIndexType', + value: Sequence['Document'], + ): + ... + + def __setitem__( + self, + index: 'DocumentArrayIndexType', + value: Union['Document', Sequence['Document']], + ): + + if isinstance(index, (int, np.generic)) and not isinstance(index, bool): + self._set_doc_by_offset(int(index), value) + elif isinstance(index, str): + if index.startswith('@'): + self._set_doc_value_pairs(self.traverse_flat(index[1:]), value) + else: + self._set_doc_by_id(index, value) + elif isinstance(index, slice): + self._set_docs_by_slice(index, value) + elif index is Ellipsis: + self._set_doc_value_pairs(self.flatten(), value) + elif isinstance(index, Sequence): + if ( + isinstance(index, tuple) + and len(index) == 2 + and isinstance(index[0], (slice, Sequence)) + ): + if isinstance(index[0], str) and isinstance(index[1], str): + # ambiguity only comes from the second string + if index[1] in self: + self._set_doc_value_pairs( + (self[index[0]], self[index[1]]), value + ) + elif hasattr(self[index[0]], index[1]): + self._set_doc_attr_by_id(index[0], index[1], value) + else: + # to avoid accidentally add new unsupport attribute + raise ValueError( + f'`{index[1]}` is neither a valid id nor attribute name' + ) + elif isinstance(index[0], (slice, Sequence)): + _attrs = index[1] + + if isinstance(_attrs, str): + # a -> [a] + # [a, a] -> [a, a] + _attrs = (index[1],) + if isinstance(value, (list, tuple)) and not any( + isinstance(el, (tuple, list)) for el in value + ): + # [x] -> [[x]] + # [[x], [y]] -> [[x], [y]] + value = (value,) + if not isinstance(value, (list, tuple)): + # x -> [x] + value = (value,) + + _docs = self[index[0]] + for _a, _v in zip(_attrs, value): + if _a in ('tensor', 'embedding'): + if _a == 'tensor': + _docs.tensors = _v + elif _a == 'embedding': + _docs.embeddings = _v + for _d in _docs: + self._set_doc_by_id(_d.id, _d) + else: + if len(_docs) == 1: + self._set_doc_attr_by_id(_docs[0].id, _a, _v) + else: + for _d, _vv in zip(_docs, _v): + self._set_doc_attr_by_id(_d.id, _a, _vv) + elif isinstance(index[0], bool): + if len(index) != len(self): + raise IndexError( + f'Boolean mask index is required to have the same length as {len(self._data)}, ' + f'but receiving {len(index)}' + ) + _selected = itertools.compress(self, index) + self._set_doc_value_pairs(_selected, value) + elif isinstance(index[0], (int, str)): + if not isinstance(value, Sequence) or len(index) != len(value): + raise ValueError( + f'Number of elements for assigning must be ' + f'the same as the index length: {len(index)}' + ) + if isinstance(value, Document): + for si in index: + self[si] = value # leverage existing setter + else: + for si, _val in zip(index, value): + self[si] = _val # leverage existing setter + elif isinstance(index, np.ndarray): + index = index.squeeze() + if index.ndim == 1: + self[index.tolist()] = value # leverage existing setter + else: + raise IndexError( + f'When using np.ndarray as index, its `ndim` must =1. However, receiving ndim={index.ndim}' + ) + else: + raise IndexError(f'Unsupported index type {typename(index)}: {index}') diff --git a/docarray/array/sqlite.py b/docarray/array/sqlite.py new file mode 100644 index 00000000000..c203ef5dc53 --- /dev/null +++ b/docarray/array/sqlite.py @@ -0,0 +1,7 @@ +from .document import DocumentArray +from .storage.sqlite import StorageMixins + + +class DocumentArraySqlite(StorageMixins, DocumentArray): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) diff --git a/docarray/array/storage/__init__.py b/docarray/array/storage/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/storage/base/__init__.py b/docarray/array/storage/base/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py new file mode 100644 index 00000000000..06ece2b12e6 --- /dev/null +++ b/docarray/array/storage/base/backend.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod + + +class BaseBackendMixin(ABC): + @abstractmethod + def _init_storage(self, *args, **kwargs): + ... diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py new file mode 100644 index 00000000000..72ee83e200c --- /dev/null +++ b/docarray/array/storage/base/getsetdel.py @@ -0,0 +1,156 @@ +from abc import abstractmethod, ABC +from typing import ( + Sequence, + Any, + Iterable, +) + +from .... import Document + + +class BaseGetSetDelMixin(ABC): + """Provide abstract methods and derived methods for ``__getitem__``, ``__setitem__`` and ``__delitem__`` + + .. note:: + The following methods must be implemented: + - :meth:`._get_doc_by_offset` + - :meth:`._get_doc_by_id` + - :meth:`._set_doc_by_offset` + - :meth:`._set_doc_by_id` + - :meth:`._del_doc_by_offset` + - :meth:`._del_doc_by_id` + + Other methods implemented a generic-but-slow version that leverage the methods above. + Please override those methods in the subclass whenever a more efficient implementation is available. + """ + + # Getitem APIs + + @abstractmethod + def _get_doc_by_offset(self, offset: int) -> 'Document': + ... + + @abstractmethod + def _get_doc_by_id(self, _id: str) -> 'Document': + ... + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_offset` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_offset(o) for o in range(len(self))[_slice]) + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_offset` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_offset(o) for o in offsets) + + def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: + """This function is derived from :meth:`_get_doc_by_id` + + Override this function if there is a more efficient logic""" + return (self._get_doc_by_id(_id) for _id in ids) + + # Delitem APIs + + @abstractmethod + def _del_doc_by_offset(self, offset: int): + ... + + @abstractmethod + def _del_doc_by_id(self, _id: str): + ... + + def _del_docs_by_slice(self, _slice: slice): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic""" + for j in range(len(self))[_slice]: + self._del_doc_by_offset(j) + + def _del_docs_by_mask(self, mask: Sequence[bool]): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic""" + for idx, m in enumerate(mask): + if not m: + self._del_doc_by_offset(idx) + + def _del_all_docs(self): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic""" + for j in range(len(self)): + self._del_doc_by_offset(j) + + # Setitem API + + @abstractmethod + def _set_doc_by_offset(self, offset: int, value: 'Document'): + ... + + @abstractmethod + def _set_doc_by_id(self, _id: str, value: 'Document'): + ... + + def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic + """ + if not isinstance(value, Iterable): + raise TypeError( + f'You right-hand assignment must be an iterable, receiving {type(value)}' + ) + for _offset, val in zip(range(len(self))[_slice], value): + self._set_doc_by_offset(_offset, val) + + def _set_doc_value_pairs( + self, docs: Iterable['Document'], values: Iterable['Document'] + ): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic + """ + for _d, _v in zip(docs, values): + _d._data = _v._data + + for _d in docs: + if _d not in docs: + root_d = self._find_root_doc(_d) + else: + # _d is already on the root-level + root_d = _d + + if root_d: + self._set_doc_by_id(root_d.id, root_d) + + def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic + """ + d = self._get_doc_by_offset(offset) + if hasattr(d, attr): + setattr(d, attr, value) + self._set_doc_by_offset(offset, d) + + def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): + """This function is derived and may not have the most efficient implementation. + + Override this function if there is a more efficient logic + """ + d = self._get_doc_by_id(_id) + if hasattr(d, attr): + setattr(d, attr, value) + self._set_doc_by_id(d.id, d) + + def _find_root_doc(self, d: Document): + """Find `d`'s root Document in an exhaustive manner""" + from docarray import DocumentArray + + for _d in self: + _all_ids = set(DocumentArray(d)[...][:, 'id']) + if d.id in _all_ids: + return _d diff --git a/docarray/array/storage/memory/__init__.py b/docarray/array/storage/memory/__init__.py new file mode 100644 index 00000000000..cf9d9e7f97d --- /dev/null +++ b/docarray/array/storage/memory/__init__.py @@ -0,0 +1,11 @@ +from abc import ABC + +from .backend import BackendMixin +from .getsetdel import GetSetDelMixin +from .seqlike import SequenceLikeMixin + +__all__ = ['StorageMixins'] + + +class StorageMixins(BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): + ... diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py new file mode 100644 index 00000000000..ea596702c85 --- /dev/null +++ b/docarray/array/storage/memory/backend.py @@ -0,0 +1,68 @@ +import itertools +from typing import ( + Generator, + Iterator, + Dict, + Sequence, + Optional, + TYPE_CHECKING, +) + +from ..base.backend import BaseBackendMixin +from .... import Document + +if TYPE_CHECKING: + from ....types import ( + DocumentArraySourceType, + ) + + +class BackendMixin(BaseBackendMixin): + """Provide necessary functions to enable this storage backend.""" + + @property + def _id2offset(self) -> Dict[str, int]: + """Return the `_id_to_index` map + + :return: a Python dict. + """ + if not hasattr(self, '_id_to_index'): + self._rebuild_id2offset() + return self._id_to_index + + def _rebuild_id2offset(self) -> None: + """Update the id_to_index map by enumerating all Documents in self._data. + + Very costy! Only use this function when self._data is dramtically changed. + """ + + self._id_to_index = { + d.id: i for i, d in enumerate(self._data) + } # type: Dict[str, int] + + def _init_storage( + self, _docs: Optional['DocumentArraySourceType'] = None, copy: bool = False + ): + from ... import DocumentArray + + self._data = [] + if _docs is None: + return + elif isinstance( + _docs, (DocumentArray, Sequence, Generator, Iterator, itertools.chain) + ): + if copy: + self._data = [Document(d, copy=True) for d in _docs] + self._rebuild_id2offset() + elif isinstance(_docs, DocumentArray): + self._data = _docs._data + self._id_to_index = _docs._id2offset + else: + self._data = list(_docs) + self._rebuild_id2offset() + else: + if isinstance(_docs, Document): + if copy: + self.append(Document(_docs, copy=True)) + else: + self.append(_docs) diff --git a/docarray/array/storage/memory/getsetdel.py b/docarray/array/storage/memory/getsetdel.py new file mode 100644 index 00000000000..8ea62471cce --- /dev/null +++ b/docarray/array/storage/memory/getsetdel.py @@ -0,0 +1,74 @@ +import itertools +from typing import ( + Sequence, + Iterable, + Any, +) + +from ..base.getsetdel import BaseGetSetDelMixin +from .... import Document + + +class GetSetDelMixin(BaseGetSetDelMixin): + """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + + def _del_docs_by_mask(self, mask: Sequence[bool]): + self._data = list(itertools.compress(self._data, (not _i for _i in mask))) + self._rebuild_id2offset() + + def _del_all_docs(self): + self._data.clear() + self._id2offset.clear() + + def _del_docs_by_slice(self, _slice: slice): + del self._data[_slice] + self._rebuild_id2offset() + + def _del_doc_by_id(self, _id: str): + del self._data[self._id2offset[_id]] + self._id2offset.pop(_id) + + def _del_doc_by_offset(self, offset: int): + self._id2offset.pop(self._data[offset].id) + del self._data[offset] + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + self._data[offset] = value + self._id2offset[value.id] = offset + + def _set_doc_by_id(self, _id: str, value: 'Document'): + old_idx = self._id2offset.pop(_id) + self._data[old_idx] = value + self._id2offset[value.id] = old_idx + + def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): + self._data[_slice] = value + self._rebuild_id2offset() + + def _set_doc_value_pairs( + self, docs: Iterable['Document'], values: Iterable['Document'] + ): + for _d, _v in zip(docs, values): + _d._data = _v._data + self._rebuild_id2offset() + + def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): + setattr(self._data[offset], attr, value) + + def _set_doc_attr_by_id(self, _id: str, attr: str, value: Any): + setattr(self._data[self._id2offset[_id]], attr, value) + + def _get_doc_by_offset(self, offset: int) -> 'Document': + return self._data[offset] + + def _get_doc_by_id(self, _id: str) -> 'Document': + return self._data[self._id2offset[_id]] + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._data[_slice] + + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + return (self._data[t] for t in offsets) + + def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: + return (self._data[self._id2offset[t]] for t in ids) diff --git a/docarray/array/storage/memory/seqlike.py b/docarray/array/storage/memory/seqlike.py new file mode 100644 index 00000000000..0121780fadf --- /dev/null +++ b/docarray/array/storage/memory/seqlike.py @@ -0,0 +1,60 @@ +from typing import Iterator, Union, Sequence, Iterable, MutableSequence + +from .... import Document + + +class SequenceLikeMixin(MutableSequence[Document]): + """Implement sequence-like methods""" + + def insert(self, index: int, value: 'Document'): + """Insert `doc` at `index`. + + :param index: Position of the insertion. + :param value: The doc needs to be inserted. + """ + self._data.insert(index, value) + self._id2offset[value.id] = index + + def __eq__(self, other): + return ( + type(self) is type(other) + and type(self._data) is type(other._data) + and self._data == other._data + ) + + def __len__(self): + return len(self._data) + + def __iter__(self) -> Iterator['Document']: + yield from self._data + + def __contains__(self, x: Union[str, 'Document']): + if isinstance(x, str): + return x in self._id2offset + elif isinstance(x, Document): + return x.id in self._id2offset + else: + return False + + def clear(self): + """Clear the data of :class:`DocumentArray`""" + self._del_all_docs() + + def __bool__(self): + """To simulate ```l = []; if l: ...``` + + :return: returns true if the length of the array is larger than 0 + """ + return len(self) > 0 + + def __repr__(self): + return f'' + + def __add__(self, other: Union['Document', Sequence['Document']]): + v = type(self)(self) + v.extend(other) + return v + + def extend(self, values: Iterable['Document']) -> None: + self._data.extend(values) + self._rebuild_id2offset() diff --git a/docarray/array/storage/sqlite/__init__.py b/docarray/array/storage/sqlite/__init__.py new file mode 100644 index 00000000000..4d9bf2b291b --- /dev/null +++ b/docarray/array/storage/sqlite/__init__.py @@ -0,0 +1,11 @@ +from abc import ABC + +from .backend import BackendMixin, SqliteConfig +from .getsetdel import GetSetDelMixin +from .seqlike import SequenceLikeMixin + +__all__ = ['StorageMixins', 'SqliteConfig'] + + +class StorageMixins(BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): + ... diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py new file mode 100644 index 00000000000..d1605292143 --- /dev/null +++ b/docarray/array/storage/sqlite/backend.py @@ -0,0 +1,100 @@ +import sqlite3 +import warnings +from dataclasses import dataclass, field +from tempfile import NamedTemporaryFile +from typing import ( + Optional, + TYPE_CHECKING, + Union, + Dict, +) + +from .helper import initialize_table +from ..base.backend import BaseBackendMixin +from ....helper import random_identity, dataclass_from_dict + +if TYPE_CHECKING: + from ....types import ( + DocumentArraySourceType, + ) + + +def _sanitize_table_name(table_name: str) -> str: + ret = ''.join(c for c in table_name if c.isalnum() or c == '_') + if ret != table_name: + warnings.warn(f'The table name is changed to {ret} due to illegal characters') + return ret + + +@dataclass +class SqliteConfig: + connection: Optional[Union[str, 'sqlite3.Connection']] = None + table_name: Optional[str] = None + serialize_config: Dict = field(default_factory=dict) + conn_config: Dict = field(default_factory=dict) + + +class BackendMixin(BaseBackendMixin): + """Provide necessary functions to enable this storage backend.""" + + schema_version = '0' + + def _sql(self, *args, **kwargs) -> 'sqlite3.Cursor': + return self._cursor.execute(*args, **kwargs) + + def _commit(self): + self._connection.commit() + + @property + def _cursor(self) -> 'sqlite3.Cursor': + return self._connection.cursor() + + def _init_storage( + self, + _docs: Optional['DocumentArraySourceType'] = None, + config: Optional[Union[SqliteConfig, Dict]] = None, + ): + if not config: + config = SqliteConfig() + + if isinstance(config, dict): + config = dataclass_from_dict(SqliteConfig, config) + + from docarray import Document + + sqlite3.register_adapter( + Document, lambda d: d.to_bytes(**config.serialize_config) + ) + sqlite3.register_converter( + 'Document', lambda x: Document.from_bytes(x, **config.serialize_config) + ) + + _conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES) + _conn_kwargs.update(config.conn_config) + if config.connection is None: + self._connection = sqlite3.connect( + NamedTemporaryFile().name, **_conn_kwargs + ) + elif isinstance(config.connection, str): + self._connection = sqlite3.connect(config.connection, **_conn_kwargs) + elif isinstance(config.connection, sqlite3.Connection): + self._connection = config.connection + else: + raise TypeError( + f'connection argument must be None or a string or a sqlite3.Connection, not `{type(config.connection)}`' + ) + + self._table_name = ( + _sanitize_table_name(self.__class__.__name__ + random_identity()) + if config.table_name is None + else _sanitize_table_name(config.table_name) + ) + self._persist = bool(config.table_name) + initialize_table( + self._table_name, self.__class__.__name__, self.schema_version, self._cursor + ) + self._connection.commit() + self._config = config + if _docs is not None: + self.clear() + self.extend(_docs) diff --git a/docarray/array/storage/sqlite/getsetdel.py b/docarray/array/storage/sqlite/getsetdel.py new file mode 100644 index 00000000000..89318b114df --- /dev/null +++ b/docarray/array/storage/sqlite/getsetdel.py @@ -0,0 +1,124 @@ +from typing import Sequence, Iterable + +from ..base.getsetdel import BaseGetSetDelMixin +from .... import Document + + +class GetSetDelMixin(BaseGetSetDelMixin): + """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + + # essential methods start + + def _del_doc_by_id(self, _id: str): + self._sql(f'DELETE FROM {self._table_name} WHERE doc_id=?', (_id,)) + self._commit() + + def _del_doc_by_offset(self, offset: int): + + # if offset = -2 and len(self)= 100 use offset = 98 + offset = len(self) + offset if offset < 0 else offset + + self._sql(f'DELETE FROM {self._table_name} WHERE item_order=?', (offset,)) + + # shift the offset of every value on the right position of the deleted item + self._sql( + f'UPDATE {self._table_name} SET item_order=item_order-1 WHERE item_order>?', + (offset,), + ) + + # Code above line is equivalent to + """ + for i in range(offset, len(self) + 1): + self._sql( f'UPDATE {self._table_name} SET item_order=? WHERE item_order=?', (i - 1, i), ) + """ + + self._commit() + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + + # if offset = -2 and len(self)= 100 use offset = 98 + offset = len(self) + offset if offset < 0 else offset + + self._sql( + f'UPDATE {self._table_name} SET serialized_value=?, doc_id=? WHERE item_order=?', + (value, value.id, offset), + ) + + self._commit() + + def _set_doc_by_id(self, _id: str, value: 'Document'): + self._sql( + f'UPDATE {self._table_name} SET serialized_value=?, doc_id=? WHERE doc_id=?', + (value, value.id, _id), + ) + self._commit() + + def _get_doc_by_offset(self, index: int) -> 'Document': + r = self._sql( + f'SELECT serialized_value FROM {self._table_name} WHERE item_order = ?', + (index + (len(self) if index < 0 else 0),), + ) + res = r.fetchone() + if res is None: + raise IndexError('index out of range') + return res[0] + + def _get_doc_by_id(self, id: str) -> 'Document': + r = self._sql( + f'SELECT serialized_value FROM {self._table_name} WHERE doc_id = ?', (id,) + ) + res = r.fetchone() + if res is None: + raise KeyError(f'Can not find Document with id=`{id}`') + return res[0] + + # essentials end here + + # now start the optimized bulk methods + def _get_docs_by_offsets(self, offsets: Sequence[int]) -> Iterable['Document']: + l = len(self) + offsets = [o + (l if o < 0 else 0) for o in offsets] + r = self._sql( + f"SELECT serialized_value FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", + offsets, + ) + for rr in r: + yield rr[0] + + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._get_docs_by_offsets(range(len(self))[_slice]) + + def _get_docs_by_ids(self, ids: str) -> Iterable['Document']: + r = self._sql( + f"SELECT serialized_value FROM {self._table_name} WHERE doc_id in ({','.join(['?'] * len(ids))})", + ids, + ) + for rr in r: + yield rr[0] + + def _del_all_docs(self): + self._sql(f'DELETE FROM {self._table_name}') + self._commit() + + def _del_docs_by_slice(self, _slice: slice): + offsets = range(len(self))[_slice] + self._sql( + f"DELETE FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", + offsets, + ) + self._commit() + + def _del_docs_by_mask(self, mask: Sequence[bool]): + + offsets = [i for i, m in enumerate(mask) if m == True] + self._sql( + f"DELETE FROM {self._table_name} WHERE item_order in ({','.join(['?'] * len(offsets))})", + offsets, + ) + self._commit() + + def _set_doc_value_pairs( + self, docs: Iterable['Document'], values: Iterable['Document'] + ): + for _d, _v in zip(docs, values): + self._set_doc_by_id(_d.id, _v) diff --git a/docarray/array/storage/sqlite/helper.py b/docarray/array/storage/sqlite/helper.py new file mode 100644 index 00000000000..5eb05a09c47 --- /dev/null +++ b/docarray/array/storage/sqlite/helper.py @@ -0,0 +1,80 @@ +import sqlite3 + + +def initialize_table( + table_name: str, container_type_name: str, schema_version: str, cur: sqlite3.Cursor +) -> None: + if not _is_metadata_table_initialized(cur): + _do_initialize_metadata_table(cur) + + if not _is_table_initialized(table_name, container_type_name, schema_version, cur): + _do_create_table(table_name, cur) + _do_tidy_table_metadata(table_name, container_type_name, schema_version, cur) + + +def _is_metadata_table_initialized(cur: sqlite3.Cursor) -> bool: + try: + cur.execute('SELECT 1 FROM metadata LIMIT 1') + _ = list(cur) + return True + except sqlite3.OperationalError as _: + pass + return False + + +def _do_initialize_metadata_table(cur: sqlite3.Cursor) -> None: + cur.execute( + ''' + CREATE TABLE metadata ( + table_name TEXT PRIMARY KEY, + schema_version TEXT NOT NULL, + container_type TEXT NOT NULL, + UNIQUE (table_name, container_type) + ) + ''' + ) + + +def _do_create_table( + table_name: str, + cur: 'sqlite3.Cursor', +) -> None: + cur.execute( + f''' + CREATE TABLE {table_name} ( + doc_id TEXT NOT NULL UNIQUE, + serialized_value Document NOT NULL, + item_order INTEGER PRIMARY KEY) + ''' + ) + + +def _is_table_initialized( + table_name: str, container_type_name: str, schema_version: str, cur: sqlite3.Cursor +) -> bool: + try: + cur.execute( + 'SELECT schema_version FROM metadata WHERE table_name=? AND container_type=?', + (table_name, container_type_name), + ) + buf = cur.fetchone() + if buf is None: + return False + version = buf[0] + if version != schema_version: + return False + cur.execute(f'SELECT 1 FROM {table_name} LIMIT 1') + _ = list(cur) + return True + except sqlite3.OperationalError as _: + pass + return False + + +def _do_tidy_table_metadata( + table_name: str, container_type_name: str, schema_version: str, cur: sqlite3.Cursor +) -> None: + cur.execute( + 'INSERT INTO metadata (table_name, schema_version, container_type) VALUES (?, ?, ?)', + (table_name, schema_version, container_type_name), + ) diff --git a/docarray/array/storage/sqlite/seqlike.py b/docarray/array/storage/sqlite/seqlike.py new file mode 100644 index 00000000000..86647531fc8 --- /dev/null +++ b/docarray/array/storage/sqlite/seqlike.py @@ -0,0 +1,104 @@ +from typing import Iterator, Union, Iterable, MutableSequence, Optional, Sequence + +from .... import Document + + +class SequenceLikeMixin(MutableSequence[Document]): + """Implement sequence-like methods""" + + def _insert_doc_at_idx(self, doc, idx: Optional[int] = None): + if idx is None: + idx = len(self) + self._sql( + f'INSERT INTO {self._table_name} (doc_id, serialized_value, item_order) VALUES (?, ?, ?)', + (doc.id, doc, idx), + ) + + def _shift_index_right_backward(self, start: int): + idx = len(self) - 1 + while idx >= start: + self._sql( + f'UPDATE {self._table_name} SET item_order = ? WHERE item_order = ?', + (idx + 1, idx), + ) + idx -= 1 + + def insert(self, index: int, value: 'Document'): + """Insert `doc` at `index`. + + :param index: Position of the insertion. + :param value: The doc needs to be inserted. + """ + length = len(self) + if index < 0: + index = length + index + index = max(0, min(length, index)) + self._shift_index_right_backward(index) + self._insert_doc_at_idx(doc=value, idx=index) + self._commit() + + def append(self, value: 'Document') -> None: + self._insert_doc_at_idx(value) + self._commit() + + def extend(self, values: Iterable['Document']) -> None: + idx = len(self) + for v in values: + self._insert_doc_at_idx(v, idx) + idx += 1 + self._commit() + + def clear(self) -> None: + self._del_all_docs() + + def __del__(self) -> None: + if not self._persist: + self._sql( + 'DELETE FROM metadata WHERE table_name=? AND container_type=?', + (self._table_name, self.__class__.__name__), + ) + self._sql(f'DROP TABLE {self._table_name}') + self._commit() + + def __contains__(self, item: Union[str, 'Document']): + if isinstance(item, str): + r = self._sql(f'SELECT 1 FROM {self._table_name} WHERE doc_id=?', (item,)) + return len(list(r)) > 0 + elif isinstance(item, Document): + return item.id in self # fall back to str check + else: + return False + + def __len__(self) -> int: + r = self._sql(f'SELECT COUNT(*) FROM {self._table_name}') + return r.fetchone()[0] + + def __iter__(self) -> Iterator['Document']: + r = self._sql( + f'SELECT serialized_value FROM {self._table_name} ORDER BY item_order' + ) + for res in r: + yield res[0] + + def __repr__(self): + return f'' + + def __bool__(self): + """To simulate ```l = []; if l: ...``` + + :return: returns true if the length of the array is larger than 0 + """ + return len(self) > 0 + + def __eq__(self, other): + """In sqlite backend, data are considered as identical if configs point to the same database source""" + return ( + type(self) is type(other) + and type(self._config) is type(other._config) + and self._config == other._config + ) + + def __add__(self, other: Union['Document', Sequence['Document']]): + v = type(self)(self, storage='sqlite') + v.extend(other) + return v diff --git a/docarray/helper.py b/docarray/helper.py index 507c1d5f118..811241e2da2 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -311,3 +311,13 @@ def get_compress_ctx(algorithm: Optional[str] = None, mode: str = 'wb'): else: compress_ctx = None return compress_ctx + + +def dataclass_from_dict(klass, dikt): + try: + fieldtypes = klass.__annotations__ + return klass(**{f: dataclass_from_dict(fieldtypes[f], dikt[f]) for f in dikt}) + except AttributeError: + if isinstance(dikt, (tuple, list)): + return [dataclass_from_dict(klass.__args__[0], f) for f in dikt] + return dikt diff --git a/docarray/math/ndarray.py b/docarray/math/ndarray.py index 1dc354f61fc..a4ef24a9f03 100644 --- a/docarray/math/ndarray.py +++ b/docarray/math/ndarray.py @@ -55,6 +55,7 @@ def ravel(value: 'ArrayType', docs: Sequence['Document'], field: str) -> None: :param field: the field of the doc to set :param value: the value to be set on ``doc.field`` """ + from .. import DocumentArray use_get_row = False if hasattr(value, 'getformat'): @@ -76,9 +77,12 @@ def ravel(value: 'ArrayType', docs: Sequence['Document'], field: str) -> None: for d, j in zip(docs, value): setattr(d, field, j) else: + emb_shape0 = value.shape[0] - for d, j in zip(docs, range(emb_shape0)): + for i, (d, j) in enumerate(zip(docs, range(emb_shape0))): setattr(d, field, value[j, ...]) + if isinstance(docs, DocumentArray): + docs._set_doc_by_id(d.id, d) def get_array_type(array: 'ArrayType') -> Tuple[str, bool]: diff --git a/docarray/types.py b/docarray/types.py index 7a4253bf7be..18ee8361c08 100644 --- a/docarray/types.py +++ b/docarray/types.py @@ -61,3 +61,8 @@ DocumentArraySingleAttributeType, DocumentArrayMultipleAttributeType, ] + + from .array.sqlite import DocumentArraySqlite + from .array.memory import DocumentArrayInMemory + + DocumentArrayLike = Union[DocumentArrayInMemory, DocumentArraySqlite] diff --git a/tests/unit/array/mixins/test_magic.py b/tests/unit/array/mixins/test_magic.py index 76ade193d7c..5ceb186ffcd 100644 --- a/tests/unit/array/mixins/test_magic.py +++ b/tests/unit/array/mixins/test_magic.py @@ -1,6 +1,6 @@ import pytest -from docarray import DocumentArray +from docarray import DocumentArray, Document N = 100 @@ -10,6 +10,11 @@ def da_and_dam(): return (da,) +@pytest.fixture +def docs(): + yield (Document(text=str(j)) for j in range(100)) + + @pytest.mark.parametrize('da', da_and_dam()) def test_iter_len_bool(da): j = 0 @@ -27,6 +32,17 @@ def test_repr(da): assert f'length={N}' in repr(da) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_repr_str(docs, storage): + da = DocumentArray(docs, storage=storage) + print(da) + da.summary() + assert da + da.clear() + assert not da + print(da) + + @pytest.mark.parametrize('da', da_and_dam()) def test_iadd(da): oid = id(da) diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 77c72a8a9fe..741d3da3ab7 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -5,90 +5,111 @@ @pytest.fixture -def docarray100(): - yield DocumentArray(Document(text=j) for j in range(100)) +def docs(): + yield (Document(text=j) for j in range(100)) -def test_getter_int_str(docarray100): +@pytest.fixture +def indices(): + yield (i for i in [-2, 0, 2]) + + +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_getter_int_str(docs, storage): + docs = DocumentArray(docs, storage=storage) # getter - assert docarray100[99].text == 99 - assert docarray100[np.int(99)].text == 99 - assert docarray100[-1].text == 99 - assert docarray100[0].text == 0 + assert docs[99].text == 99 + assert docs[np.int(99)].text == 99 + assert docs[-1].text == 99 + assert docs[0].text == 0 # string index - assert docarray100[docarray100[0].id].text == 0 - assert docarray100[docarray100[99].id].text == 99 - assert docarray100[docarray100[-1].id].text == 99 + assert docs[docs[0].id].text == 0 + assert docs[docs[99].id].text == 99 + assert docs[docs[-1].id].text == 99 with pytest.raises(IndexError): - docarray100[100] + docs[100] with pytest.raises(KeyError): - docarray100['adsad'] + docs['adsad'] -def test_setter_int_str(docarray100): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_setter_int_str(docs, storage): + docs = DocumentArray(docs, storage=storage) # setter - docarray100[99] = Document(text='hello') - docarray100[0] = Document(text='world') + docs[99] = Document(text='hello') + docs[0] = Document(text='world') - assert docarray100[99].text == 'hello' - assert docarray100[-1].text == 'hello' - assert docarray100[0].text == 'world' + assert docs[99].text == 'hello' + assert docs[-1].text == 'hello' + assert docs[0].text == 'world' - docarray100[docarray100[2].id] = Document(text='doc2') + docs[docs[2].id] = Document(text='doc2') # string index - assert docarray100[docarray100[2].id].text == 'doc2' - - -def test_del_int_str(docarray100): - zero_id = docarray100[0].id - del docarray100[0] - assert len(docarray100) == 99 - assert zero_id not in docarray100 - - new_zero_id = docarray100[0].id - new_doc_zero = docarray100[0] - del docarray100[new_zero_id] - assert len(docarray100) == 98 - assert zero_id not in docarray100 - assert new_doc_zero not in docarray100 - - -def test_slice(docarray100): + assert docs[docs[2].id].text == 'doc2' + + +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_del_int_str(docs, storage, indices): + docs = DocumentArray(docs, storage=storage) + initial_len = len(docs) + deleted_elements = 0 + for pos in indices: + pos_id = docs[pos].id + del docs[pos] + deleted_elements += 1 + assert pos_id not in docs + assert len(docs) == initial_len - deleted_elements + + new_pos_id = docs[pos].id + new_doc_zero = docs[pos] + del docs[new_pos_id] + deleted_elements += 1 + assert len(docs) == initial_len - deleted_elements + assert pos_id not in docs + assert new_doc_zero not in docs + + +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_slice(docs, storage): + docs = DocumentArray(docs, storage=storage) # getter - assert len(docarray100[1:5]) == 4 - assert len(docarray100[1:100:5]) == 20 # 1 to 100, sep with 5 + assert len(docs[1:5]) == 4 + assert len(docs[1:100:5]) == 20 # 1 to 100, sep with 5 # setter - with pytest.raises(TypeError, match='can only assign an iterable'): - docarray100[1:5] = Document(text='repl') + with pytest.raises(TypeError, match='an iterable'): + docs[1:5] = Document(text='repl') - docarray100[1:5] = [Document(text=f'repl{j}') for j in range(4)] - for d in docarray100[1:5]: + docs[1:5] = [Document(text=f'repl{j}') for j in range(4)] + for d in docs[1:5]: assert d.text.startswith('repl') - assert len(docarray100) == 100 + assert len(docs) == 100 # del - zero_doc = docarray100[0] - twenty_doc = docarray100[20] - del docarray100[0:20] - assert len(docarray100) == 80 - assert zero_doc not in docarray100 - assert twenty_doc in docarray100 + zero_doc = docs[0] + twenty_doc = docs[20] + del docs[0:20] + assert len(docs) == 80 + assert zero_doc not in docs + assert twenty_doc in docs -def test_sequence_bool_index(docarray100): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_sequence_bool_index(docs, storage): + docs = DocumentArray(docs, storage=storage) # getter mask = [True, False] * 50 - assert len(docarray100[mask]) == 50 - assert len(docarray100[[True, False]]) == 1 + assert len(docs[mask]) == 50 + assert len(docs[[True, False]]) == 1 # setter mask = [True, False] * 50 - docarray100[mask] = [Document(text=f'repl{j}') for j in range(50)] + # docs[mask] = [Document(text=f'repl{j}') for j in range(50)] + docs[mask, 'text'] = [f'repl{j}' for j in range(50)] - for idx, d in enumerate(docarray100): + for idx, d in enumerate(docs): if idx % 2 == 0: # got replaced assert d.text.startswith('repl') @@ -96,61 +117,68 @@ def test_sequence_bool_index(docarray100): assert isinstance(d.text, int) # del - del docarray100[mask] - assert len(docarray100) == 50 - - del docarray100[mask] - assert len(docarray100) == 25 + del docs[mask] + assert len(docs) == 50 @pytest.mark.parametrize('nparray', [lambda x: x, np.array, tuple]) -def test_sequence_int(docarray100, nparray): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_sequence_int(docs, nparray, storage): + docs = DocumentArray(docs, storage=storage) # getter idx = nparray([1, 3, 5, 7, -1, -2]) - assert len(docarray100[idx]) == len(idx) + assert len(docs[idx]) == len(idx) # setter - docarray100[idx] = [Document(text='repl') for _ in range(len(idx))] + docs[idx] = [Document(text='repl') for _ in range(len(idx))] for _id in idx: - assert docarray100[_id].text == 'repl' + assert docs[_id].text == 'repl' # del idx = [-3, -4, -5, 9, 10, 11] - del docarray100[idx] - assert len(docarray100) == 100 - len(idx) + del docs[idx] + assert len(docs) == 100 - len(idx) -def test_sequence_str(docarray100): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_sequence_str(docs, storage): + docs = DocumentArray(docs, storage=storage) # getter - idx = [d.id for d in docarray100[1, 3, 5, 7, -1, -2]] + idx = [d.id for d in docs[1, 3, 5, 7, -1, -2]] - assert len(docarray100[idx]) == len(idx) - assert len(docarray100[tuple(idx)]) == len(idx) + assert len(docs[idx]) == len(idx) + assert len(docs[tuple(idx)]) == len(idx) # setter - docarray100[idx] = [Document(text='repl') for _ in range(len(idx))] - idx = [d.id for d in docarray100[1, 3, 5, 7, -1, -2]] + docs[idx] = [Document(text='repl') for _ in range(len(idx))] + idx = [d.id for d in docs[1, 3, 5, 7, -1, -2]] for _id in idx: - assert docarray100[_id].text == 'repl' + assert docs[_id].text == 'repl' # del - idx = [d.id for d in docarray100[-3, -4, -5, 9, 10, 11]] - del docarray100[idx] - assert len(docarray100) == 100 - len(idx) + idx = [d.id for d in docs[-3, -4, -5, 9, 10, 11]] + del docs[idx] + assert len(docs) == 100 - len(idx) -def test_docarray_list_tuple(docarray100): - assert isinstance(docarray100[99, 98], DocumentArray) - assert len(docarray100[99, 98]) == 2 +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_docarray_list_tuple(docs, storage): + docs = DocumentArray(docs, storage=storage) + assert isinstance(docs[99, 98], DocumentArray) + assert len(docs[99, 98]) == 2 -def test_path_syntax_indexing(): - da = DocumentArray().empty(3) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_path_syntax_indexing(storage): + da = DocumentArray.empty(3) for d in da: d.chunks = DocumentArray.empty(5) d.matches = DocumentArray.empty(7) for c in d.chunks: c.chunks = DocumentArray.empty(3) + + if storage == 'sqlite': + da = DocumentArray(da, storage=storage) assert len(da['@c']) == 3 * 5 assert len(da['@c:1']) == 3 assert len(da['@c-1:']) == 3 @@ -165,8 +193,11 @@ def test_path_syntax_indexing(): assert len(da['@r:1cc,m']) == 1 * 5 * 3 + 3 * 7 -def test_attribute_indexing(): - da = DocumentArray.empty(10) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_attribute_indexing(storage): + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(10)) + for v in da[:, 'id']: assert v da[:, 'mime_type'] = [f'type {j}' for j in range(10)] @@ -187,14 +218,17 @@ def test_attribute_indexing(): assert vv -def test_tensor_attribute_selector(): +# TODO: enable weaviate storage test +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_tensor_attribute_selector(storage): import scipy.sparse sp_embed = np.random.random([3, 10]) sp_embed[sp_embed > 0.1] = 0 sp_embed = scipy.sparse.coo_matrix(sp_embed) - da = DocumentArray.empty(3) + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(3)) da[:, 'embedding'] = sp_embed @@ -212,17 +246,22 @@ def test_tensor_attribute_selector(): assert isinstance(v1, list) -def test_advance_selector_mixed(): - da = DocumentArray.empty(10) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_advance_selector_mixed(storage): + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(10)) da.embeddings = np.random.random([10, 3]) + da.match(da, exclude_self=True) assert len(da[:, ('id', 'embedding', 'matches')]) == 3 assert len(da[:, ('id', 'embedding', 'matches')][0]) == 10 -def test_single_boolean_and_padding(): - da = DocumentArray.empty(3) +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_single_boolean_and_padding(storage): + da = DocumentArray(storage=storage) + da.extend(DocumentArray.empty(3)) with pytest.raises(IndexError): da[True] @@ -237,9 +276,12 @@ def test_single_boolean_and_padding(): assert len(da[False, False]) == 0 -def test_edge_case_two_strings(): +@pytest.mark.parametrize('storage', ['memory', 'sqlite']) +def test_edge_case_two_strings(storage): # getitem - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')]) + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) assert da['1', 'id'] == '1' assert len(da['1', '2']) == 2 assert isinstance(da['1', '2'], DocumentArray) @@ -254,20 +296,26 @@ def test_edge_case_two_strings(): del da['1', '2'] assert len(da) == 1 - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')]) - del da['1', 'id'] + da = DocumentArray( + [Document(id=str(i), text='hey') for i in range(3)], storage=storage + ) + del da['1', 'text'] assert len(da) == 3 - assert not da[0].id + assert not da[1].text del da['2', 'hello'] # setitem - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')]) + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) da['1', '2'] = DocumentArray.empty(2) assert da[0].id != '1' assert da[1].id != '2' - da = DocumentArray([Document(id='1'), Document(id='2'), Document(id='3')]) + da = DocumentArray( + [Document(id='1'), Document(id='2'), Document(id='3')], storage=storage + ) da['1', 'text'] = 'hello' assert da['1'].text == 'hello' diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py new file mode 100644 index 00000000000..a772c8e44e9 --- /dev/null +++ b/tests/unit/array/test_base_getsetdel.py @@ -0,0 +1,86 @@ +from abc import ABC + +import numpy as np +import pytest + +from docarray import DocumentArray, Document +from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin +from docarray.array.storage.memory import BackendMixin, SequenceLikeMixin + + +class DummyGetSetDelMixin(BaseGetSetDelMixin): + """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + + # essentials + + def _del_doc_by_id(self, _id: str): + del self._data[self._id2offset[_id]] + self._id2offset.pop(_id) + + def _del_doc_by_offset(self, offset: int): + self._id2offset.pop(self._data[offset].id) + del self._data[offset] + + def _set_doc_by_id(self, _id: str, value: 'Document'): + old_idx = self._id2offset.pop(_id) + self._data[old_idx] = value + self._id2offset[value.id] = old_idx + + def _get_doc_by_offset(self, offset: int) -> 'Document': + return self._data[offset] + + def _get_doc_by_id(self, _id: str) -> 'Document': + return self._data[self._id2offset[_id]] + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + self._data[offset] = value + self._id2offset[value.id] = offset + + +class StorageMixins(BackendMixin, DummyGetSetDelMixin, SequenceLikeMixin, ABC): + ... + + +class DocumentArrayDummy(StorageMixins, DocumentArray): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + +@pytest.fixture(scope='function') +def docs(): + return DocumentArrayDummy([Document(id=str(j), text=j) for j in range(100)]) + + +def test_index_by_int_str(docs): + # getter + assert len(docs[[1]]) == 1 + assert len(docs[1, 2]) == 2 + assert len(docs[1, 2, 3]) == 3 + assert len(docs[1:5]) == 4 + assert len(docs[1:100:5]) == 20 # 1 to 100, sep with 5 + + # setter + with pytest.raises(TypeError, match='an iterable'): + docs[1:5] = Document(text='repl') + + docs[1:5] = [Document(text=f'repl{j}') for j in range(4)] + for d in docs[1:5]: + assert d.text.startswith('repl') + assert len(docs) == 100 + + +def test_getter_int_str(docs): + # getter + assert docs[99].text == 99 + assert docs[-1].text == 99 + assert docs[0].text == 0 + + # string index + assert docs['0'].text == 0 + assert docs['99'].text == 99 + + with pytest.raises(IndexError): + docs[100] + + with pytest.raises(KeyError): + docs['adsad']