From 883e2e5eb24f0dab4c7e4ea6c4fcae04cfea0139 Mon Sep 17 00:00:00 2001 From: AnneY Date: Tue, 30 Aug 2022 12:22:45 +0800 Subject: [PATCH 01/11] feat: add $and and $or in redis --- docarray/array/storage/redis/find.py | 115 ++++++++++++++++++--------- tests/unit/array/mixins/test_find.py | 90 +++++++++++++++------ 2 files changed, 144 insertions(+), 61 deletions(-) diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index 2459333c1ed..6258c426d14 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -7,7 +7,14 @@ from docarray.math.ndarray import to_numpy_array from docarray.score import NamedScore -from redis.commands.search.query import NumericFilter, Query +from redis.commands.search.query import Query +from redis.commands.search.querystring import lt, le, gt, ge, equal +from redis.commands.search.querystring import ( + intersect, + union, + DistjunctUnion, + IntersectNode, +) if TYPE_CHECKING: import tensorflow @@ -32,10 +39,16 @@ def _find_similar_vectors( **kwargs, ): - query_str = self._build_query_str(filter) if filter else "*" + if filter: + nodes = _build_query_nodes(filter) + query_str = intersect(*nodes).to_string() + else: + query_str = "*" + + print(f'query_str: {query_str}') q = ( - Query(f'{query_str}=>[KNN {limit} @embedding $vec AS vector_score]') + Query(f'({query_str})=>[KNN {limit} @embedding $vec AS vector_score]') .sort_by('vector_score') .paging(0, limit) .dialect(2) @@ -74,8 +87,9 @@ def _find( ] def _find_with_filter(self, filter: Dict, limit: Optional[Union[int, float]] = 20): - s = self._build_query_str(filter) - q = Query(s) + nodes = _build_query_nodes(filter) + query_str = intersect(*nodes).to_string() + q = Query(query_str) q.paging(0, limit) results = self._client.ft(index_name=self._config.index_name).search(q).docs @@ -92,36 +106,61 @@ def _filter( return self._find_with_filter(filter, limit=limit) - def _build_query_str(self, filter: Dict) -> str: - INF = "+inf" - NEG_INF = "-inf" - s = "(" - - for key in filter: - operator = list(filter[key].keys())[0] - value = filter[key][operator] - if operator == '$gt': - s += f"@{key}:[({value} {INF}] " - elif operator == '$gte': - s += f"@{key}:[{value} {INF}] " - elif operator == '$lt': - s += f"@{key}:[{NEG_INF} ({value}] " - elif operator == '$lte': - s += f"@{key}:[{NEG_INF} {value}] " - elif operator == '$eq': - if type(value) is int: - s += f"@{key}:[{value} {value}] " - elif type(value) is bool: - s += f"@{key}:[{int(value)} {int(value)}] " - else: - s += f"@{key}:{value} " - elif operator == '$ne': - if type(value) is int: - s += f"-@{key}:[{value} {value}] " - elif type(value) is bool: - s += f"-@{key}:[{int(value)} {int(value)}] " - else: - s += f"-@{key}:{value} " - s += ")" - - return s + +def _build_query_node(filter): + key = list(filter.keys())[0] + operator = list(filter[key].keys())[0] + value = filter[key][operator] + + query_dict = {} + + if operator == '$ne': + print(f'value: {value}, type: {isinstance(value, (int, float))}') + if isinstance(value, (bool)): + query_dict[key] = equal(int(value)) + elif isinstance(value, (int, float)): + query_dict[key] = equal(value) + else: + query_dict[key] = value + node = DistjunctUnion(**query_dict) + else: + if operator == '$gt': + query_dict[key] = gt(value) + elif operator == '$gte': + query_dict[key] = ge(value) + elif operator == '$lt': + query_dict[key] = lt(value) + elif operator == '$lte': + query_dict[key] = le(value) + elif operator == '$eq': + if isinstance(value, (bool)): + query_dict[key] = equal(int(value)) + elif isinstance(value, (int, float)): + query_dict[key] = equal(value) + else: + query_dict[key] = value + else: + raise ValueError( + f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead' + ) + node = IntersectNode(**query_dict) + + return node + + +def _build_query_nodes(filter): + nodes = [] + for k, v in filter.items(): + if k == "$and": + children = _build_query_nodes(v) + node = intersect(*children) + nodes.append(node) + elif k == "$or": + children = _build_query_nodes(v) + node = union(*children) + nodes.append(node) + else: + child = _build_query_node({k: v}) + nodes.append(child) + + return nodes diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 6ad66f22a5b..04964c80613 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -524,7 +524,12 @@ def test_redis_category_filter(start_storage): storage='redis', config={ 'n_dim': n_dim, - 'columns': [('color', 'str'), ('isfake', 'bool')], + 'columns': [ + ('price', 'int'), + ('category', 'str'), + ('size', 'int'), + ('isfake', 'bool'), + ], 'flush': True, }, ) @@ -534,7 +539,7 @@ def test_redis_category_filter(start_storage): Document( id=f'r{i}', embedding=np.random.rand(n_dim), - tags={'color': 'red', 'isfake': True}, + tags={'price': i, 'category': "Shoes", 'size': i, 'isfake': True}, ) for i in range(10) ] @@ -543,40 +548,79 @@ def test_redis_category_filter(start_storage): da.extend( [ Document( - id=f'r{i}', + id=f'r{i+10}', embedding=np.random.rand(n_dim), - tags={'color': 'blue', 'isfake': False}, + tags={ + 'price': i, + 'category': "Jeans", + 'size': i, + 'isfake': False, + }, ) - for i in range(10, 20) + for i in range(10) ] ) - da.extend( - [ - Document( - id=f'r{i}', - embedding=np.random.rand(n_dim), - tags={'color': 'green', 'isfake': False}, - ) - for i in range(20, 30) - ] + filter1 = { + "$or": { + "price": {"$gt": 8}, + "category": {"$eq": "Shoes"}, + }, + } + results = da.find(np.random.rand(n_dim), filter=filter1) + assert len(results) > 0 + assert all( + [(r.tags['price'] > 8 or r.tags['category'] == "Shoes") for r in results] ) - results = da.find(np.random.rand(n_dim), filter={'color': {'$eq': 'red'}}) + filter2 = { + "$and": { + "price": {"$ne": 8}, + "isfake": {"$eq": True}, + }, + } + results = da.find(np.random.rand(n_dim), filter=filter2) assert len(results) > 0 - assert all([(r.tags['color'] == 'red') for r in results]) + assert all([(r.tags['price'] != 8 and r.tags['isfake'] == True) for r in results]) - results = da.find(np.random.rand(n_dim), filter={'color': {'$ne': 'red'}}) - assert len(results) > 0 - assert all([(r.tags['color'] != 'red') for r in results]) + filter3 = { + "$or": { + "price": {"$lt": 8}, + "isfake": {"$ne": True}, + }, + "size": {"$lte": 3}, + } - results = da.find(np.random.rand(n_dim), filter={'isfake': {'$eq': True}}) + results = da.find(np.random.rand(n_dim), filter=filter3) assert len(results) > 0 - assert all([(r.tags['isfake'] == True) for r in results]) + assert all( + [ + ((r.tags['price'] < 8 or r.tags['isfake'] != True) and r.tags['size'] <= 3) + for r in results + ] + ) - results = da.find(np.random.rand(n_dim), filter={'isfake': {'$ne': True}}) + filter4 = { + "$or": { + "$and": { + "price": {"$gte": 8}, + "category": {"$ne": "Shoes"}, + }, + "size": {"$eq": 3}, + }, + } + + results = da.find(np.random.rand(n_dim), filter=filter4) assert len(results) > 0 - assert all([(r.tags['isfake'] == False) for r in results]) + assert all( + [ + ( + (r.tags['price'] >= 8 and r.tags['category'] != "Shoes") + or r.tags['size'] == 3 + ) + for r in results + ] + ) @pytest.mark.parametrize('storage', ['memory']) From 2c4f523bfde5faa534597de0fb6d34bd4714aa88 Mon Sep 17 00:00:00 2001 From: AnneY Date: Tue, 30 Aug 2022 17:32:00 +0800 Subject: [PATCH 02/11] docs: add redis $and and $or examples --- docs/advanced/document-store/redis.md | 109 +++++++++++++++++++++----- 1 file changed, 88 insertions(+), 21 deletions(-) diff --git a/docs/advanced/document-store/redis.md b/docs/advanced/document-store/redis.md index 7aad6e82956..8b3aea2de0f 100644 --- a/docs/advanced/document-store/redis.md +++ b/docs/advanced/document-store/redis.md @@ -123,7 +123,7 @@ Other functions behave the same as in-memory DocumentArray. ### Vector search with filter query -You can perform Vector Similarity Search based on [FLAT or HNSW algorithm](vector-search-index) and pre-filter results using a filter query that is based on [MongoDB's Query](https://www.mongodb.com/docs/manual/reference/operator/query/). We currently support a subset of those selectors: +You can perform Vector Similarity Search based on [FLAT or HNSW algorithm](vector-search-index) and pre-filter results using a filter query that is based on [MongoDB's Query](https://www.mongodb.com/docs/manual/reference/operator/query/). The following tags filters can be combine with `$and` and `$or`: - `$eq` - Equal to (number, string) - `$ne` - Not equal to (number, string) @@ -134,7 +134,7 @@ You can perform Vector Similarity Search based on [FLAT or HNSW algorithm](vecto Consider Documents with embeddings `[0, 0, 0]` up to `[9, 9, 9]` where the Document with embedding `[i, i, i]` -has tag `price` with a number value, and tag `color` with a string value. You can create such example with the following code: +has tag `price` with a number value, tag `color` with a string value and tag `stock` with a boolean value. You can create such example with the following code: ```python import numpy as np @@ -146,7 +146,11 @@ da = DocumentArray( storage='redis', config={ 'n_dim': n_dim, - 'columns': [('price', 'int'), ('color', 'str')], + 'columns': [ + ('price', 'int'), + ('color', 'str'), + ('stock', 'bool'), + ], 'flush': True, 'distance': 'L2', }, @@ -155,7 +159,9 @@ da = DocumentArray( da.extend( [ Document( - id=f'{i}', embedding=i * np.ones(n_dim), tags={'price': i, 'color': 'red'} + id=f'{i}', + embedding=i * np.ones(n_dim), + tags={'price': i, 'color': 'blue', 'stock': i%2==0}, ) for i in range(10) ] @@ -165,58 +171,119 @@ da.extend( Document( id=f'{i+10}', embedding=i * np.ones(n_dim), - tags={'price': i, 'color': 'blue'}, + tags={'price': i, 'color': 'red', 'stock': i%2==0}, ) for i in range(10) ] ) -print('\nIndexed prices and colors:\n') -for embedding, price, color in zip( - da.embeddings, da[:, 'tags__price'], da[:, 'tags__color'] +print('\nIndexed price, color and stock:\n') +for embedding, price, color, stock in zip( + da.embeddings, da[:, 'tags__price'], da[:, 'tags__color'], da[:, 'tags__stock'] ): - print(f'\tembedding={embedding},\t price={price},\t color={color}') + print(f'\tembedding={embedding},\t color={color},\t stock={stock}') ``` Consider the case where you want the nearest vectors to the embedding `[8., 8., 8.]`, with the restriction that -prices and colors must pass a filter. For example, let's consider that retrieved Documents must have a `price` value lower than or equal to `max_price` and have `color` equal to `color`. We can encode this information in Redis using `{'price': {'$lte': max_price}, 'color': {'$eq': color}}`. +prices, colors and stock must pass a filter. For example, let's consider that retrieved Documents must have a `price` value lower than or equal to `max_price`, have `color` equal to `color` and have `stock` equal to `True`. We can encode this information in Redis using + +```JSON +{ + "price": {"$lte": max_price}, + "color": {"$gt": color}, + "stock": {"$eq": True}, +} +``` +or + +```JSON +{ + "$and": { + "price": {"$lte": max_price}, + "color": {"$gt": color}, + "stock": {"$eq": True}, + } +} +``` Then the search with the proposed filter can be used as follows: ```python max_price = 7 -color = 'red' +color = "blue" n_limit = 5 np_query = np.ones(n_dim) * 8 print(f'\nQuery vector: \t{np_query}') -filter = {'price': {'$lte': max_price}, 'color': {'$eq': color}} +filter = { + "price": {"$lte": max_price}, + "color": {"$eq": color}, + "stock": {"$eq": True}, +} + results = da.find(np_query, filter=filter, limit=n_limit) print( - '\nEmbeddings Approximate Nearest Neighbours with "price" at most 7 and "color" red:\n' + '\nEmbeddings Approximate Nearest Neighbours with "price" at most 7, "color" blue and "stock" False:\n' ) -for embedding, price, color, score in zip( +for embedding, price, color, stock, score in zip( results.embeddings, results[:, 'tags__price'], results[:, 'tags__color'], + results[:, 'tags__stock'], results[:, 'scores'], ): print( - f' score={score["score"].value},\t embedding={embedding},\t price={price},\t color={color}' + f' score={score["score"].value},\t embedding={embedding},\t price={price},\t color={color},\t stock={stock}' ) ``` This would print: ```console -Embeddings Approximate Nearest Neighbours with "price" at most 7 and "color" red: +Embeddings Approximate Nearest Neighbours with "price" at most 7, "color" blue and "stock" False: + + score=12, embedding=[6. 6. 6.], price=6, color=blue, stock=True + score=48, embedding=[4. 4. 4.], price=4, color=blue, stock=True + score=108, embedding=[2. 2. 2.], price=2, color=blue, stock=True + score=192, embedding=[0. 0. 0.], price=0, color=blue, stock=True +``` +More example filter expresses +- A Nike shoes or price less than `100` + +```JSON +{ + "$or": { + "brand": {"$eq": "Nike"}, + "price": {"$lt": 100} + } +} +``` + +- A Nike shoes **and** either price is less than `100` or color is `"blue"` + +```JSON +{ + "brand": {"$eq": "Nike"}, + "$or": { + "price": {"$lt": 100}, + "color": {"$eq": "blue"}, + }, +} +``` - score=3, embedding=[7. 7. 7.], price=7, color=red - score=12, embedding=[6. 6. 6.], price=6, color=red - score=27, embedding=[5. 5. 5.], price=5, color=red - score=48, embedding=[4. 4. 4.], price=4, color=red - score=75, embedding=[3. 3. 3.], price=3, color=red +- A Nike shoes **or** both price is less than `100` and color is `"blue"` + +```JSON +{ + "$or": { + "brand": {"$eq": "Nike"}, + "$and": { + "price": {"$lt": 100}, + "color": {"$eq": "blue"}, + }, + } +} ``` (vector-search-index)= From 76a35f4c290c59b7a8b8f2a538739ae7e57ccb1e Mon Sep 17 00:00:00 2001 From: AnneY Date: Tue, 30 Aug 2022 17:44:54 +0800 Subject: [PATCH 03/11] fix: remove useless print and standardize code --- docarray/array/storage/redis/find.py | 14 +++++--------- tests/unit/array/mixins/test_find.py | 8 ++++---- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index 6258c426d14..67b4b2d0415 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -45,8 +45,6 @@ def _find_similar_vectors( else: query_str = "*" - print(f'query_str: {query_str}') - q = ( Query(f'({query_str})=>[KNN {limit} @embedding $vec AS vector_score]') .sort_by('vector_score') @@ -107,16 +105,14 @@ def _filter( return self._find_with_filter(filter, limit=limit) -def _build_query_node(filter): - key = list(filter.keys())[0] - operator = list(filter[key].keys())[0] - value = filter[key][operator] +def _build_query_node(key, condition): + operator = list(condition.keys())[0] + value = condition[operator] query_dict = {} if operator == '$ne': - print(f'value: {value}, type: {isinstance(value, (int, float))}') - if isinstance(value, (bool)): + if isinstance(value, bool): query_dict[key] = equal(int(value)) elif isinstance(value, (int, float)): query_dict[key] = equal(value) @@ -160,7 +156,7 @@ def _build_query_nodes(filter): node = union(*children) nodes.append(node) else: - child = _build_query_node({k: v}) + child = _build_query_node(k, v) nodes.append(child) return nodes diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 04964c80613..a7f029928f4 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -539,7 +539,7 @@ def test_redis_category_filter(start_storage): Document( id=f'r{i}', embedding=np.random.rand(n_dim), - tags={'price': i, 'category': "Shoes", 'size': i, 'isfake': True}, + tags={'price': i, 'category': 'Shoes', 'size': i, 'isfake': True}, ) for i in range(10) ] @@ -552,7 +552,7 @@ def test_redis_category_filter(start_storage): embedding=np.random.rand(n_dim), tags={ 'price': i, - 'category': "Jeans", + 'category': 'Jeans', 'size': i, 'isfake': False, }, @@ -570,7 +570,7 @@ def test_redis_category_filter(start_storage): results = da.find(np.random.rand(n_dim), filter=filter1) assert len(results) > 0 assert all( - [(r.tags['price'] > 8 or r.tags['category'] == "Shoes") for r in results] + [(r.tags['price'] > 8 or r.tags['category'] == 'Shoes') for r in results] ) filter2 = { @@ -615,7 +615,7 @@ def test_redis_category_filter(start_storage): assert all( [ ( - (r.tags['price'] >= 8 and r.tags['category'] != "Shoes") + (r.tags['price'] >= 8 and r.tags['category'] != 'Shoes') or r.tags['size'] == 3 ) for r in results From 47cadceece437bd97f9f5807515eb91ef969dbbf Mon Sep 17 00:00:00 2001 From: AnneY Date: Wed, 31 Aug 2022 09:26:31 +0800 Subject: [PATCH 04/11] refactor: organize import modules --- docarray/array/storage/redis/find.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index 67b4b2d0415..1689fa57c5b 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, Union +from typing import (TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, + Union) import numpy as np from docarray import Document, DocumentArray @@ -8,13 +9,9 @@ from docarray.score import NamedScore from redis.commands.search.query import Query -from redis.commands.search.querystring import lt, le, gt, ge, equal -from redis.commands.search.querystring import ( - intersect, - union, - DistjunctUnion, - IntersectNode, -) +from redis.commands.search.querystring import (DistjunctUnion, IntersectNode, + equal, ge, gt, intersect, le, + lt, union) if TYPE_CHECKING: import tensorflow From 2aab14fe5ad8f950a5c30c216ef228b37fa431dc Mon Sep 17 00:00:00 2001 From: AnneY Date: Wed, 31 Aug 2022 13:59:11 +0800 Subject: [PATCH 05/11] refactor: black find.py --- docarray/array/storage/redis/find.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index 1689fa57c5b..a9ce4197ed3 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -1,5 +1,4 @@ -from typing import (TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, - Union) +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, Union import numpy as np from docarray import Document, DocumentArray @@ -9,9 +8,17 @@ from docarray.score import NamedScore from redis.commands.search.query import Query -from redis.commands.search.querystring import (DistjunctUnion, IntersectNode, - equal, ge, gt, intersect, le, - lt, union) +from redis.commands.search.querystring import ( + DistjunctUnion, + IntersectNode, + equal, + ge, + gt, + intersect, + le, + lt, + union, +) if TYPE_CHECKING: import tensorflow From 5636115fa6aa3c8ecc23af6cb6c694db53670dbc Mon Sep 17 00:00:00 2001 From: AnneY Date: Sat, 3 Sep 2022 16:07:47 +0800 Subject: [PATCH 06/11] refactor: code style --- docarray/array/storage/redis/find.py | 8 +- tests/unit/array/mixins/test_find.py | 110 ++++++++++++--------------- 2 files changed, 54 insertions(+), 64 deletions(-) diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index a9ce4197ed3..af94dded057 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -47,7 +47,7 @@ def _find_similar_vectors( nodes = _build_query_nodes(filter) query_str = intersect(*nodes).to_string() else: - query_str = "*" + query_str = '*' q = ( Query(f'({query_str})=>[KNN {limit} @embedding $vec AS vector_score]') @@ -133,7 +133,7 @@ def _build_query_node(key, condition): elif operator == '$lte': query_dict[key] = le(value) elif operator == '$eq': - if isinstance(value, (bool)): + if isinstance(value, bool): query_dict[key] = equal(int(value)) elif isinstance(value, (int, float)): query_dict[key] = equal(value) @@ -151,11 +151,11 @@ def _build_query_node(key, condition): def _build_query_nodes(filter): nodes = [] for k, v in filter.items(): - if k == "$and": + if k == '$and': children = _build_query_nodes(v) node = intersect(*children) nodes.append(node) - elif k == "$or": + elif k == '$or': children = _build_query_nodes(v) node = union(*children) nodes.append(node) diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index a7f029928f4..65fa50580ce 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -518,7 +518,54 @@ def test_weaviate_filter_query(start_storage): assert isinstance(da._filter(filter={}), type(da)) -def test_redis_category_filter(start_storage): +@pytest.mark.parametrize( + 'filter,checker', + [ + ( + { + "$or": { + "price": {"$gt": 8}, + "category": {"$eq": "Shoes"}, + }, + }, + lambda r: r.tags['price'] > 8 or r.tags['category'] == 'Shoes', + ), + ( + { + "$and": { + "price": {"$ne": 8}, + "isfake": {"$eq": True}, + }, + }, + lambda r: r.tags['price'] != 8 and r.tags['isfake'] == True, + ), + ( + { + "$or": { + "price": {"$lt": 8}, + "isfake": {"$ne": True}, + }, + "size": {"$lte": 3}, + }, + lambda r: (r.tags['price'] < 8 or r.tags['isfake'] != True) + and r.tags['size'] <= 3, + ), + ( + { + "$or": { + "$and": { + "price": {"$gte": 8}, + "category": {"$ne": "Shoes"}, + }, + "size": {"$eq": 3}, + }, + }, + lambda r: (r.tags['price'] >= 8 and r.tags['category'] != 'Shoes') + or r.tags['size'] == 3, + ), + ], +) +def test_redis_category_filter(filter, checker, start_storage): n_dim = 128 da = DocumentArray( storage='redis', @@ -561,66 +608,9 @@ def test_redis_category_filter(start_storage): ] ) - filter1 = { - "$or": { - "price": {"$gt": 8}, - "category": {"$eq": "Shoes"}, - }, - } - results = da.find(np.random.rand(n_dim), filter=filter1) - assert len(results) > 0 - assert all( - [(r.tags['price'] > 8 or r.tags['category'] == 'Shoes') for r in results] - ) - - filter2 = { - "$and": { - "price": {"$ne": 8}, - "isfake": {"$eq": True}, - }, - } - results = da.find(np.random.rand(n_dim), filter=filter2) - assert len(results) > 0 - assert all([(r.tags['price'] != 8 and r.tags['isfake'] == True) for r in results]) - - filter3 = { - "$or": { - "price": {"$lt": 8}, - "isfake": {"$ne": True}, - }, - "size": {"$lte": 3}, - } - - results = da.find(np.random.rand(n_dim), filter=filter3) - assert len(results) > 0 - assert all( - [ - ((r.tags['price'] < 8 or r.tags['isfake'] != True) and r.tags['size'] <= 3) - for r in results - ] - ) - - filter4 = { - "$or": { - "$and": { - "price": {"$gte": 8}, - "category": {"$ne": "Shoes"}, - }, - "size": {"$eq": 3}, - }, - } - - results = da.find(np.random.rand(n_dim), filter=filter4) + results = da.find(np.random.rand(n_dim), filter=filter) assert len(results) > 0 - assert all( - [ - ( - (r.tags['price'] >= 8 and r.tags['category'] != 'Shoes') - or r.tags['size'] == 3 - ) - for r in results - ] - ) + assert all([checker(r) for r in results]) @pytest.mark.parametrize('storage', ['memory']) From 70afce79406e32f9439a49a041d6a147b36cb7ee Mon Sep 17 00:00:00 2001 From: AnneY Date: Mon, 5 Sep 2022 20:44:06 +0800 Subject: [PATCH 07/11] refactor: simplify code --- docarray/array/storage/redis/find.py | 53 +++++++++++++--------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index af94dded057..c031c4917aa 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -39,7 +39,7 @@ def _find_similar_vectors( self, query: 'RedisArrayType', filter: Optional[Dict] = None, - limit: Optional[Union[int, float]] = 20, + limit: int = 20, **kwargs, ): @@ -73,7 +73,7 @@ def _find_similar_vectors( def _find( self, query: 'RedisArrayType', - limit: Optional[Union[int, float]] = 20, + limit: int = 20, filter: Optional[Dict] = None, **kwargs, ) -> List['DocumentArray']: @@ -88,7 +88,7 @@ def _find( for q in query ] - def _find_with_filter(self, filter: Dict, limit: Optional[Union[int, float]] = 20): + def _find_with_filter(self, filter: Dict, limit: int = 20): nodes = _build_query_nodes(filter) query_str = intersect(*nodes).to_string() q = Query(query_str) @@ -102,9 +102,7 @@ def _find_with_filter(self, filter: Dict, limit: Optional[Union[int, float]] = 2 da.append(doc) return da - def _filter( - self, filter: Dict, limit: Optional[Union[int, float]] = 20 - ) -> 'DocumentArray': + def _filter(self, filter: Dict, limit: int = 20) -> 'DocumentArray': return self._find_with_filter(filter, limit=limit) @@ -115,37 +113,29 @@ def _build_query_node(key, condition): query_dict = {} - if operator == '$ne': + if operator in ['$ne', '$eq']: if isinstance(value, bool): query_dict[key] = equal(int(value)) elif isinstance(value, (int, float)): query_dict[key] = equal(value) else: query_dict[key] = value - node = DistjunctUnion(**query_dict) + elif operator == '$gt': + query_dict[key] = gt(value) + elif operator == '$gte': + query_dict[key] = ge(value) + elif operator == '$lt': + query_dict[key] = lt(value) + elif operator == '$lte': + query_dict[key] = le(value) else: - if operator == '$gt': - query_dict[key] = gt(value) - elif operator == '$gte': - query_dict[key] = ge(value) - elif operator == '$lt': - query_dict[key] = lt(value) - elif operator == '$lte': - query_dict[key] = le(value) - elif operator == '$eq': - if isinstance(value, bool): - query_dict[key] = equal(int(value)) - elif isinstance(value, (int, float)): - query_dict[key] = equal(value) - else: - query_dict[key] = value - else: - raise ValueError( - f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead' - ) - node = IntersectNode(**query_dict) + raise ValueError( + f'Expecting filter operator one of $gt, $gte, $lt, $lte, $eq, $ne, $and OR $or, got {operator} instead' + ) - return node + if operator == '$ne': + return DistjunctUnion(**query_dict) + return IntersectNode(**query_dict) def _build_query_nodes(filter): @@ -164,3 +154,8 @@ def _build_query_nodes(filter): nodes.append(child) return nodes + + +def _build_query_str(query): + query_str = "|".join(query.split(" ")) + return query_str From f56114b671933e385856acb6c022d8cb504e61ba Mon Sep 17 00:00:00 2001 From: AnneY Date: Tue, 6 Sep 2022 21:59:46 +0800 Subject: [PATCH 08/11] fix: remove useless function --- docarray/array/storage/redis/find.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index c031c4917aa..a9d5109ff52 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -155,7 +155,3 @@ def _build_query_nodes(filter): return nodes - -def _build_query_str(query): - query_str = "|".join(query.split(" ")) - return query_str From ed038f6f473caf1929580bb911a29e4a62bc0967 Mon Sep 17 00:00:00 2001 From: AnneY Date: Wed, 7 Sep 2022 16:06:24 +0800 Subject: [PATCH 09/11] docs: clearify formulation --- docs/advanced/document-store/redis.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/docs/advanced/document-store/redis.md b/docs/advanced/document-store/redis.md index 8b3aea2de0f..07bfe12a5b1 100644 --- a/docs/advanced/document-store/redis.md +++ b/docs/advanced/document-store/redis.md @@ -41,7 +41,7 @@ da = DocumentArray( ) ``` -The usage would be the same as the ordinary DocumentArray, but the dimension of an embedding for a Document must be provided at creation time. +The usage will be the same as the ordinary DocumentArray, but the dimension of an embedding for a Document must be provided at creation time. ```{caution} Currently, one Redis server instance can only store a single DocumentArray. @@ -184,10 +184,9 @@ for embedding, price, color, stock in zip( print(f'\tembedding={embedding},\t color={color},\t stock={stock}') ``` -Consider the case where you want the nearest vectors to the embedding `[8., 8., 8.]`, with the restriction that -prices, colors and stock must pass a filter. For example, let's consider that retrieved Documents must have a `price` value lower than or equal to `max_price`, have `color` equal to `color` and have `stock` equal to `True`. We can encode this information in Redis using +Consider the case where you want the nearest vectors to the embedding `[8., 8., 8.]`, with the restriction that prices, colors and stock must pass a filter. For example, let's consider that retrieved Documents must have a `price` value lower than or equal to `max_price`, have `color` equal to `blue` and have `stock` equal to `True`. We can encode this information in Redis using -```JSON +```text { "price": {"$lte": max_price}, "color": {"$gt": color}, @@ -196,7 +195,7 @@ prices, colors and stock must pass a filter. For example, let's consider that re ``` or -```JSON +```text { "$and": { "price": {"$lte": max_price}, @@ -238,7 +237,7 @@ for embedding, price, color, stock, score in zip( ) ``` -This would print: +This will print: ```console Embeddings Approximate Nearest Neighbours with "price" at most 7, "color" blue and "stock" False: @@ -334,7 +333,7 @@ for embedding, score in zip( print(f' embedding={embedding},\t score={score["score"].value}') ``` -This would print: +This will print: ```console Embeddings Approximate Nearest Neighbours: From b43c081013dcb8e42c5383ba24b45fd69856b4da Mon Sep 17 00:00:00 2001 From: AnneY Date: Wed, 7 Sep 2022 18:49:05 +0800 Subject: [PATCH 10/11] refactor: reformat code --- docarray/array/storage/redis/find.py | 1 - 1 file changed, 1 deletion(-) diff --git a/docarray/array/storage/redis/find.py b/docarray/array/storage/redis/find.py index a9d5109ff52..cd38fd98fb1 100644 --- a/docarray/array/storage/redis/find.py +++ b/docarray/array/storage/redis/find.py @@ -154,4 +154,3 @@ def _build_query_nodes(filter): nodes.append(child) return nodes - From 834e09e7fc8ae0591bd51d5fcb68594226b6adec Mon Sep 17 00:00:00 2001 From: Alaeddine Abdessalem Date: Thu, 8 Sep 2022 17:11:23 +0100 Subject: [PATCH 11/11] chore: lint --- tests/unit/array/mixins/test_find.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index c5b66975d53..38d4e500603 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -519,7 +519,7 @@ def test_weaviate_filter_query(start_storage, columns): 'columns', [ [('price', 'int'), ('category', 'str'), ('size', 'int'), ('isfake', 'bool')], - {'price': 'int', 'category': 'str', 'size': 'int', 'isfake': 'bool'} + {'price': 'int', 'category': 'str', 'size': 'int', 'isfake': 'bool'}, ], ) @pytest.mark.parametrize(