The PyTorch3D Framework

Learn the core features of the PyTorch3D library for PyTorch and its uses for 3D deep learning.

Overview

PyTorch is one of the leading deep learning frameworks for research-grade and production-grade projects. Its ease of use, Pythonic syntax, and extensibility make it a favorite among machine learning practitioners for 3D machine learning (3D ML) research.

For those with no experience in PyTorch, it should be easy to follow along, given an understanding of the basics of machine learning. In addition, we’ll be using PyTorch3D, a library built atop PyTorch specifically to handle 3D deep learning data.

What is PyTorch3D?

PyTorch3D is an API built atop PyTorch with GPU-optimized implementations of common components used in 3D deep learning and computer vision. It includes efficient heterogeneous batching operators, a differentiable mesh renderer, common loss functions, I/O support for common 3D formats such as OBJ, OFF, PLY, and glTF, and even early support for implicit representations for novel view synthesis like Neural Radiance Fields (NeRFs). We can solve many problems with it, including mesh deformation, bundle adjustment, learning mesh textures from images, fitting NeRFs, and more.

Overview of the PyTorch3D API

The API covers all of the stages of the 3D machine learning lifecycle, including dataset loadersHelper classes to support iteration over a dataset., common operations, renderersFunctions that convert 3D data into images., and loss functions. Each of the major modules is introduced below.

Press + to interact
Overview of the PyTorch3D API
Overview of the PyTorch3D API

The datasets module

This includes DataLoader classes for common data sources like ShapeNetCore and R2N2.

Note: A PyTorch DataLoader class is a helpful abstraction over datasets that implements loading and iterating over datasets. ShapeNetCore and R2N2 are two commonly used 3D datasets.

Here we simply list everything contained in this module, including functions, classes, and variables.

Press + to interact
import pytorch3d.datasets
print(pytorch3d.datasets.__all__)

The io module

OBJ, OFF, and PLY are a handful of standard 3D data formats. PyTorch3D implements file loaders such as load_obj and load_ply to support these formats. It also includes the batch loader, load_objs_as_meshes, for loading a list of OBJ files at once. When using external data, at least one of the helper functions in this module will likely be useful.

Press + to interact
import pytorch3d.io
print(pytorch3d.io.__all__)

The loss module

Many commonly used 3D ML loss functions are implemented here, such as the chamfer_distance, mesh_laplacian_smoothing, point_mesh_distance, and mesh_normal_consistency.

Press + to interact
import pytorch3d.loss
print(pytorch3d.loss.__all__)

The ops module

This module includes many operations that are useful to construct 3D ML model architectures, such as graph convolution with graph_conv, perspective_n_points (PnP), iterative_closest_point (ICP), and many more.

Note: Graph convolution is a variation on convolution operation that is applied to graph data. The Perspective-n-Point algorithm is a means of estimating the pose of a camera given a list of 3D points with corresponding 2D pixel coordinates. Iterative closest point is a technique for aligning a pair of point clouds.

Press + to interact
import pytorch3d.ops
print(pytorch3d.ops.__all__)

The renderer module

The renderer module is where much of the magic happens in 3D ML. It provides a set of GPU-optimized differentiable rendering classes that we’ll often use to produce 2D renders of our 3D data. This includes shading, texturing, lighting, cameras, and rasterizers for meshes and point cloud data. By making these renderers differentiable, it allows us to propagate gradients all the way from the output renders down to the components of our scene, such as object position and rotation, texture colors, lighting values, and more.

Press + to interact
import pytorch3d.renderer
print(pytorch3d.renderer.__all__)

The structures module

Meshes and point clouds are a handful of data structures used to represent different kinds of 3D data. Meshes are often useful for authored 3D models. Point clouds are often generated by sensors that generate 3D data, such as LiDAR or photogrammetry. The structures module includes the Meshes and Pointclouds classes used to represent our 3D data. It is also where the bulk of the batching logic is contained and helper functions to convert between the various types of heterogeneous batching.

Press + to interact
import pytorch3d.structures
print(pytorch3d.structures.__all__)

The transforms module

Last but not least, this module is used to manipulate the position (Translate), orientation (Rotate), and scale (Scale) of elements in our 3D scene. It supports many of the common rotation formats, such as Euler, quaternion, and SO(3) rotations.

Press + to interact
import pytorch3d.transforms
print(pytorch3d.transforms.__all__)