The problem: Adam optimizer (and SGD) converges to sharp local minima in the loss landscape. Sharp minima often correspond to overfitted solutions โ small changes in input cause big changes in prediction. They generalize poorly.
The fix (Izmailov et al. 2018): instead of using the final weights from training, average the weights from late training. The average sits in a flatter region of the loss landscape, where small input perturbations cause small output changes. Flatter minima have been shown to generalize better.
How it works here: every 60 seconds, the SWA module snapshots model.weights and updates a running average:
swa_w_new = (n ร swa_w_old + new_w) / (n + 1)
Once we have โฅ5 snapshots, SWA.predict() uses the averaged weights for inference. The Unified Predictor blends SWA's prediction (14% weight) with the live model, multi-horizon ensemble, bootstrap, and k-NN.
Divergence as uncertainty signal: if the L2 distance between SWA and live model weights is large, the model is still drifting fast โ high uncertainty. If divergence is small, the model has settled โ lower uncertainty.