Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions docarray/array/abstract_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Sequence,
Type,
TypeVar,
Union,
)

from docarray.base_document import BaseDocument
from docarray.display.document_array_summary import DocumentArraySummary
Expand All @@ -17,6 +27,7 @@
class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType):
document_type: Type[BaseDocument]
tensor_type: Type['AbstractTensor'] = NdArray
__typed_da__: Dict[Type[BaseDocument], Type] = {}

def __repr__(self):
return f'<{self.__class__.__name__} (length={len(self)})>'
Expand All @@ -27,28 +38,31 @@ def __class_getitem__(cls, item: Type[BaseDocument]):
f'{cls.__name__}[item] item should be a Document not a {item} '
)

class _DocumentArrayTyped(cls): # type: ignore
document_type: Type[BaseDocument] = item
if item not in cls.__typed_da__:

for field in _DocumentArrayTyped.document_type.__fields__.keys():
class _DocumentArrayTyped(cls): # type: ignore
document_type: Type[BaseDocument] = item

def _property_generator(val: str):
def _getter(self):
return self._get_array_attribute(val)
for field in _DocumentArrayTyped.document_type.__fields__.keys():

def _setter(self, value):
self._set_array_attribute(val, value)
def _property_generator(val: str):
def _getter(self):
return self._get_array_attribute(val)

# need docstring for the property
return property(fget=_getter, fset=_setter)
def _setter(self, value):
self._set_array_attribute(val, value)

setattr(_DocumentArrayTyped, field, _property_generator(field))
# this generates property on the fly based on the schema of the item
# need docstring for the property
return property(fget=_getter, fset=_setter)

_DocumentArrayTyped.__name__ = f'{cls.__name__}[{item.__name__}]'
_DocumentArrayTyped.__qualname__ = f'{cls.__name__}[{item.__name__}]'
setattr(_DocumentArrayTyped, field, _property_generator(field))
# this generates property on the fly based on the schema of the item

return _DocumentArrayTyped
_DocumentArrayTyped.__name__ = f'{cls.__name__}[{item.__name__}]'
_DocumentArrayTyped.__qualname__ = f'{cls.__name__}[{item.__name__}]'
cls.__typed_da__[item] = _DocumentArrayTyped

return cls.__typed_da__[item]

@abstractmethod
def _get_array_attribute(
Expand Down
13 changes: 12 additions & 1 deletion docarray/array/array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from contextlib import contextmanager
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Type,
TypeVar,
Union,
)

from typing_inspect import is_union_type

Expand Down Expand Up @@ -73,6 +83,7 @@ class Image(BaseDocument):
"""

document_type: Type[BaseDocument] = AnyDocument
__typed_da__: Dict[Type[BaseDocument], Type] = {}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions docarray/array/array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class DocumentArrayStacked(AnyDocumentArray):

document_type: Type[BaseDocument] = AnyDocument
_docs: DocumentArray
__typed_da__: Dict[Type[BaseDocument], Type] = {}

def __init__(
self: T,
Expand Down
Empty file.
33 changes: 33 additions & 0 deletions tests/units/typing/da/test_relations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from docarray import BaseDocument, DocumentArray


def test_instance_and_equivalence():
class MyDoc(BaseDocument):
text: str

docs = DocumentArray[MyDoc]([MyDoc(text='hello')])

assert issubclass(DocumentArray[MyDoc], DocumentArray[MyDoc])
assert issubclass(docs.__class__, DocumentArray[MyDoc])

assert isinstance(docs, DocumentArray[MyDoc])


def test_subclassing():
class MyDoc(BaseDocument):
text: str

class MyDocArray(DocumentArray[MyDoc]):
pass

docs = MyDocArray([MyDoc(text='hello')])

assert issubclass(MyDocArray, DocumentArray[MyDoc])
assert issubclass(docs.__class__, DocumentArray[MyDoc])

assert isinstance(docs, MyDocArray)
assert isinstance(docs, DocumentArray[MyDoc])

assert issubclass(MyDoc, BaseDocument)
assert not issubclass(DocumentArray[MyDoc], DocumentArray[BaseDocument])
assert not issubclass(MyDocArray, DocumentArray[BaseDocument])