[P] Keras SWA: Stochastic weight averaging callback for Keras
As an exercise for myself I decided to implement SWA, from the paper Averaging Weights Leads to Wider Optima and Better Generalization. I did it with Keras and decided it might make a nice package.
pip install keras-swa
If you are not familiar with SWA, it is a trick to approximate ensembling by taking a running average of your weights towards the end of training a model. You can read more in this nice blog post explaining SWA and it’s relatives SSE and FGE.
I currently only implement the constant learning rate schedule from the paper, hoping to add the cyclic one from the paper soon. It is also possible to leave the learning rate to the optimizer or other schedulers. I have also not implemented the batch normalization fix. It requires a forward pass over training data, which I don’t know how to do from a callback. So any help there would be appreciated.
I would love for people to try it! Feedback is also welcome! 🙂