Skip to content
3 changes: 2 additions & 1 deletion docarray/array/storage/redis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from docarray.helper import dataclass_from_dict, random_identity, filter_dict

from redis import Redis
from redis.commands.search.field import NumericField, TextField, VectorField
from redis.commands.search.field import NumericField, TextField, VectorField, GeoField
from redis.commands.search.indexDefinition import IndexDefinition

if TYPE_CHECKING:
Expand Down Expand Up @@ -46,6 +46,7 @@ class BackendMixin(BaseBackendMixin):
'float': TypeMap(type='float', converter=NumericField),
'double': TypeMap(type='double', converter=NumericField),
'long': TypeMap(type='long', converter=NumericField),
'geo': TypeMap(type='geo', converter=GeoField),
}

def _init_storage(
Expand Down
4 changes: 2 additions & 2 deletions docarray/array/storage/redis/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _find(
self,
query: 'RedisArrayType',
limit: Union[int, float] = 20,
filter: Optional[Dict] = None,
filter: Optional[Union[str, Dict]] = None,
**kwargs,
) -> List['DocumentArray']:

Expand Down Expand Up @@ -107,7 +107,7 @@ def _find_with_filter(

def _filter(
self,
filter: Dict,
filter: Union[str, Dict],
limit: Union[int, float] = 20,
) -> 'DocumentArray':

Expand Down
35 changes: 30 additions & 5 deletions docs/advanced/document-store/redis.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ You can check the default values in [the docarray source code](https://github.co
For vector search configurations, default values are those of the database backend, which you can find in the [Redis documentation](https://redis.io/docs/stack/search/reference/vectors/).

```{note}
We will support geo-filtering soon.
The benchmark test is on the way.
```

Expand Down Expand Up @@ -247,8 +246,7 @@ integer in `columns` configuration (`'field': 'int'`) and use a filter query tha

One can search with user-defined query filters using the `.find` method. Such queries follow the [Redis Search Query Syntax](https://redis.io/docs/stack/search/reference/query_syntax/).

Consider a case where you store Documents with a tag of `price` into Redis and you want to retrieve all Documents
with `price` less than or equal to some `max_price` value.
Consider a case where you store Documents with a tag of `price` into Redis and you want to retrieve all Documents with `price` less than or equal to some `max_price` value.

You can index such Documents as follows:

Expand All @@ -272,8 +270,7 @@ for price in da[:, 'tags__price']:
print(f'\t price={price}')
```

Then you can retrieve all documents whose price is less than or equal to `max_price` by applying the following
filter:
Then you can retrieve all documents whose price is less than or equal to `max_price` by applying the following filter:

```python
max_price = 3
Expand All @@ -298,6 +295,34 @@ This would print
price=3
```

With Redis as storage backend, you can also do geospatial searches. You can index Documents with a tag of `geo` type and retrieve all Documents that are within some `max_distance` from one earth coordinates as follows :

```python
from docarray import Document, DocumentArray

n_dim = 3
da = DocumentArray(
storage='redis',
config={
'n_dim': n_dim,
'columns': {'location': 'geo'},
},
)

with da:
da.extend(
[
Document(id=f'r{i}', tags={'location': f"{-98.17+i},{38.71+i}"})
for i in range(10)
]
)

max_distance = 1000
filter = f'@location:[-98.71 38.71 {max_distance} km] '
results = da.find(filter=filter, limit=n_limit)
```


(vector-search-index)=
### Update Vector Search Indexing Schema

Expand Down
47 changes: 44 additions & 3 deletions tests/unit/array/mixins/test_find.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import operator
from math import radians

import numpy as np
import pytest

from docarray import DocumentArray, Document
from docarray import Document, DocumentArray
from docarray.math import ndarray
import operator
from sklearn.metrics.pairwise import haversine_distances


def test_customize_metric_fn():
Expand Down Expand Up @@ -632,6 +634,45 @@ def test_redis_category_filter(filter, checker, columns, start_storage):
assert all([checker(r) for r in results])


def test_redis_geo_filter(start_storage):
n_dim = 128
da = DocumentArray(
storage='redis',
config={
'n_dim': n_dim,
'columns': {'location': 'geo'},
},
)

da.extend(
[
Document(
embedding=np.random.rand(n_dim),
tags={'location': f"{-98.17+i},{38.71+i}"},
)
for i in range(10)
]
)

filter = '@location:[-98.71 38.71 800 km] '

results = da.find(np.random.rand(n_dim), filter=filter)
assert len(results) > 0

for r in results:
lon1, lat1, lon2, lat2 = map(
radians,
[
-98.71,
38.71,
float(r.tags['location'].split(',')[0]),
float(r.tags['location'].split(',')[1]),
],
)
distance = haversine_distances([[lon1, lat1], [lon2, lat2]]) * 6371
assert distance[0][1] < 800


@pytest.mark.parametrize('storage', ['memory'])
@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])
def test_unsupported_pre_filtering(storage, start_storage, columns):
Expand Down
16 changes: 1 addition & 15 deletions tests/unit/array/storage/redis/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@ def _save_offset2ids(self):
pass


type_convert = {
'int': b'NUMERIC',
'float': b'NUMERIC',
'double': b'NUMERIC',
'long': b'NUMERIC',
'str': b'TEXT',
'bytes': b'TEXT',
'bool': b'NUMERIC',
}


@pytest.mark.parametrize('distance', ['L2', 'IP', 'COSINE'])
@pytest.mark.parametrize(
'method,initial_cap,ef_construction,block_size',
Expand All @@ -43,12 +32,9 @@ def _save_offset2ids(self):
@pytest.mark.parametrize(
'columns',
[
[('attr1', 'str'), ('attr2', 'bytes')],
[('attr1', 'int'), ('attr2', 'float')],
[('attr1', 'double'), ('attr2', 'long'), ('attr3', 'int')],
{'attr1': 'str', 'attr2': 'bytes'},
{'attr1': 'int', 'attr2': 'float'},
{'attr1': 'double', 'attr2': 'long', 'attr3': 'int'},
{'attr1': 'double', 'attr2': 'long', 'attr3': 'geo'},
],
)
@pytest.mark.parametrize(
Expand Down