diff --git a/docarray/array/mixins/find.py b/docarray/array/mixins/find.py index 71f6d0d60eb..b6f509edf5e 100644 --- a/docarray/array/mixins/find.py +++ b/docarray/array/mixins/find.py @@ -131,7 +131,6 @@ def find( :return: a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ - index_da = self._get_index(subindex_name=on) if index_da is not self: return index_da.find( @@ -145,7 +144,6 @@ def find( index, on=None, ) - from docarray import Document, DocumentArray if isinstance(query, dict): diff --git a/docarray/array/storage/elastic/backend.py b/docarray/array/storage/elastic/backend.py index 27279eff94a..28bb90f07e6 100644 --- a/docarray/array/storage/elastic/backend.py +++ b/docarray/array/storage/elastic/backend.py @@ -1,6 +1,6 @@ import copy import uuid -from dataclasses import dataclass, field, asdict +from dataclasses import dataclass, field import warnings from typing import ( diff --git a/docarray/array/storage/elastic/find.py b/docarray/array/storage/elastic/find.py index 9e32591fdf6..db9d9f460fa 100644 --- a/docarray/array/storage/elastic/find.py +++ b/docarray/array/storage/elastic/find.py @@ -48,7 +48,6 @@ def _find_similar_vectors( :return: DocumentArray containing the closest documents to the query if it is a single query, otherwise a list of DocumentArrays containing the closest Document objects for each of the queries in `query`. """ - query = to_numpy_array(query) is_all_zero = np.all(query == 0) if is_all_zero: @@ -151,7 +150,6 @@ def _find_with_filter(self, query: Dict, limit: Optional[Union[int, float]] = 20 query=query, size=limit, ) - list_of_hits = resp['hits']['hits'] da = DocumentArray() diff --git a/tests/unit/array/storage/elastic/test_find.py b/tests/unit/array/storage/elastic/test_find.py index f5957972889..745de0c60cb 100644 --- a/tests/unit/array/storage/elastic/test_find.py +++ b/tests/unit/array/storage/elastic/test_find.py @@ -33,3 +33,74 @@ def _mock_knn_search(**kwargs): np_query = np.array([2, 1, 3]) elastic_doc.find(np_query, limit=10, num_candidates=num_candidates) + + +def test_filter(start_storage): + import random + import string + + elastic_da = DocumentArray( + storage='elasticsearch', + config={ + 'n_dim': 2, + 'columns': { + 'A': 'str', + 'B': 'str', + 'V': 'str', + 'D': 'str', + 'E': 'str', + 'F': 'str', + 'G': 'str', + }, + }, + ) + + def ran(): + return ''.join(random.choices(string.ascii_uppercase + string.digits, k=10)) + + def ran_size(): + sizes = ['S', 'M', 'L', 'XL'] + return sizes[random.randint(0, len(sizes) - 1)] + + def ran_type(): + types = ['A', 'B', 'C', 'D'] + return types[random.randint(0, len(types) - 1)] + + def ran_stype(): + stypes = ['SA', 'SB', 'SC', 'SD'] + return stypes[random.randint(0, len(stypes) - 1)] + + docs = DocumentArray( + [ + Document( + id=f'r{i}', + embedding=np.random.rand(2), + tags={ + 'A': ran(), + 'B': ran_stype(), + 'C': ran_size(), + 'D': ran_type(), + 'E': ran(), + 'F': ran_type(), + 'G': f'G{i}', + }, + ) + for i in range(50) + ] + ) + + with elastic_da: + elastic_da.extend(docs) + + res = elastic_da.find(query=Document(embedding=docs[0].embedding)) + assert len(res) > 0 + assert res[0][0].tags['G'] == 'G0' + filter_ = {'match': {'G': 'G3'}} + + res = elastic_da.find(filter=filter_) + assert len(res) > 0 + assert res[0].tags['G'] == 'G3' + + res = elastic_da.find(query=Document(embedding=docs[0].embedding), filter=filter_) + assert len(res) > 0 + assert res[0][0].tags['G'] == 'G3'