diff --git a/docarray/array/mixins/plot.py b/docarray/array/mixins/plot.py index 4acc47f0c5a..c9d776cdb5d 100644 --- a/docarray/array/mixins/plot.py +++ b/docarray/array/mixins/plot.py @@ -108,6 +108,7 @@ def plot_embeddings( start_server: bool = True, host: str = '127.0.0.1', port: Optional[int] = None, + image_source: str = 'tensor', ) -> str: """Interactively visualize :attr:`.embeddings` using the Embedding Projector. @@ -121,6 +122,7 @@ def plot_embeddings( :param min_image_size: only used when `image_sprites=True`. the minimum size of the image :param channel_axis: only used when `image_sprites=True`. the axis id of the color channel, ``-1`` indicates the color channel info at the last axis :param start_server: if set, start a HTTP server and open the frontend directly. Otherwise, you need to rely on ``return`` path and serve by yourself. + :param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri :return: the path to the embeddings visualization info. """ from ...helper import random_port, __resources_path__ @@ -154,6 +156,7 @@ def plot_embeddings( canvas_size=canvas_size, min_size=min_image_size, channel_axis=channel_axis, + image_source=image_source, ) self.save_embeddings_csv(os.path.join(path, emb_fn), delimiter='\t') @@ -287,6 +290,7 @@ def plot_image_sprites( canvas_size: int = 512, min_size: int = 16, channel_axis: int = -1, + image_source: str = 'tensor', ) -> None: """Generate a sprite image for all image tensors in this DocumentArray-like object. @@ -297,6 +301,7 @@ def plot_image_sprites( :param canvas_size: the size of the canvas :param min_size: the minimum size of the image :param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis + :param image_source: specify where the image comes from, can be ``uri`` or ``tensor``. empty tensor will fallback to uri """ if not self: raise ValueError(f'{self!r} is empty') @@ -316,26 +321,36 @@ def plot_image_sprites( [img_size * img_per_row, img_size * img_per_row, 3], dtype='uint8' ) img_id = 0 - for d in self: - _d = copy.deepcopy(d) - if _d.content_type != 'tensor': - _d.load_uri_to_image_tensor() - channel_axis = -1 - - _d.set_image_tensor_channel_axis(channel_axis, -1).set_image_tensor_shape( - shape=(img_size, img_size) - ) - - row_id = floor(img_id / img_per_row) - col_id = img_id % img_per_row - sprite_img[ - (row_id * img_size) : ((row_id + 1) * img_size), - (col_id * img_size) : ((col_id + 1) * img_size), - ] = _d.tensor - - img_id += 1 - if img_id >= max_num_img: - break + try: + for d in self: + _d = copy.deepcopy(d) + + if image_source == 'uri' or ( + image_source == 'tensor' and _d.content_type != 'tensor' + ): + _d.load_uri_to_image_tensor() + channel_axis = -1 + elif image_source not in ('uri', 'tensor'): + raise ValueError(f'image_source can be only `uri` or `tensor`') + + _d.set_image_tensor_channel_axis( + channel_axis, -1 + ).set_image_tensor_shape(shape=(img_size, img_size)) + + row_id = floor(img_id / img_per_row) + col_id = img_id % img_per_row + sprite_img[ + (row_id * img_size) : ((row_id + 1) * img_size), + (col_id * img_size) : ((col_id + 1) * img_size), + ] = _d.tensor + + img_id += 1 + if img_id >= max_num_img: + break + except Exception as ex: + raise ValueError( + 'Bad image tensor. Try different `image_source` or `channel_axis`' + ) from ex from PIL import Image diff --git a/tests/unit/array/mixins/test_plot.py b/tests/unit/array/mixins/test_plot.py index 1ec701244b8..cdc3853ba4e 100644 --- a/tests/unit/array/mixins/test_plot.py +++ b/tests/unit/array/mixins/test_plot.py @@ -8,7 +8,7 @@ from docarray import DocumentArray, Document -def test_sprite_image_generator(pytestconfig, tmpdir): +def test_sprite_fail_tensor_success_uri(pytestconfig, tmpdir): da = DocumentArray.from_files( [ f'{pytestconfig.rootdir}/**/*.png', @@ -16,7 +16,26 @@ def test_sprite_image_generator(pytestconfig, tmpdir): f'{pytestconfig.rootdir}/**/*.jpeg', ] ) - da.plot_image_sprites(tmpdir / 'sprint_da.png') + da.apply( + lambda d: d.load_uri_to_image_tensor().set_image_tensor_channel_axis(-1, 0) + ) + with pytest.raises(ValueError): + da.plot_image_sprites() + da.plot_image_sprites(tmpdir / 'sprint_da.png', image_source='uri') + assert os.path.exists(tmpdir / 'sprint_da.png') + + +@pytest.mark.parametrize('image_source', ['tensor', 'uri']) +def test_sprite_image_generator(pytestconfig, tmpdir, image_source): + da = DocumentArray.from_files( + [ + f'{pytestconfig.rootdir}/**/*.png', + f'{pytestconfig.rootdir}/**/*.jpg', + f'{pytestconfig.rootdir}/**/*.jpeg', + ] + ) + da.apply(lambda d: d.load_uri_to_image_tensor()) + da.plot_image_sprites(tmpdir / 'sprint_da.png', image_source=image_source) assert os.path.exists(tmpdir / 'sprint_da.png')