diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 669a919ac54..722755aa6a0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -217,7 +217,7 @@ jobs: - name: Test id: test run: | - poetry run pytest -m 'index' tests + poetry run pytest -m 'index and (not tensorflow)' tests timeout-minutes: 30 docarray-test-tensorflow: diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index e1fd847c841..6aec9549726 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -1,3 +1,4 @@ +import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field, replace from typing import ( @@ -20,14 +21,14 @@ import numpy as np from pydantic.error_wrappers import ValidationError -from typing_inspect import get_args, is_union_type +from typing_inspect import get_args, is_optional_type, is_union_type from docarray import BaseDocument, DocumentArray from docarray.array.abstract_array import AnyDocumentArray from docarray.typing import AnyTensor +from docarray.utils._typing import unwrap_optional_type from docarray.utils.find import FindResult, _FindResult -from docarray.utils.misc import torch_imported -import logging +from docarray.utils.misc import is_tf_available, torch_imported if TYPE_CHECKING: from pydantic.fields import ModelField @@ -35,6 +36,11 @@ if torch_imported: import torch +if is_tf_available(): + import tensorflow as tf # type: ignore + + from docarray.typing import TensorFlowTensor + TSchema = TypeVar('TSchema', bound=BaseDocument) @@ -614,7 +620,13 @@ def _get_col_value_dict( docs_seq = docs def _col_gen(col_name: str): - return (self._get_values_by_column([doc], col_name)[0] for doc in docs_seq) + return ( + self._to_numpy( + self._get_values_by_column([doc], col_name)[0], + allow_passthrough=True, + ) + for doc in docs_seq + ) return {col_name: _col_gen(col_name) for col_name in self._column_infos} @@ -697,7 +709,11 @@ def _create_column_infos( """ column_infos: Dict[str, _ColumnInfo] = dict() for field_name, type_, field_ in self._flatten_schema(schema): - if is_union_type(type_): + if is_optional_type(type_): + column_infos[field_name] = self._create_single_column( + field_, unwrap_optional_type(type_) + ) + elif is_union_type(type_): raise ValueError( 'Union types are not supported in the schema of a DocumentIndex.' f' Instead of using type {type_} use a single specific type.' @@ -793,15 +809,28 @@ def _validate_docs( return DocumentArray[BaseDocument].construct(out_docs) - def _to_numpy(self, val: Any) -> Any: + def _to_numpy(self, val: Any, allow_passthrough=False) -> Any: + """ + Converts a value to a numpy array, if possible. + + :param val: The value to convert + :param allow_passthrough: If True, the value is returned as-is if it is not convertible to a numpy array. + If False, a `ValueError` is raised if the value is not convertible to a numpy array. + :return: The value as a numpy array, or as-is if `allow_passthrough` is True and the value is not convertible + """ if isinstance(val, np.ndarray): return val - elif isinstance(val, (list, tuple)): + if is_tf_available() and isinstance(val, TensorFlowTensor): + return val.unwrap().numpy() + if isinstance(val, (list, tuple)): return np.array(val) - elif torch_imported and isinstance(val, torch.Tensor): + if (torch_imported and isinstance(val, torch.Tensor)) or ( + is_tf_available() and isinstance(val, tf.Tensor) + ): return val.numpy() - else: - raise ValueError(f'Unsupported input type for {type(self)}: {type(val)}') + if allow_passthrough: + return val + raise ValueError(f'Unsupported input type for {type(self)}: {type(val)}') def _convert_dict_to_doc( self, doc_dict: Dict[str, Any], schema: Type[BaseDocument] @@ -815,7 +844,7 @@ def _convert_dict_to_doc( """ for field_name, _ in schema.__fields__.items(): - t_ = schema._get_field_type(field_name) + t_ = unwrap_optional_type(schema._get_field_type(field_name)) if issubclass(t_, BaseDocument): inner_dict = {} diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index 13f8029e2b5..4f1799fb7d7 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -20,7 +20,6 @@ import hnswlib import numpy as np -import docarray.typing from docarray import BaseDocument, DocumentArray from docarray.index.abstract import ( BaseDocumentIndex, @@ -32,17 +31,25 @@ from docarray.proto import DocumentProto from docarray.utils.filter import filter_docs from docarray.utils.find import _FindResult -from docarray.utils.misc import is_np_int, torch_imported +from docarray.utils.misc import is_np_int, is_tf_available, is_torch_available TSchema = TypeVar('TSchema', bound=BaseDocument) T = TypeVar('T', bound='HnswDocumentIndex') HNSWLIB_PY_VEC_TYPES = [list, tuple, np.ndarray] -if torch_imported: +if is_torch_available(): import torch HNSWLIB_PY_VEC_TYPES.append(torch.Tensor) +if is_tf_available(): + import tensorflow as tf # type: ignore + + from docarray.typing import TensorFlowTensor + + HNSWLIB_PY_VEC_TYPES.append(tf.Tensor) + HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor) + def _collect_query_args(method_name: str): # TODO: use partialmethod instead def inner(self, *args, **kwargs): @@ -84,8 +91,14 @@ def __init__(self, db_config=None, **kwargs): self._hnsw_indices = {} for col_name, col in self._column_infos.items(): if not col.config: - self._logger.warning( - f'No index was created for `{col_name}` as it does not have a config' + # non-tensor type; don't create an index + continue + if not load_existing and ( + (not col.n_dim and col.config['dim'] < 0) or not col.config['index'] + ): + # tensor type, but don't index + self._logger.info( + f'Not indexing column {col_name}; either `index=False` is set or no dimensionality is specified' ) continue if load_existing: @@ -133,7 +146,8 @@ class RuntimeConfig(BaseDocumentIndex.RuntimeConfig): default_column_config: Dict[Type, Dict[str, Any]] = field( default_factory=lambda: { np.ndarray: { - 'dim': 128, + 'dim': -1, + 'index': True, # if False, don't index at all 'space': 'l2', # 'l2', 'ip', 'cosine' 'max_elements': 1024, 'ef_construction': 200, @@ -157,10 +171,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any: if issubclass(python_type, allowed_type): return np.ndarray - if python_type == docarray.typing.ID: - return None - - raise ValueError(f'Unsupported column type for {type(self)}: {python_type}') + return None # all types allowed, but no db type needed def _index(self, column_data_dic, **kwargs): # not needed, we implement `index` directly @@ -328,16 +339,6 @@ def _load_index(self, col_name: str, col: '_ColumnInfo') -> hnswlib.Index: index.load_index(self._hnsw_locations[col_name]) return index - def _to_numpy(self, val: Any) -> Any: - if isinstance(val, np.ndarray): - return val - elif isinstance(val, (list, tuple)): - return np.array(val) - elif torch_imported and isinstance(val, torch.Tensor): - return val.numpy() - else: - raise ValueError(f'Unsupported input type for {type(self)}: {type(val)}') - # HNSWLib helpers def _create_index_class(self, col: '_ColumnInfo') -> hnswlib.Index: """Create an instance of hnswlib.index without initializing it.""" diff --git a/docarray/utils/_typing.py b/docarray/utils/_typing.py index 62680cf964e..9bbc0162432 100644 --- a/docarray/utils/_typing.py +++ b/docarray/utils/_typing.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from typing_inspect import get_args, is_union_type +from typing_inspect import get_args, is_optional_type, is_union_type from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -32,3 +32,17 @@ def change_cls_name(cls: type, new_name: str, scope: Optional[dict] = None) -> N scope[new_name] = cls cls.__qualname__ = cls.__qualname__[: -len(cls.__name__)] + new_name cls.__name__ = new_name + + +def unwrap_optional_type(type_: Any) -> Any: + """Return the type of an Optional type, e.g. `unwrap_optional(Optional[str]) == str`; + `unwrap_optional(Union[None, int, None]) == int`. + + :param type_: the type to unwrap + :return: the "core" type of an Optional type + """ + if not is_optional_type(type_): + return type_ + for arg in get_args(type_): + if arg is not type(None): + return arg diff --git a/tests/index/hnswlib/test_find.py b/tests/index/hnswlib/test_find.py index d8910525098..cfc7679bcea 100644 --- a/tests/index/hnswlib/test_find.py +++ b/tests/index/hnswlib/test_find.py @@ -1,10 +1,11 @@ import numpy as np import pytest +import torch from pydantic import Field from docarray import BaseDocument from docarray.index import HnswDocumentIndex -from docarray.typing import NdArray +from docarray.typing import NdArray, TorchTensor pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -26,6 +27,10 @@ class DeepNestedDoc(BaseDocument): d: NestedDoc +class TorchDoc(BaseDocument): + tens: TorchTensor[10] + + @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) def test_find_simple_schema(tmp_path, space): class SimpleSchema(BaseDocument): @@ -49,6 +54,63 @@ class SimpleSchema(BaseDocument): assert np.allclose(result.tens, np.zeros(10)) +@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) +def test_find_torch(tmp_path, space): + store = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path)) + + index_docs = [TorchDoc(tens=np.zeros(10)) for _ in range(10)] + index_docs.append(TorchDoc(tens=np.ones(10))) + store.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TorchTensor) + + query = TorchDoc(tens=np.ones(10)) + + result_docs, scores = store.find(query, search_field='tens', limit=5) + + assert len(result_docs) == 5 + assert len(scores) == 5 + for doc in result_docs: + assert isinstance(doc.tens, TorchTensor) + assert result_docs[0].id == index_docs[-1].id + assert torch.allclose(result_docs[0].tens, index_docs[-1].tens) + for result in result_docs[1:]: + assert torch.allclose(result.tens, torch.zeros(10, dtype=torch.float64)) + + +@pytest.mark.tensorflow +def test_find_tensorflow(tmp_path): + from docarray.typing import TensorFlowTensor + + class TfDoc(BaseDocument): + tens: TensorFlowTensor[10] + + store = HnswDocumentIndex[TfDoc](work_dir=str(tmp_path)) + + index_docs = [TfDoc(tens=np.zeros(10)) for _ in range(10)] + index_docs.append(TfDoc(tens=np.ones(10))) + store.index(index_docs) + + for doc in index_docs: + assert isinstance(doc.tens, TensorFlowTensor) + + query = TfDoc(tens=np.ones(10)) + + result_docs, scores = store.find(query, search_field='tens', limit=5) + + assert len(result_docs) == 5 + assert len(scores) == 5 + for doc in result_docs: + assert isinstance(doc.tens, TensorFlowTensor) + assert result_docs[0].id == index_docs[-1].id + assert np.allclose( + result_docs[0].tens.unwrap().numpy(), index_docs[-1].tens.unwrap().numpy() + ) + for result in result_docs[1:]: + assert np.allclose(result.tens.unwrap().numpy(), np.zeros(10)) + + @pytest.mark.parametrize('space', ['cosine', 'l2', 'ip']) def test_find_flat_schema(tmp_path, space): class FlatSchema(BaseDocument): diff --git a/tests/index/hnswlib/test_index_get_del.py b/tests/index/hnswlib/test_index_get_del.py index df57301fc6a..0d4eb02f537 100644 --- a/tests/index/hnswlib/test_index_get_del.py +++ b/tests/index/hnswlib/test_index_get_del.py @@ -1,10 +1,15 @@ +import os +from typing import Optional + import numpy as np import pytest +import torch from pydantic import Field from docarray import BaseDocument, DocumentArray +from docarray.documents import ImageDoc, TextDoc from docarray.index import HnswDocumentIndex -from docarray.typing import NdArray +from docarray.typing import NdArray, NdArrayEmbedding, TorchTensor pytestmark = [pytest.mark.slow, pytest.mark.index] @@ -26,6 +31,10 @@ class DeepNestedDoc(BaseDocument): d: NestedDoc +class TorchDoc(BaseDocument): + tens: TorchTensor[10] + + @pytest.fixture def ten_simple_docs(): return [SimpleDoc(tens=np.random.randn(10)) for _ in range(10)] @@ -88,6 +97,77 @@ def test_index_nested_schema(ten_nested_docs, tmp_path, use_docarray): assert index.get_current_count() == 10 +def test_index_torch(tmp_path): + docs = [TorchDoc(tens=np.random.randn(10)) for _ in range(10)] + assert isinstance(docs[0].tens, torch.Tensor) + assert isinstance(docs[0].tens, TorchTensor) + + store = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path)) + + store.index(docs) + assert store.num_docs() == 10 + for index in store._hnsw_indices.values(): + assert index.get_current_count() == 10 + + +@pytest.mark.tensorflow +def test_index_tf(tmp_path): + from docarray.typing import TensorFlowTensor + + class TfDoc(BaseDocument): + tens: TensorFlowTensor[10] + + docs = [TfDoc(tens=np.random.randn(10)) for _ in range(10)] + # assert isinstance(docs[0].tens, torch.Tensor) + assert isinstance(docs[0].tens, TensorFlowTensor) + + store = HnswDocumentIndex[TfDoc](work_dir=str(tmp_path)) + + store.index(docs) + assert store.num_docs() == 10 + for index in store._hnsw_indices.values(): + assert index.get_current_count() == 10 + + +def test_index_builtin_docs(tmp_path): + # TextDoc + class TextSchema(TextDoc): + embedding: Optional[NdArrayEmbedding] = Field(dim=10) + + store = HnswDocumentIndex[TextSchema](work_dir=str(tmp_path)) + + store.index( + DocumentArray[TextDoc]( + [TextDoc(embedding=np.random.randn(10), text=f'{i}') for i in range(10)] + ) + ) + assert store.num_docs() == 10 + for index in store._hnsw_indices.values(): + assert index.get_current_count() == 10 + + # ImageDoc + class ImageSchema(ImageDoc): + embedding: Optional[NdArrayEmbedding] = Field(dim=10) + + store = HnswDocumentIndex[ImageSchema]( + work_dir=str(os.path.join(tmp_path, 'image')) + ) + + store.index( + DocumentArray[ImageDoc]( + [ + ImageDoc( + embedding=np.random.randn(10), tensor=np.random.randn(3, 224, 224) + ) + for _ in range(10) + ] + ) + ) + assert store.num_docs() == 10 + for index in store._hnsw_indices.values(): + assert index.get_current_count() == 10 + + def test_get_single(ten_simple_docs, ten_flat_docs, ten_nested_docs, tmp_path): simple_path = tmp_path / 'simple' flat_path = tmp_path / 'flat' diff --git a/tests/index/hnswlib/test_persist_data.py b/tests/index/hnswlib/test_persist_data.py new file mode 100644 index 00000000000..a0a86eee9ab --- /dev/null +++ b/tests/index/hnswlib/test_persist_data.py @@ -0,0 +1,80 @@ +import numpy as np +import pytest +from pydantic import Field + +from docarray import BaseDocument +from docarray.index import HnswDocumentIndex +from docarray.typing import NdArray + +pytestmark = [pytest.mark.slow, pytest.mark.index] + + +class SimpleDoc(BaseDocument): + tens: NdArray[10] = Field(dim=1000) + + +class NestedDoc(BaseDocument): + d: SimpleDoc + tens: NdArray[50] + + +def test_persist_and_restore(tmp_path): + query = SimpleDoc(tens=np.random.random((10,))) + + # create index + store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + store.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(10)]) + assert store.num_docs() == 10 + find_results_before = store.find(query, search_field='tens', limit=5) + + # delete and restore + del store + store = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path)) + assert store.num_docs() == 10 + find_results_after = store.find(query, search_field='tens', limit=5) + for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): + assert doc_before.id == doc_after.id + assert (doc_before.tens == doc_after.tens).all() + + # add new data + store.index([SimpleDoc(tens=np.random.random((10,))) for _ in range(5)]) + assert store.num_docs() == 15 + + +def test_persist_and_restore_nested(tmp_path): + query = NestedDoc( + tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) + ) + + # create index + store = HnswDocumentIndex[NestedDoc](work_dir=str(tmp_path)) + store.index( + [ + NestedDoc( + tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) + ) + for _ in range(10) + ] + ) + assert store.num_docs() == 10 + find_results_before = store.find(query, search_field='d__tens', limit=5) + + # delete and restore + del store + store = HnswDocumentIndex[NestedDoc](work_dir=str(tmp_path)) + assert store.num_docs() == 10 + find_results_after = store.find(query, search_field='d__tens', limit=5) + for doc_before, doc_after in zip(find_results_before[0], find_results_after[0]): + assert doc_before.id == doc_after.id + assert (doc_before.tens == doc_after.tens).all() + + # delete and restore + store.index( + [ + NestedDoc( + tens=np.random.random((50,)), d=SimpleDoc(tens=np.random.random((10,))) + ) + for _ in range(5) + ] + ) + assert store.num_docs() == 15