diff --git a/docarray/array/mixins/find.py b/docarray/array/mixins/find.py index 8730d2a98b4..b467e945445 100644 --- a/docarray/array/mixins/find.py +++ b/docarray/array/mixins/find.py @@ -161,9 +161,9 @@ def find( elif isinstance(query, str) or ( isinstance(query, list) and isinstance(query[0], str) ): - if filter is not None: - raise ValueError('cannot use filter with text search') - result = self._find_by_text(query, index=index, limit=limit, **kwargs) + result = self._find_by_text( + query, index=index, filter=filter, limit=limit, **kwargs + ) if isinstance(query, str): return result[0] else: diff --git a/docarray/array/storage/elastic/find.py b/docarray/array/storage/elastic/find.py index d588251ad1e..43a680f8e11 100644 --- a/docarray/array/storage/elastic/find.py +++ b/docarray/array/storage/elastic/find.py @@ -79,7 +79,11 @@ def _find_similar_vectors( return da def _find_similar_documents_from_text( - self, query: str, index: str = 'text', limit: int = 10 + self, + query: str, + index: str = 'text', + filter: Union[dict, list] = None, + limit: int = 10, ): """ Return keyword matches for the input query @@ -89,9 +93,18 @@ def _find_similar_documents_from_text( the closest Document objects for each of the queries in `query`. """ + query = { + "bool": { + "must": [ + {"match": {index: query}}, + ], + "filter": filter, + } + } + resp = self._client.search( index=self._config.index_name, - query={'match': {index: query}}, + query=query, source=['id', 'blob', 'text'], size=limit, ) @@ -106,7 +119,11 @@ def _find_similar_documents_from_text( return da def _find_by_text( - self, query: Union[str, List[str]], index: str = 'text', limit: int = 10 + self, + query: Union[str, List[str]], + index: str = 'text', + filter: Union[dict, list] = None, + limit: int = 10, ): if isinstance(query, str): query = [query] @@ -115,6 +132,7 @@ def _find_by_text( self._find_similar_documents_from_text( q, index=index, + filter=filter, limit=limit, ) for q in query diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index 94473748e6a..bacb539d36a 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -117,6 +117,7 @@ def _find_by_text( self, query: Union[str, List[str]], index: str = 'text', + filter: Optional[Union[str, Dict]] = None, limit: Union[int, float] = 20, **kwargs, ): @@ -127,6 +128,7 @@ def _find_by_text( self._find_similar_documents_from_text( q, index=index, + filter=filter, limit=limit, **kwargs, ) @@ -137,10 +139,17 @@ def _find_similar_documents_from_text( self, query: str, index: str = 'text', + filter: Optional[Union[str, Dict]] = None, limit: Union[int, float] = 20, **kwargs, ): query_str = _build_query_str(query) + + if filter: + filter_str = _get_redis_filter_query(filter) + else: + filter_str = '' + scorer = kwargs.get('scorer', 'BM25') if scorer not in [ 'BM25', @@ -154,7 +163,7 @@ def _find_similar_documents_from_text( f'Expecting a valid text similarity ranking algorithm, got {scorer} instead' ) - q = Query(f'@{index}:{query_str}').scorer(scorer).paging(0, limit) + q = Query(f'@{index}:{query_str} {filter_str}').scorer(scorer).paging(0, limit) results = self._client.ft(index_name=self._config.index_name).search(q).docs diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 975a13b1379..bf7fa00226c 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -145,6 +145,66 @@ def test_find_by_text(storage, config, start_storage): assert len(results[1]) == 0 # 'token' is not present in da vocabulary +@pytest.mark.parametrize( + 'storage, config, filter', + [ + ( + 'elasticsearch', + {'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True}, + None, + ), + ( + 'elasticsearch', + {'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True}, + { + 'range': { + 'i': { + 'lte': 5, + } + } + }, + ), + ( + 'elasticsearch', + {'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True}, + [ + { + 'range': { + 'i': { + 'lte': 5, + } + } + } + ], + ), + ('redis', {'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True}, None), + ( + 'redis', + {'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True}, + '@i:[-inf 5]', + ), + ], +) +def test_find_by_text_and_filter(storage, config, filter, start_storage): + da = DocumentArray(storage=storage, config=config) + with da: + da.extend( + [Document(id=f'{i}', tags={'i': i}, text=f'pizza {i}') for i in range(10)] + ) + da.extend( + [ + Document(id=f'{i+10}', tags={'i': i}, text=f'noodles {i}') + for i in range(10) + ] + ) + + results = da.find('pizza', filter=filter) + + assert len(results) > 0 + assert all([int(r.id) < 10 for r in results]) + assert all([r.tags['i'] < 10 for r in results]) + + @pytest.mark.parametrize( 'storage, config', [