diff --git a/docarray/document/mixins/mesh.py b/docarray/document/mixins/mesh.py index 3bfbd0fa1e0..f4fedbd3b44 100644 --- a/docarray/document/mixins/mesh.py +++ b/docarray/document/mixins/mesh.py @@ -1,15 +1,44 @@ -import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union import numpy as np if TYPE_CHECKING: # pragma: no cover from docarray.typing import T + import trimesh + + +class Mesh: + FILE_EXTENSIONS = [ + 'glb', + 'obj', + 'ply', + ] + VERTICES = 'vertices' + FACES = 'faces' class MeshDataMixin: """Provide helper functions for :class:`Document` to support 3D mesh data and point cloud.""" + def _load_mesh( + self, force: str = None + ) -> Union['trimesh.Trimesh', 'trimesh.Scene']: + """Load a trimesh.Mesh or trimesh.Scene object from :attr:`.uri`. + + :param force: str or None. For 'mesh' try to coerce scenes into a single mesh. For 'scene' + try to coerce everything into a scene. + :return: trimesh.Mesh or trimesh.Scene object + """ + import urllib.parse + import trimesh + + scheme = urllib.parse.urlparse(self.uri).scheme + loader = trimesh.load_remote if scheme in ['http', 'https'] else trimesh.load + + mesh = loader(self.uri, force=force) + + return mesh + def load_uri_to_point_cloud_tensor( self: 'T', samples: int, as_chunks: bool = False ) -> 'T': @@ -21,23 +50,19 @@ def load_uri_to_point_cloud_tensor( :return: itself after processed """ - import trimesh - import urllib.parse - - scheme = urllib.parse.urlparse(self.uri).scheme - loader = trimesh.load_remote if scheme in ['http', 'https'] else trimesh.load if as_chunks: + import trimesh from docarray.document import Document # try to coerce everything into a scene - scene = loader(self.uri, force='scene') + scene = self._load_mesh(force='scene') for geo in scene.geometry.values(): geo: trimesh.Trimesh self.chunks.append(Document(tensor=np.array(geo.sample(samples)))) else: # combine a scene into a single mesh - mesh = loader(self.uri, force='mesh') + mesh = self._load_mesh(force='mesh') self.tensor = np.array(mesh.sample(samples)) return self @@ -47,22 +72,16 @@ def load_uri_to_vertices_and_faces(self: 'T') -> 'T': :return: itself after processed """ - - import trimesh - import urllib.parse from docarray.document import Document - scheme = urllib.parse.urlparse(self.uri).scheme - loader = trimesh.load_remote if scheme in ['http', 'https'] else trimesh.load - - mesh = loader(self.uri, force='mesh') + mesh = self._load_mesh(force='mesh') vertices = mesh.vertices.view(np.ndarray) faces = mesh.faces.view(np.ndarray) self.chunks = [ - Document(name='vertices', tensor=vertices), - Document(name='faces', tensor=faces), + Document(name=Mesh.VERTICES, tensor=vertices), + Document(name=Mesh.FACES, tensor=faces), ] return self @@ -73,18 +92,18 @@ def load_vertices_and_faces_to_point_cloud(self: 'T', samples: int) -> 'T': :param samples: number of points to sample from the mesh :return: itself after processed """ - import trimesh - vertices = None faces = None for chunk in self.chunks: - if chunk.tags['name'] == 'vertices': + if chunk.tags['name'] == Mesh.VERTICES: vertices = chunk.tensor - if chunk.tags['name'] == 'faces': + if chunk.tags['name'] == Mesh.FACES: faces = chunk.tensor if vertices is not None and faces is not None: + import trimesh + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) self.tensor = np.array(mesh.sample(samples)) else: diff --git a/docarray/document/mixins/plot.py b/docarray/document/mixins/plot.py index ad89f95de76..89324cb3292 100644 --- a/docarray/document/mixins/plot.py +++ b/docarray/document/mixins/plot.py @@ -3,7 +3,7 @@ import numpy as np -from docarray.helper import deprecate_by +from docarray.document.mixins.mesh import Mesh class PlotMixin: @@ -73,26 +73,81 @@ def _plot_recursion(self, tree=None): def display(self, from_: Optional[str] = None): """ Plot image data from :attr:`.uri` or from :attr:`.tensor` if :attr:`.uri` is empty . - :param from_: an optional string to decide if a document should display using either the uri or the tensor field. """ + if self._is_3d(): + self.display_3d() + else: + if not from_: + if self.uri: + from_ = 'uri' + elif self.tensor is not None: + from_ = 'tensor' + else: + self.summary() - if not from_: - if self.uri: - from_ = 'uri' - elif self.tensor is not None: - from_ = 'tensor' + if from_ == 'uri': + self.display_uri() + elif from_ == 'tensor': + self.display_tensor() else: self.summary() - if from_ == 'uri': - self.display_uri() - elif from_ == 'tensor': - self.display_tensor() + def _is_3d(self) -> bool: + """ + Tells if Document stores a 3D object saved as point cloud or vertices and face. + :return: bool. + """ + if self.uri and self.uri.endswith(tuple(Mesh.FILE_EXTENSIONS)): + return True + elif ( + self.tensor is not None + and self.tensor.shape[1] == 3 + and self.tensor.ndim == 2 + ): + return True + elif self.chunks is not None: + name_tags = [c.tags['name'] for c in self.chunks] + if Mesh.VERTICES in name_tags and Mesh.FACES in name_tags: + return True else: - self.summary() + return False - def display_tensor(self): + def display_3d(self) -> None: + """Plot 3d data.""" + from IPython.display import display + import trimesh + + if self.tensor is not None: + # point cloud from tensor + from hubble.utils.notebook import is_notebook + + if is_notebook(): + pc = trimesh.points.PointCloud( + vertices=self.tensor, + colors=np.tile(np.array([0, 0, 0, 1]), (len(self.tensor), 1)), + ) + s = trimesh.Scene(geometry=pc) + display(s.show()) + else: + pc = trimesh.points.PointCloud(vertices=self.tensor) + display(pc.show()) + + elif self.uri: + # mesh from uri + mesh = self._load_mesh() + display(mesh.show()) + + elif self.chunks is not None: + # mesh from chunks + vertices = [ + c.tensor for c in self.chunks if c.tags['name'] == Mesh.VERTICES + ][-1] + faces = [c.tensor for c in self.chunks if c.tags['name'] == Mesh.FACES][-1] + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) + display(mesh.show()) + + def display_tensor(self) -> None: """Plot image data from :attr:`.tensor`""" if self.tensor is None: raise ValueError( diff --git a/docs/datatypes/mesh/index.md b/docs/datatypes/mesh/index.md index 27c52825767..41289eb5f1f 100644 --- a/docs/datatypes/mesh/index.md +++ b/docs/datatypes/mesh/index.md @@ -7,16 +7,19 @@ This feature requires `trimesh`. You can install it via `pip install "docarray[f A 3D mesh is the structural build of a 3D model consisting of polygons. Most 3D meshes are created via professional software packages, such as commercial suites like Unity, or the free open source Blender 3D. +DocArray supports .obj, .glb and .ply files. + ## Vertices and faces representation A 3D mesh can be represented by its vertices and faces. Vertices are points in a 3D space, represented as a tensor of shape (n_points, 3). Faces are triangular surfaces that can be defined by three points in 3D space, corresponding to the three vertices of a triangle. Faces can be represented as a tensor of shape (n_faces, 3). Each number in that tensor refers to an index of a vertex in the tensor of vertices. + In DocArray, you can load a mesh and save its vertices and faces to a Document's `.chunks` as follows: ```python from docarray import Document -doc = Document(uri='viking.glb').load_uri_to_vertices_and_faces() +doc = Document(uri='mesh_man.glb').load_uri_to_vertices_and_faces() doc.summary() ``` @@ -28,7 +31,7 @@ doc.summary() └─ ``` -This stores the vertices and faces in `.tensor` of two separate sub-Documents in a Document's `.chunks`. Each sub-Document has a name assigned to it ('vertices' or 'faces'), which is saved in `.tags`: +This stores the vertices and faces in `.tensor` of four separate sub-Documents in a Document's `.chunks`. Each sub-Document has a name assigned to it ('vertices' or 'faces'), which is saved in `.tags`: ```python for chunk in doc.chunks: @@ -40,20 +43,1269 @@ chunk.tags = {'name': 'vertices'} chunk.tags = {'name': 'faces'} ``` -The following picture depicts a 3D mesh: -```{figure} 3dmesh-man.gif -:width: 50% +You can display your 3d object and interact with it via: +```python +doc.display() ``` + + + ## Point cloud representation -A point cloud is a representation of a 3D mesh. It is made by repeatedly and uniformly sampling points within the 3D body. Compared to the mesh representation, the point cloud is a fixed size ndarray and hence easier for deep learning algorithms to handle. In DocArray, you can simply load a 3D mesh and convert it into a point cloud of size `samples` via: +A point cloud is a representation of a 3D mesh. It is made by repeatedly and uniformly sampling points within the surface of the 3D body. Compared to the mesh representation, the point cloud is a fixed size ndarray and hence easier for deep learning algorithms to handle. In DocArray, you can simply load a 3D mesh and convert it into a point cloud of size `samples` via: ```python from docarray import Document -doc = Document(uri='viking.glb').load_uri_to_point_cloud_tensor(samples=1000) +doc = Document(uri='mesh_man.glb').load_uri_to_point_cloud_tensor(samples=30000) print(doc.tensor.shape) ``` @@ -62,8 +1314,1256 @@ print(doc.tensor.shape) (1000, 3) ``` -The following picture depicts a point cloud with 1000 samples from the previously depicted 3D mesh. +You can display your 3d object and interact with it via: -```{figure} pointcloud-man.gif -:width: 50% +```python +doc.display() ``` + +=12.0.0', 'jina-hubble-sdk>=0.13.1'], + install_requires=['numpy', 'rich>=12.0.0', 'jina-hubble-sdk>=0.24.0'], extras_require={ # req usage, please see https://docarray.jina.ai/#install 'common': [ diff --git a/tests/unit/document/test_converters.py b/tests/unit/document/test_converters.py index d7e24496afb..c67a507a3b7 100644 --- a/tests/unit/document/test_converters.py +++ b/tests/unit/document/test_converters.py @@ -6,9 +6,11 @@ from docarray import Document from docarray.document.generators import from_files +from docarray.document.mixins.mesh import Mesh __windows__ = sys.platform == 'win32' + cur_dir = os.path.dirname(os.path.abspath(__file__)) @@ -258,6 +260,7 @@ def test_convert_uri_to_data_uri(uri, mimetype): @pytest.mark.parametrize( 'uri, chunk_num', [ + (os.path.join(cur_dir, 'toydata/cube.ply'), 1), (os.path.join(cur_dir, 'toydata/test.glb'), 1), ( 'https://github.com/jina-ai/docarray/raw/main/tests/unit/document/toydata/test.glb', @@ -276,19 +279,33 @@ def test_glb_converters(uri, chunk_num): assert doc.chunks[0].tensor.shape == (2000, 3) -@pytest.mark.parametrize('uri', [(os.path.join(cur_dir, 'toydata/test.glb'))]) +@pytest.mark.parametrize( + 'uri', + [ + (os.path.join(cur_dir, 'toydata/cube.ply')), + (os.path.join(cur_dir, 'toydata/test.glb')), + (os.path.join(cur_dir, 'toydata/tetrahedron.obj')), + ], +) def test_load_uri_to_vertices_and_faces(uri): doc = Document(uri=uri) doc.load_uri_to_vertices_and_faces() assert len(doc.chunks) == 2 - assert doc.chunks[0].tags['name'] == 'vertices' + assert doc.chunks[0].tags['name'] == Mesh.VERTICES assert doc.chunks[0].tensor.shape[1] == 3 - assert doc.chunks[1].tags['name'] == 'faces' + assert doc.chunks[1].tags['name'] == Mesh.FACES assert doc.chunks[1].tensor.shape[1] == 3 -@pytest.mark.parametrize('uri', [(os.path.join(cur_dir, 'toydata/test.glb'))]) +@pytest.mark.parametrize( + 'uri', + [ + (os.path.join(cur_dir, 'toydata/cube.ply')), + (os.path.join(cur_dir, 'toydata/test.glb')), + (os.path.join(cur_dir, 'toydata/tetrahedron.obj')), + ], +) def test_load_vertices_and_faces_to_point_cloud(uri): doc = Document(uri=uri) doc.load_uri_to_vertices_and_faces() @@ -298,7 +315,14 @@ def test_load_vertices_and_faces_to_point_cloud(uri): assert isinstance(doc.tensor, np.ndarray) -@pytest.mark.parametrize('uri', [(os.path.join(cur_dir, 'toydata/test.glb'))]) +@pytest.mark.parametrize( + 'uri', + [ + (os.path.join(cur_dir, 'toydata/cube.ply')), + (os.path.join(cur_dir, 'toydata/test.glb')), + (os.path.join(cur_dir, 'toydata/tetrahedron.obj')), + ], +) def test_load_to_point_cloud_without_vertices_faces_set_raise_warning(uri): doc = Document(uri=uri) diff --git a/tests/unit/document/toydata/cube.ply b/tests/unit/document/toydata/cube.ply new file mode 100644 index 00000000000..681156a7fc4 --- /dev/null +++ b/tests/unit/document/toydata/cube.ply @@ -0,0 +1,24 @@ +ply +format ascii 1.0 +comment created by platoply +element vertex 8 +property float32 x +property float32 y +property float32 z +element face 6 +property list uint8 int32 vertex_indices +end_header +-1 -1 -1 +1 -1 -1 +1 1 -1 +-1 1 -1 +-1 -1 1 +1 -1 1 +1 1 1 +-1 1 1 +4 0 1 2 3 +4 5 4 7 6 +4 6 2 1 5 +4 3 7 4 0 +4 7 3 2 6 +4 5 1 0 4 diff --git a/tests/unit/document/toydata/tetrahedron.mtl b/tests/unit/document/toydata/tetrahedron.mtl new file mode 100644 index 00000000000..1bccd4474e4 --- /dev/null +++ b/tests/unit/document/toydata/tetrahedron.mtl @@ -0,0 +1,22 @@ + +newmtl red +Ka 0.4449 0.0000 0.0000 +Kd 0.7714 0.0000 0.0000 +Ks 0.8857 0.0000 0.0000 +illum 2 +Ns 136.4300 + +newmtl lime +Ka 0.0000 0.5000 0.0000 +Kd 0.0000 1.0000 0.0000 +Ks 0.0000 0.5000 0.0000 +illum 2 +Ns 65.8900 + +newmtl gold +Ka 0.5265 0.2735 0.0122 +Kd 1.0000 0.5184 0.0286 +Ks 0.3000 0.3000 0.3000 +illum 2 +Ns 123.2600 + diff --git a/tests/unit/document/toydata/tetrahedron.obj b/tests/unit/document/toydata/tetrahedron.obj new file mode 100644 index 00000000000..40347bad7b7 --- /dev/null +++ b/tests/unit/document/toydata/tetrahedron.obj @@ -0,0 +1,20 @@ +# tetrahedron.obj +# + +mtllib tetrahedron.mtl + +g tetrahedron + +v 1.00 1.00 1.00 +v 2.00 1.00 1.00 +v 1.00 2.00 1.00 +v 1.00 1.00 2.00 + +usemtl lime +f 1 3 2 +usemtl gold +f 1 4 3 +usemtl lime +f 1 2 4 +usemtl red +f 2 3 4