Skip to content

plot_embeddings with image_sprites from uri #64

@abduhbm

Description

@abduhbm

To produce:

from docarray import DocumentArray, Document
import torchvision

def preproc(d: Document):
    return (d.load_uri_to_image_tensor()  
             .set_image_tensor_shape((200, 200))  
             .set_image_tensor_normalization() 
             .set_image_tensor_channel_axis(-1, 0)) 

left_da = DocumentArray.from_files('left/*.jpg')

left_da.apply(preproc)

model = torchvision.models.resnet50(pretrained=True)  
left_da.embed(model)  

left_da.plot_embeddings(image_sprites=True)

Error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [1], in <module>
     14 model = torchvision.models.resnet50(pretrained=True)  
     15 left_da.embed(model)  
---> 17 left_da.plot_embeddings(image_sprites=True)

File ~/workplaces/docarray-env/docarray/docarray/array/mixins/plot.py:152, in PlotMixin.plot_embeddings(self, title, path, image_sprites, min_image_size, channel_axis, start_server, host, port)
    140     if len(self) > max_docs:
    141         warnings.warn(
    142             f'''
    143             {self!r} has more than {max_docs} elements, which is the maximum number of image sprites can support. 
   (...)
    149             '''
    150         )
--> 152     self.plot_image_sprites(
    153         os.path.join(path, sprite_fn),
    154         canvas_size=canvas_size,
    155         min_size=min_image_size,
    156         channel_axis=channel_axis,
    157     )
    159 self.save_embeddings_csv(os.path.join(path, emb_fn), delimiter='\t')
    161 _exclude_fields = ('embedding', 'tensor', 'scores')

File ~/workplaces/docarray-env/docarray/docarray/array/mixins/plot.py:331, in PlotMixin.plot_image_sprites(self, output, canvas_size, min_size, channel_axis)
    329 row_id = floor(img_id / img_per_row)
    330 col_id = img_id % img_per_row
--> 331 sprite_img[
    332     (row_id * img_size) : ((row_id + 1) * img_size),
    333     (col_id * img_size) : ((col_id + 1) * img_size),
    334 ] = _d.tensor
    336 img_id += 1
    337 if img_id >= max_num_img:

ValueError: could not broadcast input array from shape (16,16,200) into shape (16,16,3)

Suggested fix:
passing use_uri flag to plot_embeddings to allow visualizing the dots using uri attribute. I would be happy to raise a PR for this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions