From 51e1cf1ad0515eddb7e2fb0ae70d6f5449186aea Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 4 Apr 2023 14:12:03 +0200 Subject: [PATCH 1/2] feat: save image tensor to image file Signed-off-by: anna-charlotte --- .../tensor/image/abstract_image_tensor.py | 21 +++++++++++ .../units/typing/tensor/test_image_tensor.py | 37 +++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 tests/units/typing/tensor/test_image_tensor.py diff --git a/docarray/typing/tensor/image/abstract_image_tensor.py b/docarray/typing/tensor/image/abstract_image_tensor.py index 0d65f72ae9a..547f6b94318 100644 --- a/docarray/typing/tensor/image/abstract_image_tensor.py +++ b/docarray/typing/tensor/image/abstract_image_tensor.py @@ -1,10 +1,15 @@ import io import warnings from abc import ABC +from typing import TypeVar + +import numpy as np from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.misc import import_library, is_notebook +T = TypeVar('T', bound='AbstractImageTensor') + class AbstractImageTensor(AbstractTensor, ABC): def to_bytes(self, format: str = 'PNG') -> bytes: @@ -31,6 +36,22 @@ def to_bytes(self, format: str = 'PNG') -> bytes: return img_byte_arr + def save(self: 'T', file_path: str) -> None: + """ + Save image tensor to an image file. + + :param file_path: path to an image file. If file is a string, open the file by + that name, otherwise treat it as a file-like object. + """ + PIL = import_library('PIL', raise_error=True) # noqa: F841 + from PIL import Image as PILImage + + comp_backend = self.get_comp_backend() + np_img = comp_backend.to_numpy(self).astype(np.uint8) + + pil_img = PILImage.fromarray(np_img) + pil_img.save(file_path) + def display(self) -> None: """ Display image data from tensor in notebook. diff --git a/tests/units/typing/tensor/test_image_tensor.py b/tests/units/typing/tensor/test_image_tensor.py new file mode 100644 index 00000000000..73c7b538d45 --- /dev/null +++ b/tests/units/typing/tensor/test_image_tensor.py @@ -0,0 +1,37 @@ +import os + +import numpy as np +import pytest +import torch +from pydantic import parse_obj_as + +from docarray.typing import ImageNdArray, ImageTorchTensor +from docarray.utils._internal.misc import is_tf_available + +tf_available = is_tf_available() +if tf_available: + import tensorflow as tf + + from docarray.typing.tensor.image import ImageTensorFlowTensor + + +@pytest.mark.parametrize( + 'cls_tensor,tensor', + [ + (ImageTorchTensor, torch.zeros((224, 224, 3))), + (ImageNdArray, np.zeros((224, 224, 3))), + ], +) +def test_save_image_tensor_to_file(cls_tensor, tensor, tmpdir): + tmp_file = str(tmpdir / 'tmp.jpg') + image_tensor = parse_obj_as(cls_tensor, tensor) + image_tensor.save(tmp_file) + assert os.path.isfile(tmp_file) + + +@pytest.mark.tensorflow +def test_save_image_tensorflow_tensor_to_file(tmpdir): + tmp_file = str(tmpdir / 'tmp.jpg') + image_tensor = parse_obj_as(ImageTensorFlowTensor, tf.zeros((224, 224, 3))) + image_tensor.save(tmp_file) + assert os.path.isfile(tmp_file) From 8d50227769fcf493b4da78b11e3e72ba8bb46781 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 4 Apr 2023 15:34:32 +0200 Subject: [PATCH 2/2] fix: apply samis suggestions from code review Signed-off-by: anna-charlotte --- docarray/typing/tensor/image/abstract_image_tensor.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docarray/typing/tensor/image/abstract_image_tensor.py b/docarray/typing/tensor/image/abstract_image_tensor.py index 547f6b94318..a6709920a13 100644 --- a/docarray/typing/tensor/image/abstract_image_tensor.py +++ b/docarray/typing/tensor/image/abstract_image_tensor.py @@ -1,15 +1,12 @@ import io import warnings from abc import ABC -from typing import TypeVar import numpy as np from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.misc import import_library, is_notebook -T = TypeVar('T', bound='AbstractImageTensor') - class AbstractImageTensor(AbstractTensor, ABC): def to_bytes(self, format: str = 'PNG') -> bytes: @@ -36,7 +33,7 @@ def to_bytes(self, format: str = 'PNG') -> bytes: return img_byte_arr - def save(self: 'T', file_path: str) -> None: + def save(self, file_path: str) -> None: """ Save image tensor to an image file.