The PyTorch3D Framework
Learn the core features of the PyTorch3D library for PyTorch and its uses for 3D deep learning.
Overview
PyTorch is one of the leading deep learning frameworks for research-grade and production-grade projects. Its ease of use, Pythonic syntax, and extensibility make it a favorite among machine learning practitioners for 3D machine learning (3D ML) research.
For those with no experience in PyTorch, it should be easy to follow along, given an understanding of the basics of machine learning. In addition, we’ll be using PyTorch3D, a library built atop PyTorch specifically to handle 3D deep learning data.
What is PyTorch3D?
PyTorch3D is an API built atop PyTorch with GPU-optimized implementations of common components used in 3D deep learning and computer vision. It includes efficient heterogeneous batching operators, a differentiable mesh renderer, common loss functions, I/O support for common 3D formats such as OBJ, OFF, PLY, and glTF, and even early support for implicit representations for novel view synthesis like Neural Radiance Fields (NeRFs). We can solve many problems with it, including mesh deformation, bundle adjustment, learning mesh textures from images, fitting NeRFs, and more.
Overview of the PyTorch3D API
The API covers all of the stages of the 3D machine learning lifecycle, including
The datasets
module
This includes DataLoader classes for common data sources like ShapeNetCore and R2N2.
Note: A PyTorch DataLoader class is a helpful abstraction over datasets that implements loading and iterating over datasets. ShapeNetCore and R2N2 are two commonly used 3D datasets.
Here we simply list everything contained in this module, including functions, classes, and variables.
import pytorch3d.datasetsprint(pytorch3d.datasets.__all__)
The io
module
OBJ, OFF, and PLY are a handful of standard 3D data formats. PyTorch3D implements file loaders such as load_obj
and load_ply
to support these formats. It also includes the batch loader, load_objs_as_meshes
, for loading a list of OBJ files at once. When using external data, at least one of the helper functions in this module will likely be useful.
import pytorch3d.ioprint(pytorch3d.io.__all__)
The loss
module
Many commonly used 3D ML loss functions are implemented here, such as the chamfer_distance
, mesh_laplacian_smoothing
, point_mesh_distance
, and mesh_normal_consistency
.
import pytorch3d.lossprint(pytorch3d.loss.__all__)
The ops
module
This module includes many operations that are useful to construct 3D ML model architectures, such as graph convolution with graph_conv
, perspective_n_points
(PnP), iterative_closest_point
(ICP), and many more.
Note: Graph convolution is a variation on convolution operation that is applied to graph data. The Perspective-n-Point algorithm is a means of estimating the pose of a camera given a list of 3D points with corresponding 2D pixel coordinates. Iterative closest point is a technique for aligning a pair of point clouds.
import pytorch3d.opsprint(pytorch3d.ops.__all__)
The renderer
module
The renderer
module is where much of the magic happens in 3D ML. It provides a set of GPU-optimized differentiable rendering classes that we’ll often use to produce 2D renders of our 3D data. This includes shading, texturing, lighting, cameras, and rasterizers for meshes and point cloud data. By making these renderers differentiable, it allows us to propagate gradients all the way from the output renders down to the components of our scene, such as object position and rotation, texture colors, lighting values, and more.
import pytorch3d.rendererprint(pytorch3d.renderer.__all__)
The structures
module
Meshes and point clouds are a handful of data structures used to represent different kinds of 3D data. Meshes are often useful for authored 3D models. Point clouds are often generated by sensors that generate 3D data, such as LiDAR or photogrammetry. The structures
module includes the Meshes
and Pointclouds
classes used to represent our 3D data. It is also where the bulk of the batching logic is contained and helper functions to convert between the various types of heterogeneous batching.
import pytorch3d.structuresprint(pytorch3d.structures.__all__)
The transforms
module
Last but not least, this module is used to manipulate the position (Translate
), orientation (Rotate
), and scale (Scale
) of elements in our 3D scene. It supports many of the common rotation formats, such as Euler, quaternion, and SO(3) rotations.
import pytorch3d.transformsprint(pytorch3d.transforms.__all__)