diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index c587846c8c4..3f0ab25a1b0 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -26,7 +26,8 @@ from docarray import BaseDoc, DocArray from docarray.array.abstract_array import AnyDocArray from docarray.typing import AnyTensor -from docarray.utils._internal._typing import unwrap_optional_type +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal._typing import is_tensor_union from docarray.utils._internal.misc import is_tf_available, torch_imported from docarray.utils.find import FindResult, _FindResult @@ -676,22 +677,37 @@ def _flatten_schema( if is_union_type(t_): union_args = get_args(t_) - if len(union_args) == 2 and type(None) in union_args: + + if is_tensor_union(t_): + names_types_fields.append( + (name_prefix + field_name, AbstractTensor, field_) + ) + + elif len(union_args) == 2 and type(None) in union_args: # simple "Optional" type, treat as special case: # treat as if it was a single non-optional type for t_arg in union_args: - if t_arg is type(None): - pass - elif issubclass(t_arg, BaseDoc): - names_types_fields.extend( - cls._flatten_schema(t_arg, name_prefix=inner_prefix) - ) + if t_arg is not type(None): + if issubclass(t_arg, BaseDoc): + names_types_fields.extend( + cls._flatten_schema(t_arg, name_prefix=inner_prefix) + ) + else: + names_types_fields.append( + (name_prefix + field_name, t_arg, field_) + ) else: - names_types_fields.append((field_name, t_, field_)) + raise ValueError( + f'Union type {t_} is not supported. Only Union of subclasses of AbstractTensor or Union[type, None] are supported.' + ) elif issubclass(t_, BaseDoc): names_types_fields.extend( cls._flatten_schema(t_, name_prefix=inner_prefix) ) + elif issubclass(t_, AbstractTensor): + names_types_fields.append( + (name_prefix + field_name, AbstractTensor, field_) + ) else: names_types_fields.append((name_prefix + field_name, t_, field_)) return names_types_fields @@ -705,16 +721,8 @@ def _create_column_infos(self, schema: Type[BaseDoc]) -> Dict[str, _ColumnInfo]: """ column_infos: Dict[str, _ColumnInfo] = dict() for field_name, type_, field_ in self._flatten_schema(schema): - if is_optional_type(type_): - column_infos[field_name] = self._create_single_column( - field_, unwrap_optional_type(type_) - ) - elif is_union_type(type_): - raise ValueError( - 'Union types are not supported in the schema of a DocumentIndex.' - f' Instead of using type {type_} use a single specific type.' - ) - elif issubclass(type_, AnyDocArray): + # Union types are handle in _flatten_schema + if issubclass(type_, AnyDocArray): raise ValueError( 'Indexing field of DocArray type (=subindex)' 'is not yet supported.' @@ -725,7 +733,6 @@ def _create_column_infos(self, schema: Type[BaseDoc]) -> Dict[str, _ColumnInfo]: def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo: custom_config = field.field_info.extra - if 'col_type' in custom_config.keys(): db_type = custom_config['col_type'] custom_config.pop('col_type') @@ -740,13 +747,13 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo config.update(custom_config) # parse n_dim from parametrized tensor type if ( - hasattr(type_, '__docarray_target_shape__') - and type_.__docarray_target_shape__ + hasattr(field.type_, '__docarray_target_shape__') + and field.type_.__docarray_target_shape__ ): - if len(type_.__docarray_target_shape__) == 1: - n_dim = type_.__docarray_target_shape__[0] + if len(field.type_.__docarray_target_shape__) == 1: + n_dim = field.type_.__docarray_target_shape__[0] else: - n_dim = type_.__docarray_target_shape__ + n_dim = field.type_.__docarray_target_shape__ else: n_dim = None return _ColumnInfo( @@ -776,19 +783,23 @@ def _validate_docs( ) reference_names = [name for (name, _, _) in reference_schema_flat] reference_types = [t_ for (_, t_, _) in reference_schema_flat] + try: + input_schema_flat = self._flatten_schema(docs.document_type) + except ValueError: + pass + else: + input_names = [name for (name, _, _) in input_schema_flat] + input_types = [t_ for (_, t_, _) in input_schema_flat] + # this could be relaxed in the future, + # see schema translation ideas in the design doc + names_compatible = reference_names == input_names + types_compatible = all( + (issubclass(t2, t1)) + for (t1, t2) in zip(reference_types, input_types) + ) + if names_compatible and types_compatible: + return docs - input_schema_flat = self._flatten_schema(docs.document_type) - input_names = [name for (name, _, _) in input_schema_flat] - input_types = [t_ for (_, t_, _) in input_schema_flat] - # this could be relaxed in the future, - # see schema translation ideas in the design doc - names_compatible = reference_names == input_names - types_compatible = all( - (not is_union_type(t2) and issubclass(t1, t2)) - for (t1, t2) in zip(reference_types, input_types) - ) - if names_compatible and types_compatible: - return docs out_docs = [] for i in range(len(docs)): # validate the data @@ -836,10 +847,14 @@ def _convert_dict_to_doc( :param schema: The schema of the Document object :return: A Document object """ - for field_name, _ in schema.__fields__.items(): - t_ = unwrap_optional_type(schema._get_field_type(field_name)) - if issubclass(t_, BaseDoc): + t_ = schema._get_field_type(field_name) + if is_optional_type(t_): + for t_arg in get_args(t_): + if t_arg is not type(None): + t_ = t_arg + + if not is_union_type(t_) and issubclass(t_, BaseDoc): inner_dict = {} fields = [ diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index ab8aa6e56c1..5dfbc987741 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -29,6 +29,7 @@ _raise_not_supported, ) from docarray.proto import DocumentProto +from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.misc import is_np_int, is_tf_available, is_torch_available from docarray.utils.filter import filter_docs from docarray.utils.find import _FindResult @@ -36,7 +37,7 @@ TSchema = TypeVar('TSchema', bound=BaseDoc) T = TypeVar('T', bound='HnswDocumentIndex') -HNSWLIB_PY_VEC_TYPES = [list, tuple, np.ndarray] +HNSWLIB_PY_VEC_TYPES = [list, tuple, np.ndarray, AbstractTensor] if is_torch_available(): import torch diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py index 9bbc0162432..62680cf964e 100644 --- a/docarray/utils/_internal/_typing.py +++ b/docarray/utils/_internal/_typing.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from typing_inspect import get_args, is_optional_type, is_union_type +from typing_inspect import get_args, is_union_type from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -32,17 +32,3 @@ def change_cls_name(cls: type, new_name: str, scope: Optional[dict] = None) -> N scope[new_name] = cls cls.__qualname__ = cls.__qualname__[: -len(cls.__name__)] + new_name cls.__name__ = new_name - - -def unwrap_optional_type(type_: Any) -> Any: - """Return the type of an Optional type, e.g. `unwrap_optional(Optional[str]) == str`; - `unwrap_optional(Union[None, int, None]) == int`. - - :param type_: the type to unwrap - :return: the "core" type of an Optional type - """ - if not is_optional_type(type_): - return type_ - for arg in get_args(type_): - if arg is not type(None): - return arg diff --git a/docs/how_to/add_doc_index.md b/docs/how_to/add_doc_index.md index 8fb03b9978b..4f0125428e1 100644 --- a/docs/how_to/add_doc_index.md +++ b/docs/how_to/add_doc_index.md @@ -139,7 +139,7 @@ class _ColumnInfo: config: Dict[str, Any] ``` -- `docarray_type` is the type of the column in DocArray, e.g. `NdArray` or `str` +- `docarray_type` is the type of the column in DocArray, e.g. `AbstractTensor` or `str` - `db_type` is the type of the column in the Document Index, e.g. `np.ndarray` or `str`. You can customize the mapping from `docarray_type` to `db_type`, as we will see later. - `config` is a dictionary of configurations for the column. For example, for the `other_tensor` column above, this would contain the `space` and `dim` configurations. - `n_dim` is the dimensionality of the column, e.g. `100` for a 100-dimensional vector. See further guidance on this below. @@ -153,6 +153,9 @@ By default, it holds that `_ColumnInfo.docarray_type == self.python_type_to_db_t However, you should not rely on this, because a user can manually specify a different db_type. Therefore, your implementation should rely on `_ColumnInfo.db_type` and not directly call `python_type_to_db_type()`. +**Caution** +If a subclass of `AbstractTensor` appears in the Document Index's schema (i.e. `TorchTensor`, `NdArray`, or `TensorFlowTensor`), then `_ColumnInfo.docarray_type` will simply show `AbstractTensor` instead of the specific subclass. This is because the abstract class normalizes all input data of type `AbstractTensor` to `np.ndarray` anyways, which should make your life easier. Just be sure to properly handle `AbstractTensor` as a possible value or `_ColumnInfo.docarray_type`, and you won't have to worry about the differences between torch, tf, and np. + ### Properly handle `n_dim` `_ColumnInfo.n_dim` is automatically obtained from type parametrizations of the form `NdArray[100]`; @@ -296,7 +299,7 @@ The details of each method should become clear from the docstrings and type hint This method is slightly special, because 1) it is not exposed to the user, and 2) you absolutely have to implement it. -It is intended to do the following: It takes a type of a field in the store's schema (e.g. `NdArray` for `tensor`), and returns the corresponding type in the database (e.g. `np.ndarray`). +It is intended to do the following: It takes a type of a field in the store's schema (e.g. `AbstractTensor` for `tensor`), and returns the corresponding type in the database (e.g. `np.ndarray`). The `BaseDocIndex` class uses this information to create and populate the `_ColumnInfo`s in `self._column_infos`. If the user wants to change the default behaviour, one can set the db type by using the `col_type` field: diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index cd1d8339ab2..b5774020524 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -6,8 +6,11 @@ from pydantic import Field from docarray import BaseDoc, DocArray +from docarray.documents import ImageDoc from docarray.index.abstract import BaseDocIndex, _raise_not_composable -from docarray.typing import ID, NdArray +from docarray.typing import ID, ImageBytes, ImageUrl, NdArray +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils._internal.misc import torch_imported pytestmark = pytest.mark.index @@ -45,6 +48,7 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig): str: {'hi': 'there'}, np.ndarray: {'you': 'good?'}, 'varchar': {'good': 'bye'}, + AbstractTensor: {'dim': 1000}, } ) @@ -103,7 +107,7 @@ def test_create_columns(): assert store._column_infos['id'].n_dim is None assert store._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['tens'].docarray_type, NdArray) + assert issubclass(store._column_infos['tens'].docarray_type, AbstractTensor) assert store._column_infos['tens'].db_type == str assert store._column_infos['tens'].n_dim == 10 assert store._column_infos['tens'].config == {'dim': 1000, 'hi': 'there'} @@ -117,12 +121,12 @@ def test_create_columns(): assert store._column_infos['id'].n_dim is None assert store._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['tens_one'].docarray_type, NdArray) + assert issubclass(store._column_infos['tens_one'].docarray_type, AbstractTensor) assert store._column_infos['tens_one'].db_type == str assert store._column_infos['tens_one'].n_dim is None assert store._column_infos['tens_one'].config == {'dim': 10, 'hi': 'there'} - assert issubclass(store._column_infos['tens_two'].docarray_type, NdArray) + assert issubclass(store._column_infos['tens_two'].docarray_type, AbstractTensor) assert store._column_infos['tens_two'].db_type == str assert store._column_infos['tens_two'].n_dim is None assert store._column_infos['tens_two'].config == {'dim': 50, 'hi': 'there'} @@ -136,7 +140,7 @@ def test_create_columns(): assert store._column_infos['id'].n_dim is None assert store._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['d__tens'].docarray_type, NdArray) + assert issubclass(store._column_infos['d__tens'].docarray_type, AbstractTensor) assert store._column_infos['d__tens'].db_type == str assert store._column_infos['d__tens'].n_dim == 10 assert store._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} @@ -147,15 +151,15 @@ def test_flatten_schema(): fields = SimpleDoc.__fields__ assert set(store._flatten_schema(SimpleDoc)) == { ('id', ID, fields['id']), - ('tens', NdArray[10], fields['tens']), + ('tens', AbstractTensor, fields['tens']), } store = DummyDocIndex[FlatDoc]() fields = FlatDoc.__fields__ assert set(store._flatten_schema(FlatDoc)) == { ('id', ID, fields['id']), - ('tens_one', NdArray, fields['tens_one']), - ('tens_two', NdArray, fields['tens_two']), + ('tens_one', AbstractTensor, fields['tens_one']), + ('tens_two', AbstractTensor, fields['tens_two']), } store = DummyDocIndex[NestedDoc]() @@ -164,7 +168,7 @@ def test_flatten_schema(): assert set(store._flatten_schema(NestedDoc)) == { ('id', ID, fields['id']), ('d__id', ID, fields_nested['id']), - ('d__tens', NdArray[10], fields_nested['tens']), + ('d__tens', AbstractTensor, fields_nested['tens']), } store = DummyDocIndex[DeepNestedDoc]() @@ -175,7 +179,44 @@ def test_flatten_schema(): ('id', ID, fields['id']), ('d__id', ID, fields_nested['id']), ('d__d__id', ID, fields_nested_nested['id']), - ('d__d__tens', NdArray[10], fields_nested_nested['tens']), + ('d__d__tens', AbstractTensor, fields_nested_nested['tens']), + } + + +def test_flatten_schema_union(): + class MyDoc(BaseDoc): + image: ImageDoc + + store = DummyDocIndex[MyDoc]() + fields = MyDoc.__fields__ + fields_image = ImageDoc.__fields__ + + if torch_imported: + from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor + + assert set(store._flatten_schema(MyDoc)) == { + ('id', ID, fields['id']), + ('image__id', ID, fields_image['id']), + ('image__url', ImageUrl, fields_image['url']), + ('image__tensor', AbstractTensor, fields_image['tensor']), + ('image__embedding', AbstractTensor, fields_image['embedding']), + ('image__bytes_', ImageBytes, fields_image['bytes_']), + } + + class MyDoc2(BaseDoc): + tensor: Union[NdArray, str] + + with pytest.raises(ValueError): + _ = DummyDocIndex[MyDoc2]() + + class MyDoc3(BaseDoc): + tensor: Union[NdArray, ImageTorchTensor] + + store = DummyDocIndex[MyDoc3]() + fields = MyDoc3.__fields__ + assert set(store._flatten_schema(MyDoc3)) == { + ('id', ID, fields['id']), + ('tensor', AbstractTensor, fields['tensor']), } @@ -303,9 +344,12 @@ def test_docs_validation_unions(): class OptionalDoc(BaseDoc): tens: Optional[NdArray[10]] = Field(dim=1000) - class UnionDoc(BaseDoc): + class MixedUnionDoc(BaseDoc): tens: Union[NdArray[10], str] = Field(dim=1000) + class TensorUnionDoc(BaseDoc): + tens: Union[NdArray[10], AbstractTensor] = Field(dim=1000) + # OPTIONAL store = DummyDocIndex[SimpleDoc]() in_list = [OptionalDoc(tens=np.random.random((10,)))] @@ -316,15 +360,28 @@ class UnionDoc(BaseDoc): with pytest.raises(ValueError): store._validate_docs([OptionalDoc(tens=None)]) - # OTHER UNION + # MIXED UNION store = DummyDocIndex[SimpleDoc]() - in_list = [UnionDoc(tens=np.random.random((10,)))] + in_list = [MixedUnionDoc(tens=np.random.random((10,)))] assert isinstance(store._validate_docs(in_list), DocArray[BaseDoc]) - in_da = DocArray[UnionDoc](in_list) + in_da = DocArray[MixedUnionDoc](in_list) assert isinstance(store._validate_docs(in_da), DocArray[BaseDoc]) with pytest.raises(ValueError): - store._validate_docs([UnionDoc(tens='hello')]) + store._validate_docs([MixedUnionDoc(tens='hello')]) + + # TENSOR UNION + store = DummyDocIndex[TensorUnionDoc]() + in_list = [SimpleDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocArray[BaseDoc]) + in_da = DocArray[SimpleDoc](in_list) + assert store._validate_docs(in_da) == in_da + + store = DummyDocIndex[SimpleDoc]() + in_list = [TensorUnionDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocArray[BaseDoc]) + in_da = DocArray[TensorUnionDoc](in_list) + assert store._validate_docs(in_da) == in_da def test_get_value(): @@ -474,3 +531,30 @@ def test_convert_dict_to_doc(): assert doc.d.id == doc_dict_copy['d__id'] assert doc.d.d.id == doc_dict_copy['d__d__id'] assert np.all(doc.d.d.tens == doc_dict_copy['d__d__tens']) + + class MyDoc(BaseDoc): + image: ImageDoc + + store = DummyDocIndex[MyDoc]() + doc_dict = { + 'id': 'root', + 'image__id': 'nested', + 'image__tensor': np.random.random((128,)), + } + doc = store._convert_dict_to_doc(doc_dict, store._schema) + + if torch_imported: + from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor + + class MyDoc2(BaseDoc): + tens: Union[NdArray, ImageTorchTensor] + + store = DummyDocIndex[MyDoc2]() + doc_dict = { + 'id': 'root', + 'tens': np.random.random((128,)), + } + doc_dict_copy = doc_dict.copy() + doc = store._convert_dict_to_doc(doc_dict, store._schema) + assert doc.id == doc_dict_copy['id'] + assert np.all(doc.tens == doc_dict_copy['tens'])