diff --git a/docarray/array/storage/annlite/backend.py b/docarray/array/storage/annlite/backend.py index 3c4261e5b04..89bb9e94bde 100644 --- a/docarray/array/storage/annlite/backend.py +++ b/docarray/array/storage/annlite/backend.py @@ -27,7 +27,7 @@ class AnnliteConfig: ef_construction: Optional[int] = None ef_search: Optional[int] = None max_connection: Optional[int] = None - columns: Optional[List[Tuple[str, str]]] = None + columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None class BackendMixin(BaseBackendMixin): @@ -53,11 +53,8 @@ def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType': def _normalize_columns(self, columns): columns = super()._normalize_columns(columns) - for i in range(len(columns)): - columns[i] = ( - columns[i][0], - self._map_type(columns[i][1]), - ) + for key in columns.keys(): + columns[key] = self._map_type(columns[key]) return columns def _ensure_unique_config( diff --git a/docarray/array/storage/base/backend.py b/docarray/array/storage/base/backend.py index 68d77611d80..0fcb2416df0 100644 --- a/docarray/array/storage/base/backend.py +++ b/docarray/array/storage/base/backend.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +import warnings from collections import namedtuple from dataclasses import is_dataclass, asdict -from typing import Dict, Optional, TYPE_CHECKING +from typing import Dict, Optional, TYPE_CHECKING, Union, List, Tuple if TYPE_CHECKING: from docarray.typing import DocumentArraySourceType, ArrayType @@ -77,7 +78,14 @@ def _map_embedding(self, embedding: 'ArrayType') -> 'ArrayType': def _map_type(self, col_type: str) -> str: return self.TYPE_MAP[col_type].type - def _normalize_columns(self, columns): + def _normalize_columns( + self, columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] + ) -> Dict[str, str]: if columns is None: - return [] + return {} + if isinstance(columns, list): + warnings.warn( + 'Using "columns" as a List of Tuples will be deprecated soon. Please provide a Dictionary.' + ) + columns = {col_desc[0]: col_desc[1] for col_desc in columns} return columns diff --git a/docarray/array/storage/elastic/backend.py b/docarray/array/storage/elastic/backend.py index 882dca65e5f..27279eff94a 100644 --- a/docarray/array/storage/elastic/backend.py +++ b/docarray/array/storage/elastic/backend.py @@ -44,7 +44,7 @@ class ElasticConfig: batch_size: int = 64 ef_construction: Optional[int] = None m: Optional[int] = None - columns: Optional[List[Tuple[str, str]]] = None + columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None _banned_indexname_chars = ['[', ' ', '"', '*', '\\', '<', '|', ',', '>', '/', '?', ']'] @@ -150,7 +150,7 @@ def _build_schema_from_elastic_config(self, elastic_config): 'index': True, } - for col, coltype in self._config.columns: + for col, coltype in self._config.columns.items(): da_schema['mappings']['properties'][col] = { 'type': self._map_type(coltype), 'index': True, diff --git a/docarray/array/storage/elastic/getsetdel.py b/docarray/array/storage/elastic/getsetdel.py index 7c20a0d2693..fcc93dc6924 100644 --- a/docarray/array/storage/elastic/getsetdel.py +++ b/docarray/array/storage/elastic/getsetdel.py @@ -12,7 +12,9 @@ class GetSetDelMixin(BaseGetSetDelMixin): MAX_ES_RETURNED_DOCS = 10000 def _document_to_elastic(self, doc: 'Document') -> Dict: - extra_columns = {col: doc.tags.get(col) for col, _ in self._config.columns} + extra_columns = { + col: doc.tags.get(col) for col, _ in self._config.columns.items() + } request = { '_op_type': 'index', '_id': doc.id, diff --git a/docarray/array/storage/qdrant/backend.py b/docarray/array/storage/qdrant/backend.py index b24df4171aa..5561def8801 100644 --- a/docarray/array/storage/qdrant/backend.py +++ b/docarray/array/storage/qdrant/backend.py @@ -42,7 +42,7 @@ class QdrantConfig: ef_construct: Optional[int] = None full_scan_threshold: Optional[int] = None m: Optional[int] = None - columns: Optional[List[Tuple[str, str]]] = None + columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None class BackendMixin(BaseBackendMixin): diff --git a/docarray/array/storage/qdrant/getsetdel.py b/docarray/array/storage/qdrant/getsetdel.py index fdc2f1b1069..17e5194ca49 100644 --- a/docarray/array/storage/qdrant/getsetdel.py +++ b/docarray/array/storage/qdrant/getsetdel.py @@ -65,7 +65,9 @@ def _qdrant_to_document(self, qdrant_record: dict) -> 'Document': ) def _document_to_qdrant(self, doc: 'Document') -> 'PointStruct': - extra_columns = {col: doc.tags.get(col) for col, _ in self._config.columns} + extra_columns = { + col: doc.tags.get(col) for col, _ in self._config.columns.items() + } return PointStruct( id=self._map_id(doc.id), diff --git a/docarray/array/storage/redis/backend.py b/docarray/array/storage/redis/backend.py index 41728834714..b9f54a821a3 100644 --- a/docarray/array/storage/redis/backend.py +++ b/docarray/array/storage/redis/backend.py @@ -31,7 +31,7 @@ class RedisConfig: ef_runtime: int = field(default=10) block_size: int = field(default=1048576) initial_cap: Optional[int] = None - columns: Optional[List[Tuple[str, str]]] = None + columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None class BackendMixin(BaseBackendMixin): @@ -146,7 +146,7 @@ def _build_schema_from_redis_config(self): index_param['INITIAL_CAP'] = self._config.initial_cap schema = [VectorField('embedding', self._config.method, index_param)] - for col, coltype in self._config.columns: + for col, coltype in self._config.columns.items(): schema.append(self._map_column(col, coltype)) return schema diff --git a/docarray/array/storage/redis/getsetdel.py b/docarray/array/storage/redis/getsetdel.py index 709d404d45e..d201a164c8e 100644 --- a/docarray/array/storage/redis/getsetdel.py +++ b/docarray/array/storage/redis/getsetdel.py @@ -90,7 +90,7 @@ def _del_doc_by_id(self, _id: str): def _document_to_redis(self, doc: 'Document') -> Dict: extra_columns = {} - for col, _ in self._config.columns: + for col, _ in self._config.columns.items(): tag = doc.tags.get(col) if tag is not None: extra_columns[col] = int(tag) if isinstance(tag, bool) else tag diff --git a/docarray/array/storage/sqlite/backend.py b/docarray/array/storage/sqlite/backend.py index 10e48ecbfe7..7422738a1e7 100644 --- a/docarray/array/storage/sqlite/backend.py +++ b/docarray/array/storage/sqlite/backend.py @@ -1,16 +1,8 @@ import sqlite3 import warnings -from dataclasses import dataclass, field, asdict +from dataclasses import dataclass, field from tempfile import NamedTemporaryFile -from typing import ( - Iterable, - Dict, - Optional, - TYPE_CHECKING, - Union, - List, - Tuple, -) +from typing import Iterable, Dict, Optional, TYPE_CHECKING, Union from docarray.array.storage.sqlite.helper import initialize_table from docarray.array.storage.base.backend import BaseBackendMixin diff --git a/docarray/array/storage/weaviate/backend.py b/docarray/array/storage/weaviate/backend.py index ba25b12d280..9d747908562 100644 --- a/docarray/array/storage/weaviate/backend.py +++ b/docarray/array/storage/weaviate/backend.py @@ -45,7 +45,7 @@ class WeaviateConfig: flat_search_cutoff: Optional[int] = None cleanup_interval_seconds: Optional[int] = None skip: Optional[bool] = None - columns: Optional[List[Tuple[str, str]]] = None + columns: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None distance: Optional[str] = None @@ -215,7 +215,7 @@ def _get_schema_by_name(self, cls_name: str) -> Dict: }, ] } - for col, coltype in self._config.columns: + for col, coltype in self._config.columns.items(): new_property = { 'dataType': [self._map_type(coltype)], 'name': col, @@ -352,10 +352,9 @@ def _doc2weaviate_create_payload(self, value: 'Document'): :param value: document to create a payload for :return: the payload dictionary """ - columns_dict = {key: val for [key, val] in self._config.columns} extra_columns = { - col: self._map_column(value.tags.get(col), columns_dict[col]) - for col, _ in self._config.columns + col: self._map_column(value.tags.get(col), col_type) + for col, col_type in self._config.columns.items() } return dict( diff --git a/docs/advanced/document-store/annlite.md b/docs/advanced/document-store/annlite.md index a63f1fafd8b..09add22d81f 100644 --- a/docs/advanced/document-store/annlite.md +++ b/docs/advanced/document-store/annlite.md @@ -72,7 +72,7 @@ da = DocumentArray( storage='annlite', config={ 'n_dim': n_dim, - 'columns': [('price', 'float')], + 'columns': {'price': 'float'}, }, ) @@ -125,7 +125,7 @@ metric = 'Euclidean' da = DocumentArray( storage='annlite', - config={'n_dim': n_dim, 'columns': [('price', 'float')], 'metric': metric}, + config={'n_dim': n_dim, 'columns': {'price': 'float'}, 'metric': metric}, ) with da: diff --git a/docs/advanced/document-store/elasticsearch.md b/docs/advanced/document-store/elasticsearch.md index db76e23fd34..67f069cd122 100644 --- a/docs/advanced/document-store/elasticsearch.md +++ b/docs/advanced/document-store/elasticsearch.md @@ -132,7 +132,7 @@ n_dim = 3 da = DocumentArray( storage='elasticsearch', - config={'n_dim': 3, 'columns': [('price', 'int')], 'distance': 'l2_norm'}, + config={'n_dim': 3, 'columns': {'price': 'int'}, 'distance': 'l2_norm'}, ) with da: @@ -172,7 +172,7 @@ n_dim = 3 da = DocumentArray( storage='elasticsearch', - config={'n_dim': n_dim, 'columns': [('price', 'int')], 'distance': 'l2_norm'}, + config={'n_dim': n_dim, 'columns': {'price': 'int'}, 'distance': 'l2_norm'}, ) with da: @@ -248,7 +248,7 @@ da = DocumentArray( storage='elasticsearch', config={ 'n_dim': n_dim, - 'columns': [('price', 'float')], + 'columns': {'price': 'float'}, }, ) diff --git a/docs/advanced/document-store/index.md b/docs/advanced/document-store/index.md index 19ec7d84f6e..5d5ae4cc148 100644 --- a/docs/advanced/document-store/index.md +++ b/docs/advanced/document-store/index.md @@ -244,7 +244,7 @@ metric = 'Euclidean' da = DocumentArray( storage='annlite', - config={'n_dim': n_dim, 'columns': [('price', 'float')], 'metric': metric}, + config={'n_dim': n_dim, 'columns': {'price': 'float'}, 'metric': metric}, ) with da: @@ -276,7 +276,7 @@ metric = 'Euclidean' da = DocumentArray( storage='annlite', - config={'n_dim': n_dim, 'columns': [('price', 'float')], 'metric': metric}, + config={'n_dim': n_dim, 'columns': {'price': 'float'}, 'metric': metric}, ) with da: @@ -317,7 +317,7 @@ metric = 'Euclidean' da = DocumentArray( storage='annlite', - config={'n_dim': n_dim, 'columns': [('price', 'float')], 'metric': metric}, + config={'n_dim': n_dim, 'columns': {'price': 'float'}, 'metric': metric}, ) with da: diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index c6fb2c86040..d5a8b1b35c5 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -155,7 +155,7 @@ distance = 'euclidean' da = DocumentArray( storage='qdrant', - config={'n_dim': n_dim, 'columns': [('price', 'float')], 'distance': distance}, + config={'n_dim': n_dim, 'columns': {'price': 'float'}, 'distance': distance}, ) print(f'\nDocumentArray distance: {distance}') diff --git a/docs/advanced/document-store/redis.md b/docs/advanced/document-store/redis.md index 7aad6e82956..763d6378606 100644 --- a/docs/advanced/document-store/redis.md +++ b/docs/advanced/document-store/redis.md @@ -111,7 +111,7 @@ da2.summary() │ ef_runtime 10 │ │ block_size 1048576 │ │ initial_cap None │ -│ columns [] │ +│ columns {} │ │ │ ╰─────────────────────────────────╯ ``` @@ -146,7 +146,7 @@ da = DocumentArray( storage='redis', config={ 'n_dim': n_dim, - 'columns': [('price', 'int'), ('color', 'str')], + 'columns': {'price': 'int', 'color': 'str'}, 'flush': True, 'distance': 'L2', }, diff --git a/docs/advanced/document-store/weaviate.md b/docs/advanced/document-store/weaviate.md index 9fadcec01d9..f94b341d685 100644 --- a/docs/advanced/document-store/weaviate.md +++ b/docs/advanced/document-store/weaviate.md @@ -191,7 +191,7 @@ da = DocumentArray( storage='weaviate', config={ 'n_dim': n_dim, - 'columns': [('price', 'float')], + 'columns': {'price': 'float'}, }, ) @@ -243,7 +243,7 @@ n_dim = 3 da = DocumentArray( storage='weaviate', - config={'n_dim': n_dim, 'columns': [('price', 'int')], 'distance': 'l2-squared'}, + config={'n_dim': n_dim, 'columns': {'price': 'int'}, 'distance': 'l2-squared'}, ) with da: @@ -317,7 +317,7 @@ da = DocumentArray( storage='weaviate', config={ 'n_dim': n_dim, - 'columns': [('price', 'float')], + 'columns': {'price': 'float'}, 'distance': 'l2-squared', "name": "Persisted", "host": "localhost", diff --git a/docs/fundamentals/documentarray/subindex.md b/docs/fundamentals/documentarray/subindex.md index f6e24395856..405f7b0f19b 100644 --- a/docs/fundamentals/documentarray/subindex.md +++ b/docs/fundamentals/documentarray/subindex.md @@ -71,7 +71,7 @@ da = DocumentArray( │ ef_construction None │ │ ef_search None │ │ max_connection None │ -│ columns [] │ +│ columns {} │ │ │ ╰─────────────────────────────────────────╯ ``` @@ -129,7 +129,7 @@ da = DocumentArray( │ ef_construction None │ │ ef_search None │ │ max_connection None │ -│ columns [] │ +│ columns {} │ │ │ ╰─────────────────────────────────────────╯ ``` @@ -231,4 +231,4 @@ top_level_matches = da[top_image_matches[:, 'parent_id']] top_image_matches = da.find(query=np.random.rand(512), on='@c') top_level_matches = da[top_image_matches[:, 'parent_id']] ``` -```` \ No newline at end of file +```` diff --git a/setup.py b/setup.py index 6b1ef0cc4b3..a9f72500dfc 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ 'qdrant-client~=0.7.3', ], 'annlite': [ - 'annlite>=0.3.2', + 'annlite>=0.3.10', ], 'weaviate': [ 'weaviate-client~=3.3.0', @@ -105,7 +105,7 @@ 'jupyterlab', 'transformers>=4.16.2', 'weaviate-client~=3.3.0', - 'annlite>=0.3.2', + 'annlite>=0.3.10', 'elasticsearch>=8.2.0', 'redis>=4.3.0', 'jina', diff --git a/tests/unit/array/mixins/test_find.py b/tests/unit/array/mixins/test_find.py index 6ad66f22a5b..6ce3d4d1fbc 100644 --- a/tests/unit/array/mixins/test_find.py +++ b/tests/unit/array/mixins/test_find.py @@ -1,5 +1,3 @@ -from itertools import product - import numpy as np import pytest @@ -361,8 +359,9 @@ def test_find_by_tag(storage, config, start_storage): ], ], ) +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) def test_search_pre_filtering( - storage, filter_gen, operator, numeric_operators, start_storage + storage, filter_gen, operator, numeric_operators, start_storage, columns ): np.random.seed(0) n_dim = 128 @@ -370,12 +369,10 @@ def test_search_pre_filtering( if storage == 'redis': da = DocumentArray( storage=storage, - config={'n_dim': n_dim, 'columns': [('price', 'int')], 'flush': True}, + config={'n_dim': n_dim, 'columns': columns, 'flush': True}, ) else: - da = DocumentArray( - storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]} - ) + da = DocumentArray(storage=storage, config={'n_dim': n_dim, 'columns': columns}) da.extend( [ @@ -468,18 +465,19 @@ def test_search_pre_filtering( ], ], ) -def test_filtering(storage, filter_gen, operator, numeric_operators, start_storage): +@pytest.mark.parametrize('columns', [[('price', 'float')], {'price': 'float'}]) +def test_filtering( + storage, filter_gen, operator, numeric_operators, start_storage, columns +): n_dim = 128 if storage == 'redis': da = DocumentArray( storage=storage, - config={'n_dim': n_dim, 'columns': [('price', 'float')], 'flush': True}, + config={'n_dim': n_dim, 'columns': columns, 'flush': True}, ) else: - da = DocumentArray( - storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'float')]} - ) + da = DocumentArray(storage=storage, config={'n_dim': n_dim, 'columns': columns}) da.extend([Document(id=f'r{i}', tags={'price': i}) for i in range(50)]) thresholds = [10, 20, 30] @@ -496,11 +494,10 @@ def test_filtering(storage, filter_gen, operator, numeric_operators, start_stora ) -def test_weaviate_filter_query(start_storage): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_weaviate_filter_query(start_storage, columns): n_dim = 128 - da = DocumentArray( - storage='weaviate', config={'n_dim': n_dim, 'columns': [('price', 'int')]} - ) + da = DocumentArray(storage='weaviate', config={'n_dim': n_dim, 'columns': columns}) da.extend( [ @@ -518,13 +515,17 @@ def test_weaviate_filter_query(start_storage): assert isinstance(da._filter(filter={}), type(da)) -def test_redis_category_filter(start_storage): +@pytest.mark.parametrize( + 'columns', + [[('color', 'str'), ('isfake', 'bool')], {'color': 'str', 'isfake': 'bool'}], +) +def test_redis_category_filter(start_storage, columns): n_dim = 128 da = DocumentArray( storage='redis', config={ 'n_dim': n_dim, - 'columns': [('color', 'str'), ('isfake', 'bool')], + 'columns': columns, 'flush': True, }, ) @@ -580,12 +581,11 @@ def test_redis_category_filter(start_storage): @pytest.mark.parametrize('storage', ['memory']) -def test_unsupported_pre_filtering(storage, start_storage): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_unsupported_pre_filtering(storage, start_storage, columns): n_dim = 128 - da = DocumentArray( - storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]} - ) + da = DocumentArray(storage=storage, config={'n_dim': n_dim, 'columns': columns}) da.extend( [ diff --git a/tests/unit/array/mixins/test_match.py b/tests/unit/array/mixins/test_match.py index f246a928e04..70cd79902bb 100644 --- a/tests/unit/array/mixins/test_match.py +++ b/tests/unit/array/mixins/test_match.py @@ -698,20 +698,19 @@ def test_match_ensure_scores_unique(): ], ], ) +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) def test_match_pre_filtering( - storage, filter_gen, operator, numeric_operators, start_storage + storage, filter_gen, operator, numeric_operators, start_storage, columns ): n_dim = 128 if storage == 'redis': da = DocumentArray( storage=storage, - config={'n_dim': n_dim, 'columns': [('price', 'int')], 'flush': True}, + config={'n_dim': n_dim, 'columns': columns, 'flush': True}, ) else: - da = DocumentArray( - storage=storage, config={'n_dim': n_dim, 'columns': [('price', 'int')]} - ) + da = DocumentArray(storage=storage, config={'n_dim': n_dim, 'columns': columns}) da.extend( [ diff --git a/tests/unit/array/storage/elastic/test_add.py b/tests/unit/array/storage/elastic/test_add.py index 86f53d6031f..775bfcfec6a 100644 --- a/tests/unit/array/storage/elastic/test_add.py +++ b/tests/unit/array/storage/elastic/test_add.py @@ -5,12 +5,13 @@ @pytest.mark.filterwarnings('ignore::UserWarning') -def test_add_ignore_existing_doc_id(start_storage): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_add_ignore_existing_doc_id(start_storage, columns): elastic_doc = DocumentArray( storage='elasticsearch', config={ 'n_dim': 3, - 'columns': [('price', 'int')], + 'columns': columns, 'distance': 'l2_norm', 'index_name': 'test_add_ignore_existing_doc_id', }, @@ -48,12 +49,13 @@ def test_add_ignore_existing_doc_id(start_storage): @pytest.mark.filterwarnings('ignore::UserWarning') -def test_add_skip_wrong_data_type_and_fix_offset(start_storage): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_add_skip_wrong_data_type_and_fix_offset(start_storage, columns): elastic_doc = DocumentArray( storage='elasticsearch', config={ 'n_dim': 3, - 'columns': [('price', 'int')], + 'columns': columns, 'index_name': 'test_add_skip_wrong_data_type_and_fix_offset', }, ) @@ -91,8 +93,19 @@ def test_add_skip_wrong_data_type_and_fix_offset(start_storage): @pytest.mark.filterwarnings('ignore::UserWarning') @pytest.mark.parametrize("assert_customization_propagation", [True, False]) +@pytest.mark.parametrize( + 'columns', + [ + [ + ('is_true', 'bool'), + ('test_long', 'long'), + ('test_double', 'double'), + ], + {'is_true': 'bool', 'test_long': 'long', 'test_double': 'double'}, + ], +) def test_succes_add_bulk_custom_params( - monkeypatch, start_storage, assert_customization_propagation + monkeypatch, start_storage, assert_customization_propagation, columns ): bulk_custom_params = { 'thread_count': 4, @@ -117,11 +130,7 @@ def _mock_send_requests(requests, **kwargs): storage='elasticsearch', config={ 'n_dim': 3, - 'columns': [ - ('is_true', 'bool'), - ('test_long', 'long'), - ('test_double', 'double'), - ], + 'columns': columns, 'distance': 'l2_norm', 'index_name': 'test_succes_add_bulk_custom_params', }, diff --git a/tests/unit/array/storage/elastic/test_data_type.py b/tests/unit/array/storage/elastic/test_data_type.py index 6813fa0177b..c849d4600d6 100644 --- a/tests/unit/array/storage/elastic/test_data_type.py +++ b/tests/unit/array/storage/elastic/test_data_type.py @@ -1,16 +1,24 @@ +import pytest from docarray import DocumentArray, Document -def test_data_type(start_storage): +@pytest.mark.parametrize( + 'columns', + [ + [ + ('is_true', 'bool'), + ('test_long', 'long'), + ('test_double', 'double'), + ], + {'is_true': 'bool', 'test_long': 'long', 'test_double': 'double'}, + ], +) +def test_data_type(start_storage, columns): elastic_doc = DocumentArray( storage='elasticsearch', config={ 'n_dim': 3, - 'columns': [ - ('is_true', 'bool'), - ('test_long', 'long'), - ('test_double', 'double'), - ], + 'columns': columns, 'distance': 'l2_norm', 'index_name': 'test_data_type', }, diff --git a/tests/unit/array/storage/elastic/test_del.py b/tests/unit/array/storage/elastic/test_del.py index fd8e54b1843..8e646239017 100644 --- a/tests/unit/array/storage/elastic/test_del.py +++ b/tests/unit/array/storage/elastic/test_del.py @@ -1,17 +1,20 @@ -from docarray import Document, DocumentArray import pytest +from docarray import Document, DocumentArray + @pytest.mark.filterwarnings('ignore::UserWarning') @pytest.mark.parametrize('deleted_elmnts', [[0, 1], ['r0', 'r1']]) -def test_delete_offset_success_sync_es_offset_index(deleted_elmnts, start_storage): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_delete_offset_success_sync_es_offset_index( + deleted_elmnts, start_storage, columns +): elastic_doc = DocumentArray( storage='elasticsearch', config={ 'n_dim': 3, - 'columns': [('price', 'int')], + 'columns': columns, 'distance': 'l2_norm', - 'index_name': 'test_delete_offset_success_sync_es_offset_index', }, ) @@ -51,14 +54,14 @@ def test_delete_offset_success_sync_es_offset_index(deleted_elmnts, start_storag @pytest.mark.filterwarnings('ignore::UserWarning') -def test_success_handle_bulk_delete_not_found(start_storage): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_success_handle_bulk_delete_not_found(start_storage, columns): elastic_doc = DocumentArray( storage='elasticsearch', config={ 'n_dim': 3, - 'columns': [('price', 'int')], + 'columns': columns, 'distance': 'l2_norm', - 'index_name': 'test_bulk_delete_not_found', }, ) with elastic_doc: diff --git a/tests/unit/array/storage/elastic/test_get.py b/tests/unit/array/storage/elastic/test_get.py index b40b2eaaf5c..b27e129a071 100644 --- a/tests/unit/array/storage/elastic/test_get.py +++ b/tests/unit/array/storage/elastic/test_get.py @@ -1,17 +1,18 @@ -from docarray import Document, DocumentArray import numpy as np import pytest +from docarray import Document, DocumentArray + @pytest.mark.parametrize('nrof_docs', [10, 100, 10_000, 10_100, 20_000, 20_100]) -def test_success_get_bulk_data(start_storage, nrof_docs): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_success_get_bulk_data(start_storage, nrof_docs, columns): elastic_doc = DocumentArray( storage='elasticsearch', config={ 'n_dim': 3, - 'columns': [('price', 'int')], + 'columns': columns, 'distance': 'l2_norm', - 'index_name': 'test_get_bulk_data', }, ) @@ -26,16 +27,16 @@ def test_success_get_bulk_data(start_storage, nrof_docs): assert len(elastic_doc[:, 'id']) == nrof_docs -def test_error_get_bulk_data_id_not_exist(start_storage): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_error_get_bulk_data_id_not_exist(start_storage, columns): nrof_docs = 10 elastic_doc = DocumentArray( storage='elasticsearch', config={ 'n_dim': 3, - 'columns': [('price', 'int')], + 'columns': columns, 'distance': 'l2_norm', - 'index_name': 'test_error_get_bulk_data_id_not_exist', }, ) diff --git a/tests/unit/array/storage/redis/__init__.py b/tests/unit/array/storage/redis/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/array/storage/redis/test_backend.py b/tests/unit/array/storage/redis/test_backend.py index 04de55febba..ca57f99915b 100644 --- a/tests/unit/array/storage/redis/test_backend.py +++ b/tests/unit/array/storage/redis/test_backend.py @@ -53,6 +53,9 @@ def da_redis(): [('attr1', 'str'), ('attr2', 'bytes')], [('attr1', 'int'), ('attr2', 'float')], [('attr1', 'double'), ('attr2', 'long'), ('attr3', 'bool')], + {'attr1': 'str', 'attr2': 'bytes'}, + {'attr1': 'int', 'attr2': 'float'}, + {'attr1': 'double', 'attr2': 'long', 'attr3': 'bool'}, ], ) @pytest.mark.parametrize( @@ -92,26 +95,16 @@ def test_init_storage( assert redis_da._client.ft().info()['attributes'][0][1] == b'embedding' assert redis_da._client.ft().info()['attributes'][0][5] == b'VECTOR' - for i in range(len(columns)): - assert redis_da._client.ft().info()['attributes'][i + 1][1] == bytes( - redis_da._config.columns[i][0], 'utf-8' - ) - assert ( - redis_da._client.ft().info()['attributes'][i + 1][5] - == type_convert[redis_da._config.columns[i][1]] - ) - def test_init_storage_update_schema(start_storage): - - cfg = RedisConfig(n_dim=128, columns=[('attr1', 'str')], flush=True) + cfg = RedisConfig(n_dim=128, columns={'attr1': 'str'}, flush=True) redis_da = DocumentArrayDummy(storage='redis', config=cfg) assert redis_da._client.ft().info()['attributes'][1][1] == b'attr1' - cfg = RedisConfig(n_dim=128, columns=[('attr2', 'str')], update_schema=False) + cfg = RedisConfig(n_dim=128, columns={'attr2': 'str'}, update_schema=False) redis_da = DocumentArrayDummy(storage='redis', config=cfg) assert redis_da._client.ft().info()['attributes'][1][1] == b'attr1' - cfg = RedisConfig(n_dim=128, columns=[('attr2', 'str')], update_schema=True) + cfg = RedisConfig(n_dim=128, columns={'attr2': 'str'}, update_schema=True) redis_da = DocumentArrayDummy(storage='redis', config=cfg) assert redis_da._client.ft().info()['attributes'][1][1] == b'attr2' diff --git a/tests/unit/array/storage/redis/test_getsetdel.py b/tests/unit/array/storage/redis/test_getsetdel.py index cd2b3f3d43c..dfe7d9b0cdc 100644 --- a/tests/unit/array/storage/redis/test_getsetdel.py +++ b/tests/unit/array/storage/redis/test_getsetdel.py @@ -26,14 +26,15 @@ def _save_offset2ids(self): @pytest.fixture(scope='function') def columns(): - columns = [ - ('col_str', 'str'), - ('col_bytes', 'bytes'), - ('col_int', 'int'), - ('col_float', 'float'), - ('col_long', 'long'), - ('col_double', 'double'), - ] + columns = { + 'col_str': 'str', + 'col_bytes': 'bytes', + 'col_int': 'int', + 'col_float': 'float', + 'col_long': 'long', + 'col_double': 'double', + } + return columns @@ -92,7 +93,7 @@ def test_document_to_embedding( else: assert payload['text'] == text - for col, _ in columns: + for col, _ in columns.items(): if col in tags: assert payload[col] == tags[col] else: @@ -100,7 +101,7 @@ def test_document_to_embedding( payload[col] for key in tags: - if key not in (col[0] for col in columns): + if key not in (col for col in columns.keys()): assert key not in payload diff --git a/tests/unit/array/test_backend_configuration.py b/tests/unit/array/test_backend_configuration.py index 86ddc969a52..cec6326c4eb 100644 --- a/tests/unit/array/test_backend_configuration.py +++ b/tests/unit/array/test_backend_configuration.py @@ -1,8 +1,5 @@ -from typing import Tuple, Iterator - import pytest import requests -import itertools from docarray import DocumentArray, Document @@ -51,15 +48,15 @@ def test_weaviate_hnsw(start_storage): assert main_class.get('vectorIndexConfig', {}).get('distance') == 'l2-squared' -def test_weaviate_da_w_protobuff(start_storage): +@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}]) +def test_weaviate_da_w_protobuff(start_storage, columns): N = 10 index = DocumentArray( storage='weaviate', config={ - 'name': 'Test', - 'columns': [('price', 'int')], + 'columns': columns, }, ) @@ -86,7 +83,7 @@ def test_cast_columns_weaviate(start_storage, type_da, type_column, request): storage='weaviate', config={ 'name': f'Test{test_id}', - 'columns': [('price', type_column)], + 'columns': {'price': type_column}, }, ) @@ -107,7 +104,7 @@ def test_cast_columns_annlite(start_storage, type_da, type_column): storage='annlite', config={ 'n_dim': 3, - 'columns': [('price', type_column)], + 'columns': {'price': type_column}, }, ) @@ -132,7 +129,7 @@ def test_cast_columns_qdrant(start_storage, type_da, type_column, request): config={ 'collection_name': f'test{test_id}', 'n_dim': 3, - 'columns': [('price', type_column)], + 'columns': {'price': type_column}, }, )