From 0d3ac91855d6254ab54ce692aeeca7c02a919090 Mon Sep 17 00:00:00 2001 From: AnneY Date: Mon, 27 Mar 2023 22:16:31 +0800 Subject: [PATCH 1/8] fix: flatten schema of abstract index Signed-off-by: AnneY --- docarray/index/abstract.py | 25 ++++--- .../index/base_classes/test_base_doc_store.py | 72 +++++++++++++++++-- 2 files changed, 84 insertions(+), 13 deletions(-) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 6aec9549726..8fb6102626a 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -26,7 +26,7 @@ from docarray import BaseDocument, DocumentArray from docarray.array.abstract_array import AnyDocumentArray from docarray.typing import AnyTensor -from docarray.utils._typing import unwrap_optional_type +from docarray.utils._typing import is_tensor_union, unwrap_optional_type from docarray.utils.find import FindResult, _FindResult from docarray.utils.misc import is_tf_available, torch_imported @@ -682,14 +682,23 @@ def _flatten_schema( # simple "Optional" type, treat as special case: # treat as if it was a single non-optional type for t_arg in union_args: - if t_arg is type(None): - pass - elif issubclass(t_arg, BaseDocument): - names_types_fields.extend( - cls._flatten_schema(t_arg, name_prefix=inner_prefix) - ) + if t_arg is not type(None): + if issubclass(t_arg, BaseDocument): + names_types_fields.extend( + cls._flatten_schema(t_arg, name_prefix=inner_prefix) + ) + else: + names_types_fields.append( + (name_prefix + field_name, t_arg, field_) + ) + + elif is_tensor_union(t_): + names_types_fields.append((name_prefix + field_name, t_, field_)) + else: - names_types_fields.append((field_name, t_, field_)) + raise Exception( + f'Union type {t_} is not supported. Only Union of subclasses of ndarray or Union[type, None] are supported.' + ) elif issubclass(t_, BaseDocument): names_types_fields.extend( cls._flatten_schema(t_, name_prefix=inner_prefix) diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index f008dff9d6d..b437dd57389 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -6,11 +6,12 @@ from pydantic import Field from docarray import BaseDocument, DocumentArray -from docarray.index.abstract import ( - BaseDocumentIndex, - _raise_not_composable, -) -from docarray.typing import ID, NdArray +from docarray.documents import ImageDoc +from docarray.index.abstract import BaseDocumentIndex, _raise_not_composable +from docarray.typing import ID, ImageBytes, ImageUrl, NdArray +from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding +from docarray.typing.tensor.image.image_ndarray import ImageNdArray +from docarray.utils.misc import is_tf_available, is_torch_available pytestmark = pytest.mark.index @@ -182,6 +183,67 @@ def test_flatten_schema(): } +def test_flatten_schema_union(): + class MyDoc(BaseDocument): + image: ImageDoc + + store = DummyDocIndex[MyDoc]() + fields = MyDoc.__fields__ + fields_image = ImageDoc.__fields__ + + torch_available = is_torch_available() + if torch_available: + from docarray.typing.tensor.embedding.torch import TorchEmbedding + from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor + + tf_available = is_tf_available() + if tf_available: + from docarray.typing.tensor.embedding.tensorflow import ( + TensorFlowEmbedding as TFEmbedding, + ) + from docarray.typing.tensor.image.image_tensorflow_tensor import ( + ImageTensorFlowTensor as ImageTFTensor, + ) + + tensor_type = Union[ImageNdArray, None] + embedding_type = Union[NdArrayEmbedding, None] + if tf_available and torch_available: + tensor_type = Union[ImageNdArray, ImageTorchTensor, ImageTFTensor, None] # type: ignore + embedding_type = Union[NdArrayEmbedding, TorchEmbedding, TFEmbedding, None] + elif tf_available: + tensor_type = Union[ImageNdArray, ImageTFTensor, None] + embedding_type = Union[NdArrayEmbedding, TFEmbedding, None] + elif torch_available: + tensor_type = Union[ImageNdArray, ImageTorchTensor, None] + embedding_type = Union[NdArrayEmbedding, TorchEmbedding, None] + + assert set(store._flatten_schema(MyDoc)) == { + ('id', ID, fields['id']), + ('image__id', ID, fields_image['id']), + ('image__url', ImageUrl, fields_image['url']), + ('image__tensor', tensor_type, fields_image['tensor']), + ('image__embedding', embedding_type, fields_image['embedding']), + ('image__bytes_', ImageBytes, fields_image['bytes_']), + } + + class MyDoc2(BaseDocument): + tensor: Union[NdArray, str] + + with pytest.raises(Exception): + _ = DummyDocIndex[MyDoc2]() + + # class MyDoc3(BaseDocument): + # tensor: Union[NdArray, ImageTorchTensor] + + # store = DummyDocIndex[MyDoc3]() + # fields = MyDoc3.__fields__ + # fields_image = ImageDoc.__fields__ + # assert set(store._flatten_schema(MyDoc3)) == { + # ('id', ID, fields['id']), + # ('tensor', Union[NdArray, ImageNdArray], fields_image['tensor']), + # } + + def test_columns_db_type_with_user_defined_mapping(tmp_path): class MyDoc(BaseDocument): tens: NdArray[10] = Field(dim=1000, col_type=np.ndarray) From cd38839dd4858f645fa433229db2bfefc883a372 Mon Sep 17 00:00:00 2001 From: AnneY Date: Tue, 28 Mar 2023 10:46:26 +0800 Subject: [PATCH 2/8] fix: _convert_dict_to_doc Signed-off-by: AnneY --- docarray/index/abstract.py | 32 ++++---- .../index/base_classes/test_base_doc_store.py | 77 ++++++++++--------- 2 files changed, 59 insertions(+), 50 deletions(-) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 8fb6102626a..cbaaaa548cd 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -21,11 +21,12 @@ import numpy as np from pydantic.error_wrappers import ValidationError -from typing_inspect import get_args, is_optional_type, is_union_type +from typing_inspect import get_args, is_union_type from docarray import BaseDocument, DocumentArray from docarray.array.abstract_array import AnyDocumentArray from docarray.typing import AnyTensor +from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._typing import is_tensor_union, unwrap_optional_type from docarray.utils.find import FindResult, _FindResult from docarray.utils.misc import is_tf_available, torch_imported @@ -693,7 +694,9 @@ def _flatten_schema( ) elif is_tensor_union(t_): - names_types_fields.append((name_prefix + field_name, t_, field_)) + names_types_fields.append( + (name_prefix + field_name, AbstractTensor, field_) + ) else: raise Exception( @@ -718,16 +721,11 @@ def _create_column_infos( """ column_infos: Dict[str, _ColumnInfo] = dict() for field_name, type_, field_ in self._flatten_schema(schema): - if is_optional_type(type_): - column_infos[field_name] = self._create_single_column( - field_, unwrap_optional_type(type_) - ) - elif is_union_type(type_): - raise ValueError( - 'Union types are not supported in the schema of a DocumentIndex.' - f' Instead of using type {type_} use a single specific type.' - ) - elif issubclass(type_, AnyDocumentArray): + # if is_optional_type(type_): # TODO + # column_infos[field_name] = self._create_single_column( + # field_, unwrap_optional_type(type_) + # ) + if issubclass(type_, AnyDocumentArray): raise ValueError( 'Indexing field of DocumentArray type (=subindex)' 'is not yet supported.' @@ -782,6 +780,7 @@ def _validate_docs( """ if isinstance(docs, BaseDocument): docs = [docs] + # TODO List of Docs if isinstance(docs, DocumentArray): # validation shortcut for DocumentArray; only look at the schema reference_schema_flat = self._flatten_schema( @@ -796,6 +795,7 @@ def _validate_docs( # this could be relaxed in the future, # see schema translation ideas in the design doc names_compatible = reference_names == input_names + # TODO change here? types_compatible = all( (not is_union_type(t2) and issubclass(t1, t2)) for (t1, t2) in zip(reference_types, input_types) @@ -851,9 +851,13 @@ def _convert_dict_to_doc( :param schema: The schema of the Document object :return: A Document object """ - for field_name, _ in schema.__fields__.items(): - t_ = unwrap_optional_type(schema._get_field_type(field_name)) + t_ = schema._get_field_type(field_name) + if is_tensor_union(t_): + t_ = AbstractTensor + else: + t_ = unwrap_optional_type(t_) + if issubclass(t_, BaseDocument): inner_dict = {} diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index b437dd57389..c1e136ab3ea 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -9,9 +9,8 @@ from docarray.documents import ImageDoc from docarray.index.abstract import BaseDocumentIndex, _raise_not_composable from docarray.typing import ID, ImageBytes, ImageUrl, NdArray -from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding -from docarray.typing.tensor.image.image_ndarray import ImageNdArray -from docarray.utils.misc import is_tf_available, is_torch_available +from docarray.typing.tensor.abstract_tensor import AbstractTensor +from docarray.utils.misc import is_torch_available pytestmark = pytest.mark.index @@ -193,36 +192,14 @@ class MyDoc(BaseDocument): torch_available = is_torch_available() if torch_available: - from docarray.typing.tensor.embedding.torch import TorchEmbedding from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor - tf_available = is_tf_available() - if tf_available: - from docarray.typing.tensor.embedding.tensorflow import ( - TensorFlowEmbedding as TFEmbedding, - ) - from docarray.typing.tensor.image.image_tensorflow_tensor import ( - ImageTensorFlowTensor as ImageTFTensor, - ) - - tensor_type = Union[ImageNdArray, None] - embedding_type = Union[NdArrayEmbedding, None] - if tf_available and torch_available: - tensor_type = Union[ImageNdArray, ImageTorchTensor, ImageTFTensor, None] # type: ignore - embedding_type = Union[NdArrayEmbedding, TorchEmbedding, TFEmbedding, None] - elif tf_available: - tensor_type = Union[ImageNdArray, ImageTFTensor, None] - embedding_type = Union[NdArrayEmbedding, TFEmbedding, None] - elif torch_available: - tensor_type = Union[ImageNdArray, ImageTorchTensor, None] - embedding_type = Union[NdArrayEmbedding, TorchEmbedding, None] - assert set(store._flatten_schema(MyDoc)) == { ('id', ID, fields['id']), ('image__id', ID, fields_image['id']), ('image__url', ImageUrl, fields_image['url']), - ('image__tensor', tensor_type, fields_image['tensor']), - ('image__embedding', embedding_type, fields_image['embedding']), + ('image__tensor', AbstractTensor, fields_image['tensor']), + ('image__embedding', AbstractTensor, fields_image['embedding']), ('image__bytes_', ImageBytes, fields_image['bytes_']), } @@ -232,16 +209,15 @@ class MyDoc2(BaseDocument): with pytest.raises(Exception): _ = DummyDocIndex[MyDoc2]() - # class MyDoc3(BaseDocument): - # tensor: Union[NdArray, ImageTorchTensor] + class MyDoc3(BaseDocument): + tensor: Union[NdArray, ImageTorchTensor] - # store = DummyDocIndex[MyDoc3]() - # fields = MyDoc3.__fields__ - # fields_image = ImageDoc.__fields__ - # assert set(store._flatten_schema(MyDoc3)) == { - # ('id', ID, fields['id']), - # ('tensor', Union[NdArray, ImageNdArray], fields_image['tensor']), - # } + store = DummyDocIndex[MyDoc3]() + fields = MyDoc3.__fields__ + assert set(store._flatten_schema(MyDoc3)) == { + ('id', ID, fields['id']), + ('tensor', AbstractTensor, fields['tensor']), + } def test_columns_db_type_with_user_defined_mapping(tmp_path): @@ -366,6 +342,7 @@ class OtherNestedDoc(NestedDoc): ) +# TODO change here def test_docs_validation_unions(): class OptionalDoc(BaseDocument): tens: Optional[NdArray[10]] = Field(dim=1000) @@ -541,3 +518,31 @@ def test_convert_dict_to_doc(): assert doc.d.id == doc_dict_copy['d__id'] assert doc.d.d.id == doc_dict_copy['d__d__id'] assert np.all(doc.d.d.tens == doc_dict_copy['d__d__tens']) + + class MyDoc(BaseDocument): + image: ImageDoc + + store = DummyDocIndex[MyDoc]() + doc_dict = { + 'id': 'root', + 'image__id': 'nested', + 'image__tensor': np.random.random((128,)), + } + doc = store._convert_dict_to_doc(doc_dict, store._schema) + + torch_available = is_torch_available() + if torch_available: + from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor + + class MyDoc2(BaseDocument): + tens: Union[NdArray, ImageTorchTensor] + + store = DummyDocIndex[MyDoc2]() + doc_dict = { + 'id': 'root', + 'tens': np.random.random((128,)), + } + doc_dict_copy = doc_dict.copy() + doc = store._convert_dict_to_doc(doc_dict, store._schema) + assert doc.id == doc_dict_copy['id'] + assert np.all(doc.tens == doc_dict_copy['tens']) From 1906604051dc70f0379738f78dda861eed25aa4f Mon Sep 17 00:00:00 2001 From: AnneY Date: Tue, 28 Mar 2023 11:22:27 +0800 Subject: [PATCH 3/8] fix: catch exception when flatten schema Signed-off-by: AnneY --- docarray/index/abstract.py | 30 ++++++++++--------- .../index/base_classes/test_base_doc_store.py | 26 ++++++++++++---- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index cbaaaa548cd..42eb4a31387 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -780,7 +780,6 @@ def _validate_docs( """ if isinstance(docs, BaseDocument): docs = [docs] - # TODO List of Docs if isinstance(docs, DocumentArray): # validation shortcut for DocumentArray; only look at the schema reference_schema_flat = self._flatten_schema( @@ -788,20 +787,23 @@ def _validate_docs( ) reference_names = [name for (name, _, _) in reference_schema_flat] reference_types = [t_ for (_, t_, _) in reference_schema_flat] + try: + input_schema_flat = self._flatten_schema(docs.document_type) + input_names = [name for (name, _, _) in input_schema_flat] + input_types = [t_ for (_, t_, _) in input_schema_flat] + # this could be relaxed in the future, + # see schema translation ideas in the design doc + names_compatible = reference_names == input_names + # TODO change here? + types_compatible = all( + (issubclass(t1, t2)) + for (t1, t2) in zip(reference_types, input_types) + ) + if names_compatible and types_compatible: + return docs + except Exception: + pass - input_schema_flat = self._flatten_schema(docs.document_type) - input_names = [name for (name, _, _) in input_schema_flat] - input_types = [t_ for (_, t_, _) in input_schema_flat] - # this could be relaxed in the future, - # see schema translation ideas in the design doc - names_compatible = reference_names == input_names - # TODO change here? - types_compatible = all( - (not is_union_type(t2) and issubclass(t1, t2)) - for (t1, t2) in zip(reference_types, input_types) - ) - if names_compatible and types_compatible: - return docs out_docs = [] for i in range(len(docs)): # validate the data diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index c1e136ab3ea..c9c7e9445cf 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -347,9 +347,12 @@ def test_docs_validation_unions(): class OptionalDoc(BaseDocument): tens: Optional[NdArray[10]] = Field(dim=1000) - class UnionDoc(BaseDocument): + class MixedUnionDoc(BaseDocument): tens: Union[NdArray[10], str] = Field(dim=1000) + class TensorUnionDoc(BaseDocument): + tens: Union[NdArray[10], AbstractTensor] = Field(dim=1000) + # OPTIONAL store = DummyDocIndex[SimpleDoc]() in_list = [OptionalDoc(tens=np.random.random((10,)))] @@ -360,15 +363,28 @@ class UnionDoc(BaseDocument): with pytest.raises(ValueError): store._validate_docs([OptionalDoc(tens=None)]) - # OTHER UNION + # MIXED UNION store = DummyDocIndex[SimpleDoc]() - in_list = [UnionDoc(tens=np.random.random((10,)))] + in_list = [MixedUnionDoc(tens=np.random.random((10,)))] assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) - in_da = DocumentArray[UnionDoc](in_list) + in_da = DocumentArray[MixedUnionDoc](in_list) assert isinstance(store._validate_docs(in_da), DocumentArray[BaseDocument]) with pytest.raises(ValueError): - store._validate_docs([UnionDoc(tens='hello')]) + store._validate_docs([MixedUnionDoc(tens='hello')]) + + # TENSOR UNION + store = DummyDocIndex[TensorUnionDoc]() + in_list = [SimpleDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[SimpleDoc](in_list) + assert isinstance(store._validate_docs(in_da), DocumentArray[BaseDocument]) + + store = DummyDocIndex[SimpleDoc]() + in_list = [TensorUnionDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[TensorUnionDoc](in_list) + assert store._validate_docs(in_da) == in_da def test_get_value(): From a10a66fe00e6f523d9a42d50b31510f50df7657a Mon Sep 17 00:00:00 2001 From: AnneY Date: Tue, 28 Mar 2023 11:39:39 +0800 Subject: [PATCH 4/8] refactor: remove useless assignemnt Signed-off-by: AnneY --- docarray/index/abstract.py | 8 ++------ tests/index/base_classes/test_base_doc_store.py | 1 - 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 42eb4a31387..140b9d565ff 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -854,13 +854,9 @@ def _convert_dict_to_doc( :return: A Document object """ for field_name, _ in schema.__fields__.items(): - t_ = schema._get_field_type(field_name) - if is_tensor_union(t_): - t_ = AbstractTensor - else: - t_ = unwrap_optional_type(t_) + t_ = unwrap_optional_type(schema._get_field_type(field_name)) - if issubclass(t_, BaseDocument): + if not is_union_type(t_) and issubclass(t_, BaseDocument): inner_dict = {} fields = [ diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index c9c7e9445cf..bb7c4499694 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -342,7 +342,6 @@ class OtherNestedDoc(NestedDoc): ) -# TODO change here def test_docs_validation_unions(): class OptionalDoc(BaseDocument): tens: Optional[NdArray[10]] = Field(dim=1000) From c0d6b96f8cdc15b0bc77c7b395d14965c8946f23 Mon Sep 17 00:00:00 2001 From: AnneY Date: Tue, 28 Mar 2023 22:52:56 +0800 Subject: [PATCH 5/8] fix: use Abstractensor as tensor doc_type Signed-off-by: AnneY --- docarray/index/abstract.py | 54 ++++++++++--------- docarray/utils/_internal/_typing.py | 16 +----- .../index/base_classes/test_base_doc_store.py | 23 ++++---- 3 files changed, 42 insertions(+), 51 deletions(-) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 46c81b299d2..d3f3fb73378 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -21,13 +21,13 @@ import numpy as np from pydantic.error_wrappers import ValidationError -from typing_inspect import get_args, is_union_type +from typing_inspect import get_args, is_optional_type, is_union_type from docarray import BaseDoc, DocArray from docarray.array.abstract_array import AnyDocArray from docarray.typing import AnyTensor from docarray.typing.tensor.abstract_tensor import AbstractTensor -from docarray.utils._internal._typing import is_tensor_union, unwrap_optional_type +from docarray.utils._internal._typing import is_tensor_union from docarray.utils._internal.misc import is_tf_available, torch_imported from docarray.utils.find import FindResult, _FindResult @@ -677,7 +677,13 @@ def _flatten_schema( if is_union_type(t_): union_args = get_args(t_) - if len(union_args) == 2 and type(None) in union_args: + + if is_tensor_union(t_): + names_types_fields.append( + (name_prefix + field_name, AbstractTensor, field_) + ) + + elif len(union_args) == 2 and type(None) in union_args: # simple "Optional" type, treat as special case: # treat as if it was a single non-optional type for t_arg in union_args: @@ -690,20 +696,18 @@ def _flatten_schema( names_types_fields.append( (name_prefix + field_name, t_arg, field_) ) - - elif is_tensor_union(t_): - names_types_fields.append( - (name_prefix + field_name, AbstractTensor, field_) - ) - else: - raise Exception( + raise ValueError( f'Union type {t_} is not supported. Only Union of subclasses of ndarray or Union[type, None] are supported.' ) elif issubclass(t_, BaseDoc): names_types_fields.extend( cls._flatten_schema(t_, name_prefix=inner_prefix) ) + elif issubclass(t_, AbstractTensor): + names_types_fields.append( + (name_prefix + field_name, AbstractTensor, field_) + ) else: names_types_fields.append((name_prefix + field_name, t_, field_)) return names_types_fields @@ -717,10 +721,7 @@ def _create_column_infos(self, schema: Type[BaseDoc]) -> Dict[str, _ColumnInfo]: """ column_infos: Dict[str, _ColumnInfo] = dict() for field_name, type_, field_ in self._flatten_schema(schema): - # if is_optional_type(type_): # TODO - # column_infos[field_name] = self._create_single_column( - # field_, unwrap_optional_type(type_) - # ) + # Union types are handle in _flatten_schema if issubclass(type_, AnyDocArray): raise ValueError( 'Indexing field of DocArray type (=subindex)' @@ -732,7 +733,6 @@ def _create_column_infos(self, schema: Type[BaseDoc]) -> Dict[str, _ColumnInfo]: def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo: custom_config = field.field_info.extra - if 'col_type' in custom_config.keys(): db_type = custom_config['col_type'] custom_config.pop('col_type') @@ -747,13 +747,13 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo config.update(custom_config) # parse n_dim from parametrized tensor type if ( - hasattr(type_, '__docarray_target_shape__') - and type_.__docarray_target_shape__ + hasattr(field.type_, '__docarray_target_shape__') + and field.type_.__docarray_target_shape__ ): - if len(type_.__docarray_target_shape__) == 1: - n_dim = type_.__docarray_target_shape__[0] + if len(field.type_.__docarray_target_shape__) == 1: + n_dim = field.type_.__docarray_target_shape__[0] else: - n_dim = type_.__docarray_target_shape__ + n_dim = field.type_.__docarray_target_shape__ else: n_dim = None return _ColumnInfo( @@ -785,20 +785,20 @@ def _validate_docs( reference_types = [t_ for (_, t_, _) in reference_schema_flat] try: input_schema_flat = self._flatten_schema(docs.document_type) + except ValueError: + pass + else: input_names = [name for (name, _, _) in input_schema_flat] input_types = [t_ for (_, t_, _) in input_schema_flat] # this could be relaxed in the future, # see schema translation ideas in the design doc names_compatible = reference_names == input_names - # TODO change here? types_compatible = all( - (issubclass(t1, t2)) + (issubclass(t2, t1)) for (t1, t2) in zip(reference_types, input_types) ) if names_compatible and types_compatible: return docs - except Exception: - pass out_docs = [] for i in range(len(docs)): @@ -848,7 +848,11 @@ def _convert_dict_to_doc( :return: A Document object """ for field_name, _ in schema.__fields__.items(): - t_ = unwrap_optional_type(schema._get_field_type(field_name)) + t_ = schema._get_field_type(field_name) + if is_optional_type(t_): + for t_arg in get_args(t_): + if t_arg is not type(None): + t_ = t_arg if not is_union_type(t_) and issubclass(t_, BaseDoc): inner_dict = {} diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py index 9bbc0162432..62680cf964e 100644 --- a/docarray/utils/_internal/_typing.py +++ b/docarray/utils/_internal/_typing.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from typing_inspect import get_args, is_optional_type, is_union_type +from typing_inspect import get_args, is_union_type from docarray.typing.tensor.abstract_tensor import AbstractTensor @@ -32,17 +32,3 @@ def change_cls_name(cls: type, new_name: str, scope: Optional[dict] = None) -> N scope[new_name] = cls cls.__qualname__ = cls.__qualname__[: -len(cls.__name__)] + new_name cls.__name__ = new_name - - -def unwrap_optional_type(type_: Any) -> Any: - """Return the type of an Optional type, e.g. `unwrap_optional(Optional[str]) == str`; - `unwrap_optional(Union[None, int, None]) == int`. - - :param type_: the type to unwrap - :return: the "core" type of an Optional type - """ - if not is_optional_type(type_): - return type_ - for arg in get_args(type_): - if arg is not type(None): - return arg diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index f794178004d..b5774020524 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -48,6 +48,7 @@ class RuntimeConfig(BaseDocIndex.RuntimeConfig): str: {'hi': 'there'}, np.ndarray: {'you': 'good?'}, 'varchar': {'good': 'bye'}, + AbstractTensor: {'dim': 1000}, } ) @@ -106,7 +107,7 @@ def test_create_columns(): assert store._column_infos['id'].n_dim is None assert store._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['tens'].docarray_type, NdArray) + assert issubclass(store._column_infos['tens'].docarray_type, AbstractTensor) assert store._column_infos['tens'].db_type == str assert store._column_infos['tens'].n_dim == 10 assert store._column_infos['tens'].config == {'dim': 1000, 'hi': 'there'} @@ -120,12 +121,12 @@ def test_create_columns(): assert store._column_infos['id'].n_dim is None assert store._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['tens_one'].docarray_type, NdArray) + assert issubclass(store._column_infos['tens_one'].docarray_type, AbstractTensor) assert store._column_infos['tens_one'].db_type == str assert store._column_infos['tens_one'].n_dim is None assert store._column_infos['tens_one'].config == {'dim': 10, 'hi': 'there'} - assert issubclass(store._column_infos['tens_two'].docarray_type, NdArray) + assert issubclass(store._column_infos['tens_two'].docarray_type, AbstractTensor) assert store._column_infos['tens_two'].db_type == str assert store._column_infos['tens_two'].n_dim is None assert store._column_infos['tens_two'].config == {'dim': 50, 'hi': 'there'} @@ -139,7 +140,7 @@ def test_create_columns(): assert store._column_infos['id'].n_dim is None assert store._column_infos['id'].config == {'hi': 'there'} - assert issubclass(store._column_infos['d__tens'].docarray_type, NdArray) + assert issubclass(store._column_infos['d__tens'].docarray_type, AbstractTensor) assert store._column_infos['d__tens'].db_type == str assert store._column_infos['d__tens'].n_dim == 10 assert store._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} @@ -150,15 +151,15 @@ def test_flatten_schema(): fields = SimpleDoc.__fields__ assert set(store._flatten_schema(SimpleDoc)) == { ('id', ID, fields['id']), - ('tens', NdArray[10], fields['tens']), + ('tens', AbstractTensor, fields['tens']), } store = DummyDocIndex[FlatDoc]() fields = FlatDoc.__fields__ assert set(store._flatten_schema(FlatDoc)) == { ('id', ID, fields['id']), - ('tens_one', NdArray, fields['tens_one']), - ('tens_two', NdArray, fields['tens_two']), + ('tens_one', AbstractTensor, fields['tens_one']), + ('tens_two', AbstractTensor, fields['tens_two']), } store = DummyDocIndex[NestedDoc]() @@ -167,7 +168,7 @@ def test_flatten_schema(): assert set(store._flatten_schema(NestedDoc)) == { ('id', ID, fields['id']), ('d__id', ID, fields_nested['id']), - ('d__tens', NdArray[10], fields_nested['tens']), + ('d__tens', AbstractTensor, fields_nested['tens']), } store = DummyDocIndex[DeepNestedDoc]() @@ -178,7 +179,7 @@ def test_flatten_schema(): ('id', ID, fields['id']), ('d__id', ID, fields_nested['id']), ('d__d__id', ID, fields_nested_nested['id']), - ('d__d__tens', NdArray[10], fields_nested_nested['tens']), + ('d__d__tens', AbstractTensor, fields_nested_nested['tens']), } @@ -205,7 +206,7 @@ class MyDoc(BaseDoc): class MyDoc2(BaseDoc): tensor: Union[NdArray, str] - with pytest.raises(Exception): + with pytest.raises(ValueError): _ = DummyDocIndex[MyDoc2]() class MyDoc3(BaseDoc): @@ -374,7 +375,7 @@ class TensorUnionDoc(BaseDoc): in_list = [SimpleDoc(tens=np.random.random((10,)))] assert isinstance(store._validate_docs(in_list), DocArray[BaseDoc]) in_da = DocArray[SimpleDoc](in_list) - assert isinstance(store._validate_docs(in_da), DocArray[BaseDoc]) + assert store._validate_docs(in_da) == in_da store = DummyDocIndex[SimpleDoc]() in_list = [TensorUnionDoc(tens=np.random.random((10,)))] From 8d1ce7e71ee644ee36bef3ef142c7ca761b332ee Mon Sep 17 00:00:00 2001 From: AnneY Date: Wed, 29 Mar 2023 12:05:58 +0800 Subject: [PATCH 6/8] fix: add AbstractTensor to hnswlib Signed-off-by: AnneY --- docarray/index/backends/hnswlib.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index ab8aa6e56c1..5dfbc987741 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -29,6 +29,7 @@ _raise_not_supported, ) from docarray.proto import DocumentProto +from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.misc import is_np_int, is_tf_available, is_torch_available from docarray.utils.filter import filter_docs from docarray.utils.find import _FindResult @@ -36,7 +37,7 @@ TSchema = TypeVar('TSchema', bound=BaseDoc) T = TypeVar('T', bound='HnswDocumentIndex') -HNSWLIB_PY_VEC_TYPES = [list, tuple, np.ndarray] +HNSWLIB_PY_VEC_TYPES = [list, tuple, np.ndarray, AbstractTensor] if is_torch_available(): import torch From 1b90ee06031c71cb0c77c1352b7617fefb3e84a4 Mon Sep 17 00:00:00 2001 From: AnneY Date: Wed, 29 Mar 2023 14:36:55 +0800 Subject: [PATCH 7/8] docs: AbstractTensor as doc_type Signed-off-by: AnneY --- docarray/index/abstract.py | 2 +- docs/how_to/add_doc_index.md | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index d3f3fb73378..3f0ab25a1b0 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -698,7 +698,7 @@ def _flatten_schema( ) else: raise ValueError( - f'Union type {t_} is not supported. Only Union of subclasses of ndarray or Union[type, None] are supported.' + f'Union type {t_} is not supported. Only Union of subclasses of AbstractTensor or Union[type, None] are supported.' ) elif issubclass(t_, BaseDoc): names_types_fields.extend( diff --git a/docs/how_to/add_doc_index.md b/docs/how_to/add_doc_index.md index 8fb03b9978b..b4a477a9db3 100644 --- a/docs/how_to/add_doc_index.md +++ b/docs/how_to/add_doc_index.md @@ -139,7 +139,7 @@ class _ColumnInfo: config: Dict[str, Any] ``` -- `docarray_type` is the type of the column in DocArray, e.g. `NdArray` or `str` +- `docarray_type` is the type of the column in DocArray, e.g. `AbstractTensor` or `str` - `db_type` is the type of the column in the Document Index, e.g. `np.ndarray` or `str`. You can customize the mapping from `docarray_type` to `db_type`, as we will see later. - `config` is a dictionary of configurations for the column. For example, for the `other_tensor` column above, this would contain the `space` and `dim` configurations. - `n_dim` is the dimensionality of the column, e.g. `100` for a 100-dimensional vector. See further guidance on this below. @@ -153,6 +153,9 @@ By default, it holds that `_ColumnInfo.docarray_type == self.python_type_to_db_t However, you should not rely on this, because a user can manually specify a different db_type. Therefore, your implementation should rely on `_ColumnInfo.db_type` and not directly call `python_type_to_db_type()`. +**Caution** +`AbstractTensor` will be the `_ColumnInfo.docarray_type` if the field in `self._schema` is a subclass of `AbstractTensor` or a tensor Union. + ### Properly handle `n_dim` `_ColumnInfo.n_dim` is automatically obtained from type parametrizations of the form `NdArray[100]`; @@ -296,7 +299,7 @@ The details of each method should become clear from the docstrings and type hint This method is slightly special, because 1) it is not exposed to the user, and 2) you absolutely have to implement it. -It is intended to do the following: It takes a type of a field in the store's schema (e.g. `NdArray` for `tensor`), and returns the corresponding type in the database (e.g. `np.ndarray`). +It is intended to do the following: It takes a type of a field in the store's schema (e.g. `AbstractTensor` for `tensor`), and returns the corresponding type in the database (e.g. `np.ndarray`). The `BaseDocIndex` class uses this information to create and populate the `_ColumnInfo`s in `self._column_infos`. If the user wants to change the default behaviour, one can set the db type by using the `col_type` field: From 97900769f8619ed4cc281a760562fb361eb060ac Mon Sep 17 00:00:00 2001 From: Anne Yang Date: Wed, 29 Mar 2023 18:43:45 +0800 Subject: [PATCH 8/8] docs: complete description about AbstracTensor Co-authored-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Signed-off-by: Anne Yang --- docs/how_to/add_doc_index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how_to/add_doc_index.md b/docs/how_to/add_doc_index.md index b4a477a9db3..4f0125428e1 100644 --- a/docs/how_to/add_doc_index.md +++ b/docs/how_to/add_doc_index.md @@ -154,7 +154,7 @@ However, you should not rely on this, because a user can manually specify a diff Therefore, your implementation should rely on `_ColumnInfo.db_type` and not directly call `python_type_to_db_type()`. **Caution** -`AbstractTensor` will be the `_ColumnInfo.docarray_type` if the field in `self._schema` is a subclass of `AbstractTensor` or a tensor Union. +If a subclass of `AbstractTensor` appears in the Document Index's schema (i.e. `TorchTensor`, `NdArray`, or `TensorFlowTensor`), then `_ColumnInfo.docarray_type` will simply show `AbstractTensor` instead of the specific subclass. This is because the abstract class normalizes all input data of type `AbstractTensor` to `np.ndarray` anyways, which should make your life easier. Just be sure to properly handle `AbstractTensor` as a possible value or `_ColumnInfo.docarray_type`, and you won't have to worry about the differences between torch, tf, and np. ### Properly handle `n_dim`