From 20d648a2c7c93668ec8d775508000aa105236652 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Thu, 4 May 2023 16:03:29 +0200 Subject: [PATCH 1/9] feat: index or collection name will default to doc-type name Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/elastic.py | 14 +++--- docarray/index/backends/qdrant.py | 67 +++++++++++++++-------------- docarray/index/backends/weaviate.py | 34 ++++++++------- 3 files changed, 60 insertions(+), 55 deletions(-) diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index 7edb04ad83e..ad09da4aa6d 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -74,12 +74,6 @@ def __init__(self, db_config=None, **kwargs): self._logger.debug('Elastic Search index is being initialized') # ElasticSearch client creation - if self._db_config.index_name is None: - id = uuid.uuid4().hex - self._db_config.index_name = 'index__' + id - - self._index_name = self._db_config.index_name - self._client = Elasticsearch( hosts=self._db_config.hosts, **self._db_config.es_config, @@ -108,7 +102,7 @@ def __init__(self, db_config=None, **kwargs): mappings['properties'][col_name] = self._create_index_mapping(col) # print(mappings['properties']) - if self._client.indices.exists(index=self._index_name): + if self._client.indices.exists(index=self.index_name): self._client_put_mapping(mappings) else: self._client_create(mappings) @@ -116,7 +110,11 @@ def __init__(self, db_config=None, **kwargs): if len(self._db_config.index_settings): self._client_put_settings(self._db_config.index_settings) - self._refresh(self._index_name) + self._refresh(self.index_name) + + @property + def index_name(self): + return self._db_config.index_name or TSchema.__name__ ############################################### # Inner classes for query builder and configs # diff --git a/docarray/index/backends/qdrant.py b/docarray/index/backends/qdrant.py index d60e9daf7fa..82b4270802a 100644 --- a/docarray/index/backends/qdrant.py +++ b/docarray/index/backends/qdrant.py @@ -45,7 +45,6 @@ TSchema = TypeVar('TSchema', bound=BaseDoc) - QDRANT_PY_VECTOR_TYPES: List[Any] = [np.ndarray, AbstractTensor] if torch_imported: import torch @@ -86,6 +85,10 @@ def __init__(self, db_config=None, **kwargs): self._initialize_collection() self._logger.info(f'{self.__class__.__name__} has been initialized') + @property + def collection_name(self): + return self._db_config.collection_name or TSchema.__name__ + @dataclass class Query: """Dataclass describing a query.""" @@ -97,11 +100,11 @@ class Query: class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__( - self, - vector_search_field: Optional[str] = None, - vector_filters: Optional[List[NdArray]] = None, - payload_filters: Optional[List[rest.Filter]] = None, - text_search_filters: Optional[List[Tuple[str, str]]] = None, + self, + vector_search_field: Optional[str] = None, + vector_filters: Optional[List[NdArray]] = None, + payload_filters: Optional[List[rest.Filter]] = None, + text_search_filters: Optional[List[Tuple[str, str]]] = None, ): self._vector_search_field: Optional[str] = vector_search_field self._vector_filters: List[NdArray] = vector_filters or [] @@ -140,7 +143,7 @@ def build(self, limit: int) -> 'QdrantDocumentIndex.Query': ) def find( # type: ignore[override] - self, query: NdArray, search_field: str = '' + self, query: NdArray, search_field: str = '' ) -> 'QdrantDocumentIndex.QueryBuilder': """ Find k-nearest neighbors of the query. @@ -163,7 +166,7 @@ def find( # type: ignore[override] ) def filter( # type: ignore[override] - self, filter_query: rest.Filter + self, filter_query: rest.Filter ) -> 'QdrantDocumentIndex.QueryBuilder': """Find documents in the index based on a filter query :param filter_query: a filter @@ -177,7 +180,7 @@ def filter( # type: ignore[override] ) def text_search( # type: ignore[override] - self, query: str, search_field: str = '' + self, query: str, search_field: str = '' ) -> 'QdrantDocumentIndex.QueryBuilder': """Find documents in the index based on a text search query @@ -211,7 +214,7 @@ class DBConfig(BaseDocIndex.DBConfig): timeout: Optional[float] = None host: Optional[str] = None path: Optional[str] = None - collection_name: str = 'documents' + collection_name: Optional[str] = None shard_number: Optional[int] = None replication_factor: Optional[int] = None write_consistency_factor: Optional[int] = None @@ -250,7 +253,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any: def _initialize_collection(self): try: - self._client.get_collection(self._db_config.collection_name) + self._client.get_collection(self.collection_name) except (UnexpectedResponse, RpcError, ValueError): vectors_config = { column_name: self._to_qdrant_vector_params(column_info) @@ -258,7 +261,7 @@ def _initialize_collection(self): if column_info.db_type == 'vector' } self._client.create_collection( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, vectors_config=vectors_config, shard_number=self._db_config.shard_number, replication_factor=self._db_config.replication_factor, @@ -270,7 +273,7 @@ def _initialize_collection(self): quantization_config=self._db_config.quantization_config, ) self._client.create_payload_index( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, field_name='__generated_vectors', field_schema=rest.PayloadSchemaType.KEYWORD, ) @@ -280,7 +283,7 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): # TODO: add batching the documents to avoid timeouts points = [self._build_point_from_row(row) for row in rows] self._client.upsert( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, points=points, ) @@ -288,7 +291,7 @@ def num_docs(self) -> int: """ Get the number of documents. """ - return self._client.count(collection_name=self._db_config.collection_name).count + return self._client.count(collection_name=self.collection_name).count def _del_items(self, doc_ids: Sequence[str]): items = self._get_items(doc_ids) @@ -298,17 +301,17 @@ def _del_items(self, doc_ids: Sequence[str]): raise KeyError('Document keys could not found: %s' % ','.join(missing_keys)) self._client.delete( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, points_selector=rest.PointIdsList( points=[self._to_qdrant_id(doc_id) for doc_id in doc_ids], ), ) def _get_items( - self, doc_ids: Sequence[str] + self, doc_ids: Sequence[str] ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: response, _ = self._client.scroll( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, scroll_filter=rest.Filter( must=[ rest.HasIdCondition( @@ -343,7 +346,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi # We perform semantic search with some vectors with Qdrant's search method # should be called points = self._client.search( # type: ignore[assignment] - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, query_vector=(query.vector_field, query.vector_query), # type: ignore[arg-type] query_filter=rest.Filter( must=[query.filter], @@ -364,7 +367,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi else: # Just filtering, so Qdrant's scroll has to be used instead points, _ = self._client.scroll( # type: ignore[assignment] - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, scroll_filter=query.filter, limit=query.limit, with_payload=True, @@ -375,7 +378,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi return self._dict_list_to_docarray(docs) def _execute_raw_query( - self, query: RawQuery + self, query: RawQuery ) -> Sequence[Union[rest.ScoredPoint, rest.Record]]: payload_filter = query.pop('filter', None) if payload_filter: @@ -388,7 +391,7 @@ def _execute_raw_query( if search_params: search_params = rest.SearchParams.parse_obj(search_params) # type: ignore[assignment] points = self._client.search( # type: ignore[assignment] - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, query_vector=query.pop('vector'), query_filter=payload_filter, search_params=search_params, @@ -397,7 +400,7 @@ def _execute_raw_query( else: # Just filtering, so Qdrant's scroll has to be used instead points, _ = self._client.scroll( # type: ignore[assignment] - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, scroll_filter=payload_filter, **query, ) @@ -405,7 +408,7 @@ def _execute_raw_query( return points def _find( - self, query: np.ndarray, limit: int, search_field: str = '' + self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: query_batched = np.expand_dims(query, axis=0) docs, scores = self._find_batched( @@ -414,10 +417,10 @@ def _find( return _FindResult(documents=docs[0], scores=scores[0]) # type: ignore[arg-type] def _find_batched( - self, queries: np.ndarray, limit: int, search_field: str = '' + self, queries: np.ndarray, limit: int, search_field: str = '' ) -> _FindResultBatched: responses = self._client.search_batch( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, requests=[ rest.SearchRequest( vector=rest.NamedVector( @@ -456,21 +459,21 @@ def _find_batched( ) def _filter( - self, filter_query: rest.Filter, limit: int + self, filter_query: rest.Filter, limit: int ) -> Union[DocList, List[Dict]]: query_batched = [filter_query] docs = self._filter_batched(filter_queries=query_batched, limit=limit) return docs[0] def _filter_batched( - self, filter_queries: Sequence[rest.Filter], limit: int + self, filter_queries: Sequence[rest.Filter], limit: int ) -> Union[List[DocList], List[List[Dict]]]: responses = [] for filter_query in filter_queries: # There is no batch scroll available in Qdrant client yet, so we need to # perform the queries one by one. It will be changed in the future versions. response, _ = self._client.scroll( - collection_name=self._db_config.collection_name, + collection_name=self.collection_name, scroll_filter=filter_query, limit=limit, with_payload=True, @@ -484,7 +487,7 @@ def _filter_batched( ] def _text_search( - self, query: str, limit: int, search_field: str = '' + self, query: str, limit: int, search_field: str = '' ) -> _FindResult: query_batched = [query] docs, scores = self._text_search_batched( @@ -493,7 +496,7 @@ def _text_search( return _FindResult(documents=docs[0], scores=scores[0]) # type: ignore[arg-type] def _text_search_batched( - self, queries: Sequence[str], limit: int, search_field: str = '' + self, queries: Sequence[str], limit: int, search_field: str = '' ) -> _FindResultBatched: filter_queries = [ rest.Filter( @@ -563,7 +566,7 @@ def _to_qdrant_vector_params(self, column_info: _ColumnInfo) -> rest.VectorParam ) def _convert_to_doc( - self, point: Union[rest.ScoredPoint, rest.Record] + self, point: Union[rest.ScoredPoint, rest.Record] ) -> Dict[str, Any]: document = cast(Dict[str, Any], point.payload) generated_vectors = document.pop('__generated_vectors') diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py index 368992645e2..cff992e6665 100644 --- a/docarray/index/backends/weaviate.py +++ b/docarray/index/backends/weaviate.py @@ -111,6 +111,10 @@ def __init__(self, db_config=None, **kwargs) -> None: self._set_properties() self._create_schema() + @property + def index_name(self): + return self._db_config.index_name or TSchema.__name__ + def _set_properties(self) -> None: field_overwrites = {"id": DOCUMENTID} @@ -207,13 +211,13 @@ def _create_schema(self) -> None: # and configure replication # we will update base on user feedback schema["properties"] = properties - schema["class"] = self._db_config.index_name + schema["class"] = self.index_name # TODO: Use exists() instead of contains() when available # see https://github.com/weaviate/weaviate-python-client/issues/232 if self._client.schema.contains(schema): logging.warning( - f"Found index {self._db_config.index_name} with schema {schema}. Will reuse existing schema." + f"Found index {self.index_name} with schema {schema}. Will reuse existing schema." ) else: self._client.schema.create_class(schema) @@ -223,7 +227,7 @@ class DBConfig(BaseDocIndex.DBConfig): """Dataclass that contains all "static" configurations of WeaviateDocumentIndex.""" host: str = 'http://localhost:8080' - index_name: str = 'Document' + index_name: Optional[str] = None username: Optional[str] = None password: Optional[str] = None scopes: List[str] = field(default_factory=lambda: ["offline_access"]) @@ -269,7 +273,7 @@ def _del_items(self, doc_ids: Sequence[str]): # see: https://weaviate.io/developers/weaviate/api/rest/batch#maximum-number-of-deletes-per-query while has_matches: results = self._client.batch.delete_objects( - class_name=self._db_config.index_name, + class_name=self.index_name, where=where_filter, ) @@ -279,14 +283,14 @@ def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]: self._overwrite_id(filter_query) results = ( - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_additional("vector") .with_where(filter_query) .with_limit(limit) .do() ) - docs = results["data"]["Get"][self._db_config.index_name] + docs = results["data"]["Get"][self.index_name] return [self._parse_weaviate_result(doc) for doc in docs] @@ -297,7 +301,7 @@ def _filter_batched( self._overwrite_id(filter_query) qs = [ - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_additional("vector") .with_where(filter_query) .with_limit(limit) @@ -370,7 +374,7 @@ def _find( score_name: Literal["certainty", "distance"] = "certainty", score_threshold: Optional[float] = None, ) -> _FindResult: - index_name = self._db_config.index_name + index_name = self.index_name if search_field: logging.warning( 'Argument search_field is not supported for WeaviateDocumentIndex. Ignoring.' @@ -474,7 +478,7 @@ def _find_batched( near_vector[score_name] = score_threshold q = ( - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_near_vector(near_vector) .with_limit(limit) .with_additional([score_name, "vector"]) @@ -507,7 +511,7 @@ def _get_items(self, doc_ids: Sequence[str]) -> List[Dict]: } results = ( - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_where(where_filter) .with_additional("vector") .do() @@ -515,7 +519,7 @@ def _get_items(self, doc_ids: Sequence[str]) -> List[Dict]: docs = [ self._parse_weaviate_result(doc) - for doc in results["data"]["Get"][self._db_config.index_name] + for doc in results["data"]["Get"][self.index_name] ] return docs @@ -554,7 +558,7 @@ def _parse_weaviate_result(self, result: Dict) -> Dict: def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): docs = self._transpose_col_value_dict(column_to_data) - index_name = self._db_config.index_name + index_name = self.index_name with self._client.batch as batch: for doc in docs: @@ -577,7 +581,7 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]): def _text_search( self, query: str, limit: int, search_field: str = '' ) -> _FindResult: - index_name = self._db_config.index_name + index_name = self.index_name bm25 = {"query": query, "properties": [search_field]} results = ( @@ -602,7 +606,7 @@ def _text_search_batched( bm25 = {"query": query, "properties": [search_field]} q = ( - self._client.query.get(self._db_config.index_name, self.properties) + self._client.query.get(self.index_name, self.properties) .with_bm25(bm25) .with_limit(limit) .with_additional(["score", "vector"]) @@ -663,7 +667,7 @@ def num_docs(self) -> int: """ Get the number of documents. """ - index_name = self._db_config.index_name + index_name = self.index_name result = self._client.query.aggregate(index_name).with_meta_count().do() # TODO: decorator to check for errors total_docs = result["data"]["Aggregate"][index_name][0]["meta"]["count"] From 18ea080e8b0857cccd326838eca255cc16b014d1 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Thu, 4 May 2023 16:09:10 +0200 Subject: [PATCH 2/9] test: add tests for new default names Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/elastic.py | 10 ++++- docarray/index/backends/elasticv7.py | 14 +++--- docarray/index/backends/qdrant.py | 45 +++++++++++-------- docarray/index/backends/weaviate.py | 9 +++- tests/index/elastic/v7/test_column_config.py | 9 ++++ tests/index/elastic/v8/test_find.py | 13 +++++- tests/index/qdrant/test_index_get_del.py | 9 ++++ .../weaviate/test_column_config_weaviate.py | 9 ++++ 8 files changed, 89 insertions(+), 29 deletions(-) diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index ad09da4aa6d..60304504d83 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -1,5 +1,4 @@ # mypy: ignore-errors -import uuid import warnings from collections import defaultdict from dataclasses import dataclass, field @@ -114,7 +113,14 @@ def __init__(self, db_config=None, **kwargs): @property def index_name(self): - return self._db_config.index_name or TSchema.__name__ + default_index_name = self._schema.__name__ if self._schema is not None else None + if default_index_name is None: + raise ValueError( + 'A ElasticDocIndex must be typed with a Document type.' + 'To do so, use the syntax: ElasticDocIndex[DocumentType]' + ) + + return self._db_config.index_name or default_index_name ############################################### # Inner classes for query builder and configs # diff --git a/docarray/index/backends/elasticv7.py b/docarray/index/backends/elasticv7.py index 623f11053bb..1782e921f62 100644 --- a/docarray/index/backends/elasticv7.py +++ b/docarray/index/backends/elasticv7.py @@ -119,7 +119,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any: f'args and kwargs not supported for `execute_query` on {type(self)}' ) - resp = self._client.search(index=self._index_name, body=query) + resp = self._client.search(index=self.index_name, body=query) docs, scores = self._format_response(resp) return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores)) @@ -161,20 +161,20 @@ def _form_search_body(self, query: np.ndarray, limit: int, search_field: str = ' ############################################### def _client_put_mapping(self, mappings: Dict[str, Any]): - self._client.indices.put_mapping(index=self._index_name, body=mappings) + self._client.indices.put_mapping(index=self.index_name, body=mappings) def _client_create(self, mappings: Dict[str, Any]): body = {'mappings': mappings} - self._client.indices.create(index=self._index_name, body=body) + self._client.indices.create(index=self.index_name, body=body) def _client_put_settings(self, settings: Dict[str, Any]): - self._client.indices.put_settings(index=self._index_name, body=settings) + self._client.indices.put_settings(index=self.index_name, body=settings) def _client_mget(self, ids: Sequence[str]): - return self._client.mget(index=self._index_name, body={'ids': ids}) + return self._client.mget(index=self.index_name, body={'ids': ids}) def _client_search(self, **kwargs): - return self._client.search(index=self._index_name, body=kwargs) + return self._client.search(index=self.index_name, body=kwargs) def _client_msearch(self, request: List[Dict[str, Any]]): - return self._client.msearch(index=self._index_name, body=request) + return self._client.msearch(index=self.index_name, body=request) diff --git a/docarray/index/backends/qdrant.py b/docarray/index/backends/qdrant.py index 82b4270802a..2aebe7eb39f 100644 --- a/docarray/index/backends/qdrant.py +++ b/docarray/index/backends/qdrant.py @@ -87,7 +87,16 @@ def __init__(self, db_config=None, **kwargs): @property def collection_name(self): - return self._db_config.collection_name or TSchema.__name__ + default_collection_name = ( + self._schema.__name__ if self._schema is not None else None + ) + if default_collection_name is None: + raise ValueError( + 'A QdrantDocumentIndex must be typed with a Document type.' + 'To do so, use the syntax: QdrantDocumentIndex[DocumentType]' + ) + + return self._db_config.collection_name or default_collection_name @dataclass class Query: @@ -100,11 +109,11 @@ class Query: class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__( - self, - vector_search_field: Optional[str] = None, - vector_filters: Optional[List[NdArray]] = None, - payload_filters: Optional[List[rest.Filter]] = None, - text_search_filters: Optional[List[Tuple[str, str]]] = None, + self, + vector_search_field: Optional[str] = None, + vector_filters: Optional[List[NdArray]] = None, + payload_filters: Optional[List[rest.Filter]] = None, + text_search_filters: Optional[List[Tuple[str, str]]] = None, ): self._vector_search_field: Optional[str] = vector_search_field self._vector_filters: List[NdArray] = vector_filters or [] @@ -143,7 +152,7 @@ def build(self, limit: int) -> 'QdrantDocumentIndex.Query': ) def find( # type: ignore[override] - self, query: NdArray, search_field: str = '' + self, query: NdArray, search_field: str = '' ) -> 'QdrantDocumentIndex.QueryBuilder': """ Find k-nearest neighbors of the query. @@ -166,7 +175,7 @@ def find( # type: ignore[override] ) def filter( # type: ignore[override] - self, filter_query: rest.Filter + self, filter_query: rest.Filter ) -> 'QdrantDocumentIndex.QueryBuilder': """Find documents in the index based on a filter query :param filter_query: a filter @@ -180,7 +189,7 @@ def filter( # type: ignore[override] ) def text_search( # type: ignore[override] - self, query: str, search_field: str = '' + self, query: str, search_field: str = '' ) -> 'QdrantDocumentIndex.QueryBuilder': """Find documents in the index based on a text search query @@ -308,7 +317,7 @@ def _del_items(self, doc_ids: Sequence[str]): ) def _get_items( - self, doc_ids: Sequence[str] + self, doc_ids: Sequence[str] ) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]: response, _ = self._client.scroll( collection_name=self.collection_name, @@ -378,7 +387,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi return self._dict_list_to_docarray(docs) def _execute_raw_query( - self, query: RawQuery + self, query: RawQuery ) -> Sequence[Union[rest.ScoredPoint, rest.Record]]: payload_filter = query.pop('filter', None) if payload_filter: @@ -408,7 +417,7 @@ def _execute_raw_query( return points def _find( - self, query: np.ndarray, limit: int, search_field: str = '' + self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: query_batched = np.expand_dims(query, axis=0) docs, scores = self._find_batched( @@ -417,7 +426,7 @@ def _find( return _FindResult(documents=docs[0], scores=scores[0]) # type: ignore[arg-type] def _find_batched( - self, queries: np.ndarray, limit: int, search_field: str = '' + self, queries: np.ndarray, limit: int, search_field: str = '' ) -> _FindResultBatched: responses = self._client.search_batch( collection_name=self.collection_name, @@ -459,14 +468,14 @@ def _find_batched( ) def _filter( - self, filter_query: rest.Filter, limit: int + self, filter_query: rest.Filter, limit: int ) -> Union[DocList, List[Dict]]: query_batched = [filter_query] docs = self._filter_batched(filter_queries=query_batched, limit=limit) return docs[0] def _filter_batched( - self, filter_queries: Sequence[rest.Filter], limit: int + self, filter_queries: Sequence[rest.Filter], limit: int ) -> Union[List[DocList], List[List[Dict]]]: responses = [] for filter_query in filter_queries: @@ -487,7 +496,7 @@ def _filter_batched( ] def _text_search( - self, query: str, limit: int, search_field: str = '' + self, query: str, limit: int, search_field: str = '' ) -> _FindResult: query_batched = [query] docs, scores = self._text_search_batched( @@ -496,7 +505,7 @@ def _text_search( return _FindResult(documents=docs[0], scores=scores[0]) # type: ignore[arg-type] def _text_search_batched( - self, queries: Sequence[str], limit: int, search_field: str = '' + self, queries: Sequence[str], limit: int, search_field: str = '' ) -> _FindResultBatched: filter_queries = [ rest.Filter( @@ -566,7 +575,7 @@ def _to_qdrant_vector_params(self, column_info: _ColumnInfo) -> rest.VectorParam ) def _convert_to_doc( - self, point: Union[rest.ScoredPoint, rest.Record] + self, point: Union[rest.ScoredPoint, rest.Record] ) -> Dict[str, Any]: document = cast(Dict[str, Any], point.payload) generated_vectors = document.pop('__generated_vectors') diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py index cff992e6665..c4b33e028d8 100644 --- a/docarray/index/backends/weaviate.py +++ b/docarray/index/backends/weaviate.py @@ -113,7 +113,14 @@ def __init__(self, db_config=None, **kwargs) -> None: @property def index_name(self): - return self._db_config.index_name or TSchema.__name__ + default_index_name = self._schema.__name__ if self._schema is not None else None + if default_index_name is None: + raise ValueError( + 'A WeaviateDocumentIndex must be typed with a Document type.' + 'To do so, use the syntax: WeaviateDocumentIndex[DocumentType]' + ) + + return self._db_config.index_name or default_index_name def _set_properties(self) -> None: field_overwrites = {"id": DOCUMENTID} diff --git a/tests/index/elastic/v7/test_column_config.py b/tests/index/elastic/v7/test_column_config.py index f1fa93d7748..652d4dbcdf1 100644 --- a/tests/index/elastic/v7/test_column_config.py +++ b/tests/index/elastic/v7/test_column_config.py @@ -129,3 +129,12 @@ class MyDoc(BaseDoc): } docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] + + +def test_index_name(): + class MyDoc(BaseDoc): + expected_attendees: dict = Field(col_type='integer_range') + time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + + index = ElasticV7DocIndex[MyDoc]() + assert index.index_name == MyDoc.__name__ diff --git a/tests/index/elastic/v8/test_find.py b/tests/index/elastic/v8/test_find.py index bb87755254c..b923d5ab09b 100644 --- a/tests/index/elastic/v8/test_find.py +++ b/tests/index/elastic/v8/test_find.py @@ -272,7 +272,9 @@ class MyDoc(BaseDoc): index = ElasticDocIndex[MyDoc]() index_docs = [ - MyDoc(id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'text {int(i/2)}') + MyDoc( + id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'text {int(i / 2)}' + ) for i in range(10) ] index.index(index_docs) @@ -327,3 +329,12 @@ class MyDoc(BaseDoc): docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == ['7', '6', '5', '4'] + + +def test_index_name(): + class MyDoc(BaseDoc): + expected_attendees: dict = Field(col_type='integer_range') + time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + + index = ElasticDocIndex[MyDoc]() + assert index.index_name == MyDoc.__name__ diff --git a/tests/index/qdrant/test_index_get_del.py b/tests/index/qdrant/test_index_get_del.py index a1db816e58c..b409f41b991 100644 --- a/tests/index/qdrant/test_index_get_del.py +++ b/tests/index/qdrant/test_index_get_del.py @@ -233,3 +233,12 @@ class MyMultiModalDoc(BaseDoc): 0.0 ) assert index[id_].text.text == doc[0].text.text + + +def test_collection_name(): + class MyDoc(BaseDoc): + expected_attendees: dict = Field(col_type='integer_range') + time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + + index = QdrantDocumentIndex[MyDoc]() + assert index.collection_name == MyDoc.__name__ diff --git a/tests/index/weaviate/test_column_config_weaviate.py b/tests/index/weaviate/test_column_config_weaviate.py index 4789a6d707f..3290cac68bd 100644 --- a/tests/index/weaviate/test_column_config_weaviate.py +++ b/tests/index/weaviate/test_column_config_weaviate.py @@ -34,3 +34,12 @@ class StringDoc(BaseDoc): dbconfig = WeaviateDocumentIndex.DBConfig(index_name="StringDoc") index = WeaviateDocumentIndex[StringDoc](db_config=dbconfig) assert get_text_field_data_type(index, "StringDoc") == "string" + + +def test_index_name(): + class MyDoc(BaseDoc): + expected_attendees: dict = Field(col_type='integer_range') + time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + + index = WeaviateDocumentIndex[MyDoc]() + assert index.index_name == MyDoc.__name__ From 64a07f35749645952f4b6eef460d32a9fc2cb085 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 4 May 2023 16:55:18 +0200 Subject: [PATCH 3/9] docs: add explanation for default table name Signed-off-by: Johannes Messner --- docs/user_guide/storing/docindex.md | 7 ++++++- docs/user_guide/storing/index_weaviate.md | 16 ++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/user_guide/storing/docindex.md b/docs/user_guide/storing/docindex.md index 2240b06cc86..ad2ee2b96d2 100644 --- a/docs/user_guide/storing/docindex.md +++ b/docs/user_guide/storing/docindex.md @@ -161,7 +161,12 @@ db.index(data) For `HnswDocumentIndex` you need to specify a `work_dir` where the data will be stored; for other backends you usually specify a `host` and a `port` instead. -Either way, if the location does not yet contain any data, we start from a blank slate. +In addition to a host and a port, most backends can also take an `index_name`, `table_name`, `collection_name` or similar. +This specifies the name of the index/table/collection that will be created in the database. +You don't have to specify this though: By default, this name will be taken from the name of the Document type that you use as schema. +For example, for `WeaviateDocumentIndex[MyDoc](...)` the data will be stored in a Weaviate Class of name `MyDoc`. + +In any case, if the location does not yet contain any data, we start from a blank slate. If the location already contains data from a previous session, it will be accessible through the Document Index. ## Index data diff --git a/docs/user_guide/storing/index_weaviate.md b/docs/user_guide/storing/index_weaviate.md index 83d84dd8fa9..09aaccc69b7 100644 --- a/docs/user_guide/storing/index_weaviate.md +++ b/docs/user_guide/storing/index_weaviate.md @@ -203,18 +203,18 @@ And a list of environment variables is [available on this page](https://weaviate Additionally, you can specify the below settings when you instantiate a configuration object in DocArray. -| name | type | explanation | default | example | -| ---- | ---- | ----------- | ------- | ------- | +| name | type | explanation | default | example | +| ---- | ---- | ----------- |------------------------------------------------------------------------| ------- | | **Category: General** | -| host | str | Weaviate instance url | http://localhost:8080 | +| host | str | Weaviate instance url | http://localhost:8080 | | **Category: Authentication** | -| username | str | Username known to the specified authentication provider (e.g. WCS) | None | `jp@weaviate.io` | -| password | str | Corresponding password | None | `p@ssw0rd` | -| auth_api_key | str | API key known to the Weaviate instance | None | `mys3cretk3y` | +| username | str | Username known to the specified authentication provider (e.g. WCS) | None | `jp@weaviate.io` | +| password | str | Corresponding password | None | `p@ssw0rd` | +| auth_api_key | str | API key known to the Weaviate instance | None | `mys3cretk3y` | | **Category: Data schema** | -| index_name | str | Class name to use to store the document | `Document` | +| index_name | str | Class name to use to store the document| The document class name, e.g. `MyDoc` for `WeaviateDocumentIndex[MyDoc]` | `Document` | | **Category: Embedded Weaviate** | -| embedded_options| EmbeddedOptions | Options for embedded weaviate | None | +| embedded_options| EmbeddedOptions | Options for embedded weaviate | None | The type `EmbeddedOptions` can be specified as described [here](https://weaviate.io/developers/weaviate/installation/embedded#embedded-options) From 44fcd80cae5f96c08dcb3fc6f383ba414d92a929 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Thu, 4 May 2023 17:04:36 +0200 Subject: [PATCH 4/9] docs: update documentation Signed-off-by: Joan Fontanals Martinez --- docarray/index/backends/elastic.py | 34 ++++++++++--------- docarray/index/backends/qdrant.py | 2 +- docarray/index/backends/weaviate.py | 2 +- docs/user_guide/storing/index_elastic.md | 2 +- docs/user_guide/storing/index_qdrant.md | 4 +++ tests/index/elastic/v7/test_column_config.py | 15 +++++--- tests/index/elastic/v8/test_column_config.py | 14 ++++++++ tests/index/qdrant/test_index_get_del.py | 15 +++++--- .../weaviate/test_column_config_weaviate.py | 15 +++++--- 9 files changed, 69 insertions(+), 34 deletions(-) diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index 60304504d83..c4da3ad5e0d 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -113,7 +113,9 @@ def __init__(self, db_config=None, **kwargs): @property def index_name(self): - default_index_name = self._schema.__name__ if self._schema is not None else None + default_index_name = ( + self._schema.__name__.lower() if self._schema is not None else None + ) if default_index_name is None: raise ValueError( 'A ElasticDocIndex must be typed with a Document type.' @@ -337,7 +339,7 @@ def _index( for row in data: request = { - '_index': self._index_name, + '_index': self.index_name, '_id': row['id'], } for col_name, col in self._column_infos.items(): @@ -353,13 +355,13 @@ def _index( warnings.warn(str(info)) if refresh: - self._refresh(self._index_name) + self._refresh(self.index_name) def num_docs(self) -> int: """ Get the number of documents. """ - return self._client.count(index=self._index_name)['count'] + return self._client.count(index=self.index_name)['count'] def _del_items( self, @@ -369,7 +371,7 @@ def _del_items( requests = [] for _id in doc_ids: requests.append( - {'_op_type': 'delete', '_index': self._index_name, '_id': _id} + {'_op_type': 'delete', '_index': self.index_name, '_id': _id} ) _, warning_info = self._send_requests(requests, chunk_size) @@ -379,7 +381,7 @@ def _del_items( ids = [info['delete']['_id'] for info in warning_info] warnings.warn(f'No document with id {ids} found') - self._refresh(self._index_name) + self._refresh(self.index_name) def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]: accumulated_docs = [] @@ -420,7 +422,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any: f'args and kwargs not supported for `execute_query` on {type(self)}' ) - resp = self._client.search(index=self._index_name, **query) + resp = self._client.search(index=self.index_name, **query) docs, scores = self._format_response(resp) return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores)) @@ -444,7 +446,7 @@ def _find_batched( ) -> _FindResultBatched: request = [] for query in queries: - head = {'index': self._index_name} + head = {'index': self.index_name} body = self._form_search_body(query, limit, search_field) request.extend([head, body]) @@ -473,7 +475,7 @@ def _filter_batched( ) -> List[List[Dict]]: request = [] for query in filter_queries: - head = {'index': self._index_name} + head = {'index': self.index_name} body = {'query': query, 'size': limit} request.extend([head, body]) @@ -503,7 +505,7 @@ def _text_search_batched( ) -> _FindResultBatched: request = [] for query in queries: - head = {'index': self._index_name} + head = {'index': self.index_name} body = self._form_text_search_body(query, limit, search_field) request.extend([head, body]) @@ -619,20 +621,20 @@ def _refresh(self, index_name: str): def _client_put_mapping(self, mappings: Dict[str, Any]): self._client.indices.put_mapping( - index=self._index_name, properties=mappings['properties'] + index=self.index_name, properties=mappings['properties'] ) def _client_create(self, mappings: Dict[str, Any]): - self._client.indices.create(index=self._index_name, mappings=mappings) + self._client.indices.create(index=self.index_name, mappings=mappings) def _client_put_settings(self, settings: Dict[str, Any]): - self._client.indices.put_settings(index=self._index_name, settings=settings) + self._client.indices.put_settings(index=self.index_name, settings=settings) def _client_mget(self, ids: Sequence[str]): - return self._client.mget(index=self._index_name, ids=ids) + return self._client.mget(index=self.index_name, ids=ids) def _client_search(self, **kwargs): - return self._client.search(index=self._index_name, **kwargs) + return self._client.search(index=self.index_name, **kwargs) def _client_msearch(self, request: List[Dict[str, Any]]): - return self._client.msearch(index=self._index_name, searches=request) + return self._client.msearch(index=self.index_name, searches=request) diff --git a/docarray/index/backends/qdrant.py b/docarray/index/backends/qdrant.py index 2aebe7eb39f..e2e503d593c 100644 --- a/docarray/index/backends/qdrant.py +++ b/docarray/index/backends/qdrant.py @@ -88,7 +88,7 @@ def __init__(self, db_config=None, **kwargs): @property def collection_name(self): default_collection_name = ( - self._schema.__name__ if self._schema is not None else None + self._schema.__name__.lower() if self._schema is not None else None ) if default_collection_name is None: raise ValueError( diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py index c4b33e028d8..5179f8cb588 100644 --- a/docarray/index/backends/weaviate.py +++ b/docarray/index/backends/weaviate.py @@ -745,7 +745,7 @@ class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__(self, document_index): self._queries = [ document_index._client.query.get( - document_index._db_config.index_name, document_index.properties + document_index.index_name, document_index.properties ) ] diff --git a/docs/user_guide/storing/index_elastic.md b/docs/user_guide/storing/index_elastic.md index a21528c46b8..eb1186f6d12 100644 --- a/docs/user_guide/storing/index_elastic.md +++ b/docs/user_guide/storing/index_elastic.md @@ -420,7 +420,7 @@ The following configs can be set in `DBConfig`: |-------------------|----------------------------------------------------------------------------------------------------------------------------------------|-------------------------| | `hosts` | Hostname of the Elasticsearch server | `http://localhost:9200` | | `es_config` | Other ES [configuration options](https://www.elastic.co/guide/en/elasticsearch/client/python-api/8.6/config.html) in a Dict and pass to `Elasticsearch` client constructor, e.g. `cloud_id`, `api_key` | None | -| `index_name` | Elasticsearch index name, the name of Elasticsearch index object | None | +| `index_name` | Elasticsearch index name, the name of Elasticsearch index object | None. Data will be stored in an index named after the Document type used as schema. | | `index_settings` | Other [index settings](https://www.elastic.co/guide/en/elasticsearch/reference/8.6/index-modules.html#index-modules-settings) in a Dict for creating the index | dict | | `index_mappings` | Other [index mappings](https://www.elastic.co/guide/en/elasticsearch/reference/8.6/mapping.html) in a Dict for creating the index | dict | diff --git a/docs/user_guide/storing/index_qdrant.md b/docs/user_guide/storing/index_qdrant.md index 7d832f1dd67..01249d01b7a 100644 --- a/docs/user_guide/storing/index_qdrant.md +++ b/docs/user_guide/storing/index_qdrant.md @@ -27,6 +27,10 @@ For general usage of a Document Index, see the [general user guide](./docindex.m runtime_config = QdrantDocumentIndex.RuntimeConfig() print(runtime_config) # shows default values ``` + + Note that the collection_name from the DBConfig is an Optional[str] with None as default value. This is because + the QdrantDocumentIndex will take the name the Document type that you use as schema. For example, for QdrantDocumentIndex[MyDoc](...) + the data will be stored in a collection name MyDoc if no specific collection_name is passed in the DBConfig. ```python import numpy as np diff --git a/tests/index/elastic/v7/test_column_config.py b/tests/index/elastic/v7/test_column_config.py index 652d4dbcdf1..61e82db00b0 100644 --- a/tests/index/elastic/v7/test_column_config.py +++ b/tests/index/elastic/v7/test_column_config.py @@ -132,9 +132,14 @@ class MyDoc(BaseDoc): def test_index_name(): - class MyDoc(BaseDoc): - expected_attendees: dict = Field(col_type='integer_range') - time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + class TextDoc(BaseDoc): + text: str = Field() - index = ElasticV7DocIndex[MyDoc]() - assert index.index_name == MyDoc.__name__ + class StringDoc(BaseDoc): + text: str = Field(col_type="string") + + index = ElasticV7DocIndex[TextDoc]() + assert index.index_name == TextDoc.__name__.lower() + + index = ElasticV7DocIndex[StringDoc]() + assert index.index_name == StringDoc.__name__.lower() diff --git a/tests/index/elastic/v8/test_column_config.py b/tests/index/elastic/v8/test_column_config.py index 0edd105697d..4853985fb7f 100644 --- a/tests/index/elastic/v8/test_column_config.py +++ b/tests/index/elastic/v8/test_column_config.py @@ -129,3 +129,17 @@ class MyDoc(BaseDoc): } docs, _ = index.execute_query(query) assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] + + +def test_index_name(): + class TextDoc(BaseDoc): + text: str = Field() + + class StringDoc(BaseDoc): + text: str = Field(col_type="string") + + index = ElasticDocIndex[TextDoc]() + assert index.index_name == TextDoc.__name__.lower() + + index = ElasticDocIndex[StringDoc]() + assert index.index_name == StringDoc.__name__.lower() diff --git a/tests/index/qdrant/test_index_get_del.py b/tests/index/qdrant/test_index_get_del.py index b409f41b991..7a5f316dc47 100644 --- a/tests/index/qdrant/test_index_get_del.py +++ b/tests/index/qdrant/test_index_get_del.py @@ -236,9 +236,14 @@ class MyMultiModalDoc(BaseDoc): def test_collection_name(): - class MyDoc(BaseDoc): - expected_attendees: dict = Field(col_type='integer_range') - time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + class TextDoc(BaseDoc): + text: str = Field() - index = QdrantDocumentIndex[MyDoc]() - assert index.collection_name == MyDoc.__name__ + class StringDoc(BaseDoc): + text: str = Field(col_type="string") + + index = QdrantDocumentIndex[TextDoc]() + assert index.collection_name == TextDoc.__name__.lower() + + index = QdrantDocumentIndex[StringDoc]() + assert index.collection_name == StringDoc.__name__.lower() diff --git a/tests/index/weaviate/test_column_config_weaviate.py b/tests/index/weaviate/test_column_config_weaviate.py index 3290cac68bd..fd5a18d7560 100644 --- a/tests/index/weaviate/test_column_config_weaviate.py +++ b/tests/index/weaviate/test_column_config_weaviate.py @@ -37,9 +37,14 @@ class StringDoc(BaseDoc): def test_index_name(): - class MyDoc(BaseDoc): - expected_attendees: dict = Field(col_type='integer_range') - time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + class TextDoc(BaseDoc): + text: str = Field() + + class StringDoc(BaseDoc): + text: str = Field(col_type="string") + + index = WeaviateDocumentIndex[TextDoc]() + assert index.index_name == TextDoc.__name__ - index = WeaviateDocumentIndex[MyDoc]() - assert index.index_name == MyDoc.__name__ + index = WeaviateDocumentIndex[StringDoc]() + assert index.index_name == StringDoc.__name__ From 4d2c5233f751f75e4f76c2c5b228ca8d1b2e1db8 Mon Sep 17 00:00:00 2001 From: AnneY Date: Fri, 5 May 2023 20:41:11 +0800 Subject: [PATCH 5/9] fix: elastic tests Signed-off-by: AnneY --- tests/index/elastic/fixture.py | 6 ++ tests/index/elastic/v7/test_column_config.py | 2 +- tests/index/elastic/v7/test_find.py | 12 ++-- tests/index/elastic/v7/test_index_get_del.py | 61 ++++++++++++-------- tests/index/elastic/v8/test_column_config.py | 5 +- tests/index/elastic/v8/test_find.py | 22 ++++--- tests/index/elastic/v8/test_index_get_del.py | 61 ++++++++++++-------- 7 files changed, 103 insertions(+), 66 deletions(-) diff --git a/tests/index/elastic/fixture.py b/tests/index/elastic/fixture.py index 812f0f09d51..ef7766acd0c 100644 --- a/tests/index/elastic/fixture.py +++ b/tests/index/elastic/fixture.py @@ -1,5 +1,6 @@ import os import time +import uuid import numpy as np import pytest @@ -87,3 +88,8 @@ def ten_deep_nested_docs(): DeepNestedDoc(d=NestedDoc(d=SimpleDoc(tens=np.random.randn(10)))) for _ in range(10) ] + + +@pytest.fixture(scope='function') +def tmp_index_name(): + return uuid.uuid4().hex diff --git a/tests/index/elastic/v7/test_column_config.py b/tests/index/elastic/v7/test_column_config.py index 61e82db00b0..812734aef26 100644 --- a/tests/index/elastic/v7/test_column_config.py +++ b/tests/index/elastic/v7/test_column_config.py @@ -136,7 +136,7 @@ class TextDoc(BaseDoc): text: str = Field() class StringDoc(BaseDoc): - text: str = Field(col_type="string") + text: str = Field(col_type='text') index = ElasticV7DocIndex[TextDoc]() assert index.index_name == TextDoc.__name__.lower() diff --git a/tests/index/elastic/v7/test_find.py b/tests/index/elastic/v7/test_find.py index d54b3b0480d..e82eff3015a 100644 --- a/tests/index/elastic/v7/test_find.py +++ b/tests/index/elastic/v7/test_find.py @@ -6,8 +6,12 @@ from docarray import BaseDoc from docarray.index import ElasticV7DocIndex from docarray.typing import NdArray, TorchTensor -from tests.index.elastic.fixture import start_storage_v7 # noqa: F401 -from tests.index.elastic.fixture import FlatDoc, SimpleDoc +from tests.index.elastic.fixture import ( # noqa: F401 + FlatDoc, + SimpleDoc, + start_storage_v7, + tmp_index_name, +) pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -243,13 +247,13 @@ class MyDoc(BaseDoc): assert doc.text.index(query) >= 0 -def test_query_builder(): +def test_query_builder(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): tens: NdArray[10] num: int text: str - index = ElasticV7DocIndex[MyDoc]() + index = ElasticV7DocIndex[MyDoc](index_name=tmp_index_name) index_docs = [ MyDoc( id=f'{i}', tens=np.random.rand(10), num=int(i / 2), text=f'text {int(i/2)}' diff --git a/tests/index/elastic/v7/test_index_get_del.py b/tests/index/elastic/v7/test_index_get_del.py index e6e6baf3e60..27cf3b5642d 100644 --- a/tests/index/elastic/v7/test_index_get_del.py +++ b/tests/index/elastic/v7/test_index_get_del.py @@ -18,14 +18,17 @@ ten_flat_docs, ten_nested_docs, ten_simple_docs, + tmp_index_name, ) pytestmark = [pytest.mark.slow, pytest.mark.index] @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 - index = ElasticV7DocIndex[SimpleDoc]() +def test_index_simple_schema( + ten_simple_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) @@ -34,8 +37,8 @@ def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 - index = ElasticV7DocIndex[FlatDoc]() +def test_index_flat_schema(ten_flat_docs, use_docarray, tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[FlatDoc](index_name=tmp_index_name) if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) @@ -44,8 +47,10 @@ def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 - index = ElasticV7DocIndex[NestedDoc]() +def test_index_nested_schema( + ten_nested_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticV7DocIndex[NestedDoc](index_name=tmp_index_name) if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) @@ -54,8 +59,10 @@ def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: F811 - index = ElasticV7DocIndex[DeepNestedDoc]() +def test_index_deep_nested_schema( + ten_deep_nested_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticV7DocIndex[DeepNestedDoc](index_name=tmp_index_name) if use_docarray: ten_deep_nested_docs = DocList[DeepNestedDoc](ten_deep_nested_docs) @@ -63,9 +70,11 @@ def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: assert index.num_docs() == 10 -def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 +def test_get_single( + ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_index_name # noqa: F811 +): # simple - index = ElasticV7DocIndex[SimpleDoc]() + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -75,7 +84,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert np.all(index[id_].tens == d.tens) # flat - index = ElasticV7DocIndex[FlatDoc]() + index = ElasticV7DocIndex[FlatDoc](index_name=tmp_index_name + 'flat') index.index(ten_flat_docs) assert index.num_docs() == 10 @@ -86,7 +95,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert np.all(index[id_].tens_two == d.tens_two) # nested - index = ElasticV7DocIndex[NestedDoc]() + index = ElasticV7DocIndex[NestedDoc](index_name=tmp_index_name + 'nested') index.index(ten_nested_docs) assert index.num_docs() == 10 @@ -97,11 +106,13 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert np.all(index[id_].d.tens == d.d.tens) -def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 +def test_get_multiple( + ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_index_name # noqa: F811 +): docs_to_get_idx = [0, 2, 4, 6, 8] # simple - index = ElasticV7DocIndex[SimpleDoc]() + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -113,7 +124,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: assert np.all(d_out.tens == d_in.tens) # flat - index = ElasticV7DocIndex[FlatDoc]() + index = ElasticV7DocIndex[FlatDoc](index_name=tmp_index_name + 'flat') index.index(ten_flat_docs) assert index.num_docs() == 10 @@ -126,7 +137,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: assert np.all(d_out.tens_two == d_in.tens_two) # nested - index = ElasticV7DocIndex[NestedDoc]() + index = ElasticV7DocIndex[NestedDoc](index_name=tmp_index_name + 'nested') index.index(ten_nested_docs) assert index.num_docs() == 10 @@ -147,16 +158,16 @@ def test_get_key_error(ten_simple_docs): # noqa: F811 index['not_a_real_id'] -def test_persisting(ten_simple_docs): # noqa: F811 - index = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') +def test_persisting(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) - index2 = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') + index2 = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) assert index2.num_docs() == 10 -def test_del_single(ten_simple_docs): # noqa: F811 - index = ElasticV7DocIndex[SimpleDoc]() +def test_del_single(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) # delete once assert index.num_docs() == 10 @@ -183,10 +194,10 @@ def test_del_single(ten_simple_docs): # noqa: F811 assert np.all(index[id_].tens == d.tens) -def test_del_multiple(ten_simple_docs): # noqa: F811 +def test_del_multiple(ten_simple_docs, tmp_index_name): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - index = ElasticV7DocIndex[SimpleDoc]() + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -210,8 +221,8 @@ def test_del_key_error(ten_simple_docs): # noqa: F811 del index['not_a_real_id'] -def test_num_docs(ten_simple_docs): # noqa: F811 - index = ElasticV7DocIndex[SimpleDoc]() +def test_num_docs(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 diff --git a/tests/index/elastic/v8/test_column_config.py b/tests/index/elastic/v8/test_column_config.py index 4853985fb7f..852d018db50 100644 --- a/tests/index/elastic/v8/test_column_config.py +++ b/tests/index/elastic/v8/test_column_config.py @@ -3,7 +3,8 @@ from docarray import BaseDoc from docarray.index import ElasticDocIndex -from tests.index.elastic.fixture import start_storage_v8 # noqa: F401 + +# from tests.index.elastic.fixture import start_storage_v8 # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] @@ -136,7 +137,7 @@ class TextDoc(BaseDoc): text: str = Field() class StringDoc(BaseDoc): - text: str = Field(col_type="string") + text: str = Field(col_type='text') index = ElasticDocIndex[TextDoc]() assert index.index_name == TextDoc.__name__.lower() diff --git a/tests/index/elastic/v8/test_find.py b/tests/index/elastic/v8/test_find.py index b923d5ab09b..f3cf6d6119a 100644 --- a/tests/index/elastic/v8/test_find.py +++ b/tests/index/elastic/v8/test_find.py @@ -6,18 +6,22 @@ from docarray import BaseDoc from docarray.index import ElasticDocIndex from docarray.typing import NdArray, TorchTensor -from tests.index.elastic.fixture import start_storage_v8 # noqa: F401 -from tests.index.elastic.fixture import FlatDoc, SimpleDoc +from tests.index.elastic.fixture import ( # noqa: F401 + FlatDoc, + SimpleDoc, + start_storage_v8, + tmp_index_name, +) pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] @pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) -def test_find_simple_schema(similarity): +def test_find_simple_schema(similarity, tmp_index_name): # noqa: F811 class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(similarity=similarity) - index = ElasticDocIndex[SimpleSchema]() + index = ElasticDocIndex[SimpleSchema](index_name=tmp_index_name) index_docs = [] for _ in range(10): @@ -37,12 +41,12 @@ class SimpleSchema(BaseDoc): @pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) -def test_find_flat_schema(similarity): +def test_find_flat_schema(similarity, tmp_index_name): # noqa: F811 class FlatSchema(BaseDoc): tens_one: NdArray = Field(dims=10, similarity=similarity) tens_two: NdArray = Field(dims=50, similarity=similarity) - index = ElasticDocIndex[FlatSchema]() + index = ElasticDocIndex[FlatSchema](index_name=tmp_index_name) index_docs = [] for _ in range(10): @@ -75,7 +79,7 @@ class FlatSchema(BaseDoc): @pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) -def test_find_nested_schema(similarity): +def test_find_nested_schema(similarity, tmp_index_name): # noqa: F811 class SimpleDoc(BaseDoc): tens: NdArray[10] = Field(similarity=similarity) @@ -87,7 +91,7 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(similarity=similarity, dims=10) - index = ElasticDocIndex[DeepNestedDoc]() + index = ElasticDocIndex[DeepNestedDoc](index_name=tmp_index_name) index_docs = [] for _ in range(10): @@ -337,4 +341,4 @@ class MyDoc(BaseDoc): time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') index = ElasticDocIndex[MyDoc]() - assert index.index_name == MyDoc.__name__ + assert index.index_name == MyDoc.__name__.lower() diff --git a/tests/index/elastic/v8/test_index_get_del.py b/tests/index/elastic/v8/test_index_get_del.py index 8efd66429b0..4e34b712fcb 100644 --- a/tests/index/elastic/v8/test_index_get_del.py +++ b/tests/index/elastic/v8/test_index_get_del.py @@ -18,14 +18,17 @@ ten_flat_docs, ten_nested_docs, ten_simple_docs, + tmp_index_name, ) pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_index_simple_schema( + ten_simple_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) @@ -34,8 +37,8 @@ def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 - index = ElasticDocIndex[FlatDoc]() +def test_index_flat_schema(ten_flat_docs, use_docarray, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[FlatDoc](index_name=tmp_index_name) if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) @@ -44,8 +47,10 @@ def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 - index = ElasticDocIndex[NestedDoc]() +def test_index_nested_schema( + ten_nested_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticDocIndex[NestedDoc](index_name=tmp_index_name) if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) @@ -54,8 +59,10 @@ def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: F811 - index = ElasticDocIndex[DeepNestedDoc]() +def test_index_deep_nested_schema( + ten_deep_nested_docs, use_docarray, tmp_index_name # noqa: F811 +): + index = ElasticDocIndex[DeepNestedDoc](index_name=tmp_index_name) if use_docarray: ten_deep_nested_docs = DocList[DeepNestedDoc](ten_deep_nested_docs) @@ -63,9 +70,11 @@ def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: assert index.num_docs() == 10 -def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 +def test_get_single( + ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_index_name # noqa: F811 +): # simple - index = ElasticDocIndex[SimpleDoc]() + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -75,7 +84,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert np.all(index[id_].tens == d.tens) # flat - index = ElasticDocIndex[FlatDoc]() + index = ElasticDocIndex[FlatDoc](index_name=tmp_index_name + 'flat') index.index(ten_flat_docs) assert index.num_docs() == 10 @@ -86,7 +95,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert np.all(index[id_].tens_two == d.tens_two) # nested - index = ElasticDocIndex[NestedDoc]() + index = ElasticDocIndex[NestedDoc](index_name=tmp_index_name + 'nested') index.index(ten_nested_docs) assert index.num_docs() == 10 @@ -97,11 +106,13 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F assert np.all(index[id_].d.tens == d.d.tens) -def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 +def test_get_multiple( + ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_index_name # noqa: F811 +): docs_to_get_idx = [0, 2, 4, 6, 8] # simple - index = ElasticDocIndex[SimpleDoc]() + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -113,7 +124,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: assert np.all(d_out.tens == d_in.tens) # flat - index = ElasticDocIndex[FlatDoc]() + index = ElasticDocIndex[FlatDoc](index_name=tmp_index_name + 'flat') index.index(ten_flat_docs) assert index.num_docs() == 10 @@ -126,7 +137,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: assert np.all(d_out.tens_two == d_in.tens_two) # nested - index = ElasticDocIndex[NestedDoc]() + index = ElasticDocIndex[NestedDoc](index_name=tmp_index_name + 'nested') index.index(ten_nested_docs) assert index.num_docs() == 10 @@ -147,16 +158,16 @@ def test_get_key_error(ten_simple_docs): # noqa: F811 index['not_a_real_id'] -def test_persisting(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc](index_name='test_persisting') +def test_persisting(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) - index2 = ElasticDocIndex[SimpleDoc](index_name='test_persisting') + index2 = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) assert index2.num_docs() == 10 -def test_del_single(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_del_single(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) # delete once assert index.num_docs() == 10 @@ -183,10 +194,10 @@ def test_del_single(ten_simple_docs): # noqa: F811 assert np.all(index[id_].tens == d.tens) -def test_del_multiple(ten_simple_docs): # noqa: F811 +def test_del_multiple(ten_simple_docs, tmp_index_name): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - index = ElasticDocIndex[SimpleDoc]() + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -210,8 +221,8 @@ def test_del_key_error(ten_simple_docs): # noqa: F811 del index['not_a_real_id'] -def test_num_docs(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_num_docs(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) assert index.num_docs() == 10 From 77827ee75e78e68778d94a8f7fbcf74fa98e7a19 Mon Sep 17 00:00:00 2001 From: AnneY Date: Fri, 5 May 2023 20:55:39 +0800 Subject: [PATCH 6/9] fix: weaviate tests Signed-off-by: AnneY --- tests/index/weaviate/test_index_get_del_weaviate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/index/weaviate/test_index_get_del_weaviate.py b/tests/index/weaviate/test_index_get_del_weaviate.py index ba4d6a27f5e..71b1609f03f 100644 --- a/tests/index/weaviate/test_index_get_del_weaviate.py +++ b/tests/index/weaviate/test_index_get_del_weaviate.py @@ -65,7 +65,7 @@ def test_index(weaviate_client, documents): def test_index_simple_schema(weaviate_client, ten_simple_docs): - index = WeaviateDocumentIndex[SimpleDoc]() + index = WeaviateDocumentIndex[SimpleDoc](index_name="Document") index.index(ten_simple_docs) assert index.num_docs() == 10 From 7358f5d83877f05a9f05d80cb282ee7e58b73bd8 Mon Sep 17 00:00:00 2001 From: AnneY Date: Fri, 5 May 2023 21:28:22 +0800 Subject: [PATCH 7/9] fix: qdrant tests Signed-off-by: AnneY --- tests/index/qdrant/docker-compose.yml | 12 +++++ tests/index/qdrant/fixtures.py | 21 ++++++++ tests/index/qdrant/test_find.py | 30 +++++------ tests/index/qdrant/test_index_get_del.py | 66 +++++++++++++++--------- tests/index/qdrant/test_persist_data.py | 22 +++++--- 5 files changed, 105 insertions(+), 46 deletions(-) create mode 100644 tests/index/qdrant/docker-compose.yml diff --git a/tests/index/qdrant/docker-compose.yml b/tests/index/qdrant/docker-compose.yml new file mode 100644 index 00000000000..a3a57b7c87f --- /dev/null +++ b/tests/index/qdrant/docker-compose.yml @@ -0,0 +1,12 @@ +version: '3.8' + +services: + qdrant: + image: qdrant/qdrant:v1.1.2 + ports: + - "6333:6333" + - "6334:6334" + ulimits: # Only required for tests, as there are a lot of collections created + nofile: + soft: 65535 + hard: 65535 \ No newline at end of file diff --git a/tests/index/qdrant/fixtures.py b/tests/index/qdrant/fixtures.py index d44a0950d35..d48372ad0f0 100644 --- a/tests/index/qdrant/fixtures.py +++ b/tests/index/qdrant/fixtures.py @@ -1,8 +1,29 @@ +import os +import time +import uuid + import pytest import qdrant_client from docarray.index import QdrantDocumentIndex +cur_dir = os.path.dirname(os.path.abspath(__file__)) +qdrant_yml = os.path.abspath(os.path.join(cur_dir, 'docker-compose.yml')) + + +@pytest.fixture(scope='session', autouse=True) +def start_storage(): + os.system(f"docker-compose -f {qdrant_yml} up -d --remove-orphans") + time.sleep(1) + + yield + os.system(f"docker-compose -f {qdrant_yml} down --remove-orphans") + + +@pytest.fixture(scope='function') +def tmp_collection_name(): + return uuid.uuid4().hex + @pytest.fixture def qdrant() -> qdrant_client.QdrantClient: diff --git a/tests/index/qdrant/test_find.py b/tests/index/qdrant/test_find.py index 610695e5c81..4dd2c27ba25 100644 --- a/tests/index/qdrant/test_find.py +++ b/tests/index/qdrant/test_find.py @@ -5,7 +5,7 @@ from docarray import BaseDoc, DocList from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray, TorchTensor -from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 +from tests.index.qdrant.fixtures import start_storage, tmp_collection_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -32,11 +32,11 @@ class TorchDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_simple_schema(qdrant_config, space): # noqa: F811 +def test_find_simple_schema(space): class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] - index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema](host='localhost') index_docs = [SimpleDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(SimpleDoc(tens=np.ones(10))) @@ -50,9 +50,8 @@ class SimpleSchema(BaseDoc): assert len(scores) == 5 -@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_torch(qdrant_config, space): # noqa: F811 - index = QdrantDocumentIndex[TorchDoc](db_config=qdrant_config) +def test_find_torch(): + index = QdrantDocumentIndex[TorchDoc](host='localhost') index_docs = [TorchDoc(tens=np.zeros(10)) for _ in range(10)] index_docs.append(TorchDoc(tens=np.ones(10))) @@ -72,14 +71,13 @@ def test_find_torch(qdrant_config, space): # noqa: F811 @pytest.mark.tensorflow -@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_tensorflow(qdrant_config, space): # noqa: F811 +def test_find_tensorflow(): from docarray.typing import TensorFlowTensor class TfDoc(BaseDoc): tens: TensorFlowTensor[10] # type: ignore[valid-type] - index = QdrantDocumentIndex[TfDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[TfDoc](host='localhost') index_docs = [ TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) @@ -101,12 +99,12 @@ class TfDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_flat_schema(qdrant_config, space): # noqa: F811 +def test_find_flat_schema(space): class FlatSchema(BaseDoc): tens_one: NdArray = Field(dim=10, space=space) tens_two: NdArray = Field(dim=50, space=space) - index = QdrantDocumentIndex[FlatSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[FlatSchema](host='localhost') index_docs = [ FlatDoc(tens_one=np.zeros(10), tens_two=np.zeros(50)) for _ in range(10) @@ -129,7 +127,7 @@ class FlatSchema(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_nested_schema(qdrant_config, space): # noqa: F811 +def test_find_nested_schema(space): class SimpleDoc(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] @@ -141,7 +139,7 @@ class DeepNestedDoc(BaseDoc): d: NestedDoc tens: NdArray = Field(space=space, dim=10) - index = QdrantDocumentIndex[DeepNestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[DeepNestedDoc](host='localhost') index_docs = [ DeepNestedDoc( @@ -191,11 +189,13 @@ class DeepNestedDoc(BaseDoc): @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) -def test_find_batched(qdrant_config, space): # noqa: F811 +def test_find_batched(space, tmp_collection_name): # noqa: F811 class SimpleSchema(BaseDoc): tens: NdArray[10] = Field(space=space) # type: ignore[valid-type] - index = QdrantDocumentIndex[SimpleSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleSchema]( + host='localhost', collection_name=tmp_collection_name + ) index_docs = [SimpleDoc(tens=vector) for vector in np.identity(10)] index.index(index_docs) diff --git a/tests/index/qdrant/test_index_get_del.py b/tests/index/qdrant/test_index_get_del.py index 7a5f316dc47..643b989ceac 100644 --- a/tests/index/qdrant/test_index_get_del.py +++ b/tests/index/qdrant/test_index_get_del.py @@ -10,7 +10,7 @@ from docarray.documents import ImageDoc, TextDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray, NdArrayEmbedding, TorchTensor -from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 +from tests.index.qdrant.fixtures import start_storage, tmp_collection_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -56,9 +56,11 @@ def ten_nested_docs(): @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_simple_schema( - ten_simple_docs, qdrant_config, use_docarray # noqa: F811 + ten_simple_docs, use_docarray, tmp_collection_name # noqa: F811 ): - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) @@ -67,8 +69,12 @@ def test_index_simple_schema( @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_flat_schema(ten_flat_docs, qdrant_config, use_docarray): # noqa: F811 - index = QdrantDocumentIndex[FlatDoc](db_config=qdrant_config) +def test_index_flat_schema( + ten_flat_docs, use_docarray, tmp_collection_name # noqa: F811 +): + index = QdrantDocumentIndex[FlatDoc]( + host='localhost', collection_name=tmp_collection_name + ) if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) @@ -78,9 +84,11 @@ def test_index_flat_schema(ten_flat_docs, qdrant_config, use_docarray): # noqa: @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_nested_schema( - ten_nested_docs, qdrant_config, use_docarray # noqa: F811 + ten_nested_docs, use_docarray, tmp_collection_name # noqa: F811 ): - index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[NestedDoc]( + host='localhost', collection_name=tmp_collection_name + ) if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) @@ -88,24 +96,28 @@ def test_index_nested_schema( assert index.num_docs() == 10 -def test_index_torch(qdrant_config): # noqa: F811 +def test_index_torch(tmp_collection_name): # noqa: F811 docs = [TorchDoc(tens=np.random.randn(10)) for _ in range(10)] assert isinstance(docs[0].tens, torch.Tensor) assert isinstance(docs[0].tens, TorchTensor) - index = QdrantDocumentIndex[TorchDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[TorchDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(docs) assert index.num_docs() == 10 @pytest.mark.skip('Qdrant does not support storing image tensors yet') -def test_index_builtin_docs(qdrant_config): # noqa: F811 +def test_index_builtin_docs(tmp_collection_name): # noqa: F811 # TextDoc class TextSchema(TextDoc): embedding: Optional[NdArrayEmbedding] = Field(dim=10) - index = QdrantDocumentIndex[TextSchema](db_config=qdrant_config) + index = QdrantDocumentIndex[TextSchema]( + host='localhost', collection_name=tmp_collection_name + ) index.index( DocList[TextDoc]( @@ -133,16 +145,18 @@ class ImageSchema(ImageDoc): assert index.num_docs() == 10 -def test_get_key_error(ten_simple_docs, qdrant_config): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_get_key_error(ten_simple_docs): + index = QdrantDocumentIndex[SimpleDoc](host='localhost') index.index(ten_simple_docs) with pytest.raises(KeyError): index['not_a_real_id'] -def test_del_single(ten_simple_docs, qdrant_config): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_del_single(ten_simple_docs, tmp_collection_name): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) # delete once assert index.num_docs() == 10 @@ -167,10 +181,12 @@ def test_del_single(ten_simple_docs, qdrant_config): # noqa: F811 assert index[id_].id == id_ -def test_del_multiple(ten_simple_docs, qdrant_config): # noqa: F811 +def test_del_multiple(ten_simple_docs, tmp_collection_name): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -185,16 +201,18 @@ def test_del_multiple(ten_simple_docs, qdrant_config): # noqa: F811 assert index[doc.id].id == doc.id -def test_del_key_error(ten_simple_docs, qdrant_config): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_del_key_error(ten_simple_docs): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc](host='localhost') index.index(ten_simple_docs) with pytest.raises(KeyError): del index['not_a_real_id'] -def test_num_docs(ten_simple_docs, qdrant_config): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) +def test_num_docs(ten_simple_docs, tmp_collection_name): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -213,12 +231,12 @@ def test_num_docs(ten_simple_docs, qdrant_config): # noqa: F811 assert index.num_docs() == 10 -def test_multimodal_doc(qdrant_config): # noqa: F811 +def test_multimodal_doc(): # noqa: F811 class MyMultiModalDoc(BaseDoc): image: ImageDoc text: TextDoc - index = QdrantDocumentIndex[MyMultiModalDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[MyMultiModalDoc](host='localhost') doc = [ MyMultiModalDoc( @@ -240,7 +258,7 @@ class TextDoc(BaseDoc): text: str = Field() class StringDoc(BaseDoc): - text: str = Field(col_type="string") + text: str = Field(col_type='payload') index = QdrantDocumentIndex[TextDoc]() assert index.collection_name == TextDoc.__name__.lower() diff --git a/tests/index/qdrant/test_persist_data.py b/tests/index/qdrant/test_persist_data.py index 9fd54715a30..88ab2ec342a 100644 --- a/tests/index/qdrant/test_persist_data.py +++ b/tests/index/qdrant/test_persist_data.py @@ -5,7 +5,7 @@ from docarray import BaseDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray -from tests.index.qdrant.fixtures import qdrant, qdrant_config # noqa: F401 +from tests.index.qdrant.fixtures import start_storage, tmp_collection_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -19,18 +19,22 @@ class NestedDoc(BaseDoc): tens: NdArray[50] # type: ignore[valid-type] -def test_persist_and_restore(qdrant_config): # noqa: F811 +def test_persist_and_restore(tmp_collection_name): # noqa: F811 query = SimpleDoc(tens=np.random.random((10,))) # create index - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) assert index.num_docs() == 10 find_results_before = index.find(query, search_field='tens', limit=5) # delete and restore del index - index = QdrantDocumentIndex[SimpleDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) assert index.num_docs() == 10 find_results_after = index.find(query, search_field='tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): @@ -42,13 +46,15 @@ def test_persist_and_restore(qdrant_config): # noqa: F811 assert index.num_docs() == 15 -def test_persist_and_restore_nested(qdrant_config): # noqa: F811 +def test_persist_and_restore_nested(tmp_collection_name): # noqa: F811 query = NestedDoc( tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) ) # create index - index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[NestedDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index( [ NestedDoc( @@ -62,7 +68,9 @@ def test_persist_and_restore_nested(qdrant_config): # noqa: F811 # delete and restore del index - index = QdrantDocumentIndex[NestedDoc](db_config=qdrant_config) + index = QdrantDocumentIndex[NestedDoc]( + host='localhost', collection_name=tmp_collection_name + ) assert index.num_docs() == 10 find_results_after = index.find(query, search_field='d__tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): From 67f418f99b1c9195a5de36fc3a1ce5266b2a285f Mon Sep 17 00:00:00 2001 From: AnneY Date: Fri, 5 May 2023 22:15:27 +0800 Subject: [PATCH 8/9] fix: elastic v8 tests Signed-off-by: AnneY --- tests/index/elastic/v8/test_column_config.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/index/elastic/v8/test_column_config.py b/tests/index/elastic/v8/test_column_config.py index 852d018db50..c2831dfb59e 100644 --- a/tests/index/elastic/v8/test_column_config.py +++ b/tests/index/elastic/v8/test_column_config.py @@ -3,18 +3,17 @@ from docarray import BaseDoc from docarray.index import ElasticDocIndex - -# from tests.index.elastic.fixture import start_storage_v8 # noqa: F401 +from tests.index.elastic.fixture import start_storage_v8, tmp_index_name # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] -def test_column_config(): +def test_column_config(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): text: str color: str = Field(col_type='keyword') - index = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc](index_name=tmp_index_name) index_docs = [ MyDoc(id='0', text='hello world', color='red'), MyDoc(id='1', text='never gonna give you up', color='blue'), @@ -31,7 +30,7 @@ class MyDoc(BaseDoc): assert [doc.id for doc in docs] == ['0', '1'] -def test_field_object(): +def test_field_object(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): manager: dict = Field( properties={ @@ -45,7 +44,7 @@ class MyDoc(BaseDoc): } ) - index = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc](index_name=tmp_index_name) doc = [ MyDoc(manager={'age': 25, 'name': {'first': 'Rachel', 'last': 'Green'}}), MyDoc(manager={'age': 30, 'name': {'first': 'Monica', 'last': 'Geller'}}), @@ -61,11 +60,11 @@ class MyDoc(BaseDoc): assert [doc.id for doc in docs] == [doc[1].id, doc[2].id] -def test_field_geo_point(): +def test_field_geo_point(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): location: dict = Field(col_type='geo_point') - index = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc](index_name=tmp_index_name) doc = [ MyDoc(location={'lat': 40.12, 'lon': -72.34}), MyDoc(location={'lat': 41.12, 'lon': -73.34}), @@ -88,12 +87,12 @@ class MyDoc(BaseDoc): assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] -def test_field_range(): +def test_field_range(tmp_index_name): # noqa: F811 class MyDoc(BaseDoc): expected_attendees: dict = Field(col_type='integer_range') time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') - index = ElasticDocIndex[MyDoc]() + index = ElasticDocIndex[MyDoc](index_name=tmp_index_name) doc = [ MyDoc( expected_attendees={'gte': 10, 'lt': 20}, From ff172e943720f65f259195fe355b1effa192a8ff Mon Sep 17 00:00:00 2001 From: AnneY Date: Fri, 5 May 2023 23:01:57 +0800 Subject: [PATCH 9/9] test: let more tests use default index name Signed-off-by: AnneY --- tests/index/elastic/v7/test_find.py | 4 +-- tests/index/elastic/v7/test_index_get_del.py | 23 +++++++-------- tests/index/elastic/v8/test_find.py | 4 +-- tests/index/elastic/v8/test_index_get_del.py | 31 ++++++++++---------- tests/index/qdrant/test_index_get_del.py | 20 +++++++------ tests/index/qdrant/test_persist_data.py | 22 +++++--------- 6 files changed, 48 insertions(+), 56 deletions(-) diff --git a/tests/index/elastic/v7/test_find.py b/tests/index/elastic/v7/test_find.py index e82eff3015a..1fe7893e91e 100644 --- a/tests/index/elastic/v7/test_find.py +++ b/tests/index/elastic/v7/test_find.py @@ -171,8 +171,8 @@ class TfDoc(BaseDoc): ) -def test_find_batched(): - index = ElasticV7DocIndex[SimpleDoc]() +def test_find_batched(tmp_index_name): # noqa: F811 + index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] index.index(index_docs) diff --git a/tests/index/elastic/v7/test_index_get_del.py b/tests/index/elastic/v7/test_index_get_del.py index 27cf3b5642d..050bcb03f54 100644 --- a/tests/index/elastic/v7/test_index_get_del.py +++ b/tests/index/elastic/v7/test_index_get_del.py @@ -70,11 +70,9 @@ def test_index_deep_nested_schema( assert index.num_docs() == 10 -def test_get_single( - ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_index_name # noqa: F811 -): +def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 # simple - index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) + index = ElasticV7DocIndex[SimpleDoc]() index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -82,9 +80,10 @@ def test_get_single( id_ = d.id assert index[id_].id == id_ assert np.all(index[id_].tens == d.tens) + index._client.indices.delete(index='simpledoc') # flat - index = ElasticV7DocIndex[FlatDoc](index_name=tmp_index_name + 'flat') + index = ElasticV7DocIndex[FlatDoc]() index.index(ten_flat_docs) assert index.num_docs() == 10 @@ -93,9 +92,10 @@ def test_get_single( assert index[id_].id == id_ assert np.all(index[id_].tens_one == d.tens_one) assert np.all(index[id_].tens_two == d.tens_two) + index._client.indices.delete(index='flatdoc') # nested - index = ElasticV7DocIndex[NestedDoc](index_name=tmp_index_name + 'nested') + index = ElasticV7DocIndex[NestedDoc]() index.index(ten_nested_docs) assert index.num_docs() == 10 @@ -104,15 +104,14 @@ def test_get_single( assert index[id_].id == id_ assert index[id_].d.id == d.d.id assert np.all(index[id_].d.tens == d.d.tens) + index._client.indices.delete(index='nesteddoc') -def test_get_multiple( - ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_index_name # noqa: F811 -): +def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 docs_to_get_idx = [0, 2, 4, 6, 8] # simple - index = ElasticV7DocIndex[SimpleDoc](index_name=tmp_index_name) + index = ElasticV7DocIndex[SimpleDoc]() index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -124,7 +123,7 @@ def test_get_multiple( assert np.all(d_out.tens == d_in.tens) # flat - index = ElasticV7DocIndex[FlatDoc](index_name=tmp_index_name + 'flat') + index = ElasticV7DocIndex[FlatDoc]() index.index(ten_flat_docs) assert index.num_docs() == 10 @@ -137,7 +136,7 @@ def test_get_multiple( assert np.all(d_out.tens_two == d_in.tens_two) # nested - index = ElasticV7DocIndex[NestedDoc](index_name=tmp_index_name + 'nested') + index = ElasticV7DocIndex[NestedDoc]() index.index(ten_nested_docs) assert index.num_docs() == 10 diff --git a/tests/index/elastic/v8/test_find.py b/tests/index/elastic/v8/test_find.py index f3cf6d6119a..cfd27ed0912 100644 --- a/tests/index/elastic/v8/test_find.py +++ b/tests/index/elastic/v8/test_find.py @@ -192,8 +192,8 @@ class TfDoc(BaseDoc): ) -def test_find_batched(): - index = ElasticDocIndex[SimpleDoc]() +def test_find_batched(tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] index.index(index_docs) diff --git a/tests/index/elastic/v8/test_index_get_del.py b/tests/index/elastic/v8/test_index_get_del.py index 4e34b712fcb..8d182dfd19a 100644 --- a/tests/index/elastic/v8/test_index_get_del.py +++ b/tests/index/elastic/v8/test_index_get_del.py @@ -70,11 +70,9 @@ def test_index_deep_nested_schema( assert index.num_docs() == 10 -def test_get_single( - ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_index_name # noqa: F811 -): +def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 # simple - index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) + index = ElasticDocIndex[SimpleDoc]() index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -82,9 +80,10 @@ def test_get_single( id_ = d.id assert index[id_].id == id_ assert np.all(index[id_].tens == d.tens) + index._client.indices.delete(index='simpledoc', ignore_unavailable=True) # flat - index = ElasticDocIndex[FlatDoc](index_name=tmp_index_name + 'flat') + index = ElasticDocIndex[FlatDoc]() index.index(ten_flat_docs) assert index.num_docs() == 10 @@ -93,9 +92,10 @@ def test_get_single( assert index[id_].id == id_ assert np.all(index[id_].tens_one == d.tens_one) assert np.all(index[id_].tens_two == d.tens_two) + index._client.indices.delete(index='flatdoc', ignore_unavailable=True) # nested - index = ElasticDocIndex[NestedDoc](index_name=tmp_index_name + 'nested') + index = ElasticDocIndex[NestedDoc]() index.index(ten_nested_docs) assert index.num_docs() == 10 @@ -104,15 +104,14 @@ def test_get_single( assert index[id_].id == id_ assert index[id_].d.id == d.d.id assert np.all(index[id_].d.tens == d.d.tens) + index._client.indices.delete(index='nesteddoc', ignore_unavailable=True) -def test_get_multiple( - ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_index_name # noqa: F811 -): +def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 docs_to_get_idx = [0, 2, 4, 6, 8] # simple - index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) + index = ElasticDocIndex[SimpleDoc]() index.index(ten_simple_docs) assert index.num_docs() == 10 @@ -124,7 +123,7 @@ def test_get_multiple( assert np.all(d_out.tens == d_in.tens) # flat - index = ElasticDocIndex[FlatDoc](index_name=tmp_index_name + 'flat') + index = ElasticDocIndex[FlatDoc]() index.index(ten_flat_docs) assert index.num_docs() == 10 @@ -137,7 +136,7 @@ def test_get_multiple( assert np.all(d_out.tens_two == d_in.tens_two) # nested - index = ElasticDocIndex[NestedDoc](index_name=tmp_index_name + 'nested') + index = ElasticDocIndex[NestedDoc]() index.index(ten_nested_docs) assert index.num_docs() == 10 @@ -150,8 +149,8 @@ def test_get_multiple( assert np.all(d_out.d.tens == d_in.d.tens) -def test_get_key_error(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_get_key_error(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) with pytest.raises(KeyError): @@ -213,8 +212,8 @@ def test_del_multiple(ten_simple_docs, tmp_index_name): # noqa: F811 assert np.all(index[doc.id].tens == doc.tens) -def test_del_key_error(ten_simple_docs): # noqa: F811 - index = ElasticDocIndex[SimpleDoc]() +def test_del_key_error(ten_simple_docs, tmp_index_name): # noqa: F811 + index = ElasticDocIndex[SimpleDoc](index_name=tmp_index_name) index.index(ten_simple_docs) with pytest.warns(UserWarning): diff --git a/tests/index/qdrant/test_index_get_del.py b/tests/index/qdrant/test_index_get_del.py index 643b989ceac..27cd3ee171b 100644 --- a/tests/index/qdrant/test_index_get_del.py +++ b/tests/index/qdrant/test_index_get_del.py @@ -110,14 +110,12 @@ def test_index_torch(tmp_collection_name): # noqa: F811 @pytest.mark.skip('Qdrant does not support storing image tensors yet') -def test_index_builtin_docs(tmp_collection_name): # noqa: F811 +def test_index_builtin_docs(): # TextDoc class TextSchema(TextDoc): embedding: Optional[NdArrayEmbedding] = Field(dim=10) - index = QdrantDocumentIndex[TextSchema]( - host='localhost', collection_name=tmp_collection_name - ) + index = QdrantDocumentIndex[TextSchema](host='localhost') index.index( DocList[TextDoc]( @@ -145,8 +143,10 @@ class ImageSchema(ImageDoc): assert index.num_docs() == 10 -def test_get_key_error(ten_simple_docs): - index = QdrantDocumentIndex[SimpleDoc](host='localhost') +def test_get_key_error(ten_simple_docs, tmp_collection_name): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) with pytest.raises(KeyError): @@ -201,8 +201,10 @@ def test_del_multiple(ten_simple_docs, tmp_collection_name): # noqa: F811 assert index[doc.id].id == doc.id -def test_del_key_error(ten_simple_docs): # noqa: F811 - index = QdrantDocumentIndex[SimpleDoc](host='localhost') +def test_del_key_error(ten_simple_docs, tmp_collection_name): # noqa: F811 + index = QdrantDocumentIndex[SimpleDoc]( + host='localhost', collection_name=tmp_collection_name + ) index.index(ten_simple_docs) with pytest.raises(KeyError): @@ -231,7 +233,7 @@ def test_num_docs(ten_simple_docs, tmp_collection_name): # noqa: F811 assert index.num_docs() == 10 -def test_multimodal_doc(): # noqa: F811 +def test_multimodal_doc(): class MyMultiModalDoc(BaseDoc): image: ImageDoc text: TextDoc diff --git a/tests/index/qdrant/test_persist_data.py b/tests/index/qdrant/test_persist_data.py index 88ab2ec342a..cea1fa70295 100644 --- a/tests/index/qdrant/test_persist_data.py +++ b/tests/index/qdrant/test_persist_data.py @@ -5,7 +5,7 @@ from docarray import BaseDoc from docarray.index import QdrantDocumentIndex from docarray.typing import NdArray -from tests.index.qdrant.fixtures import start_storage, tmp_collection_name # noqa: F401 +from tests.index.qdrant.fixtures import start_storage # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -19,22 +19,18 @@ class NestedDoc(BaseDoc): tens: NdArray[50] # type: ignore[valid-type] -def test_persist_and_restore(tmp_collection_name): # noqa: F811 +def test_persist_and_restore(): query = SimpleDoc(tens=np.random.random((10,))) # create index - index = QdrantDocumentIndex[SimpleDoc]( - host='localhost', collection_name=tmp_collection_name - ) + index = QdrantDocumentIndex[SimpleDoc](host='localhost') index.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) assert index.num_docs() == 10 find_results_before = index.find(query, search_field='tens', limit=5) # delete and restore del index - index = QdrantDocumentIndex[SimpleDoc]( - host='localhost', collection_name=tmp_collection_name - ) + index = QdrantDocumentIndex[SimpleDoc](host='localhost') assert index.num_docs() == 10 find_results_after = index.find(query, search_field='tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): @@ -46,15 +42,13 @@ def test_persist_and_restore(tmp_collection_name): # noqa: F811 assert index.num_docs() == 15 -def test_persist_and_restore_nested(tmp_collection_name): # noqa: F811 +def test_persist_and_restore_nested(): query = NestedDoc( tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) ) # create index - index = QdrantDocumentIndex[NestedDoc]( - host='localhost', collection_name=tmp_collection_name - ) + index = QdrantDocumentIndex[NestedDoc](host='localhost') index.index( [ NestedDoc( @@ -68,9 +62,7 @@ def test_persist_and_restore_nested(tmp_collection_name): # noqa: F811 # delete and restore del index - index = QdrantDocumentIndex[NestedDoc]( - host='localhost', collection_name=tmp_collection_name - ) + index = QdrantDocumentIndex[NestedDoc](host='localhost') assert index.num_docs() == 10 find_results_after = index.find(query, search_field='d__tens', limit=5) for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]):