Skip to content
Merged
95 changes: 55 additions & 40 deletions docarray/index/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from docarray import BaseDoc, DocArray
from docarray.array.abstract_array import AnyDocArray
from docarray.typing import AnyTensor
from docarray.utils._internal._typing import unwrap_optional_type
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import is_tensor_union
from docarray.utils._internal.misc import is_tf_available, torch_imported
from docarray.utils.find import FindResult, _FindResult

Expand Down Expand Up @@ -676,22 +677,37 @@ def _flatten_schema(

if is_union_type(t_):
union_args = get_args(t_)
if len(union_args) == 2 and type(None) in union_args:

if is_tensor_union(t_):
names_types_fields.append(
(name_prefix + field_name, AbstractTensor, field_)
)

elif len(union_args) == 2 and type(None) in union_args:
# simple "Optional" type, treat as special case:
# treat as if it was a single non-optional type
for t_arg in union_args:
if t_arg is type(None):
pass
elif issubclass(t_arg, BaseDoc):
names_types_fields.extend(
cls._flatten_schema(t_arg, name_prefix=inner_prefix)
)
if t_arg is not type(None):
if issubclass(t_arg, BaseDoc):
names_types_fields.extend(
cls._flatten_schema(t_arg, name_prefix=inner_prefix)
)
else:
names_types_fields.append(
(name_prefix + field_name, t_arg, field_)
)
else:
names_types_fields.append((field_name, t_, field_))
raise ValueError(
f'Union type {t_} is not supported. Only Union of subclasses of AbstractTensor or Union[type, None] are supported.'
)
elif issubclass(t_, BaseDoc):
names_types_fields.extend(
cls._flatten_schema(t_, name_prefix=inner_prefix)
)
elif issubclass(t_, AbstractTensor):
names_types_fields.append(
(name_prefix + field_name, AbstractTensor, field_)
)
else:
names_types_fields.append((name_prefix + field_name, t_, field_))
return names_types_fields
Expand All @@ -705,16 +721,8 @@ def _create_column_infos(self, schema: Type[BaseDoc]) -> Dict[str, _ColumnInfo]:
"""
column_infos: Dict[str, _ColumnInfo] = dict()
for field_name, type_, field_ in self._flatten_schema(schema):
if is_optional_type(type_):
column_infos[field_name] = self._create_single_column(
field_, unwrap_optional_type(type_)
)
elif is_union_type(type_):
raise ValueError(
'Union types are not supported in the schema of a DocumentIndex.'
f' Instead of using type {type_} use a single specific type.'
)
elif issubclass(type_, AnyDocArray):
# Union types are handle in _flatten_schema
if issubclass(type_, AnyDocArray):
raise ValueError(
'Indexing field of DocArray type (=subindex)'
'is not yet supported.'
Expand All @@ -725,7 +733,6 @@ def _create_column_infos(self, schema: Type[BaseDoc]) -> Dict[str, _ColumnInfo]:

def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo:
custom_config = field.field_info.extra

if 'col_type' in custom_config.keys():
db_type = custom_config['col_type']
custom_config.pop('col_type')
Expand All @@ -740,13 +747,13 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo
config.update(custom_config)
# parse n_dim from parametrized tensor type
if (
hasattr(type_, '__docarray_target_shape__')
and type_.__docarray_target_shape__
hasattr(field.type_, '__docarray_target_shape__')
and field.type_.__docarray_target_shape__
):
if len(type_.__docarray_target_shape__) == 1:
n_dim = type_.__docarray_target_shape__[0]
if len(field.type_.__docarray_target_shape__) == 1:
n_dim = field.type_.__docarray_target_shape__[0]
else:
n_dim = type_.__docarray_target_shape__
n_dim = field.type_.__docarray_target_shape__
else:
n_dim = None
return _ColumnInfo(
Expand Down Expand Up @@ -776,19 +783,23 @@ def _validate_docs(
)
reference_names = [name for (name, _, _) in reference_schema_flat]
reference_types = [t_ for (_, t_, _) in reference_schema_flat]
try:
input_schema_flat = self._flatten_schema(docs.document_type)
except ValueError:
pass
else:
input_names = [name for (name, _, _) in input_schema_flat]
input_types = [t_ for (_, t_, _) in input_schema_flat]
# this could be relaxed in the future,
# see schema translation ideas in the design doc
names_compatible = reference_names == input_names
types_compatible = all(
(issubclass(t2, t1))
for (t1, t2) in zip(reference_types, input_types)
)
if names_compatible and types_compatible:
return docs

input_schema_flat = self._flatten_schema(docs.document_type)
input_names = [name for (name, _, _) in input_schema_flat]
input_types = [t_ for (_, t_, _) in input_schema_flat]
# this could be relaxed in the future,
# see schema translation ideas in the design doc
names_compatible = reference_names == input_names
types_compatible = all(
(not is_union_type(t2) and issubclass(t1, t2))
for (t1, t2) in zip(reference_types, input_types)
)
if names_compatible and types_compatible:
return docs
out_docs = []
for i in range(len(docs)):
# validate the data
Expand Down Expand Up @@ -836,10 +847,14 @@ def _convert_dict_to_doc(
:param schema: The schema of the Document object
:return: A Document object
"""

for field_name, _ in schema.__fields__.items():
t_ = unwrap_optional_type(schema._get_field_type(field_name))
if issubclass(t_, BaseDoc):
t_ = schema._get_field_type(field_name)
if is_optional_type(t_):
for t_arg in get_args(t_):
if t_arg is not type(None):
t_ = t_arg

if not is_union_type(t_) and issubclass(t_, BaseDoc):
inner_dict = {}

fields = [
Expand Down
3 changes: 2 additions & 1 deletion docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
_raise_not_supported,
)
from docarray.proto import DocumentProto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import is_np_int, is_tf_available, is_torch_available
from docarray.utils.filter import filter_docs
from docarray.utils.find import _FindResult

TSchema = TypeVar('TSchema', bound=BaseDoc)
T = TypeVar('T', bound='HnswDocumentIndex')

HNSWLIB_PY_VEC_TYPES = [list, tuple, np.ndarray]
HNSWLIB_PY_VEC_TYPES = [list, tuple, np.ndarray, AbstractTensor]
if is_torch_available():
import torch

Expand Down
16 changes: 1 addition & 15 deletions docarray/utils/_internal/_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Optional

from typing_inspect import get_args, is_optional_type, is_union_type
from typing_inspect import get_args, is_union_type

from docarray.typing.tensor.abstract_tensor import AbstractTensor

Expand Down Expand Up @@ -32,17 +32,3 @@ def change_cls_name(cls: type, new_name: str, scope: Optional[dict] = None) -> N
scope[new_name] = cls
cls.__qualname__ = cls.__qualname__[: -len(cls.__name__)] + new_name
cls.__name__ = new_name


def unwrap_optional_type(type_: Any) -> Any:
"""Return the type of an Optional type, e.g. `unwrap_optional(Optional[str]) == str`;
`unwrap_optional(Union[None, int, None]) == int`.

:param type_: the type to unwrap
:return: the "core" type of an Optional type
"""
if not is_optional_type(type_):
return type_
for arg in get_args(type_):
if arg is not type(None):
return arg
7 changes: 5 additions & 2 deletions docs/how_to/add_doc_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class _ColumnInfo:
config: Dict[str, Any]
```

- `docarray_type` is the type of the column in DocArray, e.g. `NdArray` or `str`
- `docarray_type` is the type of the column in DocArray, e.g. `AbstractTensor` or `str`
- `db_type` is the type of the column in the Document Index, e.g. `np.ndarray` or `str`. You can customize the mapping from `docarray_type` to `db_type`, as we will see later.
- `config` is a dictionary of configurations for the column. For example, for the `other_tensor` column above, this would contain the `space` and `dim` configurations.
- `n_dim` is the dimensionality of the column, e.g. `100` for a 100-dimensional vector. See further guidance on this below.
Expand All @@ -153,6 +153,9 @@ By default, it holds that `_ColumnInfo.docarray_type == self.python_type_to_db_t
However, you should not rely on this, because a user can manually specify a different db_type.
Therefore, your implementation should rely on `_ColumnInfo.db_type` and not directly call `python_type_to_db_type()`.

**Caution**
If a subclass of `AbstractTensor` appears in the Document Index's schema (i.e. `TorchTensor`, `NdArray`, or `TensorFlowTensor`), then `_ColumnInfo.docarray_type` will simply show `AbstractTensor` instead of the specific subclass. This is because the abstract class normalizes all input data of type `AbstractTensor` to `np.ndarray` anyways, which should make your life easier. Just be sure to properly handle `AbstractTensor` as a possible value or `_ColumnInfo.docarray_type`, and you won't have to worry about the differences between torch, tf, and np.

### Properly handle `n_dim`

`_ColumnInfo.n_dim` is automatically obtained from type parametrizations of the form `NdArray[100]`;
Expand Down Expand Up @@ -296,7 +299,7 @@ The details of each method should become clear from the docstrings and type hint

This method is slightly special, because 1) it is not exposed to the user, and 2) you absolutely have to implement it.

It is intended to do the following: It takes a type of a field in the store's schema (e.g. `NdArray` for `tensor`), and returns the corresponding type in the database (e.g. `np.ndarray`).
It is intended to do the following: It takes a type of a field in the store's schema (e.g. `AbstractTensor` for `tensor`), and returns the corresponding type in the database (e.g. `np.ndarray`).
The `BaseDocIndex` class uses this information to create and populate the `_ColumnInfo`s in `self._column_infos`.

If the user wants to change the default behaviour, one can set the db type by using the `col_type` field:
Expand Down
Loading