diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index e55096d1e60..6a527b43bbb 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -11,6 +11,7 @@ NamedTuple, Optional, Sequence, + Tuple, Type, TypeVar, Union, @@ -18,7 +19,8 @@ ) import numpy as np -from typing_inspect import is_union_type +from pydantic.error_wrappers import ValidationError +from typing_inspect import get_args, is_union_type from docarray import BaseDocument, DocumentArray from docarray.array.abstract_array import AnyDocumentArray @@ -89,7 +91,9 @@ def __init__(self, db_config=None, **kwargs): if not isinstance(self._db_config, self.DBConfig): raise ValueError(f'db_config must be of type {self.DBConfig}') self._runtime_config = self.RuntimeConfig() - self._column_infos: Dict[str, _ColumnInfo] = self._create_columns(self._schema) + self._column_infos: Dict[str, _ColumnInfo] = self._create_column_infos( + self._schema + ) ############################################### # Inner classes for query builder and configs # @@ -367,9 +371,12 @@ def configure(self, runtime_config=None, **kwargs): def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs): """index Documents into the index. - :param docs: Documents to index + :param docs: Documents to index. NOTE: passing a Sequence of Documents that is + not a DocumentArray comes at a performance penalty, since compatibility + with the Index's schema need to be checked for every Document individually. """ - data_by_columns = self._get_col_value_dict(docs) + docs_validated = self._validate_docs(docs) + data_by_columns = self._get_col_value_dict(docs_validated) self._index(data_by_columns, **kwargs) def find( @@ -585,11 +592,6 @@ def _get_col_value_dict( docs_seq: Sequence[BaseDocument] = [docs] else: docs_seq = docs - if not self._is_schema_compatible(docs_seq): - raise ValueError( - 'The schema of the documents to be indexed is not compatible' - ' with the schema of the index.' - ) def _col_gen(col_name: str): return (self._get_values_by_column([doc], col_name)[0] for doc in docs_seq) @@ -626,31 +628,68 @@ def build_query(self) -> QueryBuilder: """ return self.QueryBuilder() # type: ignore - def _create_columns(self, schema: Type[BaseDocument]) -> Dict[str, _ColumnInfo]: - columns: Dict[str, _ColumnInfo] = dict() + @classmethod + def _flatten_schema( + cls, schema: Type[BaseDocument], name_prefix: str = '' + ) -> List[Tuple[str, Type, 'ModelField']]: + """Flatten the schema of a Document into a list of column names and types. + Nested Documents are handled in a recursive manner by adding `'__'` as a prefix to the column name. + + :param schema: The schema to flatten + :param name_prefix: prefix to append to the column names. Used for recursive calls to handle nesting. + :return: A list of column names, types, and fields + """ + names_types_fields: List[Tuple[str, Type, 'ModelField']] = [] for field_name, field_ in schema.__fields__.items(): t_ = schema._get_field_type(field_name) + inner_prefix = name_prefix + field_name + '__' + if is_union_type(t_): + union_args = get_args(t_) + if len(union_args) == 2 and type(None) in union_args: + # simple "Optional" type, treat as special case: + # treat as if it was a single non-optional type + for t_arg in union_args: + if t_arg is type(None): + pass + elif issubclass(t_arg, BaseDocument): + names_types_fields.extend( + cls._flatten_schema(t_arg, name_prefix=inner_prefix) + ) + else: + names_types_fields.append((field_name, t_, field_)) + elif issubclass(t_, BaseDocument): + names_types_fields.extend( + cls._flatten_schema(t_, name_prefix=inner_prefix) + ) + else: + names_types_fields.append((name_prefix + field_name, t_, field_)) + return names_types_fields + + def _create_column_infos( + self, schema: Type[BaseDocument] + ) -> Dict[str, _ColumnInfo]: + """Collects information about every column that is implied by a given schema. + + :param schema: The schema (subclass of BaseDocument) to analyze and parse + columns from + :returns: A dictionary mapping from column names to column information. + """ + column_infos: Dict[str, _ColumnInfo] = dict() + for field_name, type_, field_ in self._flatten_schema(schema): + if is_union_type(type_): raise ValueError( 'Union types are not supported in the schema of a DocumentIndex.' - f' Instead of using type {t_} use a single specific type.' + f' Instead of using type {type_} use a single specific type.' ) - elif issubclass(t_, AnyDocumentArray): + elif issubclass(type_, AnyDocumentArray): raise ValueError( 'Indexing field of DocumentArray type (=subindex)' 'is not yet supported.' ) - elif issubclass(t_, BaseDocument): - columns = dict( - columns, - **{ - f'{field_name}__{nested_name}': t - for nested_name, t in self._create_columns(t_).items() - }, - ) else: - columns[field_name] = self._create_single_column(field_, t_) - return columns + column_infos[field_name] = self._create_single_column(field_, type_) + return column_infos def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo: custom_config = field.field_info.extra @@ -682,30 +721,57 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo docarray_type=type_, db_type=db_type, config=config, n_dim=n_dim ) - def _is_schema_compatible(self, docs: Sequence[BaseDocument]) -> bool: - """Flatten a DocumentArray into a DocumentArray of the schema type.""" - reference_col_db_types = [ - (name, col.db_type) for name, col in self._column_infos.items() - ] - if isinstance(docs, AnyDocumentArray): - input_columns = self._create_columns(docs.document_type) - input_col_db_types = [ - (name, col.db_type) for name, col in input_columns.items() - ] + def _validate_docs( + self, docs: Union[BaseDocument, Sequence[BaseDocument]] + ) -> DocumentArray[BaseDocument]: + """Validates Document against the schema of the Document Index. + For validation to pass, the schema of `docs` and the schema of the Document + Index need to evaluate to the same flattened columns. + If Validation fails, a ValueError is raised. + + :param docs: Document to evaluate. If this is a DocumentArray, validation is + performed using its `doc_type` (parametrization), without having to check + ever Document in `docs`. If this check fails, or if `docs` is not a + DocumentArray, evaluation is performed for every Document in `docs`. + :return: A DocumentArray containing the Documents in `docs` + """ + if isinstance(docs, BaseDocument): + docs = [docs] + if isinstance(docs, DocumentArray): + # validation shortcut for DocumentArray; only look at the schema + reference_schema_flat = self._flatten_schema( + cast(Type[BaseDocument], self._schema) + ) + reference_names = [name for (name, _, _) in reference_schema_flat] + reference_types = [t_ for (_, t_, _) in reference_schema_flat] + + input_schema_flat = self._flatten_schema(docs.document_type) + input_names = [name for (name, _, _) in input_schema_flat] + input_types = [t_ for (_, t_, _) in input_schema_flat] # this could be relaxed in the future, # see schema translation ideas in the design doc - return reference_col_db_types == input_col_db_types - else: - for d in docs: - input_columns = self._create_columns(type(d)) - input_col_db_types = [ - (name, col.db_type) for name, col in input_columns.items() - ] - # this could be relaxed in the future, - # see schema translation ideas in the design doc - if reference_col_db_types != input_col_db_types: - return False - return True + names_compatible = reference_names == input_names + types_compatible = all( + (not is_union_type(t2) and issubclass(t1, t2)) + for (t1, t2) in zip(reference_types, input_types) + ) + if names_compatible and types_compatible: + return docs + out_docs = [] + for i in range(len(docs)): + # validate the data + try: + out_docs.append( + cast(Type[BaseDocument], self._schema).parse_obj(docs[i]) + ) + except (ValueError, ValidationError): + raise ValueError( + 'The schema of the input Documents is not compatible with the schema of the Document Index.' + ' Ensure that the field names of your data match the field names of the Document Index schema,' + ' and that the types of your data match the types of the Document Index schema.' + ) + + return DocumentArray[BaseDocument].construct(out_docs) def _to_numpy(self, val: Any) -> Any: if isinstance(val, np.ndarray): diff --git a/docarray/doc_index/backends/hnswlib_doc_index.py b/docarray/doc_index/backends/hnswlib_doc_index.py index a560249a6dc..3c8c2df14e3 100644 --- a/docarray/doc_index/backends/hnswlib_doc_index.py +++ b/docarray/doc_index/backends/hnswlib_doc_index.py @@ -161,9 +161,9 @@ def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs): """index a document into the store""" if kwargs: raise ValueError(f'{list(kwargs.keys())} are not valid keyword arguments') - doc_seq = docs if isinstance(docs, Sequence) else [docs] - data_by_columns = self._get_col_value_dict(doc_seq) - hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in doc_seq) + docs_validated = self._validate_docs(docs) + data_by_columns = self._get_col_value_dict(docs_validated) + hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in docs_validated) # indexing into HNSWLib and SQLite sequentially # could be improved by processing in parallel @@ -174,7 +174,7 @@ def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs): index.add_items(data_stacked, ids=hashed_ids) index.save_index(self._hnsw_locations[col_name]) - self._send_docs_to_sqlite(doc_seq) + self._send_docs_to_sqlite(docs_validated) self._sqlite_conn.commit() def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: diff --git a/docarray/typing/tensor/abstract_tensor.py b/docarray/typing/tensor/abstract_tensor.py index 278645509dc..049151d3a47 100644 --- a/docarray/typing/tensor/abstract_tensor.py +++ b/docarray/typing/tensor/abstract_tensor.py @@ -32,12 +32,14 @@ class _ParametrizedMeta(type): """ - This metaclass ensures that instance and subclass checks on parametrized Tensors + This metaclass ensures that instance, subclass and equality checks on parametrized Tensors are handled as expected: assert issubclass(TorchTensor[128], TorchTensor[128]) t = parse_obj_as(TorchTensor[128], torch.zeros(128)) assert isinstance(t, TorchTensor[128]) + TorchTensor[128] == TorchTensor[128] + hash(TorchTensor[128]) == hash(TorchTensor[128]) etc. This special handling is needed because every call to `AbstractTensor.__getitem__` @@ -45,11 +47,12 @@ class _ParametrizedMeta(type): We want technically distinct but identical classes to be considered equal. """ - def __subclasscheck__(cls, subclass): - is_tensor = AbstractTensor in subclass.mro() - same_parents = is_tensor and cls.mro()[1:] == subclass.mro()[1:] + def _equals_special_case(cls, other): + is_type = isinstance(other, type) + is_tensor = is_type and AbstractTensor in other.mro() + same_parents = is_tensor and cls.mro()[1:] == other.mro()[1:] - subclass_target_shape = getattr(subclass, '__docarray_target_shape__', False) + subclass_target_shape = getattr(other, '__docarray_target_shape__', False) self_target_shape = getattr(cls, '__docarray_target_shape__', False) same_shape = ( same_parents @@ -58,7 +61,10 @@ def __subclasscheck__(cls, subclass): and subclass_target_shape == self_target_shape ) - if same_shape: + return same_shape + + def __subclasscheck__(cls, subclass): + if cls._equals_special_case(subclass): return True return super().__subclasscheck__(subclass) @@ -81,6 +87,22 @@ def __instancecheck__(cls, instance): return any(issubclass(candidate, cls) for candidate in type(instance).mro()) return super().__instancecheck__(instance) + def __eq__(cls, other): + if cls._equals_special_case(other): + return True + return NotImplemented + + def __hash__(cls): + try: + cls_ = cast(AbstractTensor, cls) + return hash((cls_.__docarray_target_shape__, cls_.__unparametrizedcls__)) + except AttributeError: + raise NotImplementedError( + '`hash()` is not implemented for this class. The `_ParametrizedMeta` ' + 'metaclass should only be used for `AbstractTensor` subclasses. ' + 'Otherwise, you have to implement `__hash__` for your class yourself.' + ) + class AbstractTensor(Generic[TTensor, T], AbstractType, ABC, Sized): __parametrized_meta__: type = _ParametrizedMeta diff --git a/tests/doc_index/base_classes/test_base_doc_store.py b/tests/doc_index/base_classes/test_base_doc_store.py index e7e8357ef42..f50c2d8cd97 100644 --- a/tests/doc_index/base_classes/test_base_doc_store.py +++ b/tests/doc_index/base_classes/test_base_doc_store.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Type +from typing import Any, Dict, Optional, Type, Union import numpy as np import pytest @@ -145,6 +145,43 @@ def test_create_columns(): assert store._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} +def test_flatten_schema(): + store = DummyDocIndex[SimpleDoc]() + fields = SimpleDoc.__fields__ + assert set(store._flatten_schema(SimpleDoc)) == { + ('id', ID, fields['id']), + ('tens', NdArray[10], fields['tens']), + } + + store = DummyDocIndex[FlatDoc]() + fields = FlatDoc.__fields__ + assert set(store._flatten_schema(FlatDoc)) == { + ('id', ID, fields['id']), + ('tens_one', NdArray, fields['tens_one']), + ('tens_two', NdArray, fields['tens_two']), + } + + store = DummyDocIndex[NestedDoc]() + fields = NestedDoc.__fields__ + fields_nested = SimpleDoc.__fields__ + assert set(store._flatten_schema(NestedDoc)) == { + ('id', ID, fields['id']), + ('d__id', ID, fields_nested['id']), + ('d__tens', NdArray[10], fields_nested['tens']), + } + + store = DummyDocIndex[DeepNestedDoc]() + fields = DeepNestedDoc.__fields__ + fields_nested = NestedDoc.__fields__ + fields_nested_nested = SimpleDoc.__fields__ + assert set(store._flatten_schema(DeepNestedDoc)) == { + ('id', ID, fields['id']), + ('d__id', ID, fields_nested['id']), + ('d__d__id', ID, fields_nested_nested['id']), + ('d__d__tens', NdArray[10], fields_nested_nested['tens']), + } + + def test_columns_db_type_with_user_defined_mapping(tmp_path): class MyDoc(BaseDocument): tens: NdArray[10] = Field(dim=1000, col_type=np.ndarray) @@ -174,7 +211,7 @@ class MyDoc(BaseDocument): DummyDocIndex[MyDoc](work_dir=str(tmp_path)) -def test_is_schema_compatible(): +def test_docs_validation(): class OtherSimpleDoc(SimpleDoc): ... @@ -184,81 +221,115 @@ class OtherFlatDoc(FlatDoc): class OtherNestedDoc(NestedDoc): ... + # SIMPLE store = DummyDocIndex[SimpleDoc]() - assert store._is_schema_compatible([SimpleDoc(tens=np.random.random((10,)))]) - assert store._is_schema_compatible( - DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) - ) - assert store._is_schema_compatible([OtherSimpleDoc(tens=np.random.random((10,)))]) - assert store._is_schema_compatible( - DocumentArray[OtherSimpleDoc]([OtherSimpleDoc(tens=np.random.random((10,)))]) - ) - assert not store._is_schema_compatible( - [FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))] - ) - assert not store._is_schema_compatible( - DocumentArray[FlatDoc]( + in_list = [SimpleDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[SimpleDoc](in_list) + assert store._validate_docs(in_da) == in_da + in_other_list = [OtherSimpleDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_other_list), DocumentArray[BaseDocument]) + in_other_da = DocumentArray[OtherSimpleDoc](in_other_list) + assert store._validate_docs(in_other_da) == in_other_da + + with pytest.raises(ValueError): + store._validate_docs( [ FlatDoc( tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) ) ] ) - ) + with pytest.raises(ValueError): + store._validate_docs( + DocumentArray[FlatDoc]( + [ + FlatDoc( + tens_one=np.random.random((10,)), + tens_two=np.random.random((50,)), + ) + ] + ) + ) + # FLAT store = DummyDocIndex[FlatDoc]() - assert store._is_schema_compatible( + in_list = [ + FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) + ] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[FlatDoc]( [FlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,)))] ) - assert store._is_schema_compatible( - DocumentArray[FlatDoc]( - [ - FlatDoc( - tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) - ) - ] - ) - ) - assert store._is_schema_compatible( + assert store._validate_docs(in_da) == in_da + in_other_list = [ + OtherFlatDoc(tens_one=np.random.random((10,)), tens_two=np.random.random((50,))) + ] + assert isinstance(store._validate_docs(in_other_list), DocumentArray[BaseDocument]) + in_other_da = DocumentArray[OtherFlatDoc]( [ OtherFlatDoc( tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) ) ] ) - assert store._is_schema_compatible( - DocumentArray[OtherFlatDoc]( - [ - OtherFlatDoc( - tens_one=np.random.random((10,)), tens_two=np.random.random((50,)) - ) - ] + assert store._validate_docs(in_other_da) == in_other_da + with pytest.raises(ValueError): + store._validate_docs([SimpleDoc(tens=np.random.random((10,)))]) + with pytest.raises(ValueError): + assert not store._validate_docs( + DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) ) - ) - assert not store._is_schema_compatible([SimpleDoc(tens=np.random.random((10,)))]) - assert not store._is_schema_compatible( - DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) - ) + # NESTED store = DummyDocIndex[NestedDoc]() - assert store._is_schema_compatible( + in_list = [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[NestedDoc]( [NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))] ) - assert store._is_schema_compatible( - DocumentArray[NestedDoc]([NestedDoc(d=SimpleDoc(tens=np.random.random((10,))))]) - ) - assert store._is_schema_compatible( + assert store._validate_docs(in_da) == in_da + in_other_list = [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] + assert isinstance(store._validate_docs(in_other_list), DocumentArray[BaseDocument]) + in_other_da = DocumentArray[OtherNestedDoc]( [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] ) - assert store._is_schema_compatible( - DocumentArray[OtherNestedDoc]( - [OtherNestedDoc(d=OtherSimpleDoc(tens=np.random.random((10,))))] + + assert store._validate_docs(in_other_da) == in_other_da + with pytest.raises(ValueError): + store._validate_docs([SimpleDoc(tens=np.random.random((10,)))]) + with pytest.raises(ValueError): + store._validate_docs( + DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) ) - ) - assert not store._is_schema_compatible([SimpleDoc(tens=np.random.random((10,)))]) - assert not store._is_schema_compatible( - DocumentArray[SimpleDoc]([SimpleDoc(tens=np.random.random((10,)))]) - ) + + +def test_docs_validation_unions(): + class OptionalDoc(BaseDocument): + tens: Optional[NdArray[10]] = Field(dim=1000) + + class UnionDoc(BaseDocument): + tens: Union[NdArray[10], str] = Field(dim=1000) + + # OPTIONAL + store = DummyDocIndex[SimpleDoc]() + in_list = [OptionalDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[OptionalDoc](in_list) + assert store._validate_docs(in_da) == in_da + + with pytest.raises(ValueError): + store._validate_docs([OptionalDoc(tens=None)]) + + # OTHER UNION + store = DummyDocIndex[SimpleDoc]() + in_list = [UnionDoc(tens=np.random.random((10,)))] + assert isinstance(store._validate_docs(in_list), DocumentArray[BaseDocument]) + in_da = DocumentArray[UnionDoc](in_list) + assert isinstance(store._validate_docs(in_da), DocumentArray[BaseDocument]) + + with pytest.raises(ValueError): + store._validate_docs([UnionDoc(tens='hello')]) def test_get_value(): diff --git a/tests/units/typing/tensor/test_tensor.py b/tests/units/typing/tensor/test_tensor.py index 99457e96588..a72ba18769d 100644 --- a/tests/units/typing/tensor/test_tensor.py +++ b/tests/units/typing/tensor/test_tensor.py @@ -5,7 +5,7 @@ from pydantic.tools import parse_obj_as, schema_json_of from docarray.base_document.io.json import orjson_dumps -from docarray.typing import NdArray, TorchTensor +from docarray.typing import AudioNdArray, NdArray, TorchTensor from docarray.typing.tensor import NdArrayEmbedding @@ -206,3 +206,35 @@ def test_parametrized_operations(): assert isinstance(t_result, np.ndarray) assert isinstance(t_result, NdArray) assert isinstance(t_result, NdArray[128]) + + +def test_class_equality(): + assert NdArray == NdArray + assert NdArray[128] == NdArray[128] + assert NdArray[128] != NdArray[256] + assert NdArray[128] != NdArray[2, 64] + assert not NdArray[128] == NdArray[2, 64] + + assert NdArrayEmbedding == NdArrayEmbedding + assert NdArrayEmbedding[128] == NdArrayEmbedding[128] + assert NdArrayEmbedding[128] != NdArrayEmbedding[256] + + assert AudioNdArray == AudioNdArray + assert AudioNdArray[128] == AudioNdArray[128] + assert AudioNdArray[128] != AudioNdArray[256] + + +def test_class_hash(): + assert hash(NdArray) == hash(NdArray) + assert hash(NdArray[128]) == hash(NdArray[128]) + assert hash(NdArray[128]) != hash(NdArray[256]) + assert hash(NdArray[128]) != hash(NdArray[2, 64]) + assert not hash(NdArray[128]) == hash(NdArray[2, 64]) + + assert hash(NdArrayEmbedding) == hash(NdArrayEmbedding) + assert hash(NdArrayEmbedding[128]) == hash(NdArrayEmbedding[128]) + assert hash(NdArrayEmbedding[128]) != hash(NdArrayEmbedding[256]) + + assert hash(AudioNdArray) == hash(AudioNdArray) + assert hash(AudioNdArray[128]) == hash(AudioNdArray[128]) + assert hash(AudioNdArray[128]) != hash(AudioNdArray[256])