diff --git a/docarray/array/mixins/find.py b/docarray/array/mixins/find.py index 1e0a58f765d..e028f7c5550 100644 --- a/docarray/array/mixins/find.py +++ b/docarray/array/mixins/find.py @@ -1,15 +1,13 @@ import abc -from typing import overload, Optional, Union, Dict, List, Tuple, Callable, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union, overload import numpy as np - from docarray.math import ndarray from docarray.score import NamedScore if TYPE_CHECKING: # pragma: no cover - from docarray.typing import T, ArrayType - from docarray import Document, DocumentArray + from docarray.typing import ArrayType, T class FindMixin: @@ -99,6 +97,7 @@ def find( filter: Union[Dict, str, None] = None, only_id: bool = False, index: str = 'text', + return_root: Optional[bool] = False, on: Optional[str] = None, **kwargs, ) -> Union['DocumentArray', List['DocumentArray']]: @@ -126,14 +125,17 @@ def find( parameter is ignored. By default, the Document `text` attribute will be used for search, otherwise the tag field specified by `index` will be used. You can only use this parameter if the storage backend supports searching by text. + :param return_root: if set, then the root-level DocumentArray will be returned :param on: specifies a subindex to search on. If set, the returned DocumentArray will be retrieved from the given subindex. :param kwargs: other kwargs. :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ + from docarray import Document, DocumentArray + index_da = self._get_index(subindex_name=on) if index_da is not self: - return index_da.find( + results = index_da.find( query, metric, limit, @@ -144,7 +146,15 @@ def find( index, on=None, ) - from docarray import Document, DocumentArray + + if return_root: + da = self._get_root_docs(results) + for d, s in zip(da, results[:, 'scores']): + d.scores = s + + return da + + return results if isinstance(query, dict): if filter is None: @@ -301,3 +311,15 @@ def _find_by_text(self, *args, **kwargs): raise NotImplementedError( f'Search by text is not supported with this backend {self.__class__.__name__}' ) + + def _get_root_docs(self, docs: 'DocumentArray') -> 'DocumentArray': + """Get the root documents of the current DocumentArray. + + :return: a `DocumentArray` containing the root documents. + """ + + if not all(docs[:, 'tags___root_id_']): + raise ValueError( + f'Not all Documents in this subindex have the "_root_id_" attribute set in all `tags`.' + ) + return self[docs[:, 'tags___root_id_']] diff --git a/docarray/array/mixins/setitem.py b/docarray/array/mixins/setitem.py index 33968402c36..18832893567 100644 --- a/docarray/array/mixins/setitem.py +++ b/docarray/array/mixins/setitem.py @@ -63,6 +63,10 @@ def __setitem__( index: 'DocumentArrayIndexType', value: Union['Document', Sequence['Document']], ): + from docarray.helper import check_root_id + + if self._is_subindex: + check_root_id(self, value) self._update_subindices_set(index, value) # set by offset diff --git a/docarray/array/storage/annlite/backend.py b/docarray/array/storage/annlite/backend.py index 477741e49ab..49fc493be9f 100644 --- a/docarray/array/storage/annlite/backend.py +++ b/docarray/array/storage/annlite/backend.py @@ -31,6 +31,7 @@ class AnnliteConfig: max_connection: Optional[int] = None n_components: Optional[int] = None columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None + root_id: bool = True class BackendMixin(BaseBackendMixin): @@ -104,7 +105,7 @@ def _init_storage( self._annlite = AnnLite(self.n_dim, lock=False, **filter_dict(config)) - super()._init_storage() + super()._init_storage(**kwargs) if _docs is None: return diff --git a/docarray/array/storage/annlite/seqlike.py b/docarray/array/storage/annlite/seqlike.py index 93628ab3cbd..4c5a7fff6f0 100644 --- a/docarray/array/storage/annlite/seqlike.py +++ b/docarray/array/storage/annlite/seqlike.py @@ -20,7 +20,7 @@ def _extend(self, values: Iterable['Document']) -> None: self._offset2ids.extend([doc.id for doc in docs]) def _append(self, value: 'Document'): - self.extend([value]) + self._extend([value]) def __eq__(self, other): """In annlite backend, data are considered as identical if configs point to the same database source""" diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 7105b568b3a..effc784b9ec 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -17,9 +17,11 @@ def _init_storage( self, _docs: Optional['DocumentArraySourceType'] = None, copy: bool = False, + _is_subindex: bool = False, *args, **kwargs, ): + self._is_subindex = _is_subindex self._load_offset2ids() def _init_subindices( @@ -40,7 +42,9 @@ def _init_subindices( config_joined = self._ensure_unique_config( config, config_subindex, config_joined, name ) - self._subindices[name] = self.__class__(config=config_joined) + self._subindices[name] = self.__class__( + config=config_joined, _is_subindex=True + ) if _docs: from docarray import DocumentArray diff --git a/docarray/array/storage/base/getsetdel.py b/docarray/array/storage/base/getsetdel.py index 0c672fde963..97ce50f9a71 100644 --- a/docarray/array/storage/base/getsetdel.py +++ b/docarray/array/storage/base/getsetdel.py @@ -200,13 +200,25 @@ def _update_subindices_set(self, set_index, docs): _check_valid_values_nested_set(self[set_index], docs) if set_index in subindices: subindex_da = subindices[set_index] + subindex_da.clear() subindex_da.extend(docs) else: # root level set, update subindices iteratively for subindex_selector, subindex_da in subindices.items(): old_ids = DocumentArray(self[set_index])[subindex_selector, 'id'] del subindex_da[old_ids] - subindex_da.extend(DocumentArray(docs)[subindex_selector]) + + value = DocumentArray(docs) + + if ( + getattr(subindex_da, '_config', None) # checks if in-memory da + and subindex_da._config.root_id + ): + for v in value: + for doc in DocumentArray(v)[subindex_selector]: + doc.tags['_root_id_'] = v.id + + subindex_da.extend(value[subindex_selector]) def _set_docs(self, ids, docs: Iterable['Document']): docs = list(docs) diff --git a/docarray/array/storage/base/seqlike.py b/docarray/array/storage/base/seqlike.py index d5ef0ebc50c..5e46cafe607 100644 --- a/docarray/array/storage/base/seqlike.py +++ b/docarray/array/storage/base/seqlike.py @@ -1,5 +1,6 @@ +import warnings from abc import abstractmethod -from typing import Iterator, Iterable, MutableSequence +from typing import Iterable, Iterator, MutableSequence from docarray import Document, DocumentArray @@ -10,7 +11,15 @@ class BaseSequenceLikeMixin(MutableSequence[Document]): def _update_subindices_append_extend(self, value): if getattr(self, '_subindices', None): for selector, da in self._subindices.items(): - docs_selector = DocumentArray(value)[selector] + + value = DocumentArray(value) + + if getattr(da, '_config', None) and da._config.root_id: + for v in value: + for doc in DocumentArray(v)[selector]: + doc.tags['_root_id_'] = v.id + + docs_selector = value[selector] if len(docs_selector) > 0: da.extend(docs_selector) @@ -63,6 +72,12 @@ def __bool__(self): return len(self) > 0 def extend(self, values: Iterable['Document'], **kwargs) -> None: + + from docarray.helper import check_root_id + + if self._is_subindex: + check_root_id(self, values) + self._extend(values, **kwargs) self._update_subindices_append_extend(values) diff --git a/docarray/array/storage/elastic/backend.py b/docarray/array/storage/elastic/backend.py index d71d5e3bde2..1513987a4a5 100644 --- a/docarray/array/storage/elastic/backend.py +++ b/docarray/array/storage/elastic/backend.py @@ -46,6 +46,7 @@ class ElasticConfig: ef_construction: Optional[int] = None m: Optional[int] = None columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None + root_id: bool = True _banned_indexname_chars = ['[', ' ', '"', '*', '\\', '<', '|', ',', '>', '/', '?', ']'] @@ -100,7 +101,7 @@ def _init_storage( self._build_offset2id_index() # Note super()._init_storage() calls _load_offset2ids which calls _get_offset2ids_meta - super()._init_storage() + super()._init_storage(**kwargs) if _docs is None: return diff --git a/docarray/array/storage/memory/find.py b/docarray/array/storage/memory/find.py index 9cbd07dfc32..ccf1855fa90 100644 --- a/docarray/array/storage/memory/find.py +++ b/docarray/array/storage/memory/find.py @@ -180,3 +180,19 @@ def _get_dist(da: 'DocumentArray'): idx = np.take_along_axis(top_inds, permutation, axis=1) return dist, idx + + def _get_root_docs(self, docs: 'DocumentArray') -> 'DocumentArray': + """Get the root documents of the current DocumentArray. + + :return: a `DocumentArray` containing the root documents. + """ + from docarray import DocumentArray + + root_da_flat = self[...] + da = DocumentArray() + for doc in docs: + result = doc + while getattr(result, 'parent_id', None): + result = root_da_flat[result.parent_id] + da.append(result) + return da diff --git a/docarray/array/storage/milvus/backend.py b/docarray/array/storage/milvus/backend.py index 35154ca3c66..e09a24c093b 100644 --- a/docarray/array/storage/milvus/backend.py +++ b/docarray/array/storage/milvus/backend.py @@ -93,6 +93,7 @@ class MilvusConfig: batch_size: int = -1 columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None list_like: bool = True + root_id: bool = True class BackendMixin(BaseBackendMixin): @@ -134,7 +135,7 @@ def _init_storage( self._collection = self._create_or_reuse_collection() self._offset2id_collection = self._create_or_reuse_offset2id_collection() self._build_index() - super()._init_storage() + super()._init_storage(**kwargs) # To align with Sqlite behavior; if `docs` is not `None` and table name # is provided, :class:`DocumentArraySqlite` will clear the existing diff --git a/docarray/array/storage/qdrant/backend.py b/docarray/array/storage/qdrant/backend.py index c67847c57b8..13d27bb72ee 100644 --- a/docarray/array/storage/qdrant/backend.py +++ b/docarray/array/storage/qdrant/backend.py @@ -51,6 +51,7 @@ class QdrantConfig: full_scan_threshold: Optional[int] = None m: Optional[int] = None columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None + root_id: bool = True class BackendMixin(BaseBackendMixin): @@ -128,7 +129,7 @@ def _init_storage( self._initialize_qdrant_schema() - super()._init_storage() + super()._init_storage(**kwargs) if docs is None and config.collection_name: return diff --git a/docarray/array/storage/redis/backend.py b/docarray/array/storage/redis/backend.py index 1fbcc0dd9c1..bb72fdab405 100644 --- a/docarray/array/storage/redis/backend.py +++ b/docarray/array/storage/redis/backend.py @@ -35,6 +35,7 @@ class RedisConfig: block_size: Optional[int] = None initial_cap: Optional[int] = None columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None + root_id: bool = True class BackendMixin(BaseBackendMixin): @@ -87,7 +88,7 @@ def _init_storage( self._client = self._build_client() self._build_index() - super()._init_storage() + super()._init_storage(**kwargs) if _docs is None: return diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 688158f0f91..5c98ef1e93b 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -29,6 +29,7 @@ class SqliteConfig: conn_config: Dict = field(default_factory=dict) journal_mode: str = 'WAL' synchronous: str = 'OFF' + root_id: bool = True class BackendMixin(BaseBackendMixin): @@ -101,7 +102,7 @@ def _init_storage( self._connection.commit() self._config = config self._list_like = config.list_like - super()._init_storage() + super()._init_storage(**kwargs) if _docs is None: return diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index c16eee41b38..57c5c811740 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -52,6 +52,7 @@ class WeaviateConfig: # weaviate python client parameters batch_size: Optional[int] = field(default=50) dynamic_batching: Optional[bool] = field(default=False) + root_id: bool = True def __post_init__(self): if isinstance(self.timeout_config, list): diff --git a/docarray/helper.py b/docarray/helper.py index b5858724517..6d5b676ada6 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -12,7 +12,7 @@ import hubble if TYPE_CHECKING: # pragma: no cover - from docarray import DocumentArray + from docarray import Document, DocumentArray __resources_path__ = os.path.join( os.path.dirname( @@ -491,6 +491,35 @@ def _get_array_info(da: 'DocumentArray'): return is_homo, _nested_in, _nested_items, attr_counter, all_attrs_names +def check_root_id(da: 'DocumentArray', value: Union['Document', Sequence['Document']]): + + from docarray import Document + from docarray.array.memory import DocumentArrayInMemory + + if not ( + isinstance(value, Document) + or (isinstance(value, Sequence) and isinstance(value[0], Document)) + ): + return + + if isinstance(value, Document): + value = [value] + + if isinstance(da, DocumentArrayInMemory): + if not all([getattr(doc, 'parent_id', None) for doc in value]): + warnings.warn( + "Not all documents have parent_id set. This may cause unexpected behavior.", + UserWarning, + ) + elif da._config.root_id and not all( + [doc.tags.get('_root_id_', None) for doc in value] + ): + warnings.warn( + "root_id is enabled but not all documents have _root_id_ set. This may cause unexpected behavior.", + UserWarning, + ) + + def login(interactive: Optional[bool] = None, force: bool = False, **kwargs): """Login to Jina AI Cloud account. :param interactive: If set to true, login will support notebook environments, otherwise the enviroment will be inferred. diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index 922b4938306..c764a157b93 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -48,6 +48,7 @@ The following configs can be set: | `max_connection` | The number of bi-directional links created for every new element during construction. | `None`, defaults to the default value in the AnnLite package* | | `n_components` | The output dimension of PCA model. Should be a positive number and less than `n_dim` if it's not `None` | `None`, defaults to the default value in the AnnLite package* | | `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True | *You can check the default values in [the AnnLite source code](https://github.com/jina-ai/annlite/blob/main/annlite/core/index/hnsw/index.py) diff --git a/docs/advanced/document-store/elasticsearch.md b/docs/advanced/document-store/elasticsearch.md index b55e3ba3172..20af11fd6a7 100644 --- a/docs/advanced/document-store/elasticsearch.md +++ b/docs/advanced/document-store/elasticsearch.md @@ -404,6 +404,7 @@ The following configs can be set: | `tag_indices` | List of tags to index | False | | `batch_size` | Batch size used to handle storage refreshes/updates | 64 | | `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True | ```{tip} You can read more about HNSW parameters and their default values [here](https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params) diff --git a/docs/advanced/document-store/milvus.md b/docs/advanced/document-store/milvus.md index a93c9ccaed5..6fdff41eb0f 100644 --- a/docs/advanced/document-store/milvus.md +++ b/docs/advanced/document-store/milvus.md @@ -128,10 +128,11 @@ The following configs can be set: | `index_params` | A dictionary of parameters used for index building. The [allowed parameters](https://milvus.io/docs/v2.1.x/index.md) depend on the index type. | {'M': 4, 'efConstruction': 200} (assumes HNSW index) | | `collection_config` | Configuration for the Milvus collection. Passed as **kwargs during collection creation (`Collection(...)`). | {} | | `serialize_config` | [Serialization config of each Document](../../../fundamentals/document/serialization.md) | {} | - | `consistency_level` | [Consistency level](https://milvus.io/docs/v2.1.x/consistency.md#Consistency-levels) for Milvus database operations. Can be 'Session', 'Strong', 'Bounded' or 'Eventually'. | 'Session' | +| `consistency_level` | [Consistency level](https://milvus.io/docs/v2.1.x/consistency.md#Consistency-levels) for Milvus database operations. Can be 'Session', 'Strong', 'Bounded' or 'Eventually'. | 'Session' | | `batch_size` | Default batch size for CRUD operations. | -1 (no batching) | | `columns` | Additional columns to be stored in the datbase, taken from Document `tags`. | None | | `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True | ## Minimal example diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index b7438f5499d..179fde41596 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -94,6 +94,8 @@ The following configs can be set: | `m` | Number of edges per node in the index graph. Larger = more accurate search, more space required | `None`, defaults to the default value in Qdrant* | | `columns` | Other fields to store in Document | `None` | | `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True | + *You can read more about the HNSW parameters and their default values [here](https://qdrant.tech/documentation/indexing/#vector-index) diff --git a/docs/advanced/document-store/redis.md b/docs/advanced/document-store/redis.md index aee5d43560b..c5267b4e61c 100644 --- a/docs/advanced/document-store/redis.md +++ b/docs/advanced/document-store/redis.md @@ -136,7 +136,8 @@ The following configs can be set: | `block_size` | Optional parameter for Redis FLAT algorithm | `1048576` | | `initial_cap` | Optional parameter for Redis HNSW and FLAT algorithm | `None`, defaults to the default value in Redis | | `columns` | Other fields to store in Document and build schema | `None` | -| `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | `True` | +| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | `True` | You can check the default values in [the docarray source code](https://github.com/jina-ai/docarray/blob/main/docarray/array/storage/redis/backend.py). For vector search configurations, default values are those of the database backend, which you can find in the [Redis documentation](https://redis.io/docs/stack/search/reference/vectors/). diff --git a/docs/advanced/document-store/sqlite.md b/docs/advanced/document-store/sqlite.md index 6184588d1de..980af482527 100644 --- a/docs/advanced/document-store/sqlite.md +++ b/docs/advanced/document-store/sqlite.md @@ -33,12 +33,13 @@ Other functions behave the same as in-memory DocumentArray. The following configs can be set: -| Name | Description | Default | -|--------------------|------------------------------------------------------------------------------------------------------------------|--| -| `connection` | SQLite database filename | a random temp file | -| `table_name` | SQLite table name | a random name | -| `serialize_config` | [Serialization config of each Document](../../../fundamentals/document/serialization.md) | None | -| `conn_config` | [Connection config pass to `sqlite3.connect`](https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection) | None | -| `journal_mode` | [SQLite Pragma: journal mode](https://www.sqlite.org/pragma.html#pragma_journal_mode) | `'DELETE'` | -| `synchronous` | [SQLite Pragma: synchronous](https://www.sqlite.org/pragma.html#pragma_synchronous) | `'OFF'` | -| `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| Name | Description | Default | +|--------------------|----------------------------------------------------------------------------------------------------------------------------------------|--------------------| +| `connection` | SQLite database filename | a random temp file | +| `table_name` | SQLite table name | a random name | +| `serialize_config` | [Serialization config of each Document](../../../fundamentals/document/serialization.md) | None | +| `conn_config` | [Connection config pass to `sqlite3.connect`](https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection) | None | +| `journal_mode` | [SQLite Pragma: journal mode](https://www.sqlite.org/pragma.html#pragma_journal_mode) | `'DELETE'` | +| `synchronous` | [SQLite Pragma: synchronous](https://www.sqlite.org/pragma.html#pragma_synchronous) | `'OFF'` | +| `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True | diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index 3563a9232cd..a60c2268ed2 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -103,7 +103,9 @@ The following configs can be set: | `flat_search_cutoff` | Absolute number of objects configured as the threshold for a flat-search cutoff. If a filter on a filtered vector search matches fewer than the specified elements, the HNSW index is bypassed entirely and a flat (brute-force) search is performed instead. This can speed up queries with very restrictive filters considerably. Optional, defaults to 40000. Set to 0 to turn off flat-search cutoff entirely. | `None`, defaults to the default value in Weaviate* | | `cleanup_interval_seconds` | How often the async process runs that “repairs” the HNSW graph after deletes and updates. (Prior to the repair/cleanup process, deleted objects are simply marked as deleted, but still a fully connected member of the HNSW graph. After the repair has run, the edges are reassigned and the datapoints deleted for good). Typically this value does not need to be adjusted, but if deletes or updates are very frequent it might make sense to adjust the value up or down. (Higher value means it runs less frequently, but cleans up more in a single batch. Lower value means it runs more frequently, but might not be as efficient with each run). | `None`, defaults to the default value in Weaviate* | | `skip` | There are situations where it doesn’t make sense to vectorize a class. For example if the class is just meant as glue between two other class (consisting only of references) or if the class contains mostly duplicate elements (Note that importing duplicate vectors into HNSW is very expensive as the algorithm uses a check whether a candidate’s distance is higher than the worst candidate’s distance for an early exit condition. With (mostly) identical vectors, this early exit condition is never met leading to an exhaustive search on each import or query). In this case, you can skip indexing a vector all-together. To do so, set "skip" to "true". skip defaults to false; if not set to true, classes will be indexed normally. This setting is immutable after class initialization. | `None`, defaults to the default value in Weaviate* | -| `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| `list_like` | Controls if ordering of Documents is persisted in the Database. Disabling this breaks list-like features, but can improve performance. | True | +| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True | + *You can read more about the HNSW parameters and their default values [here](https://weaviate.io/developers/weaviate/current/vector-index-plugins/hnsw.html#how-to-use-hnsw-and-parameters) diff --git a/docs/fundamentals/documentarray/subindex.md b/docs/fundamentals/documentarray/subindex.md index 405f7b0f19b..1091c7d6749 100644 --- a/docs/fundamentals/documentarray/subindex.md +++ b/docs/fundamentals/documentarray/subindex.md @@ -217,18 +217,26 @@ Document(embedding=np.random.rand(512)).match(da, on='@c') ``` ```` -Such a search will return Documents from the subindex. If you are interested in the top-level Documents associated with -a match, you can retrieve them using `parent_id`: +Such a search will return Documents from the subindex. If you are interested in the top-level Documents associated with a match, you can retrieve them by setting `return_root=True` in `find`: ````{tab} Subindex with dataclass modalities ```python -top_image_matches = da.find(query=np.random.rand(512), on='@.[image]') -top_level_matches = da[top_image_matches[:, 'parent_id']] +top_level_matches = da.find(query=np.random.rand(512), on='@.[image]', return_root=True) ``` ```` ````{tab} Subindex with chunks ```python -top_image_matches = da.find(query=np.random.rand(512), on='@c') -top_level_matches = da[top_image_matches[:, 'parent_id']] +top_level_matches = da.find(query=np.random.rand(512), on='@c', return_root=True) +``` +```` + +````{admonition} Note +:class: note +When you add or change Documents directly on a subindex, the `_root_id_` (or `parent_id` for DocumentArrayInMemory) of new Documents should be set manually for `return_root=True` to work: + +```python +da['@c'].extend( + Document(embedding=np.random.random(512), tags={'_root_id_': 'your_root_id'}) +) ``` ```` diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 47a66b44d76..0c87ae34858 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -987,3 +987,121 @@ class MMDoc: assert (closest_docs[0].embedding == np.array([3, 3])).all() for d in closest_docs: assert d.id.endswith('_2') + + +@pytest.mark.parametrize( + 'storage, config, subindex_configs', + [ + ('memory', None, {'@c': None}), + ( + 'weaviate', + { + 'n_dim': 3, + }, + {'@c': {'n_dim': 3}}, + ), + ('annlite', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ('sqlite', dict(), {'@c': dict()}), + ('qdrant', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ('elasticsearch', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ('redis', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ('milvus', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ], +) +def test_find_return_root(storage, config, subindex_configs, start_storage): + da = DocumentArray( + storage=storage, + config=config, + subindex_configs=subindex_configs, + ) + + with da: + da.extend( + [ + Document( + id=f'{i}', + chunks=[ + Document(id=f'sub{i}', embedding=np.random.random(3)), + ], + ) + for i in range(9) + ] + ) + + da[0] = Document( + id='9', + embedding=np.random.random(3), + chunks=[ + Document(id=f'sub9', embedding=np.random.random(3)), + ], + ) + + if storage != 'memory': + assert all( + d.tags['_root_id_'] in [f'{i}' for i in range(1, 10)] for d in da['@c'] + ) + + query = np.random.random(3) + res = da.find(query, on='@c') + root_level_res = da.find(query, on='@c', return_root=True) + + res_root_id = [i.id[3] for i in res] + assert res_root_id == root_level_res[:, 'id'] + assert res[:, 'scores'] == root_level_res[:, 'scores'] + + assert len(root_level_res) > 0 + assert all(d.id in [f'{i}' for i in range(1, 10)] for d in root_level_res) + + +@pytest.mark.parametrize( + 'storage, config, subindex_configs', + [ + ('memory', None, {'@c': None}), + ( + 'weaviate', + { + 'n_dim': 3, + }, + {'@c': {'n_dim': 3}}, + ), + ('annlite', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ('sqlite', dict(), {'@c': dict()}), + ('qdrant', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ('elasticsearch', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ('redis', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ('milvus', {'n_dim': 3}, {'@c': {'n_dim': 3}}), + ], +) +def test_subindex_root_id(storage, config, subindex_configs, start_storage): + da = DocumentArray( + storage=storage, + config=config, + subindex_configs=subindex_configs, + ) + + with da: + da.extend( + [ + Document( + id=f'{i}', + chunks=[ + Document(id=f'sub{i}_0'), + Document(id=f'sub{i}_1'), + ], + ) + for i in range(5) + ] + ) + + with pytest.warns(UserWarning): + new_da = DocumentArray([Document(id=f'temp{i}') for i in range(10)]) + new_da[:, 'id'] = da['@c'][:, 'id'] + da['@c'] = new_da + with pytest.warns(UserWarning): + da['@c'].extend([Document(id='sub_extra')]) + with pytest.warns(UserWarning): + da['@c']['sub0_0'] = Document(id='sub0_new') + with pytest.warns(UserWarning): + da['@c'][0] = Document(id='sub0_new') + with pytest.warns(UserWarning): + da['@c'][0:] = [Document(id='sub0_new'), Document(id='sub0_new2')]