Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions docarray/array/storage/qdrant/find.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
TypeVar,
Sequence,
List,
Union,
Optional,
Dict,
)
from qdrant_client.http.models.models import Distance
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, Union

from docarray import Document, DocumentArray
from docarray.math import ndarray
from docarray.score import NamedScore
from qdrant_client.http import models as rest
from qdrant_client.http.models.models import Distance

if TYPE_CHECKING: # pragma: no cover
import numpy as np
import tensorflow
import torch
import numpy as np
from qdrant_client import QdrantClient

QdrantArrayType = TypeVar(
Expand Down Expand Up @@ -51,15 +44,22 @@ def distance(self) -> 'Distance':
raise NotImplementedError()

def _find_similar_vectors(
self, q: 'QdrantArrayType', limit: int = 10, filter: Optional[Dict] = None
self,
q: 'QdrantArrayType',
limit: int = 10,
filter: Optional[Dict] = None,
search_params: Optional[Dict] = None,
**kwargs,
):
query_vector = self._map_embedding(q)

search_result = self.client.search(
self.collection_name,
query_vector=query_vector,
query_filter=filter,
search_params=None,
search_params=None
if not search_params
else rest.SearchParams(**search_params),
top=limit,
append_payload=['_serialized'],
)
Expand All @@ -82,12 +82,14 @@ def _find(
query: 'QdrantArrayType',
limit: int = 10,
filter: Optional[Dict] = None,
search_params: Optional[Dict] = None,
**kwargs,
) -> List['DocumentArray']:
"""Returns approximate nearest neighbors given a batch of input queries.
:param query: input supported to be used in Qdrant.
:param limit: number of retrieved items
:param filter: filter query used for pre-filtering
:param search_params: additional parameters of the search


:return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`.
Expand All @@ -96,11 +98,17 @@ def _find(
num_rows, _ = ndarray.get_array_rows(query)

if num_rows == 1:
return [self._find_similar_vectors(query, limit=limit, filter=filter)]
return [
self._find_similar_vectors(
query, limit=limit, filter=filter, search_params=search_params
)
]
else:
closest_docs = []
for q in query:
da = self._find_similar_vectors(q, limit=limit, filter=filter)
da = self._find_similar_vectors(
q, limit=limit, filter=filter, search_params=search_params
)
closest_docs.append(da)
return closest_docs

Expand Down
6 changes: 2 additions & 4 deletions docs/advanced/document-store/qdrant.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,7 @@ for embedding, price in zip(da.embeddings, da[:, 'tags__price']):
print(f'\tembedding={embedding},\t price={price}')
```

Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that
prices must follow a filter. As an example, retrieved Documents must have `price` value lower than
or equal to `max_price`. We can encode this information in Qdrant using `filter = {'price': {'$lte': max_price}}`.
Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that prices must follow a filter. As an example, retrieved Documents must have `price` value lower than or equal to `max_price`. We can encode this information in Qdrant using `filter = {'must': [{'key': 'price', 'range': {'lte': max_price}}]}`. You can also pass additional `search_params` following [Qdrant's Search API](https://qdrant.tech/documentation/search/#search-api).

Then you can implement and use the search with the proposed filter:

Expand All @@ -189,7 +187,7 @@ np_query = np.ones(n_dim) * 8
print(f'\nQuery vector: \t{np_query}')

filter = {'must': [{'key': 'price', 'range': {'lte': max_price}}]}
results = da.find(np_query, filter=filter, limit=n_limit)
results = da.find(np_query, filter=filter, limit=n_limit, search_params={"hnsw_ef": 64})

print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n')
for embedding, price in zip(results.embeddings, results[:, 'tags__price']):
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/array/storage/qdrant/test_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest
from docarray import Document, DocumentArray


@pytest.mark.parametrize(
'search_params,expected',
[
({'hnsw_ef': 64}, 64),
(None, None),
],
)
def test_success_find_with_added_kwargs(
search_params, expected, start_storage, monkeypatch
):
nrof_docs = 10

qdrant_doc = DocumentArray(
storage='qdrant',
config={
'n_dim': 3,
},
)

with qdrant_doc:
qdrant_doc.extend(
[
Document(id=f'r{i}', embedding=np.ones((3,)) * i)
for i in range(nrof_docs)
],
)

def _mock_search(*args, **kwargs):
if expected:
assert kwargs['search_params'].hnsw_ef == expected
else:
assert kwargs['search_params'] is None
return []

monkeypatch.setattr(qdrant_doc._client, 'search', _mock_search)

np_query = np.array([2, 1, 3])

qdrant_doc.find(np_query, limit=10, search_params=search_params)