Skip to content
36 changes: 21 additions & 15 deletions docarray/array/storage/qdrant/find.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
TypeVar,
Sequence,
List,
Union,
Optional,
Dict,
)
from qdrant_client.http.models.models import Distance
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, Union

from docarray import Document, DocumentArray
from docarray.math import ndarray
from docarray.score import NamedScore
from qdrant_client.http import models as rest
from qdrant_client.http.models.models import Distance

if TYPE_CHECKING: # pragma: no cover
import numpy as np
import tensorflow
import torch
import numpy as np
from qdrant_client import QdrantClient

QdrantArrayType = TypeVar(
Expand Down Expand Up @@ -51,15 +44,20 @@ def distance(self) -> 'Distance':
raise NotImplementedError()

def _find_similar_vectors(
self, q: 'QdrantArrayType', limit: int = 10, filter: Optional[Dict] = None
self,
q: 'QdrantArrayType',
limit: int = 10,
filter: Optional[Dict] = None,
search_params: Optional[Dict] = {},
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set it as None as default and handle it properly

**kwargs,
):
query_vector = self._map_embedding(q)

search_result = self.client.search(
self.collection_name,
query_vector=query_vector,
query_filter=filter,
search_params=None,
search_params=rest.SearchParams(**search_params),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u have said that search_params is an Optional[Dict], but this is not true, if I pass None in search_params this will not work.

top=limit,
append_payload=['_serialized'],
)
Expand All @@ -82,12 +80,14 @@ def _find(
query: 'QdrantArrayType',
limit: int = 10,
filter: Optional[Dict] = None,
search_params: Optional[Dict] = {},
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

**kwargs,
) -> List['DocumentArray']:
"""Returns approximate nearest neighbors given a batch of input queries.
:param query: input supported to be used in Qdrant.
:param limit: number of retrieved items
:param filter: filter query used for pre-filtering
:param search_params: additional parameters of the search


:return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`.
Expand All @@ -96,11 +96,17 @@ def _find(
num_rows, _ = ndarray.get_array_rows(query)

if num_rows == 1:
return [self._find_similar_vectors(query, limit=limit, filter=filter)]
return [
self._find_similar_vectors(
query, limit=limit, filter=filter, search_params=search_params
)
]
else:
closest_docs = []
for q in query:
da = self._find_similar_vectors(q, limit=limit, filter=filter)
da = self._find_similar_vectors(
q, limit=limit, filter=filter, search_params=search_params
)
closest_docs.append(da)
return closest_docs

Expand Down
6 changes: 2 additions & 4 deletions docs/advanced/document-store/qdrant.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,7 @@ for embedding, price in zip(da.embeddings, da[:, 'tags__price']):
print(f'\tembedding={embedding},\t price={price}')
```

Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that
prices must follow a filter. As an example, retrieved Documents must have `price` value lower than
or equal to `max_price`. We can encode this information in Qdrant using `filter = {'price': {'$lte': max_price}}`.
Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that prices must follow a filter. As an example, retrieved Documents must have `price` value lower than or equal to `max_price`. We can encode this information in Qdrant using `filter = {'must': [{'key': 'price', 'range': {'lte': max_price}}]}`. You can also pass additional `search_params` following [Qdrant's Search API](https://qdrant.tech/documentation/search/#search-api).

Then you can implement and use the search with the proposed filter:

Expand All @@ -189,7 +187,7 @@ np_query = np.ones(n_dim) * 8
print(f'\nQuery vector: \t{np_query}')

filter = {'must': [{'key': 'price', 'range': {'lte': max_price}}]}
results = da.find(np_query, filter=filter, limit=n_limit)
results = da.find(np_query, filter=filter, limit=n_limit, search_params={"hnsw_ef": 64})

print('\nEmbeddings Nearest Neighbours with "price" at most 7:\n')
for embedding, price in zip(results.embeddings, results[:, 'tags__price']):
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/array/storage/qdrant/test_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from docarray import Document, DocumentArray
import numpy as np


def test_success_find_with_added_kwargs(start_storage, monkeypatch):
nrof_docs = 10
hnsw_ef = 64

qdrant_doc = DocumentArray(
storage='qdrant',
config={
'n_dim': 3,
},
)

with qdrant_doc:
qdrant_doc.extend(
[
Document(id=f'r{i}', embedding=np.ones((3,)) * i)
for i in range(nrof_docs)
],
)

def _mock_search(*args, **kwargs):
assert kwargs['search_params'].hnsw_ef == hnsw_ef
return []

monkeypatch.setattr(qdrant_doc._client, 'search', _mock_search)

np_query = np.array([2, 1, 3])

qdrant_doc.find(np_query, limit=10, search_params={"hnsw_ef": hnsw_ef})