From a93584a001b956004b062892d50cf7c63ee9bc44 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Tue, 7 Mar 2023 14:57:28 +0100 Subject: [PATCH 01/13] refactor: split flattening into separate method Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 66 ++++++++++++++++++------ 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index c5a9557689b..e6792d36f17 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -18,6 +18,7 @@ TypeVar, Union, cast, + get_args, ) import numpy as np @@ -112,7 +113,9 @@ def __init__(self, db_config=None, **kwargs): if not isinstance(self._db_config, self.DBConfig): raise ValueError(f'db_config must be of type {self.DBConfig}') self._runtime_config = self.RuntimeConfig() - self._column_infos: Dict[str, _ColumnInfo] = self._create_columns(self._schema) + self._column_infos: Dict[str, _ColumnInfo] = self._create_column_infos( + self._schema + ) ############################################### # Inner classes for query builder and configs # @@ -613,31 +616,60 @@ def build_query(self) -> QueryBuilder: """ return self.QueryBuilder() # type: ignore - def _create_columns(self, schema: Type[BaseDocument]) -> Dict[str, _ColumnInfo]: - columns: Dict[str, _ColumnInfo] = dict() + def _flatten_schema( + self, schema: Type[BaseDocument], name_prefix: str = '' + ) -> List[Tuple[str, Type, 'ModelField']]: + """Flatten the schema of a Document into a list of column names and types. + + :param schema: The schema to flatten + :param name_prefix: prefix to append to the column names. Used for recursive calls to handle nesting. + :return: A list of column names, types, and fields + """ + names_types_fields: List[Tuple[str, Type, 'ModelField']] = [] for field_name, field_ in schema.__fields__.items(): t_ = schema._get_field_type(field_name) + inner_prefix = name_prefix + field_name + '__' + if is_union_type(t_): + union_args = get_args(t_) + if len(union_args) == 2 and type(None) in union_args: + # simple "Optional" type, treat as special case: + # treat as if it was a single non-optional type + for t_arg in union_args: + if t_arg is type(None): + pass + elif issubclass(t_arg, BaseDocument): + names_types_fields.extend( + self._flatten_schema(t_arg, name_prefix=inner_prefix) + ) + else: + names_types_fields.append((field_name, t_, field_)) + elif issubclass(t_, BaseDocument): + names_types_fields.extend( + self._flatten_schema(t_, name_prefix=inner_prefix) + ) + else: + names_types_fields.append((name_prefix + field_name, t_, field_)) + return names_types_fields + + def _create_column_infos( + self, schema: Type[BaseDocument] + ) -> Dict[str, _ColumnInfo]: + column_infos: Dict[str, _ColumnInfo] = dict() + for field_name, type_, field_ in self._flatten_schema(schema): + if is_union_type(type_): raise ValueError( 'Union types are not supported in the schema of a DocumentIndex.' - f' Instead of using type {t_} use a single specific type.' + f' Instead of using type {type_} use a single specific type.' ) - elif issubclass(t_, AnyDocumentArray): + elif issubclass(type_, AnyDocumentArray): raise ValueError( 'Indexing field of DocumentArray type (=subindex)' 'is not yet supported.' ) - elif issubclass(t_, BaseDocument): - columns = dict( - columns, - **{ - f'{field_name}__{nested_name}': t - for nested_name, t in self._create_columns(t_).items() - }, - ) else: - columns[field_name] = self._create_single_column(field_, t_) - return columns + column_infos[field_name] = self._create_single_column(field_, type_) + return column_infos def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo: db_type = self.python_type_to_db_type(type_) @@ -665,7 +697,7 @@ def _is_schema_compatible(self, docs: Sequence[BaseDocument]) -> bool: (name, col.db_type) for name, col in self._column_infos.items() ] if isinstance(docs, AnyDocumentArray): - input_columns = self._create_columns(docs.document_type) + input_columns = self._create_column_infos(docs.document_type) input_col_db_types = [ (name, col.db_type) for name, col in input_columns.items() ] @@ -674,7 +706,7 @@ def _is_schema_compatible(self, docs: Sequence[BaseDocument]) -> bool: return reference_col_db_types == input_col_db_types else: for d in docs: - input_columns = self._create_columns(type(d)) + input_columns = self._create_column_infos(type(d)) input_col_db_types = [ (name, col.db_type) for name, col in input_columns.items() ] From 3fcba13de3497b7500480d4960c994dd76d44090 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Thu, 9 Mar 2023 17:30:10 +0100 Subject: [PATCH 02/13] refactor: don't build column info during schema check Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 33 +++++++++++++----------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index e6792d36f17..96abfe55875 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -577,8 +577,8 @@ def _get_col_value_dict( docs_seq = docs if not self._is_schema_compatible(docs_seq): raise ValueError( - 'The schema of the documents to be indexed is not compatible' - ' with the schema of the index.' + 'The schema of the input documents is not compatible' + ' with the schema of the Document Index.' ) def _col_gen(col_name: str): @@ -693,27 +693,30 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo def _is_schema_compatible(self, docs: Sequence[BaseDocument]) -> bool: """Flatten a DocumentArray into a DocumentArray of the schema type.""" - reference_col_db_types = [ - (name, col.db_type) for name, col in self._column_infos.items() - ] + reference_schema_flat = self._flatten_schema(self._schema) + reference_names = [name for (name, _, _) in reference_schema_flat] + reference_types = [t_ for (_, t_, _) in reference_schema_flat] if isinstance(docs, AnyDocumentArray): - input_columns = self._create_column_infos(docs.document_type) - input_col_db_types = [ - (name, col.db_type) for name, col in input_columns.items() - ] + input_schema_flat = self._flatten_schema(docs.document_type) + input_names = [name for (name, _, _) in input_schema_flat] + input_types = [t_ for (_, t_, _) in input_schema_flat] # this could be relaxed in the future, # see schema translation ideas in the design doc - return reference_col_db_types == input_col_db_types + return reference_names == input_names and all( + issubclass(t1, t2) for (t1, t2) in zip(reference_types, input_types) + ) else: for d in docs: - input_columns = self._create_column_infos(type(d)) - input_col_db_types = [ - (name, col.db_type) for name, col in input_columns.items() - ] + input_schema_flat = self._flatten_schema(type(d)) + input_names = [name for (name, _, _) in input_schema_flat] + input_types = [t_ for (_, t_, _) in input_schema_flat] # this could be relaxed in the future, # see schema translation ideas in the design doc - if reference_col_db_types != input_col_db_types: + if reference_names != input_names or not all( + issubclass(t1, t2) for (t1, t2) in zip(reference_types, input_types) + ): return False + return True def _to_numpy(self, val: Any) -> Any: From d11a4b44e202fe10dc2557636529431b5c04a00a Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 10 Mar 2023 15:51:13 +0100 Subject: [PATCH 03/13] feat: allos unions and optional in indexed data Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 66 +++++---- .../doc_index/backends/hnswlib_doc_index.py | 8 +- .../base_classes/test_base_doc_store.py | 138 +++++++++++------- 3 files changed, 128 insertions(+), 84 deletions(-) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index 96abfe55875..b8b1a82738b 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -22,6 +22,7 @@ ) import numpy as np +from pydantic.error_wrappers import ValidationError from typing_inspect import is_union_type from docarray import BaseDocument, DocumentArray @@ -386,9 +387,12 @@ def configure(self, runtime_config=None, **kwargs): def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs): """Index Documents into the index. - :param docs: Documents to index + :param docs: Documents to index. NOTE: passing a Sequence of Documents that is + not a DocumentArray comes at a performance penalty, since compatibility + with the Index's schema need to be checked for every Document individually. """ - data_by_columns = self._get_col_value_dict(docs) + docs_validated = self._validate_docs(docs) + data_by_columns = self._get_col_value_dict(docs_validated) self._index(data_by_columns, **kwargs) # type: ignore def find( @@ -575,11 +579,6 @@ def _get_col_value_dict( docs_seq: Sequence[BaseDocument] = [docs] else: docs_seq = docs - if not self._is_schema_compatible(docs_seq): - raise ValueError( - 'The schema of the input documents is not compatible' - ' with the schema of the Document Index.' - ) def _col_gen(col_name: str): return (self._get_values_by_column([doc], col_name)[0] for doc in docs_seq) @@ -691,33 +690,44 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo docarray_type=type_, db_type=db_type, config=config, n_dim=n_dim ) - def _is_schema_compatible(self, docs: Sequence[BaseDocument]) -> bool: - """Flatten a DocumentArray into a DocumentArray of the schema type.""" - reference_schema_flat = self._flatten_schema(self._schema) - reference_names = [name for (name, _, _) in reference_schema_flat] - reference_types = [t_ for (_, t_, _) in reference_schema_flat] - if isinstance(docs, AnyDocumentArray): + def _validate_docs( + self, docs: Union[BaseDocument, Sequence[BaseDocument]] + ) -> DocumentArray[BaseDocument]: + if isinstance(docs, BaseDocument): + docs = [docs] + if isinstance(docs, DocumentArray): + # validation shortcut for DocumentArray; only look at the schema + reference_schema_flat = self._flatten_schema(self._schema) + reference_names = [name for (name, _, _) in reference_schema_flat] + reference_types = [t_ for (_, t_, _) in reference_schema_flat] + input_schema_flat = self._flatten_schema(docs.document_type) input_names = [name for (name, _, _) in input_schema_flat] input_types = [t_ for (_, t_, _) in input_schema_flat] # this could be relaxed in the future, # see schema translation ideas in the design doc - return reference_names == input_names and all( - issubclass(t1, t2) for (t1, t2) in zip(reference_types, input_types) + names_compatible = reference_names == input_names + types_compatible = all( + (not is_union_type(t2) and issubclass(t1, t2)) + for (t1, t2) in zip(reference_types, input_types) ) - else: - for d in docs: - input_schema_flat = self._flatten_schema(type(d)) - input_names = [name for (name, _, _) in input_schema_flat] - input_types = [t_ for (_, t_, _) in input_schema_flat] - # this could be relaxed in the future, - # see schema translation ideas in the design doc - if reference_names != input_names or not all( - issubclass(t1, t2) for (t1, t2) in zip(reference_types, input_types) - ): - return False - - return True + if names_compatible and types_compatible: + return docs + out_docs = [] + for i in range(len(docs)): + # validate the data + try: + out_docs.append(self._schema.parse_obj(docs[i])) + except (ValueError, ValidationError): + raise ValueError( + 'The schema of the input Documents is not compatible with the schema of the Document Index.' + ' Ensure that the field names of your data match the field names of the Document Index schema,' + ' and that the types of your data match the types of the Document Index schema.' + ) + + return DocumentArray[BaseDocument]( + out_docs + ) # TODO(johannes): use `construct` here to avoid validating again def _to_numpy(self, val: Any) -> Any: if isinstance(val, np.ndarray): diff --git a/docarray/doc_index/backends/hnswlib_doc_index.py b/docarray/doc_index/backends/hnswlib_doc_index.py index ce98c5d4e47..21e53993c13 100644 --- a/docarray/doc_index/backends/hnswlib_doc_index.py +++ b/docarray/doc_index/backends/hnswlib_doc_index.py @@ -143,9 +143,9 @@ def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs): """Index a document into the store""" if kwargs: raise ValueError(f'{list(kwargs.keys())} are not valid keyword arguments') - doc_seq = docs if isinstance(docs, Sequence) else [docs] - data_by_columns = self._get_col_value_dict(doc_seq) - hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in doc_seq) + docs_validated = self._validate_docs(docs) + data_by_columns = self._get_col_value_dict(docs_validated) + hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in docs_validated) # indexing into HNSWLib and SQLite sequentially # could be improved by processing in parallel @@ -156,7 +156,7 @@ def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs): index.add_items(data_stacked, ids=hashed_ids) index.save_index(self._hnsw_locations[col_name]) - self._send_docs_to_sqlite(doc_seq) + self._send_docs_to_sqlite(docs_validated) self._sqlite_conn.commit() def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: diff --git a/tests/doc_index/base_classes/test_base_doc_store.py b/tests/doc_index/base_classes/test_base_doc_store.py index 5bc9fd55116..c9234e215e7 100644 --- a/tests/doc_index/base_classes/test_base_doc_store.py +++ b/tests/doc_index/base_classes/test_base_doc_store.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Type +from typing import Any, Dict, Optional, Type, Union import numpy as np import pytest @@ -131,7 +131,7 @@ def test_create_columns(): assert store._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} -def test_is_schema_compatible(): +def test_docs_validation(): class OtherSimpleDoc(SimpleDoc): ... @@ -141,81 +141,115 @@ class OtherFlatDoc(FlatDoc): class OtherNestedDoc(NestedDoc): ... + # SIMPLE store = DummyDocIndex[SimpleDoc]() - assert store._is_schema_compatible([SimpleDoc(tens=np.random.random((10,)))]) - assert store._is_schema_compatible( - DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) - ) - assert store._is_schema_compatible([OtherSimpleDoc(tens=np.random.random((10,)))]) - assert store._is_schema_compatible( - DocumentArray[OtherSimpleDoc]([OtherSimpleDoc(tens=np.random.random((10,)))]) - ) - assert not store._is_schema_compatible( - [FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))] - ) - assert not store._is_schema_compatible( - DocumentArray[FlatDoc]( + in_list = [SimpleDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[SimpleDoc](in_list) + assert store._validate_docs(in_da) == in_da + in_other_list = [OtherSimpleDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_other_list), DocumentArray[BaseDocument]) + in_other_da = DocumentArray[OtherSimpleDoc](in_other_list) + assert store._validate_docs(in_other_da) == in_other_da + + with pytest.raises(ValueError): + store._validate_docs( [ FlatDoc( tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) ) ] ) - ) + with pytest.raises(ValueError): + store._validate_docs( + DocumentArray[FlatDoc]( + [ + FlatDoc( + tens_one=np.random.random((10,)), + tens_two=np.random.random((50,)), + ) + ] + ) + ) + # FLAT store = DummyDocIndex[FlatDoc]() - assert store._is_schema_compatible( + in_list = [ + FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) + ] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[FlatDoc]( [FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))] ) - assert store._is_schema_compatible( - DocumentArray[FlatDoc]( - [ - FlatDoc( - tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) - ) - ] - ) - ) - assert store._is_schema_compatible( + assert store._validate_docs(in_da) == in_da + in_other_list = [ + OtherFlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) + ] + assert isinstance(store._validate_docs(in_other_list), DocumentArray[BaseDocument]) + in_other_da = DocumentArray[OtherFlatDoc]( [ OtherFlatDoc( tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) ) ] ) - assert store._is_schema_compatible( - DocumentArray[OtherFlatDoc]( - [ - OtherFlatDoc( - tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) - ) - ] + assert store._validate_docs(in_other_da) == in_other_da + with pytest.raises(ValueError): + store._validate_docs([SimpleDoc(tens=np.random.random((10,)))]) + with pytest.raises(ValueError): + assert not store._validate_docs( + DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) ) - ) - assert not store._is_schema_compatible([SimpleDoc(tens=np.random.random((10,)))]) - assert not store._is_schema_compatible( - DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) - ) + # NESTED store = DummyDocIndex[NestedDoc]() - assert store._is_schema_compatible( + in_list = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[NestedDoc]( [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))] ) - assert store._is_schema_compatible( - DocumentArray[NestedDoc]([NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))]) - ) - assert store._is_schema_compatible( + assert store._validate_docs(in_da) == in_da + in_other_list = [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] + assert isinstance(store._validate_docs(in_other_list), DocumentArray[BaseDocument]) + in_other_da = DocumentArray[OtherNestedDoc]( [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] ) - assert store._is_schema_compatible( - DocumentArray[OtherNestedDoc]( - [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] + + assert store._validate_docs(in_other_da) == in_other_da + with pytest.raises(ValueError): + store._validate_docs([SimpleDoc(tens=np.random.random((10,)))]) + with pytest.raises(ValueError): + store._validate_docs( + DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) ) - ) - assert not store._is_schema_compatible([SimpleDoc(tens=np.random.random((10,)))]) - assert not store._is_schema_compatible( - DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) - ) + + +def test_docs_validation_unions(): + class OptionalDoc(BaseDocument): + tens: Optional[NdArray[10]] = Field(dim=1000) + + class UnionDoc(BaseDocument): + tens: Union[NdArray[10], str] = Field(dim=1000) + + # OPTIONAL + store = DummyDocIndex[SimpleDoc]() + in_list = [OptionalDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[OptionalDoc](in_list) + assert store._validate_docs(in_da) == in_da + + with pytest.raises(ValueError): + store._validate_docs([OptionalDoc(tens=None)]) + + # OTHER UNION + store = DummyDocIndex[SimpleDoc]() + in_list = [UnionDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[UnionDoc](in_list) + assert isinstance(store._validate_docs(in_da), DocumentArray[BaseDocument]) + + with pytest.raises(ValueError): + store._validate_docs([UnionDoc(tens='hello')]) def test_get_value(): From b4c1caeb1deb495921a0586595723fe88f917b21 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 10 Mar 2023 15:56:35 +0100 Subject: [PATCH 04/13] fix: mypy Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index b8b1a82738b..a9aa9e5a94e 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -697,7 +697,9 @@ def _validate_docs( docs = [docs] if isinstance(docs, DocumentArray): # validation shortcut for DocumentArray; only look at the schema - reference_schema_flat = self._flatten_schema(self._schema) + reference_schema_flat = self._flatten_schema( + cast(Type[BaseDocument], self._schema) + ) reference_names = [name for (name, _, _) in reference_schema_flat] reference_types = [t_ for (_, t_, _) in reference_schema_flat] @@ -717,7 +719,9 @@ def _validate_docs( for i in range(len(docs)): # validate the data try: - out_docs.append(self._schema.parse_obj(docs[i])) + out_docs.append( + cast(Type[BaseDocument], self._schema).parse_obj(docs[i]) + ) except (ValueError, ValidationError): raise ValueError( 'The schema of the input Documents is not compatible with the schema of the Document Index.' From 211bd958c52769c4ca8315d9cf16dd6feabb2580 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 10 Mar 2023 16:13:00 +0100 Subject: [PATCH 05/13] fix: mypy Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index b0a3f049f8e..1f2e71eb709 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -11,6 +11,7 @@ NamedTuple, Optional, Sequence, + Tuple, Type, TypeVar, Union, From a5ec4a68e3db385f525adc82d0a3e34e4974555d Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Fri, 10 Mar 2023 16:41:16 +0100 Subject: [PATCH 06/13] fix: import from typing inspect instead of typing Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index 1f2e71eb709..7956ba8c1e8 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -16,12 +16,11 @@ TypeVar, Union, cast, - get_args, ) import numpy as np from pydantic.error_wrappers import ValidationError -from typing_inspect import is_union_type +from typing_inspect import get_args, is_union_type from docarray import BaseDocument, DocumentArray from docarray.array.abstract_array import AnyDocumentArray From 29c5b439d3ca0e23ada05b7b70b6d406b999e5da Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Mon, 20 Mar 2023 11:47:45 +0100 Subject: [PATCH 07/13] fix: equality and hash for parametrized tensors Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 33 ++++++++++++++++++---- tests/units/typing/tensor/test_tensor.py | 34 ++++++++++++++++++++++- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index c2451fa3272..add49b6b9a5 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -31,12 +31,14 @@ class _ParametrizedMeta(type): """ - This metaclass ensures that instance and subclass checks on parametrized Tensors + This metaclass ensures that instance, subclass and equality checks on parametrized Tensors are handled as expected: assert issubclass(TorchTensor[128], TorchTensor[128]) t = parse_obj_as(TorchTensor[128], torch.zeros(128)) assert isinstance(t, TorchTensor[128]) + TorchTensor[128] == TorchTensor[128] + hash(TorchTensor[128]) == hash(TorchTensor[128]) etc. This special handling is needed because every call to `AbstractTensor.__getitem__` @@ -44,11 +46,11 @@ class _ParametrizedMeta(type): We want technically distinct but identical classes to be considered equal. """ - def __subclasscheck__(cls, subclass): - is_tensor = AbstractTensor in subclass.mro() - same_parents = is_tensor and cls.mro()[1:] == subclass.mro()[1:] + def _equals_special_case(cls, other): + is_tensor = AbstractTensor in other.mro() + same_parents = is_tensor and cls.mro()[1:] == other.mro()[1:] - subclass_target_shape = getattr(subclass, '__docarray_target_shape__', False) + subclass_target_shape = getattr(other, '__docarray_target_shape__', False) self_target_shape = getattr(cls, '__docarray_target_shape__', False) same_shape = ( same_parents @@ -57,7 +59,10 @@ def __subclasscheck__(cls, subclass): and subclass_target_shape == self_target_shape ) - if same_shape: + return same_shape + + def __subclasscheck__(cls, subclass): + if cls._equals_special_case(subclass): return True return super().__subclasscheck__(subclass) @@ -80,6 +85,22 @@ def __instancecheck__(cls, instance): return any(issubclass(candidate, cls) for candidate in type(instance).mro()) return super().__instancecheck__(instance) + def __eq__(cls, other): + if cls._equals_special_case(other): + return True + return NotImplemented + + def __hash__(cls): + try: + cls_ = cast(AbstractTensor, cls) + return hash((cls_.__docarray_target_shape__, cls_.__unparametrizedcls__)) + except AttributeError: + raise NotImplementedError( + '`hash()` is not implemented for this class. The `_ParametrizedMeta` ' + 'metaclass should only be used for `AbstractTensor` subclasses. ' + 'Otherwise, you have to implement `__hash__` for your class yourself.' + ) + class AbstractTensor(Generic[TTensor, T], AbstractType, ABC): diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index 1d36a4c355e..37e62ece44f 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -4,7 +4,7 @@ from pydantic.tools import parse_obj_as, schema_json_of from docarray.base_document.io.json import orjson_dumps -from docarray.typing import NdArray +from docarray.typing import AudioNdArray, NdArray from docarray.typing.tensor import NdArrayEmbedding @@ -158,3 +158,35 @@ def test_parametrized_operations(): assert isinstance(t_result, np.ndarray) assert isinstance(t_result, NdArray) assert isinstance(t_result, NdArray[128]) + + +def test_class_equality(): + assert NdArray == NdArray + assert NdArray[128] == NdArray[128] + assert NdArray[128] != NdArray[256] + assert NdArray[128] != NdArray[2, 64] + assert not NdArray[128] == NdArray[2, 64] + + assert NdArrayEmbedding == NdArrayEmbedding + assert NdArrayEmbedding[128] == NdArrayEmbedding[128] + assert NdArrayEmbedding[128] != NdArrayEmbedding[256] + + assert AudioNdArray == AudioNdArray + assert AudioNdArray[128] == AudioNdArray[128] + assert AudioNdArray[128] != AudioNdArray[256] + + +def test_class_hash(): + assert hash(NdArray) == hash(NdArray) + assert hash(NdArray[128]) == hash(NdArray[128]) + assert hash(NdArray[128]) != hash(NdArray[256]) + assert hash(NdArray[128]) != hash(NdArray[2, 64]) + assert not hash(NdArray[128]) == hash(NdArray[2, 64]) + + assert hash(NdArrayEmbedding) == hash(NdArrayEmbedding) + assert hash(NdArrayEmbedding[128]) == hash(NdArrayEmbedding[128]) + assert hash(NdArrayEmbedding[128]) != hash(NdArrayEmbedding[256]) + + assert hash(AudioNdArray) == hash(AudioNdArray) + assert hash(AudioNdArray[128]) == hash(AudioNdArray[128]) + assert hash(AudioNdArray[128]) != hash(AudioNdArray[256]) From e48df95d3ff7af49c5350f9ea2af1a9986c42269 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Mon, 20 Mar 2023 11:50:13 +0100 Subject: [PATCH 08/13] test: add test for flatten docs Signed-off-by: Johannes Messner --- .../base_classes/test_base_doc_store.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/doc_index/base_classes/test_base_doc_store.py b/tests/doc_index/base_classes/test_base_doc_store.py index 1457808aafb..8d2484addb5 100644 --- a/tests/doc_index/base_classes/test_base_doc_store.py +++ b/tests/doc_index/base_classes/test_base_doc_store.py @@ -141,6 +141,43 @@ def test_create_columns(): assert store._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} +def test_flatten_schema(): + store = DummyDocIndex[SimpleDoc]() + fields = SimpleDoc.__fields__ + assert set(store._flatten_schema(SimpleDoc)) == { + ('id', ID, fields['id']), + ('tens', NdArray[10], fields['tens']), + } + + store = DummyDocIndex[FlatDoc]() + fields = FlatDoc.__fields__ + assert set(store._flatten_schema(FlatDoc)) == { + ('id', ID, fields['id']), + ('tens_one', NdArray, fields['tens_one']), + ('tens_two', NdArray, fields['tens_two']), + } + + store = DummyDocIndex[NestedDoc]() + fields = NestedDoc.__fields__ + fields_nested = SimpleDoc.__fields__ + assert set(store._flatten_schema(NestedDoc)) == { + ('id', ID, fields['id']), + ('d__id', ID, fields_nested['id']), + ('d__tens', NdArray[10], fields_nested['tens']), + } + + store = DummyDocIndex[DeepNestedDoc]() + fields = DeepNestedDoc.__fields__ + fields_nested = NestedDoc.__fields__ + fields_nested_nested = SimpleDoc.__fields__ + assert set(store._flatten_schema(DeepNestedDoc)) == { + ('id', ID, fields['id']), + ('d__id', ID, fields_nested['id']), + ('d__d__id', ID, fields_nested_nested['id']), + ('d__d__tens', NdArray[10], fields_nested_nested['tens']), + } + + def test_docs_validation(): class OtherSimpleDoc(SimpleDoc): ... From dfa5b0439a349ae476fdb8169198dc9c4d809119 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Mon, 20 Mar 2023 11:50:34 +0100 Subject: [PATCH 09/13] refactor: apply suggestions Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index 7956ba8c1e8..7d879b7daa5 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -628,10 +628,12 @@ def build_query(self) -> QueryBuilder: """ return self.QueryBuilder() # type: ignore + @classmethod def _flatten_schema( - self, schema: Type[BaseDocument], name_prefix: str = '' + cls, schema: Type[BaseDocument], name_prefix: str = '' ) -> List[Tuple[str, Type, 'ModelField']]: """Flatten the schema of a Document into a list of column names and types. + Nested Documents are handled in a recursive manner by adding `'__'` as a prefix to the column name. :param schema: The schema to flatten :param name_prefix: prefix to append to the column names. Used for recursive calls to handle nesting. @@ -652,13 +654,13 @@ def _flatten_schema( pass elif issubclass(t_arg, BaseDocument): names_types_fields.extend( - self._flatten_schema(t_arg, name_prefix=inner_prefix) + cls._flatten_schema(t_arg, name_prefix=inner_prefix) ) else: names_types_fields.append((field_name, t_, field_)) elif issubclass(t_, BaseDocument): names_types_fields.extend( - self._flatten_schema(t_, name_prefix=inner_prefix) + cls._flatten_schema(t_, name_prefix=inner_prefix) ) else: names_types_fields.append((name_prefix + field_name, t_, field_)) From f2136f1e51e2fa2a00cb321a001bad2bc6bf3424 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Mon, 20 Mar 2023 15:48:49 +0100 Subject: [PATCH 10/13] docs: better docstrings Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index 7d879b7daa5..9993bcff5fe 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -669,6 +669,12 @@ def _flatten_schema( def _create_column_infos( self, schema: Type[BaseDocument] ) -> Dict[str, _ColumnInfo]: + """Collects information about every column that is implied by a given schema. + + :param schema: The schema (subclass of BaseDocument) to analyze and parse + columns from + :returns: A dictionary mapping from column names to column information. + """ column_infos: Dict[str, _ColumnInfo] = dict() for field_name, type_, field_ in self._flatten_schema(schema): if is_union_type(type_): @@ -708,6 +714,17 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo def _validate_docs( self, docs: Union[BaseDocument, Sequence[BaseDocument]] ) -> DocumentArray[BaseDocument]: + """Validates Document against the schema of the Document Index. + For validation to pass, the schema of `docs` and the schema of the Document + Index need to evaluate to the same flattened columns. + If Validation fails, a ValueError is raised. + + :param docs: Document to evaluate. If this is a DocumentArray, validation is + performed using its `doc_type` (parametrization), without having to check + ever Document in `docs`. If this check fails, or if `docs` is not a + DocumentArray, evaluation is performed for every Document in `docs`. + :return: A DocumentArray containing the Documents in `docs` + """ if isinstance(docs, BaseDocument): docs = [docs] if isinstance(docs, DocumentArray): From 25fd665817009c5d49ca349cb7184f3d4ad9c25b Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Mon, 20 Mar 2023 16:30:29 +0100 Subject: [PATCH 11/13] refactor: use construct to create docarray Signed-off-by: Johannes Messner --- docarray/doc_index/abstract_doc_index.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index 5bf174c4cdd..ac19d1a0bcc 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -761,9 +761,7 @@ def _validate_docs( ' and that the types of your data match the types of the Document Index schema.' ) - return DocumentArray[BaseDocument]( - out_docs - ) # TODO(johannes): use `construct` here to avoid validating again + return DocumentArray[BaseDocument].construct(out_docs) def _to_numpy(self, val: Any) -> Any: if isinstance(val, np.ndarray): From 32264315fd8623ba9afd632e66e47451046b5e95 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Mon, 20 Mar 2023 17:12:26 +0100 Subject: [PATCH 12/13] fix: check for nonetype Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 80a96f77f8d..50b76d77b18 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -48,7 +48,9 @@ class _ParametrizedMeta(type): """ def _equals_special_case(cls, other): - is_tensor = AbstractTensor in other.mro() + is_type = isinstance(other, type) + is_none_type = is_type and issubclass(other, type(None)) + is_tensor = (not is_none_type) and AbstractTensor in other.mro() same_parents = is_tensor and cls.mro()[1:] == other.mro()[1:] subclass_target_shape = getattr(other, '__docarray_target_shape__', False) From b45dfea6cf70003c95b3d65302d280a968e99dd0 Mon Sep 17 00:00:00 2001 From: Johannes Messner Date: Mon, 20 Mar 2023 17:32:20 +0100 Subject: [PATCH 13/13] fix: none in equals check Signed-off-by: Johannes Messner --- docarray/typing/tensor/abstract_tensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 50b76d77b18..049151d3a47 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -49,8 +49,7 @@ class _ParametrizedMeta(type): def _equals_special_case(cls, other): is_type = isinstance(other, type) - is_none_type = is_type and issubclass(other, type(None)) - is_tensor = (not is_none_type) and AbstractTensor in other.mro() + is_tensor = is_type and AbstractTensor in other.mro() same_parents = is_tensor and cls.mro()[1:] == other.mro()[1:] subclass_target_shape = getattr(other, '__docarray_target_shape__', False)