Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ jobs:
- name: Test
id: test
run: |
poetry run pytest -m 'index' tests
poetry run pytest -m 'index and (not tensorflow)' tests
timeout-minutes: 30

docarray-test-tensorflow:
Expand Down
51 changes: 40 additions & 11 deletions docarray/index/abstract.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, replace
from typing import (
Expand All @@ -20,21 +21,26 @@

import numpy as np
from pydantic.error_wrappers import ValidationError
from typing_inspect import get_args, is_union_type
from typing_inspect import get_args, is_optional_type, is_union_type

from docarray import BaseDocument, DocumentArray
from docarray.array.abstract_array import AnyDocumentArray
from docarray.typing import AnyTensor
from docarray.utils._typing import unwrap_optional_type
from docarray.utils.find import FindResult, _FindResult
from docarray.utils.misc import torch_imported
import logging
from docarray.utils.misc import is_tf_available, torch_imported

if TYPE_CHECKING:
from pydantic.fields import ModelField

if torch_imported:
import torch

if is_tf_available():
import tensorflow as tf # type: ignore

from docarray.typing import TensorFlowTensor

TSchema = TypeVar('TSchema', bound=BaseDocument)


Expand Down Expand Up @@ -614,7 +620,13 @@ def _get_col_value_dict(
docs_seq = docs

def _col_gen(col_name: str):
return (self._get_values_by_column([doc], col_name)[0] for doc in docs_seq)
return (
self._to_numpy(
self._get_values_by_column([doc], col_name)[0],
allow_passthrough=True,
)
for doc in docs_seq
)

return {col_name: _col_gen(col_name) for col_name in self._column_infos}

Expand Down Expand Up @@ -697,7 +709,11 @@ def _create_column_infos(
"""
column_infos: Dict[str, _ColumnInfo] = dict()
for field_name, type_, field_ in self._flatten_schema(schema):
if is_union_type(type_):
if is_optional_type(type_):
column_infos[field_name] = self._create_single_column(
field_, unwrap_optional_type(type_)
)
elif is_union_type(type_):
raise ValueError(
'Union types are not supported in the schema of a DocumentIndex.'
f' Instead of using type {type_} use a single specific type.'
Expand Down Expand Up @@ -793,15 +809,28 @@ def _validate_docs(

return DocumentArray[BaseDocument].construct(out_docs)

def _to_numpy(self, val: Any) -> Any:
def _to_numpy(self, val: Any, allow_passthrough=False) -> Any:
"""
Converts a value to a numpy array, if possible.

:param val: The value to convert
:param allow_passthrough: If True, the value is returned as-is if it is not convertible to a numpy array.
If False, a `ValueError` is raised if the value is not convertible to a numpy array.
:return: The value as a numpy array, or as-is if `allow_passthrough` is True and the value is not convertible
"""
if isinstance(val, np.ndarray):
return val
elif isinstance(val, (list, tuple)):
if is_tf_available() and isinstance(val, TensorFlowTensor):
return val.unwrap().numpy()
if isinstance(val, (list, tuple)):
return np.array(val)
elif torch_imported and isinstance(val, torch.Tensor):
if (torch_imported and isinstance(val, torch.Tensor)) or (
is_tf_available() and isinstance(val, tf.Tensor)
):
return val.numpy()
else:
raise ValueError(f'Unsupported input type for {type(self)}: {type(val)}')
if allow_passthrough:
return val
raise ValueError(f'Unsupported input type for {type(self)}: {type(val)}')

def _convert_dict_to_doc(
self, doc_dict: Dict[str, Any], schema: Type[BaseDocument]
Expand All @@ -815,7 +844,7 @@ def _convert_dict_to_doc(
"""

for field_name, _ in schema.__fields__.items():
t_ = schema._get_field_type(field_name)
t_ = unwrap_optional_type(schema._get_field_type(field_name))
if issubclass(t_, BaseDocument):
inner_dict = {}

Expand Down
41 changes: 21 additions & 20 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import hnswlib
import numpy as np

import docarray.typing
from docarray import BaseDocument, DocumentArray
from docarray.index.abstract import (
BaseDocumentIndex,
Expand All @@ -32,17 +31,25 @@
from docarray.proto import DocumentProto
from docarray.utils.filter import filter_docs
from docarray.utils.find import _FindResult
from docarray.utils.misc import is_np_int, torch_imported
from docarray.utils.misc import is_np_int, is_tf_available, is_torch_available

TSchema = TypeVar('TSchema', bound=BaseDocument)
T = TypeVar('T', bound='HnswDocumentIndex')

HNSWLIB_PY_VEC_TYPES = [list, tuple, np.ndarray]
if torch_imported:
if is_torch_available():
import torch

HNSWLIB_PY_VEC_TYPES.append(torch.Tensor)

if is_tf_available():
import tensorflow as tf # type: ignore

from docarray.typing import TensorFlowTensor

HNSWLIB_PY_VEC_TYPES.append(tf.Tensor)
HNSWLIB_PY_VEC_TYPES.append(TensorFlowTensor)


def _collect_query_args(method_name: str): # TODO: use partialmethod instead
def inner(self, *args, **kwargs):
Expand Down Expand Up @@ -84,8 +91,14 @@ def __init__(self, db_config=None, **kwargs):
self._hnsw_indices = {}
for col_name, col in self._column_infos.items():
if not col.config:
self._logger.warning(
f'No index was created for `{col_name}` as it does not have a config'
# non-tensor type; don't create an index
continue
if not load_existing and (
(not col.n_dim and col.config['dim'] < 0) or not col.config['index']
):
# tensor type, but don't index
self._logger.info(
f'Not indexing column {col_name}; either `index=False` is set or no dimensionality is specified'
)
continue
if load_existing:
Expand Down Expand Up @@ -133,7 +146,8 @@ class RuntimeConfig(BaseDocumentIndex.RuntimeConfig):
default_column_config: Dict[Type, Dict[str, Any]] = field(
default_factory=lambda: {
np.ndarray: {
'dim': 128,
'dim': -1,
'index': True, # if False, don't index at all
'space': 'l2', # 'l2', 'ip', 'cosine'
'max_elements': 1024,
'ef_construction': 200,
Expand All @@ -157,10 +171,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
if issubclass(python_type, allowed_type):
return np.ndarray

if python_type == docarray.typing.ID:
return None

raise ValueError(f'Unsupported column type for {type(self)}: {python_type}')
return None # all types allowed, but no db type needed

def _index(self, column_data_dic, **kwargs):
# not needed, we implement `index` directly
Expand Down Expand Up @@ -328,16 +339,6 @@ def _load_index(self, col_name: str, col: '_ColumnInfo') -> hnswlib.Index:
index.load_index(self._hnsw_locations[col_name])
return index

def _to_numpy(self, val: Any) -> Any:
if isinstance(val, np.ndarray):
return val
elif isinstance(val, (list, tuple)):
return np.array(val)
elif torch_imported and isinstance(val, torch.Tensor):
return val.numpy()
else:
raise ValueError(f'Unsupported input type for {type(self)}: {type(val)}')

# HNSWLib helpers
def _create_index_class(self, col: '_ColumnInfo') -> hnswlib.Index:
"""Create an instance of hnswlib.index without initializing it."""
Expand Down
16 changes: 15 additions & 1 deletion docarray/utils/_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Optional

from typing_inspect import get_args, is_union_type
from typing_inspect import get_args, is_optional_type, is_union_type

from docarray.typing.tensor.abstract_tensor import AbstractTensor

Expand Down Expand Up @@ -32,3 +32,17 @@ def change_cls_name(cls: type, new_name: str, scope: Optional[dict] = None) -> N
scope[new_name] = cls
cls.__qualname__ = cls.__qualname__[: -len(cls.__name__)] + new_name
cls.__name__ = new_name


def unwrap_optional_type(type_: Any) -> Any:
"""Return the type of an Optional type, e.g. `unwrap_optional(Optional[str]) == str`;
`unwrap_optional(Union[None, int, None]) == int`.

:param type_: the type to unwrap
:return: the "core" type of an Optional type
"""
if not is_optional_type(type_):
return type_
for arg in get_args(type_):
if arg is not type(None):
return arg
64 changes: 63 additions & 1 deletion tests/index/hnswlib/test_find.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import pytest
import torch
from pydantic import Field

from docarray import BaseDocument
from docarray.index import HnswDocumentIndex
from docarray.typing import NdArray
from docarray.typing import NdArray, TorchTensor

pytestmark = [pytest.mark.slow, pytest.mark.index]

Expand All @@ -26,6 +27,10 @@ class DeepNestedDoc(BaseDocument):
d: NestedDoc


class TorchDoc(BaseDocument):
tens: TorchTensor[10]


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_simple_schema(tmp_path, space):
class SimpleSchema(BaseDocument):
Expand All @@ -49,6 +54,63 @@ class SimpleSchema(BaseDocument):
assert np.allclose(result.tens, np.zeros(10))


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_torch(tmp_path, space):
store = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path))

index_docs = [TorchDoc(tens=np.zeros(10)) for _ in range(10)]
index_docs.append(TorchDoc(tens=np.ones(10)))
store.index(index_docs)

for doc in index_docs:
assert isinstance(doc.tens, TorchTensor)

query = TorchDoc(tens=np.ones(10))

result_docs, scores = store.find(query, search_field='tens', limit=5)

assert len(result_docs) == 5
assert len(scores) == 5
for doc in result_docs:
assert isinstance(doc.tens, TorchTensor)
assert result_docs[0].id == index_docs[-1].id
assert torch.allclose(result_docs[0].tens, index_docs[-1].tens)
for result in result_docs[1:]:
assert torch.allclose(result.tens, torch.zeros(10, dtype=torch.float64))


@pytest.mark.tensorflow
def test_find_tensorflow(tmp_path):
from docarray.typing import TensorFlowTensor

class TfDoc(BaseDocument):
tens: TensorFlowTensor[10]

store = HnswDocumentIndex[TfDoc](work_dir=str(tmp_path))

index_docs = [TfDoc(tens=np.zeros(10)) for _ in range(10)]
index_docs.append(TfDoc(tens=np.ones(10)))
store.index(index_docs)

for doc in index_docs:
assert isinstance(doc.tens, TensorFlowTensor)

query = TfDoc(tens=np.ones(10))

result_docs, scores = store.find(query, search_field='tens', limit=5)

assert len(result_docs) == 5
assert len(scores) == 5
for doc in result_docs:
assert isinstance(doc.tens, TensorFlowTensor)
assert result_docs[0].id == index_docs[-1].id
assert np.allclose(
result_docs[0].tens.unwrap().numpy(), index_docs[-1].tens.unwrap().numpy()
)
for result in result_docs[1:]:
assert np.allclose(result.tens.unwrap().numpy(), np.zeros(10))


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_flat_schema(tmp_path, space):
class FlatSchema(BaseDocument):
Expand Down
Loading