Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docarray/index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import types
from typing import TYPE_CHECKING

from docarray.index.backends.in_memory import InMemoryDocIndex
from docarray.utils._internal.misc import (
_get_path_from_docarray_root_level,
import_library,
Expand All @@ -13,7 +14,7 @@
from docarray.index.backends.qdrant import QdrantDocumentIndex # noqa: F401
from docarray.index.backends.weaviate import WeaviateDocumentIndex # noqa: F401

__all__ = []
__all__ = ['InMemoryDocIndex']


def __getattr__(name: str):
Expand Down
58 changes: 58 additions & 0 deletions docarray/index/backends/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Dict, List, Tuple, Type, cast

from docarray import BaseDoc, DocList
from docarray.index.abstract import BaseDocIndex
from docarray.utils.filter import filter_docs
from docarray.utils.find import FindResult


def _collect_query_args(method_name: str): # TODO: use partialmethod instead
def inner(self, *args, **kwargs):
if args:
raise ValueError(
f'Positional arguments are not supported for '
f'`{type(self)}.{method_name}`.'
f' Use keyword arguments instead.'
)
updated_query = self._queries + [(method_name, kwargs)]
return type(self)(updated_query)

return inner


def _execute_find_and_filter_query(
doc_index: BaseDocIndex, query: List[Tuple[str, Dict]]
) -> FindResult:
"""
Executes all find calls from query first using `doc_index.find()`,
and filtering queries after that using DocArray's `filter_docs()`.

Text search is not supported.
"""
docs_found = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))([])
filter_conditions = []
doc_to_score: Dict[BaseDoc, Any] = {}
for op, op_kwargs in query:
if op == 'find':
docs, scores = doc_index.find(**op_kwargs)
docs_found.extend(docs)
doc_to_score.update(zip(docs.__getattribute__('id'), scores))
elif op == 'filter':
filter_conditions.append(op_kwargs['filter_query'])
else:
raise ValueError(f'Query operation is not supported: {op}')

doc_index._logger.debug(f'Executing query {query}')
docs_filtered = docs_found
for cond in filter_conditions:
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))
docs_filtered = docs_cls(filter_docs(docs_filtered, cond))

doc_index._logger.debug(f'{len(docs_filtered)} results found')
docs_and_scores = zip(
docs_filtered, (doc_to_score[doc.id] for doc in docs_filtered)
)
docs_sorted = sorted(docs_and_scores, key=lambda x: x[1])
out_docs, out_scores = zip(*docs_sorted)

return FindResult(documents=out_docs, scores=out_scores)
49 changes: 9 additions & 40 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
_raise_not_composable,
_raise_not_supported,
)
from docarray.index.backends.helper import (
_collect_query_args,
_execute_find_and_filter_query,
)
from docarray.proto import DocProto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
from docarray.utils._internal.misc import import_library, is_np_int
from docarray.utils.filter import filter_docs
from docarray.utils.find import _FindResult, _FindResultBatched

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,20 +64,6 @@
T = TypeVar('T', bound='HnswDocumentIndex')


def _collect_query_args(method_name: str): # TODO: use partialmethod instead
def inner(self, *args, **kwargs):
if args:
raise ValueError(
f'Positional arguments are not supported for '
f'`{type(self)}.{method_name}`.'
f' Use keyword arguments instead.'
)
updated_query = self._queries + [(method_name, kwargs)]
return type(self)(updated_query)

return inner


class HnswDocumentIndex(BaseDocIndex, Generic[TSchema]):
def __init__(self, db_config=None, **kwargs):
"""Initialize HnswDocumentIndex"""
Expand Down Expand Up @@ -232,7 +221,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):

def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
"""
Execute a query on the WeaviateDocumentIndex.
Execute a query on the HnswDocumentIndex.

Can take two kinds of inputs:

Expand All @@ -249,31 +238,11 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
raise ValueError(
f'args and kwargs not supported for `execute_query` on {type(self)}'
)

ann_docs = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))([])
filter_conditions = []
doc_to_score: Dict[BaseDoc, Any] = {}
for op, op_kwargs in query:
if op == 'find':
docs, scores = self.find(**op_kwargs)
ann_docs.extend(docs)
doc_to_score.update(zip(docs.__getattribute__('id'), scores))
elif op == 'filter':
filter_conditions.append(op_kwargs['filter_query'])

self._logger.debug(f'Executing query {query}')
docs_filtered = ann_docs
for cond in filter_conditions:
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
docs_filtered = docs_cls(filter_docs(docs_filtered, cond))

self._logger.debug(f'{len(docs_filtered)} results found')
docs_and_scores = zip(
docs_filtered, (doc_to_score[doc.id] for doc in docs_filtered)
find_res = _execute_find_and_filter_query(
doc_index=self,
query=query,
)
docs_sorted = sorted(docs_and_scores, key=lambda x: x[1])
out_docs, out_scores = zip(*docs_sorted)
return _FindResult(documents=out_docs, scores=out_scores)
return find_res

def _find_batched(
self,
Expand Down
Loading