diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index 2459333c1ed..cd38fd98fb1 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -7,7 +7,18 @@ from docarray.math.ndarray import to_numpy_array from docarray.score import NamedScore -from redis.commands.search.query import NumericFilter, Query +from redis.commands.search.query import Query +from redis.commands.search.querystring import ( + DistjunctUnion, + IntersectNode, + equal, + ge, + gt, + intersect, + le, + lt, + union, +) if TYPE_CHECKING: import tensorflow @@ -28,14 +39,18 @@ def _find_similar_vectors( self, query: 'RedisArrayType', filter: Optional[Dict] = None, - limit: Optional[Union[int, float]] = 20, + limit: int = 20, **kwargs, ): - query_str = self._build_query_str(filter) if filter else "*" + if filter: + nodes = _build_query_nodes(filter) + query_str = intersect(*nodes).to_string() + else: + query_str = '*' q = ( - Query(f'{query_str}=>[KNN {limit} @embedding $vec AS vector_score]') + Query(f'({query_str})=>[KNN {limit} @embedding $vec AS vector_score]') .sort_by('vector_score') .paging(0, limit) .dialect(2) @@ -58,7 +73,7 @@ def _find_similar_vectors( def _find( self, query: 'RedisArrayType', - limit: Optional[Union[int, float]] = 20, + limit: int = 20, filter: Optional[Dict] = None, **kwargs, ) -> List['DocumentArray']: @@ -73,9 +88,10 @@ def _find( for q in query ] - def _find_with_filter(self, filter: Dict, limit: Optional[Union[int, float]] = 20): - s = self._build_query_str(filter) - q = Query(s) + def _find_with_filter(self, filter: Dict, limit: int = 20): + nodes = _build_query_nodes(filter) + query_str = intersect(*nodes).to_string() + q = Query(query_str) q.paging(0, limit) results = self._client.ft(index_name=self._config.index_name).search(q).docs @@ -86,42 +102,55 @@ def _find_with_filter(self, filter: Dict, limit: Optional[Union[int, float]] = 2 da.append(doc) return da - def _filter( - self, filter: Dict, limit: Optional[Union[int, float]] = 20 - ) -> 'DocumentArray': + def _filter(self, filter: Dict, limit: int = 20) -> 'DocumentArray': return self._find_with_filter(filter, limit=limit) - def _build_query_str(self, filter: Dict) -> str: - INF = "+inf" - NEG_INF = "-inf" - s = "(" - - for key in filter: - operator = list(filter[key].keys())[0] - value = filter[key][operator] - if operator == '$gt': - s += f"@{key}:[({value} {INF}] " - elif operator == '$gte': - s += f"@{key}:[{value} {INF}] " - elif operator == '$lt': - s += f"@{key}:[{NEG_INF} ({value}] " - elif operator == '$lte': - s += f"@{key}:[{NEG_INF} {value}] " - elif operator == '$eq': - if type(value) is int: - s += f"@{key}:[{value} {value}] " - elif type(value) is bool: - s += f"@{key}:[{int(value)} {int(value)}] " - else: - s += f"@{key}:{value} " - elif operator == '$ne': - if type(value) is int: - s += f"-@{key}:[{value} {value}] " - elif type(value) is bool: - s += f"-@{key}:[{int(value)} {int(value)}] " - else: - s += f"-@{key}:{value} " - s += ")" - - return s + +def _build_query_node(key, condition): + operator = list(condition.keys())[0] + value = condition[operator] + + query_dict = {} + + if operator in ['$ne', '$eq']: + if isinstance(value, bool): + query_dict[key] = equal(int(value)) + elif isinstance(value, (int, float)): + query_dict[key] = equal(value) + else: + query_dict[key] = value + elif operator == '$gt': + query_dict[key] = gt(value) + elif operator == '$gte': + query_dict[key] = ge(value) + elif operator == '$lt': + query_dict[key] = lt(value) + elif operator == '$lte': + query_dict[key] = le(value) + else: + raise ValueError( + f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead' + ) + + if operator == '$ne': + return DistjunctUnion(**query_dict) + return IntersectNode(**query_dict) + + +def _build_query_nodes(filter): + nodes = [] + for k, v in filter.items(): + if k == '$and': + children = _build_query_nodes(v) + node = intersect(*children) + nodes.append(node) + elif k == '$or': + children = _build_query_nodes(v) + node = union(*children) + nodes.append(node) + else: + child = _build_query_node(k, v) + nodes.append(child) + + return nodes diff --git a/docs/advanced/document-store/redis.md b/docs/advanced/document-store/redis.md index 763d6378606..9a67404d734 100644 --- a/docs/advanced/document-store/redis.md +++ b/docs/advanced/document-store/redis.md @@ -41,7 +41,7 @@ da = DocumentArray( ) ``` -The usage would be the same as the ordinary DocumentArray, but the dimension of an embedding for a Document must be provided at creation time. +The usage will be the same as the ordinary DocumentArray, but the dimension of an embedding for a Document must be provided at creation time. ```{caution} Currently, one Redis server instance can only store a single DocumentArray. @@ -123,7 +123,7 @@ Other functions behave the same as in-memory DocumentArray. ### Vector search with filter query -You can perform Vector Similarity Search based on [FLAT or HNSW algorithm](vector-search-index) and pre-filter results using a filter query that is based on [MongoDB's Query](https://www.mongodb.com/docs/manual/reference/operator/query/). We currently support a subset of those selectors: +You can perform Vector Similarity Search based on [FLAT or HNSW algorithm](vector-search-index) and pre-filter results using a filter query that is based on [MongoDB's Query](https://www.mongodb.com/docs/manual/reference/operator/query/). The following tags filters can be combine with `$and` and `$or`: - `$eq` - Equal to (number, string) - `$ne` - Not equal to (number, string) @@ -134,7 +134,7 @@ You can perform Vector Similarity Search based on [FLAT or HNSW algorithm](vecto Consider Documents with embeddings `[0, 0, 0]` up to `[9, 9, 9]` where the Document with embedding `[i, i, i]` -has tag `price` with a number value, and tag `color` with a string value. You can create such example with the following code: +has tag `price` with a number value, tag `color` with a string value and tag `stock` with a boolean value. You can create such example with the following code: ```python import numpy as np @@ -146,7 +146,7 @@ da = DocumentArray( storage='redis', config={ 'n_dim': n_dim, - 'columns': {'price': 'int', 'color': 'str'}, + 'columns': {'price': 'int', 'color': 'str', 'stock': 'bool'}, 'flush': True, 'distance': 'L2', }, @@ -155,7 +155,9 @@ da = DocumentArray( da.extend( [ Document( - id=f'{i}', embedding=i * np.ones(n_dim), tags={'price': i, 'color': 'red'} + id=f'{i}', + embedding=i * np.ones(n_dim), + tags={'price': i, 'color': 'blue', 'stock': i%2==0}, ) for i in range(10) ] @@ -165,58 +167,118 @@ da.extend( Document( id=f'{i+10}', embedding=i * np.ones(n_dim), - tags={'price': i, 'color': 'blue'}, + tags={'price': i, 'color': 'red', 'stock': i%2==0}, ) for i in range(10) ] ) -print('\nIndexed prices and colors:\n') -for embedding, price, color in zip( - da.embeddings, da[:, 'tags__price'], da[:, 'tags__color'] +print('\nIndexed price, color and stock:\n') +for embedding, price, color, stock in zip( + da.embeddings, da[:, 'tags__price'], da[:, 'tags__color'], da[:, 'tags__stock'] ): - print(f'\tembedding={embedding},\t price={price},\t color={color}') + print(f'\tembedding={embedding},\t color={color},\t stock={stock}') ``` -Consider the case where you want the nearest vectors to the embedding `[8., 8., 8.]`, with the restriction that -prices and colors must pass a filter. For example, let's consider that retrieved Documents must have a `price` value lower than or equal to `max_price` and have `color` equal to `color`. We can encode this information in Redis using `{'price': {'$lte': max_price}, 'color': {'$eq': color}}`. +Consider the case where you want the nearest vectors to the embedding `[8., 8., 8.]`, with the restriction that prices, colors and stock must pass a filter. For example, let's consider that retrieved Documents must have a `price` value lower than or equal to `max_price`, have `color` equal to `blue` and have `stock` equal to `True`. We can encode this information in Redis using + +```text +{ + "price": {"$lte": max_price}, + "color": {"$gt": color}, + "stock": {"$eq": True}, +} +``` +or + +```text +{ + "$and": { + "price": {"$lte": max_price}, + "color": {"$gt": color}, + "stock": {"$eq": True}, + } +} +``` Then the search with the proposed filter can be used as follows: ```python max_price = 7 -color = 'red' +color = "blue" n_limit = 5 np_query = np.ones(n_dim) * 8 print(f'\nQuery vector: \t{np_query}') -filter = {'price': {'$lte': max_price}, 'color': {'$eq': color}} +filter = { + "price": {"$lte": max_price}, + "color": {"$eq": color}, + "stock": {"$eq": True}, +} + results = da.find(np_query, filter=filter, limit=n_limit) print( - '\nEmbeddings Approximate Nearest Neighbours with "price" at most 7 and "color" red:\n' + '\nEmbeddings Approximate Nearest Neighbours with "price" at most 7, "color" blue and "stock" False:\n' ) -for embedding, price, color, score in zip( +for embedding, price, color, stock, score in zip( results.embeddings, results[:, 'tags__price'], results[:, 'tags__color'], + results[:, 'tags__stock'], results[:, 'scores'], ): print( - f' score={score["score"].value},\t embedding={embedding},\t price={price},\t color={color}' + f' score={score["score"].value},\t embedding={embedding},\t price={price},\t color={color},\t stock={stock}' ) ``` -This would print: +This will print: ```console -Embeddings Approximate Nearest Neighbours with "price" at most 7 and "color" red: +Embeddings Approximate Nearest Neighbours with "price" at most 7, "color" blue and "stock" False: + + score=12, embedding=[6. 6. 6.], price=6, color=blue, stock=True + score=48, embedding=[4. 4. 4.], price=4, color=blue, stock=True + score=108, embedding=[2. 2. 2.], price=2, color=blue, stock=True + score=192, embedding=[0. 0. 0.], price=0, color=blue, stock=True +``` +More example filter expresses +- A Nike shoes or price less than `100` + +```JSON +{ + "$or": { + "brand": {"$eq": "Nike"}, + "price": {"$lt": 100} + } +} +``` - score=3, embedding=[7. 7. 7.], price=7, color=red - score=12, embedding=[6. 6. 6.], price=6, color=red - score=27, embedding=[5. 5. 5.], price=5, color=red - score=48, embedding=[4. 4. 4.], price=4, color=red - score=75, embedding=[3. 3. 3.], price=3, color=red +- A Nike shoes **and** either price is less than `100` or color is `"blue"` + +```JSON +{ + "brand": {"$eq": "Nike"}, + "$or": { + "price": {"$lt": 100}, + "color": {"$eq": "blue"}, + }, +} +``` + +- A Nike shoes **or** both price is less than `100` and color is `"blue"` + +```JSON +{ + "$or": { + "brand": {"$eq": "Nike"}, + "$and": { + "price": {"$lt": 100}, + "color": {"$eq": "blue"}, + }, + } +} ``` (vector-search-index)= @@ -267,7 +329,7 @@ for embedding, score in zip( print(f' embedding={embedding},\t score={score["score"].value}') ``` -This would print: +This will print: ```console Embeddings Approximate Nearest Neighbours: diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 6ce3d4d1fbc..38d4e500603 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -517,9 +517,59 @@ def test_weaviate_filter_query(start_storage, columns): @pytest.mark.parametrize( 'columns', - [[('color', 'str'), ('isfake', 'bool')], {'color': 'str', 'isfake': 'bool'}], + [ + [('price', 'int'), ('category', 'str'), ('size', 'int'), ('isfake', 'bool')], + {'price': 'int', 'category': 'str', 'size': 'int', 'isfake': 'bool'}, + ], ) -def test_redis_category_filter(start_storage, columns): +@pytest.mark.parametrize( + 'filter,checker', + [ + ( + { + "$or": { + "price": {"$gt": 8}, + "category": {"$eq": "Shoes"}, + }, + }, + lambda r: r.tags['price'] > 8 or r.tags['category'] == 'Shoes', + ), + ( + { + "$and": { + "price": {"$ne": 8}, + "isfake": {"$eq": True}, + }, + }, + lambda r: r.tags['price'] != 8 and r.tags['isfake'] == True, + ), + ( + { + "$or": { + "price": {"$lt": 8}, + "isfake": {"$ne": True}, + }, + "size": {"$lte": 3}, + }, + lambda r: (r.tags['price'] < 8 or r.tags['isfake'] != True) + and r.tags['size'] <= 3, + ), + ( + { + "$or": { + "$and": { + "price": {"$gte": 8}, + "category": {"$ne": "Shoes"}, + }, + "size": {"$eq": 3}, + }, + }, + lambda r: (r.tags['price'] >= 8 and r.tags['category'] != 'Shoes') + or r.tags['size'] == 3, + ), + ], +) +def test_redis_category_filter(filter, checker, start_storage, columns): n_dim = 128 da = DocumentArray( storage='redis', @@ -535,7 +585,7 @@ def test_redis_category_filter(start_storage, columns): Document( id=f'r{i}', embedding=np.random.rand(n_dim), - tags={'color': 'red', 'isfake': True}, + tags={'price': i, 'category': 'Shoes', 'size': i, 'isfake': True}, ) for i in range(10) ] @@ -544,40 +594,22 @@ def test_redis_category_filter(start_storage, columns): da.extend( [ Document( - id=f'r{i}', - embedding=np.random.rand(n_dim), - tags={'color': 'blue', 'isfake': False}, - ) - for i in range(10, 20) - ] - ) - - da.extend( - [ - Document( - id=f'r{i}', + id=f'r{i+10}', embedding=np.random.rand(n_dim), - tags={'color': 'green', 'isfake': False}, + tags={ + 'price': i, + 'category': 'Jeans', + 'size': i, + 'isfake': False, + }, ) - for i in range(20, 30) + for i in range(10) ] ) - results = da.find(np.random.rand(n_dim), filter={'color': {'$eq': 'red'}}) - assert len(results) > 0 - assert all([(r.tags['color'] == 'red') for r in results]) - - results = da.find(np.random.rand(n_dim), filter={'color': {'$ne': 'red'}}) - assert len(results) > 0 - assert all([(r.tags['color'] != 'red') for r in results]) - - results = da.find(np.random.rand(n_dim), filter={'isfake': {'$eq': True}}) - assert len(results) > 0 - assert all([(r.tags['isfake'] == True) for r in results]) - - results = da.find(np.random.rand(n_dim), filter={'isfake': {'$ne': True}}) + results = da.find(np.random.rand(n_dim), filter=filter) assert len(results) > 0 - assert all([(r.tags['isfake'] == False) for r in results]) + assert all([checker(r) for r in results]) @pytest.mark.parametrize('storage', ['memory'])