-
Notifications
You must be signed in to change notification settings - Fork 238
plot_embeddings with image_sprites from uri #64
Copy link
Copy link
Closed
Description
To produce:
from docarray import DocumentArray, Document
import torchvision
def preproc(d: Document):
return (d.load_uri_to_image_tensor()
.set_image_tensor_shape((200, 200))
.set_image_tensor_normalization()
.set_image_tensor_channel_axis(-1, 0))
left_da = DocumentArray.from_files('left/*.jpg')
left_da.apply(preproc)
model = torchvision.models.resnet50(pretrained=True)
left_da.embed(model)
left_da.plot_embeddings(image_sprites=True)Error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [1], in <module>
14 model = torchvision.models.resnet50(pretrained=True)
15 left_da.embed(model)
---> 17 left_da.plot_embeddings(image_sprites=True)
File ~/workplaces/docarray-env/docarray/docarray/array/mixins/plot.py:152, in PlotMixin.plot_embeddings(self, title, path, image_sprites, min_image_size, channel_axis, start_server, host, port)
140 if len(self) > max_docs:
141 warnings.warn(
142 f'''
143 {self!r} has more than {max_docs} elements, which is the maximum number of image sprites can support.
(...)
149 '''
150 )
--> 152 self.plot_image_sprites(
153 os.path.join(path, sprite_fn),
154 canvas_size=canvas_size,
155 min_size=min_image_size,
156 channel_axis=channel_axis,
157 )
159 self.save_embeddings_csv(os.path.join(path, emb_fn), delimiter='\t')
161 _exclude_fields = ('embedding', 'tensor', 'scores')
File ~/workplaces/docarray-env/docarray/docarray/array/mixins/plot.py:331, in PlotMixin.plot_image_sprites(self, output, canvas_size, min_size, channel_axis)
329 row_id = floor(img_id / img_per_row)
330 col_id = img_id % img_per_row
--> 331 sprite_img[
332 (row_id * img_size) : ((row_id + 1) * img_size),
333 (col_id * img_size) : ((col_id + 1) * img_size),
334 ] = _d.tensor
336 img_id += 1
337 if img_id >= max_num_img:
ValueError: could not broadcast input array from shape (16,16,200) into shape (16,16,3)
Suggested fix:
passing use_uri flag to plot_embeddings to allow visualizing the dots using uri attribute. I would be happy to raise a PR for this.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels