Report this

What is the reason for this report?

The Practical Guide to Advanced PyTorch

Published on January 28, 2026
The Practical Guide to Advanced PyTorch

Introduction

Mastering PyTorch no longer means familiarity with features; it’s about operating a repeatable engineering workflow where training code stays fast, scalable, and recoverable under realistic production workloads. Features like torch.compile, torch.profiler, DDP/FSDP, and Distributed Checkpointing are powerful weapons in the ML and infra teams’ arsenal for keeping training fast. However, they’re meaningless if not applied in the right sequence and validated rigorously.

In this article, we share a recommended workflow path: baseline → compile → profile → scale → checkpoint. We’ll cover what to measure before optimizing, common pitfalls to avoid with the compiler & profiler, decision rules to follow for DDP vs FSDP, and how to implement fault-tolerant checkpointing for multi-node training runs.

Key Takeaways

  • Approach PyTorch performance tuning as an iterative engineering process rather than a feature checklist: Iteratively applying baseline → compile → profile → scale → checkpoint leads to more sustainable performance than enabling optimizations in isolation.
  • Unless you have a stable single-GPU eager baseline with known throughput and verified correctness, you can not effectively compare performance or debug issues.
  • Never optimize before establishing a correct baseline: A stable single-GPU eager baseline with measured throughput and verified correctness is essential for meaningful performance comparisons and debugging.
  • Use torch.compile deliberately, not blindly: Keep track of graph breaks, manage shapes, warm up before measuring, and validate that steady-state performance actually improves over eager execution.
  • Profile to guide decisions, not to confirm assumptions: Use torch.profiler to make explicit decisions, resolve CPU stalls, kernel hotspots, shape re-tracing, and validate communication overhead in distributed runs.
  • Design checkpointing for failure from day one: Distributed Checkpointing (with optional async saves) must capture full training state, support resharding across GPU counts, and be routinely tested via restore drills to ensure real fault tolerance.

Baseline: Establish a Reference Point

Start with having a working single-GPU training example. This will be your reference point for both functionality and performance. Define the model, dataloader, and the training loop. Ensure everything runs end-to-end in eager mode (no compilation, one process). Here’s an example of what the training loop might look like:

import torch
import torch.nn as nn
# Dummy model and data for illustration
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10))
data = torch.randn(32, 100)
targets = torch.randint(0, 10, (32,))
# Baseline forward + backward
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
outputs = model(data)
loss = nn.CrossEntropyLoss()(outputs, targets)
loss.backward()
optimizer.step()

Execute a few iterations of training and measure throughput ( for example, samples/sec). Having a correct baseline ensures your model converges and provides a point of comparison for future optimizations. At this point, correctness and baseline performance are priorities: ensure the GPU is being utilized (look at nvidia-smi or torch logs), and that there are no major obvious bottlenecks (loadgen stalling on data, excessive CPU ops, etc.). Don’t move on to compilation and other optimizations until you have a stable baseline.

Baseline Checklist

  • Functional correctness: Model trains and produces expected results on a single GPU.
  • Basic performance logged: e.g., time per batch or GPU utilization percentage.
  • No obvious bottlenecks: Data pipeline keeps GPU busy (no long idle times in baseline).

Once you’ve implemented a good baseline, you can start adding advanced features incrementally, starting with PyTorch 2.x’s flagship compiler to accelerate training.

Compile: Accelerate with torch.compile in PyTorch

PyTorch 2 allowed users to JIT compile models for performance using torch.compile. Wrap your model or function with torch.compile, and PyTorch generates optimized code behind the scenes.

Check out how easy it is below:

# Switch to compiled mode
model = torch.compile(model)  # uses default backend 'inductor'

Calling model(data) will automatically use an optimized code path. The first few iterations will compile your model just-in-time; future iterations will use the optimized kernels directly. PyTorch 2.9 automatically caches compilation results to improve future runs (even between processes).

However, if you want to get the most performance out of compiles, there are a couple of things you should know about graph breaks and dynamic shapes.

Graph breaks

Graph breaks happen whenever the compiler is unable to capture a portion of your code into a single graph (such as data-dependent Python control flow). By default, torch.compile will automatically fall back to eager execution for unsupported parts (breaking the graph) and compile the rest. This ensures that your code won’t crash. However, each graph break prevents part of the code from being optimized (that portion runs purely in Python, not in the fused graph). torch.compile(fullgraph=True) is helpful for development-time debugging of graph breaks, as it throws an error if any part of the model cannot be compiled. For instance:

