From 083218abb86b9c8ff81e2aea04be1bc92573f78d Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 28 Feb 2022 16:31:36 +0100 Subject: [PATCH 1/4] fix: implement private _find method --- docarray/array/storage/qdrant/find.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docarray/array/storage/qdrant/find.py b/docarray/array/storage/qdrant/find.py index 89da6fe4613..736e820442a 100644 --- a/docarray/array/storage/qdrant/find.py +++ b/docarray/array/storage/qdrant/find.py @@ -74,8 +74,8 @@ def _find_similar_vectors(self, q: 'QdrantArrayType', limit=10): return DocumentArray(docs) - def find( - self, query: 'QdrantArrayType', limit: int = 10 + def _find( + self, query: 'QdrantArrayType', limit: int = 10, **kwargs ) -> Union['DocumentArray', List['DocumentArray']]: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be used in Qdrant. From ca25d299926a54fed508459ba5a71b7412bfa8d2 Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 28 Feb 2022 16:32:14 +0100 Subject: [PATCH 2/4] fix: support torch.Tensor in find --- docarray/array/storage/weaviate/find.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docarray/array/storage/weaviate/find.py b/docarray/array/storage/weaviate/find.py index 58a5d008695..c7d4440cb45 100644 --- a/docarray/array/storage/weaviate/find.py +++ b/docarray/array/storage/weaviate/find.py @@ -10,6 +10,7 @@ from .... import Document, DocumentArray from ....math import ndarray from ....math.helper import EPSILON +from ....math.ndarray import to_numpy_array from ....score import NamedScore if TYPE_CHECKING: @@ -27,7 +28,7 @@ class FindMixin: def _find_similar_vectors(self, query: 'WeaviateArrayType', limit=10): - + query = to_numpy_array(query) is_all_zero = np.all(query == 0) if is_all_zero: query = query + EPSILON From 2b35be1d79e987650ab72a9a30d77978a9138eeb Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 28 Feb 2022 16:32:44 +0100 Subject: [PATCH 3/4] test: cover torch.Tensor and qdrant in test_match --- tests/unit/array/mixins/test_match.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/unit/array/mixins/test_match.py b/tests/unit/array/mixins/test_match.py index 6ed626d8a11..95fbebf95a5 100644 --- a/tests/unit/array/mixins/test_match.py +++ b/tests/unit/array/mixins/test_match.py @@ -71,12 +71,21 @@ def doc_lists_to_doc_arrays(doc_lists, *args, **kwargs): @pytest.mark.parametrize( 'storage, config', - [('weaviate', WeaviateConfig(3)), ('pqlite', {'n_dim': 3})], + [('weaviate', {'n_dim': 3}), ('pqlite', {'n_dim': 3}), ('qdrant', {'n_dim': 3})], ) @pytest.mark.parametrize('limit', [1, 2, 3]) @pytest.mark.parametrize('exclude_self', [True, False]) -def test_match(storage, config, doc_lists, limit, exclude_self, start_storage): +@pytest.mark.parametrize('as_tensor', [True, False]) +def test_match( + storage, config, doc_lists, limit, exclude_self, start_storage, as_tensor +): D1, D2 = doc_lists_to_doc_arrays(doc_lists) + if as_tensor: + for d in D1: + d.embedding = torch.from_numpy(d.embedding) + + for d in D2: + d.embedding = torch.from_numpy(d.embedding) if config: da = DocumentArray(D2, storage=storage, config=config) From 47638f99230aa620e61c462707c9ed94d8853c9f Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Mon, 28 Feb 2022 17:01:36 +0100 Subject: [PATCH 4/4] fix: align qdrant _find interface with other storage backends --- docarray/array/storage/qdrant/find.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docarray/array/storage/qdrant/find.py b/docarray/array/storage/qdrant/find.py index 736e820442a..b3e08c2bf31 100644 --- a/docarray/array/storage/qdrant/find.py +++ b/docarray/array/storage/qdrant/find.py @@ -76,19 +76,19 @@ def _find_similar_vectors(self, q: 'QdrantArrayType', limit=10): def _find( self, query: 'QdrantArrayType', limit: int = 10, **kwargs - ) -> Union['DocumentArray', List['DocumentArray']]: + ) -> List['DocumentArray']: """Returns approximate nearest neighbors given a batch of input queries. :param query: input supported to be used in Qdrant. :param limit: number of retrieved items - :return: DocumentArray containing the closest documents to the query if it is a single query, otherwise a list of DocumentArrays containing - the closest Document objects for each of the queries in `query`. + + :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ num_rows, _ = ndarray.get_array_rows(query) if num_rows == 1: - return self._find_similar_vectors(query, limit=limit) + return [self._find_similar_vectors(query, limit=limit)] else: closest_docs = [] for q in query: