Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docarray/array/mixins/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def find(
elif isinstance(query, str) or (
isinstance(query, list) and isinstance(query[0], str)
):
if filter is not None:
raise ValueError('cannot use filter with text search')
result = self._find_by_text(query, index=index, limit=limit, **kwargs)
result = self._find_by_text(
query, index=index, filter=filter, limit=limit, **kwargs
)
if isinstance(query, str):
return result[0]
else:
Expand Down
24 changes: 21 additions & 3 deletions docarray/array/storage/elastic/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def _find_similar_vectors(
return da

def _find_similar_documents_from_text(
self, query: str, index: str = 'text', limit: int = 10
self,
query: str,
index: str = 'text',
filter: Union[dict, list] = None,
limit: int = 10,
):
"""
Return keyword matches for the input query
Expand All @@ -89,9 +93,18 @@ def _find_similar_documents_from_text(
the closest Document objects for each of the queries in `query`.
"""

query = {
"bool": {
"must": [
{"match": {index: query}},
],
"filter": filter,
}
}

resp = self._client.search(
index=self._config.index_name,
query={'match': {index: query}},
query=query,
source=['id', 'blob', 'text'],
size=limit,
)
Expand All @@ -106,7 +119,11 @@ def _find_similar_documents_from_text(
return da

def _find_by_text(
self, query: Union[str, List[str]], index: str = 'text', limit: int = 10
self,
query: Union[str, List[str]],
index: str = 'text',
filter: Union[dict, list] = None,
limit: int = 10,
):
if isinstance(query, str):
query = [query]
Expand All @@ -115,6 +132,7 @@ def _find_by_text(
self._find_similar_documents_from_text(
q,
index=index,
filter=filter,
limit=limit,
)
for q in query
Expand Down
11 changes: 10 additions & 1 deletion docarray/array/storage/redis/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _find_by_text(
self,
query: Union[str, List[str]],
index: str = 'text',
filter: Optional[Union[str, Dict]] = None,
limit: Union[int, float] = 20,
**kwargs,
):
Expand All @@ -127,6 +128,7 @@ def _find_by_text(
self._find_similar_documents_from_text(
q,
index=index,
filter=filter,
limit=limit,
**kwargs,
)
Expand All @@ -137,10 +139,17 @@ def _find_similar_documents_from_text(
self,
query: str,
index: str = 'text',
filter: Optional[Union[str, Dict]] = None,
limit: Union[int, float] = 20,
**kwargs,
):
query_str = _build_query_str(query)

if filter:
filter_str = _get_redis_filter_query(filter)
else:
filter_str = ''

scorer = kwargs.get('scorer', 'BM25')
if scorer not in [
'BM25',
Expand All @@ -154,7 +163,7 @@ def _find_similar_documents_from_text(
f'Expecting a valid text similarity ranking algorithm, got {scorer} instead'
)

q = Query(f'@{index}:{query_str}').scorer(scorer).paging(0, limit)
q = Query(f'@{index}:{query_str} {filter_str}').scorer(scorer).paging(0, limit)

results = self._client.ft(index_name=self._config.index_name).search(q).docs

Expand Down
60 changes: 60 additions & 0 deletions tests/unit/array/mixins/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,66 @@ def test_find_by_text(storage, config, start_storage):
assert len(results[1]) == 0 # 'token' is not present in da vocabulary


@pytest.mark.parametrize(
'storage, config, filter',
[
(
'elasticsearch',
{'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True},
None,
),
(
'elasticsearch',
{'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True},
{
'range': {
'i': {
'lte': 5,
}
}
},
),
(
'elasticsearch',
{'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True},
[
{
'range': {
'i': {
'lte': 5,
}
}
}
],
),
('redis', {'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True}, None),
(
'redis',
{'n_dim': 32, 'columns': {'i': 'int'}, 'index_text': True},
'@i:[-inf 5]',
),
],
)
def test_find_by_text_and_filter(storage, config, filter, start_storage):
da = DocumentArray(storage=storage, config=config)
with da:
da.extend(
[Document(id=f'{i}', tags={'i': i}, text=f'pizza {i}') for i in range(10)]
)
da.extend(
[
Document(id=f'{i+10}', tags={'i': i}, text=f'noodles {i}')
for i in range(10)
]
)

results = da.find('pizza', filter=filter)

assert len(results) > 0
assert all([int(r.id) < 10 for r in results])
assert all([r.tags['i'] < 10 for r in results])


@pytest.mark.parametrize(
'storage, config',
[
Expand Down