model = torch.compile(model, fullgraph=True)
try:
    model(data)
except Exception as e:
    print("Graph break:", e)
model = torch.compile(model, fullgraph=True)
try:
    model(data)
except Exception as e:
    print("Graph break:", e)

This will cause an error on the first unsupported operation and point you towards pieces that need to be refactored. The error message will usually include some hints (or a URL) on what caused the breakage and how to avoid it. PyTorch logs may be useful: if you run your script with the environment variable TORCH_LOGS=“graph_breaks”, it will print reasons and locations where graph breaks occur. Use the provided information to rewrite/remove Python-side operations that prevent compilation (such as replacing Python list() ops or data-dependent if/else code with tensor ops, or remove them entirely).

Dynamic shapes

By default, torch.compile will specialize the compiled graph to the shapes it sees. If your model is running with inputs of varying sizes, it will recompile the graph each time it sees a new shape. This recompilation has some overhead. In PyTorch 2.9, we added a dynamic flag. Specifying dynamic=True ahead of time attempts to generate a single kernel that can handle multiple shapes (symbolic shapes). For instance:

model = torch.compile(model, dynamic=True)

This instructs the compiler to trace a generalized graph, which trades off recompilations as sizes change (such as variable sequence lengths). On the other hand, passing dynamic=False will never produce dynamic kernels and will force specialization to exact shapes (which can result in faster code if true). Setting dynamic=None is the default starting in 2.9, and tells the compiler to autodetect if there were multiple recompilations and then switch to using a dynamic kernel on the fly.

  • Leave dynamic at default or set to False if the input shapes are constant or change rarely for maximal specialization.
  • Set dynamic=True if you’re getting frequent recompiles due to shape variations (NLP models running sentences of varying length, etc.) and want to avoid hitting the recompile limit (which is 8 by default before giving up and reverting to eager execution).

Compiled Example

Let’s compile a simple model and demonstrate handling graph breaks and dynamic shapes:

import torch
# Sample model with a potential graph break (data-dependent control flow)
class ToyModel(torch.nn.Module):
    def forward(self, x):
        # Example: data-dependent branching (not traceable by Dynamo)
        if x.sum() > 0:
            return x * 2
        else:
            return x
model = ToyModel()
# Attempt full-graph compilation to catch breaks
try:
    model = torch.compile(model, fullgraph=True)
    output = model(torch.randn(4, 4))
except Exception as e:
    print("Graph break detected:", e)
    # Rewrite model or accept partial graph compilation
    model = torch.compile(model, fullgraph=False)  # fallback to allow breaks

In the code above, the data-dependent if will throw a graph break error when using fullgraph=True. We catch that error, print it out, then recompile with the default behavior of allowing graph breaks. In practice, you would fix your model (remove data-dependent branching in this case) until you can compile with fullgraph=True with no exceptions – meaning the entire model is statically compile-able as one graph for max speed.

Compile Checklist

  • Wrap the model with torch.compile(): Small code change that often leads to significant speedups. Choose the backend “inductor” by default unless you have a reason not to.
  • Warm up: Run a few iterations before timing – first runs include compile time.
  • Check graph breaks: In development, try running with fullgraph=True and set the environment variable TORCH_LOGS=“graph_breaks” to identify pieces of code that aren’t supported. Remove/refactor those uses.
  • Tune dynamic shapes: If you find yourself getting many recompilations (look for “recompiling” in stdout/stderr/logs), you may consider dynamic=True to generate a shape-flexible graph.
  • Monitor speedup: Measure throughput vs your baseline. The first iteration will likely be slower, but steady-state speed should exceed your eager baseline. If not, run with performance hints enabled (e.g., set TORCH_LOGS=“perf_hints”).

After compiling our model and achieving faster performance, we should analyze its execution to identify any remaining bottlenecks and determine further areas for tuning.

Profile: Diagnose Bottlenecks with torch.profiler

