Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docarray/typing/tensor/image/abstract_image_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import warnings
from abc import ABC

import numpy as np

from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import import_library, is_notebook

Expand Down Expand Up @@ -31,6 +33,22 @@ def to_bytes(self, format: str = 'PNG') -> bytes:

return img_byte_arr

def save(self, file_path: str) -> None:
"""
Save image tensor to an image file.

:param file_path: path to an image file. If file is a string, open the file by
that name, otherwise treat it as a file-like object.
"""
PIL = import_library('PIL', raise_error=True) # noqa: F841
from PIL import Image as PILImage

comp_backend = self.get_comp_backend()
np_img = comp_backend.to_numpy(self).astype(np.uint8)

pil_img = PILImage.fromarray(np_img)
pil_img.save(file_path)

def display(self) -> None:
"""
Display image data from tensor in notebook.
Expand Down
37 changes: 37 additions & 0 deletions tests/units/typing/tensor/test_image_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os

import numpy as np
import pytest
import torch
from pydantic import parse_obj_as

from docarray.typing import ImageNdArray, ImageTorchTensor
from docarray.utils._internal.misc import is_tf_available

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf

from docarray.typing.tensor.image import ImageTensorFlowTensor


@pytest.mark.parametrize(
'cls_tensor,tensor',
[
(ImageTorchTensor, torch.zeros((224, 224, 3))),
(ImageNdArray, np.zeros((224, 224, 3))),
],
)
def test_save_image_tensor_to_file(cls_tensor, tensor, tmpdir):
tmp_file = str(tmpdir / 'tmp.jpg')
image_tensor = parse_obj_as(cls_tensor, tensor)
image_tensor.save(tmp_file)
assert os.path.isfile(tmp_file)


@pytest.mark.tensorflow
def test_save_image_tensorflow_tensor_to_file(tmpdir):
tmp_file = str(tmpdir / 'tmp.jpg')
image_tensor = parse_obj_as(ImageTensorFlowTensor, tf.zeros((224, 224, 3)))
image_tensor.save(tmp_file)
assert os.path.isfile(tmp_file)