diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 03ab7361f62..923a95de6d1 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -384,13 +384,14 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs): """index Documents into the index. :param docs: Documents to index. + + !!! note + Passing a sequence of Documents that is not a DocList + (such as a List of Docs) comes at a performance penalty. + This is because the Index needs to check compatibility between itself and + the data. With a DocList as input this is a single check; for other inputs + compatibility needs to be checked for every Document individually. """ - if not isinstance(docs, (BaseDoc, DocList)): - self._logger.warning( - 'Passing a sequence of Documents that is not a DocList comes at ' - 'a performance penalty, since compatibility with the schema of Index ' - 'needs to be checked for every Document individually.' - ) self._logger.debug(f'Indexing {len(docs)} documents') docs_validated = self._validate_docs(docs) data_by_columns = self._get_col_value_dict(docs_validated) @@ -399,7 +400,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs): def find( self, query: Union[AnyTensor, BaseDoc], - search_field: str = 'embedding', + search_field: str = '', limit: int = 10, **kwargs, ) -> FindResult: @@ -432,7 +433,7 @@ def find( def find_batched( self, queries: Union[AnyTensor, DocList], - search_field: str = 'embedding', + search_field: str = '', limit: int = 10, **kwargs, ) -> FindResultBatched: @@ -511,7 +512,7 @@ def filter_batched( def text_search( self, query: Union[str, BaseDoc], - search_field: str = 'text', + search_field: str = '', limit: int = 10, **kwargs, ) -> FindResult: @@ -539,7 +540,7 @@ def text_search( def text_search_batched( self, queries: Union[Sequence[str], Sequence[BaseDoc]], - search_field: str = 'text', + search_field: str = '', limit: int = 10, **kwargs, ) -> FindResultBatched: diff --git a/docarray/utils/find.py b/docarray/utils/find.py index 405f3e75f15..8fc32a1226f 100644 --- a/docarray/utils/find.py +++ b/docarray/utils/find.py @@ -26,7 +26,7 @@ class _FindResult(NamedTuple): def find( index: AnyDocArray, query: Union[AnyTensor, BaseDoc], - embedding_field: str = 'embedding', + search_field: str = '', metric: str = 'cosine_sim', limit: int = 10, device: Optional[str] = None, @@ -61,7 +61,7 @@ class MyDocument(BaseDoc): top_matches, scores = find( index=index, query=query, - embedding_field='embedding', + search_field='embedding', metric='cosine_sim', ) @@ -70,7 +70,7 @@ class MyDocument(BaseDoc): top_matches, scores = find( index=index, query=query, - embedding_field='embedding', + search_field='embedding', metric='cosine_sim', ) ``` @@ -79,7 +79,7 @@ class MyDocument(BaseDoc): :param index: the index of Documents to search in :param query: the query to search for - :param embedding_field: the tensor-like field in the index to use + :param search_field: the tensor-like field in the index to use for the similarity computation :param metric: the distance metric to use for the similarity computation. Can be one of the following strings: @@ -94,11 +94,11 @@ class MyDocument(BaseDoc): where the first element contains the closes matches for the query, and the second element contains the corresponding scores. """ - query = _extract_embedding_single(query, embedding_field) + query = _extract_embedding_single(query, search_field) return find_batched( index=index, query=query, - embedding_field=embedding_field, + search_field=search_field, metric=metric, limit=limit, device=device, @@ -109,7 +109,7 @@ class MyDocument(BaseDoc): def find_batched( index: AnyDocArray, query: Union[AnyTensor, DocList], - embedding_field: str = 'embedding', + search_field: str = '', metric: str = 'cosine_sim', limit: int = 10, device: Optional[str] = None, @@ -145,7 +145,7 @@ class MyDocument(BaseDoc): results = find_batched( index=index, query=query, - embedding_field='embedding', + search_field='embedding', metric='cosine_sim', ) top_matches, scores = results[0] @@ -155,7 +155,7 @@ class MyDocument(BaseDoc): results = find_batched( index=index, query=query, - embedding_field='embedding', + search_field='embedding', metric='cosine_sim', ) top_matches, scores = results[0] @@ -165,7 +165,7 @@ class MyDocument(BaseDoc): :param index: the index of Documents to search in :param query: the query to search for - :param embedding_field: the tensor-like field in the index to use + :param search_field: the tensor-like field in the index to use for the similarity computation :param metric: the distance metric to use for the similarity computation. Can be one of the following strings: @@ -183,12 +183,12 @@ class MyDocument(BaseDoc): if descending is None: descending = metric.endswith('_sim') # similarity metrics are descending - embedding_type = _da_attr_type(index, embedding_field) + embedding_type = _da_attr_type(index, search_field) comp_backend = embedding_type.get_comp_backend() # extract embeddings from query and index - index_embeddings = _extract_embeddings(index, embedding_field, embedding_type) - query_embeddings = _extract_embeddings(query, embedding_field, embedding_type) + index_embeddings = _extract_embeddings(index, search_field, embedding_type) + query_embeddings = _extract_embeddings(query, search_field, embedding_type) # compute distances and return top results metric_fn = getattr(comp_backend.Metrics, metric) @@ -209,18 +209,18 @@ class MyDocument(BaseDoc): def _extract_embedding_single( data: Union[DocList, BaseDoc, AnyTensor], - embedding_field: str, + search_field: str, ) -> AnyTensor: """Extract the embeddings from a single query, and return it in a batched representation. :param data: the data - :param embedding_field: the embedding field + :param search_field: the embedding field :param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc. :return: the embeddings """ if isinstance(data, BaseDoc): - emb = next(AnyDocArray._traverse(data, embedding_field)) + emb = next(AnyDocArray._traverse(data, search_field)) else: # treat data as tensor emb = data if len(emb.shape) == 1: @@ -232,22 +232,22 @@ def _extract_embedding_single( def _extract_embeddings( data: Union[AnyDocArray, BaseDoc, AnyTensor], - embedding_field: str, + search_field: str, embedding_type: Type, ) -> AnyTensor: """Extract the embeddings from the data. :param data: the data - :param embedding_field: the embedding field + :param search_field: the embedding field :param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc. :return: the embeddings """ emb: AnyTensor if isinstance(data, DocList): - emb_list = list(AnyDocArray._traverse(data, embedding_field)) + emb_list = list(AnyDocArray._traverse(data, search_field)) emb = embedding_type._docarray_stack(emb_list) elif isinstance(data, (DocVec, BaseDoc)): - emb = next(AnyDocArray._traverse(data, embedding_field)) + emb = next(AnyDocArray._traverse(data, search_field)) else: # treat data as tensor emb = cast(AnyTensor, data) diff --git a/tests/index/hnswlib/test_configurations.py b/tests/index/hnswlib/test_configurations.py new file mode 100644 index 00000000000..dff64fdcc19 --- /dev/null +++ b/tests/index/hnswlib/test_configurations.py @@ -0,0 +1,27 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import HnswDocumentIndex +from docarray.typing import NdArray + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class MyDoc(BaseDoc): + tens: NdArray + + +def test_configure_dim(tmp_path): + class Schema(BaseDoc): + tens: NdArray = Field(dim=10) + + index = HnswDocumentIndex[Schema](work_dir=str(tmp_path)) + + assert index._hnsw_indices['tens'].dim == 10 + + docs = [Schema(tens=np.random.random((10,))) for _ in range(10)] + index.index(docs) + + assert index.num_docs() == 10 diff --git a/tests/units/util/test_find.py b/tests/units/util/test_find.py index 90b3c7005d8..96837cddc7f 100644 --- a/tests/units/util/test_find.py +++ b/tests/units/util/test_find.py @@ -52,7 +52,7 @@ def test_find_torch(random_torch_query, random_torch_index, metric): top_k, scores = find( random_torch_index, random_torch_query, - embedding_field='tensor', + search_field='tensor', limit=7, metric=metric, ) @@ -69,7 +69,7 @@ def test_find_torch_tensor_query(random_torch_query, random_torch_index): top_k, scores = find( random_torch_index, query, - embedding_field='tensor', + search_field='tensor', limit=7, metric='cosine_sim', ) @@ -83,7 +83,7 @@ def test_find_torch_stacked(random_torch_query, random_torch_index): top_k, scores = find( random_torch_index, random_torch_query, - embedding_field='tensor', + search_field='tensor', limit=7, metric='cosine_sim', ) @@ -97,7 +97,7 @@ def test_find_np(random_nd_query, random_nd_index, metric): top_k, scores = find( random_nd_index, random_nd_query, - embedding_field='tensor', + search_field='tensor', limit=7, metric=metric, ) @@ -114,7 +114,7 @@ def test_find_np_tensor_query(random_nd_query, random_nd_index): top_k, scores = find( random_nd_index, query, - embedding_field='tensor', + search_field='tensor', limit=7, metric='cosine_sim', ) @@ -128,7 +128,7 @@ def test_find_np_stacked(random_nd_query, random_nd_index): top_k, scores = find( random_nd_index, random_nd_query, - embedding_field='tensor', + search_field='tensor', limit=7, metric='cosine_sim', ) @@ -142,7 +142,7 @@ def test_find_batched_torch(random_torch_batch_query, random_torch_index, metric results = find_batched( random_torch_index, random_torch_batch_query, - embedding_field='tensor', + search_field='tensor', limit=7, metric=metric, ) @@ -162,7 +162,7 @@ def test_find_batched_torch_tensor_query(random_torch_batch_query, random_torch_ results = find_batched( random_torch_index, query, - embedding_field='tensor', + search_field='tensor', limit=7, metric='cosine_sim', ) @@ -186,7 +186,7 @@ def test_find_batched_torch_stacked( results = find_batched( random_torch_index, random_torch_batch_query, - embedding_field='tensor', + search_field='tensor', limit=7, metric='cosine_sim', ) @@ -203,7 +203,7 @@ def test_find_batched_np(random_nd_batch_query, random_nd_index, metric): results = find_batched( random_nd_index, random_nd_batch_query, - embedding_field='tensor', + search_field='tensor', limit=7, metric=metric, ) @@ -223,7 +223,7 @@ def test_find_batched_np_tensor_query(random_nd_batch_query, random_nd_index): results = find_batched( random_nd_index, query, - embedding_field='tensor', + search_field='tensor', limit=7, metric='cosine_sim', ) @@ -244,7 +244,7 @@ def test_find_batched_np_stacked(random_nd_batch_query, random_nd_index, stack_w results = find_batched( random_nd_index, random_nd_batch_query, - embedding_field='tensor', + search_field='tensor', limit=7, metric='cosine_sim', ) @@ -266,7 +266,7 @@ class MyDoc(BaseDoc): top_k, scores = find( index, query, - embedding_field='embedding', + search_field='embedding', limit=7, ) assert len(top_k) == 7 @@ -284,7 +284,7 @@ class MyDoc(BaseDoc): top_k, scores = find( index, query, - embedding_field='embedding', + search_field='embedding', limit=7, ) assert len(top_k) == 7 @@ -314,7 +314,7 @@ class MyDoc(BaseDoc): top_k, scores = find( index, query, - embedding_field='inner__embedding', + search_field='inner__embedding', limit=7, ) assert len(top_k) == 7 @@ -350,7 +350,7 @@ class MyDoc(BaseDoc): top_k, scores = find( index, query, - embedding_field='embedding', + search_field='embedding', limit=7, ) assert len(top_k) == 7 @@ -360,7 +360,7 @@ class MyDoc(BaseDoc): top_k, scores = find( index, query, - embedding_field='embedding2', + search_field='embedding2', limit=7, ) assert len(top_k) == 7 @@ -370,7 +370,7 @@ class MyDoc(BaseDoc): top_k, scores = find( index, query, - embedding_field='embedding3', + search_field='embedding3', limit=7, ) assert len(top_k) == 7 @@ -380,7 +380,7 @@ class MyDoc(BaseDoc): top_k, scores = find( index, query, - embedding_field='embedding4', + search_field='embedding4', limit=7, ) assert len(top_k) == 7