Performance bottlenecks can still exist even after compilation - whether that be an underutilized GPU, I/O bottlenecks, or unoptimized kernels. PyTorch comes with a built-in Profiler to collect execution traces and identify these bottlenecks. In PyTorch 2.9, torch.profiler can trace CPU and GPU activities, record shapes, and integrate with Chrome Trace Viewer or TensorBoard for visualization. torch.profiler can also capture traces asynchronously. This way, you can profile portions of training without interrupting the program flow.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.profiler
# -----------------------------
# Device setup
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Model definition
# -----------------------------
model = nn.Sequential(
    nn.Linear(100, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
).to(device)
# -----------------------------
# Optimizer and loss
# -----------------------------
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# -----------------------------
# Data iterator (synthetic)
# -----------------------------
def data_generator(batch_size=32):
    while True:
        inputs = torch.randn(batch_size, 100, device=device)
        targets = torch.randint(0, 10, (batch_size,), device=device)
        yield inputs, targets
data_iter = data_generator()
# -----------------------------
# Training step
# -----------------------------
def train_step(batch):
    model.train()
    inputs, targets = batch
    optimizer.zero_grad(set_to_none=True)
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss
# -----------------------------
# Warm-up phase
# -----------------------------
for _ in range(5):
    train_step(next(data_iter))
# -----------------------------
# Profiling phase
# -----------------------------
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    record_shapes=True,
    profile_memory=True,
) as prof:
    for _ in range(3):
        train_step(next(data_iter))
        prof.step()
# -----------------------------
# Export trace
# -----------------------------
prof.export_chrome_trace("trace.json")
print("Profiler trace saved to trace.json")

This example script shows how to define a minimal (but end-to-end) PyTorch training loop and use torch.profiler to capture steady-state CPU and GPU performance. It builds a toy neural network, creates synthetic data with an iterator, and wraps a training iteration into the train_step() function.

We run a few warm-up iterations first (so we don’t profile one-time initialization/cache overhead). Then we profile a few training steps, recording operator shapes and memory use. We also export Chrome trace (trace.json), which can be loaded in the Chrome browser to inspect GPU utilization, kernel launches, CPU–GPU overlap, and potential performance bottlenecks. We have enabled record_shapes=True and profile_memory=True to track input shapes and memory allocations for debugging OOM errors or inefficiencies.

Tip: Open chrome://tracing in Google Chrome and load the trace.json to get a timeline view.

For automatic bottleneck detection, you can use torch.utils.bottleneck or torch.profiler.schedule to capture snapshots. For example, profiling a few steps every epoch using a scheduler:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.profiler
# -----------------------------
# Device
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Model
# -----------------------------
model = nn.Sequential(
    nn.Linear(100, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
).to(device)
# -----------------------------
# Optimizer and loss
# -----------------------------
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# -----------------------------
# Dataset and DataLoader
# -----------------------------
num_samples = 10_000
batch_size = 32
inputs = torch.randn(num_samples, 100)
targets = torch.randint(0, 10, (num_samples,))
dataset = TensorDataset(inputs, targets)
data_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
)
# -----------------------------
# Training step
# -----------------------------
def train_step(batch):
    model.train()
    inputs, targets = batch
    inputs = inputs.to(device, non_blocking=True)
    targets = targets.to(device, non_blocking=True)
    optimizer.zero_grad(set_to_none=True)
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss
# -----------------------------
# Warm-up (outside profiler)
# -----------------------------
for i, batch in enumerate(data_loader):
    train_step(batch)
    if i >= 5:
        break
# -----------------------------
# Profiling with schedule
# -----------------------------
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    schedule=torch.profiler.schedule(
        wait=1,
        warmup=1,
        active=3,
        repeat=2,
    ),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./prof_log"),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    for step, batch in enumerate(data_loader):
        loss = train_step(batch)
        prof.step()

        if step >= 50:
            break

This will asynchronously begin two trace periods (every 3 steps long) after skipping a wait/warmup period. The on_trace_ready callback then outputs traces you can view in TensorBoard’s profile tab without stopping training. Async trace capture allows you to profile intermittently (e.g., every 3 steps in a 50-step window) to reduce overhead while still capturing performance data.

After profiling, interpret the results to take action:

  • CPU-bound (many gaps when the GPU is waiting): Consider if you can push more work to the GPU, or overlap data loading with computation using asynchronous data loading (e.g., passing `workers` to DataLoader or preprocessing on the GPU). Also, check that there aren’t any Python loops doing work during training; if this is the case, try to remove or `compile` them.
  • GPU kernels are the bottleneck: You can optimize those operations (fused ops, lower precision, etc.). PyTorch’s performance guide recommends sticking to high-level ops (such as torch.nn.functional calls) to allow the libraries to optimize them. However, if you notice a particular op that is orders of magnitude slower, determine if it is expected (loading large FFTs can be slow) or if there are alternative ways to compute it.
  • Multiple shapes causing retracing: consider bucketing inputs by size, or use dynamic=True as mentioned above.
  • Ensure there aren’t external bottlenecks: For example, if using distributed training, verify that communication (all-reduce) isn’t the bottleneck - this would be indicated by NCCL kernels or CPU waiting on the network.

