Mastering Gradients, `zero_grad`, and Optimizers in PyTorch
Mastering Gradients, zero_grad
, and Optimizers in PyTorch
A practical guide to what actually happens under the hood when you train a neural network in PyTorch—and how to take full control of it.
Table of Contents
- Autograd Recap
loss.backward()
— What Really Happens- Why Gradients Accumulate
optimizer.zero_grad()
vsmodel.zero_grad()
- Setting Gradients to
None
for Speed - The
optimizer.step()
Update - Gradient Accumulation for Large Effective Batch Sizes
- Best‑Practice Training Loop Templates
- Common Pitfalls & Debugging Tips
- Cheat Sheet
Autograd Recap
PyTorch builds a dynamic computation graph as you execute tensor operations. If a tensor has requires_grad=True
, every subsequent operation records a Function node in the graph. The graph is directed and acyclic until you call loss.backward()
, which traverses it in reverse to compute gradients via automatic differentiation.
1
2
3
4
x = torch.randn(32, 100, requires_grad=True)
w = torch.randn(100, 10, requires_grad=True)
output = x @ w # graph grows one op: matmul
loss = output.pow(2).mean()
- Leaf tensors (
x
,w
) store a.grad
attribute where gradients accumulate. - Non‑leaf tensors (intermediate results) typically don’t hold gradients unless you explicitly call
.retain_grad()
.
loss.backward()
— What Really Happens
- Gradient seed: If the loss is a scalar, autograd seeds the backward pass with a gradient of 1 w.r.t. the loss.
- Reverse traversal: PyTorch walks the graph backward, calling each Function’s
backward()
to compute∂output/∂input
. Accumulation: For every leaf parameter
p
, the computed gradient is added top.grad
:1
p.grad = (p.grad or 0) + dp
- No parameter update yet:
backward()
only fills.grad
; you still needoptimizer.step()
to change the weights.
Why Gradients Accumulate
- Flexibility: Lets you combine gradients from multiple forward passes (e.g. gradient accumulation, multi‑task losses, TBPTT).
- Historical context: Mirrors classical deep‑learning frameworks (Theano, Torch 7) where you manually zeroed grads.
If you don’t clear .grad
between mini‑batches, your parameter updates will be incorrect because each step will mix gradients from multiple batches.
optimizer.zero_grad()
vs model.zero_grad()
1
2
optimizer.zero_grad() # preferred
model.zero_grad() # identical effect
Both iterate over parameters and set p.grad
to zero (torch.zeros_like
). Use one or the other, not both.
Under the hood, optimizer.zero_grad()
simply calls p.grad = p.grad.detach().zero_()
for every parameter in the optimizer’s param groups.
When might they differ?
If you pass a subset of parameters to the optimizer (rare but possible), model.zero_grad()
clears all parameters—including ones the optimizer doesn’t know about. Usually that’s fine, but stick to optimizer.zero_grad()
for clarity.
Setting Gradients to None
for Speed
Clearing gradients by zero‑ing writes to every element, wasting bandwidth. PyTorch ≥1.7 lets you instead delete the tensor and let autograd recreate it next backward:
1
optimizer.zero_grad(set_to_none=True)
- Pros: Saves a kernel launch and memory bandwidth.
- Cons: You must check
p.grad is not None
before using.grad
(e.g. for gradient clipping).
Alternatively, manual loop:
1
2
for p in model.parameters():
p.grad = None
The optimizer.step()
Update
After fresh gradients sit in .grad
, call:
1
optimizer.step()
This iterates through param groups and updates each parameter using the chosen rule (SGD, Adam, etc.). For Adam:
1
2
3
m = beta1 * m + (1-beta1) * grad
v = beta2 * v + (1-beta2) * grad**2
param -= lr * m / (sqrt(v) + eps)
Order always matters:
zero_grad()
- forward pass
loss.backward()
optimizer.step()
Gradient Accumulation for Large Effective Batch Sizes
When a full batch won’t fit in GPU RAM:
1
2
3
4
5
6
7
8
acc_steps = 4 # accumulate 4 mini‑batches
optimizer.zero_grad()
for i, batch in enumerate(loader):
loss = compute_loss(batch) / acc_steps
loss.backward()
if (i+1) % acc_steps == 0:
optimizer.step()
optimizer.zero_grad()
- Divide the loss by
acc_steps
so the total gradient matches that of a real large batch. - Clip gradients after accumulation but before
step()
.
Best‑Practice Training Loop Templates
Standard Training Loop
1
2
3
4
5
6
7
8
model.train()
for inputs, targets in loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
AMP + Grad‑Accumulation (Mixed Precision)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
scaler = torch.cuda.amp.GradScaler()
optimizer.zero_grad(set_to_none=True)
for i, batch in enumerate(loader):
with torch.cuda.amp.autocast():
loss = compute_loss(batch) / acc_steps
scaler.scale(loss).backward()
if (i+1) % acc_steps == 0:
scaler.unscale_(optimizer) # for gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
Common Pitfalls & Debugging Tips
Symptom | Likely Cause | Fix |
---|---|---|
Loss oscillates wildly | Forgot to zero grads → accumulating across batches | Call optimizer.zero_grad() each iteration |
NoneType grad error | Using set_to_none=True and later assuming .grad exists | Check param.grad is not None |
Slow training | Zeroing large gradients on CPU before transfer | Move model to GPU before calling zero_grad() |
Out‑of‑memory on backward | Large batch | Use gradient accumulation or checkpointing |
Cheat Sheet
loss.backward()
: computes and adds gradients to.grad
.optimizer.zero_grad()
: clears.grad
(zero‑fill orNone
).optimizer.step()
: updates params using current.grad
.- Call zero_grad → forward → backward → step every update unless intentionally accumulating.
Remember: Clearing gradients isn’t a performance hack; it’s about correctness. Treat
.grad
as a scratchpad—scribble a fresh set of numbers there every time you callbackward()
, unless you want them to add up.
Happy training! 🎉