From c88c87b55a7fca559bef3eb98a37d3fe17c4fbe9 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 26 Jan 2023 14:42:41 +0100 Subject: [PATCH 1/7] feat: add validate on Text document Signed-off-by: samsja --- docarray/documents/text.py | 13 ++++++++++- .../predefined_document/test_text.py | 22 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 tests/integrations/predefined_document/test_text.py diff --git a/docarray/documents/text.py b/docarray/documents/text.py index 958dbb22021..e7276fef76a 100644 --- a/docarray/documents/text.py +++ b/docarray/documents/text.py @@ -1,9 +1,11 @@ -from typing import Optional +from typing import Any, Optional, Type, TypeVar, Union from docarray.base_document import BaseDocument from docarray.typing import TextUrl from docarray.typing.tensor.embedding import AnyEmbedding +T = TypeVar('T', bound='Text') + class Text(BaseDocument): """ @@ -68,3 +70,12 @@ class MultiModalDoc(BaseDocument): text: Optional[str] = None url: Optional[TextUrl] = None embedding: Optional[AnyEmbedding] = None + + @classmethod + def validate( + cls: Type[T], + value: Union[str, Any], + ) -> T: + if isinstance(value, str): + value = cls(text=value) + return super().validate(value) diff --git a/tests/integrations/predefined_document/test_text.py b/tests/integrations/predefined_document/test_text.py new file mode 100644 index 00000000000..80955166682 --- /dev/null +++ b/tests/integrations/predefined_document/test_text.py @@ -0,0 +1,22 @@ +from pydantic import parse_obj_as + +from docarray import BaseDocument +from docarray.documents import Text + + +def test_simple_init(): + t = Text(text='hello') + t.text == 'hello' + + +def test_str_init(): + t = parse_obj_as(Text, 'hello') + t.text == 'hello' + + +def test_doc(): + class MyDoc(BaseDocument): + text1: Text + text2: Text + + MyDoc(text1='hello', text2=Text(text='world')) From 7fd944645c4c66b457c047f2c794e8d5779b6cd3 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 26 Jan 2023 15:34:15 +0100 Subject: [PATCH 2/7] fix: complete test Signed-off-by: samsja --- tests/integrations/predefined_document/test_text.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integrations/predefined_document/test_text.py b/tests/integrations/predefined_document/test_text.py index 80955166682..3281caf59c8 100644 --- a/tests/integrations/predefined_document/test_text.py +++ b/tests/integrations/predefined_document/test_text.py @@ -19,4 +19,7 @@ class MyDoc(BaseDocument): text1: Text text2: Text - MyDoc(text1='hello', text2=Text(text='world')) + doc = MyDoc(text1='hello', text2=Text(text='world')) + + assert doc.text1.text == 'hello' + assert doc.text2.text == 'world' From 130dcc2b44b80ca5e0dc52fbcc4e443ed457afd4 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 26 Jan 2023 16:20:38 +0100 Subject: [PATCH 3/7] feat: add image shortcut Signed-off-by: samsja --- docarray/documents/image.py | 28 ++++++++++++++- .../predefined_document/test_image.py | 35 ++++++++++++++++++- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/docarray/documents/image.py b/docarray/documents/image.py index a8f7fbbef20..394634c85ad 100644 --- a/docarray/documents/image.py +++ b/docarray/documents/image.py @@ -1,7 +1,19 @@ -from typing import Optional +from typing import Any, Optional, Type, TypeVar, Union + +import numpy as np from docarray.base_document import BaseDocument from docarray.typing import AnyEmbedding, AnyTensor, ImageUrl +from docarray.typing.tensor.abstract_tensor import AbstractTensor + +T = TypeVar('T', bound='Image') + +try: + import torch + + torch_available = True +except ImportError: + torch_available = False class Image(BaseDocument): @@ -67,3 +79,17 @@ class MultiModalDoc(BaseDocument): url: Optional[ImageUrl] tensor: Optional[AnyTensor] embedding: Optional[AnyEmbedding] + + @classmethod + def validate( + cls: Type[T], + value: Union[str, AbstractTensor, Any], + ) -> T: + if isinstance(value, str): + value = cls(url=value) + elif isinstance(value, (AbstractTensor, np.ndarray)) or ( + torch_available and isinstance(value, torch.Tensor) + ): + value = cls(tensor=value) + + return super().validate(value) diff --git a/tests/integrations/predefined_document/test_image.py b/tests/integrations/predefined_document/test_image.py index c6a2a448553..e047ca37f5d 100644 --- a/tests/integrations/predefined_document/test_image.py +++ b/tests/integrations/predefined_document/test_image.py @@ -1,6 +1,9 @@ import numpy as np import pytest +import torch +from pydantic import parse_obj_as +from docarray import BaseDocument from docarray.documents import Image REMOTE_JPG = ( @@ -12,9 +15,39 @@ @pytest.mark.slow @pytest.mark.internet def test_image(): - image = Image(url=REMOTE_JPG) image.tensor = image.url.load() assert isinstance(image.tensor, np.ndarray) + + +def test_image_str(): + image = parse_obj_as(Image, 'http://myurl.jpg') + assert image.url == 'http://myurl.jpg' + + +def test_image_np(): + image = parse_obj_as(Image, np.zeros((10, 10, 3))) + assert (image.tensor == np.zeros((10, 10, 3))).all() + + +def test_image_torch(): + image = parse_obj_as(Image, torch.zeros(10, 10, 3)) + assert (image.tensor == torch.zeros(10, 10, 3)).all() + + +def test_image_shortcut_doc(): + class MyDoc(BaseDocument): + image: Image + image2: Image + image3: Image + + doc = MyDoc( + image='http://myurl.jpg', + image2=np.zeros((10, 10, 3)), + image3=torch.zeros(10, 10, 3), + ) + assert doc.image.url == 'http://myurl.jpg' + assert (doc.image2.tensor == np.zeros((10, 10, 3))).all() + assert (doc.image3.tensor == torch.zeros(10, 10, 3)).all() From f7ff6342765dba6def7a2b56341e3f36ee529c1b Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 26 Jan 2023 16:30:29 +0100 Subject: [PATCH 4/7] feat: add video Signed-off-by: samsja --- docarray/documents/video.py | 26 +++++++++++++++- .../predefined_document/test_video.py | 30 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/docarray/documents/video.py b/docarray/documents/video.py index 4a6547b6013..69f13a67d54 100644 --- a/docarray/documents/video.py +++ b/docarray/documents/video.py @@ -1,11 +1,21 @@ -from typing import Optional, TypeVar +from typing import Any, Optional, Type, TypeVar, Union + +import numpy as np from docarray.base_document import BaseDocument from docarray.documents import Audio from docarray.typing import AnyEmbedding, AnyTensor +from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.typing.tensor.video.video_tensor import VideoTensor from docarray.typing.url.video_url import VideoUrl +try: + import torch + + torch_available = True +except ImportError: + torch_available = False + T = TypeVar('T', bound='Video') @@ -83,3 +93,17 @@ class MultiModalDoc(BaseDocument): tensor: Optional[VideoTensor] key_frame_indices: Optional[AnyTensor] embedding: Optional[AnyEmbedding] + + @classmethod + def validate( + cls: Type[T], + value: Union[str, AbstractTensor, Any], + ) -> T: + if isinstance(value, str): + value = cls(url=value) + elif isinstance(value, (AbstractTensor, np.ndarray)) or ( + torch_available and isinstance(value, torch.Tensor) + ): + value = cls(tensor=value) + + return super().validate(value) diff --git a/tests/integrations/predefined_document/test_video.py b/tests/integrations/predefined_document/test_video.py index 43963c36e76..aecf4a7091a 100644 --- a/tests/integrations/predefined_document/test_video.py +++ b/tests/integrations/predefined_document/test_video.py @@ -1,5 +1,9 @@ +import numpy as np import pytest +import torch +from pydantic import parse_obj_as +from docarray import BaseDocument from docarray.documents import Video from docarray.typing import AudioNdArray, NdArray, VideoNdArray from tests import TOYDATA_DIR @@ -18,3 +22,29 @@ def test_video(file_url): assert isinstance(vid.tensor, VideoNdArray) assert isinstance(vid.audio.tensor, AudioNdArray) assert isinstance(vid.key_frame_indices, NdArray) + + +def test_image_np(): + image = parse_obj_as(Video, np.zeros((10, 10, 3))) + assert (image.tensor == np.zeros((10, 10, 3))).all() + + +def test_image_torch(): + image = parse_obj_as(Video, torch.zeros(10, 10, 3)) + assert (image.tensor == torch.zeros(10, 10, 3)).all() + + +def test_image_shortcut_doc(): + class MyDoc(BaseDocument): + image: Video + image2: Video + image3: Video + + doc = MyDoc( + image='http://myurl.mp4', + image2=np.zeros((10, 10, 3)), + image3=torch.zeros(10, 10, 3), + ) + assert doc.image.url == 'http://myurl.mp4' + assert (doc.image2.tensor == np.zeros((10, 10, 3))).all() + assert (doc.image3.tensor == torch.zeros(10, 10, 3)).all() From 91b56d4a4104f52979a851665070b2015a8c6e83 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 26 Jan 2023 16:35:34 +0100 Subject: [PATCH 5/7] docs: we to u Signed-off-by: samsja --- docarray/documents/point_cloud.py | 28 +++++++++++++++++- .../predefined_document/test_point_cloud.py | 29 +++++++++++++++++++ .../predefined_document/test_video.py | 6 ++-- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/docarray/documents/point_cloud.py b/docarray/documents/point_cloud.py index 4215d89dccb..9006db858fb 100644 --- a/docarray/documents/point_cloud.py +++ b/docarray/documents/point_cloud.py @@ -1,7 +1,19 @@ -from typing import Optional +from typing import Any, Optional, Type, TypeVar, Union + +import numpy as np from docarray.base_document import BaseDocument from docarray.typing import AnyEmbedding, AnyTensor, PointCloud3DUrl +from docarray.typing.tensor.abstract_tensor import AbstractTensor + +try: + import torch + + torch_available = True +except ImportError: + torch_available = False + +T = TypeVar('T', bound='PointCloud3D') class PointCloud3D(BaseDocument): @@ -75,3 +87,17 @@ class MultiModalDoc(BaseDocument): url: Optional[PointCloud3DUrl] tensor: Optional[AnyTensor] embedding: Optional[AnyEmbedding] + + @classmethod + def validate( + cls: Type[T], + value: Union[str, AbstractTensor, Any], + ) -> T: + if isinstance(value, str): + value = cls(url=value) + elif isinstance(value, (AbstractTensor, np.ndarray)) or ( + torch_available and isinstance(value, torch.Tensor) + ): + value = cls(tensor=value) + + return super().validate(value) diff --git a/tests/integrations/predefined_document/test_point_cloud.py b/tests/integrations/predefined_document/test_point_cloud.py index 3ce1fd32eb6..ad8f0c4a80f 100644 --- a/tests/integrations/predefined_document/test_point_cloud.py +++ b/tests/integrations/predefined_document/test_point_cloud.py @@ -1,6 +1,9 @@ import numpy as np import pytest +import torch +from pydantic import parse_obj_as +from docarray import BaseDocument from docarray.documents import PointCloud3D from tests import TOYDATA_DIR @@ -18,3 +21,29 @@ def test_point_cloud(file_url): point_cloud.tensor = point_cloud.url.load(samples=100) assert isinstance(point_cloud.tensor, np.ndarray) + + +def test_point_cloud_np(): + image = parse_obj_as(PointCloud3D, np.zeros((10, 10, 3))) + assert (image.tensor == np.zeros((10, 10, 3))).all() + + +def test_point_cloud_torch(): + image = parse_obj_as(PointCloud3D, torch.zeros(10, 10, 3)) + assert (image.tensor == torch.zeros(10, 10, 3)).all() + + +def test_point_cloud_shortcut_doc(): + class MyDoc(BaseDocument): + image: PointCloud3D + image2: PointCloud3D + image3: PointCloud3D + + doc = MyDoc( + image='http://myurl.ply', + image2=np.zeros((10, 10, 3)), + image3=torch.zeros(10, 10, 3), + ) + assert doc.image.url == 'http://myurl.ply' + assert (doc.image2.tensor == np.zeros((10, 10, 3))).all() + assert (doc.image3.tensor == torch.zeros(10, 10, 3)).all() diff --git a/tests/integrations/predefined_document/test_video.py b/tests/integrations/predefined_document/test_video.py index aecf4a7091a..4dd806afbc1 100644 --- a/tests/integrations/predefined_document/test_video.py +++ b/tests/integrations/predefined_document/test_video.py @@ -24,17 +24,17 @@ def test_video(file_url): assert isinstance(vid.key_frame_indices, NdArray) -def test_image_np(): +def test_video_np(): image = parse_obj_as(Video, np.zeros((10, 10, 3))) assert (image.tensor == np.zeros((10, 10, 3))).all() -def test_image_torch(): +def test_video_torch(): image = parse_obj_as(Video, torch.zeros(10, 10, 3)) assert (image.tensor == torch.zeros(10, 10, 3)).all() -def test_image_shortcut_doc(): +def test_video_shortcut_doc(): class MyDoc(BaseDocument): image: Video image2: Video From 231e9b930d5b6165d826b34af87ea8faf603ac1d Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 26 Jan 2023 16:40:12 +0100 Subject: [PATCH 6/7] feat: add mesh Signed-off-by: samsja --- docarray/documents/mesh.py | 13 ++++++++++++- .../predefined_document/test_mesh.py | 18 ++++++++++++++++++ .../predefined_document/test_text.py | 4 ++-- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/docarray/documents/mesh.py b/docarray/documents/mesh.py index 6061da941c3..f7f0e1bcaf7 100644 --- a/docarray/documents/mesh.py +++ b/docarray/documents/mesh.py @@ -1,8 +1,10 @@ -from typing import Optional +from typing import Any, Optional, Type, TypeVar, Union from docarray.base_document import BaseDocument from docarray.typing import AnyEmbedding, AnyTensor, Mesh3DUrl +T = TypeVar('T', bound='Mesh3D') + class Mesh3D(BaseDocument): """ @@ -77,3 +79,12 @@ class MultiModalDoc(BaseDocument): vertices: Optional[AnyTensor] faces: Optional[AnyTensor] embedding: Optional[AnyEmbedding] + + @classmethod + def validate( + cls: Type[T], + value: Union[str, Any], + ) -> T: + if isinstance(value, str): + value = cls(url=value) + return super().validate(value) diff --git a/tests/integrations/predefined_document/test_mesh.py b/tests/integrations/predefined_document/test_mesh.py index a194f20396a..8177233a433 100644 --- a/tests/integrations/predefined_document/test_mesh.py +++ b/tests/integrations/predefined_document/test_mesh.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from pydantic import parse_obj_as +from docarray import BaseDocument from docarray.documents import Mesh3D from tests import TOYDATA_DIR @@ -19,3 +21,19 @@ def test_mesh(file_url): assert isinstance(mesh.vertices, np.ndarray) assert isinstance(mesh.faces, np.ndarray) + + +def test_str_init(): + t = parse_obj_as(Mesh3D, 'http://hello.ply') + assert t.url == 'http://hello.ply' + + +def test_doc(): + class MyDoc(BaseDocument): + mesh1: Mesh3D + mesh2: Mesh3D + + doc = MyDoc(mesh1='http://hello.ply', mesh2=Mesh3D(url='http://hello.ply')) + + assert doc.mesh1.url == 'http://hello.ply' + assert doc.mesh2.url == 'http://hello.ply' diff --git a/tests/integrations/predefined_document/test_text.py b/tests/integrations/predefined_document/test_text.py index 3281caf59c8..36c4913114e 100644 --- a/tests/integrations/predefined_document/test_text.py +++ b/tests/integrations/predefined_document/test_text.py @@ -6,12 +6,12 @@ def test_simple_init(): t = Text(text='hello') - t.text == 'hello' + assert t.text == 'hello' def test_str_init(): t = parse_obj_as(Text, 'hello') - t.text == 'hello' + assert t.text == 'hello' def test_doc(): From 457c5d1c16e18a727dedf0fb6bf7d050156dcf23 Mon Sep 17 00:00:00 2001 From: samsja Date: Thu, 26 Jan 2023 17:51:12 +0100 Subject: [PATCH 7/7] feat: add audio Signed-off-by: samsja --- docarray/documents/audio.py | 26 +++++++++++++++++- .../predefined_document/test_audio.py | 27 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/docarray/documents/audio.py b/docarray/documents/audio.py index 776020bc964..473986ef89b 100644 --- a/docarray/documents/audio.py +++ b/docarray/documents/audio.py @@ -1,9 +1,19 @@ -from typing import Optional, TypeVar +from typing import Any, Optional, Type, TypeVar, Union + +import numpy as np from docarray.base_document import BaseDocument from docarray.typing import AnyEmbedding, AudioUrl +from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.typing.tensor.audio.audio_tensor import AudioTensor +try: + import torch + + torch_available = True +except ImportError: + torch_available = False + T = TypeVar('T', bound='Audio') @@ -76,3 +86,17 @@ class MultiModalDoc(Document): url: Optional[AudioUrl] tensor: Optional[AudioTensor] embedding: Optional[AnyEmbedding] + + @classmethod + def validate( + cls: Type[T], + value: Union[str, AbstractTensor, Any], + ) -> T: + if isinstance(value, str): + value = cls(url=value) + elif isinstance(value, (AbstractTensor, np.ndarray)) or ( + torch_available and isinstance(value, torch.Tensor) + ): + value = cls(tensor=value) + + return super().validate(value) diff --git a/tests/integrations/predefined_document/test_audio.py b/tests/integrations/predefined_document/test_audio.py index a4c7dafbde5..b13bc7f856f 100644 --- a/tests/integrations/predefined_document/test_audio.py +++ b/tests/integrations/predefined_document/test_audio.py @@ -6,6 +6,7 @@ import torch from pydantic import parse_obj_as +from docarray import BaseDocument from docarray.documents import Audio from docarray.typing import AudioUrl from docarray.typing.tensor.audio import AudioNdArray, AudioTorchTensor @@ -83,3 +84,29 @@ class MyAudio(Audio): assert isinstance(my_audio.tensor, AudioNdArray) assert isinstance(my_audio.url, AudioUrl) + + +def test_audio_np(): + audio = parse_obj_as(Audio, np.zeros((10, 10, 3))) + assert (audio.tensor == np.zeros((10, 10, 3))).all() + + +def test_audio_torch(): + audio = parse_obj_as(Audio, torch.zeros(10, 10, 3)) + assert (audio.tensor == torch.zeros(10, 10, 3)).all() + + +def test_audio_shortcut_doc(): + class MyDoc(BaseDocument): + audio: Audio + audio2: Audio + audio3: Audio + + doc = MyDoc( + audio='http://myurl.wav', + audio2=np.zeros((10, 10, 3)), + audio3=torch.zeros(10, 10, 3), + ) + assert doc.audio.url == 'http://myurl.wav' + assert (doc.audio2.tensor == np.zeros((10, 10, 3))).all() + assert (doc.audio3.tensor == torch.zeros(10, 10, 3)).all()