diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 0fcb2416df0..8caace91e54 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -22,7 +22,9 @@ def _init_storage( ): self._load_offset2ids() - def _init_subindices(self, *args, **kwargs): + def _init_subindices( + self, _docs: Optional['DocumentArraySourceType'] = None, *args, **kwargs + ): self._subindices = {} subindex_configs = kwargs.get('subindex_configs', None) if subindex_configs: @@ -39,7 +41,12 @@ def _init_subindices(self, *args, **kwargs): config, config_subindex, config_joined, name ) self._subindices[name] = self.__class__(config=config_joined) - self._subindices[name].extend(self.traverse_flat(name[1:])) + if _docs: + from docarray import DocumentArray + + self._subindices[name].extend( + DocumentArray(_docs).traverse_flat(name[1:]) + ) @abstractmethod def _ensure_unique_config( diff --git a/tests/unit/array/test_backend_configuration.py b/tests/unit/array/test_backend_configuration.py index cec6326c4eb..8255375e6d2 100644 --- a/tests/unit/array/test_backend_configuration.py +++ b/tests/unit/array/test_backend_configuration.py @@ -1,7 +1,20 @@ +import os +import random + import pytest import requests -from docarray import DocumentArray, Document +from docarray import Document, DocumentArray, dataclass +from docarray.typing import Image, Text + + +@dataclass +class MyDocument: + image: Image + paragraph: Text + + +cur_dir = os.path.dirname(os.path.abspath(__file__)) def test_weaviate_hnsw(start_storage): @@ -138,3 +151,50 @@ def test_cast_columns_qdrant(start_storage, type_da, type_column, request): index.extend(docs) assert len(index) == N + + +def test_random_subindices_config(): + database_index = random.randint(0, 100) + database_name = "jina" + str(database_index) + ".db" + table_index = random.randint(0, 100) + table_name = "test" + str(table_index) + subindice_image_index = random.randint(0, 100) + subindice_image_name = "test" + str(subindice_image_index) + subindice_paragraph_index = random.randint(0, 100) + subindice_paragraph_name = "test" + str(subindice_paragraph_index) + sqlite3_config = {'connection': database_name, 'table_name': table_name} + + common_subindex_config = { + '@.[image]': {'connection': database_name, 'table_name': subindice_image_name}, + '@.[paragraph]': { + 'connection': database_name, + 'table_name': subindice_paragraph_name, + }, + } + # extend with Documents, including embeddings + _docs = [ + ( + MyDocument( + image=os.path.join(cur_dir, '../document/toydata/test.png'), + paragraph='hello world', + ) + ) + ] + + da = DocumentArray( + storage='sqlite', # use SQLite as vector database + config=sqlite3_config, + subindex_configs=common_subindex_config, # set up subindices for image and description + ) + da.summary() + + for item in _docs: + d = Document(item) + da.append(d) + + da = DocumentArray( + storage='sqlite', # use SQLite as vector database + config=sqlite3_config, + subindex_configs=common_subindex_config, # set up subindices for image and description + ) + da.summary()