diff --git a/docarray/document/mixins/mesh.py b/docarray/document/mixins/mesh.py index f4fedbd3b44..3b26489f030 100644 --- a/docarray/document/mixins/mesh.py +++ b/docarray/document/mixins/mesh.py @@ -112,3 +112,34 @@ def load_vertices_and_faces_to_point_cloud(self: 'T', samples: int) -> 'T': ) return self + + def load_uris_to_rgbd_tensor(self: 'T') -> 'T': + """Load RGB image from :attr:`.uri` of :attr:`.chunks[0]` and depth image from :attr:`.uri` of :attr:`.chunks[1]` and merge them into :attr:`.tensor`. + + :return: itself after processed + """ + from PIL import Image + + if len(self.chunks) != 2: + raise ValueError( + f'The provided Document does not have two chunks but instead {len(self.chunks)}. To load uris to RGBD tensor, the Document needs to have two chunks, with the first one providing the RGB image uri, and the second one providing the depth image uri.' + ) + for chunk in self.chunks: + if chunk.uri == '': + raise ValueError( + 'A chunk of the given Document does not provide a uri.' + ) + + rgb_img = np.array(Image.open(self.chunks[0].uri).convert('RGB')) + depth_img = np.array(Image.open(self.chunks[1].uri)) + + if rgb_img.shape[0:2] != depth_img.shape: + raise ValueError( + f'The provided RGB image and depth image are not of the same shapes: {rgb_img.shape[0:2]} != {depth_img.shape}' + ) + + self.tensor = np.concatenate( + (rgb_img, np.expand_dims(depth_img, axis=2)), axis=-1 + ) + + return self diff --git a/docarray/document/mixins/plot.py b/docarray/document/mixins/plot.py index e2aff4f392f..af1779c2362 100644 --- a/docarray/document/mixins/plot.py +++ b/docarray/document/mixins/plot.py @@ -76,8 +76,12 @@ def display(self, from_: Optional[str] = None): Plot image data from :attr:`.uri` or from :attr:`.tensor` if :attr:`.uri` is empty . :param from_: an optional string to decide if a document should display using either the uri or the tensor field. """ - if self._is_3d(): - self.display_3d() + if self._is_3d_point_cloud(): + self.display_point_cloud_tensor() + elif self._is_3d_rgbd(): + self.display_rgbd_tensor() + elif self._is_3d_vertices_and_faces(): + self.display_vertices_and_faces() else: if not from_: if self.uri: @@ -94,60 +98,46 @@ def display(self, from_: Optional[str] = None): else: self.summary() - def _is_3d(self) -> bool: + def _is_3d_point_cloud(self): """ - Tells if Document stores a 3D object saved as point cloud or vertices and face. + Tells if Document stores a 3D object saved as point cloud tensor. :return: bool. """ - if self.uri and self.uri.endswith(tuple(Mesh.FILE_EXTENSIONS)): - return True - elif ( + if ( self.tensor is not None - and self.tensor.shape[1] == 3 and self.tensor.ndim == 2 + and self.tensor.shape[-1] == 3 + ): + return True + else: + return False + + def _is_3d_rgbd(self): + """ + Tells if Document stores a 3D object saved as RGB-D image tensor. + :return: bool. + """ + if ( + self.tensor is not None + and self.tensor.ndim == 3 + and self.tensor.shape[-1] == 4 ): return True - elif self.chunks is not None: + else: + return False + + def _is_3d_vertices_and_faces(self): + """ + Tells if Document stores a 3D object saved as vertices and faces. + :return: bool. + """ + if self.chunks is not None: name_tags = [c.tags['name'] for c in self.chunks] if Mesh.VERTICES in name_tags and Mesh.FACES in name_tags: return True else: return False - def display_3d(self) -> None: - """Plot 3d data.""" - from IPython.display import display - import trimesh - - if self.tensor is not None: - # point cloud from tensor - from hubble.utils.notebook import is_notebook - - if is_notebook(): - pc = trimesh.points.PointCloud( - vertices=self.tensor, - colors=np.tile(np.array([0, 0, 0, 1]), (len(self.tensor), 1)), - ) - s = trimesh.Scene(geometry=pc) - display(s.show()) - else: - pc = trimesh.points.PointCloud(vertices=self.tensor) - display(pc.show()) - - elif self.uri: - # mesh from uri - mesh = self._load_mesh() - display(mesh.show()) - - elif self.chunks is not None: - # mesh from chunks - vertices = [ - c.tensor for c in self.chunks if c.tags['name'] == Mesh.VERTICES - ][-1] - faces = [c.tensor for c in self.chunks if c.tags['name'] == Mesh.FACES][-1] - mesh = trimesh.Trimesh(vertices=vertices, faces=faces) - display(mesh.show()) - def display_tensor(self) -> None: """Plot image data from :attr:`.tensor`""" if self.tensor is None: @@ -169,6 +159,68 @@ def display_tensor(self) -> None: plt.matshow(self.tensor) + def display_vertices_and_faces(self): + """Plot mesh consisting of vertices and faces.""" + from IPython.display import display + + if self.uri: + # mesh from uri + mesh = self._load_mesh() + display(mesh.show()) + + else: + # mesh from chunks + import trimesh + + vertices = [ + c.tensor for c in self.chunks if c.tags['name'] == Mesh.VERTICES + ][-1] + faces = [c.tensor for c in self.chunks if c.tags['name'] == Mesh.FACES][-1] + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) + display(mesh.show()) + + def display_point_cloud_tensor(self) -> None: + """Plot interactive point cloud from :attr:`.tensor`""" + import trimesh + from IPython.display import display + from hubble.utils.notebook import is_notebook + + if is_notebook(): + pc = trimesh.points.PointCloud( + vertices=self.tensor, + colors=np.tile(np.array([0, 0, 0, 1]), (len(self.tensor), 1)), + ) + s = trimesh.Scene(geometry=pc) + display(s.show()) + else: + pc = trimesh.points.PointCloud(vertices=self.tensor) + display(pc.show()) + + def display_rgbd_tensor(self) -> None: + """Plot an RGB-D image and a corresponding depth image from :attr:`.tensor`""" + import matplotlib.pyplot as plt + from mpl_toolkits.axes_grid1 import make_axes_locatable + + rgb_img = self.tensor[:, :, :3] + + depth_img = self.tensor[:, :, -1] + depth_img = depth_img / (np.max(depth_img) + 1e-08) * 255 + depth_img = depth_img.astype(np.uint8) + + f, ax = plt.subplots(1, 2, figsize=(16, 6)) + + ax[0].imshow(rgb_img, interpolation='None') + ax[0].set_title('RGB image\n', fontsize=16) + + im2 = ax[1].imshow(self.tensor[:, :, -1], cmap='gray') + cax = make_axes_locatable(ax[1]).append_axes('right', size='5%', pad=0.05) + f.colorbar(im2, cax=cax, orientation='vertical', label='Depth') + + ax[1].imshow(depth_img, cmap='gray') + ax[1].set_title('Depth image\n', fontsize=16) + + plt.show() + def display_uri(self): """Plot image data from :attr:`.uri`""" diff --git a/docs/datatypes/mesh/index.md b/docs/datatypes/mesh/index.md index 4a11eafb913..0e3c68e75c9 100644 --- a/docs/datatypes/mesh/index.md +++ b/docs/datatypes/mesh/index.md @@ -51,7 +51,7 @@ chunk.tags = {'name': 'faces'} ``` -You can display your 3d object and interact with it via: +You can display your 3D object and interact with it via: ```python doc.display() ``` @@ -1321,7 +1321,7 @@ print(doc.tensor.shape) (1000, 3) ``` -You can display your 3d object and interact with it via: +You can display your 3D object and interact with it via: ```python doc.display() @@ -2574,3 +2574,35 @@ function animate(){requestAnimationFrame(animate);controls.update();} function render(){tracklight.position.copy(camera.position);renderer.render(scene,camera);} init(); " width="100%" height="500px" style="border:none;"> + └─ chunks + ├─ + └─ +``` + +To display the RGB image and its corresponding depth image: + +```python +doc.display() +``` + +```{figure} rgbd_chair.png +``` diff --git a/docs/datatypes/mesh/rgbd_chair.png b/docs/datatypes/mesh/rgbd_chair.png new file mode 100644 index 00000000000..5a908851637 Binary files /dev/null and b/docs/datatypes/mesh/rgbd_chair.png differ diff --git a/docs/fundamentals/notebook-support/image-rgbd.png b/docs/fundamentals/notebook-support/image-rgbd.png new file mode 100644 index 00000000000..dec053b9375 Binary files /dev/null and b/docs/fundamentals/notebook-support/image-rgbd.png differ diff --git a/docs/fundamentals/notebook-support/index.md b/docs/fundamentals/notebook-support/index.md index f1e89bfb9ff..38e40f3f36e 100644 --- a/docs/fundamentals/notebook-support/index.md +++ b/docs/fundamentals/notebook-support/index.md @@ -41,11 +41,16 @@ Video and audio Document can be displayed as well, you can play them in the cell ```{figure} audio-video.png ``` -You can also display your 3d object and interact with it in the cell, whether you stored it as a point cloud or vertices and faces. +You can also display your 3D object and interact with it in the cell, whether you stored it as a point cloud or vertices and faces. ```{figure} mesh-point-cloud.png ``` +Additionally, you can also display an RGB image and its corresponding depth image: + +```{figure} image-rgbd.png +``` + ## Display DocumentArray A cell with a DocumentArray object can be pretty-printed automatically. diff --git a/tests/unit/document/test_converters.py b/tests/unit/document/test_converters.py index c67a507a3b7..f1ddf50b3d6 100644 --- a/tests/unit/document/test_converters.py +++ b/tests/unit/document/test_converters.py @@ -330,3 +330,61 @@ def test_load_to_point_cloud_without_vertices_faces_set_raise_warning(uri): AttributeError, match='vertices and faces chunk tensor have not been set' ): doc.load_vertices_and_faces_to_point_cloud(100) + + +@pytest.mark.parametrize( + 'uri_rgb, uri_depth', + [ + ( + os.path.join(cur_dir, 'toydata/test_rgb.jpg'), + os.path.join(cur_dir, 'toydata/test_depth.png'), + ) + ], +) +def test_load_uris_to_rgbd_tensor(uri_rgb, uri_depth): + doc = Document( + chunks=[ + Document(uri=uri_rgb), + Document(uri=uri_depth), + ] + ) + doc.load_uris_to_rgbd_tensor() + + assert doc.tensor.shape[-1] == 4 + + +@pytest.mark.parametrize( + 'uri_rgb, uri_depth', + [ + ( + os.path.join(cur_dir, 'toydata/test.png'), + os.path.join(cur_dir, 'toydata/test_depth.png'), + ) + ], +) +def test_load_uris_to_rgbd_tensor_different_shapes_raise_exception(uri_rgb, uri_depth): + doc = Document( + chunks=[ + Document(uri=uri_rgb), + Document(uri=uri_depth), + ] + ) + with pytest.raises( + ValueError, + match='The provided RGB image and depth image are not of the same shapes', + ): + doc.load_uris_to_rgbd_tensor() + + +def test_load_uris_to_rgbd_tensor_doc_wo_uri_raise_exception(): + doc = Document( + chunks=[ + Document(), + Document(), + ] + ) + with pytest.raises( + ValueError, + match='A chunk of the given Document does not provide a uri.', + ): + doc.load_uris_to_rgbd_tensor() diff --git a/tests/unit/document/toydata/test_depth.png b/tests/unit/document/toydata/test_depth.png new file mode 100644 index 00000000000..1ceb90a1b70 Binary files /dev/null and b/tests/unit/document/toydata/test_depth.png differ diff --git a/tests/unit/document/toydata/test_rgb.jpg b/tests/unit/document/toydata/test_rgb.jpg new file mode 100644 index 00000000000..8327ee2904d Binary files /dev/null and b/tests/unit/document/toydata/test_rgb.jpg differ