Profiling Checklist

  • Warm up before profiling: Avoid measuring compile or lazy init overhead.
  • Capture both CPU & GPU: Capture both CPU & GPU (activities=[CPU, CUDA]) to understand their interaction.
  • Record shapes & memory: Helps identify shape-related issues and memory usage spikes.
  • Use scheduling for long runs: Use profiler.schedule to capture short time windows periodically. This reduces profiler overhead and still provides visibility into program behavior.
  • Analyze utilization: Ensure your GPUs are being fully utilized. If they aren’t, diagnose whether CPU or I/O is the limiter.
  • Identify top ops: The profiler can output a list of operations sorted by important metrics such as self-time. Determine which operations are taking the most time and optimize them (better algorithms, increase batch size if the GPUs are underutilized).

With the application optimized on a single machine, the next challenge is scaling out to leverage multiple GPUs or nodes, which introduces PyTorch’s distributed training paradigms.

Scale: Distributed Training via DDP or FSDP

When your workload or model grows beyond a single GPU, you’ll need to scale training to multiple GPUs. PyTorch provides two primary methods for training models across GPUs: Distributed Data Parallel and Fully Sharded Data Parallel. While both methods leverage data parallelism under the hood (each process trains on a different subset of your data), they have different approaches to managing model parameters in memory. We’ll explore when to choose each method and how to configure them for both single-node and multi-node training with torchrun.

Distributed Data Parallel

DDP maintains a copy of the entire model on each GPU and synchronizes the copies by all-reducing gradients after each step. Ideally suited for models that can easily fit into a single GPU’s memory, DDP is conceptually straightforward: Initialize a process group, then wrap the model with torch.nn.parallel.DistributedDataParallel.

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
# -----------------------------
# Setup distributed (safe)
# -----------------------------
def setup_distributed():
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    if world_size > 1:
        dist.init_process_group(backend="nccl")
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
        distributed = True
    else:
        local_rank = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        distributed = False

    return distributed, local_rank, device

def cleanup_distributed(distributed):
    if distributed:
        dist.destroy_process_group()

# -----------------------------
# Main training
# -----------------------------
def main():
    distributed, local_rank, device = setup_distributed()

    # -----------------------------
    # Model
    # -----------------------------
    model = nn.Sequential(
        nn.Linear(100, 256),
        nn.ReLU(),
        nn.Linear(256, 10),
    ).to(device)
    if distributed:
        model = DDP(model, device_ids=[local_rank])
    # -----------------------------
    # Optimizer and loss
    # -----------------------------
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    # -----------------------------
    # Dataset
    # -----------------------------
    num_samples = 20_000
    batch_size = 32
    inputs = torch.randn(num_samples, 100)
    targets = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(inputs, targets)
    if distributed:
        sampler = DistributedSampler(dataset, shuffle=True)
    else:
        sampler = None
    train_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=2,
        pin_memory=True,
    )

    # -----------------------------
    # Training loop
    # -----------------------------
    epochs = 3
    for epoch in range(epochs):
        if distributed:
            sampler.set_epoch(epoch)

        for step, (x, y) in enumerate(train_loader):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            if step % 50 == 0 and local_rank == 0:
                print(
                    f"[Epoch {epoch}] Step {step} | "
                    f"Loss {loss.item():.4f}"
                )
    cleanup_distributed(distributed)
# -----------------------------
# Entry point
# -----------------------------
if __name__ == "__main__":
    main()

Launching a DDP job is done with torchrun. For example, if you want to launch a job on a single node with 4 GPUs:

torchrun --nproc_per_node=4 train.py

The command will spawn 4 processes (one for each GPU) and initialize the appropriate environment variables (rank, world size, etc). Specifying multiple nodes requires you to pass --nnodes and network details:

# On node 0:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=0 --master_addr="<IP of node0>" --master_port=12345 train.py
# On node 1:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=1 --master_addr="<IP of node0>" --master_port=12345 train.py

Fully Sharded Data Parallel (FSDP)

