-
Notifications
You must be signed in to change notification settings - Fork 238
feat(qdrant): pass search_params in find #668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dc136ad
3efa311
6ed10a8
a631ab1
1f7182d
e74780c
239f54e
dd3a4c8
2c1bb62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,23 +1,16 @@ | ||
| from abc import abstractmethod | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| TypeVar, | ||
| Sequence, | ||
| List, | ||
| Union, | ||
| Optional, | ||
| Dict, | ||
| ) | ||
| from qdrant_client.http.models.models import Distance | ||
| from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, Union | ||
|
|
||
| from docarray import Document, DocumentArray | ||
| from docarray.math import ndarray | ||
| from docarray.score import NamedScore | ||
| from qdrant_client.http import models as rest | ||
| from qdrant_client.http.models.models import Distance | ||
|
|
||
| if TYPE_CHECKING: # pragma: no cover | ||
| import numpy as np | ||
| import tensorflow | ||
| import torch | ||
| import numpy as np | ||
| from qdrant_client import QdrantClient | ||
|
|
||
| QdrantArrayType = TypeVar( | ||
|
|
@@ -51,15 +44,20 @@ def distance(self) -> 'Distance': | |
| raise NotImplementedError() | ||
|
|
||
| def _find_similar_vectors( | ||
| self, q: 'QdrantArrayType', limit: int = 10, filter: Optional[Dict] = None | ||
| self, | ||
| q: 'QdrantArrayType', | ||
| limit: int = 10, | ||
| filter: Optional[Dict] = None, | ||
| search_params: Optional[Dict] = {}, | ||
AnneYang720 marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. set it as None as default and handle it properly |
||
| **kwargs, | ||
| ): | ||
| query_vector = self._map_embedding(q) | ||
|
|
||
| search_result = self.client.search( | ||
| self.collection_name, | ||
| query_vector=query_vector, | ||
| query_filter=filter, | ||
| search_params=None, | ||
| search_params=rest.SearchParams(**search_params), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. u have said that |
||
| top=limit, | ||
| append_payload=['_serialized'], | ||
| ) | ||
|
|
@@ -82,12 +80,14 @@ def _find( | |
| query: 'QdrantArrayType', | ||
| limit: int = 10, | ||
| filter: Optional[Dict] = None, | ||
| search_params: Optional[Dict] = {}, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
| **kwargs, | ||
| ) -> List['DocumentArray']: | ||
| """Returns approximate nearest neighbors given a batch of input queries. | ||
| :param query: input supported to be used in Qdrant. | ||
| :param limit: number of retrieved items | ||
| :param filter: filter query used for pre-filtering | ||
| :param search_params: additional parameters of the search | ||
|
|
||
|
|
||
| :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. | ||
|
|
@@ -96,11 +96,17 @@ def _find( | |
| num_rows, _ = ndarray.get_array_rows(query) | ||
|
|
||
| if num_rows == 1: | ||
| return [self._find_similar_vectors(query, limit=limit, filter=filter)] | ||
| return [ | ||
| self._find_similar_vectors( | ||
| query, limit=limit, filter=filter, search_params=search_params | ||
| ) | ||
| ] | ||
| else: | ||
| closest_docs = [] | ||
| for q in query: | ||
| da = self._find_similar_vectors(q, limit=limit, filter=filter) | ||
| da = self._find_similar_vectors( | ||
| q, limit=limit, filter=filter, search_params=search_params | ||
| ) | ||
| closest_docs.append(da) | ||
| return closest_docs | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| from docarray import Document, DocumentArray | ||
| import numpy as np | ||
|
|
||
|
|
||
| def test_success_find_with_added_kwargs(start_storage, monkeypatch): | ||
| nrof_docs = 10 | ||
| hnsw_ef = 64 | ||
|
|
||
| qdrant_doc = DocumentArray( | ||
| storage='qdrant', | ||
| config={ | ||
| 'n_dim': 3, | ||
| }, | ||
| ) | ||
|
|
||
| with qdrant_doc: | ||
| qdrant_doc.extend( | ||
| [ | ||
| Document(id=f'r{i}', embedding=np.ones((3,)) * i) | ||
| for i in range(nrof_docs) | ||
| ], | ||
| ) | ||
|
|
||
| def _mock_search(*args, **kwargs): | ||
| assert kwargs['search_params'].hnsw_ef == hnsw_ef | ||
| return [] | ||
|
|
||
| monkeypatch.setattr(qdrant_doc._client, 'search', _mock_search) | ||
|
|
||
| np_query = np.array([2, 1, 3]) | ||
|
|
||
| qdrant_doc.find(np_query, limit=10, search_params={"hnsw_ef": hnsw_ef}) |
Uh oh!
There was an error while loading. Please reload this page.