-
Notifications
You must be signed in to change notification settings - Fork 238
feat(redis): implement Redis storage backend and unit tests #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc3f93b
703ea09
6f4f8b1
06fc38a
f21e2da
3a3e78b
5ae1409
9e716d1
90e5ae9
5300e23
a6e1968
ca8efb9
0dab0c4
0276073
f64ee39
94f7e73
86323c2
f23b73a
15d5ff6
5369de0
58a365b
56aab5d
0cb15df
902fede
f918fcb
e29061d
bcc6ae3
bf84282
0358083
926ecc5
897465b
15c1c17
bdf5529
f827c2b
4626ffd
be1413d
096b303
3d4f411
a3144cb
13e6b6c
6949fe2
08d6644
a8f6eeb
efd052d
60b9526
5da0d89
0d32d68
6faa2ab
cfbb543
a91d2c5
8bc105a
8e89d14
e36ae39
f4b19b7
aa2f8b2
e46174c
634b62d
5ddaeb7
57feb9b
cad75f7
75e7a90
12a9985
61360d0
d4a82d1
8a6168d
7d3de6c
4d0505f
911b648
baeccd2
c0a1f35
b2f3a77
ddd1eba
2d7c75c
e85f1f7
53cbaa9
97bc99e
3bba4b0
c0f083f
582b56b
1e2682b
80cff7a
549939a
1fce045
cfd9aa8
b718087
c23f630
4e1e933
8dbc5ee
be3f14a
55a6c1c
173066a
59dc17f
1dce140
4f5ff4a
e23e7ef
ac6352c
5f8ff6f
e74bca8
73b6dbe
be72851
fdd9462
0cf9618
8d0f251
fec9cbf
9063420
fe800d8
5f94757
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| from .document import DocumentArray | ||
| from .storage.redis import RedisConfig, StorageMixins | ||
|
|
||
| __all__ = ['DocumentArrayRedis', 'RedisConfig'] | ||
|
|
||
|
|
||
| class DocumentArrayRedis(StorageMixins, DocumentArray): | ||
| """This is a :class:`DocumentArray` that uses Redis as | ||
| vector search engine and storage. | ||
| """ | ||
|
|
||
| def __new__(cls, *args, **kwargs): | ||
| """``__new__`` method for :class:`DocumentArrayRedis` | ||
|
|
||
| :param *args: list of args to instantiate the object | ||
| :param **kwargs: dict of args to instantiate the object | ||
| :return: the instantiated :class:`DocumentArrayRedis` object | ||
| """ | ||
| return super().__new__(cls) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| from abc import ABC | ||
|
|
||
| from .backend import BackendMixin, RedisConfig | ||
| from .find import FindMixin | ||
| from .getsetdel import GetSetDelMixin | ||
| from .seqlike import SequenceLikeMixin | ||
|
|
||
| __all__ = ['StorageMixins', 'RedisConfig'] | ||
|
|
||
|
|
||
| class StorageMixins(FindMixin, BackendMixin, GetSetDelMixin, SequenceLikeMixin, ABC): | ||
| ... |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| from dataclasses import dataclass, field | ||
| from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| from docarray import Document | ||
| from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap | ||
| from docarray.helper import dataclass_from_dict | ||
|
|
||
| from redis import Redis | ||
| from redis.commands.search.field import NumericField, TextField, VectorField | ||
| from redis.commands.search.indexDefinition import IndexDefinition | ||
|
|
||
| if TYPE_CHECKING: | ||
| from docarray.typing import ArrayType, DocumentArraySourceType | ||
|
|
||
|
|
||
| @dataclass | ||
| class RedisConfig: | ||
| n_dim: int | ||
| host: str = field(default='localhost') | ||
| port: int = field(default=6379) | ||
| index_name: str = field(default='idx') | ||
| flush: bool = field(default=False) | ||
| update_schema: bool = field(default=True) | ||
| distance: str = field(default='COSINE') | ||
| redis_config: Dict[str, Any] = field(default_factory=dict) | ||
| batch_size: int = field(default=64) | ||
| method: str = field(default='HNSW') | ||
alaeddine-13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ef_construction: int = field(default=200) | ||
| m: int = field(default=16) | ||
| ef_runtime: int = field(default=10) | ||
| block_size: int = field(default=1048576) | ||
| initial_cap: Optional[int] = None | ||
| columns: Optional[List[Tuple[str, str]]] = None | ||
|
|
||
|
|
||
| class BackendMixin(BaseBackendMixin): | ||
| """Provide necessary functions to enable this storage backend.""" | ||
|
|
||
| TYPE_MAP = { | ||
| 'str': TypeMap(type='text', converter=TextField), | ||
| 'bytes': TypeMap(type='text', converter=TextField), | ||
| 'int': TypeMap(type='integer', converter=NumericField), | ||
| 'float': TypeMap(type='float', converter=NumericField), | ||
| 'double': TypeMap(type='double', converter=NumericField), | ||
| 'long': TypeMap(type='long', converter=NumericField), | ||
| 'bool': TypeMap(type='long', converter=NumericField), | ||
| } | ||
|
|
||
| def _init_storage( | ||
| self, | ||
| _docs: Optional['DocumentArraySourceType'] = None, | ||
| config: Optional[Union[RedisConfig, Dict]] = None, | ||
| **kwargs, | ||
| ): | ||
| if not config: | ||
| raise ValueError('Empty config is not allowed for Redis storage') | ||
| elif isinstance(config, dict): | ||
| config = dataclass_from_dict(RedisConfig, config) | ||
|
|
||
| if config.distance not in ['L2', 'IP', 'COSINE']: | ||
| raise ValueError( | ||
| f'Expecting distance metric one of COSINE, L2 OR IP, got {config.distance} instead' | ||
| ) | ||
| if config.method not in ['HNSW', 'FLAT']: | ||
| raise ValueError( | ||
| f'Expecting search method one of HNSW OR FLAT, got {config.method} instead' | ||
| ) | ||
|
|
||
| if config.redis_config.get('decode_responses'): | ||
| config.redis_config['decode_responses'] = False | ||
|
|
||
| self._offset2id_key = config.index_name + '__offset2id' | ||
| self._config = config | ||
| self.n_dim = self._config.n_dim | ||
| self._doc_prefix = config.index_name + ':' | ||
| self._config.columns = self._normalize_columns(self._config.columns) | ||
|
|
||
| self._client = self._build_client() | ||
| super()._init_storage() | ||
|
|
||
| if _docs is None: | ||
| return | ||
| elif isinstance(_docs, Iterable): | ||
| self.extend(_docs) | ||
| elif isinstance(_docs, Document): | ||
| self.append(_docs) | ||
|
|
||
| def _build_client(self): | ||
| client = Redis( | ||
| host=self._config.host, | ||
| port=self._config.port, | ||
| **self._config.redis_config, | ||
| ) | ||
|
|
||
| if self._config.flush: | ||
| client.flushdb() | ||
|
|
||
| if self._config.update_schema: | ||
| if self._config.index_name.encode() in client.execute_command('FT._LIST'): | ||
| client.ft(index_name=self._config.index_name).dropindex() | ||
|
|
||
| if self._config.flush or self._config.update_schema: | ||
| schema = self._build_schema_from_redis_config() | ||
| idef = IndexDefinition(prefix=[self._doc_prefix]) | ||
| client.ft(index_name=self._config.index_name).create_index( | ||
| schema, definition=idef | ||
| ) | ||
|
|
||
| return client | ||
|
|
||
| def _ensure_unique_config( | ||
| self, | ||
| config_root: dict, | ||
| config_subindex: dict, | ||
| config_joined: dict, | ||
| subindex_name: str, | ||
| ) -> dict: | ||
| if 'index_name' not in config_subindex: | ||
| config_joined['index_name'] = ( | ||
| config_joined['index_name'] + '_subindex_' + subindex_name | ||
| ) | ||
| config_joined['flush'] = False | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this do? Why do we need it here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parameter flush only exists in RedisConfig, flush=True will clear the redis database. We don't want to clear database when there is subindex docarray, so I make sure flush is set to False here.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure about this, why would a subindex behave differently from the main index? Shouldn't wee let the user decide what they want to do? @alaeddine-13 @bwanglzu ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. didn't get it, i think we already have a flush option which user can control in the config? |
||
| return config_joined | ||
AnneYang720 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _build_schema_from_redis_config(self): | ||
| index_param = { | ||
| 'TYPE': 'FLOAT32', | ||
AnneYang720 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| 'DIM': self.n_dim, | ||
| 'DISTANCE_METRIC': self._config.distance, | ||
| } | ||
|
|
||
| if self._config.method == 'HNSW': | ||
| index_options = { | ||
| 'M': self._config.m, | ||
| 'EF_CONSTRUCTION': self._config.ef_construction, | ||
| 'EF_RUNTIME': self._config.ef_runtime, | ||
| } | ||
| index_param.update(index_options) | ||
|
|
||
| if self._config.method == 'FLAT': | ||
| index_options = {'BLOCK_SIZE': self._config.block_size} | ||
| index_param.update(index_options) | ||
|
|
||
| if self._config.initial_cap: | ||
| index_param['INITIAL_CAP'] = self._config.initial_cap | ||
| schema = [VectorField('embedding', self._config.method, index_param)] | ||
|
|
||
| for col, coltype in self._config.columns: | ||
| schema.append(self._map_column(col, coltype)) | ||
|
|
||
| return schema | ||
|
|
||
| def _doc_id_exists(self, doc_id): | ||
| return self._client.exists(self._doc_prefix + doc_id) | ||
|
|
||
| def _map_embedding(self, embedding: 'ArrayType') -> bytes: | ||
| if embedding is not None: | ||
| from docarray.math.ndarray import to_numpy_array | ||
|
|
||
| embedding = to_numpy_array(embedding) | ||
|
|
||
| if embedding.ndim > 1: | ||
| embedding = np.asarray(embedding).squeeze() | ||
| else: | ||
| embedding = np.zeros(self.n_dim) | ||
| return embedding.astype(np.float32).tobytes() | ||
|
|
||
| def _get_offset2ids_meta(self) -> List[str]: | ||
| if not self._client.exists(self._offset2id_key): | ||
| return [] | ||
| ids = self._client.lrange(self._offset2id_key, 0, -1) | ||
| return [id.decode() for id in ids] | ||
|
|
||
| def _update_offset2ids_meta(self): | ||
| """Update the offset2ids in redis""" | ||
| if self._client.exists(self._offset2id_key): | ||
| self._client.delete(self._offset2id_key) | ||
| if len(self._offset2ids.ids) > 0: | ||
| self._client.rpush(self._offset2id_key, *self._offset2ids.ids) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, Union | ||
|
|
||
| import numpy as np | ||
| from docarray import Document, DocumentArray | ||
| from docarray.array.mixins.find import FindMixin as BaseFindMixin | ||
| from docarray.math import ndarray | ||
| from docarray.math.ndarray import to_numpy_array | ||
| from docarray.score import NamedScore | ||
|
|
||
| from redis.commands.search.query import NumericFilter, Query | ||
|
|
||
| if TYPE_CHECKING: | ||
| import tensorflow | ||
| import torch | ||
|
|
||
| RedisArrayType = TypeVar( | ||
| 'RedisArrayType', | ||
| np.ndarray, | ||
| tensorflow.Tensor, | ||
| torch.Tensor, | ||
| Sequence[float], | ||
| Dict, | ||
| ) | ||
|
|
||
|
|
||
| class FindMixin(BaseFindMixin): | ||
| def _find_similar_vectors( | ||
| self, | ||
| query: 'RedisArrayType', | ||
| filter: Optional[Dict] = None, | ||
| limit: Optional[Union[int, float]] = 20, | ||
AnneYang720 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| **kwargs, | ||
| ): | ||
|
|
||
| query_str = self._build_query_str(filter) if filter else "*" | ||
|
|
||
| q = ( | ||
| Query(f'{query_str}=>[KNN {limit} @embedding $vec AS vector_score]') | ||
| .sort_by('vector_score') | ||
| .paging(0, limit) | ||
| .dialect(2) | ||
| ) | ||
|
|
||
| query_params = {'vec': to_numpy_array(query).astype(np.float32).tobytes()} | ||
| results = ( | ||
| self._client.ft(index_name=self._config.index_name) | ||
| .search(q, query_params) | ||
| .docs | ||
| ) | ||
|
|
||
| da = DocumentArray() | ||
| for res in results: | ||
| doc = Document.from_base64(res.blob.encode()) | ||
| doc.scores['score'] = NamedScore(value=res.vector_score) | ||
| da.append(doc) | ||
| return da | ||
|
|
||
| def _find( | ||
| self, | ||
| query: 'RedisArrayType', | ||
| limit: Optional[Union[int, float]] = 20, | ||
| filter: Optional[Dict] = None, | ||
| **kwargs, | ||
| ) -> List['DocumentArray']: | ||
|
|
||
| query = np.array(query) | ||
| num_rows, n_dim = ndarray.get_array_rows(query) | ||
| if n_dim != 2: | ||
| query = query.reshape((num_rows, -1)) | ||
|
|
||
| return [ | ||
| self._find_similar_vectors(q, filter=filter, limit=limit, **kwargs) | ||
| for q in query | ||
| ] | ||
|
|
||
| def _find_with_filter(self, filter: Dict, limit: Optional[Union[int, float]] = 20): | ||
| s = self._build_query_str(filter) | ||
| q = Query(s) | ||
| q.paging(0, limit) | ||
|
|
||
| results = self._client.ft(index_name=self._config.index_name).search(q).docs | ||
|
|
||
| da = DocumentArray() | ||
| for res in results: | ||
| doc = Document.from_base64(res.blob.encode()) | ||
| da.append(doc) | ||
| return da | ||
|
|
||
| def _filter( | ||
| self, filter: Dict, limit: Optional[Union[int, float]] = 20 | ||
| ) -> 'DocumentArray': | ||
|
|
||
| return self._find_with_filter(filter, limit=limit) | ||
|
|
||
| def _build_query_str(self, filter: Dict) -> str: | ||
| INF = "+inf" | ||
| NEG_INF = "-inf" | ||
| s = "(" | ||
|
|
||
| for key in filter: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here, lt's check for |
||
| operator = list(filter[key].keys())[0] | ||
| value = filter[key][operator] | ||
| if operator == '$gt': | ||
| s += f"@{key}:[({value} {INF}] " | ||
| elif operator == '$gte': | ||
| s += f"@{key}:[{value} {INF}] " | ||
| elif operator == '$lt': | ||
| s += f"@{key}:[{NEG_INF} ({value}] " | ||
| elif operator == '$lte': | ||
| s += f"@{key}:[{NEG_INF} {value}] " | ||
| elif operator == '$eq': | ||
| if type(value) is int: | ||
| s += f"@{key}:[{value} {value}] " | ||
| elif type(value) is bool: | ||
| s += f"@{key}:[{int(value)} {int(value)}] " | ||
| else: | ||
| s += f"@{key}:{value} " | ||
| elif operator == '$ne': | ||
| if type(value) is int: | ||
| s += f"-@{key}:[{value} {value}] " | ||
| elif type(value) is bool: | ||
| s += f"-@{key}:[{int(value)} {int(value)}] " | ||
| else: | ||
| s += f"-@{key}:{value} " | ||
| s += ")" | ||
|
|
||
| return s | ||
Uh oh!
There was an error while loading. Please reload this page.