FSDP goes a step further by sharding model parameters and optimizer states across GPUs, rather than replicating them fully on each GPU. Using FSDP in PyTorch 2.9 is straightforward with the FullyShardedDataParallel wrapper.

Key points

  • Wrap the model (or submodules) with FSDP before optimizer creation.
  • Use torch.cuda.set_device(rank) and move the model to the GPU as with DDP.
  • Optional: provide an auto-wrap policy if you’d like to shard at module granularity (e.g., wrap each transformer block separately). By default, FSDP(model) will shard the entire model as one shard unit (technically one FSDP unit), which makes it suboptimal for very deep models.

This will recursively shard any submodules with more than 100k params into their own FSDP shards. If you have too few shards (e.g., the whole model is a single shard), you won’t get memory savings between layers. However, too many shards increase communication overhead. Let’s summarize the differences between DDP and FSDP:

Aspect DDP (Distributed Data Parallel) FSDP (Fully Sharded Data Parallel)
Launch method Launched with torchrun, one process per GPU. Also launched with torchrun, one process per GPU.
Model replication Each rank holds a full copy of the model. The model is sharded across ranks; each rank holds only a subset of parameters.
Memory usage per GPU High, proportional to full model size. Much lower, approximately model size divided by the number of GPUs.
Communication pattern Gradients are all reduced after the backward pass. Parameters are all gathered before forward, and gradients are reduced-scattered after backward.
Model wrapping requirement All ranks wrap the entire model in DistributedDataParallel. All ranks must apply the exact same FSDP wrapping logic so each rank knows which shard it owns.
Ease of use Simple to set up; minimal code changes. More complex; requires careful wrapping policies and optimizer setup.
Scalability Limited by per-GPU memory; not suitable for very large models. Designed for very large models that do not fit on a single GPU.
Typical use case Models that comfortably fit in GPU memory and need faster training. Extremely large models or scenarios where memory efficiency is critical.

For example purposes, you might start with DDP. If you encounter out-of-memory or require larger models, consider using FSDP. Note that FSDP in PyTorch 2.9 has evolved; by default, it “shards everything" (ZeRO Stage 3) and leverages other defaults, such as limit_all_gathers=True, to prevent unexpected memory spikes.

