diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5c29ec99a19..b9e79e24fbc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -238,7 +238,7 @@ jobs: pytest --suppress-no-test-exit-code --cov=docarray --cov-report=xml \ -v -s -m "not gpu" ${{ matrix.test-path }} echo "::set-output name=codecov_flag::docarray" - timeout-minutes: 30 + timeout-minutes: 40 env: JINA_AUTH_TOKEN: "${{ secrets.JINA_AUTH_TOKEN }}" - name: Check codecov file diff --git a/GOVERNANCE.md b/GOVERNANCE.md index bbad6eb5f08..1801a0cd16b 100644 --- a/GOVERNANCE.md +++ b/GOVERNANCE.md @@ -28,9 +28,9 @@ Project releases will occur on a scheduled basis as agreed to by the committers. # Communication -This project, just like all of open source, is a global community. In addition to the [Code of Conduct](./.github/CODE_OF_CONDUCT.md), this project will: +This project, just like all open source, is a global community. In addition to the [Code of Conduct](./.github/CODE_OF_CONDUCT.md), this project will: -* Keep all communucation on open channels ( mailing list, forums, chat ). +* Keep all communication on open channels ( mailing list, forums, chat ). * Be respectful of time and language differences between community members ( such as scheduling meetings, email/issue responsiveness, etc ). * Ensure tools are able to be used by community members regardless of their region. diff --git a/docarray/array/storage/qdrant/backend.py b/docarray/array/storage/qdrant/backend.py index 9f5e195283e..c5509e623f0 100644 --- a/docarray/array/storage/qdrant/backend.py +++ b/docarray/array/storage/qdrant/backend.py @@ -1,5 +1,6 @@ import copy import uuid +from abc import abstractmethod from dataclasses import dataclass, field, asdict from typing import ( Optional, @@ -19,6 +20,7 @@ PointsList, PointStruct, HnswConfigDiff, + VectorParams, ) from docarray import Document @@ -38,6 +40,10 @@ class QdrantConfig: collection_name: Optional[str] = None host: Optional[str] = field(default="localhost") port: Optional[int] = field(default=6333) + grpc_port: Optional[int] = field(default=6334) + prefer_grpc: Optional[bool] = field(default=False) + api_key: Optional[str] = field(default=None) + https: Optional[bool] = field(default=None) serialize_config: Dict = field(default_factory=dict) scroll_batch_size: int = 64 ef_construct: Optional[int] = None @@ -47,6 +53,11 @@ class QdrantConfig: class BackendMixin(BaseBackendMixin): + @property + @abstractmethod + def client(self) -> 'QdrantClient': + raise NotImplementedError() + @classmethod def _tmp_collection_name(cls) -> str: return uuid.uuid4().hex @@ -85,7 +96,14 @@ def _init_storage( self._distance = config.distance self._serialize_config = config.serialize_config - self._client = QdrantClient(host=config.host, port=config.port) + self._client = QdrantClient( + host=config.host, + port=config.port, + prefer_grpc=config.prefer_grpc, + grpc_port=config.grpc_port, + api_key=config.api_key, + https=config.https, + ) self._config = config @@ -133,13 +151,13 @@ def _initialize_qdrant_schema(self): full_scan_threshold=self._config.full_scan_threshold, m=self._config.m, ) - self.client.http.collections_api.create_collection( + self.client.recreate_collection( collection_name=self.collection_name, - create_collection=CreateCollection( - vector_size=self._n_dim, - distance=DISTANCES[self._distance], - hnsw_config=hnsw_config, + vectors_config=VectorParams( + size=self.n_dim, + distance=self.distance, ), + hnsw_config=hnsw_config, ) def _collection_exists(self, collection_name): @@ -164,38 +182,36 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__ = state self._client = QdrantClient( - host=state['_config'].host, port=state['_config'].port + host=state['_config'].host, + port=state['_config'].port, + prefer_grpc=state['_config'].prefer_grpc, + grpc_port=state['_config'].grpc_port, + api_key=state['_config'].api_key, + https=state['_config'].https, ) def _get_offset2ids_meta(self) -> List[str]: if not self._collection_exists(self.collection_name_meta): return [] - results = self.client.retrieve( - collection_name=self.collection_name_meta, ids=[1] - ) - if len(results) == 0: - return [] - else: - return results[0].payload.get('offset2id', []) + return self.client.retrieve(self.collection_name_meta, ids=[1])[0].payload[ + 'offset2id' + ] def _update_offset2ids_meta(self): if not self._collection_exists(self.collection_name_meta): self.client.recreate_collection( - self.collection_name_meta, - vector_size=1, - distance=Distance.COSINE, + collection_name=self.collection_name_meta, + vectors_config={}, # no vectors ) - self.client.http.points_api.upsert_points( + self.client.upsert( collection_name=self.collection_name_meta, + points=[ + PointStruct( + id=1, payload={"offset2id": self._offset2ids.ids}, vector={} + ) + ], wait=True, - point_insert_operations=PointsList( - points=[ - PointStruct( - id=1, payload={"offset2id": self._offset2ids.ids}, vector=[1] - ) - ] - ), ) def _map_embedding(self, embedding: 'ArrayType') -> List[float]: @@ -214,4 +230,5 @@ def _map_embedding(self, embedding: 'ArrayType') -> List[float]: if np.all(embedding == 0): embedding = embedding + EPSILON - return embedding.tolist() + + return embedding.astype(float).tolist() diff --git a/docarray/array/storage/qdrant/find.py b/docarray/array/storage/qdrant/find.py index 3086bb286b3..4692b37d394 100644 --- a/docarray/array/storage/qdrant/find.py +++ b/docarray/array/storage/qdrant/find.py @@ -60,7 +60,7 @@ def _find_similar_vectors( search_params=None if not search_params else rest.SearchParams(**search_params), - top=limit, + limit=limit, append_payload=['_serialized'], ) diff --git a/docarray/array/storage/qdrant/getsetdel.py b/docarray/array/storage/qdrant/getsetdel.py index 17e5194ca49..b0974816851 100644 --- a/docarray/array/storage/qdrant/getsetdel.py +++ b/docarray/array/storage/qdrant/getsetdel.py @@ -5,9 +5,8 @@ from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models.models import ( PointIdsList, - PointsList, - ScrollRequest, PointStruct, + VectorParams, ) from docarray import Document @@ -46,17 +45,17 @@ def _upload_batch(self, docs: Iterable['Document']): for doc in docs: batch.append(self._document_to_qdrant(doc)) if len(batch) > self.scroll_batch_size: - self.client.http.points_api.upsert_points( + self.client.upsert( collection_name=self.collection_name, + points=batch, wait=True, - point_insert_operations=PointsList(points=batch), ) batch = [] if len(batch) > 0: - self.client.http.points_api.upsert_points( + self.client.upsert( collection_name=self.collection_name, wait=True, - point_insert_operations=PointsList(points=batch), + points=batch, ) def _qdrant_to_document(self, qdrant_record: dict) -> 'Document': @@ -79,49 +78,47 @@ def _document_to_qdrant(self, doc: 'Document') -> 'PointStruct': def _get_doc_by_id(self, _id: str) -> 'Document': try: - resp = self.client.http.points_api.get_point( - collection_name=self.collection_name, id=self._map_id(_id) + resp = self.client.retrieve( + collection_name=self.collection_name, ids=[self._map_id(_id)] ) - return self._qdrant_to_document(resp.result.payload) + if len(resp) == 0: + raise KeyError(_id) + return self._qdrant_to_document(resp[0].payload) except UnexpectedResponse as response_error: if response_error.status_code in [404, 400]: raise KeyError(_id) def _del_doc_by_id(self, _id: str): - self.client.http.points_api.delete_points( + self.client.delete( collection_name=self.collection_name, - wait=True, points_selector=PointIdsList(points=[self._map_id(_id)]), + wait=True, ) def _set_doc_by_id(self, _id: str, value: 'Document'): if _id != value.id: self._del_doc_by_id(_id) - self.client.http.points_api.upsert_points( + self.client.upsert( collection_name=self.collection_name, wait=True, - point_insert_operations=PointsList( - points=[self._document_to_qdrant(value)] - ), + points=[self._document_to_qdrant(value)], ) def scan(self) -> Iterator['Document']: offset = None while True: - response = self.client.http.points_api.scroll_points( + response, next_page = self.client.scroll( collection_name=self.collection_name, - scroll_request=ScrollRequest( - offset=offset, - limit=self.scroll_batch_size, - with_payload=['_serialized'], - with_vector=False, - ), + offset=offset, + limit=self.scroll_batch_size, + with_payload=['_serialized'], + with_vectors=False, ) - for point in response.result.points: + for point in response: yield self._qdrant_to_document(point.payload) - if response.result.next_page_offset: - offset = response.result.next_page_offset + if next_page: + offset = next_page else: break @@ -133,8 +130,10 @@ def _save_offset2ids(self): self._update_offset2ids_meta() def _clear_storage(self): - self._client.recreate_collection( + self.client.recreate_collection( self.collection_name, - vector_size=self.n_dim, - distance=self.distance, + vectors_config=VectorParams( + size=self.n_dim, + distance=self.distance, + ), ) diff --git a/docarray/array/storage/qdrant/seqlike.py b/docarray/array/storage/qdrant/seqlike.py index 7ded158bc4d..92d068997e8 100644 --- a/docarray/array/storage/qdrant/seqlike.py +++ b/docarray/array/storage/qdrant/seqlike.py @@ -43,9 +43,7 @@ def __eq__(self, other): ) def __len__(self): - return self.client.http.collections_api.get_collection( - self.collection_name - ).result.vectors_count + return self.client.get_collection(self.collection_name).points_count def __contains__(self, x: Union[str, 'Document']): if isinstance(x, str): diff --git a/docs/advanced/document-store/qdrant.md b/docs/advanced/document-store/qdrant.md index 2e51bbaeaf5..e6ff683de6f 100644 --- a/docs/advanced/document-store/qdrant.md +++ b/docs/advanced/document-store/qdrant.md @@ -19,9 +19,10 @@ server. Create `docker-compose.yml` as follows: version: '3.4' services: qdrant: - image: qdrant/qdrant:v0.8.0 + image: qdrant/qdrant:v0.10.1 ports: - "6333:6333" + - "6334:6334" ulimits: # Only required for tests, as there are a lot of collections created nofile: soft: 65535 @@ -79,13 +80,19 @@ The following configs can be set: |-----------------------|----------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------| | `n_dim` | Number of dimensions of embeddings to be stored and retrieved | **This is always required** | | `collection_name` | Qdrant collection name client | **Random collection name generated** | -| `host` | Hostname of the Qdrant server | 'localhost' | -| `port` | port of the Qdrant server | 6333 | -| `distance` | Distance metric to be used during search. Can be 'cosine', 'dot' or 'euclidean' | 'cosine' | -| `scroll_batch_size` | batch size used when scrolling over the storage | 64 | -| `ef_construct` | Number of neighbours to consider during the index building. Larger the value - more accurate the search, more time required to build index. | `None`, defaults to the default value in Qdrant* | -| `full_scan_threshold` | Minimal amount of points for additional payload-based indexing. | `None`, defaults to the default value in Qdrant* | -| `m` | Number of edges per node in the index graph. Larger the value - more accurate the search, more space required. | `None`, defaults to the default value in Qdrant* | +| `distance` | Distance metric to be used during search. Can be 'cosine', 'dot' or 'euclidean' | `'cosine'` | +| `host` | Hostname of the Qdrant server | `'localhost'` | +| `port` | Port of the Qdrant server | `6333` | +| `grpc_port` | Port of the Qdrant gRPC interface | `6334` | +| `prefer_grpc` | Set `true` to use gPRC interface whenever possible in custom methods | `False` | +| `api_key` | API key for authentication in Qdrant Cloud | `None` | +| `https` | Set `true` to use HTTPS(SSL) protocol | `None` | +| `serialize_config` | [Serialization config of each Document](../../../fundamentals/document/serialization.md) | `None` | +| `scroll_batch_size` | Batch size used when scrolling over the storage | 64 | +| `ef_construct` | Number of neighbours to consider during the index building. Larger the value - more accurate the search, more time required to build index | `None`, defaults to the default value in Qdrant* | +| `full_scan_threshold` | Minimal amount of points for additional payload-based indexing | `None`, defaults to the default value in Qdrant* | +| `m` | Number of edges per node in the index graph. Larger the value - more accurate the search, more space required | `None`, defaults to the default value in Qdrant* | +| `columns` | Other fields to store in Document | `None` | *You can read more about the HNSW parameters and their default values [here](https://qdrant.tech/documentation/indexing/#vector-index) @@ -98,9 +105,10 @@ Create `docker-compose.yml`: version: '3.4' services: qdrant: - image: qdrant/qdrant:v0.8.0 + image: qdrant/qdrant:v0.10.1 ports: - "6333:6333" + - "6334:6334" ulimits: # Only required for tests, as there are a lot of collections created nofile: soft: 65535 @@ -205,57 +213,4 @@ Embeddings Nearest Neighbours with "price" at most 7: embedding=[6. 6. 6.], price=6 embedding=[5. 5. 5.], price=5 embedding=[4. 4. 4.], price=4 -``` -### Example of `.filter` with a filter -The following example shows how to use DocArray with Qdrant Document Store in order to filter text documents. -Consider Documents have the tag `price` with a value of `i`. We can create these with the following code: -```python -from docarray import Document, DocumentArray -import numpy as np - -n_dim = 3 - -da = DocumentArray( - storage='qdrant', - config={'n_dim': n_dim, 'columns': {'price': 'float'}}, -) - -with da: - da.extend( - [ - Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i}) - for i in range(10) - ] - ) - -print('\nIndexed Prices:\n') -for embedding, price in zip(da.embeddings, da[:, 'tags__price']): - print(f'\tembedding={embedding},\t price={price}') -``` -For example, suppose we want to filter results such that -retrieved documents must have a `price` value less than or equal to `max_price`. We can encode -this information in Qdrant using `filter = {'price': {'$lte': max_price}}`. - -Then you can implement and use the search with the proposed filter: -```python -max_price = 7 -n_limit = 4 - -filter = {'must': [{'key': 'price', 'range': {'lte': max_price}}]} -results = da.filter(filter=filter, limit=n_limit) - -print('\nPoints with "price" at most 7:\n') -for embedding, price in zip(results.embeddings, results[:, 'tags__price']): - print(f'\tembedding={embedding},\t price={price}') -``` -This prints: - -``` - -Points with "price" at most 7: - - embedding=[6. 6. 6.], price=6 - embedding=[7. 7. 7.], price=7 - embedding=[1. 1. 1.], price=1 - embedding=[2. 2. 2.], price=2 ``` \ No newline at end of file diff --git a/scripts/docker-compose.yml b/scripts/docker-compose.yml index 60944769bef..1d1e73d24bd 100644 --- a/scripts/docker-compose.yml +++ b/scripts/docker-compose.yml @@ -10,9 +10,10 @@ services: AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' PERSISTENCE_DATA_PATH: '/var/lib/weaviate' qdrant: - image: qdrant/qdrant:v0.8.0 + image: qdrant/qdrant:v0.10.1 ports: - - "41233:41233" + - "41237:41237" + - "41238:41238" ulimits: # Only required for tests, as there are a lot of collections created nofile: soft: 65535 diff --git a/setup.py b/setup.py index 6a18f1f848c..9021c56be06 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ 'strawberry-graphql', ], 'qdrant': [ - 'qdrant-client==0.8.0', + 'qdrant-client~=0.10.3', ], 'annlite': [ 'annlite', @@ -83,6 +83,7 @@ 'seaborn', ], 'test': [ + 'protobuf>=3.13.0,<=3.20.0', # pip dependency resolution does not respect this restriction from paddle 'pytest', 'pytest-timeout', 'pytest-mock', @@ -93,7 +94,7 @@ 'pytest-custom_exit_code', 'black==22.3.0', 'tensorflow==2.7.0', - 'paddlepaddle==2.2.0', + 'paddlepaddle', 'torch==1.9.0', 'torchvision==0.10.0', 'datasets', @@ -106,7 +107,6 @@ 'elasticsearch>=8.2.0', 'redis>=4.3.0', 'jina', - 'rocksdict<=0.2.16', ], }, classifiers=[ diff --git a/tests/unit/array/docker-compose.yml b/tests/unit/array/docker-compose.yml index 7fe783c699e..ddc76f3f1cf 100644 --- a/tests/unit/array/docker-compose.yml +++ b/tests/unit/array/docker-compose.yml @@ -10,9 +10,10 @@ services: AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' PERSISTENCE_DATA_PATH: '/var/lib/weaviate' qdrant: - image: qdrant/qdrant:v0.8.0 + image: qdrant/qdrant:v0.10.1 ports: - "6333:6333" + - "6334:6334" ulimits: # Only required for tests, as there are a lot of collections created nofile: soft: 65535 diff --git a/tests/unit/array/mixins/test_eval_class.py b/tests/unit/array/mixins/oldproto/test_eval_class.py similarity index 100% rename from tests/unit/array/mixins/test_eval_class.py rename to tests/unit/array/mixins/oldproto/test_eval_class.py diff --git a/tests/unit/array/test_advance_indexing.py b/tests/unit/array/test_advance_indexing.py index 4586b3edede..df8e005ddc0 100644 --- a/tests/unit/array/test_advance_indexing.py +++ b/tests/unit/array/test_advance_indexing.py @@ -27,6 +27,7 @@ def indices(): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -61,6 +62,7 @@ def test_getter_int_str(docs, storage, config, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('redis', RedisConfig(n_dim=123)), ], ) @@ -90,6 +92,7 @@ def test_setter_int_str(docs, storage, config, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -125,6 +128,7 @@ def test_del_int_str(docs, storage, config, start_storage, indices): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -164,6 +168,7 @@ def test_slice(docs, storage, config, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -211,6 +216,7 @@ def test_sequence_bool_index(docs, storage, config, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -248,6 +254,7 @@ def test_sequence_int(docs, nparray, storage, config, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -283,6 +290,7 @@ def test_sequence_str(docs, storage, config, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -304,6 +312,7 @@ def test_docarray_list_tuple(docs, storage, config, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -344,6 +353,7 @@ def test_path_syntax_indexing(storage, config, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -441,6 +451,7 @@ def test_path_syntax_indexing_set(storage, config, use_subindex, start_storage): ('weaviate', WeaviateConfig(n_dim=123)), ('annlite', AnnliteConfig(n_dim=123)), ('qdrant', QdrantConfig(n_dim=123)), + ('qdrant', QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', ElasticConfig(n_dim=123)), ('redis', RedisConfig(n_dim=123)), ], @@ -487,6 +498,7 @@ def test_getset_subindex(storage, config, start_storage): ('weaviate', lambda: WeaviateConfig(n_dim=123)), ('annlite', lambda: AnnliteConfig(n_dim=123)), ('qdrant', lambda: QdrantConfig(n_dim=123)), + ('qdrant', lambda: QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', lambda: ElasticConfig(n_dim=123)), ('redis', lambda: RedisConfig(n_dim=123)), ], @@ -519,18 +531,27 @@ def test_attribute_indexing(storage, config_gen, start_storage, size): @pytest.mark.parametrize( - 'storage', - ['memory', 'sqlite', 'weaviate', 'annlite', 'qdrant', 'elasticsearch', 'redis'], + 'storage,config_gen', + [ + ('memory', None), + ('sqlite', None), + ('weaviate', lambda: WeaviateConfig(n_dim=10)), + ('annlite', lambda: AnnliteConfig(n_dim=10)), + ('qdrant', lambda: QdrantConfig(n_dim=10)), + ('qdrant', lambda: QdrantConfig(n_dim=10, prefer_grpc=True)), + ('elasticsearch', lambda: ElasticConfig(n_dim=10)), + ('redis', lambda: RedisConfig(n_dim=10)), + ], ) -def test_tensor_attribute_selector(storage, start_storage): +def test_tensor_attribute_selector(storage, config_gen, start_storage): import scipy.sparse sp_embed = np.random.random([3, 10]) sp_embed[sp_embed > 0.1] = 0 sp_embed = scipy.sparse.coo_matrix(sp_embed) - if storage in ('annlite', 'weaviate', 'qdrant', 'elasticsearch', 'redis'): - da = DocumentArray(storage=storage, config={'n_dim': 10}) + if config_gen: + da = DocumentArray(storage=storage, config=config_gen()) else: da = DocumentArray(storage=storage) @@ -572,12 +593,21 @@ def test_advance_selector_mixed(storage): @pytest.mark.parametrize( - 'storage', - ['memory', 'sqlite', 'weaviate', 'annlite', 'qdrant', 'elasticsearch', 'redis'], + 'storage,config_gen', + [ + ('memory', None), + ('sqlite', None), + ('weaviate', lambda: WeaviateConfig(n_dim=10)), + ('annlite', lambda: AnnliteConfig(n_dim=10)), + ('qdrant', lambda: QdrantConfig(n_dim=10)), + ('qdrant', lambda: QdrantConfig(n_dim=10, prefer_grpc=True)), + ('elasticsearch', lambda: ElasticConfig(n_dim=10)), + ('redis', lambda: RedisConfig(n_dim=10)), + ], ) -def test_single_boolean_and_padding(storage, start_storage): - if storage in ('annlite', 'weaviate', 'qdrant', 'elasticsearch', 'redis'): - da = DocumentArray(storage=storage, config={'n_dim': 10}) +def test_single_boolean_and_padding(storage, config_gen, start_storage): + if config_gen: + da = DocumentArray(storage=storage, config=config_gen()) else: da = DocumentArray(storage=storage) da.extend(DocumentArray.empty(3)) @@ -604,6 +634,7 @@ def test_single_boolean_and_padding(storage, start_storage): ('weaviate', lambda: WeaviateConfig(n_dim=123)), ('annlite', lambda: AnnliteConfig(n_dim=123)), ('qdrant', lambda: QdrantConfig(n_dim=123)), + ('qdrant', lambda: QdrantConfig(n_dim=123, prefer_grpc=True)), ('elasticsearch', lambda: ElasticConfig(n_dim=123)), ('redis', lambda: RedisConfig(n_dim=123)), ], diff --git a/tests/unit/array/test_backend_configuration.py b/tests/unit/array/test_backend_configuration.py index 8255375e6d2..e10b080d9da 100644 --- a/tests/unit/array/test_backend_configuration.py +++ b/tests/unit/array/test_backend_configuration.py @@ -130,7 +130,8 @@ def test_cast_columns_annlite(start_storage, type_da, type_column): @pytest.mark.parametrize('type_da', [int, float, str]) @pytest.mark.parametrize('type_column', ['int', 'float', 'str']) -def test_cast_columns_qdrant(start_storage, type_da, type_column, request): +@pytest.mark.parametrize('prefer_grpc', [False, True]) +def test_cast_columns_qdrant(start_storage, type_da, type_column, prefer_grpc, request): test_id = request.node.callspec.id.replace( '-', '' @@ -143,6 +144,7 @@ def test_cast_columns_qdrant(start_storage, type_da, type_column, request): 'collection_name': f'test{test_id}', 'n_dim': 3, 'columns': {'price': type_column}, + 'prefer_grpc': prefer_grpc, }, )