diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py index ea596702c85..c38e1acaab0 100644 --- a/docarray/array/storage/memory/backend.py +++ b/docarray/array/storage/memory/backend.py @@ -6,7 +6,9 @@ Sequence, Optional, TYPE_CHECKING, + Callable, ) +import functools from ..base.backend import BaseBackendMixin from .... import Document @@ -17,6 +19,17 @@ ) +def needs_id2offset_rebuild(func) -> Callable: + # self._id2offset needs to be rebuilt after every insert or delete + # this flag allows to do it lazily and cache the result + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + self._needs_id2offset_rebuild = True + return func(self, *args, **kwargs) + + return wrapper + + class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" @@ -26,8 +39,9 @@ def _id2offset(self) -> Dict[str, int]: :return: a Python dict. """ - if not hasattr(self, '_id_to_index'): + if self._needs_id2offset_rebuild: self._rebuild_id2offset() + return self._id_to_index def _rebuild_id2offset(self) -> None: @@ -40,12 +54,17 @@ def _rebuild_id2offset(self) -> None: d.id: i for i, d in enumerate(self._data) } # type: Dict[str, int] + self._needs_id2offset_rebuild = False + + @needs_id2offset_rebuild def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, copy: bool = False ): from ... import DocumentArray + from ...memory import DocumentArrayInMemory self._data = [] + self._id_to_index = {} if _docs is None: return elif isinstance( @@ -53,13 +72,14 @@ def _init_storage( ): if copy: self._data = [Document(d, copy=True) for d in _docs] - self._rebuild_id2offset() elif isinstance(_docs, DocumentArray): self._data = _docs._data - self._id_to_index = _docs._id2offset else: self._data = list(_docs) - self._rebuild_id2offset() + + if isinstance(_docs, DocumentArrayInMemory): + self._id_to_index = _docs._id2offset + self._needs_id2offset_rebuild = _docs._needs_id2offset_rebuild else: if isinstance(_docs, Document): if copy: diff --git a/docarray/array/storage/memory/getsetdel.py b/docarray/array/storage/memory/getsetdel.py index 8ea62471cce..1a33b28e750 100644 --- a/docarray/array/storage/memory/getsetdel.py +++ b/docarray/array/storage/memory/getsetdel.py @@ -4,7 +4,7 @@ Iterable, Any, ) - +from ..memory.backend import needs_id2offset_rebuild from ..base.getsetdel import BaseGetSetDelMixin from .... import Document @@ -12,24 +12,23 @@ class GetSetDelMixin(BaseGetSetDelMixin): """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + @needs_id2offset_rebuild def _del_docs_by_mask(self, mask: Sequence[bool]): self._data = list(itertools.compress(self._data, (not _i for _i in mask))) - self._rebuild_id2offset() def _del_all_docs(self): self._data.clear() self._id2offset.clear() + @needs_id2offset_rebuild def _del_docs_by_slice(self, _slice: slice): del self._data[_slice] - self._rebuild_id2offset() def _del_doc_by_id(self, _id: str): - del self._data[self._id2offset[_id]] - self._id2offset.pop(_id) + self._del_doc_by_offset(self._id2offset[_id]) + @needs_id2offset_rebuild def _del_doc_by_offset(self, offset: int): - self._id2offset.pop(self._data[offset].id) del self._data[offset] def _set_doc_by_offset(self, offset: int, value: 'Document'): @@ -41,16 +40,15 @@ def _set_doc_by_id(self, _id: str, value: 'Document'): self._data[old_idx] = value self._id2offset[value.id] = old_idx + @needs_id2offset_rebuild def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): self._data[_slice] = value - self._rebuild_id2offset() def _set_doc_value_pairs( self, docs: Iterable['Document'], values: Iterable['Document'] ): for _d, _v in zip(docs, values): _d._data = _v._data - self._rebuild_id2offset() def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): setattr(self._data[offset], attr, value) diff --git a/docarray/array/storage/memory/seqlike.py b/docarray/array/storage/memory/seqlike.py index 0121780fadf..633d7c1fd66 100644 --- a/docarray/array/storage/memory/seqlike.py +++ b/docarray/array/storage/memory/seqlike.py @@ -2,10 +2,13 @@ from .... import Document +from ..memory.backend import needs_id2offset_rebuild + class SequenceLikeMixin(MutableSequence[Document]): """Implement sequence-like methods""" + @needs_id2offset_rebuild def insert(self, index: int, value: 'Document'): """Insert `doc` at `index`. @@ -13,7 +16,6 @@ def insert(self, index: int, value: 'Document'): :param value: The doc needs to be inserted. """ self._data.insert(index, value) - self._id2offset[value.id] = index def __eq__(self, other): return ( @@ -56,5 +58,7 @@ def __add__(self, other: Union['Document', Sequence['Document']]): return v def extend(self, values: Iterable['Document']) -> None: + values = list(values) # consume the iterator only once + last_idx = len(self._id2offset) self._data.extend(values) - self._rebuild_id2offset() + self._id_to_index.update({d.id: i + last_idx for i, d in enumerate(values)}) diff --git a/test.py b/test.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/array/mixins/test_del.py b/tests/unit/array/mixins/test_del.py new file mode 100644 index 00000000000..45d2b05402a --- /dev/null +++ b/tests/unit/array/mixins/test_del.py @@ -0,0 +1,42 @@ +import pytest + +from docarray import DocumentArray, Document + + +@pytest.fixture() +def docs(): + return DocumentArray([Document(id=f'{i}') for i in range(1, 10)]) + + +@pytest.mark.parametrize( + 'to_delete', + [ + 0, + 1, + 4, + -1, + list(range(1, 4)), + [2, 4, 7, 1, 1], + slice(0, 2), + slice(2, 4), + slice(4, -1), + [True, True, False], + ..., + ], +) +def test_del_all(docs, to_delete): + doc_to_delete = docs[to_delete] + del docs[to_delete] + assert doc_to_delete not in docs + + +@pytest.mark.parametrize( + ['deleted_ids', 'expected_ids'], + [ + (['1', '2', '3', '4'], ['5', '6', '7', '8', '9']), + (['2', '4', '7', '1'], ['3', '5', '6', '8', '9']), + ], +) +def test_del_by_multiple_idx(docs, deleted_ids, expected_ids): + del docs[deleted_ids] + assert docs[:, 'id'] == expected_ids diff --git a/tests/unit/array/test_sequence.py b/tests/unit/array/test_sequence.py index 6ba6679f936..d846161c2e8 100644 --- a/tests/unit/array/test_sequence.py +++ b/tests/unit/array/test_sequence.py @@ -7,11 +7,13 @@ def test_insert(da_cls): da = da_cls() assert not len(da) - da.insert(0, Document(text='hello')) - da.insert(0, Document(text='world')) + da.insert(0, Document(text='hello', id="0")) + da.insert(0, Document(text='world', id="1")) assert len(da) == 2 assert da[0].text == 'world' assert da[1].text == 'hello' + assert da["1"].text == 'world' + assert da["0"].text == 'hello' @pytest.mark.parametrize('da_cls', [DocumentArray])