diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index fd81a9caa6a..51fcc9102f7 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -1,5 +1,15 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Sequence, + Type, + TypeVar, + Union, +) from docarray.base_document import BaseDocument from docarray.display.document_array_summary import DocumentArraySummary @@ -17,6 +27,7 @@ class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] tensor_type: Type['AbstractTensor'] = NdArray + __typed_da__: Dict[Type[BaseDocument], Type] = {} def __repr__(self): return f'<{self.__class__.__name__} (length={len(self)})>' @@ -27,28 +38,31 @@ def __class_getitem__(cls, item: Type[BaseDocument]): f'{cls.__name__}[item] item should be a Document not a {item} ' ) - class _DocumentArrayTyped(cls): # type: ignore - document_type: Type[BaseDocument] = item + if item not in cls.__typed_da__: - for field in _DocumentArrayTyped.document_type.__fields__.keys(): + class _DocumentArrayTyped(cls): # type: ignore + document_type: Type[BaseDocument] = item - def _property_generator(val: str): - def _getter(self): - return self._get_array_attribute(val) + for field in _DocumentArrayTyped.document_type.__fields__.keys(): - def _setter(self, value): - self._set_array_attribute(val, value) + def _property_generator(val: str): + def _getter(self): + return self._get_array_attribute(val) - # need docstring for the property - return property(fget=_getter, fset=_setter) + def _setter(self, value): + self._set_array_attribute(val, value) - setattr(_DocumentArrayTyped, field, _property_generator(field)) - # this generates property on the fly based on the schema of the item + # need docstring for the property + return property(fget=_getter, fset=_setter) - _DocumentArrayTyped.__name__ = f'{cls.__name__}[{item.__name__}]' - _DocumentArrayTyped.__qualname__ = f'{cls.__name__}[{item.__name__}]' + setattr(_DocumentArrayTyped, field, _property_generator(field)) + # this generates property on the fly based on the schema of the item - return _DocumentArrayTyped + _DocumentArrayTyped.__name__ = f'{cls.__name__}[{item.__name__}]' + _DocumentArrayTyped.__qualname__ = f'{cls.__name__}[{item.__name__}]' + cls.__typed_da__[item] = _DocumentArrayTyped + + return cls.__typed_da__[item] @abstractmethod def _get_array_attribute( diff --git a/docarray/array/array.py b/docarray/array/array.py index faa5915dbf0..253e053f26f 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -1,6 +1,16 @@ from contextlib import contextmanager from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Type, + TypeVar, + Union, +) from typing_inspect import is_union_type @@ -73,6 +83,7 @@ class Image(BaseDocument): """ document_type: Type[BaseDocument] = AnyDocument + __typed_da__: Dict[Type[BaseDocument], Type] = {} def __init__( self, diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 5a443a427e5..abd100ad616 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -56,6 +56,7 @@ class DocumentArrayStacked(AnyDocumentArray): document_type: Type[BaseDocument] = AnyDocument _docs: DocumentArray + __typed_da__: Dict[Type[BaseDocument], Type] = {} def __init__( self: T, diff --git a/tests/units/typing/da/__init__.py b/tests/units/typing/da/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/typing/da/test_relations.py b/tests/units/typing/da/test_relations.py new file mode 100644 index 00000000000..424d22b633b --- /dev/null +++ b/tests/units/typing/da/test_relations.py @@ -0,0 +1,33 @@ +from docarray import BaseDocument, DocumentArray + + +def test_instance_and_equivalence(): + class MyDoc(BaseDocument): + text: str + + docs = DocumentArray[MyDoc]([MyDoc(text='hello')]) + + assert issubclass(DocumentArray[MyDoc], DocumentArray[MyDoc]) + assert issubclass(docs.__class__, DocumentArray[MyDoc]) + + assert isinstance(docs, DocumentArray[MyDoc]) + + +def test_subclassing(): + class MyDoc(BaseDocument): + text: str + + class MyDocArray(DocumentArray[MyDoc]): + pass + + docs = MyDocArray([MyDoc(text='hello')]) + + assert issubclass(MyDocArray, DocumentArray[MyDoc]) + assert issubclass(docs.__class__, DocumentArray[MyDoc]) + + assert isinstance(docs, MyDocArray) + assert isinstance(docs, DocumentArray[MyDoc]) + + assert issubclass(MyDoc, BaseDocument) + assert not issubclass(DocumentArray[MyDoc], DocumentArray[BaseDocument]) + assert not issubclass(MyDocArray, DocumentArray[BaseDocument])