diff --git a/docarray/index/__init__.py b/docarray/index/__init__.py index 5fdbf8ad736..4a9986e1a97 100644 --- a/docarray/index/__init__.py +++ b/docarray/index/__init__.py @@ -1,3 +1,4 @@ +from docarray.index.backends.elastic import ElasticV7DocIndex from docarray.index.backends.hnswlib import HnswDocumentIndex -__all__ = ['HnswDocumentIndex'] +__all__ = ['HnswDocumentIndex', 'ElasticV7DocIndex'] diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 3f0ab25a1b0..a037b070527 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -9,6 +9,7 @@ Generic, Iterable, List, + Mapping, NamedTuple, Optional, Sequence, @@ -233,13 +234,13 @@ def _find( @abstractmethod def _find_batched( self, - query: np.ndarray, + queries: np.ndarray, limit: int, search_field: str = '', ) -> _FindResultBatched: """Find documents in the index - :param query: query vectors for KNN/ANN search. + :param queries: query vectors for KNN/ANN search. Has shape (batch_size, vector_dim) :param limit: maximum number of documents to return :param search_field: name of the field to search on @@ -593,7 +594,7 @@ def _get_values_by_column(docs: Sequence[BaseDoc], col_name: str) -> List[Any]: @staticmethod def _transpose_col_value_dict( - col_value_dict: Dict[str, Iterable[Any]] + col_value_dict: Mapping[str, Iterable[Any]] ) -> Generator[Dict[str, Any], None, None]: """'Transpose' the output of `_get_col_value_dict()`: Yield rows of columns, where each row represent one Document. Since a generator is returned, this process comes at negligible cost. diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py new file mode 100644 index 00000000000..deefc3b2a86 --- /dev/null +++ b/docarray/index/backends/elastic.py @@ -0,0 +1,546 @@ +import os +import uuid +import warnings +from collections import defaultdict +from dataclasses import dataclass, field +from typing import ( + Any, + Dict, + Generator, + Generic, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +import numpy as np +from elasticsearch import Elasticsearch +from elasticsearch.helpers import parallel_bulk +from pydantic import parse_obj_as + +import docarray.typing +from docarray import BaseDoc +from docarray.index.abstract import ( + BaseDocIndex, + _ColumnInfo, + _FindResultBatched, + _raise_not_composable, +) +from docarray.typing import AnyTensor +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.typing.tensor.ndarray import NdArray +from docarray.utils._internal.misc import is_tf_available, is_torch_available +from docarray.utils.find import _FindResult + +TSchema = TypeVar('TSchema', bound=BaseDoc) +T = TypeVar('T', bound='ElasticV7DocIndex') + +ELASTIC_PY_VEC_TYPES: List[Any] = [list, tuple, np.ndarray, AbstractTensor] + +if is_torch_available(): + import torch + + ELASTIC_PY_VEC_TYPES.append(torch.Tensor) + +if is_tf_available(): + import tensorflow as tf # type: ignore + + from docarray.typing import TensorFlowTensor + + ELASTIC_PY_VEC_TYPES.append(tf.Tensor) + ELASTIC_PY_VEC_TYPES.append(TensorFlowTensor) + + +class ElasticV7DocIndex(BaseDocIndex, Generic[TSchema]): + def __init__(self, db_config=None, **kwargs): + super().__init__(db_config=db_config, **kwargs) + self._db_config = cast(ElasticV7DocIndex.DBConfig, self._db_config) + + if self._db_config.index_name is None: + id = uuid.uuid4().hex + self._db_config.index_name = 'index__' + id + + self._index_name = self._db_config.index_name + + self._client = Elasticsearch( + hosts=self._db_config.hosts, + **self._db_config.es_config, + ) + + # compatibility + self._server_version = self._client.info()['version']['number'] + if int(self._server_version.split('.')[0]) >= 8: + os.environ['ELASTIC_CLIENT_APIVERSIONING'] = '1' + + body: Dict[str, Any] = { + 'mappings': { + 'dynamic': True, + '_source': {'enabled': 'true'}, + 'properties': {}, + } + } + + for col_name, col in self._column_infos.items(): + body['mappings']['properties'][col_name] = self._create_index_mapping(col) + + if self._client.indices.exists(index=self._index_name): + self._client.indices.put_mapping( + index=self._index_name, body=body['mappings'] + ) + else: + self._client.indices.create(index=self._index_name, body=body) + + if len(self._db_config.index_settings): + self._client.indices.put_settings( + index=self._index_name, body=self._db_config.index_settings + ) + + self._refresh(self._index_name) + + ############################################### + # Inner classes for query builder and configs # + ############################################### + + class QueryBuilder(BaseDocIndex.QueryBuilder): + def __init__(self, outer_instance, **kwargs): + super().__init__() + self._outer_instance = outer_instance + self._query: Dict[str, Any] = { + 'query': defaultdict(lambda: defaultdict(list)) + } + + def build(self, *args, **kwargs) -> Any: + if ( + 'script_score' in self._query['query'] + and 'bool' in self._query['query'] + and len(self._query['query']['bool']) > 0 + ): + self._query['query']['script_score']['query'] = {} + self._query['query']['script_score']['query']['bool'] = self._query[ + 'query' + ]['bool'] + del self._query['query']['bool'] + + return self._query + + def find( + self, + query: Union[AnyTensor, BaseDoc], + search_field: str = 'embedding', + limit: int = 10, + ): + if isinstance(query, BaseDoc): + query_vec = BaseDocIndex._get_values_by_column([query], search_field)[0] + else: + query_vec = query + query_vec_np = BaseDocIndex._to_numpy(self._outer_instance, query_vec) + self._query['size'] = limit + self._query['query']['script_score'] = ElasticV7DocIndex._form_search_body( + query_vec_np, limit, search_field + )['query']['script_score'] + + return self + + def filter(self, query: Dict[str, Any], limit: int = 10): + self._query['size'] = limit + self._query['query']['bool']['filter'].append(query) + return self + + def text_search(self, query: str, search_field: str = 'text', limit: int = 10): + self._query['size'] = limit + self._query['query']['bool']['must'].append( + {'match': {search_field: query}} + ) + return self + + find_batched = _raise_not_composable('find_batched') + filter_batched = _raise_not_composable('find_batched') + text_search_batched = _raise_not_composable('text_search') + + def build_query(self, **kwargs) -> QueryBuilder: + """ + Build a query for this DocumentIndex. + """ + return self.QueryBuilder(self, **kwargs) + + @dataclass + class DBConfig(BaseDocIndex.DBConfig): + hosts: Union[str, List[str], None] = 'http://localhost:9200' + index_name: Optional[str] = None + es_config: Dict[str, Any] = field(default_factory=dict) + index_settings: Dict[str, Any] = field(default_factory=dict) + + @dataclass + class RuntimeConfig(BaseDocIndex.RuntimeConfig): + default_column_config: Dict[Any, Dict[str, Any]] = field( + default_factory=lambda: { + 'binary': {}, + 'boolean': {}, + 'keyword': {}, + 'long': {}, + 'integer': {}, + 'short': {}, + 'byte': {}, + 'double': {}, + 'float': {}, + 'half_float': {}, + 'scaled_float': {}, + 'unsigned_long': {}, + 'dates': {}, + 'alias': {}, + 'object': {}, + 'flattened': {}, + 'nested': {}, + 'join': {}, + 'integer_range': {}, + 'float_range': {}, + 'long_range': {}, + 'double_range': {}, + 'date_range': {}, + 'ip': {}, + 'version': {}, + 'histogram': {}, + 'text': {}, + 'annotated_text': {}, + 'completion': {}, + 'search_as_you_type': {}, + 'token_count': {}, + 'dense_vector': {'dims': 128}, + 'sparse_vector': {}, + 'rank_feature': {}, + 'rank_features': {}, + 'geo_point': {}, + 'geo_shape': {}, + 'point': {}, + 'shape': {}, + 'percolator': {}, + # `None` is not a Type, but we allow it here anyway + None: {}, # type: ignore + } + ) + chunk_size: int = 500 + + ############################################### + # Implementation of abstract methods # + ############################################### + + def python_type_to_db_type(self, python_type: Type) -> Any: + """Map python type to database type.""" + + for allowed_type in ELASTIC_PY_VEC_TYPES: + if issubclass(python_type, allowed_type): + return 'dense_vector' + + elastic_py_types = { + docarray.typing.ID: 'keyword', + docarray.typing.AnyUrl: 'keyword', + bool: 'boolean', + int: 'integer', + float: 'float', + str: 'text', + bytes: 'binary', + dict: 'object', + } + + for type in elastic_py_types.keys(): + if issubclass(python_type, type): + return elastic_py_types[type] + + raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') + + def _index( + self, + column_to_data: Mapping[str, Generator[Any, None, None]], + refresh: bool = True, + chunk_size: Optional[int] = None, + ): + + data = self._transpose_col_value_dict(column_to_data) + requests = [] + + for row in data: + request = { + '_index': self._index_name, + '_id': row['id'], + } + for col_name, col in self._column_infos.items(): + if col.db_type == 'dense_vector' and np.all(row[col_name] == 0): + row[col_name] = row[col_name] + 1.0e-9 + if row[col_name] is None: + continue + request[col_name] = row[col_name] + requests.append(request) + + _, warning_info = self._send_requests(requests, chunk_size) + for info in warning_info: + warnings.warn(str(info)) + + if refresh: + self._refresh(self._index_name) + + def num_docs(self) -> int: + return self._client.count(index=self._index_name)['count'] + + def _del_items( + self, + doc_ids: Sequence[str], + chunk_size: Optional[int] = None, + ): + requests = [] + for _id in doc_ids: + requests.append( + {'_op_type': 'delete', '_index': self._index_name, '_id': _id} + ) + + _, warning_info = self._send_requests(requests, chunk_size) + + # raise warning if some ids are not found + if warning_info: + ids = [info['delete']['_id'] for info in warning_info] + warnings.warn(f'No document with id {ids} found') + + self._refresh(self._index_name) + + def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]: + accumulated_docs = [] + accumulated_docs_id_not_found = [] + + es_rows = self._client.mget( + index=self._index_name, + body={'ids': doc_ids}, + )['docs'] + + for row in es_rows: + if row['found']: + doc_dict = row['_source'] + accumulated_docs.append(doc_dict) + else: + accumulated_docs_id_not_found.append(row['_id']) + + # raise warning if some ids are not found + if accumulated_docs_id_not_found: + warnings.warn(f'No document with id {accumulated_docs_id_not_found} found') + + return accumulated_docs + + def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any: + if args or kwargs: + raise ValueError( + f'args and kwargs not supported for `execute_query` on {type(self)}' + ) + + resp = self._client.search(index=self._index_name, body=query) + docs, scores = self._format_response(resp) + + return _FindResult(documents=docs, scores=scores) + + def _find( + self, query: np.ndarray, limit: int, search_field: str = '' + ) -> _FindResult: + if int(self._server_version.split('.')[0]) >= 8: + warnings.warn( + 'You are using Elasticsearch 8.0+ and the current client is 7.10.1. HNSW based vector search is not supported and the find method has a default implementation using exhaustive KNN search with cosineSimilarity, which may result in slow performance.' + ) + + body = self._form_search_body(query, limit, search_field) + + resp = self._client.search( + index=self._index_name, + body=body, + ) + + docs, scores = self._format_response(resp) + + return _FindResult(documents=docs, scores=scores) + + def _find_batched( + self, + queries: np.ndarray, + limit: int, + search_field: str = '', + ) -> _FindResultBatched: + request = [] + for query in queries: + head = {'index': self._index_name} + body = self._form_search_body(query, limit, search_field) + request.extend([head, body]) + + responses = self._client.msearch(body=request) + + das, scores = zip( + *[self._format_response(resp) for resp in responses['responses']] + ) + return _FindResultBatched(documents=list(das), scores=np.array(scores)) + + def _filter( + self, + filter_query: Dict[str, Any], + limit: int, + ) -> List[Dict]: + body = { + 'size': limit, + 'query': filter_query, + } + + resp = self._client.search( + index=self._index_name, + body=body, + ) + + docs, _ = self._format_response(resp) + + return docs + + def _filter_batched( + self, + filter_queries: Any, + limit: int, + ) -> List[List[Dict]]: + request = [] + for query in filter_queries: + head = {'index': self._index_name} + body = {'query': query, 'size': limit} + request.extend([head, body]) + + responses = self._client.msearch(body=request) + das, _ = zip(*[self._format_response(resp) for resp in responses['responses']]) + + return list(das) + + def _text_search( + self, + query: str, + limit: int, + search_field: str = '', + ) -> _FindResult: + + body = self._form_text_search_body(query, limit, search_field) + + resp = self._client.search( + index=self._index_name, + body=body, + ) + + docs, scores = self._format_response(resp) + + return _FindResult(documents=docs, scores=scores) + + def _text_search_batched( + self, + queries: Sequence[str], + limit: int, + search_field: str = '', + ) -> _FindResultBatched: + request = [] + for query in queries: + head = {'index': self._index_name} + body = self._form_text_search_body(query, limit, search_field) + request.extend([head, body]) + + responses = self._client.msearch(body=request) + + das, scores = zip( + *[self._format_response(resp) for resp in responses['responses']] + ) + return _FindResultBatched(documents=list(das), scores=np.array(scores)) + + ############################################### + # Helpers # + ############################################### + + # ElasticSearch helpers + def _create_index_mapping(self, col: '_ColumnInfo') -> Dict[str, Any]: + """Create a new HNSW index for a column, and initialize it.""" + + index = col.config.copy() + if 'type' not in index: + index['type'] = col.db_type + + if col.db_type == 'dense_vector' and col.n_dim: + index['dims'] = col.n_dim + + return index + + def _send_requests( + self, + request: Iterable[Dict[str, Any]], + chunk_size: Optional[int] = None, + **kwargs, + ) -> Tuple[List[Dict], List[Any]]: + """Send bulk request to Elastic and gather the successful info""" + + accumulated_info = [] + warning_info = [] + for success, info in parallel_bulk( + self._client, + request, + raise_on_error=False, + raise_on_exception=False, + chunk_size=chunk_size if chunk_size else self._runtime_config.chunk_size, # type: ignore + **kwargs, + ): + if not success: + warning_info.append(info) + else: + accumulated_info.append(info) + + return accumulated_info, warning_info + + @staticmethod + def _form_search_body( + query: np.ndarray, limit: int, search_field: str = '' + ) -> Dict[str, Any]: + body = { + 'size': limit, + 'query': { + 'script_score': { + 'query': {'match_all': {}}, + 'script': { + 'source': f'cosineSimilarity(params.query_vector, \'{search_field}\') + 1.0', + 'params': {'query_vector': query}, + }, + } + }, + } + return body + + @staticmethod + def _form_text_search_body( + query: str, limit: int, search_field: str = '' + ) -> Dict[str, Any]: + body = { + 'size': limit, + 'query': { + 'bool': { + 'must': {'match': {search_field: query}}, + } + }, + } + return body + + def _format_response(self, response: Any) -> Tuple[List[Dict], NdArray]: + docs = [] + scores = [] + for result in response['hits']['hits']: + if not isinstance(result, dict): + result = result.to_dict() + + if result.get('_source', None): + doc_dict = result['_source'] + else: + doc_dict = result['fields'] + doc_dict['id'] = result['_id'] + docs.append(doc_dict) + scores.append(result['_score']) + + return docs, parse_obj_as(NdArray, scores) + + def _refresh(self, index_name: str): + self._client.indices.refresh(index=index_name) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 5dfbc987741..851fdcda3ef 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -232,12 +232,12 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: def _find_batched( self, - query: np.ndarray, + queries: np.ndarray, limit: int, search_field: str = '', ) -> _FindResultBatched: index = self._hnsw_indices[search_field] - labels, distances = index.knn_query(query, k=limit) + labels, distances = index.knn_query(queries, k=limit) result_das = [ self._get_docs_sqlite_hashed_id( ids_per_query.tolist(), @@ -251,7 +251,7 @@ def _find( ) -> _FindResult: query_batched = np.expand_dims(query, axis=0) docs, scores = self._find_batched( - query=query_batched, limit=limit, search_field=search_field + queries=query_batched, limit=limit, search_field=search_field ) return _FindResult(documents=docs[0], scores=scores[0]) diff --git a/poetry.lock b/poetry.lock index 2f89d1231dc..93569cb83c4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -773,6 +773,28 @@ six = ">=1.9.0" gmpy = ["gmpy"] gmpy2 = ["gmpy2"] +[[package]] +name = "elasticsearch" +version = "7.10.1" +description = "Python client for Elasticsearch" +category = "main" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4" +files = [ + {file = "elasticsearch-7.10.1-py2.py3-none-any.whl", hash = "sha256:4ebd34fd223b31c99d9f3b6b6236d3ac18b3046191a37231e8235b06ae7db955"}, + {file = "elasticsearch-7.10.1.tar.gz", hash = "sha256:a725dd923d349ca0652cf95d6ce23d952e2153740cf4ab6daf4a2d804feeed48"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = ">=1.21.1,<2" + +[package.extras] +async = ["aiohttp (>=3,<4)"] +develop = ["black", "coverage", "jinja2", "mock", "pytest", "pytest-cov", "pyyaml", "requests (>=2.0.0,<3.0.0)", "sphinx (<1.7)", "sphinx-rtd-theme"] +docs = ["sphinx (<1.7)", "sphinx-rtd-theme"] +requests = ["requests (>=2.4.0,<3.0.0)"] + [[package]] name = "entrypoints" version = "0.4" @@ -4016,8 +4038,9 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [extras] audio = ["pydub"] aws = ["smart-open"] -common = ["protobuf", "lz4"] -full = ["protobuf", "lz4", "pillow", "types-pillow", "av", "pydub", "trimesh"] +common = ["lz4", "protobuf"] +elasticsearch = ["elasticsearch"] +full = ["av", "lz4", "pillow", "protobuf", "pydub", "trimesh", "types-pillow"] hnswlib = ["hnswlib"] image = ["pillow", "types-pillow"] jac = ["jina-hubble-sdk"] @@ -4030,4 +4053,4 @@ web = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.7,<4.0" -content-hash = "c5b13c9b48aa9edf9d494ce8ba91cfdd9f78d4220ae758ac8a74b69963fc7253" +content-hash = "3388af41b53300637299b44bd9ea94a18a44368bb1006ee16e04d35c99a238b6" diff --git a/pyproject.toml b/pyproject.toml index 9e30c3a804e..78930d3d464 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ hnswlib = {version = ">=0.6.2", optional = true } lz4 = {version= ">=1.0.0", optional = true} pydub = {version = "^0.25.1", optional = true } pandas = {version = ">=1.1.0", optional = true } +elasticsearch = {version = "7.10.1", optional = true } smart-open = {version = ">=6.3.0", extras = ["s3"], optional = true} jina-hubble-sdk = {version = ">=0.34.0", optional = true} @@ -37,6 +38,7 @@ mesh = ["trimesh"] web = ["fastapi"] hnswlib = ["hnswlib"] pandas = ["pandas"] +elasticsearch = ["elasticsearch"] jac = ["jina-hubble-sdk"] aws = ["smart-open"] diff --git a/tests/index/elastic/fixture.py b/tests/index/elastic/fixture.py new file mode 100644 index 00000000000..1caa31da2a6 --- /dev/null +++ b/tests/index/elastic/fixture.py @@ -0,0 +1,58 @@ +import os +import time + +import pytest +from pydantic import Field + +from docarray import BaseDoc +from docarray.typing import NdArray + +pytestmark = [pytest.mark.slow, pytest.mark.doc_index] + + +class SimpleDoc(BaseDoc): + tens: NdArray[10] = Field(dims=1000) + + +class FlatDoc(BaseDoc): + tens_one: NdArray = Field(dims=10) + tens_two: NdArray = Field(dims=50) + + +class NestedDoc(BaseDoc): + d: SimpleDoc + + +class DeepNestedDoc(BaseDoc): + d: NestedDoc + + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +compose_yml_v7 = os.path.abspath(os.path.join(cur_dir, 'v7/docker-compose.yml')) +compose_yml_v8 = os.path.abspath(os.path.join(cur_dir, 'v8/docker-compose.yml')) + + +@pytest.fixture(scope='module', autouse=True) +def start_storage_v7(): + os.system(f"docker-compose -f {compose_yml_v7} up -d --remove-orphans") + _wait_for_es() + + yield + os.system(f"docker-compose -f {compose_yml_v7} down --remove-orphans") + + +@pytest.fixture(scope='module', autouse=True) +def start_storage_v8(): + os.system(f"docker-compose -f {compose_yml_v8} up -d --remove-orphans") + _wait_for_es() + + yield + os.system(f"docker-compose -f {compose_yml_v8} down --remove-orphans") + + +def _wait_for_es(): + from elasticsearch import Elasticsearch + + es = Elasticsearch(hosts='http://localhost:9200/') + while not es.ping(): + time.sleep(0.5) diff --git a/tests/index/elastic/v7/docker-compose.yml b/tests/index/elastic/v7/docker-compose.yml new file mode 100644 index 00000000000..f4dd8a49d0b --- /dev/null +++ b/tests/index/elastic/v7/docker-compose.yml @@ -0,0 +1,16 @@ +version: "3.3" +services: + elastic: + image: docker.elastic.co/elasticsearch/elasticsearch:7.10.2 + environment: + - xpack.security.enabled=false + - discovery.type=single-node + - ES_JAVA_OPTS=-Xmx1024m + ports: + - "9200:9200" + networks: + - elastic + +networks: + elastic: + name: elastic \ No newline at end of file diff --git a/tests/index/elastic/v7/test_column_config.py b/tests/index/elastic/v7/test_column_config.py new file mode 100644 index 00000000000..80f42c6a347 --- /dev/null +++ b/tests/index/elastic/v7/test_column_config.py @@ -0,0 +1,128 @@ +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import ElasticV7DocIndex +from tests.index.elastic.fixture import start_storage_v7 # noqa: F401 + + +def test_column_config(): + class MyDoc(BaseDoc): + text: str + color: str = Field(col_type='keyword') + + store = ElasticV7DocIndex[MyDoc]() + index_docs = [ + MyDoc(id='0', text='hello world', color='red'), + MyDoc(id='1', text='never gonna give you up', color='blue'), + MyDoc(id='2', text='we are the world', color='green'), + ] + store.index(index_docs) + + query = 'world' + docs, _ = store.text_search(query, search_field='text') + assert [doc.id for doc in docs] == ['0', '2'] + + filter_query = {'terms': {'color': ['red', 'blue']}} + docs = store.filter(filter_query) + assert [doc.id for doc in docs] == ['0', '1'] + + +def test_field_object(): + class MyDoc(BaseDoc): + manager: dict = Field( + properties={ + 'age': {'type': 'integer'}, + 'name': { + 'properties': { + 'first': {'type': 'keyword'}, + 'last': {'type': 'keyword'}, + } + }, + } + ) + + store = ElasticV7DocIndex[MyDoc]() + doc = [ + MyDoc(manager={'age': 25, 'name': {'first': 'Rachel', 'last': 'Green'}}), + MyDoc(manager={'age': 30, 'name': {'first': 'Monica', 'last': 'Geller'}}), + MyDoc(manager={'age': 35, 'name': {'first': 'Phoebe', 'last': 'Buffay'}}), + ] + store.index(doc) + id_ = doc[0].id + assert store[id_].id == id_ + assert store[id_].manager == doc[0].manager + + filter_query = {'range': {'manager.age': {'gte': 30}}} + docs = store.filter(filter_query) + assert [doc.id for doc in docs] == [doc[1].id, doc[2].id] + + +def test_field_geo_point(): + class MyDoc(BaseDoc): + location: dict = Field(col_type='geo_point') + + store = ElasticV7DocIndex[MyDoc]() + doc = [ + MyDoc(location={'lat': 40.12, 'lon': -72.34}), + MyDoc(location={'lat': 41.12, 'lon': -73.34}), + MyDoc(location={'lat': 42.12, 'lon': -74.34}), + ] + store.index(doc) + + query = { + 'query': { + 'geo_bounding_box': { + 'location': { + 'top_left': {'lat': 42, 'lon': -74}, + 'bottom_right': {'lat': 40, 'lon': -72}, + } + } + }, + } + + docs, _ = store.execute_query(query) + assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] + + +def test_field_range(): + class MyDoc(BaseDoc): + expected_attendees: dict = Field(col_type='integer_range') + time_frame: dict = Field(col_type='date_range', format='yyyy-MM-dd') + + store = ElasticV7DocIndex[MyDoc]() + doc = [ + MyDoc( + expected_attendees={'gte': 10, 'lt': 20}, + time_frame={'gte': '2023-01-01', 'lt': '2023-02-01'}, + ), + MyDoc( + expected_attendees={'gte': 20, 'lt': 30}, + time_frame={'gte': '2023-02-01', 'lt': '2023-03-01'}, + ), + MyDoc( + expected_attendees={'gte': 30, 'lt': 40}, + time_frame={'gte': '2023-03-01', 'lt': '2023-04-01'}, + ), + ] + store.index(doc) + + query = { + 'query': { + 'bool': { + 'should': [ + {'term': {'expected_attendees': {'value': 15}}}, + { + 'range': { + 'time_frame': { + 'gte': '2023-02-05', + 'lt': '2023-02-10', + 'relation': 'contains', + } + } + }, + ] + } + }, + } + docs, _ = store.execute_query(query) + assert [doc['id'] for doc in docs] == [doc[0].id, doc[1].id] diff --git a/tests/index/elastic/v7/test_find.py b/tests/index/elastic/v7/test_find.py new file mode 100644 index 00000000000..562b6f68326 --- /dev/null +++ b/tests/index/elastic/v7/test_find.py @@ -0,0 +1,318 @@ +import numpy as np +import torch +from pydantic import Field + +from docarray import BaseDoc +from docarray.index import ElasticV7DocIndex +from docarray.typing import NdArray, TorchTensor +from tests.index.elastic.fixture import start_storage_v7 # noqa: F401 +from tests.index.elastic.fixture import FlatDoc, SimpleDoc + + +def test_find_simple_schema(): + class SimpleSchema(BaseDoc): + tens: NdArray[10] + + store = ElasticV7DocIndex[SimpleSchema]() + + index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] + store.index(index_docs) + + query = index_docs[-1] + docs, scores = store.find(query, search_field='tens', limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens, index_docs[-1].tens) + + +def test_find_flat_schema(): + class FlatSchema(BaseDoc): + tens_one: NdArray = Field(dims=10) + tens_two: NdArray = Field(dims=50) + + store = ElasticV7DocIndex[FlatSchema]() + + index_docs = [ + FlatDoc(tens_one=np.random.rand(10), tens_two=np.random.rand(50)) + for _ in range(10) + ] + store.index(index_docs) + + query = index_docs[-1] + + # find on tens_one + docs, scores = store.find(query, search_field='tens_one', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens_one, index_docs[-1].tens_one) + assert np.allclose(docs[0].tens_two, index_docs[-1].tens_two) + + # find on tens_two + docs, scores = store.find(query, search_field='tens_two', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens_one, index_docs[-1].tens_one) + assert np.allclose(docs[0].tens_two, index_docs[-1].tens_two) + + +def test_find_nested_schema(): + class SimpleDoc(BaseDoc): + tens: NdArray[10] + + class NestedDoc(BaseDoc): + d: SimpleDoc + tens: NdArray[10] + + class DeepNestedDoc(BaseDoc): + d: NestedDoc + tens: NdArray = Field(dims=10) + + store = ElasticV7DocIndex[DeepNestedDoc]() + + index_docs = [ + DeepNestedDoc( + d=NestedDoc(d=SimpleDoc(tens=np.random.rand(10)), tens=np.random.rand(10)), + tens=np.random.rand(10), + ) + for _ in range(10) + ] + store.index(index_docs) + + query = index_docs[-1] + + # find on root level + docs, scores = store.find(query, search_field='tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].tens, index_docs[-1].tens) + + # find on first nesting level + docs, scores = store.find(query, search_field='d__tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].d.tens, index_docs[-1].d.tens) + + # find on second nesting level + docs, scores = store.find(query, search_field='d__d__tens', limit=5) + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == index_docs[-1].id + assert np.allclose(docs[0].d.d.tens, index_docs[-1].d.d.tens) + + +def test_find_torch(): + class TorchDoc(BaseDoc): + tens: TorchTensor[10] + + store = ElasticV7DocIndex[TorchDoc]() + + # A dense_vector field stores dense vectors of float values. + index_docs = [ + TorchDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) + ] + store.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TorchTensor) + + query = index_docs[-1] + docs, scores = store.find(query, search_field='tens', limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + for doc in docs: + assert isinstance(doc.tens, TorchTensor) + + assert docs[0].id == index_docs[-1].id + assert torch.allclose(docs[0].tens, index_docs[-1].tens) + + +def test_find_tensorflow(): + from docarray.typing import TensorFlowTensor + + class TfDoc(BaseDoc): + tens: TensorFlowTensor[10] + + store = ElasticV7DocIndex[TfDoc]() + + index_docs = [ + TfDoc(tens=np.random.rand(10).astype(dtype=np.float32)) for _ in range(10) + ] + store.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TensorFlowTensor) + + query = index_docs[-1] + docs, scores = store.find(query, search_field='tens', limit=5) + + assert len(docs) == 5 + assert len(scores) == 5 + for doc in docs: + assert isinstance(doc.tens, TensorFlowTensor) + + assert docs[0].id == index_docs[-1].id + assert np.allclose( + docs[0].tens.unwrap().numpy(), index_docs[-1].tens.unwrap().numpy() + ) + + +def test_find_batched(): + store = ElasticV7DocIndex[SimpleDoc]() + + index_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(10)] + store.index(index_docs) + + queries = index_docs[-2:] + docs_batched, scores_batched = store.find_batched( + queries, search_field='tens', limit=5 + ) + + for docs, scores, query in zip(docs_batched, scores_batched, queries): + assert len(docs) == 5 + assert len(scores) == 5 + assert docs[0].id == query.id + assert np.allclose(docs[0].tens, query.tens) + + +def test_filter(): + class MyDoc(BaseDoc): + A: bool + B: int + C: float + + store = ElasticV7DocIndex[MyDoc]() + + index_docs = [MyDoc(id=f'{i}', A=(i % 2 == 0), B=i, C=i + 0.5) for i in range(10)] + store.index(index_docs) + + filter_query = {'term': {'A': True}} + docs = store.filter(filter_query) + assert len(docs) > 0 + for doc in docs: + assert doc.A + + filter_query = { + 'bool': { + 'filter': [ + {'terms': {'B': [3, 4, 7, 8]}}, + {'range': {'C': {'gte': 3, 'lte': 5}}}, + ] + } + } + docs = store.filter(filter_query) + assert [doc.id for doc in docs] == ['3', '4'] + + +def test_text_search(): + class MyDoc(BaseDoc): + text: str + + store = ElasticV7DocIndex[MyDoc]() + index_docs = [ + MyDoc(text='hello world'), + MyDoc(text='never gonna give you up'), + MyDoc(text='we are the world'), + ] + store.index(index_docs) + + query = 'world' + docs, scores = store.text_search(query, search_field='text') + + assert len(docs) == 2 + assert len(scores) == 2 + assert docs[0].text.index(query) >= 0 + assert docs[1].text.index(query) >= 0 + + queries = ['world', 'never'] + docs, scores = store.text_search_batched(queries, search_field='text') + for query, da, score in zip(queries, docs, scores): + assert len(da) > 0 + assert len(score) > 0 + for doc in da: + assert doc.text.index(query) >= 0 + + +def test_query_builder(): + class MyDoc(BaseDoc): + tens: NdArray[10] + num: int + text: str + + store = ElasticV7DocIndex[MyDoc]() + index_docs = [ + MyDoc( + id=f'{i}', tens=np.random.rand(10), num=int(i / 2), text=f'text {int(i/2)}' + ) + for i in range(10) + ] + store.index(index_docs) + + # build_query + q = store.build_query() + assert isinstance(q, store.QueryBuilder) + + # filter + q = store.build_query().filter({'term': {'num': 0}}).build() + docs, _ = store.execute_query(q) + assert [doc['id'] for doc in docs] == ['0', '1'] + + # find + q = store.build_query().find(index_docs[-1], search_field='tens', limit=3).build() + docs, scores = store.execute_query(q) + assert len(docs) == 3 + assert len(scores) == 3 + assert docs[0]['id'] == index_docs[-1].id + assert np.allclose(docs[0]['tens'], index_docs[-1].tens) + + # text search + q = store.build_query().text_search('0', search_field='text').build() + docs, _ = store.execute_query(q) + assert [doc['id'] for doc in docs] == ['0', '1'] + + # combination + q = ( + store.build_query() + .filter({'range': {'num': {'lte': 3}}}) + .find(index_docs[-1], search_field='tens') + .text_search('0', search_field='text') + .build() + ) + docs, _ = store.execute_query(q) + assert sorted([doc['id'] for doc in docs]) == ['0', '1'] + + # direct + index_docs = [ + MyDoc(id=f'{i}', tens=np.ones(10) * i, num=int(i / 2), text=f'text {int(i/2)}') + for i in range(10) + ] + store.index(index_docs) + + query = { + 'query': { + 'script_score': { + 'query': { + 'bool': { + 'filter': [ + {'range': {'num': {'gte': 2}}}, + {'range': {'num': {'lte': 3}}}, + ], + }, + }, + 'script': { + 'source': '1 / (1 + l2norm(params.query_vector, \'tens\'))', + 'params': {'query_vector': index_docs[-1].tens}, + }, + } + } + } + + docs, _ = store.execute_query(query) + assert [doc['id'] for doc in docs] == ['7', '6', '5', '4'] diff --git a/tests/index/elastic/v7/test_index_get_del.py b/tests/index/elastic/v7/test_index_get_del.py new file mode 100644 index 00000000000..8389e28ffbb --- /dev/null +++ b/tests/index/elastic/v7/test_index_get_del.py @@ -0,0 +1,280 @@ +from typing import Union + +import numpy as np +import pytest + +from docarray import BaseDoc, DocArray +from docarray.documents import ImageDoc, TextDoc +from docarray.index import ElasticV7DocIndex +from docarray.typing import NdArray +from tests.index.elastic.fixture import start_storage_v7 # noqa: F401 +from tests.index.elastic.fixture import DeepNestedDoc, FlatDoc, NestedDoc, SimpleDoc + + +@pytest.fixture +def ten_simple_docs(): + return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)] + + +@pytest.fixture +def ten_flat_docs(): + return [ + FlatDoc(tens_one=np.random.randn(10), tens_two=np.random.randn(50)) + for _ in range(10) + ] + + +@pytest.fixture +def ten_nested_docs(): + return [NestedDoc(d=SimpleDoc(tens=np.random.randn(10))) for _ in range(10)] + + +@pytest.fixture +def ten_deep_nested_docs(): + return [ + DeepNestedDoc(d=NestedDoc(d=SimpleDoc(tens=np.random.randn(10)))) + for _ in range(10) + ] + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_simple_schema(ten_simple_docs, use_docarray): + store = ElasticV7DocIndex[SimpleDoc]() + if use_docarray: + ten_simple_docs = DocArray[SimpleDoc](ten_simple_docs) + + store.index(ten_simple_docs) + assert store.num_docs() == 10 + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_flat_schema(ten_flat_docs, use_docarray): + store = ElasticV7DocIndex[FlatDoc]() + if use_docarray: + ten_flat_docs = DocArray[FlatDoc](ten_flat_docs) + + store.index(ten_flat_docs) + assert store.num_docs() == 10 + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_nested_schema(ten_nested_docs, use_docarray): + store = ElasticV7DocIndex[NestedDoc]() + if use_docarray: + ten_nested_docs = DocArray[NestedDoc](ten_nested_docs) + + store.index(ten_nested_docs) + assert store.num_docs() == 10 + + +@pytest.mark.parametrize('use_docarray', [True, False]) +def test_index_deep_nested_schema(ten_deep_nested_docs, use_docarray): + store = ElasticV7DocIndex[DeepNestedDoc]() + if use_docarray: + ten_deep_nested_docs = DocArray[DeepNestedDoc](ten_deep_nested_docs) + + store.index(ten_deep_nested_docs) + assert store.num_docs() == 10 + + +def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs): + # simple + store = ElasticV7DocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + assert store.num_docs() == 10 + for d in ten_simple_docs: + id_ = d.id + assert store[id_].id == id_ + assert np.all(store[id_].tens == d.tens) + + # flat + store = ElasticV7DocIndex[FlatDoc]() + store.index(ten_flat_docs) + + assert store.num_docs() == 10 + for d in ten_flat_docs: + id_ = d.id + assert store[id_].id == id_ + assert np.all(store[id_].tens_one == d.tens_one) + assert np.all(store[id_].tens_two == d.tens_two) + + # nested + store = ElasticV7DocIndex[NestedDoc]() + store.index(ten_nested_docs) + + assert store.num_docs() == 10 + for d in ten_nested_docs: + id_ = d.id + assert store[id_].id == id_ + assert store[id_].d.id == d.d.id + assert np.all(store[id_].d.tens == d.d.tens) + + +def test_get_multiple(ten_simple_docs, ten_flat_docs, ten_nested_docs): + docs_to_get_idx = [0, 2, 4, 6, 8] + + # simple + store = ElasticV7DocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + assert store.num_docs() == 10 + docs_to_get = [ten_simple_docs[i] for i in docs_to_get_idx] + ids_to_get = [d.id for d in docs_to_get] + retrieved_docs = store[ids_to_get] + for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): + assert d_out.id == id_ + assert np.all(d_out.tens == d_in.tens) + + # flat + store = ElasticV7DocIndex[FlatDoc]() + store.index(ten_flat_docs) + + assert store.num_docs() == 10 + docs_to_get = [ten_flat_docs[i] for i in docs_to_get_idx] + ids_to_get = [d.id for d in docs_to_get] + retrieved_docs = store[ids_to_get] + for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): + assert d_out.id == id_ + assert np.all(d_out.tens_one == d_in.tens_one) + assert np.all(d_out.tens_two == d_in.tens_two) + + # nested + store = ElasticV7DocIndex[NestedDoc]() + store.index(ten_nested_docs) + + assert store.num_docs() == 10 + docs_to_get = [ten_nested_docs[i] for i in docs_to_get_idx] + ids_to_get = [d.id for d in docs_to_get] + retrieved_docs = store[ids_to_get] + for id_, d_in, d_out in zip(ids_to_get, docs_to_get, retrieved_docs): + assert d_out.id == id_ + assert d_out.d.id == d_in.d.id + assert np.all(d_out.d.tens == d_in.d.tens) + + +def test_get_key_error(ten_simple_docs): + store = ElasticV7DocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + with pytest.raises(KeyError): + store['not_a_real_id'] + + +def test_persisting(ten_simple_docs): + store = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') + store.index(ten_simple_docs) + + store2 = ElasticV7DocIndex[SimpleDoc](index_name='test_persisting') + assert store2.num_docs() == 10 + + +def test_del_single(ten_simple_docs): + store = ElasticV7DocIndex[SimpleDoc]() + store.index(ten_simple_docs) + # delete once + assert store.num_docs() == 10 + del store[ten_simple_docs[0].id] + assert store.num_docs() == 9 + for i, d in enumerate(ten_simple_docs): + id_ = d.id + if i == 0: # deleted + with pytest.raises(KeyError): + store[id_] + else: + assert store[id_].id == id_ + assert np.all(store[id_].tens == d.tens) + # delete again + del store[ten_simple_docs[3].id] + assert store.num_docs() == 8 + for i, d in enumerate(ten_simple_docs): + id_ = d.id + if i in (0, 3): # deleted + with pytest.raises(KeyError): + store[id_] + else: + assert store[id_].id == id_ + assert np.all(store[id_].tens == d.tens) + + +def test_del_multiple(ten_simple_docs): + docs_to_del_idx = [0, 2, 4, 6, 8] + + store = ElasticV7DocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + assert store.num_docs() == 10 + docs_to_del = [ten_simple_docs[i] for i in docs_to_del_idx] + ids_to_del = [d.id for d in docs_to_del] + del store[ids_to_del] + for i, doc in enumerate(ten_simple_docs): + if i in docs_to_del_idx: + with pytest.raises(KeyError): + store[doc.id] + else: + assert store[doc.id].id == doc.id + assert np.all(store[doc.id].tens == doc.tens) + + +def test_del_key_error(ten_simple_docs): + store = ElasticV7DocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + with pytest.warns(UserWarning): + del store['not_a_real_id'] + + +def test_num_docs(ten_simple_docs): + store = ElasticV7DocIndex[SimpleDoc]() + store.index(ten_simple_docs) + + assert store.num_docs() == 10 + + del store[ten_simple_docs[0].id] + assert store.num_docs() == 9 + + del store[ten_simple_docs[3].id, ten_simple_docs[5].id] + assert store.num_docs() == 7 + + more_docs = [SimpleDoc(tens=np.random.rand(10)) for _ in range(5)] + store.index(more_docs) + assert store.num_docs() == 12 + + del store[more_docs[2].id, ten_simple_docs[7].id] + assert store.num_docs() == 10 + + +def test_index_union_doc(): + class MyDoc(BaseDoc): + tensor: Union[NdArray, str] + + class MySchema(BaseDoc): + tensor: NdArray + + store = ElasticV7DocIndex[MySchema]() + doc = [MyDoc(tensor=np.random.randn(128))] + store.index(doc) + + id_ = doc[0].id + assert store[id_].id == id_ + assert np.all(store[id_].tensor == doc[0].tensor) + + +def test_index_multi_modal_doc(): + class MyMultiModalDoc(BaseDoc): + image: ImageDoc + text: TextDoc + + store = ElasticV7DocIndex[MyMultiModalDoc]() + + doc = [ + MyMultiModalDoc( + image=ImageDoc(embedding=np.random.randn(128)), text=TextDoc(text='hello') + ) + ] + store.index(doc) + + id_ = doc[0].id + assert store[id_].id == id_ + assert np.all(store[id_].image.embedding == doc[0].image.embedding) + assert store[id_].text.text == doc[0].text.text