-
Notifications
You must be signed in to change notification settings - Fork 238
feat(index): index data with union types #1220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a93584a
3fcba13
d11a4b4
b4c1cae
0a3f93e
211bd95
a5ec4a6
29c5b43
e48df95
dfa5b04
f2136f1
1743b86
25fd665
5b42914
3226431
b45dfea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,14 +11,16 @@ | |
| NamedTuple, | ||
| Optional, | ||
| Sequence, | ||
| Tuple, | ||
| Type, | ||
| TypeVar, | ||
| Union, | ||
| cast, | ||
| ) | ||
|
|
||
| import numpy as np | ||
| from typing_inspect import is_union_type | ||
| from pydantic.error_wrappers import ValidationError | ||
| from typing_inspect import get_args, is_union_type | ||
|
|
||
| from docarray import BaseDocument, DocumentArray | ||
| from docarray.array.abstract_array import AnyDocumentArray | ||
|
|
@@ -89,7 +91,9 @@ def __init__(self, db_config=None, **kwargs): | |
| if not isinstance(self._db_config, self.DBConfig): | ||
| raise ValueError(f'db_config must be of type {self.DBConfig}') | ||
| self._runtime_config = self.RuntimeConfig() | ||
| self._column_infos: Dict[str, _ColumnInfo] = self._create_columns(self._schema) | ||
| self._column_infos: Dict[str, _ColumnInfo] = self._create_column_infos( | ||
| self._schema | ||
| ) | ||
|
|
||
| ############################################### | ||
| # Inner classes for query builder and configs # | ||
|
|
@@ -367,9 +371,12 @@ def configure(self, runtime_config=None, **kwargs): | |
| def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs): | ||
| """index Documents into the index. | ||
|
|
||
| :param docs: Documents to index | ||
| :param docs: Documents to index. NOTE: passing a Sequence of Documents that is | ||
| not a DocumentArray comes at a performance penalty, since compatibility | ||
| with the Index's schema need to be checked for every Document individually. | ||
| """ | ||
| data_by_columns = self._get_col_value_dict(docs) | ||
| docs_validated = self._validate_docs(docs) | ||
| data_by_columns = self._get_col_value_dict(docs_validated) | ||
| self._index(data_by_columns, **kwargs) | ||
|
|
||
| def find( | ||
|
|
@@ -585,11 +592,6 @@ def _get_col_value_dict( | |
| docs_seq: Sequence[BaseDocument] = [docs] | ||
| else: | ||
| docs_seq = docs | ||
| if not self._is_schema_compatible(docs_seq): | ||
| raise ValueError( | ||
| 'The schema of the documents to be indexed is not compatible' | ||
| ' with the schema of the index.' | ||
| ) | ||
|
|
||
| def _col_gen(col_name: str): | ||
| return (self._get_values_by_column([doc], col_name)[0] for doc in docs_seq) | ||
|
|
@@ -626,31 +628,68 @@ def build_query(self) -> QueryBuilder: | |
| """ | ||
| return self.QueryBuilder() # type: ignore | ||
|
|
||
| def _create_columns(self, schema: Type[BaseDocument]) -> Dict[str, _ColumnInfo]: | ||
| columns: Dict[str, _ColumnInfo] = dict() | ||
| @classmethod | ||
| def _flatten_schema( | ||
| cls, schema: Type[BaseDocument], name_prefix: str = '' | ||
| ) -> List[Tuple[str, Type, 'ModelField']]: | ||
| """Flatten the schema of a Document into a list of column names and types. | ||
JohannesMessner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Nested Documents are handled in a recursive manner by adding `'__'` as a prefix to the column name. | ||
|
|
||
| :param schema: The schema to flatten | ||
| :param name_prefix: prefix to append to the column names. Used for recursive calls to handle nesting. | ||
| :return: A list of column names, types, and fields | ||
| """ | ||
| names_types_fields: List[Tuple[str, Type, 'ModelField']] = [] | ||
| for field_name, field_ in schema.__fields__.items(): | ||
| t_ = schema._get_field_type(field_name) | ||
| inner_prefix = name_prefix + field_name + '__' | ||
|
|
||
| if is_union_type(t_): | ||
| union_args = get_args(t_) | ||
| if len(union_args) == 2 and type(None) in union_args: | ||
| # simple "Optional" type, treat as special case: | ||
| # treat as if it was a single non-optional type | ||
| for t_arg in union_args: | ||
| if t_arg is type(None): | ||
| pass | ||
| elif issubclass(t_arg, BaseDocument): | ||
JohannesMessner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| names_types_fields.extend( | ||
| cls._flatten_schema(t_arg, name_prefix=inner_prefix) | ||
| ) | ||
| else: | ||
| names_types_fields.append((field_name, t_, field_)) | ||
| elif issubclass(t_, BaseDocument): | ||
| names_types_fields.extend( | ||
| cls._flatten_schema(t_, name_prefix=inner_prefix) | ||
| ) | ||
| else: | ||
| names_types_fields.append((name_prefix + field_name, t_, field_)) | ||
| return names_types_fields | ||
|
|
||
| def _create_column_infos( | ||
| self, schema: Type[BaseDocument] | ||
| ) -> Dict[str, _ColumnInfo]: | ||
| """Collects information about every column that is implied by a given schema. | ||
|
|
||
| :param schema: The schema (subclass of BaseDocument) to analyze and parse | ||
| columns from | ||
| :returns: A dictionary mapping from column names to column information. | ||
| """ | ||
| column_infos: Dict[str, _ColumnInfo] = dict() | ||
| for field_name, type_, field_ in self._flatten_schema(schema): | ||
| if is_union_type(type_): | ||
| raise ValueError( | ||
| 'Union types are not supported in the schema of a DocumentIndex.' | ||
| f' Instead of using type {t_} use a single specific type.' | ||
| f' Instead of using type {type_} use a single specific type.' | ||
| ) | ||
| elif issubclass(t_, AnyDocumentArray): | ||
| elif issubclass(type_, AnyDocumentArray): | ||
| raise ValueError( | ||
| 'Indexing field of DocumentArray type (=subindex)' | ||
| 'is not yet supported.' | ||
| ) | ||
| elif issubclass(t_, BaseDocument): | ||
| columns = dict( | ||
| columns, | ||
| **{ | ||
| f'{field_name}__{nested_name}': t | ||
| for nested_name, t in self._create_columns(t_).items() | ||
| }, | ||
| ) | ||
| else: | ||
| columns[field_name] = self._create_single_column(field_, t_) | ||
| return columns | ||
| column_infos[field_name] = self._create_single_column(field_, type_) | ||
| return column_infos | ||
|
|
||
| def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo: | ||
| custom_config = field.field_info.extra | ||
|
|
@@ -682,30 +721,57 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo | |
| docarray_type=type_, db_type=db_type, config=config, n_dim=n_dim | ||
| ) | ||
|
|
||
| def _is_schema_compatible(self, docs: Sequence[BaseDocument]) -> bool: | ||
| """Flatten a DocumentArray into a DocumentArray of the schema type.""" | ||
| reference_col_db_types = [ | ||
| (name, col.db_type) for name, col in self._column_infos.items() | ||
| ] | ||
| if isinstance(docs, AnyDocumentArray): | ||
| input_columns = self._create_columns(docs.document_type) | ||
| input_col_db_types = [ | ||
| (name, col.db_type) for name, col in input_columns.items() | ||
| ] | ||
| def _validate_docs( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice to have docstrings for this and |
||
| self, docs: Union[BaseDocument, Sequence[BaseDocument]] | ||
| ) -> DocumentArray[BaseDocument]: | ||
| """Validates Document against the schema of the Document Index. | ||
| For validation to pass, the schema of `docs` and the schema of the Document | ||
| Index need to evaluate to the same flattened columns. | ||
| If Validation fails, a ValueError is raised. | ||
|
|
||
| :param docs: Document to evaluate. If this is a DocumentArray, validation is | ||
| performed using its `doc_type` (parametrization), without having to check | ||
| ever Document in `docs`. If this check fails, or if `docs` is not a | ||
| DocumentArray, evaluation is performed for every Document in `docs`. | ||
| :return: A DocumentArray containing the Documents in `docs` | ||
| """ | ||
| if isinstance(docs, BaseDocument): | ||
| docs = [docs] | ||
| if isinstance(docs, DocumentArray): | ||
| # validation shortcut for DocumentArray; only look at the schema | ||
| reference_schema_flat = self._flatten_schema( | ||
| cast(Type[BaseDocument], self._schema) | ||
| ) | ||
| reference_names = [name for (name, _, _) in reference_schema_flat] | ||
| reference_types = [t_ for (_, t_, _) in reference_schema_flat] | ||
|
|
||
| input_schema_flat = self._flatten_schema(docs.document_type) | ||
| input_names = [name for (name, _, _) in input_schema_flat] | ||
| input_types = [t_ for (_, t_, _) in input_schema_flat] | ||
| # this could be relaxed in the future, | ||
| # see schema translation ideas in the design doc | ||
| return reference_col_db_types == input_col_db_types | ||
| else: | ||
| for d in docs: | ||
| input_columns = self._create_columns(type(d)) | ||
| input_col_db_types = [ | ||
| (name, col.db_type) for name, col in input_columns.items() | ||
| ] | ||
| # this could be relaxed in the future, | ||
| # see schema translation ideas in the design doc | ||
| if reference_col_db_types != input_col_db_types: | ||
| return False | ||
| return True | ||
| names_compatible = reference_names == input_names | ||
| types_compatible = all( | ||
| (not is_union_type(t2) and issubclass(t1, t2)) | ||
| for (t1, t2) in zip(reference_types, input_types) | ||
| ) | ||
| if names_compatible and types_compatible: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so if this is False, docs get checked again and only after that the error is thrown, right? should we somehow separate these cases so that we don't have to check docs twice? especially since the second check is a costly operation as we do it for each doc
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is false we check every doc individually, yes. But in order to calculate this truth value in line 729 we don't actually look at the documents, we only look at the schema of the DocumentArray, and all of the fields there.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do we need to iterate over docs if we already know that it's not successful? this was my point But I guess if the schema is not compatible, error will be thrown right away for the first doc at line 767 so it's not really an iteration over all docs
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we still need to check because the actual data might be compatible even if the schema says otherwise. For example, the input could be of type So basically the validation check based on the DocumentArray parametrization can be seen as just a shortcut that may or may not work.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense, thanks for clarification! |
||
| return docs | ||
| out_docs = [] | ||
| for i in range(len(docs)): | ||
| # validate the data | ||
| try: | ||
| out_docs.append( | ||
| cast(Type[BaseDocument], self._schema).parse_obj(docs[i]) | ||
| ) | ||
| except (ValueError, ValidationError): | ||
| raise ValueError( | ||
| 'The schema of the input Documents is not compatible with the schema of the Document Index.' | ||
JohannesMessner marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ' Ensure that the field names of your data match the field names of the Document Index schema,' | ||
| ' and that the types of your data match the types of the Document Index schema.' | ||
| ) | ||
|
|
||
| return DocumentArray[BaseDocument].construct(out_docs) | ||
|
|
||
| def _to_numpy(self, val: Any) -> Any: | ||
| if isinstance(val, np.ndarray): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.