diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index 7edb04ad83e..c4da3ad5e0d 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -1,5 +1,4 @@ # mypy: ignore-errors -import uuid import warnings from collections import defaultdict from dataclasses import dataclass, field @@ -74,12 +73,6 @@ def __init__(self, db_config=None, **kwargs): self._logger.debug('Elastic Search index is being initialized') # ElasticSearch client creation - if self._db_config.index_name is None: - id = uuid.uuid4().hex - self._db_config.index_name = 'index__' + id - - self._index_name = self._db_config.index_name - self._client = Elasticsearch( hosts=self._db_config.hosts, **self._db_config.es_config, @@ -108,7 +101,7 @@ def __init__(self, db_config=None, **kwargs): mappings['properties'][col_name] = self._create_index_mapping(col) # print(mappings['properties']) - if self._client.indices.exists(index=self._index_name): + if self._client.indices.exists(index=self.index_name): self._client_put_mapping(mappings) else: self._client_create(mappings) @@ -116,7 +109,20 @@ def __init__(self, db_config=None, **kwargs): if len(self._db_config.index_settings): self._client_put_settings(self._db_config.index_settings) - self._refresh(self._index_name) + self._refresh(self.index_name) + + @property + def index_name(self): + default_index_name = ( + self._schema.__name__.lower() if self._schema is not None else None + ) + if default_index_name is None: + raise ValueError( + 'A ElasticDocIndex must be typed with a Document type.' + 'To do so, use the syntax: ElasticDocIndex[DocumentType]' + ) + + return self._db_config.index_name or default_index_name ############################################### # Inner classes for query builder and configs # @@ -333,7 +339,7 @@ def _index( for row in data: request = { - '_index': self._index_name, + '_index': self.index_name, '_id': row['id'], } for col_name, col in self._column_infos.items(): @@ -349,13 +355,13 @@ def _index( warnings.warn(str(info)) if refresh: - self._refresh(self._index_name) + self._refresh(self.index_name) def num_docs(self) -> int: """ Get the number of documents. """ - return self._client.count(index=self._index_name)['count'] + return self._client.count(index=self.index_name)['count'] def _del_items( self, @@ -365,7 +371,7 @@ def _del_items( requests = [] for _id in doc_ids: requests.append( - {'_op_type': 'delete', '_index': self._index_name, '_id': _id} + {'_op_type': 'delete', '_index': self.index_name, '_id': _id} ) _, warning_info = self._send_requests(requests, chunk_size) @@ -375,7 +381,7 @@ def _del_items( ids = [info['delete']['_id'] for info in warning_info] warnings.warn(f'No document with id {ids} found') - self._refresh(self._index_name) + self._refresh(self.index_name) def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]: accumulated_docs = [] @@ -416,7 +422,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any: f'args and kwargs not supported for `execute_query` on {type(self)}' ) - resp = self._client.search(index=self._index_name, **query) + resp = self._client.search(index=self.index_name, **query) docs, scores = self._format_response(resp) return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores)) @@ -440,7 +446,7 @@ def _find_batched( ) -> _FindResultBatched: request = [] for query in queries: - head = {'index': self._index_name} + head = {'index': self.index_name} body = self._form_search_body(query, limit, search_field) request.extend([head, body]) @@ -469,7 +475,7 @@ def _filter_batched( ) -> List[List[Dict]]: request = [] for query in filter_queries: - head = {'index': self._index_name} + head = {'index': self.index_name} body = {'query': query, 'size': limit} request.extend([head, body]) @@ -499,7 +505,7 @@ def _text_search_batched( ) -> _FindResultBatched: request = [] for query in queries: - head = {'index': self._index_name} + head = {'index': self.index_name} body = self._form_text_search_body(query, limit, search_field) request.extend([head, body]) @@ -615,20 +621,20 @@ def _refresh(self, index_name: str): def _client_put_mapping(self, mappings: Dict[str, Any]): self._client.indices.put_mapping( - index=self._index_name, properties=mappings['properties'] + index=self.index_name, properties=mappings['properties'] ) def _client_create(self, mappings: Dict[str, Any]): - self._client.indices.create(index=self._index_name, mappings=mappings) + self._client.indices.create(index=self.index_name, mappings=mappings) def _client_put_settings(self, settings: Dict[str, Any]): - self._client.indices.put_settings(index=self._index_name, settings=settings) + self._client.indices.put_settings(index=self.index_name, settings=settings) def _client_mget(self, ids: Sequence[str]): - return self._client.mget(index=self._index_name, ids=ids) + return self._client.mget(index=self.index_name, ids=ids) def _client_search(self, **kwargs): - return self._client.search(index=self._index_name, **kwargs) + return self._client.search(index=self.index_name, **kwargs) def _client_msearch(self, request: List[Dict[str, Any]]): - return self._client.msearch(index=self._index_name, searches=request) + return self._client.msearch(index=self.index_name, searches=request) diff --git a/docarray/index/backends/elasticv7.py b/docarray/index/backends/elasticv7.py index 623f11053bb..1782e921f62 100644 --- a/docarray/index/backends/elasticv7.py +++ b/docarray/index/backends/elasticv7.py @@ -119,7 +119,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any: f'args and kwargs not supported for `execute_query` on {type(self)}' ) - resp = self._client.search(index=self._index_name, body=query) + resp = self._client.search(index=self.index_name, body=query) docs, scores = self._format_response(resp) return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores)) @@ -161,20 +161,20 @@ def _form_search_body(self, query: np.ndarray, limit: int, search_field: str = ' ############################################### def _client_put_mapping(self, mappings: Dict[str, Any]): - self._client.indices.put_mapping(index=self._index_name, body=mappings) + self._client.indices.put_mapping(index=self.index_name, body=mappings) def _client_create(self, mappings: Dict[str, Any]): body = {'mappings': mappings} - self._client.indices.create(index=self._index_name, body=body) + self._client.indices.create(index=self.index_name, body=body) def _client_put_settings(self, settings: Dict[str, Any]): - self._client.indices.put_settings(index=self._index_name, body=settings) + self._client.indices.put_settings(index=self.index_name, body=settings) def _client_mget(self, ids: Sequence[str]): - return self._client.mget(index=self._index_name, body={'ids': ids}) + return self._client.mget(index=self.index_name, body={'ids': ids}) def _client_search(self, **kwargs): - return self._client.search(index=self._index_name, body=kwargs) + return self._client.search(index=self.index_name, body=kwargs) def _client_msearch(self, request: List[Dict[str, Any]]): - return self._client.msearch(index=self._index_name, body=request) + return self._client.msearch(index=self.index_name, body=request) diff --git a/docarray/index/backends/qdrant.py b/docarray/index/backends/qdrant.py index d60e9daf7fa..e2e503d593c 100644 --- a/docarray/index/backends/qdrant.py +++ b/docarray/index/backends/qdrant.py @@ -45,7 +45,6 @@ TSchema = TypeVar('TSchema', bound=BaseDoc) - QDRANT_PY_VECTOR_TYPES: List[Any] = [np.ndarray, AbstractTensor] if torch_imported: import torch @@ -86,6 +85,19 @@ def __init__(self, db_config=None, **kwargs): self._initialize_collection() self._logger.info(f'{self.__class__.__name__} has been initialized') + @property + def collection_name(self): + default_collection_name = ( + self._schema.__name__.lower() if self._schema is not None else None + ) + if default_collection_name is None: + raise ValueError( + 'A QdrantDocumentIndex must be typed with a Document type.' + 'To do so, use the syntax: QdrantDocumentIndex[DocumentType]' + ) + + return self._db_config.collection_name or default_collection_name + @dataclass class Query: """Dataclass describing a query.""" @@ -211,7 +223,7 @@ class DBConfig(BaseDocIndex.DBConfig): timeout: Optional[float] = None host: Optional[str] = None path: Optional[str] = None - collection_name: str = 'documents' + collection_name: Optional[str] = None shard_number: Optional[int] = None replication_factor: Optional[int] = None write_consistency_factor: Optional[int] = None @@ -250,7 +262,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any: def _initialize_collection(self): try: - self._client.get_collection(self._db_config.collection_name) + self._client.get_collection(self.collection_name) except (UnexpectedResponse, RpcError, ValueError): vectors_config = { column_name: self._to_qdrant_vector_params(column_info) @@ -258,7 +270,7 @@ def _initialize_collection(self): if column_info.db_type == 'vector' } self._client.create_collection( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, vectors_config=vectors_config, shard_number=self._db_config.shard_number, replication_factor=self._db_config.replication_factor, @@ -270,7 +282,7 @@ def _initialize_collection(self): quantization_config=self._db_config.quantization_config, ) self._client.create_payload_index( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, field_name='__generated_vectors', field_schema=rest.PayloadSchemaType.KEYWORD, ) @@ -280,7 +292,7 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): # TODO: add batching the documents to avoid timeouts points = [self._build_point_from_row(row) for row in rows] self._client.upsert( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, points=points, ) @@ -288,7 +300,7 @@ def num_docs(self) -> int: """ Get the number of documents. """ - return self._client.count(collection_name=self._db_config.collection_name).count + return self._client.count(collection_name=self.collection_name).count def _del_items(self, doc_ids: Sequence[str]): items = self._get_items(doc_ids) @@ -298,7 +310,7 @@ def _del_items(self, doc_ids: Sequence[str]): raise KeyError('Document keys could not found: %s' % ','.join(missing_keys)) self._client.delete( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, points_selector=rest.PointIdsList( points=[self._to_qdrant_id(doc_id) for doc_id in doc_ids], ), @@ -308,7 +320,7 @@ def _get_items( self, doc_ids: Sequence[str] ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: response, _ = self._client.scroll( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, scroll_filter=rest.Filter( must=[ rest.HasIdCondition( @@ -343,7 +355,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi # We perform semantic search with some vectors with Qdrant's search method # should be called points = self._client.search( # type: ignore[assignment] - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, query_vector=(query.vector_field, query.vector_query), # type: ignore[arg-type] query_filter=rest.Filter( must=[query.filter], @@ -364,7 +376,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi else: # Just filtering, so Qdrant's scroll has to be used instead points, _ = self._client.scroll( # type: ignore[assignment] - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, scroll_filter=query.filter, limit=query.limit, with_payload=True, @@ -388,7 +400,7 @@ def _execute_raw_query( if search_params: search_params = rest.SearchParams.parse_obj(search_params) # type: ignore[assignment] points = self._client.search( # type: ignore[assignment] - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, query_vector=query.pop('vector'), query_filter=payload_filter, search_params=search_params, @@ -397,7 +409,7 @@ def _execute_raw_query( else: # Just filtering, so Qdrant's scroll has to be used instead points, _ = self._client.scroll( # type: ignore[assignment] - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, scroll_filter=payload_filter, **query, ) @@ -417,7 +429,7 @@ def _find_batched( self, queries: np.ndarray, limit: int, search_field: str = '' ) -> _FindResultBatched: responses = self._client.search_batch( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, requests=[ rest.SearchRequest( vector=rest.NamedVector( @@ -470,7 +482,7 @@ def _filter_batched( # There is no batch scroll available in Qdrant client yet, so we need to # perform the queries one by one. It will be changed in the future versions. response, _ = self._client.scroll( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, scroll_filter=filter_query, limit=limit, with_payload=True, diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py index 368992645e2..5179f8cb588 100644 --- a/docarray/index/backends/weaviate.py +++ b/docarray/index/backends/weaviate.py @@ -111,6 +111,17 @@ def __init__(self, db_config=None, **kwargs) -> None: self._set_properties() self._create_schema() + @property + def index_name(self): + default_index_name = self._schema.__name__ if self._schema is not None else None + if default_index_name is None: + raise ValueError( + 'A WeaviateDocumentIndex must be typed with a Document type.' + 'To do so, use the syntax: WeaviateDocumentIndex[DocumentType]' + ) + + return self._db_config.index_name or default_index_name + def _set_properties(self) -> None: field_overwrites = {"id": DOCUMENTID} @@ -207,13 +218,13 @@ def _create_schema(self) -> None: # and configure replication # we will update base on user feedback schema["properties"] = properties - schema["class"] = self._db_config.index_name + schema["class"] = self.index_name # TODO: Use exists() instead of contains() when available # see https://github.com/weaviate/weaviate-python-client/issues/232 if self._client.schema.contains(schema): logging.warning( - f"Found index {self._db_config.index_name} with schema {schema}. Will reuse existing schema." + f"Found index {self.index_name} with schema {schema}. Will reuse existing schema." ) else: self._client.schema.create_class(schema) @@ -223,7 +234,7 @@ class DBConfig(BaseDocIndex.DBConfig): """Dataclass that contains all "static" configurations of WeaviateDocumentIndex.""" host: str = 'http://localhost:8080' - index_name: str = 'Document' + index_name: Optional[str] = None username: Optional[str] = None password: Optional[str] = None scopes: List[str] = field(default_factory=lambda: ["offline_access"]) @@ -269,7 +280,7 @@ def _del_items(self, doc_ids: Sequence[str]): # see: https://weaviate.io/developers/weaviate/api/rest/batch#maximum-number-of-deletes-per-query while has_matches: results = self._client.batch.delete_objects( - class_name=self._db_config.index_name, + class_name=self.index_name, where=where_filter, ) @@ -279,14 +290,14 @@ def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: self._overwrite_id(filter_query) results = ( - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_additional("vector") .with_where(filter_query) .with_limit(limit) .do() ) - docs = results["data"]["Get"][self._db_config.index_name] + docs = results["data"]["Get"][self.index_name] return [self._parse_weaviate_result(doc) for doc in docs] @@ -297,7 +308,7 @@ def _filter_batched( self._overwrite_id(filter_query) qs = [ - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_additional("vector") .with_where(filter_query) .with_limit(limit) @@ -370,7 +381,7 @@ def _find( score_name: Literal["certainty", "distance"] = "certainty", score_threshold: Optional[float] = None, ) -> _FindResult: - index_name = self._db_config.index_name + index_name = self.index_name if search_field: logging.warning( 'Argument search_field is not supported for WeaviateDocumentIndex. Ignoring.' @@ -474,7 +485,7 @@ def _find_batched( near_vector[score_name] = score_threshold q = ( - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_near_vector(near_vector) .with_limit(limit) .with_additional([score_name, "vector"]) @@ -507,7 +518,7 @@ def _get_items(self, doc_ids: Sequence[str]) -> List[Dict]: } results = ( - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_where(where_filter) .with_additional("vector") .do() @@ -515,7 +526,7 @@ def _get_items(self, doc_ids: Sequence[str]) -> List[Dict]: docs = [ self._parse_weaviate_result(doc) - for doc in results["data"]["Get"][self._db_config.index_name] + for doc in results["data"]["Get"][self.index_name] ] return docs @@ -554,7 +565,7 @@ def _parse_weaviate_result(self, result: Dict) -> Dict: def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): docs = self._transpose_col_value_dict(column_to_data) - index_name = self._db_config.index_name + index_name = self.index_name with self._client.batch as batch: for doc in docs: @@ -577,7 +588,7 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): def _text_search( self, query: str, limit: int, search_field: str = '' ) -> _FindResult: - index_name = self._db_config.index_name + index_name = self.index_name bm25 = {"query": query, "properties": [search_field]} results = ( @@ -602,7 +613,7 @@ def _text_search_batched( bm25 = {"query": query, "properties": [search_field]} q = ( - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_bm25(bm25) .with_limit(limit) .with_additional(["score", "vector"]) @@ -663,7 +674,7 @@ def num_docs(self) -> int: """ Get the number of documents. """ - index_name = self._db_config.index_name + index_name = self.index_name result = self._client.query.aggregate(index_name).with_meta_count().do() # TODO: decorator to check for errors total_docs = result["data"]["Aggregate"][index_name][0]["meta"]["count"] @@ -734,7 +745,7 @@ class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__(self, document_index): self._queries = [ document_index._client.query.get( - document_index._db_config.index_name, document_index.properties + document_index.index_name, document_index.properties ) ] diff --git a/docs/user_guide/storing/docindex.md b/docs/user_guide/storing/docindex.md index 2240b06cc86..ad2ee2b96d2 100644 --- a/docs/user_guide/storing/docindex.md +++ b/docs/user_guide/storing/docindex.md @@ -161,7 +161,12 @@ db.index(data) For `HnswDocumentIndex` you need to specify a `work_dir` where the data will be stored; for other backends you usually specify a `host` and a `port` instead. -Either way, if the location does not yet contain any data, we start from a blank slate. +In addition to a host and a port, most backends can also take an `index_name`, `table_name`, `collection_name` or similar. +This specifies the name of the index/table/collection that will be created in the database. +You don't have to specify this though: By default, this name will be taken from the name of the Document type that you use as schema. +For example, for `WeaviateDocumentIndex[MyDoc](...)` the data will be stored in a Weaviate Class of name `MyDoc`. + +In any case, if the location does not yet contain any data, we start from a blank slate. If the location already contains data from a previous session, it will be accessible through the Document Index. ## Index data diff --git a/docs/user_guide/storing/index_elastic.md b/docs/user_guide/storing/index_elastic.md index a21528c46b8..eb1186f6d12 100644 --- a/docs/user_guide/storing/index_elastic.md +++ b/docs/user_guide/storing/index_elastic.md @@ -420,7 +420,7 @@ The following configs can be set in `DBConfig`: |-------------------|----------------------------------------------------------------------------------------------------------------------------------------|-------------------------| | `hosts` | Hostname of the Elasticsearch server | `http://localhost:9200` | | `es_config` | Other ES [configuration options](https://www.elastic.co/guide/en/elasticsearch/client/python-api/8.6/config.html) in a Dict and pass to `Elasticsearch` client constructor, e.g. `cloud_id`, `api_key` | None | -| `index_name` | Elasticsearch index name, the name of Elasticsearch index object | None | +| `index_name` | Elasticsearch index name, the name of Elasticsearch index object | None. Data will be stored in an index named after the Document type used as schema. | | `index_settings` | Other [index settings](https://www.elastic.co/guide/en/elasticsearch/reference/8.6/index-modules.html#index-modules-settings) in a Dict for creating the index | dict | | `index_mappings` | Other [index mappings](https://www.elastic.co/guide/en/elasticsearch/reference/8.6/mapping.html) in a Dict for creating the index | dict | diff --git a/docs/user_guide/storing/index_qdrant.md b/docs/user_guide/storing/index_qdrant.md index 7d832f1dd67..01249d01b7a 100644 --- a/docs/user_guide/storing/index_qdrant.md +++ b/docs/user_guide/storing/index_qdrant.md @@ -27,6 +27,10 @@ For general usage of a Document Index, see the [general user guide](./docindex.m runtime_config = QdrantDocumentIndex.RuntimeConfig() print(runtime_config) # shows default values ``` + + Note that the collection_name from the DBConfig is an Optional[str] with None as default value. This is because + the QdrantDocumentIndex will take the name the Document type that you use as schema. For example, for QdrantDocumentIndex[MyDoc](...) + the data will be stored in a collection name MyDoc if no specific collection_name is passed in the DBConfig. ```python import numpy as np diff --git a/docs/user_guide/storing/index_weaviate.md b/docs/user_guide/storing/index_weaviate.md index 83d84dd8fa9..09aaccc69b7 100644 --- a/docs/user_guide/storing/index_weaviate.md +++ b/docs/user_guide/storing/index_weaviate.md @@ -203,18 +203,18 @@ And a list of environment variables is [available on this page](https://weaviate Additionally, you can specify the below settings when you instantiate a configuration object in DocArray. -| name | type | explanation | default | example | -| ---- | ---- | ----------- | ------- | ------- | +| name | type | explanation | default | example | +| ---- | ---- | ----------- |------------------------------------------------------------------------| ------- | | **Category: General** | -| host | str | Weaviate instance url | http://localhost:8080 | +| host | str | Weaviate instance url | http://localhost:8080 | | **Category: Authentication** | -| username | str | Username known to the specified authentication provider (e.g. WCS) | None | `jp@weaviate.io` | -| password | str | Corresponding password | None | `p@ssw0rd` | -| auth_api_key | str | API key known to the Weaviate instance | None | `mys3cretk3y` | +| username | str | Username known to the specified authentication provider (e.g. WCS) | None | `jp@weaviate.io` | +| password | str | Corresponding password | None | `p@ssw0rd` | +| auth_api_key | str | API key known to the Weaviate instance | None | `mys3cretk3y` | | **Category: Data schema** | -| index_name | str | Class name to use to store the document | `Document` | +| index_name | str | Class name to use to store the document| The document class name, e.g. `MyDoc` for `WeaviateDocumentIndex[MyDoc]` | `Document` | | **Category: Embedded Weaviate** | -| embedded_options| EmbeddedOptions | Options for embedded weaviate | None | +| embedded_options| EmbeddedOptions | Options for embedded weaviate | None | The type `EmbeddedOptions` can be specified as described [here](https://weaviate.io/developers/weaviate/installation/embedded#embedded-options) diff --git a/tests/index/elastic/fixture.py b/tests/index/elastic/fixture.py index 812f0f09d51..ef7766acd0c 100644 --- a/tests/index/elastic/fixture.py +++ b/tests/index/elastic/fixture.py @@ -1,5 +1,6 @@ import os import time +import uuid import numpy as np import pytest @@ -87,3 +88,8 @@ def ten_deep_nested_docs(): DeepNestedDoc(d=NestedDoc(d=SimpleDoc(tens=np.random.randn(10)))) for _ in range(10) ] + + +@pytest.fixture(scope='function') +def tmp_index_name(): + return uuid.uuid4().hex diff --git a/tests/index/elastic/v7/test_column_config.py b/tests/index/elastic/v7/test_column_config.py index f1fa93d7748..812734aef26 100644 --- a/tests/index/elastic/v7/test_column_config.py +++ b/tests/index/elastic/v7/test_column_config.py @@ -129,3 +129,17 @@ class MyDoc(BaseDoc): } docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] + + +def test_index_name(): + class TextDoc(BaseDoc): + text: str = Field() + + class StringDoc(BaseDoc): + text: str = Field(col_type='text') + + index = ElasticV7DocIndex[TextDoc]() + assert index.index_name == TextDoc.__name__.lower() + + index = ElasticV7DocIndex[StringDoc]() + assert index.index_name == StringDoc.__name__.lower() diff --git a/tests/index/elastic/v7/test_find.py b/tests/index/elastic/v7/test_find.py index d54b3b0480d..1fe7893e91e 100644 --- a/tests/index/elastic/v7/test_find.py +++ b/tests/index/elastic/v7/test_find.py @@ -6,8 +6,12 @@ from docarray import BaseDoc from docarray.index import ElasticV7DocIndex from docarray.typing import NdArray, TorchTensor -from tests.index.elastic.fixture import start_storage_v7 # noqa: F401 -from tests.index.elastic.fixture import FlatDoc, SimpleDoc +from tests.index.elastic.fixture import ( # noqa: F401 + FlatDoc, + SimpleDoc, + start_storage_v7, + tmp_index_name, +) pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -167,8 +171,8 @@ class TfDoc(BaseDoc): ) -def test_find_batched(): - index = ElasticV7DocIndex[SimpleDoc]() +def test_find_batched(tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] index.index(index_docs) @@ -243,13 +247,13 @@ class MyDoc(BaseDoc): assert doc.text.index(query) >= 0 -def test_query_builder(): +def test_query_builder(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): tens: NdArray[10] num: int text: str - index = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc](index_name=tmp_index_name) index_docs = [ MyDoc( id=f'{i}', tens=np.random.rand(10), num=int(i / 2), text=f'text {int(i/2)}' diff --git a/tests/index/elastic/v7/test_index_get_del.py b/tests/index/elastic/v7/test_index_get_del.py index e6e6baf3e60..050bcb03f54 100644 --- a/tests/index/elastic/v7/test_index_get_del.py +++ b/tests/index/elastic/v7/test_index_get_del.py @@ -18,14 +18,17 @@ ten_flat_docs, ten_nested_docs, ten_simple_docs, + tmp_index_name, ) pytestmark = [pytest.mark.slow, pytest.mark.index] @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 - index = ElasticV7DocIndex[SimpleDoc]() +def test_index_simple_schema( + ten_simple_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) @@ -34,8 +37,8 @@ def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 - index = ElasticV7DocIndex[FlatDoc]() +def test_index_flat_schema(ten_flat_docs, use_docarray, tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[FlatDoc](index_name=tmp_index_name) if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) @@ -44,8 +47,10 @@ def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 - index = ElasticV7DocIndex[NestedDoc]() +def test_index_nested_schema( + ten_nested_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticV7DocIndex[NestedDoc](index_name=tmp_index_name) if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) @@ -54,8 +59,10 @@ def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: F811 - index = ElasticV7DocIndex[DeepNestedDoc]() +def test_index_deep_nested_schema( + ten_deep_nested_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticV7DocIndex[DeepNestedDoc](index_name=tmp_index_name) if use_docarray: ten_deep_nested_docs = DocList[DeepNestedDoc](ten_deep_nested_docs) @@ -73,6 +80,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F id_ = d.id assert index[id_].id == id_ assert np.all(index[id_].tens == d.tens) + index._client.indices.delete(index='simpledoc') # flat index = ElasticV7DocIndex[FlatDoc]() @@ -84,6 +92,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert index[id_].id == id_ assert np.all(index[id_].tens_one == d.tens_one) assert np.all(index[id_].tens_two == d.tens_two) + index._client.indices.delete(index='flatdoc') # nested index = ElasticV7DocIndex[NestedDoc]() @@ -95,6 +104,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert index[id_].id == id_ assert index[id_].d.id == d.d.id assert np.all(index[id_].d.tens == d.d.tens) + index._client.indices.delete(index='nesteddoc') def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 @@ -147,16 +157,16 @@ def test_get_key_error(ten_simple_docs): # noqa: F811 index['not_a_real_id'] -def test_persisting(ten_simple_docs): # noqa: F811 - index = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') +def test_persisting(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) - index2 = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') + index2 = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) assert index2.num_docs() == 10 -def test_del_single(ten_simple_docs): # noqa: F811 - index = ElasticV7DocIndex[SimpleDoc]() +def test_del_single(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) # delete once assert index.num_docs() == 10 @@ -183,10 +193,10 @@ def test_del_single(ten_simple_docs): # noqa: F811 assert np.all(index[id_].tens == d.tens) -def test_del_multiple(ten_simple_docs): # noqa: F811 +def test_del_multiple(ten_simple_docs, tmp_index_name): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - index = ElasticV7DocIndex[SimpleDoc]() + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -210,8 +220,8 @@ def test_del_key_error(ten_simple_docs): # noqa: F811 del index['not_a_real_id'] -def test_num_docs(ten_simple_docs): # noqa: F811 - index = ElasticV7DocIndex[SimpleDoc]() +def test_num_docs(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 diff --git a/tests/index/elastic/v8/test_column_config.py b/tests/index/elastic/v8/test_column_config.py index 0edd105697d..c2831dfb59e 100644 --- a/tests/index/elastic/v8/test_column_config.py +++ b/tests/index/elastic/v8/test_column_config.py @@ -3,17 +3,17 @@ from docarray import BaseDoc from docarray.index import ElasticDocIndex -from tests.index.elastic.fixture import start_storage_v8 # noqa: F401 +from tests.index.elastic.fixture import start_storage_v8, tmp_index_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] -def test_column_config(): +def test_column_config(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): text: str color: str = Field(col_type='keyword') - index = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc](index_name=tmp_index_name) index_docs = [ MyDoc(id='0', text='hello world', color='red'), MyDoc(id='1', text='never gonna give you up', color='blue'), @@ -30,7 +30,7 @@ class MyDoc(BaseDoc): assert [doc.id for doc in docs] == ['0', '1'] -def test_field_object(): +def test_field_object(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): manager: dict = Field( properties={ @@ -44,7 +44,7 @@ class MyDoc(BaseDoc): } ) - index = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc](index_name=tmp_index_name) doc = [ MyDoc(manager={'age': 25, 'name': {'first': 'Rachel', 'last': 'Green'}}), MyDoc(manager={'age': 30, 'name': {'first': 'Monica', 'last': 'Geller'}}), @@ -60,11 +60,11 @@ class MyDoc(BaseDoc): assert [doc.id for doc in docs] == [doc[1].id, doc[2].id] -def test_field_geo_point(): +def test_field_geo_point(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): location: dict = Field(col_type='geo_point') - index = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc](index_name=tmp_index_name) doc = [ MyDoc(location={'lat': 40.12, 'lon': -72.34}), MyDoc(location={'lat': 41.12, 'lon': -73.34}), @@ -87,12 +87,12 @@ class MyDoc(BaseDoc): assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] -def test_field_range(): +def test_field_range(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): expected_attendees: dict = Field(col_type='integer_range') time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') - index = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc](index_name=tmp_index_name) doc = [ MyDoc( expected_attendees={'gte': 10, 'lt': 20}, @@ -129,3 +129,17 @@ class MyDoc(BaseDoc): } docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] + + +def test_index_name(): + class TextDoc(BaseDoc): + text: str = Field() + + class StringDoc(BaseDoc): + text: str = Field(col_type='text') + + index = ElasticDocIndex[TextDoc]() + assert index.index_name == TextDoc.__name__.lower() + + index = ElasticDocIndex[StringDoc]() + assert index.index_name == StringDoc.__name__.lower() diff --git a/tests/index/elastic/v8/test_find.py b/tests/index/elastic/v8/test_find.py index bb87755254c..cfd27ed0912 100644 --- a/tests/index/elastic/v8/test_find.py +++ b/tests/index/elastic/v8/test_find.py @@ -6,18 +6,22 @@ from docarray import BaseDoc from docarray.index import ElasticDocIndex from docarray.typing import NdArray, TorchTensor -from tests.index.elastic.fixture import start_storage_v8 # noqa: F401 -from tests.index.elastic.fixture import FlatDoc, SimpleDoc +from tests.index.elastic.fixture import ( # noqa: F401 + FlatDoc, + SimpleDoc, + start_storage_v8, + tmp_index_name, +) pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] @pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) -def test_find_simple_schema(similarity): +def test_find_simple_schema(similarity, tmp_index_name): # noqa: F811 class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(similarity=similarity) - index = ElasticDocIndex[SimpleSchema]() + index = ElasticDocIndex[SimpleSchema](index_name=tmp_index_name) index_docs = [] for _ in range(10): @@ -37,12 +41,12 @@ class SimpleSchema(BaseDoc): @pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) -def test_find_flat_schema(similarity): +def test_find_flat_schema(similarity, tmp_index_name): # noqa: F811 class FlatSchema(BaseDoc): tens_one: NdArray = Field(dims=10, similarity=similarity) tens_two: NdArray = Field(dims=50, similarity=similarity) - index = ElasticDocIndex[FlatSchema]() + index = ElasticDocIndex[FlatSchema](index_name=tmp_index_name) index_docs = [] for _ in range(10): @@ -75,7 +79,7 @@ class FlatSchema(BaseDoc): @pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) -def test_find_nested_schema(similarity): +def test_find_nested_schema(similarity, tmp_index_name): # noqa: F811 class SimpleDoc(BaseDoc): tens: NdArray[10] = Field(similarity=similarity) @@ -87,7 +91,7 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(similarity=similarity, dims=10) - index = ElasticDocIndex[DeepNestedDoc]() + index = ElasticDocIndex[DeepNestedDoc](index_name=tmp_index_name) index_docs = [] for _ in range(10): @@ -188,8 +192,8 @@ class TfDoc(BaseDoc): ) -def test_find_batched(): - index = ElasticDocIndex[SimpleDoc]() +def test_find_batched(tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] index.index(index_docs) @@ -272,7 +276,9 @@ class MyDoc(BaseDoc): index = ElasticDocIndex[MyDoc]() index_docs = [ - MyDoc(id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'text {int(i/2)}') + MyDoc( + id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'text {int(i / 2)}' + ) for i in range(10) ] index.index(index_docs) @@ -327,3 +333,12 @@ class MyDoc(BaseDoc): docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == ['7', '6', '5', '4'] + + +def test_index_name(): + class MyDoc(BaseDoc): + expected_attendees: dict = Field(col_type='integer_range') + time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + + index = ElasticDocIndex[MyDoc]() + assert index.index_name == MyDoc.__name__.lower() diff --git a/tests/index/elastic/v8/test_index_get_del.py b/tests/index/elastic/v8/test_index_get_del.py index 8efd66429b0..8d182dfd19a 100644 --- a/tests/index/elastic/v8/test_index_get_del.py +++ b/tests/index/elastic/v8/test_index_get_del.py @@ -18,14 +18,17 @@ ten_flat_docs, ten_nested_docs, ten_simple_docs, + tmp_index_name, ) pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_index_simple_schema( + ten_simple_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) @@ -34,8 +37,8 @@ def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 - index = ElasticDocIndex[FlatDoc]() +def test_index_flat_schema(ten_flat_docs, use_docarray, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[FlatDoc](index_name=tmp_index_name) if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) @@ -44,8 +47,10 @@ def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 - index = ElasticDocIndex[NestedDoc]() +def test_index_nested_schema( + ten_nested_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticDocIndex[NestedDoc](index_name=tmp_index_name) if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) @@ -54,8 +59,10 @@ def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: F811 - index = ElasticDocIndex[DeepNestedDoc]() +def test_index_deep_nested_schema( + ten_deep_nested_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticDocIndex[DeepNestedDoc](index_name=tmp_index_name) if use_docarray: ten_deep_nested_docs = DocList[DeepNestedDoc](ten_deep_nested_docs) @@ -73,6 +80,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F id_ = d.id assert index[id_].id == id_ assert np.all(index[id_].tens == d.tens) + index._client.indices.delete(index='simpledoc', ignore_unavailable=True) # flat index = ElasticDocIndex[FlatDoc]() @@ -84,6 +92,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert index[id_].id == id_ assert np.all(index[id_].tens_one == d.tens_one) assert np.all(index[id_].tens_two == d.tens_two) + index._client.indices.delete(index='flatdoc', ignore_unavailable=True) # nested index = ElasticDocIndex[NestedDoc]() @@ -95,6 +104,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert index[id_].id == id_ assert index[id_].d.id == d.d.id assert np.all(index[id_].d.tens == d.d.tens) + index._client.indices.delete(index='nesteddoc', ignore_unavailable=True) def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 @@ -139,24 +149,24 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: assert np.all(d_out.d.tens == d_in.d.tens) -def test_get_key_error(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_get_key_error(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) with pytest.raises(KeyError): index['not_a_real_id'] -def test_persisting(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc](index_name='test_persisting') +def test_persisting(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) - index2 = ElasticDocIndex[SimpleDoc](index_name='test_persisting') + index2 = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) assert index2.num_docs() == 10 -def test_del_single(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_del_single(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) # delete once assert index.num_docs() == 10 @@ -183,10 +193,10 @@ def test_del_single(ten_simple_docs): # noqa: F811 assert np.all(index[id_].tens == d.tens) -def test_del_multiple(ten_simple_docs): # noqa: F811 +def test_del_multiple(ten_simple_docs, tmp_index_name): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - index = ElasticDocIndex[SimpleDoc]() + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -202,16 +212,16 @@ def test_del_multiple(ten_simple_docs): # noqa: F811 assert np.all(index[doc.id].tens == doc.tens) -def test_del_key_error(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_del_key_error(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) with pytest.warns(UserWarning): del index['not_a_real_id'] -def test_num_docs(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_num_docs(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 diff --git a/tests/index/qdrant/docker-compose.yml b/tests/index/qdrant/docker-compose.yml new file mode 100644 index 00000000000..a3a57b7c87f --- /dev/null +++ b/tests/index/qdrant/docker-compose.yml @@ -0,0 +1,12 @@ +version: '3.8' + +services: + qdrant: + image: qdrant/qdrant:v1.1.2 + ports: + - "6333:6333" + - "6334:6334" + ulimits: # Only required for tests, as there are a lot of collections created + nofile: + soft: 65535 + hard: 65535 \ No newline at end of file diff --git a/tests/index/qdrant/fixtures.py b/tests/index/qdrant/fixtures.py index d44a0950d35..d48372ad0f0 100644 --- a/tests/index/qdrant/fixtures.py +++ b/tests/index/qdrant/fixtures.py @@ -1,8 +1,29 @@ +import os +import time +import uuid + import pytest import qdrant_client from docarray.index import QdrantDocumentIndex +cur_dir = os.path.dirname(os.path.abspath(__file__)) +qdrant_yml = os.path.abspath(os.path.join(cur_dir, 'docker-compose.yml')) + + +@pytest.fixture(scope='session', autouse=True) +def start_storage(): + os.system(f"docker-compose -f {qdrant_yml} up -d --remove-orphans") + time.sleep(1) + + yield + os.system(f"docker-compose -f {qdrant_yml} down --remove-orphans") + + +@pytest.fixture(scope='function') +def tmp_collection_name(): + return uuid.uuid4().hex + @pytest.fixture def qdrant() -> qdrant_client.QdrantClient: diff --git a/tests/index/qdrant/test_find.py b/tests/index/qdrant/test_find.py index 610695e5c81..4dd2c27ba25 100644 --- a/tests/index/qdrant/test_find.py +++ b/tests/index/qdrant/test_find.py @@ -5,7 +5,7 @@ from docarray import BaseDoc, DocList from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray, TorchTensor -from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 +from tests.index.qdrant.fixtures import start_storage, tmp_collection_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -32,11 +32,11 @@ class TorchDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_simple_schema(qdrant_config, space): # noqa: F811 +def test_find_simple_schema(space): class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] - index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema](host='localhost') index_docs = [SimpleDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(SimpleDoc(tens=np.ones(10))) @@ -50,9 +50,8 @@ class SimpleSchema(BaseDoc): assert len(scores) == 5 -@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_torch(qdrant_config, space): # noqa: F811 - index = QdrantDocumentIndex[TorchDoc](db_config=qdrant_config) +def test_find_torch(): + index = QdrantDocumentIndex[TorchDoc](host='localhost') index_docs = [TorchDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(TorchDoc(tens=np.ones(10))) @@ -72,14 +71,13 @@ def test_find_torch(qdrant_config, space): # noqa: F811 @pytest.mark.tensorflow -@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_tensorflow(qdrant_config, space): # noqa: F811 +def test_find_tensorflow(): from docarray.typing import TensorFlowTensor class TfDoc(BaseDoc): tens: TensorFlowTensor[10] # type: ignore[valid-type] - index = QdrantDocumentIndex[TfDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[TfDoc](host='localhost') index_docs = [ TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) @@ -101,12 +99,12 @@ class TfDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_flat_schema(qdrant_config, space): # noqa: F811 +def test_find_flat_schema(space): class FlatSchema(BaseDoc): tens_one: NdArray = Field(dim=10, space=space) tens_two: NdArray = Field(dim=50, space=space) - index = QdrantDocumentIndex[FlatSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[FlatSchema](host='localhost') index_docs = [ FlatDoc(tens_one=np.zeros(10), tens_two=np.zeros(50)) for _ in range(10) @@ -129,7 +127,7 @@ class FlatSchema(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_nested_schema(qdrant_config, space): # noqa: F811 +def test_find_nested_schema(space): class SimpleDoc(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] @@ -141,7 +139,7 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(space=space, dim=10) - index = QdrantDocumentIndex[DeepNestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[DeepNestedDoc](host='localhost') index_docs = [ DeepNestedDoc( @@ -191,11 +189,13 @@ class DeepNestedDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_batched(qdrant_config, space): # noqa: F811 +def test_find_batched(space, tmp_collection_name): # noqa: F811 class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] - index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema]( + host='localhost', collection_name=tmp_collection_name + ) index_docs = [SimpleDoc(tens=vector) for vector in np.identity(10)] index.index(index_docs) diff --git a/tests/index/qdrant/test_index_get_del.py b/tests/index/qdrant/test_index_get_del.py index a1db816e58c..27cd3ee171b 100644 --- a/tests/index/qdrant/test_index_get_del.py +++ b/tests/index/qdrant/test_index_get_del.py @@ -10,7 +10,7 @@ from docarray.documents import ImageDoc, TextDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray, NdArrayEmbedding, TorchTensor -from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 +from tests.index.qdrant.fixtures import start_storage, tmp_collection_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -56,9 +56,11 @@ def ten_nested_docs(): @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_simple_schema( - ten_simple_docs, qdrant_config, use_docarray # noqa: F811 + ten_simple_docs, use_docarray, tmp_collection_name # noqa: F811 ): - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) @@ -67,8 +69,12 @@ def test_index_simple_schema( @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_flat_schema(ten_flat_docs, qdrant_config, use_docarray): # noqa: F811 - index = QdrantDocumentIndex[FlatDoc](db_config=qdrant_config) +def test_index_flat_schema( + ten_flat_docs, use_docarray, tmp_collection_name # noqa: F811 +): + index = QdrantDocumentIndex[FlatDoc]( + host='localhost', collection_name=tmp_collection_name + ) if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) @@ -78,9 +84,11 @@ def test_index_flat_schema(ten_flat_docs, qdrant_config, use_docarray): # noqa: @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_nested_schema( - ten_nested_docs, qdrant_config, use_docarray # noqa: F811 + ten_nested_docs, use_docarray, tmp_collection_name # noqa: F811 ): - index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[NestedDoc]( + host='localhost', collection_name=tmp_collection_name + ) if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) @@ -88,24 +96,26 @@ def test_index_nested_schema( assert index.num_docs() == 10 -def test_index_torch(qdrant_config): # noqa: F811 +def test_index_torch(tmp_collection_name): # noqa: F811 docs = [TorchDoc(tens=np.random.randn(10)) for _ in range(10)] assert isinstance(docs[0].tens, torch.Tensor) assert isinstance(docs[0].tens, TorchTensor) - index = QdrantDocumentIndex[TorchDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[TorchDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(docs) assert index.num_docs() == 10 @pytest.mark.skip('Qdrant does not support storing image tensors yet') -def test_index_builtin_docs(qdrant_config): # noqa: F811 +def test_index_builtin_docs(): # TextDoc class TextSchema(TextDoc): embedding: Optional[NdArrayEmbedding] = Field(dim=10) - index = QdrantDocumentIndex[TextSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[TextSchema](host='localhost') index.index( DocList[TextDoc]( @@ -133,16 +143,20 @@ class ImageSchema(ImageDoc): assert index.num_docs() == 10 -def test_get_key_error(ten_simple_docs, qdrant_config): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_get_key_error(ten_simple_docs, tmp_collection_name): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) with pytest.raises(KeyError): index['not_a_real_id'] -def test_del_single(ten_simple_docs, qdrant_config): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_del_single(ten_simple_docs, tmp_collection_name): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) # delete once assert index.num_docs() == 10 @@ -167,10 +181,12 @@ def test_del_single(ten_simple_docs, qdrant_config): # noqa: F811 assert index[id_].id == id_ -def test_del_multiple(ten_simple_docs, qdrant_config): # noqa: F811 +def test_del_multiple(ten_simple_docs, tmp_collection_name): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -185,16 +201,20 @@ def test_del_multiple(ten_simple_docs, qdrant_config): # noqa: F811 assert index[doc.id].id == doc.id -def test_del_key_error(ten_simple_docs, qdrant_config): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_del_key_error(ten_simple_docs, tmp_collection_name): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) with pytest.raises(KeyError): del index['not_a_real_id'] -def test_num_docs(ten_simple_docs, qdrant_config): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_num_docs(ten_simple_docs, tmp_collection_name): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -213,12 +233,12 @@ def test_num_docs(ten_simple_docs, qdrant_config): # noqa: F811 assert index.num_docs() == 10 -def test_multimodal_doc(qdrant_config): # noqa: F811 +def test_multimodal_doc(): class MyMultiModalDoc(BaseDoc): image: ImageDoc text: TextDoc - index = QdrantDocumentIndex[MyMultiModalDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[MyMultiModalDoc](host='localhost') doc = [ MyMultiModalDoc( @@ -233,3 +253,17 @@ class MyMultiModalDoc(BaseDoc): 0.0 ) assert index[id_].text.text == doc[0].text.text + + +def test_collection_name(): + class TextDoc(BaseDoc): + text: str = Field() + + class StringDoc(BaseDoc): + text: str = Field(col_type='payload') + + index = QdrantDocumentIndex[TextDoc]() + assert index.collection_name == TextDoc.__name__.lower() + + index = QdrantDocumentIndex[StringDoc]() + assert index.collection_name == StringDoc.__name__.lower() diff --git a/tests/index/qdrant/test_persist_data.py b/tests/index/qdrant/test_persist_data.py index 9fd54715a30..cea1fa70295 100644 --- a/tests/index/qdrant/test_persist_data.py +++ b/tests/index/qdrant/test_persist_data.py @@ -5,7 +5,7 @@ from docarray import BaseDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray -from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 +from tests.index.qdrant.fixtures import start_storage # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -19,18 +19,18 @@ class NestedDoc(BaseDoc): tens: NdArray[50] # type: ignore[valid-type] -def test_persist_and_restore(qdrant_config): # noqa: F811 +def test_persist_and_restore(): query = SimpleDoc(tens=np.random.random((10,))) # create index - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleDoc](host='localhost') index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) assert index.num_docs() == 10 find_results_before = index.find(query, search_field='tens', limit=5) # delete and restore del index - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleDoc](host='localhost') assert index.num_docs() == 10 find_results_after = index.find(query, search_field='tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): @@ -42,13 +42,13 @@ def test_persist_and_restore(qdrant_config): # noqa: F811 assert index.num_docs() == 15 -def test_persist_and_restore_nested(qdrant_config): # noqa: F811 +def test_persist_and_restore_nested(): query = NestedDoc( tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) ) # create index - index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[NestedDoc](host='localhost') index.index( [ NestedDoc( @@ -62,7 +62,7 @@ def test_persist_and_restore_nested(qdrant_config): # noqa: F811 # delete and restore del index - index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[NestedDoc](host='localhost') assert index.num_docs() == 10 find_results_after = index.find(query, search_field='d__tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): diff --git a/tests/index/weaviate/test_column_config_weaviate.py b/tests/index/weaviate/test_column_config_weaviate.py index 4789a6d707f..fd5a18d7560 100644 --- a/tests/index/weaviate/test_column_config_weaviate.py +++ b/tests/index/weaviate/test_column_config_weaviate.py @@ -34,3 +34,17 @@ class StringDoc(BaseDoc): dbconfig = WeaviateDocumentIndex.DBConfig(index_name="StringDoc") index = WeaviateDocumentIndex[StringDoc](db_config=dbconfig) assert get_text_field_data_type(index, "StringDoc") == "string" + + +def test_index_name(): + class TextDoc(BaseDoc): + text: str = Field() + + class StringDoc(BaseDoc): + text: str = Field(col_type="string") + + index = WeaviateDocumentIndex[TextDoc]() + assert index.index_name == TextDoc.__name__ + + index = WeaviateDocumentIndex[StringDoc]() + assert index.index_name == StringDoc.__name__ diff --git a/tests/index/weaviate/test_index_get_del_weaviate.py b/tests/index/weaviate/test_index_get_del_weaviate.py index ba4d6a27f5e..71b1609f03f 100644 --- a/tests/index/weaviate/test_index_get_del_weaviate.py +++ b/tests/index/weaviate/test_index_get_del_weaviate.py @@ -65,7 +65,7 @@ def test_index(weaviate_client, documents): def test_index_simple_schema(weaviate_client, ten_simple_docs): - index = WeaviateDocumentIndex[SimpleDoc]() + index = WeaviateDocumentIndex[SimpleDoc](index_name="Document") index.index(ten_simple_docs) assert index.num_docs() == 10