Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion docarray/documents/audio.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from typing import Optional, TypeVar
from typing import Any, Optional, Type, TypeVar, Union

import numpy as np

from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AudioUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.audio.audio_tensor import AudioTensor

try:
import torch

torch_available = True
except ImportError:
torch_available = False

T = TypeVar('T', bound='Audio')


Expand Down Expand Up @@ -76,3 +86,17 @@ class MultiModalDoc(Document):
url: Optional[AudioUrl]
tensor: Optional[AudioTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
):
value = cls(tensor=value)

return super().validate(value)
28 changes: 27 additions & 1 deletion docarray/documents/image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from typing import Optional
from typing import Any, Optional, Type, TypeVar, Union

import numpy as np

from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AnyTensor, ImageUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor

T = TypeVar('T', bound='Image')

try:
import torch

torch_available = True
except ImportError:
torch_available = False


class Image(BaseDocument):
Expand Down Expand Up @@ -67,3 +79,17 @@ class MultiModalDoc(BaseDocument):
url: Optional[ImageUrl]
tensor: Optional[AnyTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
):
value = cls(tensor=value)

return super().validate(value)
13 changes: 12 additions & 1 deletion docarray/documents/mesh.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional
from typing import Any, Optional, Type, TypeVar, Union

from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AnyTensor, Mesh3DUrl

T = TypeVar('T', bound='Mesh3D')


class Mesh3D(BaseDocument):
"""
Expand Down Expand Up @@ -77,3 +79,12 @@ class MultiModalDoc(BaseDocument):
vertices: Optional[AnyTensor]
faces: Optional[AnyTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
return super().validate(value)
28 changes: 27 additions & 1 deletion docarray/documents/point_cloud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from typing import Optional
from typing import Any, Optional, Type, TypeVar, Union

import numpy as np

from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AnyTensor, PointCloud3DUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor

try:
import torch

torch_available = True
except ImportError:
torch_available = False

T = TypeVar('T', bound='PointCloud3D')


class PointCloud3D(BaseDocument):
Expand Down Expand Up @@ -75,3 +87,17 @@ class MultiModalDoc(BaseDocument):
url: Optional[PointCloud3DUrl]
tensor: Optional[AnyTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
):
value = cls(tensor=value)

return super().validate(value)
13 changes: 12 additions & 1 deletion docarray/documents/text.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional
from typing import Any, Optional, Type, TypeVar, Union

from docarray.base_document import BaseDocument
from docarray.typing import TextUrl
from docarray.typing.tensor.embedding import AnyEmbedding

T = TypeVar('T', bound='Text')


class Text(BaseDocument):
"""
Expand Down Expand Up @@ -68,3 +70,12 @@ class MultiModalDoc(BaseDocument):
text: Optional[str] = None
url: Optional[TextUrl] = None
embedding: Optional[AnyEmbedding] = None

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(text=value)
return super().validate(value)
26 changes: 25 additions & 1 deletion docarray/documents/video.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from typing import Optional, TypeVar
from typing import Any, Optional, Type, TypeVar, Union

import numpy as np

from docarray.base_document import BaseDocument
from docarray.documents import Audio
from docarray.typing import AnyEmbedding, AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.video.video_tensor import VideoTensor
from docarray.typing.url.video_url import VideoUrl

try:
import torch

torch_available = True
except ImportError:
torch_available = False

T = TypeVar('T', bound='Video')


Expand Down Expand Up @@ -83,3 +93,17 @@ class MultiModalDoc(BaseDocument):
tensor: Optional[VideoTensor]
key_frame_indices: Optional[AnyTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
):
value = cls(tensor=value)

return super().validate(value)
27 changes: 27 additions & 0 deletions tests/integrations/predefined_document/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import Audio
from docarray.typing import AudioUrl
from docarray.typing.tensor.audio import AudioNdArray, AudioTorchTensor
Expand Down Expand Up @@ -83,3 +84,29 @@ class MyAudio(Audio):

assert isinstance(my_audio.tensor, AudioNdArray)
assert isinstance(my_audio.url, AudioUrl)


def test_audio_np():
audio = parse_obj_as(Audio, np.zeros((10, 10, 3)))
assert (audio.tensor == np.zeros((10, 10, 3))).all()


def test_audio_torch():
audio = parse_obj_as(Audio, torch.zeros(10, 10, 3))
assert (audio.tensor == torch.zeros(10, 10, 3)).all()


