Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 111 additions & 45 deletions docarray/doc_index/abstract_doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
NamedTuple,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)

import numpy as np
from typing_inspect import is_union_type
from pydantic.error_wrappers import ValidationError
from typing_inspect import get_args, is_union_type

from docarray import BaseDocument, DocumentArray
from docarray.array.abstract_array import AnyDocumentArray
Expand Down Expand Up @@ -89,7 +91,9 @@ def __init__(self, db_config=None, **kwargs):
if not isinstance(self._db_config, self.DBConfig):
raise ValueError(f'db_config must be of type {self.DBConfig}')
self._runtime_config = self.RuntimeConfig()
self._column_infos: Dict[str, _ColumnInfo] = self._create_columns(self._schema)
self._column_infos: Dict[str, _ColumnInfo] = self._create_column_infos(
self._schema
)

###############################################
# Inner classes for query builder and configs #
Expand Down Expand Up @@ -367,9 +371,12 @@ def configure(self, runtime_config=None, **kwargs):
def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs):
"""index Documents into the index.

:param docs: Documents to index
:param docs: Documents to index. NOTE: passing a Sequence of Documents that is
not a DocumentArray comes at a performance penalty, since compatibility
with the Index's schema need to be checked for every Document individually.
"""
data_by_columns = self._get_col_value_dict(docs)
docs_validated = self._validate_docs(docs)
data_by_columns = self._get_col_value_dict(docs_validated)
self._index(data_by_columns, **kwargs)

