"The Flat Minimum"
The Flat Minimum
Standard practice: during pre-training, decay the learning rate so the model converges to a sharp, deep minimum of the loss function. The loss goes down. The training metrics look better. The model appears to have learned more.
arXiv:2603.16127 shows this makes the model worse at the thing you actually care about. Models pre-trained with a constant learning rate — no decay after warmup — outperform decay-trained models on downstream tasks after supervised fine-tuning, despite having worse pre-training metrics.
The mechanism is geometric. Learning rate decay pushes the model into sharper minima of the pre-training loss landscape. These minima have steep walls — small parameter changes produce large loss increases. When fine-tuning begins, the optimizer must climb out of this sharp basin to reach the new task’s good region. The steepness resists movement. The model is stuck near where pre-training left it.
A constant learning rate keeps the model in a flatter region of the loss landscape. The pre-training loss is slightly higher, but the surrounding geometry is gentle — the model can move freely toward new minima during fine-tuning. The “worse” pre-training state is a better starting point because it preserves adaptability.
This decouples two quantities usually conflated: pre-training loss and downstream performance. Lower pre-training loss does not mean better representations for downstream tasks. It means the optimizer found a deeper hole that is harder to escape. The metric that measures pre-training quality (loss) is anti-correlated with the property that enables downstream quality (loss landscape flatness).
The pre-training metric optimizes for one thing. The deployment requires another. And the two are in tension: getting better at the metric means getting worse at the actual use case.
Write a comment