diff --git a/docarray/typing/tensor/image/abstract_image_tensor.py b/docarray/typing/tensor/image/abstract_image_tensor.py index 0d65f72ae9a..a6709920a13 100644 --- a/docarray/typing/tensor/image/abstract_image_tensor.py +++ b/docarray/typing/tensor/image/abstract_image_tensor.py @@ -2,6 +2,8 @@ import warnings from abc import ABC +import numpy as np + from docarray.typing.tensor.abstract_tensor import AbstractTensor from docarray.utils._internal.misc import import_library, is_notebook @@ -31,6 +33,22 @@ def to_bytes(self, format: str = 'PNG') -> bytes: return img_byte_arr + def save(self, file_path: str) -> None: + """ + Save image tensor to an image file. + + :param file_path: path to an image file. If file is a string, open the file by + that name, otherwise treat it as a file-like object. + """ + PIL = import_library('PIL', raise_error=True) # noqa: F841 + from PIL import Image as PILImage + + comp_backend = self.get_comp_backend() + np_img = comp_backend.to_numpy(self).astype(np.uint8) + + pil_img = PILImage.fromarray(np_img) + pil_img.save(file_path) + def display(self) -> None: """ Display image data from tensor in notebook. diff --git a/tests/units/typing/tensor/test_image_tensor.py b/tests/units/typing/tensor/test_image_tensor.py new file mode 100644 index 00000000000..73c7b538d45 --- /dev/null +++ b/tests/units/typing/tensor/test_image_tensor.py @@ -0,0 +1,37 @@ +import os + +import numpy as np +import pytest +import torch +from pydantic import parse_obj_as + +from docarray.typing import ImageNdArray, ImageTorchTensor +from docarray.utils._internal.misc import is_tf_available + +tf_available = is_tf_available() +if tf_available: + import tensorflow as tf + + from docarray.typing.tensor.image import ImageTensorFlowTensor + + +@pytest.mark.parametrize( + 'cls_tensor,tensor', + [ + (ImageTorchTensor, torch.zeros((224, 224, 3))), + (ImageNdArray, np.zeros((224, 224, 3))), + ], +) +def test_save_image_tensor_to_file(cls_tensor, tensor, tmpdir): + tmp_file = str(tmpdir / 'tmp.jpg') + image_tensor = parse_obj_as(cls_tensor, tensor) + image_tensor.save(tmp_file) + assert os.path.isfile(tmp_file) + + +@pytest.mark.tensorflow +def test_save_image_tensorflow_tensor_to_file(tmpdir): + tmp_file = str(tmpdir / 'tmp.jpg') + image_tensor = parse_obj_as(ImageTensorFlowTensor, tf.zeros((224, 224, 3))) + image_tensor.save(tmp_file) + assert os.path.isfile(tmp_file)