def find(
Expand Down Expand Up @@ -585,11 +592,6 @@ def _get_col_value_dict(
docs_seq: Sequence[BaseDocument] = [docs]
else:
docs_seq = docs
if not self._is_schema_compatible(docs_seq):
raise ValueError(
'The schema of the documents to be indexed is not compatible'
' with the schema of the index.'
)

def _col_gen(col_name: str):
return (self._get_values_by_column([doc], col_name)[0] for doc in docs_seq)
Expand Down Expand Up @@ -626,31 +628,68 @@ def build_query(self) -> QueryBuilder:
"""
return self.QueryBuilder() # type: ignore

def _create_columns(self, schema: Type[BaseDocument]) -> Dict[str, _ColumnInfo]:
columns: Dict[str, _ColumnInfo] = dict()
@classmethod
def _flatten_schema(
cls, schema: Type[BaseDocument], name_prefix: str = ''
) -> List[Tuple[str, Type, 'ModelField']]:
"""Flatten the schema of a Document into a list of column names and types.
Nested Documents are handled in a recursive manner by adding `'__'` as a prefix to the column name.

:param schema: The schema to flatten
:param name_prefix: prefix to append to the column names. Used for recursive calls to handle nesting.
:return: A list of column names, types, and fields
"""
names_types_fields: List[Tuple[str, Type, 'ModelField']] = []
for field_name, field_ in schema.__fields__.items():
t_ = schema._get_field_type(field_name)
inner_prefix = name_prefix + field_name + '__'

if is_union_type(t_):
union_args = get_args(t_)
if len(union_args) == 2 and type(None) in union_args:
# simple "Optional" type, treat as special case:
# treat as if it was a single non-optional type
for t_arg in union_args:
if t_arg is type(None):
pass
elif issubclass(t_arg, BaseDocument):
names_types_fields.extend(
cls._flatten_schema(t_arg, name_prefix=inner_prefix)
)
else:
names_types_fields.append((field_name, t_, field_))
elif issubclass(t_, BaseDocument):
names_types_fields.extend(
cls._flatten_schema(t_, name_prefix=inner_prefix)
)
else:
names_types_fields.append((name_prefix + field_name, t_, field_))
return names_types_fields

def _create_column_infos(
self, schema: Type[BaseDocument]
) -> Dict[str, _ColumnInfo]:
"""Collects information about every column that is implied by a given schema.

:param schema: The schema (subclass of BaseDocument) to analyze and parse
columns from
:returns: A dictionary mapping from column names to column information.
"""
column_infos: Dict[str, _ColumnInfo] = dict()
for field_name, type_, field_ in self._flatten_schema(schema):
if is_union_type(type_):
raise ValueError(
'Union types are not supported in the schema of a DocumentIndex.'
f' Instead of using type {t_} use a single specific type.'
f' Instead of using type {type_} use a single specific type.'
)
elif issubclass(t_, AnyDocumentArray):
elif issubclass(type_, AnyDocumentArray):
raise ValueError(
'Indexing field of DocumentArray type (=subindex)'
'is not yet supported.'
)
elif issubclass(t_, BaseDocument):
columns = dict(
columns,
**{
f'{field_name}__{nested_name}': t
for nested_name, t in self._create_columns(t_).items()
},
)
else:
columns[field_name] = self._create_single_column(field_, t_)
return columns
column_infos[field_name] = self._create_single_column(field_, type_)
return column_infos

def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo:
custom_config = field.field_info.extra
Expand Down Expand Up @@ -682,30 +721,57 @@ def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo
docarray_type=type_, db_type=db_type, config=config, n_dim=n_dim
)

def _is_schema_compatible(self, docs: Sequence[BaseDocument]) -> bool:
"""Flatten a DocumentArray into a DocumentArray of the schema type."""
reference_col_db_types = [
(name, col.db_type) for name, col in self._column_infos.items()
]
if isinstance(docs, AnyDocumentArray):
input_columns = self._create_columns(docs.document_type)
input_col_db_types = [
(name, col.db_type) for name, col in input_columns.items()
]
def _validate_docs(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to have docstrings for this and _create_column_infos

self, docs: Union[BaseDocument, Sequence[BaseDocument]]
) -> DocumentArray[BaseDocument]:
"""Validates Document against the schema of the Document Index.
For validation to pass, the schema of `docs` and the schema of the Document
Index need to evaluate to the same flattened columns.
If Validation fails, a ValueError is raised.

:param docs: Document to evaluate. If this is a DocumentArray, validation is
performed using its `doc_type` (parametrization), without having to check
ever Document in `docs`. If this check fails, or if `docs` is not a
DocumentArray, evaluation is performed for every Document in `docs`.
:return: A DocumentArray containing the Documents in `docs`
"""
if isinstance(docs, BaseDocument):
docs = [docs]
if isinstance(docs, DocumentArray):
# validation shortcut for DocumentArray; only look at the schema
reference_schema_flat = self._flatten_schema(
cast(Type[BaseDocument], self._schema)
)
reference_names = [name for (name, _, _) in reference_schema_flat]
reference_types = [t_ for (_, t_, _) in reference_schema_flat]

input_schema_flat = self._flatten_schema(docs.document_type)
input_names = [name for (name, _, _) in input_schema_flat]
input_types = [t_ for (_, t_, _) in input_schema_flat]
# this could be relaxed in the future,
# see schema translation ideas in the design doc
return reference_col_db_types == input_col_db_types
else:
for d in docs:
input_columns = self._create_columns(type(d))
input_col_db_types = [
(name, col.db_type) for name, col in input_columns.items()
]
# this could be relaxed in the future,
# see schema translation ideas in the design doc
if reference_col_db_types != input_col_db_types:
return False
return True
names_compatible = reference_names == input_names
types_compatible = all(
(not is_union_type(t2) and issubclass(t1, t2))
for (t1, t2) in zip(reference_types, input_types)
)
if names_compatible and types_compatible:
Copy link
Copy Markdown
Contributor

@jupyterjazz jupyterjazz Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if this is False, docs get checked again and only after that the error is thrown, right? should we somehow separate these cases so that we don't have to check docs twice? especially since the second check is a costly operation as we do it for each doc

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is false we check every doc individually, yes. But in order to calculate this truth value in line 729 we don't actually look at the documents, we only look at the schema of the DocumentArray, and all of the fields there.
Only if this is not successful we iterate over the docs. So I don't see the potential for optimization. Or am I misunderstanding your point?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if this is not successful we iterate over the docs

Do we need to iterate over docs if we already know that it's not successful? this was my point

But I guess if the schema is not compatible, error will be thrown right away for the first doc at line 767 so it's not really an iteration over all docs

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we still need to check because the actual data might be compatible even if the schema says otherwise.

For example, the input could be of type DocumentArray[BaseDocument], but contain only ImageDoc. This is valid since ImageDoc is a subclass of BaseDocument.
If the Document Index has schema ImageDoc, then the fast validation will fail, because BaseDoc != ImageDoc, but actually it should work, since all Documents are actually ImageDoc. So we have to do the exhaustive check over the data.

So basically the validation check based on the DocumentArray parametrization can be seen as just a shortcut that may or may not work.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, thanks for clarification!

return docs
out_docs = []
for i in range(len(docs)):
# validate the data
try:
out_docs.append(
cast(Type[BaseDocument], self._schema).parse_obj(docs[i])
)
except (ValueError, ValidationError):
raise ValueError(
'The schema of the input Documents is not compatible with the schema of the Document Index.'
' Ensure that the field names of your data match the field names of the Document Index schema,'
' and that the types of your data match the types of the Document Index schema.'
)

return DocumentArray[BaseDocument].construct(out_docs)

def _to_numpy(self, val: Any) -> Any:
if isinstance(val, np.ndarray):
Expand Down
8 changes: 4 additions & 4 deletions docarray/doc_index/backends/hnswlib_doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs):
"""index a document into the store"""
if kwargs:
raise ValueError(f'{list(kwargs.keys())} are not valid keyword arguments')
doc_seq = docs if isinstance(docs, Sequence) else [docs]
data_by_columns = self._get_col_value_dict(doc_seq)
hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in doc_seq)
docs_validated = self._validate_docs(docs)
data_by_columns = self._get_col_value_dict(docs_validated)
hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in docs_validated)

# indexing into HNSWLib and SQLite sequentially
# could be improved by processing in parallel
Expand All @@ -174,7 +174,7 @@ def index(self, docs: Union[BaseDocument, Sequence[BaseDocument]], **kwargs):
index.add_items(data_stacked, ids=hashed_ids)
index.save_index(self._hnsw_locations[col_name])

self._send_docs_to_sqlite(doc_seq)
self._send_docs_to_sqlite(docs_validated)
self._sqlite_conn.commit()

def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
Expand Down
34 changes: 28 additions & 6 deletions docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,27 @@

class _ParametrizedMeta(type):
"""
This metaclass ensures that instance and subclass checks on parametrized Tensors
This metaclass ensures that instance, subclass and equality checks on parametrized Tensors
are handled as expected:

assert issubclass(TorchTensor[128], TorchTensor[128])
t = parse_obj_as(TorchTensor[128], torch.zeros(128))
assert isinstance(t, TorchTensor[128])
TorchTensor[128] == TorchTensor[128]
hash(TorchTensor[128]) == hash(TorchTensor[128])
etc.

This special handling is needed because every call to `AbstractTensor.__getitem__`
creates a new class on the fly.
We want technically distinct but identical classes to be considered equal.
"""

def __subclasscheck__(cls, subclass):
is_tensor = AbstractTensor in subclass.mro()
same_parents = is_tensor and cls.mro()[1:] == subclass.mro()[1:]
def _equals_special_case(cls, other):
is_type = isinstance(other, type)
is_tensor = is_type and AbstractTensor in other.mro()
same_parents = is_tensor and cls.mro()[1:] == other.mro()[1:]

subclass_target_shape = getattr(subclass, '__docarray_target_shape__', False)
subclass_target_shape = getattr(other, '__docarray_target_shape__', False)
self_target_shape = getattr(cls, '__docarray_target_shape__', False)
same_shape = (
same_parents
Expand All @@ -58,7 +61,10 @@ def __subclasscheck__(cls, subclass):
and subclass_target_shape == self_target_shape
)

if same_shape:
return same_shape

def __subclasscheck__(cls, subclass):
if cls._equals_special_case(subclass):
return True
return super().__subclasscheck__(subclass)

Expand All @@ -81,6 +87,22 @@ def __instancecheck__(cls, instance):
return any(issubclass(candidate, cls) for candidate in type(instance).mro())
return super().__instancecheck__(instance)

def __eq__(cls, other):
if cls._equals_special_case(other):
return True
return NotImplemented

def __hash__(cls):
try:
cls_ = cast(AbstractTensor, cls)
return hash((cls_.__docarray_target_shape__, cls_.__unparametrizedcls__))
except AttributeError:
raise NotImplementedError(
'`hash()` is not implemented for this class. The `_ParametrizedMeta` '
'metaclass should only be used for `AbstractTensor` subclasses. '
'Otherwise, you have to implement `__hash__` for your class yourself.'
)


class AbstractTensor(Generic[TTensor, T], AbstractType, ABC, Sized):
__parametrized_meta__: type = _ParametrizedMeta
Expand Down
Loading