Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions docarray/array/storage/elastic/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ def _init_storage(
self._config.columns = self._normalize_columns(self._config.columns)

self.n_dim = self._config.n_dim
self._client = self._build_client()
self._list_like = self._config.list_like

self._client = self._build_client()
self._build_index()
self._build_offset2id_index()

# Note super()._init_storage() calls _load_offset2ids which calls _get_offset2ids_meta
Expand Down Expand Up @@ -167,21 +169,22 @@ def _build_schema_from_elastic_config(self, elastic_config):
return da_schema

def _build_client(self):

client = Elasticsearch(
hosts=self._config.hosts,
**self._config.es_config,
)

return client

def _build_index(self):
schema = self._build_schema_from_elastic_config(self._config)

if not client.indices.exists(index=self._config.index_name):
client.indices.create(
if not self._client.indices.exists(index=self._config.index_name):
self._client.indices.create(
index=self._config.index_name, mappings=schema['mappings']
)

client.indices.refresh(index=self._config.index_name)
return client
self._client.indices.refresh(index=self._config.index_name)

def _send_requests(self, request, **kwargs) -> List[Dict]:
"""Send bulk request to Elastic and gather the successful info"""
Expand Down
1 change: 1 addition & 0 deletions docarray/array/storage/elastic/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _del_doc_by_id(self, _id: str):
def _clear_storage(self):
"""Concrete implementation of base class' ``_clear_storage``"""
self._client.indices.delete(index=self._config.index_name)
self._build_index()

def _load_offset2ids(self):
if self._list_like:
Expand Down
22 changes: 11 additions & 11 deletions docarray/array/storage/redis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def _init_storage(
self._config.columns = self._normalize_columns(self._config.columns)

self._client = self._build_client()
self._build_index()

super()._init_storage()

if _docs is None:
Expand All @@ -100,19 +102,21 @@ def _build_client(self):
port=self._config.port,
**self._config.redis_config,
)
return client

if self._config.update_schema:
if self._config.index_name.encode() in client.execute_command('FT._LIST'):
client.ft(index_name=self._config.index_name).dropindex()
def _build_index(self, rebuild: bool = False):
if self._config.update_schema or rebuild:
if self._config.index_name.encode() in self._client.execute_command(
'FT._LIST'
):
self._client.ft(index_name=self._config.index_name).dropindex()

schema = self._build_schema_from_redis_config()
idef = IndexDefinition(prefix=[self._doc_prefix])
client.ft(index_name=self._config.index_name).create_index(
self._client.ft(index_name=self._config.index_name).create_index(
schema, definition=idef
)

return client

def _ensure_unique_config(
self,
config_root: dict,
Expand Down Expand Up @@ -195,8 +199,4 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state
self._client = Redis(
host=self._config.host,
port=self._config.port,
**self._config.redis_config,
)
self._client = self._build_client()
1 change: 1 addition & 0 deletions docarray/array/storage/redis/getsetdel.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,5 @@ def _clear_storage(self):
self._client.ft(index_name=self._config.index_name).dropindex(
delete_documents=True
)
self._build_index(rebuild=True)
self._client.delete(self._offset2id_key)
55 changes: 38 additions & 17 deletions tests/unit/array/test_advance_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,25 +459,10 @@ def test_path_syntax_indexing_set(storage, config, use_subindex, start_storage):
assert da[2].id == 'new_id'


@pytest.mark.parametrize(
'storage,config',
[
('memory', None),
('sqlite', None),
('weaviate', WeaviateConfig(n_dim=123)),
('annlite', AnnliteConfig(n_dim=123)),
('qdrant', QdrantConfig(n_dim=123)),
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)),
('elasticsearch', ElasticConfig(n_dim=123)),
('redis', RedisConfig(n_dim=123)),
('milvus', MilvusConfig(n_dim=123)),
],
)
def test_getset_subindex(storage, config, start_storage):
def test_getset_subindex():
da = DocumentArray(
[Document(chunks=[Document() for _ in range(5)]) for _ in range(3)],
config=config,
subindex_configs={'@c': {'n_dim': 123}} if config else {'@c': None},
subindex_configs={'@c': None},
)
with da:
assert len(da['@c']) == 15
Expand Down Expand Up @@ -509,6 +494,42 @@ def test_getset_subindex(storage, config, start_storage):
assert collected_chunks == new_chunks


@pytest.mark.parametrize(
'storage,config,subindex_config',
[
('memory', None, None),
('sqlite', None, None),
('weaviate', WeaviateConfig(n_dim=123), {'n_dim': 123}),
('annlite', AnnliteConfig(n_dim=123), {'n_dim': 123}),
('qdrant', QdrantConfig(n_dim=123), {'n_dim': 123}),
('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True), {'n_dim': 123}),
('elasticsearch', ElasticConfig(n_dim=123), {'n_dim': 123}),
('redis', RedisConfig(n_dim=123), {'n_dim': 123}),
('milvus', MilvusConfig(n_dim=123), {'n_dim': 123}),
],
)
def test_getset_subindex_in_store(storage, config, subindex_config, start_storage):
da = DocumentArray(
[Document(chunks=[Document() for _ in range(5)]) for _ in range(3)],
storage=storage,
config=config,
subindex_configs={'@c': subindex_config},
)
with da:
assert len(da['@c']) == 15
assert len(da._subindices['@c']) == 15

chunks_ids = [c.id for c in da['@c']]
new_chunks = [
Document(id=cid, embedding=np.ones(123) * i)
for i, cid in enumerate(chunks_ids)
]
da['@c'] = new_chunks

res = da.find(np.random.random(123), on='@c')
assert len(res) > 0


@pytest.mark.parametrize('size', [1, 5])
@pytest.mark.parametrize(
'storage,config_gen',
Expand Down