From 9ee61e75694eecfa3a0a67210efd280f5b3d8a28 Mon Sep 17 00:00:00 2001 From: AnneY Date: Thu, 13 Apr 2023 16:04:43 +0800 Subject: [PATCH] fix: default dims=-1 for elastic index Signed-off-by: AnneY --- docarray/index/backends/elastic.py | 11 ++++++++++- tests/index/elastic/fixture.py | 5 +++++ tests/index/elastic/v7/test_index_get_del.py | 7 ++++++- tests/index/elastic/v8/test_index_get_del.py | 9 +++++++-- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index c2c1c6646a2..52f60e1d098 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -88,8 +88,17 @@ def __init__(self, db_config=None, **kwargs): mappings.update(self._db_config.index_mappings) for col_name, col in self._column_infos.items(): + if col.db_type == 'dense_vector' and ( + not col.n_dim and col.config['dims'] < 0 + ): + self._logger.info( + f'Not indexing column {col_name}, the dimensionality is not specified' + ) + continue + mappings['properties'][col_name] = self._create_index_mapping(col) + # print(mappings['properties']) if self._client.indices.exists(index=self._index_name): self._client_put_mapping(mappings) else: @@ -231,8 +240,8 @@ def __post_init__(self): def dense_vector_config(self): config = { + 'dims': -1, 'index': True, - 'dims': 128, 'similarity': 'cosine', # 'l2_norm', 'dot_product', 'cosine' 'm': 16, 'ef_construction': 100, diff --git a/tests/index/elastic/fixture.py b/tests/index/elastic/fixture.py index 315078d6269..812f0f09d51 100644 --- a/tests/index/elastic/fixture.py +++ b/tests/index/elastic/fixture.py @@ -6,6 +6,7 @@ from pydantic import Field from docarray import BaseDoc +from docarray.documents import ImageDoc from docarray.typing import NdArray pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -58,6 +59,10 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc +class MyImageDoc(ImageDoc): + embedding: NdArray = Field(dims=128) + + @pytest.fixture(scope='function') def ten_simple_docs(): return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)] diff --git a/tests/index/elastic/v7/test_index_get_del.py b/tests/index/elastic/v7/test_index_get_del.py index 7124d5d61bd..d5ead493c03 100644 --- a/tests/index/elastic/v7/test_index_get_del.py +++ b/tests/index/elastic/v7/test_index_get_del.py @@ -10,6 +10,7 @@ from tests.index.elastic.fixture import ( # noqa: F401 DeepNestedDoc, FlatDoc, + MyImageDoc, NestedDoc, SimpleDoc, start_storage_v7, @@ -247,7 +248,7 @@ class MySchema(BaseDoc): def test_index_multi_modal_doc(): class MyMultiModalDoc(BaseDoc): - image: ImageDoc + image: MyImageDoc text: TextDoc store = ElasticV7DocIndex[MyMultiModalDoc]() @@ -263,3 +264,7 @@ class MyMultiModalDoc(BaseDoc): assert store[id_].id == id_ assert np.all(store[id_].image.embedding == doc[0].image.embedding) assert store[id_].text.text == doc[0].text.text + + query = doc[0] + docs, _ = store.find(query, limit=10, search_field='image__embedding') + assert len(docs) > 0 diff --git a/tests/index/elastic/v8/test_index_get_del.py b/tests/index/elastic/v8/test_index_get_del.py index db2df925ebb..03560caae7d 100644 --- a/tests/index/elastic/v8/test_index_get_del.py +++ b/tests/index/elastic/v8/test_index_get_del.py @@ -10,6 +10,7 @@ from tests.index.elastic.fixture import ( # noqa: F401 DeepNestedDoc, FlatDoc, + MyImageDoc, NestedDoc, SimpleDoc, start_storage_v8, @@ -234,7 +235,7 @@ class MyDoc(BaseDoc): tensor: Union[NdArray, str] class MySchema(BaseDoc): - tensor: NdArray + tensor: NdArray[128] store = ElasticDocIndex[MySchema]() doc = [MyDoc(tensor=np.random.randn(128))] @@ -247,7 +248,7 @@ class MySchema(BaseDoc): def test_index_multi_modal_doc(): class MyMultiModalDoc(BaseDoc): - image: ImageDoc + image: MyImageDoc text: TextDoc store = ElasticDocIndex[MyMultiModalDoc]() @@ -264,6 +265,10 @@ class MyMultiModalDoc(BaseDoc): assert np.all(store[id_].image.embedding == doc[0].image.embedding) assert store[id_].text.text == doc[0].text.text + query = doc[0] + docs, _ = store.find(query, limit=10, search_field='image__embedding') + assert len(docs) > 0 + def test_elasticv7_version_check(): with pytest.raises(ImportError):