Skip to content
54 changes: 30 additions & 24 deletions docarray/index/backends/elastic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# mypy: ignore-errors
import uuid
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
Expand Down Expand Up @@ -74,12 +73,6 @@ def __init__(self, db_config=None, **kwargs):
self._logger.debug('Elastic Search index is being initialized')

# ElasticSearch client creation
if self._db_config.index_name is None:
id = uuid.uuid4().hex
self._db_config.index_name = 'index__' + id

self._index_name = self._db_config.index_name

self._client = Elasticsearch(
hosts=self._db_config.hosts,
**self._db_config.es_config,
Expand Down Expand Up @@ -108,15 +101,28 @@ def __init__(self, db_config=None, **kwargs):
mappings['properties'][col_name] = self._create_index_mapping(col)

# print(mappings['properties'])
if self._client.indices.exists(index=self._index_name):
if self._client.indices.exists(index=self.index_name):
self._client_put_mapping(mappings)
else:
self._client_create(mappings)

if len(self._db_config.index_settings):
self._client_put_settings(self._db_config.index_settings)

self._refresh(self._index_name)
self._refresh(self.index_name)

@property
def index_name(self):
default_index_name = (
self._schema.__name__.lower() if self._schema is not None else None
)
if default_index_name is None:
raise ValueError(
'A ElasticDocIndex must be typed with a Document type.'
'To do so, use the syntax: ElasticDocIndex[DocumentType]'
)

return self._db_config.index_name or default_index_name

###############################################
# Inner classes for query builder and configs #
Expand Down Expand Up @@ -333,7 +339,7 @@ def _index(

for row in data:
request = {
'_index': self._index_name,
'_index': self.index_name,
'_id': row['id'],
}
for col_name, col in self._column_infos.items():
Expand All @@ -349,13 +355,13 @@ def _index(
warnings.warn(str(info))

if refresh:
self._refresh(self._index_name)
self._refresh(self.index_name)

def num_docs(self) -> int:
"""
Get the number of documents.
"""
return self._client.count(index=self._index_name)['count']
return self._client.count(index=self.index_name)['count']

def _del_items(
self,
Expand All @@ -365,7 +371,7 @@ def _del_items(
requests = []
for _id in doc_ids:
requests.append(
{'_op_type': 'delete', '_index': self._index_name, '_id': _id}
{'_op_type': 'delete', '_index': self.index_name, '_id': _id}
)

_, warning_info = self._send_requests(requests, chunk_size)
Expand All @@ -375,7 +381,7 @@ def _del_items(
ids = [info['delete']['_id'] for info in warning_info]
warnings.warn(f'No document with id {ids} found')

self._refresh(self._index_name)
self._refresh(self.index_name)

def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]:
accumulated_docs = []
Expand Down Expand Up @@ -416,7 +422,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any:
f'args and kwargs not supported for `execute_query` on {type(self)}'
)

resp = self._client.search(index=self._index_name, **query)
resp = self._client.search(index=self.index_name, **query)
docs, scores = self._format_response(resp)

return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores))
Expand All @@ -440,7 +446,7 @@ def _find_batched(
) -> _FindResultBatched:
request = []
for query in queries:
head = {'index': self._index_name}
head = {'index': self.index_name}
body = self._form_search_body(query, limit, search_field)
request.extend([head, body])

Expand Down Expand Up @@ -469,7 +475,7 @@ def _filter_batched(
) -> List[List[Dict]]:
request = []
for query in filter_queries:
head = {'index': self._index_name}
head = {'index': self.index_name}
body = {'query': query, 'size': limit}
request.extend([head, body])

Expand Down Expand Up @@ -499,7 +505,7 @@ def _text_search_batched(
) -> _FindResultBatched:
request = []
for query in queries:
head = {'index': self._index_name}
head = {'index': self.index_name}
body = self._form_text_search_body(query, limit, search_field)
request.extend([head, body])

Expand Down Expand Up @@ -615,20 +621,20 @@ def _refresh(self, index_name: str):

def _client_put_mapping(self, mappings: Dict[str, Any]):
self._client.indices.put_mapping(
index=self._index_name, properties=mappings['properties']
index=self.index_name, properties=mappings['properties']
)

def _client_create(self, mappings: Dict[str, Any]):
self._client.indices.create(index=self._index_name, mappings=mappings)
self._client.indices.create(index=self.index_name, mappings=mappings)

def _client_put_settings(self, settings: Dict[str, Any]):
self._client.indices.put_settings(index=self._index_name, settings=settings)
self._client.indices.put_settings(index=self.index_name, settings=settings)

def _client_mget(self, ids: Sequence[str]):
return self._client.mget(index=self._index_name, ids=ids)
return self._client.mget(index=self.index_name, ids=ids)

def _client_search(self, **kwargs):
return self._client.search(index=self._index_name, **kwargs)
return self._client.search(index=self.index_name, **kwargs)

def _client_msearch(self, request: List[Dict[str, Any]]):
return self._client.msearch(index=self._index_name, searches=request)
return self._client.msearch(index=self.index_name, searches=request)
14 changes: 7 additions & 7 deletions docarray/index/backends/elasticv7.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def execute_query(self, query: Dict[str, Any], *args, **kwargs) -> Any:
f'args and kwargs not supported for `execute_query` on {type(self)}'
)

resp = self._client.search(index=self._index_name, body=query)
resp = self._client.search(index=self.index_name, body=query)
docs, scores = self._format_response(resp)

return _FindResult(documents=docs, scores=parse_obj_as(NdArray, scores))
Expand Down Expand Up @@ -161,20 +161,20 @@ def _form_search_body(self, query: np.ndarray, limit: int, search_field: str = '
###############################################

def _client_put_mapping(self, mappings: Dict[str, Any]):
self._client.indices.put_mapping(index=self._index_name, body=mappings)
self._client.indices.put_mapping(index=self.index_name, body=mappings)

def _client_create(self, mappings: Dict[str, Any]):
body = {'mappings': mappings}
self._client.indices.create(index=self._index_name, body=body)
self._client.indices.create(index=self.index_name, body=body)

def _client_put_settings(self, settings: Dict[str, Any]):
self._client.indices.put_settings(index=self._index_name, body=settings)
self._client.indices.put_settings(index=self.index_name, body=settings)

def _client_mget(self, ids: Sequence[str]):
return self._client.mget(index=self._index_name, body={'ids': ids})
return self._client.mget(index=self.index_name, body={'ids': ids})

def _client_search(self, **kwargs):
return self._client.search(index=self._index_name, body=kwargs)
return self._client.search(index=self.index_name, body=kwargs)

def _client_msearch(self, request: List[Dict[str, Any]]):
return self._client.msearch(index=self._index_name, body=request)
return self._client.msearch(index=self.index_name, body=request)
42 changes: 27 additions & 15 deletions docarray/index/backends/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@

TSchema = TypeVar('TSchema', bound=BaseDoc)


QDRANT_PY_VECTOR_TYPES: List[Any] = [np.ndarray, AbstractTensor]
if torch_imported:
import torch
Expand Down Expand Up @@ -86,6 +85,19 @@ def __init__(self, db_config=None, **kwargs):
self._initialize_collection()
self._logger.info(f'{self.__class__.__name__} has been initialized')

@property
def collection_name(self):
default_collection_name = (
self._schema.__name__.lower() if self._schema is not None else None
)
if default_collection_name is None:
raise ValueError(
'A QdrantDocumentIndex must be typed with a Document type.'
'To do so, use the syntax: QdrantDocumentIndex[DocumentType]'
)

return self._db_config.collection_name or default_collection_name

@dataclass
class Query:
"""Dataclass describing a query."""
Expand Down Expand Up @@ -211,7 +223,7 @@ class DBConfig(BaseDocIndex.DBConfig):
timeout: Optional[float] = None
host: Optional[str] = None
path: Optional[str] = None
collection_name: str = 'documents'
collection_name: Optional[str] = None
shard_number: Optional[int] = None
replication_factor: Optional[int] = None
write_consistency_factor: Optional[int] = None
Expand Down Expand Up @@ -250,15 +262,15 @@ def python_type_to_db_type(self, python_type: Type) -> Any:

def _initialize_collection(self):
try:
self._client.get_collection(self._db_config.collection_name)
self._client.get_collection(self.collection_name)
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = {
column_name: self._to_qdrant_vector_params(column_info)
for column_name, column_info in self._column_infos.items()
if column_info.db_type == 'vector'
}
self._client.create_collection(
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
vectors_config=vectors_config,
shard_number=self._db_config.shard_number,
replication_factor=self._db_config.replication_factor,
Expand All @@ -270,7 +282,7 @@ def _initialize_collection(self):
quantization_config=self._db_config.quantization_config,
)
self._client.create_payload_index(
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
field_name='__generated_vectors',
field_schema=rest.PayloadSchemaType.KEYWORD,
)
Expand All @@ -280,15 +292,15 @@ def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
# TODO: add batching the documents to avoid timeouts
points = [self._build_point_from_row(row) for row in rows]
self._client.upsert(
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
points=points,
)

def num_docs(self) -> int:
"""
Get the number of documents.
"""
return self._client.count(collection_name=self._db_config.collection_name).count
return self._client.count(collection_name=self.collection_name).count

def _del_items(self, doc_ids: Sequence[str]):
items = self._get_items(doc_ids)
Expand All @@ -298,7 +310,7 @@ def _del_items(self, doc_ids: Sequence[str]):
raise KeyError('Document keys could not found: %s' % ','.join(missing_keys))

self._client.delete(
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
points_selector=rest.PointIdsList(
points=[self._to_qdrant_id(doc_id) for doc_id in doc_ids],
),
Expand All @@ -308,7 +320,7 @@ def _get_items(
self, doc_ids: Sequence[str]
) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
response, _ = self._client.scroll(
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
scroll_filter=rest.Filter(
must=[
rest.HasIdCondition(
Expand Down Expand Up @@ -343,7 +355,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi
# We perform semantic search with some vectors with Qdrant's search method
# should be called
points = self._client.search( # type: ignore[assignment]
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
query_vector=(query.vector_field, query.vector_query), # type: ignore[arg-type]
query_filter=rest.Filter(
must=[query.filter],
Expand All @@ -364,7 +376,7 @@ def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocLi
else:
# Just filtering, so Qdrant's scroll has to be used instead
points, _ = self._client.scroll( # type: ignore[assignment]
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
scroll_filter=query.filter,
limit=query.limit,
with_payload=True,
Expand All @@ -388,7 +400,7 @@ def _execute_raw_query(
if search_params:
search_params = rest.SearchParams.parse_obj(search_params) # type: ignore[assignment]
points = self._client.search( # type: ignore[assignment]
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
query_vector=query.pop('vector'),
query_filter=payload_filter,
search_params=search_params,
Expand All @@ -397,7 +409,7 @@ def _execute_raw_query(
else:
# Just filtering, so Qdrant's scroll has to be used instead
points, _ = self._client.scroll( # type: ignore[assignment]
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
scroll_filter=payload_filter,
**query,
)
Expand All @@ -417,7 +429,7 @@ def _find_batched(
self, queries: np.ndarray, limit: int, search_field: str = ''
) -> _FindResultBatched:
responses = self._client.search_batch(
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
requests=[
rest.SearchRequest(
vector=rest.NamedVector(
Expand Down Expand Up @@ -470,7 +482,7 @@ def _filter_batched(
# There is no batch scroll available in Qdrant client yet, so we need to
# perform the queries one by one. It will be changed in the future versions.
response, _ = self._client.scroll(
collection_name=self._db_config.collection_name,
collection_name=self.collection_name,
scroll_filter=filter_query,
limit=limit,
with_payload=True,
Expand Down
Loading