Scaling Checklist

  • Ensure determinism: Ensure you are setting the same random seed on all ranks (torch.manual_seed(seed for ran offset) if you need reproducibility.
  • Use DistributedSampler: Split your data to multiple ranks and call set_epoch() on it every epoch to have it shuffle each epoch differently.
  • DDP: Gradients are automatically summed across GPUs after calling backward() on your models wrapped in DDP. Don’t forget to scale up the learning rate or batch size when increasing the number of GPUs!.
  • FSDP: Wrap your modules before creating the optimizer. Play around with auto_wrap_policy to avoid one huge shard (likely if you have a very deep model). Keep an eye on the GPU– ensure shard size actually fits.
  • Communication backend: For multi-node, ensure you have a networking set (NCCL typically). Set environment NCCL_P2P_LEVEL=NVL if you have NVLink, etc., to optimize NCCL.
  • Gradient accumulation: If using it, ensure that with FSDP, you call model.no_sync() if skipping gradient sync on certain iterations (similar to DDP usage).

Checkpoint: Recover Training Reliably with Distributed Checkpoints

In long or large-scale training runs, it is important to save checkpoints. You may want to resume training after a crash, fine-tune from the middle of training, or inspect intermediate results. PyTorch 2.9 includes convenient primitives for distributed checkpointing (DCP). This allows users to save/load more efficiently and robustly than traditional approaches with torch.save. Let’s compare both approaches before diving into the recommended approach.

Traditional torch.save / torch.load

The traditional method in a single-GPU or even DDP training script will save the model’s state_dict() to a file (typically only on rank 0, since all processes will have access to this file). For example:

# On rank 0 only:
torch.save(model.state_dict(), "checkpoint.pt")
and to load:
model.load_state_dict(torch.load("checkpoint.pt"))

While this is fine for small models/configs where one process can keep the entire model state in memory (usually rank 0), there are issues with using this approach for distributed training where the model may be sharded across processes.

Distributed Checkpoint (DCP): torch.distributed.checkpoint module in PyTorch (added since v2.1+, matured by 2.9) addresses these issues. It parallelizes saving: each rank writes out its portion of the model state, producing multiple files (one per rank at minimum) that comprise the checkpoint. It also performs the resharding on load: you can save on N ranks and load on M ranks, and DCP will automatically gather/shard as appropriate.

The typical usage pattern involves wrapping your model (and optimizer, if needed) in a stateful container that provides a state_dict. You can do this using the provided utility functions get_state_dict and set_state_dict, which are FSDP-aware:

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

# Define a stateful container for model & optimizer
class AppState(Stateful):
    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer
    def state_dict(self):
        model_sd, opt_sd = get_state_dict(self.model, self.optimizer)
        return {"model": model_sd, "optim": opt_sd}
    def load_state_dict(self, state):
        set_state_dict(self.model, self.optimizer, state["model"], state.get("optim"))

AppState here implements the Stateful interface, so DCP knows how to get and set state. Note that get_state_dict automatically uses FSDP’s sharded state dict if the model is sharded (each rank only gets its chunk of the model).

Saving with DCP

All ranks execute:

app_state = AppState(model, optimizer)
dcp.save(app_state, checkpoint_id="mycheckpoint")

This will produce directories containing files prefixed with the string “mycheckpoint” (mycheckpoint/mpirank_0_0.pdparams). Each rank writes out its shard in parallel. This can reduce time dramatically compared to one rank doing it.

Loading with DCP

To resume from a checkpoint, you’ll need a model with the same topology (potentially on a different number of ranks). Build your model (and optimizer), wrap in AppState, and then:

app_state = AppState(model, optimizer)
dcp.load(app_state, checkpoint_id="mycheckpoint");

Provided the number of ranks differs from when you saved, DCP will take care of reshuffling the shards internally (for instance, going from 8 to 4 GPUs means each rank will load two saved shards, or if increasing ranks, it will distribute accordingly). This load-time resharding step removes the need for any manual checkpoint conversion.

Asynchronous checkpointing

One of the coolest features in PyTorch 2.9 is async checkpoint save, which can overlap the checkpoint writing with training. Using dcp.async_save, the API returns a future and performs the heavy I/O in background threads. The basic idea:

save_future = dcp.async_save(app_state, checkpoint_id="chkpt_epoch10")
# ... training can continue immediately ...
save_future.wait()  # later, wait for completion (or periodically check)

async_save first stages the data locally on CPU (copies the model/optimizer state to pinned memory buffers), then writes those buffers to disk/network asynchronously. The GPUs can resume compute work (i.e., training) while the checkpoint is flushing.

Asynchronous saving temporarily requires twice the memory requirements for storing your model states. Each rank allocates a buffer on the CPU ~ equal in size to its shard of the checkpoint. So saving very large models asynchronously across many ranks can increase the memory usage considerably (CPU RAM + pinned memory usage ~ checkpoint_size_per_rank * num_ranks, worst-case).

Now, let’s compile a comparison of checkpointing methods with guidance on when to use each:

Checkpoint Method Description When to Use
torch.save/torch.load (rank 0) Single-file checkpoint on one process. The model’s state_dict is gathered (if sharded) and saved as one file, typically by rank 0. Simple and works out of the box for non-sharded models. Use for small-to-medium models or single-GPU training. Also acceptable for DDP if the model size is moderate (each rank has a full model, so rank 0 can save it). Not suitable when the model is extremely large or distributed (may OOM or be very slow).
DCP (synchronous) Distributed Checkpoint (parallel, synced). Each rank saves its shard of the state_dict. Blocks training until complete (all ranks finish writing). Handles FSDP, ShardedTensor, etc., and allows loading on different world sizes (automatic resharding). Produces multiple files (at least one per rank). Use for large models on multi-GPU (DDP or FSDP). Recommended when a single rank’s checkpoint would be too large or slow. Ensures faster saves (parallel I/O) and avoids a single point of failure. Use synchronous save if training can afford to pause for a checkpoint (e.g., between epochs) or if system I/O is fast enough that the pause is short.
DCP + Async (async_save) Distributed async checkpoint. Same as synchronous DCP, but returns immediately and performs saving in background threads. Training can continue while the checkpoint is being written, at the cost of extra memory for buffers. Requires careful handling of overlapping saves (do not start multiple without waits, to avoid memory bloat). Use for very large models or when checkpoint time is significant relative to training time. Ideal in production training where even a minute of pause is too long—async hides checkpoint latency. Ensure sufficient CPU RAM/pinned memory. Test in your environment because async I/O performance can vary (e.g., slow filesystems may still indirectly slow training if the disk is saturated).

Checkpointing Checklist

  • Include all necessary state: Ensure saving model weights and optimizer state (and scheduler, RNG state, etc, if necessary) to fully resume training later. If using DCP’s Stateful method, you can easily pass in the optimizer via get_state_dict(…, model, optimizer).
  • Test restore: As soon as you’ve saved the checkpoint, load it back in (preferably on the same and different numbers of GPUs if distributed) in a clean environment to ensure everything restores as expected. This will help you find errors, such as forgetting to save some part of the training state early.
  • Manage storage: Checkpoints can get large very quickly (especially with sharded distributed checkpoints). Implement some cleaning strategy (keep the last k checkpoints, etc.) to avoid running out of disk.
  • Async usage: async_save is a useful feature, but you must be careful not to have multiple async saves in flight at a time (unless you’re very careful with how it uses memory and allows overlaps). Eventually, you must call .wait() on the resulting future to raise any exceptions or confirm success.
  • Consistency: When resuming from a checkpoint, make sure all ranks call the same dcp.load. This is especially important for FSDP, as it needs to synchronize the inside load. Additionally, ensure the model is on expected devices before loading if using CPU offloading or other mechanisms.

Conclusion

By following this workflow, you’ll develop PyTorch training code that performs well and is robust to failures:

  • Start from a working baseline.
  • Compile for maximum speed.
  • Profile to identify opportunities for optimization.
  • Scale out with an appropriate data or model parallel strategy.
  • Checkpoint effectively to allow recovery from failure.

There are playbooks available for each step. Use them iteratively (e.g., profile again after step 4 changes your performance characteristics, or revisit compile settings in step 2 to tune for your new environment). As PyTorch has evolved rapidly, there are capabilities to help at each step. Learn when and how to use each, and you’ll be ready to train large models at speed and scale, with good recovery strategies.

FAQs

Why is establishing a baseline so important before optimization?

You need to have a correct single-GPU eager baseline before optimizing so that you have something you trust for correctness and performance. Otherwise, optimizations may hide bugs that cause incorrect results. Only with a correct baseline can you attribute speedups or regressions to specific code changes rather than cryptic bugs or measurement noise.

When should I use torch.compile, and what should I watch out for?

torch.compile should be used after you have a solid baseline to accelerate steady-state execution. Monitor graph break, warm up before measuring, and ensure dynamic shapes are handled deliberately to avoid excessive recompilation.

How does torch.profiler help beyond confirming assumptions?

torch.profiler exposes your true bottlenecks- whether that be CPU stalls, inefficient GPU kernels, shape retracing, or communications overhead-. This allows you to focus optimization efforts where they matter based on evidence rather than intuition.

How do I choose between DDP and FSDP for distributed training?

If your model comfortably fits on each GPU you’d like to use, go with DDP for its simplicity and relative speed. If you’re bound by GPU memory or working with very large models, FSDP will give increased scalability at the cost of added complexity.

Why is Distributed Checkpointing preferred over torch.save in large runs?

Distributed Checkpointing parallelizes saves across ranks, supports resharding on load, and enables async checkpointing, making recovery reliable and efficient for large, multi-GPU or multi-node training jobs.

References

Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.

Learn more about our products

About the author(s)

Adrien Payong
Adrien Payong
Author
AI consultant and technical writer
See author profile

I am a skilled AI consultant and technical writer with over four years of experience. I have a master’s degree in AI and have written innovative articles that provide developers and researchers with actionable insights. As a thought leader, I specialize in simplifying complex AI concepts through practical content, positioning myself as a trusted voice in the tech community.

Shaoni Mukherjee
Shaoni Mukherjee
Editor
Technical Writer
See author profile

With a strong background in data science and over six years of experience, I am passionate about creating in-depth content on technologies. Currently focused on AI, machine learning, and GPU computing, working on topics ranging from deep learning frameworks to optimizing GPU-based workloads.

Still looking for an answer?

Was this helpful?


This textbox defaults to using Markdown to format your answer.

You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!

Creative CommonsThis work is licensed under a Creative Commons Attribution-NonCommercial- ShareAlike 4.0 International License.
Join the Tech Talk
Success! Thank you! Please check your email for further details.

Please complete your information!

The developer cloud

Scale up as you grow — whether you're running one virtual machine or ten thousand.

Get started for free

Sign up and get $200 in credit for your first 60 days with DigitalOcean.*

*This promotional offer applies to new accounts only.