Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions docarray/index/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,14 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
"""index Documents into the index.

:param docs: Documents to index.

!!! note
Passing a sequence of Documents that is not a DocList
(such as a List of Docs) comes at a performance penalty.
This is because the Index needs to check compatibility between itself and
the data. With a DocList as input this is a single check; for other inputs
compatibility needs to be checked for every Document individually.
"""
if not isinstance(docs, (BaseDoc, DocList)):
self._logger.warning(
'Passing a sequence of Documents that is not a DocList comes at '
'a performance penalty, since compatibility with the schema of Index '
'needs to be checked for every Document individually.'
)
self._logger.debug(f'Indexing {len(docs)} documents')
docs_validated = self._validate_docs(docs)
data_by_columns = self._get_col_value_dict(docs_validated)
Expand All @@ -399,7 +400,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
def find(
self,
query: Union[AnyTensor, BaseDoc],
search_field: str = 'embedding',
search_field: str = '',
limit: int = 10,
**kwargs,
) -> FindResult:
Expand Down Expand Up @@ -432,7 +433,7 @@ def find(
def find_batched(
self,
queries: Union[AnyTensor, DocList],
search_field: str = 'embedding',
search_field: str = '',
limit: int = 10,
**kwargs,
) -> FindResultBatched:
Expand Down Expand Up @@ -511,7 +512,7 @@ def filter_batched(
def text_search(
self,
query: Union[str, BaseDoc],
search_field: str = 'text',
search_field: str = '',
limit: int = 10,
**kwargs,
) -> FindResult:
Expand Down Expand Up @@ -539,7 +540,7 @@ def text_search(
def text_search_batched(
self,
queries: Union[Sequence[str], Sequence[BaseDoc]],
search_field: str = 'text',
search_field: str = '',
limit: int = 10,
**kwargs,
) -> FindResultBatched:
Expand Down
40 changes: 20 additions & 20 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class _FindResult(NamedTuple):
def find(
index: AnyDocArray,
query: Union[AnyTensor, BaseDoc],
embedding_field: str = 'embedding',
search_field: str = '',
metric: str = 'cosine_sim',
limit: int = 10,
device: Optional[str] = None,
Expand Down Expand Up @@ -61,7 +61,7 @@ class MyDocument(BaseDoc):
top_matches, scores = find(
index=index,
query=query,
embedding_field='embedding',
search_field='embedding',
metric='cosine_sim',
)

Expand All @@ -70,7 +70,7 @@ class MyDocument(BaseDoc):
top_matches, scores = find(
index=index,
query=query,
embedding_field='embedding',
search_field='embedding',
metric='cosine_sim',
)
```
Expand All @@ -79,7 +79,7 @@ class MyDocument(BaseDoc):

:param index: the index of Documents to search in
:param query: the query to search for
:param embedding_field: the tensor-like field in the index to use
:param search_field: the tensor-like field in the index to use
for the similarity computation
:param metric: the distance metric to use for the similarity computation.
Can be one of the following strings:
Expand All @@ -94,11 +94,11 @@ class MyDocument(BaseDoc):
where the first element contains the closes matches for the query,
and the second element contains the corresponding scores.
"""
query = _extract_embedding_single(query, embedding_field)
query = _extract_embedding_single(query, search_field)
return find_batched(
index=index,
query=query,
embedding_field=embedding_field,
search_field=search_field,
metric=metric,
limit=limit,
device=device,
Expand All @@ -109,7 +109,7 @@ class MyDocument(BaseDoc):
def find_batched(
index: AnyDocArray,
query: Union[AnyTensor, DocList],
embedding_field: str = 'embedding',
search_field: str = '',
metric: str = 'cosine_sim',
limit: int = 10,
device: Optional[str] = None,
Expand Down Expand Up @@ -145,7 +145,7 @@ class MyDocument(BaseDoc):
results = find_batched(
index=index,
query=query,
embedding_field='embedding',
search_field='embedding',
metric='cosine_sim',
)
top_matches, scores = results[0]
Expand All @@ -155,7 +155,7 @@ class MyDocument(BaseDoc):
results = find_batched(
index=index,
query=query,
embedding_field='embedding',
search_field='embedding',
metric='cosine_sim',
)
top_matches, scores = results[0]
Expand All @@ -165,7 +165,7 @@ class MyDocument(BaseDoc):

:param index: the index of Documents to search in
:param query: the query to search for
:param embedding_field: the tensor-like field in the index to use
:param search_field: the tensor-like field in the index to use
for the similarity computation
:param metric: the distance metric to use for the similarity computation.
Can be one of the following strings:
Expand All @@ -183,12 +183,12 @@ class MyDocument(BaseDoc):
if descending is None:
descending = metric.endswith('_sim') # similarity metrics are descending

embedding_type = _da_attr_type(index, embedding_field)
embedding_type = _da_attr_type(index, search_field)
comp_backend = embedding_type.get_comp_backend()

# extract embeddings from query and index
index_embeddings = _extract_embeddings(index, embedding_field, embedding_type)
query_embeddings = _extract_embeddings(query, embedding_field, embedding_type)
index_embeddings = _extract_embeddings(index, search_field, embedding_type)
query_embeddings = _extract_embeddings(query, search_field, embedding_type)

# compute distances and return top results
metric_fn = getattr(comp_backend.Metrics, metric)
Expand All @@ -209,18 +209,18 @@ class MyDocument(BaseDoc):

def _extract_embedding_single(
data: Union[DocList, BaseDoc, AnyTensor],
embedding_field: str,
search_field: str,
) -> AnyTensor:
"""Extract the embeddings from a single query,
and return it in a batched representation.

:param data: the data
:param embedding_field: the embedding field
:param search_field: the embedding field
:param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
:return: the embeddings
"""
if isinstance(data, BaseDoc):
emb = next(AnyDocArray._traverse(data, embedding_field))
emb = next(AnyDocArray._traverse(data, search_field))
else: # treat data as tensor
emb = data
if len(emb.shape) == 1:
Expand All @@ -232,22 +232,22 @@ def _extract_embedding_single(

def _extract_embeddings(
data: Union[AnyDocArray, BaseDoc, AnyTensor],
embedding_field: str,
search_field: str,
embedding_type: Type,
) -> AnyTensor:
"""Extract the embeddings from the data.

:param data: the data
:param embedding_field: the embedding field
:param search_field: the embedding field
:param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
:return: the embeddings
"""
emb: AnyTensor
if isinstance(data, DocList):
emb_list = list(AnyDocArray._traverse(data, embedding_field))
emb_list = list(AnyDocArray._traverse(data, search_field))
emb = embedding_type._docarray_stack(emb_list)
elif isinstance(data, (DocVec, BaseDoc)):
emb = next(AnyDocArray._traverse(data, embedding_field))
emb = next(AnyDocArray._traverse(data, search_field))
else: # treat data as tensor
emb = cast(AnyTensor, data)

Expand Down
27 changes: 27 additions & 0 deletions tests/index/hnswlib/test_configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
import pytest
from pydantic import Field

from docarray import BaseDoc
from docarray.index import HnswDocumentIndex
from docarray.typing import NdArray

pytestmark = [pytest.mark.slow, pytest.mark.index]


class MyDoc(BaseDoc):
tens: NdArray


def test_configure_dim(tmp_path):
class Schema(BaseDoc):
tens: NdArray = Field(dim=10)

index = HnswDocumentIndex[Schema](work_dir=str(tmp_path))

assert index._hnsw_indices['tens'].dim == 10

docs = [Schema(tens=np.random.random((10,))) for _ in range(10)]
index.index(docs)

assert index.num_docs() == 10
Loading