Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions docarray/documents/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def create_doc(
return doc


def create_from_typeddict(
def create_doc_from_typeddict(
typeddict_cls: Type['TypedDict'], # type: ignore
**kwargs: Any,
):
Expand All @@ -91,7 +91,7 @@ def create_from_typeddict(

from docarray import BaseDocument
from docarray.documents import Audio
from docarray.documents.helper import create_from_typeddict
from docarray.documents.helper import create_doc_from_typeddict
from docarray.typing.tensor.audio import AudioNdArray


Expand All @@ -100,7 +100,7 @@ class MyAudio(TypedDict):
tensor: AudioNdArray


Doc = create_from_typeddict(MyAudio, __base__=Audio)
Doc = create_doc_from_typeddict(MyAudio, __base__=Audio)

assert issubclass(Doc, BaseDocument)
assert issubclass(Doc, Audio)
Expand All @@ -118,3 +118,39 @@ class MyAudio(TypedDict):
doc = create_model_from_typeddict(typeddict_cls, **kwargs)

return doc


def create_doc_from_dict(model_name: str, data_dict: Dict[str, Any]) -> Type['T_doc']:
"""
Create a subclass of BaseDocument based on example data given as a dictionary.

In case the example contains None as a value,
corresponding field will be viewed as the type Any.

:param model_name: Name of the new Document class
:param data_dict: Dictionary of field types to their corresponding values.
:return: the new Document class

EXAMPLE USAGE

.. code-block:: python

import numpy as np
from docarray.documents import ImageDoc
from docarray.documents.helper import create_doc_from_dict

data_dict = {'image': ImageDoc(tensor=np.random.rand(3, 224, 224)), 'author': 'me'}

MyDoc = create_doc_from_dict(model_name='MyDoc', data_dict=data_dict)

assert issubclass(MyDoc, BaseDocument)

"""
if not data_dict:
raise ValueError('`data_dict` should contain at least one item')

field_types = {
field: (type(value) if value else Any, ...)
for field, value in data_dict.items()
}
return create_doc(__model_name=model_name, **field_types) # type: ignore
62 changes: 56 additions & 6 deletions tests/integrations/document/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import numpy as np
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from typing_extensions import TypedDict

from docarray import BaseDocument, DocumentArray
from docarray.documents import AudioDoc, ImageDoc, TextDoc
from docarray.documents.helper import create_doc, create_from_typeddict
from docarray.documents.helper import (
create_doc,
create_doc_from_typeddict,
create_doc_from_dict,
)
from docarray.typing import AudioNdArray


Expand Down Expand Up @@ -78,23 +82,69 @@ def test_create_doc():
assert issubclass(MyAudio, AudioDoc)


def test_create_from_typeddict():
def test_create_doc_from_typeddict():
class MyMultiModalDoc(TypedDict):
image: ImageDoc
text: TextDoc

with pytest.raises(ValueError):
_ = create_from_typeddict(MyMultiModalDoc, __base__=BaseModel)
_ = create_doc_from_typeddict(MyMultiModalDoc, __base__=BaseModel)

Doc = create_from_typeddict(MyMultiModalDoc)
Doc = create_doc_from_typeddict(MyMultiModalDoc)

assert issubclass(Doc, BaseDocument)

class MyAudio(TypedDict):
title: str
tensor: Optional[AudioNdArray]

Doc = create_from_typeddict(MyAudio, __base__=AudioDoc)
Doc = create_doc_from_typeddict(MyAudio, __base__=AudioDoc)

assert issubclass(Doc, BaseDocument)
assert issubclass(Doc, AudioDoc)


def test_create_doc_from_dict():
data_dict = {
'image': ImageDoc(tensor=np.random.rand(3, 224, 224)),
'text': TextDoc(text='hello'),
'id': 123,
}

MyDoc = create_doc_from_dict(model_name='MyDoc', data_dict=data_dict)

assert issubclass(MyDoc, BaseDocument)

doc = MyDoc(
image=ImageDoc(tensor=np.random.rand(3, 224, 224)),
text=TextDoc(text='hey'),
id=111,
)

assert isinstance(doc, BaseDocument)
assert isinstance(doc.text, TextDoc)
assert isinstance(doc.image, ImageDoc)
assert isinstance(doc.id, int)

# Create a doc with an incorrect type
with pytest.raises(ValidationError):
doc = MyDoc(
image=ImageDoc(tensor=np.random.rand(3, 224, 224)),
text=['some', 'text'], # should be TextDoc
id=111,
)

# Handle empty data_dict
with pytest.raises(ValueError):
MyDoc = create_doc_from_dict(model_name='MyDoc', data_dict={})

# Data with a None value
data_dict = {'text': 'some text', 'other': None}
MyDoc = create_doc_from_dict(model_name='MyDoc', data_dict=data_dict)

assert issubclass(MyDoc, BaseDocument)

doc1 = MyDoc(text='txt', other=10)
doc2 = MyDoc(text='txt', other='also text')

assert isinstance(doc1, BaseDocument) and isinstance(doc2, BaseDocument)