From 165afc396478654c84af5976b3b673bd1972dc76 Mon Sep 17 00:00:00 2001 From: AnneY Date: Wed, 19 Apr 2023 20:53:15 +0800 Subject: [PATCH 1/2] fix: save index during creation Signed-off-by: AnneY --- docarray/index/backends/hnswlib.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 467d3f754fb..66dc82c26bf 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -116,7 +116,7 @@ def __init__(self, db_config=None, **kwargs): self._hnsw_indices[col_name] = self._load_index(col_name, col) self._logger.info(f'Loading an existing index for column `{col_name}`') else: - self._hnsw_indices[col_name] = self._create_index(col) + self._hnsw_indices[col_name] = self._create_index(col_name, col) self._logger.info(f'Created a new index for column `{col_name}`') # SQLite setup @@ -396,13 +396,14 @@ def _create_index_class(self, col: '_ColumnInfo') -> hnswlib.Index: construct_params['dim'] = col.n_dim return hnswlib.Index(**construct_params) - def _create_index(self, col: '_ColumnInfo') -> hnswlib.Index: + def _create_index(self, col_name: str, col: '_ColumnInfo') -> hnswlib.Index: """Create a new HNSW index for a column, and initialize it.""" index = self._create_index_class(col) init_params = dict((k, col.config[k]) for k in self._index_init_params) index.init_index(**init_params) index.set_ef(col.config['ef']) index.set_num_threads(col.config['num_threads']) + index.save_index(self._hnsw_locations[col_name]) return index # SQLite helpers From a0d5e6bc4da26fa6137a5695793c3e7c4eb9dcd3 Mon Sep 17 00:00:00 2001 From: AnneY Date: Wed, 19 Apr 2023 21:10:43 +0800 Subject: [PATCH 2/2] test: add test for index bin files persist Signed-off-by: AnneY --- tests/index/hnswlib/test_persist_data.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/index/hnswlib/test_persist_data.py b/tests/index/hnswlib/test_persist_data.py index 1ac02d11d42..7724d5408c5 100644 --- a/tests/index/hnswlib/test_persist_data.py +++ b/tests/index/hnswlib/test_persist_data.py @@ -78,3 +78,8 @@ def test_persist_and_restore_nested(tmp_path): ] ) assert store.num_docs() == 15 + + +def test_persist_index_file(tmp_path): + _ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + _ = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))