diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d49d1d603c7..f49442f938a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -225,7 +225,38 @@ jobs: - name: Test id: test run: | - poetry run pytest -m 'index' tests + poetry run pytest -m 'index and not elasticv8' tests + timeout-minutes: 30 + + + docarray-elastic-v8: + needs: [lint-ruff, check-black, import-test] + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.7] + steps: + - uses: actions/checkout@v2.5.0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Prepare environment + run: | + python -m pip install --upgrade pip + python -m pip install poetry + poetry install --all-extras + poetry run pip install protobuf==3.19.0 + poetry run pip install tensorflow==2.11.0 + poetry run pip install elasticsearch==8.6.2 + sudo apt-get update + sudo apt-get install --no-install-recommends ffmpeg + + - name: Test + id: test + run: | + poetry run pytest -m 'index and elasticv8' tests timeout-minutes: 30 docarray-test-tensorflow: @@ -284,7 +315,7 @@ jobs: # just for blocking the merge until all parallel core-test are successful success-all-test: - needs: [docarray-test, docarray-test-proto3, docarray-doc-index, docarray-test-tensorflow, docarray-test-benchmarks, import-test, check-black, check-mypy, lint-ruff] + needs: [docarray-test, docarray-test-proto3, docarray-doc-index, docarray-elastic-v8, docarray-test-tensorflow, docarray-test-benchmarks, import-test, check-black, check-mypy, lint-ruff] if: always() runs-on: ubuntu-latest steps: diff --git a/docarray/index/__init__.py b/docarray/index/__init__.py index 37ef5e6f611..2c724030fe7 100644 --- a/docarray/index/__init__.py +++ b/docarray/index/__init__.py @@ -7,7 +7,8 @@ ) if TYPE_CHECKING: - from docarray.index.backends.elastic import ElasticV7DocIndex # noqa: F401 + from docarray.index.backends.elastic import ElasticDocIndex # noqa: F401 + from docarray.index.backends.elasticv7 import ElasticV7DocIndex # noqa: F401 from docarray.index.backends.hnswlib import HnswDocumentIndex # noqa: F401 __all__ = [] @@ -18,9 +19,13 @@ def __getattr__(name: str): if name == 'HnswDocumentIndex': import_library('hnswlib', raise_error=True) import docarray.index.backends.hnswlib as lib - elif name == 'ElasticV7DocIndex': + elif name == 'ElasticDocIndex': import_library('elasticsearch', raise_error=True) import docarray.index.backends.elastic as lib + elif name == 'ElasticV7DocIndex': + import_library('elasticsearch', raise_error=True) + import docarray.index.backends.elasticv7 as lib + else: raise ImportError( f'cannot import name \'{name}\' from \'{_get_path_from_docarray_root_level(__file__)}\'' diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index 08c29c150d2..c2c1c6646a2 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -1,4 +1,4 @@ -import os +# mypy: ignore-errors import uuid import warnings from collections import defaultdict @@ -21,6 +21,7 @@ ) import numpy as np +from elastic_transport import NodeConfig from elasticsearch import Elasticsearch from elasticsearch.helpers import parallel_bulk from pydantic import parse_obj_as @@ -40,7 +41,7 @@ from docarray.utils.find import _FindResult TSchema = TypeVar('TSchema', bound=BaseDoc) -T = TypeVar('T', bound='ElasticV7DocIndex') +T = TypeVar('T', bound='ElasticDocIndex') ELASTIC_PY_VEC_TYPES: List[Any] = [list, tuple, np.ndarray, AbstractTensor] @@ -58,11 +59,12 @@ ELASTIC_PY_VEC_TYPES.append(TensorFlowTensor) -class ElasticV7DocIndex(BaseDocIndex, Generic[TSchema]): +class ElasticDocIndex(BaseDocIndex, Generic[TSchema]): def __init__(self, db_config=None, **kwargs): super().__init__(db_config=db_config, **kwargs) - self._db_config = cast(ElasticV7DocIndex.DBConfig, self._db_config) + self._db_config = cast(self.DBConfig, self._db_config) + # ElasticSearch client creation if self._db_config.index_name is None: id = uuid.uuid4().hex self._db_config.index_name = 'index__' + id @@ -74,40 +76,33 @@ def __init__(self, db_config=None, **kwargs): **self._db_config.es_config, ) - # compatibility - self._server_version = self._client.info()['version']['number'] - if int(self._server_version.split('.')[0]) >= 8: - os.environ['ELASTIC_CLIENT_APIVERSIONING'] = '1' + # ElasticSearh index setup + self._index_vector_params = ('dims', 'similarity', 'index') + self._index_vector_options = ('m', 'ef_construction') - body: Dict[str, Any] = { - 'mappings': { - 'dynamic': True, - '_source': {'enabled': 'true'}, - 'properties': {}, - } + mappings: Dict[str, Any] = { + 'dynamic': True, + '_source': {'enabled': 'true'}, + 'properties': {}, } + mappings.update(self._db_config.index_mappings) for col_name, col in self._column_infos.items(): - body['mappings']['properties'][col_name] = self._create_index_mapping(col) + mappings['properties'][col_name] = self._create_index_mapping(col) if self._client.indices.exists(index=self._index_name): - self._client.indices.put_mapping( - index=self._index_name, body=body['mappings'] - ) + self._client_put_mapping(mappings) else: - self._client.indices.create(index=self._index_name, body=body) + self._client_create(mappings) if len(self._db_config.index_settings): - self._client.indices.put_settings( - index=self._index_name, body=self._db_config.index_settings - ) + self._client_put_settings(self._db_config.index_settings) self._refresh(self._index_name) ############################################### # Inner classes for query builder and configs # ############################################### - class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__(self, outer_instance, **kwargs): super().__init__() @@ -117,16 +112,11 @@ def __init__(self, outer_instance, **kwargs): } def build(self, *args, **kwargs) -> Any: - if ( - 'script_score' in self._query['query'] - and 'bool' in self._query['query'] - and len(self._query['query']['bool']) > 0 - ): - self._query['query']['script_score']['query'] = {} - self._query['query']['script_score']['query']['bool'] = self._query[ - 'query' - ]['bool'] - del self._query['query']['bool'] + if len(self._query['query']) == 0: + del self._query['query'] + elif 'knn' in self._query: + self._query['knn']['filter'] = self._query['query'] + del self._query['query'] return self._query @@ -135,6 +125,7 @@ def find( query: Union[AnyTensor, BaseDoc], search_field: str = 'embedding', limit: int = 10, + num_candidates: Optional[int] = None, ): self._outer_instance._validate_search_field(search_field) if isinstance(query, BaseDoc): @@ -142,13 +133,17 @@ def find( else: query_vec = query query_vec_np = BaseDocIndex._to_numpy(self._outer_instance, query_vec) - self._query['size'] = limit - self._query['query']['script_score'] = ElasticV7DocIndex._form_search_body( - query_vec_np, limit, search_field - )['query']['script_score'] + self._query['knn'] = self._outer_instance._form_search_body( + query_vec_np, + limit, + search_field, + num_candidates, + )['knn'] return self + # filter accepts Leaf/Compound query clauses + # https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html def filter(self, query: Dict[str, Any], limit: int = 10): self._query['size'] = limit self._query['query']['bool']['filter'].append(query) @@ -163,8 +158,8 @@ def text_search(self, query: str, search_field: str = 'text', limit: int = 10): return self find_batched = _raise_not_composable('find_batched') - filter_batched = _raise_not_composable('find_batched') - text_search_batched = _raise_not_composable('text_search') + filter_batched = _raise_not_composable('filter_batched') + text_search_batched = _raise_not_composable('text_search_batched') def build_query(self, **kwargs) -> QueryBuilder: """ @@ -174,15 +169,21 @@ def build_query(self, **kwargs) -> QueryBuilder: @dataclass class DBConfig(BaseDocIndex.DBConfig): - hosts: Union[str, List[str], None] = 'http://localhost:9200' + hosts: Union[ + str, List[Union[str, Mapping[str, Union[str, int]], NodeConfig]], None + ] = 'http://localhost:9200' index_name: Optional[str] = None es_config: Dict[str, Any] = field(default_factory=dict) index_settings: Dict[str, Any] = field(default_factory=dict) + index_mappings: Dict[str, Any] = field(default_factory=dict) @dataclass class RuntimeConfig(BaseDocIndex.RuntimeConfig): - default_column_config: Dict[Any, Dict[str, Any]] = field( - default_factory=lambda: { + default_column_config: Dict[Any, Dict[str, Any]] = field(default_factory=dict) + chunk_size: int = 500 + + def __post_init__(self): + self.default_column_config = { 'binary': {}, 'boolean': {}, 'keyword': {}, @@ -206,6 +207,7 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig): 'long_range': {}, 'double_range': {}, 'date_range': {}, + 'ip_range': {}, 'ip': {}, 'version': {}, 'histogram': {}, @@ -214,7 +216,6 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig): 'completion': {}, 'search_as_you_type': {}, 'token_count': {}, - 'dense_vector': {'dims': 128}, 'sparse_vector': {}, 'rank_feature': {}, 'rank_features': {}, @@ -226,8 +227,19 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig): # `None` is not a Type, but we allow it here anyway None: {}, # type: ignore } - ) - chunk_size: int = 500 + self.default_column_config['dense_vector'] = self.dense_vector_config() + + def dense_vector_config(self): + config = { + 'index': True, + 'dims': 128, + 'similarity': 'cosine', # 'l2_norm', 'dot_product', 'cosine' + 'm': 16, + 'ef_construction': 100, + 'num_candidates': 10000, + } + + return config ############################################### # Implementation of abstract methods # @@ -235,7 +247,6 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig): def python_type_to_db_type(self, python_type: Type) -> Any: """Map python type to database type.""" - for allowed_type in ELASTIC_PY_VEC_TYPES: if issubclass(python_type, allowed_type): return 'dense_vector' @@ -313,10 +324,7 @@ def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]: accumulated_docs = [] accumulated_docs_id_not_found = [] - es_rows = self._client.mget( - index=self._index_name, - body={'ids': doc_ids}, - )['docs'] + es_rows = self._client_mget(doc_ids)['docs'] for row in es_rows: if row['found']: @@ -337,7 +345,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any: f'args and kwargs not supported for `execute_query` on {type(self)}' ) - resp = self._client.search(index=self._index_name, body=query) + resp = self._client.search(index=self._index_name, **query) docs, scores = self._format_response(resp) return _FindResult(documents=docs, scores=scores) @@ -345,17 +353,9 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any: def _find( self, query: np.ndarray, limit: int, search_field: str = '' ) -> _FindResult: - if int(self._server_version.split('.')[0]) >= 8: - warnings.warn( - 'You are using Elasticsearch 8.0+ and the current client is 7.10.1. HNSW based vector search is not supported and the find method has a default implementation using exhaustive KNN search with cosineSimilarity, which may result in slow performance.' - ) - body = self._form_search_body(query, limit, search_field) - resp = self._client.search( - index=self._index_name, - body=body, - ) + resp = self._client_search(**body) docs, scores = self._format_response(resp) @@ -373,7 +373,7 @@ def _find_batched( body = self._form_search_body(query, limit, search_field) request.extend([head, body]) - responses = self._client.msearch(body=request) + responses = self._client_msearch(request) das, scores = zip( *[self._format_response(resp) for resp in responses['responses']] @@ -385,15 +385,7 @@ def _filter( filter_query: Dict[str, Any], limit: int, ) -> List[Dict]: - body = { - 'size': limit, - 'query': filter_query, - } - - resp = self._client.search( - index=self._index_name, - body=body, - ) + resp = self._client_search(query=filter_query, size=limit) docs, _ = self._format_response(resp) @@ -410,7 +402,7 @@ def _filter_batched( body = {'query': query, 'size': limit} request.extend([head, body]) - responses = self._client.msearch(body=request) + responses = self._client_msearch(request) das, _ = zip(*[self._format_response(resp) for resp in responses['responses']]) return list(das) @@ -422,15 +414,11 @@ def _text_search( search_field: str = '', ) -> _FindResult: body = self._form_text_search_body(query, limit, search_field) - - resp = self._client.search( - index=self._index_name, - body=body, - ) + resp = self._client_search(**body) docs, scores = self._format_response(resp) - return _FindResult(documents=docs, scores=scores) + return _FindResult(documents=docs, scores=np.array(scores)) # type: ignore def _text_search_batched( self, @@ -444,28 +432,32 @@ def _text_search_batched( body = self._form_text_search_body(query, limit, search_field) request.extend([head, body]) - responses = self._client.msearch(body=request) - + responses = self._client_msearch(request) das, scores = zip( *[self._format_response(resp) for resp in responses['responses']] ) - return _FindResultBatched(documents=list(das), scores=np.array(scores)) + return _FindResultBatched( + documents=list(das), scores=np.array(scores, dtype=object) + ) ############################################### # Helpers # ############################################### - # ElasticSearch helpers def _create_index_mapping(self, col: '_ColumnInfo') -> Dict[str, Any]: """Create a new HNSW index for a column, and initialize it.""" - index = col.config.copy() - if 'type' not in index: - index['type'] = col.db_type - - if col.db_type == 'dense_vector' and col.n_dim: - index['dims'] = col.n_dim + index = {'type': col.config['type'] if 'type' in col.config else col.db_type} + if col.db_type == 'dense_vector': + for k in self._index_vector_params: + index[k] = col.config[k] + if col.n_dim: + index['dims'] = col.n_dim + index['index_options'] = dict( + (k, col.config[k]) for k in self._index_vector_options + ) + index['index_options']['type'] = 'hnsw' return index def _send_requests( @@ -493,27 +485,30 @@ def _send_requests( return accumulated_info, warning_info - @staticmethod def _form_search_body( - query: np.ndarray, limit: int, search_field: str = '' + self, + query: np.ndarray, + limit: int, + search_field: str = '', + num_candidates: Optional[int] = None, ) -> Dict[str, Any]: + if not num_candidates: + num_candidates = self._runtime_config.default_column_config['dense_vector'][ + 'num_candidates' + ] body = { 'size': limit, - 'query': { - 'script_score': { - 'query': {'match_all': {}}, - 'script': { - 'source': f'cosineSimilarity(params.query_vector, \'{search_field}\') + 1.0', - 'params': {'query_vector': query}, - }, - } + 'knn': { + 'field': search_field, + 'query_vector': query, + 'k': limit, + 'num_candidates': num_candidates, }, } return body - @staticmethod def _form_text_search_body( - query: str, limit: int, search_field: str = '' + self, query: str, limit: int, search_field: str = '' ) -> Dict[str, Any]: body = { 'size': limit, @@ -544,3 +539,27 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], NdArray]: def _refresh(self, index_name: str): self._client.indices.refresh(index=index_name) + + ############################################### + # API Wrappers # + ############################################### + + def _client_put_mapping(self, mappings: Dict[str, Any]): + self._client.indices.put_mapping( + index=self._index_name, properties=mappings['properties'] + ) + + def _client_create(self, mappings: Dict[str, Any]): + self._client.indices.create(index=self._index_name, mappings=mappings) + + def _client_put_settings(self, settings: Dict[str, Any]): + self._client.indices.put_settings(index=self._index_name, settings=settings) + + def _client_mget(self, ids: Sequence[str]): + return self._client.mget(index=self._index_name, ids=ids) + + def _client_search(self, **kwargs): + return self._client.search(index=self._index_name, **kwargs) + + def _client_msearch(self, request: List[Dict[str, Any]]): + return self._client.msearch(index=self._index_name, searches=request) diff --git a/docarray/index/backends/elasticv7.py b/docarray/index/backends/elasticv7.py new file mode 100644 index 00000000000..e77aedfc2b4 --- /dev/null +++ b/docarray/index/backends/elasticv7.py @@ -0,0 +1,152 @@ +import warnings +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union + +import numpy as np + +from docarray import BaseDoc +from docarray.index import ElasticDocIndex +from docarray.index.abstract import BaseDocIndex, _ColumnInfo +from docarray.typing import AnyTensor +from docarray.utils.find import _FindResult + +TSchema = TypeVar('TSchema', bound=BaseDoc) +T = TypeVar('T', bound='ElasticV7DocIndex') + + +class ElasticV7DocIndex(ElasticDocIndex): + def __init__(self, db_config=None, **kwargs): + from elasticsearch import __version__ as __es__version__ + + if __es__version__[0] > 7: + raise ImportError( + 'ElasticV7DocIndex requires the elasticsearch library to be version 7.10.1' + ) + + super().__init__(db_config, **kwargs) + + ############################################### + # Inner classes for query builder and configs # + ############################################### + + class QueryBuilder(ElasticDocIndex.QueryBuilder): + def build(self, *args, **kwargs) -> Any: + if ( + 'script_score' in self._query['query'] + and 'bool' in self._query['query'] + and len(self._query['query']['bool']) > 0 + ): + self._query['query']['script_score']['query'] = {} + self._query['query']['script_score']['query']['bool'] = self._query[ + 'query' + ]['bool'] + del self._query['query']['bool'] + + return self._query + + def find( + self, + query: Union[AnyTensor, BaseDoc], + search_field: str = 'embedding', + limit: int = 10, + num_candidates: Optional[int] = None, + ): + if num_candidates: + warnings.warn('`num_candidates` is not supported in ElasticV7DocIndex') + + if isinstance(query, BaseDoc): + query_vec = BaseDocIndex._get_values_by_column([query], search_field)[0] + else: + query_vec = query + query_vec_np = BaseDocIndex._to_numpy(self._outer_instance, query_vec) + self._query['size'] = limit + self._query['query'][ + 'script_score' + ] = self._outer_instance._form_search_body( + query_vec_np, limit, search_field + )[ + 'query' + ][ + 'script_score' + ] + + return self + + @dataclass + class DBConfig(ElasticDocIndex.DBConfig): + hosts: Union[str, List[str], None] = 'http://localhost:9200' # type: ignore + + @dataclass + class RuntimeConfig(ElasticDocIndex.RuntimeConfig): + def dense_vector_config(self): + return {'dims': 128} + + ############################################### + # Implementation of abstract methods # + ############################################### + + def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any: + if args or kwargs: + raise ValueError( + f'args and kwargs not supported for `execute_query` on {type(self)}' + ) + + resp = self._client.search(index=self._index_name, body=query) + docs, scores = self._format_response(resp) + + return _FindResult(documents=docs, scores=scores) + + ############################################### + # Helpers # + ############################################### + + # ElasticSearch helpers + def _create_index_mapping(self, col: '_ColumnInfo') -> Dict[str, Any]: + """Create a new HNSW index for a column, and initialize it.""" + + index = col.config.copy() + if 'type' not in index: + index['type'] = col.db_type + + if col.db_type == 'dense_vector' and col.n_dim: + index['dims'] = col.n_dim + + return index + + def _form_search_body(self, query: np.ndarray, limit: int, search_field: str = '') -> Dict[str, Any]: # type: ignore + body = { + 'size': limit, + 'query': { + 'script_score': { + 'query': {'match_all': {}}, + 'script': { + 'source': f'cosineSimilarity(params.query_vector, \'{search_field}\') + 1.0', + 'params': {'query_vector': query}, + }, + } + }, + } + return body + + ############################################### + # API Wrappers # + ############################################### + + def _client_put_mapping(self, mappings: Dict[str, Any]): + self._client.indices.put_mapping(index=self._index_name, body=mappings) + + def _client_create(self, mappings: Dict[str, Any]): + body = {'mappings': mappings} + self._client.indices.create(index=self._index_name, body=body) + + def _client_put_settings(self, settings: Dict[str, Any]): + self._client.indices.put_settings(index=self._index_name, body=settings) + + def _client_mget(self, ids: Sequence[str]): + return self._client.mget(index=self._index_name, body={'ids': ids}) + + def _client_search(self, **kwargs): + return self._client.search(index=self._index_name, body=kwargs) + + def _client_msearch(self, request: List[Dict[str, Any]]): + return self._client.msearch(index=self._index_name, body=request) diff --git a/poetry.lock b/poetry.lock index 771cafd3e80..cd46e05c897 100644 --- a/poetry.lock +++ b/poetry.lock @@ -803,6 +803,25 @@ six = ">=1.9.0" gmpy = ["gmpy"] gmpy2 = ["gmpy2"] +[[package]] +name = "elastic-transport" +version = "8.4.0" +description = "Transport classes and utilities shared among Python Elastic client libraries" +category = "main" +optional = true +python-versions = ">=3.6" +files = [ + {file = "elastic-transport-8.4.0.tar.gz", hash = "sha256:b9ad708ceb7fcdbc6b30a96f886609a109f042c0b9d9f2e44403b3133ba7ff10"}, + {file = "elastic_transport-8.4.0-py3-none-any.whl", hash = "sha256:19db271ab79c9f70f8c43f8f5b5111408781a6176b54ab2e54d713b6d9ceb815"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = ">=1.26.2,<2" + +[package.extras] +develop = ["aiohttp", "mock", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "trustme"] + [[package]] name = "elasticsearch" version = "7.10.1" @@ -4606,7 +4625,7 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [extras] audio = ["pydub"] aws = ["smart-open"] -elasticsearch = ["elasticsearch"] +elasticsearch = ["elasticsearch", "elastic-transport"] full = ["protobuf", "lz4", "pandas", "pillow", "types-pillow", "av", "pydub", "trimesh"] hnswlib = ["hnswlib"] image = ["pillow", "types-pillow"] @@ -4621,4 +4640,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.7,<4.0" -content-hash = "61780ee493f649cc3cc164f8a3585083d69aed63831fad3c3cdcf91609804221" +content-hash = "a5bae8ca8239347d066e7566dfea56f08d42950f7037e50870cee226809f4b01" diff --git a/pyproject.toml b/pyproject.toml index 49a5cd704fb..ecc72c74719 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,10 @@ hnswlib = {version = ">=0.6.2", optional = true } lz4 = {version= ">=1.0.0", optional = true} pydub = {version = "^0.25.1", optional = true } pandas = {version = ">=1.1.0", optional = true } -elasticsearch = {version = "7.10.1", optional = true } +elasticsearch = {version = ">=7.10.1", optional = true } smart-open = {version = ">=6.3.0", extras = ["s3"], optional = true} jina-hubble-sdk = {version = ">=0.34.0", optional = true} +elastic-transport = {version ="^8.4.0", optional = true } [tool.poetry.extras] proto = ["protobuf", "lz4"] @@ -36,7 +37,7 @@ video = ["av"] audio = ["pydub"] mesh = ["trimesh"] hnswlib = ["hnswlib"] -elasticsearch = ["elasticsearch"] +elasticsearch = ["elasticsearch", "elastic-transport"] jac = ["jina-hubble-sdk"] aws = ["smart-open"] torch = ["torch"] @@ -115,4 +116,5 @@ markers = [ "tensorflow: marks test using tensorflow and proto 3", "index: marks test using a document index", "benchmark: marks slow benchmarking tests", + "elasticv8: marks test that run with ElasticSearch v8", ] diff --git a/tests/integrations/doc_index/__init__.py b/tests/index/elastic/__init__.py similarity index 100% rename from tests/integrations/doc_index/__init__.py rename to tests/index/elastic/__init__.py diff --git a/tests/integrations/doc_index/elastic/fixture.py b/tests/index/elastic/fixture.py similarity index 67% rename from tests/integrations/doc_index/elastic/fixture.py rename to tests/index/elastic/fixture.py index 82462ea96e9..315078d6269 100644 --- a/tests/integrations/doc_index/elastic/fixture.py +++ b/tests/index/elastic/fixture.py @@ -1,6 +1,7 @@ import os import time +import numpy as np import pytest from pydantic import Field @@ -9,24 +10,6 @@ pytestmark = [pytest.mark.slow, pytest.mark.index] - -class SimpleDoc(BaseDoc): - tens: NdArray[10] = Field(dims=1000) - - -class FlatDoc(BaseDoc): - tens_one: NdArray = Field(dims=10) - tens_two: NdArray = Field(dims=50) - - -class NestedDoc(BaseDoc): - d: SimpleDoc - - -class DeepNestedDoc(BaseDoc): - d: NestedDoc - - cur_dir = os.path.dirname(os.path.abspath(__file__)) compose_yml_v7 = os.path.abspath(os.path.join(cur_dir, 'v7/docker-compose.yml')) compose_yml_v8 = os.path.abspath(os.path.join(cur_dir, 'v8/docker-compose.yml')) @@ -56,3 +39,46 @@ def _wait_for_es(): es = Elasticsearch(hosts='http://localhost:9200/') while not es.ping(): time.sleep(0.5) + + +class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(dims=1000) + + +class FlatDoc(BaseDoc): + tens_one: NdArray = Field(dims=10) + tens_two: NdArray = Field(dims=50) + + +class NestedDoc(BaseDoc): + d: SimpleDoc + + +class DeepNestedDoc(BaseDoc): + d: NestedDoc + + +@pytest.fixture(scope='function') +def ten_simple_docs(): + return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)] + + +@pytest.fixture(scope='function') +def ten_flat_docs(): + return [ + FlatDoc(tens_one=np.random.randn(10), tens_two=np.random.randn(50)) + for _ in range(10) + ] + + +@pytest.fixture(scope='function') +def ten_nested_docs(): + return [NestedDoc(d=SimpleDoc(tens=np.random.randn(10))) for _ in range(10)] + + +@pytest.fixture(scope='function') +def ten_deep_nested_docs(): + return [ + DeepNestedDoc(d=NestedDoc(d=SimpleDoc(tens=np.random.randn(10)))) + for _ in range(10) + ] diff --git a/tests/integrations/doc_index/elastic/__init__.py b/tests/index/elastic/v7/__init__.py similarity index 100% rename from tests/integrations/doc_index/elastic/__init__.py rename to tests/index/elastic/v7/__init__.py diff --git a/tests/integrations/doc_index/elastic/v7/docker-compose.yml b/tests/index/elastic/v7/docker-compose.yml similarity index 100% rename from tests/integrations/doc_index/elastic/v7/docker-compose.yml rename to tests/index/elastic/v7/docker-compose.yml diff --git a/tests/integrations/doc_index/elastic/v7/test_column_config.py b/tests/index/elastic/v7/test_column_config.py similarity index 97% rename from tests/integrations/doc_index/elastic/v7/test_column_config.py rename to tests/index/elastic/v7/test_column_config.py index df927a2c2de..a0d4aa4dec9 100644 --- a/tests/integrations/doc_index/elastic/v7/test_column_config.py +++ b/tests/index/elastic/v7/test_column_config.py @@ -3,7 +3,7 @@ from docarray import BaseDoc from docarray.index import ElasticV7DocIndex -from tests.integrations.doc_index.elastic.fixture import start_storage_v7 # noqa: F401 +from tests.index.elastic.fixture import start_storage_v7 # noqa: F401 pytestmark = [pytest.mark.slow, pytest.mark.index] diff --git a/tests/integrations/doc_index/elastic/v7/test_find.py b/tests/index/elastic/v7/test_find.py similarity index 98% rename from tests/integrations/doc_index/elastic/v7/test_find.py rename to tests/index/elastic/v7/test_find.py index 1a0503711a7..6665c8b2b60 100644 --- a/tests/integrations/doc_index/elastic/v7/test_find.py +++ b/tests/index/elastic/v7/test_find.py @@ -6,8 +6,8 @@ from docarray import BaseDoc from docarray.index import ElasticV7DocIndex from docarray.typing import NdArray, TorchTensor -from tests.integrations.doc_index.elastic.fixture import start_storage_v7 # noqa: F401 -from tests.integrations.doc_index.elastic.fixture import FlatDoc, SimpleDoc +from tests.index.elastic.fixture import start_storage_v7 # noqa: F401 +from tests.index.elastic.fixture import FlatDoc, SimpleDoc pytestmark = [pytest.mark.slow, pytest.mark.index] diff --git a/tests/integrations/doc_index/elastic/v7/test_index_get_del.py b/tests/index/elastic/v7/test_index_get_del.py similarity index 86% rename from tests/integrations/doc_index/elastic/v7/test_index_get_del.py rename to tests/index/elastic/v7/test_index_get_del.py index 40779116c4e..7124d5d61bd 100644 --- a/tests/integrations/doc_index/elastic/v7/test_index_get_del.py +++ b/tests/index/elastic/v7/test_index_get_del.py @@ -7,45 +7,23 @@ from docarray.documents import ImageDoc, TextDoc from docarray.index import ElasticV7DocIndex from docarray.typing import NdArray -from tests.integrations.doc_index.elastic.fixture import start_storage_v7 # noqa: F401 -from tests.integrations.doc_index.elastic.fixture import ( +from tests.index.elastic.fixture import ( # noqa: F401 DeepNestedDoc, FlatDoc, NestedDoc, SimpleDoc, + start_storage_v7, + ten_deep_nested_docs, + ten_flat_docs, + ten_nested_docs, + ten_simple_docs, ) pytestmark = [pytest.mark.slow, pytest.mark.index] -@pytest.fixture -def ten_simple_docs(): - return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)] - - -@pytest.fixture -def ten_flat_docs(): - return [ - FlatDoc(tens_one=np.random.randn(10), tens_two=np.random.randn(50)) - for _ in range(10) - ] - - -@pytest.fixture -def ten_nested_docs(): - return [NestedDoc(d=SimpleDoc(tens=np.random.randn(10))) for _ in range(10)] - - -@pytest.fixture -def ten_deep_nested_docs(): - return [ - DeepNestedDoc(d=NestedDoc(d=SimpleDoc(tens=np.random.randn(10)))) - for _ in range(10) - ] - - @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_simple_schema(ten_simple_docs, use_docarray): +def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 store = ElasticV7DocIndex[SimpleDoc]() if use_docarray: ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) @@ -55,7 +33,7 @@ def test_index_simple_schema(ten_simple_docs, use_docarray): @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_flat_schema(ten_flat_docs, use_docarray): +def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 store = ElasticV7DocIndex[FlatDoc]() if use_docarray: ten_flat_docs = DocList[FlatDoc](ten_flat_docs) @@ -65,7 +43,7 @@ def test_index_flat_schema(ten_flat_docs, use_docarray): @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_nested_schema(ten_nested_docs, use_docarray): +def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 store = ElasticV7DocIndex[NestedDoc]() if use_docarray: ten_nested_docs = DocList[NestedDoc](ten_nested_docs) @@ -75,7 +53,7 @@ def test_index_nested_schema(ten_nested_docs, use_docarray): @pytest.mark.parametrize('use_docarray', [True, False]) -def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): +def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: F811 store = ElasticV7DocIndex[DeepNestedDoc]() if use_docarray: ten_deep_nested_docs = DocList[DeepNestedDoc](ten_deep_nested_docs) @@ -84,7 +62,7 @@ def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): assert store.num_docs() == 10 -def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): +def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 # simple store = ElasticV7DocIndex[SimpleDoc]() store.index(ten_simple_docs) @@ -118,7 +96,7 @@ def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): assert np.all(store[id_].d.tens == d.d.tens) -def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): +def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 docs_to_get_idx = [0, 2, 4, 6, 8] # simple @@ -160,7 +138,7 @@ def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): assert np.all(d_out.d.tens == d_in.d.tens) -def test_get_key_error(ten_simple_docs): +def test_get_key_error(ten_simple_docs): # noqa: F811 store = ElasticV7DocIndex[SimpleDoc]() store.index(ten_simple_docs) @@ -168,7 +146,7 @@ def test_get_key_error(ten_simple_docs): store['not_a_real_id'] -def test_persisting(ten_simple_docs): +def test_persisting(ten_simple_docs): # noqa: F811 store = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') store.index(ten_simple_docs) @@ -176,7 +154,7 @@ def test_persisting(ten_simple_docs): assert store2.num_docs() == 10 -def test_del_single(ten_simple_docs): +def test_del_single(ten_simple_docs): # noqa: F811 store = ElasticV7DocIndex[SimpleDoc]() store.index(ten_simple_docs) # delete once @@ -204,7 +182,7 @@ def test_del_single(ten_simple_docs): assert np.all(store[id_].tens == d.tens) -def test_del_multiple(ten_simple_docs): +def test_del_multiple(ten_simple_docs): # noqa: F811 docs_to_del_idx = [0, 2, 4, 6, 8] store = ElasticV7DocIndex[SimpleDoc]() @@ -223,7 +201,7 @@ def test_del_multiple(ten_simple_docs): assert np.all(store[doc.id].tens == doc.tens) -def test_del_key_error(ten_simple_docs): +def test_del_key_error(ten_simple_docs): # noqa: F811 store = ElasticV7DocIndex[SimpleDoc]() store.index(ten_simple_docs) @@ -231,7 +209,7 @@ def test_del_key_error(ten_simple_docs): del store['not_a_real_id'] -def test_num_docs(ten_simple_docs): +def test_num_docs(ten_simple_docs): # noqa: F811 store = ElasticV7DocIndex[SimpleDoc]() store.index(ten_simple_docs) diff --git a/tests/index/elastic/v8/docker-compose.yml b/tests/index/elastic/v8/docker-compose.yml new file mode 100644 index 00000000000..70eedba34f5 --- /dev/null +++ b/tests/index/elastic/v8/docker-compose.yml @@ -0,0 +1,16 @@ +version: "3.3" +services: + elastic: + image: docker.elastic.co/elasticsearch/elasticsearch:8.6.2 + environment: + - xpack.security.enabled=false + - discovery.type=single-node + - ES_JAVA_OPTS=-Xmx1024m + ports: + - "9200:9200" + networks: + - elastic + +networks: + elastic: + name: elastic \ No newline at end of file diff --git a/tests/index/elastic/v8/test_column_config.py b/tests/index/elastic/v8/test_column_config.py new file mode 100644 index 00000000000..2b3bbcee0f8 --- /dev/null +++ b/tests/index/elastic/v8/test_column_config.py @@ -0,0 +1,131 @@ +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import ElasticDocIndex +from tests.index.elastic.fixture import start_storage_v8 # noqa: F401 + +pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] + + +def test_column_config(): + class MyDoc(BaseDoc): + text: str + color: str = Field(col_type='keyword') + + store = ElasticDocIndex[MyDoc]() + index_docs = [ + MyDoc(id='0', text='hello world', color='red'), + MyDoc(id='1', text='never gonna give you up', color='blue'), + MyDoc(id='2', text='we are the world', color='green'), + ] + store.index(index_docs) + + query = 'world' + docs, _ = store.text_search(query, search_field='text') + assert [doc.id for doc in docs] == ['0', '2'] + + filter_query = {'terms': {'color': ['red', 'blue']}} + docs = store.filter(filter_query) + assert [doc.id for doc in docs] == ['0', '1'] + + +def test_field_object(): + class MyDoc(BaseDoc): + manager: dict = Field( + properties={ + 'age': {'type': 'integer'}, + 'name': { + 'properties': { + 'first': {'type': 'keyword'}, + 'last': {'type': 'keyword'}, + } + }, + } + ) + + store = ElasticDocIndex[MyDoc]() + doc = [ + MyDoc(manager={'age': 25, 'name': {'first': 'Rachel', 'last': 'Green'}}), + MyDoc(manager={'age': 30, 'name': {'first': 'Monica', 'last': 'Geller'}}), + MyDoc(manager={'age': 35, 'name': {'first': 'Phoebe', 'last': 'Buffay'}}), + ] + store.index(doc) + id_ = doc[0].id + assert store[id_].id == id_ + assert store[id_].manager == doc[0].manager + + filter_query = {'range': {'manager.age': {'gte': 30}}} + docs = store.filter(filter_query) + assert [doc.id for doc in docs] == [doc[1].id, doc[2].id] + + +def test_field_geo_point(): + class MyDoc(BaseDoc): + location: dict = Field(col_type='geo_point') + + store = ElasticDocIndex[MyDoc]() + doc = [ + MyDoc(location={'lat': 40.12, 'lon': -72.34}), + MyDoc(location={'lat': 41.12, 'lon': -73.34}), + MyDoc(location={'lat': 42.12, 'lon': -74.34}), + ] + store.index(doc) + + query = { + 'query': { + 'geo_bounding_box': { + 'location': { + 'top_left': {'lat': 42, 'lon': -74}, + 'bottom_right': {'lat': 40, 'lon': -72}, + } + } + }, + } + + docs, _ = store.execute_query(query) + assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] + + +def test_field_range(): + class MyDoc(BaseDoc): + expected_attendees: dict = Field(col_type='integer_range') + time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + + store = ElasticDocIndex[MyDoc]() + doc = [ + MyDoc( + expected_attendees={'gte': 10, 'lt': 20}, + time_frame={'gte': '2023-01-01', 'lt': '2023-02-01'}, + ), + MyDoc( + expected_attendees={'gte': 20, 'lt': 30}, + time_frame={'gte': '2023-02-01', 'lt': '2023-03-01'}, + ), + MyDoc( + expected_attendees={'gte': 30, 'lt': 40}, + time_frame={'gte': '2023-03-01', 'lt': '2023-04-01'}, + ), + ] + store.index(doc) + + query = { + 'query': { + 'bool': { + 'should': [ + {'term': {'expected_attendees': {'value': 15}}}, + { + 'range': { + 'time_frame': { + 'gte': '2023-02-05', + 'lt': '2023-02-10', + 'relation': 'contains', + } + } + }, + ] + } + }, + } + docs, _ = store.execute_query(query) + assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] diff --git a/tests/index/elastic/v8/test_find.py b/tests/index/elastic/v8/test_find.py new file mode 100644 index 00000000000..5ee0956bb87 --- /dev/null +++ b/tests/index/elastic/v8/test_find.py @@ -0,0 +1,329 @@ +import numpy as np +import pytest +import torch +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import ElasticDocIndex +from docarray.typing import NdArray, TorchTensor +from tests.index.elastic.fixture import start_storage_v8 # noqa: F401 +from tests.index.elastic.fixture import FlatDoc, SimpleDoc + +pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] + + +@pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) +def test_find_simple_schema(similarity): + class SimpleSchema(BaseDoc): + tens: NdArray[10] = Field(similarity=similarity) + + store = ElasticDocIndex[SimpleSchema]() + + index_docs = [] + for _ in range(10): + vec = np.random.rand(10) + if similarity == 'dot_product': + vec = vec / np.linalg.norm(vec) + index_docs.append(SimpleDoc(tens=vec)) + store.index(index_docs) + + query = index_docs[-1] + docs, scores = store.find(query, search_field='tens', limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens, index_docs[-1].tens) + + +@pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) +def test_find_flat_schema(similarity): + class FlatSchema(BaseDoc): + tens_one: NdArray = Field(dims=10, similarity=similarity) + tens_two: NdArray = Field(dims=50, similarity=similarity) + + store = ElasticDocIndex[FlatSchema]() + + index_docs = [] + for _ in range(10): + vec_one = np.random.rand(10) + vec_two = np.random.rand(50) + if similarity == 'dot_product': + vec_one = vec_one / np.linalg.norm(vec_one) + vec_two = vec_two / np.linalg.norm(vec_two) + index_docs.append(FlatDoc(tens_one=vec_one, tens_two=vec_two)) + + store.index(index_docs) + + query = index_docs[-1] + + # find on tens_one + docs, scores = store.find(query, search_field='tens_one', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens_one, index_docs[-1].tens_one) + assert np.allclose(docs[0].tens_two, index_docs[-1].tens_two) + + # find on tens_two + docs, scores = store.find(query, search_field='tens_two', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens_one, index_docs[-1].tens_one) + assert np.allclose(docs[0].tens_two, index_docs[-1].tens_two) + + +@pytest.mark.parametrize('similarity', ['cosine', 'l2_norm', 'dot_product']) +def test_find_nested_schema(similarity): + class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(similarity=similarity) + + class NestedDoc(BaseDoc): + d: SimpleDoc + tens: NdArray[10] = Field(similarity=similarity) + + class DeepNestedDoc(BaseDoc): + d: NestedDoc + tens: NdArray = Field(similarity=similarity, dims=10) + + store = ElasticDocIndex[DeepNestedDoc]() + + index_docs = [] + for _ in range(10): + vec_simple = np.random.rand(10) + vec_nested = np.random.rand(10) + vec_deep = np.random.rand(10) + if similarity == 'dot_product': + vec_simple = vec_simple / np.linalg.norm(vec_simple) + vec_nested = vec_nested / np.linalg.norm(vec_nested) + vec_deep = vec_deep / np.linalg.norm(vec_deep) + index_docs.append( + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=vec_simple), tens=vec_nested), + tens=vec_deep, + ) + ) + + store.index(index_docs) + + query = index_docs[-1] + + # find on root level + docs, scores = store.find(query, search_field='tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens, index_docs[-1].tens) + + # find on first nesting level + docs, scores = store.find(query, search_field='d__tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].d.tens, index_docs[-1].d.tens) + + # find on second nesting level + docs, scores = store.find(query, search_field='d__d__tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].d.d.tens, index_docs[-1].d.d.tens) + + +def test_find_torch(): + class TorchDoc(BaseDoc): + tens: TorchTensor[10] + + store = ElasticDocIndex[TorchDoc]() + + # A dense_vector field stores dense vectors of float values. + index_docs = [ + TorchDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) + ] + store.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TorchTensor) + + query = index_docs[-1] + docs, scores = store.find(query, search_field='tens', limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + for doc in docs: + assert isinstance(doc.tens, TorchTensor) + + assert docs[0].id == index_docs[-1].id + assert torch.allclose(docs[0].tens, index_docs[-1].tens) + + +def test_find_tensorflow(): + from docarray.typing import TensorFlowTensor + + class TfDoc(BaseDoc): + tens: TensorFlowTensor[10] + + store = ElasticDocIndex[TfDoc]() + + index_docs = [ + TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) + ] + store.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TensorFlowTensor) + + query = index_docs[-1] + docs, scores = store.find(query, search_field='tens', limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + for doc in docs: + assert isinstance(doc.tens, TensorFlowTensor) + + assert docs[0].id == index_docs[-1].id + assert np.allclose( + docs[0].tens.unwrap().numpy(), index_docs[-1].tens.unwrap().numpy() + ) + + +def test_find_batched(): + store = ElasticDocIndex[SimpleDoc]() + + index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] + store.index(index_docs) + + queries = index_docs[-2:] + docs_batched, scores_batched = store.find_batched( + queries, search_field='tens', limit=5 + ) + + for docs, scores, query in zip(docs_batched, scores_batched, queries): + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == query.id + assert np.allclose(docs[0].tens, query.tens) + + +def test_filter(): + class MyDoc(BaseDoc): + A: bool + B: int + C: float + + store = ElasticDocIndex[MyDoc]() + + index_docs = [MyDoc(id=f'{i}', A=(i % 2 == 0), B=i, C=i + 0.5) for i in range(10)] + store.index(index_docs) + + filter_query = {'term': {'A': True}} + docs = store.filter(filter_query) + assert len(docs) > 0 + for doc in docs: + assert doc.A + + filter_query = { + 'bool': { + 'filter': [ + {'terms': {'B': [3, 4, 7, 8]}}, + {'range': {'C': {'gte': 3, 'lte': 5}}}, + ] + } + } + docs = store.filter(filter_query) + assert [doc.id for doc in docs] == ['3', '4'] + + +def test_text_search(): + class MyDoc(BaseDoc): + text: str + + store = ElasticDocIndex[MyDoc]() + index_docs = [ + MyDoc(text='hello world'), + MyDoc(text='never gonna give you up'), + MyDoc(text='we are the world'), + ] + store.index(index_docs) + + query = 'world' + docs, scores = store.text_search(query, search_field='text') + + assert len(docs) == 2 + assert len(scores) == 2 + assert docs[0].text.index(query) >= 0 + assert docs[1].text.index(query) >= 0 + + queries = ['world', 'never'] + docs, scores = store.text_search_batched(queries, search_field='text') + for query, da, score in zip(queries, docs, scores): + assert len(da) > 0 + assert len(score) > 0 + for doc in da: + assert doc.text.index(query) >= 0 + + +def test_query_builder(): + class MyDoc(BaseDoc): + tens: NdArray[10] = Field(similarity='l2_norm') + num: int + text: str + + store = ElasticDocIndex[MyDoc]() + index_docs = [ + MyDoc(id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'text {int(i/2)}') + for i in range(10) + ] + store.index(index_docs) + + # build_query + q = store.build_query() + assert isinstance(q, store.QueryBuilder) + + # filter + q = store.build_query().filter({'term': {'num': 0}}).build() + docs, _ = store.execute_query(q) + assert [doc['id'] for doc in docs] == ['0', '1'] + + # find + q = store.build_query().find(index_docs[-1], search_field='tens', limit=3).build() + docs, _ = store.execute_query(q) + assert [doc['id'] for doc in docs] == ['9', '8', '7'] + + # text_search + q = store.build_query().text_search('0', search_field='text').build() + docs, _ = store.execute_query(q) + assert [doc['id'] for doc in docs] == ['0', '1'] + + # combination + q = ( + store.build_query() + .filter({'range': {'num': {'lte': 3}}}) + .find(index_docs[-1], search_field='tens') + .text_search('0', search_field='text') + .build() + ) + docs, _ = store.execute_query(q) + assert [doc['id'] for doc in docs] == ['1', '0'] + + # direct + query = { + 'knn': { + 'field': 'tens', + 'query_vector': [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0], + 'k': 10, + 'num_candidates': 10000, + 'filter': { + 'bool': { + 'filter': [ + {'range': {'num': {'gte': 2}}}, + {'range': {'num': {'lte': 3}}}, + ] + } + }, + }, + } + + docs, _ = store.execute_query(query) + assert [doc['id'] for doc in docs] == ['7', '6', '5', '4'] diff --git a/tests/index/elastic/v8/test_index_get_del.py b/tests/index/elastic/v8/test_index_get_del.py new file mode 100644 index 00000000000..db2df925ebb --- /dev/null +++ b/tests/index/elastic/v8/test_index_get_del.py @@ -0,0 +1,272 @@ +from typing import Union + +import numpy as np +import pytest + +from docarray import BaseDoc, DocList +from docarray.documents import ImageDoc, TextDoc +from docarray.index import ElasticDocIndex +from docarray.typing import NdArray +from tests.index.elastic.fixture import ( # noqa: F401 + DeepNestedDoc, + FlatDoc, + NestedDoc, + SimpleDoc, + start_storage_v8, + ten_deep_nested_docs, + ten_flat_docs, + ten_nested_docs, + ten_simple_docs, +) + +pytestmark = [pytest.mark.slow, pytest.mark.index, pytest.mark.elasticv8] + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_simple_schema(ten_simple_docs, use_docarray): # noqa: F811 + store = ElasticDocIndex[SimpleDoc]() + if use_docarray: + ten_simple_docs = DocList[SimpleDoc](ten_simple_docs) + + store.index(ten_simple_docs) + assert store.num_docs() == 10 + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_flat_schema(ten_flat_docs, use_docarray): # noqa: F811 + store = ElasticDocIndex[FlatDoc]() + if use_docarray: + ten_flat_docs = DocList[FlatDoc](ten_flat_docs) + + store.index(ten_flat_docs) + assert store.num_docs() == 10 + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_nested_schema(ten_nested_docs, use_docarray): # noqa: F811 + store = ElasticDocIndex[NestedDoc]() + if use_docarray: + ten_nested_docs = DocList[NestedDoc](ten_nested_docs) + + store.index(ten_nested_docs) + assert store.num_docs() == 10 + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): # noqa: F811 + store = ElasticDocIndex[DeepNestedDoc]() + if use_docarray: + ten_deep_nested_docs = DocList[DeepNestedDoc](ten_deep_nested_docs) + + store.index(ten_deep_nested_docs) + assert store.num_docs() == 10 + + +def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 + # simple + store = ElasticDocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + assert store.num_docs() == 10 + for d in ten_simple_docs: + id_ = d.id + assert store[id_].id == id_ + assert np.all(store[id_].tens == d.tens) + + # flat + store = ElasticDocIndex[FlatDoc]() + store.index(ten_flat_docs) + + assert store.num_docs() == 10 + for d in ten_flat_docs: + id_ = d.id + assert store[id_].id == id_ + assert np.all(store[id_].tens_one == d.tens_one) + assert np.all(store[id_].tens_two == d.tens_two) + + # nested + store = ElasticDocIndex[NestedDoc]() + store.index(ten_nested_docs) + + assert store.num_docs() == 10 + for d in ten_nested_docs: + id_ = d.id + assert store[id_].id == id_ + assert store[id_].d.id == d.d.id + assert np.all(store[id_].d.tens == d.d.tens) + + +def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): # noqa: F811 + docs_to_get_idx = [0, 2, 4, 6, 8] + + # simple + store = ElasticDocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + assert store.num_docs() == 10 + docs_to_get = [ten_simple_docs[i] for i in docs_to_get_idx] + ids_to_get = [d.id for d in docs_to_get] + retrieved_docs = store[ids_to_get] + for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): + assert d_out.id == id_ + assert np.all(d_out.tens == d_in.tens) + + # flat + store = ElasticDocIndex[FlatDoc]() + store.index(ten_flat_docs) + + assert store.num_docs() == 10 + docs_to_get = [ten_flat_docs[i] for i in docs_to_get_idx] + ids_to_get = [d.id for d in docs_to_get] + retrieved_docs = store[ids_to_get] + for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): + assert d_out.id == id_ + assert np.all(d_out.tens_one == d_in.tens_one) + assert np.all(d_out.tens_two == d_in.tens_two) + + # nested + store = ElasticDocIndex[NestedDoc]() + store.index(ten_nested_docs) + + assert store.num_docs() == 10 + docs_to_get = [ten_nested_docs[i] for i in docs_to_get_idx] + ids_to_get = [d.id for d in docs_to_get] + retrieved_docs = store[ids_to_get] + for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): + assert d_out.id == id_ + assert d_out.d.id == d_in.d.id + assert np.all(d_out.d.tens == d_in.d.tens) + + +def test_get_key_error(ten_simple_docs): # noqa: F811 + store = ElasticDocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + with pytest.raises(KeyError): + store['not_a_real_id'] + + +def test_persisting(ten_simple_docs): # noqa: F811 + store = ElasticDocIndex[SimpleDoc](index_name='test_persisting') + store.index(ten_simple_docs) + + store2 = ElasticDocIndex[SimpleDoc](index_name='test_persisting') + assert store2.num_docs() == 10 + + +def test_del_single(ten_simple_docs): # noqa: F811 + store = ElasticDocIndex[SimpleDoc]() + store.index(ten_simple_docs) + # delete once + assert store.num_docs() == 10 + del store[ten_simple_docs[0].id] + assert store.num_docs() == 9 + for i, d in enumerate(ten_simple_docs): + id_ = d.id + if i == 0: # deleted + with pytest.raises(KeyError): + store[id_] + else: + assert store[id_].id == id_ + assert np.all(store[id_].tens == d.tens) + # delete again + del store[ten_simple_docs[3].id] + assert store.num_docs() == 8 + for i, d in enumerate(ten_simple_docs): + id_ = d.id + if i in (0, 3): # deleted + with pytest.raises(KeyError): + store[id_] + else: + assert store[id_].id == id_ + assert np.all(store[id_].tens == d.tens) + + +def test_del_multiple(ten_simple_docs): # noqa: F811 + docs_to_del_idx = [0, 2, 4, 6, 8] + + store = ElasticDocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + assert store.num_docs() == 10 + docs_to_del = [ten_simple_docs[i] for i in docs_to_del_idx] + ids_to_del = [d.id for d in docs_to_del] + del store[ids_to_del] + for i, doc in enumerate(ten_simple_docs): + if i in docs_to_del_idx: + with pytest.raises(KeyError): + store[doc.id] + else: + assert store[doc.id].id == doc.id + assert np.all(store[doc.id].tens == doc.tens) + + +def test_del_key_error(ten_simple_docs): # noqa: F811 + store = ElasticDocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + with pytest.warns(UserWarning): + del store['not_a_real_id'] + + +def test_num_docs(ten_simple_docs): # noqa: F811 + store = ElasticDocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + assert store.num_docs() == 10 + + del store[ten_simple_docs[0].id] + assert store.num_docs() == 9 + + del store[ten_simple_docs[3].id, ten_simple_docs[5].id] + assert store.num_docs() == 7 + + more_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(5)] + store.index(more_docs) + assert store.num_docs() == 12 + + del store[more_docs[2].id, ten_simple_docs[7].id] + assert store.num_docs() == 10 + + +def test_index_union_doc(): # noqa: F811 + class MyDoc(BaseDoc): + tensor: Union[NdArray, str] + + class MySchema(BaseDoc): + tensor: NdArray + + store = ElasticDocIndex[MySchema]() + doc = [MyDoc(tensor=np.random.randn(128))] + store.index(doc) + + id_ = doc[0].id + assert store[id_].id == id_ + assert np.all(store[id_].tensor == doc[0].tensor) + + +def test_index_multi_modal_doc(): + class MyMultiModalDoc(BaseDoc): + image: ImageDoc + text: TextDoc + + store = ElasticDocIndex[MyMultiModalDoc]() + + doc = [ + MyMultiModalDoc( + image=ImageDoc(embedding=np.random.randn(128)), text=TextDoc(text='hello') + ) + ] + store.index(doc) + + id_ = doc[0].id + assert store[id_].id == id_ + assert np.all(store[id_].image.embedding == doc[0].image.embedding) + assert store[id_].text.text == doc[0].text.text + + +def test_elasticv7_version_check(): + with pytest.raises(ImportError): + from docarray.index import ElasticV7DocIndex + + _ = ElasticV7DocIndex[SimpleDoc]() diff --git a/tests/integrations/doc_index/elastic/v7/__init__.py b/tests/integrations/doc_index/elastic/v7/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000