diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index be10730ce74..3b38658e107 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -235,6 +235,10 @@ def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): :param attr: the attribute of document to update :param value: the value doc's attr will be updated to """ + if attr == 'id' and value is None: + raise ValueError( + 'setting the ID of a Document stored in a DocumentArray to None is not allowed' + ) _id = self._offset2ids.get_id(offset) d = self._get_doc_by_id(_id) if hasattr(d, attr): diff --git a/docarray/array/storage/memory/backend.py b/docarray/array/storage/memory/backend.py index 804320343f9..773289ec621 100644 --- a/docarray/array/storage/memory/backend.py +++ b/docarray/array/storage/memory/backend.py @@ -1,7 +1,10 @@ +import functools from typing import ( Optional, TYPE_CHECKING, Iterable, + Callable, + Dict, ) from ..base.backend import BaseBackendMixin @@ -13,9 +16,44 @@ ) +def needs_id2offset_rebuild(func) -> Callable: + # self._id2offset needs to be rebuilt after every insert or delete + # this flag allows to do it lazily and cache the result + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + self._needs_id2offset_rebuild = True + return func(self, *args, **kwargs) + + return wrapper + + class BackendMixin(BaseBackendMixin): """Provide necessary functions to enable this storage backend.""" + @property + def _id2offset(self) -> Dict[str, int]: + """Return the `_id_to_index` map + + :return: a Python dict. + """ + if self._needs_id2offset_rebuild: + self._rebuild_id2offset() + + return self._id_to_index + + def _rebuild_id2offset(self) -> None: + """Update the id_to_index map by enumerating all Documents in self._data. + + Very costy! Only use this function when self._data is dramtically changed. + """ + + self._id_to_index = { + d.id: i for i, d in enumerate(self._data) + } # type: Dict[str, int] + + self._needs_id2offset_rebuild = False + + @needs_id2offset_rebuild def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, @@ -23,9 +61,12 @@ def _init_storage( *args, **kwargs ): + from docarray.array.memory import DocumentArrayInMemory + super()._init_storage(_docs, copy=copy, *args, **kwargs) - self._data = {} + self._data = [] + self._id_to_index = {} if _docs is None: return elif isinstance( @@ -33,8 +74,11 @@ def _init_storage( Iterable, ): if copy: - for doc in _docs: - self.append(Document(doc, copy=True)) + self._data = [Document(d, copy=True) for d in _docs] + elif isinstance(_docs, DocumentArrayInMemory): + self._data = _docs._data + self._id_to_index = _docs._id2offset + self._needs_id2offset_rebuild = _docs._needs_id2offset_rebuild else: self.extend(_docs) else: diff --git a/docarray/array/storage/memory/getsetdel.py b/docarray/array/storage/memory/getsetdel.py index 55f7865499d..993a5e1dce5 100644 --- a/docarray/array/storage/memory/getsetdel.py +++ b/docarray/array/storage/memory/getsetdel.py @@ -1,43 +1,75 @@ +import itertools from typing import ( Sequence, Iterable, + Any, ) from ..base.getsetdel import BaseGetSetDelMixin -from ..base.helper import Offset2ID +from ..memory.backend import needs_id2offset_rebuild from .... import Document class GetSetDelMixin(BaseGetSetDelMixin): """Implement required and derived functions that power `getitem`, `setitem`, `delitem`""" + @needs_id2offset_rebuild + def _del_docs_by_mask(self, mask: Sequence[bool]): + self._data = list(itertools.compress(self._data, (not _i for _i in mask))) + + @needs_id2offset_rebuild + def _del_docs_by_slice(self, _slice: slice): + del self._data[_slice] + def _del_doc_by_id(self, _id: str): - del self._data[_id] + self._del_doc_by_offset(self._id2offset[_id]) + + @needs_id2offset_rebuild + def _del_doc_by_offset(self, offset: int): + del self._data[offset] + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + old_id = self._data[offset].id + self._id2offset[value.id] = offset + self._data[offset] = value + self._id2offset.pop(old_id) def _set_doc_by_id(self, _id: str, value: 'Document'): - if _id != value.id: - del self._data[_id] - self._data[value.id] = value + old_idx = self._id2offset.pop(_id) + self._data[old_idx] = value + self._id2offset[value.id] = old_idx - def _set_doc_value_pairs( - self, docs: Iterable['Document'], values: Sequence['Document'] - ): - docs = list(docs) + @needs_id2offset_rebuild + def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): + self._data[_slice] = value - for _d, _v in zip(docs, values): - _d._data = _v._data + def _set_doc_attr_by_offset(self, offset: int, attr: str, value: Any): + if attr == 'id' and value is None: + raise ValueError( + 'setting the ID of a Document stored in a DocumentArray to None is not allowed' + ) + + setattr(self._data[offset], attr, value) + + def _get_doc_by_offset(self, offset: int) -> 'Document': + return self._data[offset] def _get_doc_by_id(self, _id: str) -> 'Document': - return self._data[_id] + return self._data[self._id2offset[_id]] - def _get_docs_by_ids(self, ids: Sequence[str]) -> Iterable['Document']: - return (self._data[_id] for _id in ids) + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._data[_slice] def _clear_storage(self): self._data.clear() + self._id2offset.clear() def _load_offset2ids(self): - self._offset2ids = Offset2ID() + ... def _save_offset2ids(self): ... + + _set_doc = _set_doc_by_id + _del_doc = _del_doc_by_id + _del_all_docs = _clear_storage diff --git a/docarray/array/storage/memory/seqlike.py b/docarray/array/storage/memory/seqlike.py index e210f521e1a..061382ab910 100644 --- a/docarray/array/storage/memory/seqlike.py +++ b/docarray/array/storage/memory/seqlike.py @@ -1,4 +1,6 @@ -from typing import Union, Iterable +from typing import Union, Iterable, MutableSequence, Iterator + +from ..memory.backend import needs_id2offset_rebuild from ..base.seqlike import BaseSequenceLikeMixin from .... import Document @@ -7,19 +9,42 @@ class SequenceLikeMixin(BaseSequenceLikeMixin): """Implement sequence-like methods""" + @needs_id2offset_rebuild + def insert(self, index: int, value: 'Document'): + """Insert `doc` at `index`. + + :param index: Position of the insertion. + :param value: The doc needs to be inserted. + """ + self._data.insert(index, value) + + def append(self, value: 'Document'): + """Append `doc` to the end of the array. + + :param value: The doc needs to be appended. + """ + self._data.append(value) + if not self._needs_id2offset_rebuild: + self._id_to_index[value.id] = len(self) - 1 + def __eq__(self, other): return ( type(self) is type(other) and type(self._data) is type(other._data) and self._data == other._data - and self._offset2ids == other._offset2ids ) + def __len__(self): + return len(self._data) + + def __iter__(self) -> Iterator['Document']: + yield from self._data + def __contains__(self, x: Union[str, 'Document']): if isinstance(x, str): - return x in self._data + return x in self._id2offset elif isinstance(x, Document): - return x.id in self._data + return x.id in self._id2offset else: return False @@ -30,3 +55,9 @@ def __add__(self, other: Union['Document', Iterable['Document']]): v = type(self)(self) v.extend(other) return v + + def extend(self, values: Iterable['Document']) -> None: + values = list(values) # consume the iterator only once + last_idx = len(self._id2offset) + self._data.extend(values) + self._id_to_index.update({d.id: i + last_idx for i, d in enumerate(values)}) diff --git a/tests/unit/array/mixins/test_traverse.py b/tests/unit/array/mixins/test_traverse.py index c983453fa74..0a128551de2 100644 --- a/tests/unit/array/mixins/test_traverse.py +++ b/tests/unit/array/mixins/test_traverse.py @@ -736,3 +736,27 @@ def test_traverse_flat_offset(): assert len(flat_docs) == 2 assert flat_docs[0].id == 'r2c1' assert flat_docs[1].id == 'r2c2' + + +def test_traverse_flat_conflicting_ids(): + da = DocumentArray( + [ + Document( + id=f'r{i}', + chunks=[Document(id=f'rc{j}') for j in range(3)], + matches=[Document(id=f'rm{j}') for j in range(3)], + ) + for i in range(3) + ] + ) + + for traversal_path in ['rc', 'rm']: + + flattened = da.traverse_flat(traversal_path) + assert len(flattened) == 9 + child_ids = set() + + for child in flattened: + child_ids.add(id(child)) + + assert len(child_ids) == 9 diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index bad7725606e..87b37a7ee8d 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -635,3 +635,26 @@ def test_offset2ids_persistence(storage, config, start_storage): da = DocumentArray(storage=storage, config=config) assert da[:, 'id'] == da_ids + + +def test_dam_conflicting_ids(): + docs = [ + Document(id='1'), + Document(id='2'), + Document(id='3'), + ] + + d = Document(id='1') + da = DocumentArray() + da.extend(docs) + da.append(d) + + assert len(da) == 4 + assert id(da[0]) == id(docs[0]) + assert id(da[3]) == id(d) + + da[0].text = 'd1' + da[3].text = 'd2' + + assert docs[0].text == 'd1' + assert d.text == 'd2' diff --git a/tests/unit/array/test_base_getsetdel.py b/tests/unit/array/test_base_getsetdel.py index 146296fad76..c0ef9f2a4f6 100644 --- a/tests/unit/array/test_base_getsetdel.py +++ b/tests/unit/array/test_base_getsetdel.py @@ -1,10 +1,10 @@ from abc import ABC +from typing import Iterable, Sequence import pytest from docarray import DocumentArray, Document from docarray.array.storage.base.getsetdel import BaseGetSetDelMixin -from docarray.array.storage.base.helper import Offset2ID from docarray.array.storage.memory import BackendMixin, SequenceLikeMixin @@ -14,18 +14,33 @@ class DummyGetSetDelMixin(BaseGetSetDelMixin): # essentials def _del_doc_by_id(self, _id: str): - del self._data[_id] + del self._data[self._id2offset[_id]] + self._id2offset.pop(_id) + + def _del_doc_by_offset(self, offset: int): + self._id2offset.pop(self._data[offset].id) + del self._data[offset] def _set_doc_by_id(self, _id: str, value: 'Document'): - if _id != value.id: - del self._data[_id] - self._data[value.id] = value + old_idx = self._id2offset.pop(_id) + self._data[old_idx] = value + self._id2offset[value.id] = old_idx + + def _get_doc_by_offset(self, offset: int) -> 'Document': + return self._data[offset] def _get_doc_by_id(self, _id: str) -> 'Document': - return self._data[_id] + return self._data[self._id2offset[_id]] + + def _set_doc_by_offset(self, offset: int, value: 'Document'): + self._data[offset] = value + self._id2offset[value.id] = offset - def _clear_storage(self): - self._data.clear() + def _get_docs_by_slice(self, _slice: slice) -> Iterable['Document']: + return self._data[_slice] + + def _set_docs_by_slice(self, _slice: slice, value: Sequence['Document']): + self._data[_slice] = value class StorageMixins(BackendMixin, DummyGetSetDelMixin, SequenceLikeMixin, ABC): @@ -37,7 +52,7 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls) def _load_offset2ids(self): - self._offset2ids = Offset2ID() + pass def _save_offset2ids(self): pass diff --git a/tests/unit/document/test_feature_hashing.py b/tests/unit/document/test_feature_hashing.py index ec112c5c5fd..5702678ac8f 100644 --- a/tests/unit/document/test_feature_hashing.py +++ b/tests/unit/document/test_feature_hashing.py @@ -21,6 +21,7 @@ def test_feature_hashing(n_dim, sparse, metric): assert da.embeddings.shape == (6, n_dim) da.embeddings = to_numpy_array(da.embeddings) da.match(da, metric=metric, use_scipy=True) - for doc in da: - assert doc.matches[0].scores[metric].value == pytest.approx(0.0) - assert doc.matches[1].scores[metric].value > 0.0 + result = da['@m', ('id', f'scores__{metric}__value')] + assert len(result) == 2 + assert result[1][0] == 0.0 + assert result[1][1] > 0.0