From 9ec1b46545a0c949d445940a8b05283782dbaa18 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Tue, 24 Jan 2023 13:47:48 +0800 Subject: [PATCH 1/4] test: test isinstance and issubclass for typed das Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- tests/units/typing/da/__init__.py | 0 tests/units/typing/da/test_relations.py | 33 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 tests/units/typing/da/__init__.py create mode 100644 tests/units/typing/da/test_relations.py diff --git a/tests/units/typing/da/__init__.py b/tests/units/typing/da/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/typing/da/test_relations.py b/tests/units/typing/da/test_relations.py new file mode 100644 index 00000000000..424d22b633b --- /dev/null +++ b/tests/units/typing/da/test_relations.py @@ -0,0 +1,33 @@ +from docarray import BaseDocument, DocumentArray + + +def test_instance_and_equivalence(): + class MyDoc(BaseDocument): + text: str + + docs = DocumentArray[MyDoc]([MyDoc(text='hello')]) + + assert issubclass(DocumentArray[MyDoc], DocumentArray[MyDoc]) + assert issubclass(docs.__class__, DocumentArray[MyDoc]) + + assert isinstance(docs, DocumentArray[MyDoc]) + + +def test_subclassing(): + class MyDoc(BaseDocument): + text: str + + class MyDocArray(DocumentArray[MyDoc]): + pass + + docs = MyDocArray([MyDoc(text='hello')]) + + assert issubclass(MyDocArray, DocumentArray[MyDoc]) + assert issubclass(docs.__class__, DocumentArray[MyDoc]) + + assert isinstance(docs, MyDocArray) + assert isinstance(docs, DocumentArray[MyDoc]) + + assert issubclass(MyDoc, BaseDocument) + assert not issubclass(DocumentArray[MyDoc], DocumentArray[BaseDocument]) + assert not issubclass(MyDocArray, DocumentArray[BaseDocument]) From dfc05c2aa9719bfb01d3c07d491683fbebc70fae Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Tue, 24 Jan 2023 13:49:08 +0800 Subject: [PATCH 2/4] fix: have a class level cache of typed das so they arent different Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/array/abstract_array.py | 46 +++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index 4482a553989..a039a0cc458 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -1,5 +1,15 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Sequence, + Type, + TypeVar, + Union, +) from docarray.base_document import BaseDocument from docarray.typing import NdArray @@ -16,6 +26,7 @@ class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] tensor_type: Type['AbstractTensor'] = NdArray + __typed_da__: Dict[BaseDocument, Type] = {} def __class_getitem__(cls, item: Type[BaseDocument]): if not issubclass(item, BaseDocument): @@ -23,28 +34,31 @@ def __class_getitem__(cls, item: Type[BaseDocument]): f'{cls.__name__}[item] item should be a Document not a {item} ' ) - class _DocumentArrayTyped(cls): # type: ignore - document_type: Type[BaseDocument] = item + if item not in cls.__typed_da__: - for field in _DocumentArrayTyped.document_type.__fields__.keys(): + class _DocumentArrayTyped(cls): # type: ignore + document_type: Type[BaseDocument] = item - def _property_generator(val: str): - def _getter(self): - return self._get_array_attribute(val) + for field in _DocumentArrayTyped.document_type.__fields__.keys(): - def _setter(self, value): - self._set_array_attribute(val, value) + def _property_generator(val: str): + def _getter(self): + return self._get_array_attribute(val) - # need docstring for the property - return property(fget=_getter, fset=_setter) + def _setter(self, value): + self._set_array_attribute(val, value) - setattr(_DocumentArrayTyped, field, _property_generator(field)) - # this generates property on the fly based on the schema of the item + # need docstring for the property + return property(fget=_getter, fset=_setter) - _DocumentArrayTyped.__name__ = f'{cls.__name__}[{item.__name__}]' - _DocumentArrayTyped.__qualname__ = f'{cls.__name__}[{item.__name__}]' + setattr(_DocumentArrayTyped, field, _property_generator(field)) + # this generates property on the fly based on the schema of the item - return _DocumentArrayTyped + _DocumentArrayTyped.__name__ = f'{cls.__name__}[{item.__name__}]' + _DocumentArrayTyped.__qualname__ = f'{cls.__name__}[{item.__name__}]' + cls.__typed_da__[item] = _DocumentArrayTyped + + return cls.__typed_da__[item] @abstractmethod def _get_array_attribute( From dfe5c8a7824e3ca095a3b23a78d47c741879a0df Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Tue, 24 Jan 2023 14:54:39 +0800 Subject: [PATCH 3/4] fix: small type fix Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/array/abstract_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/abstract_array.py b/docarray/array/abstract_array.py index a039a0cc458..8c225226408 100644 --- a/docarray/array/abstract_array.py +++ b/docarray/array/abstract_array.py @@ -26,7 +26,7 @@ class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType): document_type: Type[BaseDocument] tensor_type: Type['AbstractTensor'] = NdArray - __typed_da__: Dict[BaseDocument, Type] = {} + __typed_da__: Dict[Type[BaseDocument], Type] = {} def __class_getitem__(cls, item: Type[BaseDocument]): if not issubclass(item, BaseDocument): From 01e3232adb84bd18d22b8e54be017c8de6ced0ec Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Tue, 24 Jan 2023 15:08:57 +0800 Subject: [PATCH 4/4] fix: subclasses of AnyDocumentArray need their own cache Signed-off-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> --- docarray/array/array.py | 13 ++++++++++++- docarray/array/array_stacked.py | 1 + 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docarray/array/array.py b/docarray/array/array.py index faa5915dbf0..253e053f26f 100644 --- a/docarray/array/array.py +++ b/docarray/array/array.py @@ -1,6 +1,16 @@ from contextlib import contextmanager from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Type, + TypeVar, + Union, +) from typing_inspect import is_union_type @@ -73,6 +83,7 @@ class Image(BaseDocument): """ document_type: Type[BaseDocument] = AnyDocument + __typed_da__: Dict[Type[BaseDocument], Type] = {} def __init__( self, diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 5a443a427e5..abd100ad616 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -56,6 +56,7 @@ class DocumentArrayStacked(AnyDocumentArray): document_type: Type[BaseDocument] = AnyDocument _docs: DocumentArray + __typed_da__: Dict[Type[BaseDocument], Type] = {} def __init__( self: T,