Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def plot_embeddings(
start_server: bool = True,
host: str = '127.0.0.1',
port: Optional[int] = None,
image_source: str = 'tensor',
) -> str:
"""Interactively visualize :attr:`.embeddings` using the Embedding Projector.

Expand All @@ -121,6 +122,7 @@ def plot_embeddings(
:param min_image_size: only used when `image_sprites=True`. the minimum size of the image
:param channel_axis: only used when `image_sprites=True`. the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param start_server: if set, start a HTTP server and open the frontend directly. Otherwise, you need to rely on ``return`` path and serve by yourself.
:param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri
:return: the path to the embeddings visualization info.
"""
from ...helper import random_port, __resources_path__
Expand Down Expand Up @@ -154,6 +156,7 @@ def plot_embeddings(
canvas_size=canvas_size,
min_size=min_image_size,
channel_axis=channel_axis,
image_source=image_source,
)

self.save_embeddings_csv(os.path.join(path, emb_fn), delimiter='\t')
Expand Down Expand Up @@ -287,6 +290,7 @@ def plot_image_sprites(
canvas_size: int = 512,
min_size: int = 16,
channel_axis: int = -1,
image_source: str = 'tensor',
) -> None:
"""Generate a sprite image for all image tensors in this DocumentArray-like object.

Expand All @@ -297,6 +301,7 @@ def plot_image_sprites(
:param canvas_size: the size of the canvas
:param min_size: the minimum size of the image
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri
"""
if not self:
raise ValueError(f'{self!r} is empty')
Expand All @@ -316,26 +321,36 @@ def plot_image_sprites(
[img_size * img_per_row, img_size * img_per_row, 3], dtype='uint8'
)
img_id = 0
for d in self:
_d = copy.deepcopy(d)
if _d.content_type != 'tensor':
_d.load_uri_to_image_tensor()
channel_axis = -1

_d.set_image_tensor_channel_axis(channel_axis, -1).set_image_tensor_shape(
shape=(img_size, img_size)
)

row_id = floor(img_id / img_per_row)
col_id = img_id % img_per_row
sprite_img[
(row_id * img_size) : ((row_id + 1) * img_size),
(col_id * img_size) : ((col_id + 1) * img_size),
] = _d.tensor

img_id += 1
if img_id >= max_num_img:
break
try:
for d in self:
_d = copy.deepcopy(d)

if image_source == 'uri' or (
image_source == 'tensor' and _d.content_type != 'tensor'
):
_d.load_uri_to_image_tensor()
channel_axis = -1
elif image_source not in ('uri', 'tensor'):
raise ValueError(f'image_source can be only `uri` or `tensor`')

_d.set_image_tensor_channel_axis(
channel_axis, -1
).set_image_tensor_shape(shape=(img_size, img_size))

row_id = floor(img_id / img_per_row)
col_id = img_id % img_per_row
sprite_img[
(row_id * img_size) : ((row_id + 1) * img_size),
(col_id * img_size) : ((col_id + 1) * img_size),
] = _d.tensor

img_id += 1
if img_id >= max_num_img:
break
except Exception as ex:
raise ValueError(
'Bad image tensor. Try different `image_source` or `channel_axis`'
) from ex

from PIL import Image

Expand Down
23 changes: 21 additions & 2 deletions tests/unit/array/mixins/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,34 @@
from docarray import DocumentArray, Document


def test_sprite_image_generator(pytestconfig, tmpdir):
def test_sprite_fail_tensor_success_uri(pytestconfig, tmpdir):
da = DocumentArray.from_files(
[
f'{pytestconfig.rootdir}/**/*.png',
f'{pytestconfig.rootdir}/**/*.jpg',
f'{pytestconfig.rootdir}/**/*.jpeg',
]
)
da.plot_image_sprites(tmpdir / 'sprint_da.png')
da.apply(
lambda d: d.load_uri_to_image_tensor().set_image_tensor_channel_axis(-1, 0)
)
with pytest.raises(ValueError):
da.plot_image_sprites()
da.plot_image_sprites(tmpdir / 'sprint_da.png', image_source='uri')
assert os.path.exists(tmpdir / 'sprint_da.png')


@pytest.mark.parametrize('image_source', ['tensor', 'uri'])
def test_sprite_image_generator(pytestconfig, tmpdir, image_source):
da = DocumentArray.from_files(
[
f'{pytestconfig.rootdir}/**/*.png',
f'{pytestconfig.rootdir}/**/*.jpg',
f'{pytestconfig.rootdir}/**/*.jpeg',
]
)
da.apply(lambda d: d.load_uri_to_image_tensor())
da.plot_image_sprites(tmpdir / 'sprint_da.png', image_source=image_source)
assert os.path.exists(tmpdir / 'sprint_da.png')


Expand Down