If you come from classical ML, you probably have a working mental model of gradient descent: compute a loss, take a step in the negative gradient direction, repeat. For logistic regression, this is a well defined 'convex-optimization' problem, guaranteed to converge (with proper lr annealing) to a global optima. It's also likely you understand that deep learning is a 'non-convex-optimization' problem not guaranteed to converge to a global optima.
AdamW has established itself as the SOTA optimizer for most deep learning applications (outside of some natural gradient applications in RL). AdamW may seem like a complex beast to understand or implement, but is actually the combination of 4 fairly simple ideas, Stochastic-Gradient-Descent (SGD), Momentum, RMSProp, and L2-Regularization.
Beyond non-convex-optimization, deep learning optimization suffers from training stability issues that classical models rarely suffer from. The geometry of the loss landscape creates gradient explosion and vanishing pitfalls beyond the local-minima problem. We will cover the details of AdamW optimizer, cover the largest issues in training stability in deep learning, and how Adam improves training stability.
AdamW as a combination of Stochastic-Gradient-Descent (SGD), Momentum, RMSProp, and L2-Regularization
AdamW wasn't designed from scratch — it's the result of stacking four ideas, Stochastic-Gradient-Descent (SGD), Momentum, RMSProp, and L2-Regularization. Where Adam is the simpler combination of the first 3, Stochastic-Gradient-Descent (SGD), Momentum and RMSProp.
Building AdamW up step by step is the clearest way to understand what each hyperparameter controls and why the defaults are what they are.
SGD theta -= lr * grad
│
│ + smooth noisy gradient direction
▼
Momentum theta -= lr * v
v = β₁v + (1-β₁)*grad
│
│ + per-parameter adaptive learning rate
▼
RMSProp theta -= lr * grad / (sqrt(s) + ε)
s = β₂s + (1-β₂)*grad²
│
│ + combine both, add bias correction
▼
Adam theta -= lr * v̂ / (sqrt(ŝ) + ε)
│
│ + decouple weight decay from gradient
▼
AdamW theta -= lr * v̂ / (sqrt(ŝ) + ε)
theta -= lr * λ * theta
Stochastic Gradient Descent
Plain stochastic gradient descent is the foundation of Adam and AdamW.
x_sample, y_sample = sample_minibatch()
loss = compute_loss(y_sample, model(x_sample))
grad = d(loss)/d(theta)
theta = theta - lr * grad
However, in both toy problems and real world problems pure SGD led to two major issues:
- Gradient noise makes the update direction unreliable step-to-step — adjacent mini-batches can point in opposite directions, causing oscillation. This is addressed by Momentum.
- A single global learning rate doesn't fit all parameters at all times — a step size appropriate for one layer can be too large or too small for another. This is addressed by RMSProp.
Momentum — Smooth the gradient direction
Momentum maintains an exponential moving average of past gradients, often called velocity v, and updates weight by v rather than the raw gradient:
v = beta1 * v + (1 - beta1) * grad
theta = theta - lr * v
With beta1=0.9, each step is 90% last step's direction plus 10% the new gradient. This damps oscillations in noisy directions and accelerates movement along consistent ones — the ball-rolling-downhill analogy. The gradient direction is now stable. The per-parameter scale problem remains.
RMSProp — Adapt the learning rate per parameter
RMSProp maintains an exponential moving average of past squared gradient magnitude, often called the second-moment s, for each parameter and uses it to normalize the step size:
s = beta2 * s + (1 - beta2) * grad**2
theta = theta - lr * grad / (sqrt(s) + eps)
Parameters that have been receiving large gradients accumulate large second-moment s values, shrinking their effective learning rate. Parameters with historically small gradients get larger effective steps. One global lr now behaves like per-parameter adaptive rates. Gradient direction is still the raw noisy gradient — no momentum yet.
Adam — Combining Momentum and RMSProp
Adam applies both RMSProp and Momentum simultaneously: smooth the gradient direction with momentum (v), and normalize the step size with the RMS of past gradients (s).
Adam also adds a bias correction terms (1 / (1 - beta**t)) account for the fact that v and s are initialized at zero — without them, early steps are systematically underestimated. The correction fades to ~1 as t grows and the accumulators warm up.
v = beta1 * m + (1 - beta1) * grad # first moment (momentum)
s = beta2 * v + (1 - beta2) * grad**2 # second moment (RMS)
v_hat = v / (1 - beta1**t) # bias correction
s_hat = s / (1 - beta2**t) # bias correction
theta = theta - lr * v_hat / (sqrt(s_hat) + eps)
A common interview question, is how to define Adam and what are the B1, B2 hyperparameters. Now, we can simply say, B1 and B2 are the exponential moving average hyper-parameters of the first and second moment of the loss.
Common values for B1 and B2 are (0.9, 0.999) (PyTorch defaults) or (0.9, 0.98) (common in LLM pre-training publications).
The common refrain on B2 is to use a higher (0.999) B2 for dense gradients problems, and lower (0.95) B2 for sparse gradient problems.
AdamW — Decoupled weight decay
AdamW is simply Adam with a "decoupled" form of L2 regularization. Regulaization needs to be "decoupled" because of RMSProp. In RMSProp (and Adam), a classical regularization term is a part of the loss, and when the gradient is applied that regularization term's gradient, as part of the loss gradient, is passed through the same adaptive RMSProp denominator. This leads to the effective strength of the regulaization scaling up and down unpredictably across parameters.
AdamW fixes this by moving the L2 weight decay out of the loss & gradient computation and applying it directly to the parameter update:
# AdamW (correct: decay is applied separately, after the adaptive step)
v = beta1 * v + (1 - beta1) * grad_loss
s = beta2 * s + (1 - beta2) * grad_loss**2
v_hat = v / (1 - beta1**t)
s_hat = s / (1 - beta2**t)
theta = theta - lr * v_hat / (sqrt(s_hat) + eps) - lr * lambda * theta
The decoupled AdamW applies a uniform shrinkage rate to every parameter on every step, regardless of gradient history. The lambda weight_decay is now a clean, interpretable hyperparameter.
The following sections cover each of these components in more depth — the failure modes they address, their implementation, and how they interact with training stability.
A minimal AdamW implementation
import torch
class CustomAdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), weight_decay=1e-2, eps=1e-8):
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay, eps=eps)
super().__init__(params, defaults)
for group in self.param_groups:
for p in group['params']:
self.state[p]['t'] = 0
self.state[p]['v'] = torch.zeros_like(p) # first moment
self.state[p]['s'] = torch.zeros_like(p) # second moment
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
state['t'] += 1
v, s, t = state['v'], state['s'], state['t']
v.mul_(beta1).add_(p.grad, alpha=1 - beta1)
s.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
v_hat = v / (1 - beta1 ** t)
s_hat = s / (1 - beta2 ** t)
# adaptive gradient step
p.addcdiv_(v_hat, s_hat.sqrt().add_(group['eps']), value=-lr)
# decoupled weight decay applied after adaptive step
p.mul_(1 - lr * group['weight_decay'])
Training Stability
Even with a perfect optimizer, deep networks are prone to training instability. The root cause is structural: gradients are computed via the chain rule, meaning the gradient at an early layer is the product of Jacobians from every layer above it. With L layers, that's roughly L matrix multiplications chained together. If the typical singular value of those Jacobians is less than 1, the product shrinks exponentially with depth — the gradient vanishes. If greater than 1, it grows exponentially — the gradient explodes.
In a transformer with 32–96 layers, small systematic biases in those Jacobians compound dramatically. The earliest layers receive a gradient signal that has passed through dozens of weight matrices and activations, and it arrives near-zero or near-infinity.
Vanishing gradients cause early layers to stop learning. Their parameters barely move because the loss signal is too attenuated by the time it arrives. This shows up as loss plateauing early, with later layers improving while earlier ones stagnate.
Exploding gradients cause large destabilizing parameter updates. A single outlier batch can produce a spike in gradient norms that corrupts the optimizer's moment estimates, leading to an outsized step, which can cause another spike. At scale these cascade into divergence.
Recognizing Instability in Practice
Training instability has a few recognizable signatures in the loss curve:
Loss spikes — a sudden jump followed by partial recovery — usually indicate an outlier batch or a learning rate that's briefly too high. The danger is cascading: a spike corrupts v and s, which scales up subsequent steps.
Loss divergence — monotonically increasing loss with no recovery — typically means the learning rate is too high, warmup is misconfigured, or exploding gradients are propagating NaN/Inf values through the network.
Early plateau — loss stops improving long before convergence — usually means the learning rate is too low, weight decay is too aggressive, or vanishing gradients are starving the early layers of signal.
Noisy but trending — generally fine. This is normal stochastic behavior, especially with small batch sizes or high dropout. Gradient accumulation reduces variance if it's excessive.
Healthy Spike Divergence Early Plateau
loss loss loss loss
│╲ │ ╭╮ │ ╱ │─────────
│ ╲────── │───╯╰────── │ ╱ │
│ │ │ ╱ │
└───────── └───────── └───────── └─────────
steps steps steps steps
↑ ↑ ↑
lr spike lr / NaN lr / wd
Where AdamW Helps
AdamW addresses several dimensions of training instability, though not all of them.
The adaptive denominator sqrt(s_hat) directly mitigates gradient scale variance across layers. Parameters receiving consistently large gradients get smaller effective steps; parameters with historically small gradients get larger ones. This prevents the loudest layers from overwriting the signal in quieter layers — a problem SGD handles poorly with a single global learning rate.
The momentum term v smooths the update direction across steps, damping oscillation in noisy gradient directions while accelerating movement along consistent ones.
Decoupled weight decay ensures regularization pressure is uniform across all parameters regardless of gradient history. With coupled Adam+L2, parameters receiving large gradients got less regularization — an unintended interaction that could allow weights in active layers to grow unchecked.
Where AdamW Falls Short
AdamW does not fix vanishing or exploding gradients at the source. The adaptive denominator normalizes each update by its own parameter's gradient history — it cannot compensate for a gradient that has already been multiplied to near-zero or near-infinity before it arrives. That's a structural problem, and it requires structural solutions.
Gradient clipping is the direct mitigation for exploding gradients. It caps the global gradient norm before the optimizer step, rescaling the entire gradient vector when it exceeds a threshold:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Direction is preserved; only magnitude is bounded. Clipping is reactive — it catches explosions after they occur — but it prevents them from corrupting the optimizer state and cascading. max_norm=1.0 is the standard default for transformer training.
Learning rate warmup addresses an instability that AdamW itself introduces. At step 0, s is initialized to zero. Even with bias correction, the effective variance estimate is unreliable for the first several hundred steps — producing erratically large adaptive step sizes before the optimizer has calibrated. A linear ramp from zero over the first few thousand steps avoids these early destabilizing updates:
def lr_with_warmup(step, warmup_steps, base_lr):
if step < warmup_steps:
return base_lr * step / warmup_steps
return base_lr
Cosine decay rounds out the schedule. After warmup, decaying the learning rate along a cosine curve to ~10% of peak lets the model make large exploratory updates early and precise refinements late:
import math
def cosine_lr(step, warmup_steps, total_steps, base_lr, min_lr_fraction=0.1):
if step < warmup_steps:
return base_lr * step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
cosine_factor = 0.5 * (1 + math.cos(math.pi * progress))
min_lr = base_lr * min_lr_fraction
return min_lr + cosine_factor * (base_lr - min_lr)
Residual connections and layer normalization are the architectural solutions to vanishing gradients — they operate at the network design level, not the optimizer level. Residuals create a gradient highway that bypasses the multiplicative chain; layer norm keeps activations in a well-scaled range before each sublayer. We'll cover both in detail in Part 4.
The Full Picture
| Problem | Mechanism | |---|---| | Gradient scale variance across layers | AdamW adaptive rates | | Noisy gradient direction | Momentum | | Regularization entangled with gradient scale | AdamW decoupled weight decay | | Exploding gradients | Gradient clipping | | Adam cold-start instability | Learning rate warmup | | Late-training over-stepping | Cosine decay | | Vanishing gradients (structural) | Residual connections + layer norm |
AdamW makes deep network training tractable, but it's one piece of a system. The architectural solutions in the table are equally load-bearing — and we'll see exactly why in Part 4.
Next: Part 2 — Self-Supervision, Cross-Entropy & the Training Loop