def test_audio_shortcut_doc():
class MyDoc(BaseDocument):
audio: Audio
audio2: Audio
audio3: Audio

doc = MyDoc(
audio='http://myurl.wav',
audio2=np.zeros((10, 10, 3)),
audio3=torch.zeros(10, 10, 3),
)
assert doc.audio.url == 'http://myurl.wav'
assert (doc.audio2.tensor == np.zeros((10, 10, 3))).all()
assert (doc.audio3.tensor == torch.zeros(10, 10, 3)).all()
35 changes: 34 additions & 1 deletion tests/integrations/predefined_document/test_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pytest
import torch
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import Image

REMOTE_JPG = (
Expand All @@ -12,9 +15,39 @@
@pytest.mark.slow
@pytest.mark.internet
def test_image():

image = Image(url=REMOTE_JPG)

image.tensor = image.url.load()

assert isinstance(image.tensor, np.ndarray)


def test_image_str():
image = parse_obj_as(Image, 'http://myurl.jpg')
assert image.url == 'http://myurl.jpg'


def test_image_np():
image = parse_obj_as(Image, np.zeros((10, 10, 3)))
assert (image.tensor == np.zeros((10, 10, 3))).all()


def test_image_torch():
image = parse_obj_as(Image, torch.zeros(10, 10, 3))
assert (image.tensor == torch.zeros(10, 10, 3)).all()


def test_image_shortcut_doc():
class MyDoc(BaseDocument):
image: Image
image2: Image
image3: Image

doc = MyDoc(
image='http://myurl.jpg',
image2=np.zeros((10, 10, 3)),
image3=torch.zeros(10, 10, 3),
)
assert doc.image.url == 'http://myurl.jpg'
assert (doc.image2.tensor == np.zeros((10, 10, 3))).all()
assert (doc.image3.tensor == torch.zeros(10, 10, 3)).all()
18 changes: 18 additions & 0 deletions tests/integrations/predefined_document/test_mesh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import Mesh3D
from tests import TOYDATA_DIR

Expand All @@ -19,3 +21,19 @@ def test_mesh(file_url):

assert isinstance(mesh.vertices, np.ndarray)
assert isinstance(mesh.faces, np.ndarray)


def test_str_init():
t = parse_obj_as(Mesh3D, 'http://hello.ply')
assert t.url == 'http://hello.ply'


def test_doc():
class MyDoc(BaseDocument):
mesh1: Mesh3D
mesh2: Mesh3D

doc = MyDoc(mesh1='http://hello.ply', mesh2=Mesh3D(url='http://hello.ply'))

assert doc.mesh1.url == 'http://hello.ply'
assert doc.mesh2.url == 'http://hello.ply'
29 changes: 29 additions & 0 deletions tests/integrations/predefined_document/test_point_cloud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pytest
import torch
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import PointCloud3D
from tests import TOYDATA_DIR

Expand All @@ -18,3 +21,29 @@ def test_point_cloud(file_url):
point_cloud.tensor = point_cloud.url.load(samples=100)

assert isinstance(point_cloud.tensor, np.ndarray)


def test_point_cloud_np():
image = parse_obj_as(PointCloud3D, np.zeros((10, 10, 3)))
assert (image.tensor == np.zeros((10, 10, 3))).all()


def test_point_cloud_torch():
image = parse_obj_as(PointCloud3D, torch.zeros(10, 10, 3))
assert (image.tensor == torch.zeros(10, 10, 3)).all()


def test_point_cloud_shortcut_doc():
class MyDoc(BaseDocument):
image: PointCloud3D
image2: PointCloud3D
image3: PointCloud3D

doc = MyDoc(
image='http://myurl.ply',
image2=np.zeros((10, 10, 3)),
image3=torch.zeros(10, 10, 3),
)
assert doc.image.url == 'http://myurl.ply'
assert (doc.image2.tensor == np.zeros((10, 10, 3))).all()
assert (doc.image3.tensor == torch.zeros(10, 10, 3)).all()
25 changes: 25 additions & 0 deletions tests/integrations/predefined_document/test_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import Text


def test_simple_init():
t = Text(text='hello')
assert t.text == 'hello'


def test_str_init():
t = parse_obj_as(Text, 'hello')
assert t.text == 'hello'


def test_doc():
class MyDoc(BaseDocument):
text1: Text
text2: Text

doc = MyDoc(text1='hello', text2=Text(text='world'))

assert doc.text1.text == 'hello'
assert doc.text2.text == 'world'
Loading