by Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy and Peter Tang
Paper link: arXiv preprint
Our code was written in Keras 1.X, there have been a lot of API changes in Keras 2.X which have broken our code. We're working on updating our code to support Keras 2.X but in the meantime, provide a preliminary PyTorch implementation (refer to the PyTorch folder for details). As always, we welcome any questions, suggestions, requests or bug-reports.
This repository contains (Python) code needed to reproduce some of the figures in our paper. The plots illustrate the relative sharpness of the minima obtained when trained using small-batch (SB) and large-batch (LB) methods. For ease of exposition, we use a Keras/Theano setup but owing to the simplicity of the code, translating the code into other frameworks should be easy. Please contact us if you have any questions, suggestions, requests or bug-reports.
If you use this code or our results in your research, please cite:
@article{Keskar2016,
author = {Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy and Ping Tak Peter Tang},
title = {On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima},
journal = {arXiv preprint arXiv:1609.04836},
year = {2016}
}
To reproduce the parametric plots, you only need the two Python files (plot_parametric_plot.py
and network_zoo.py
). The latter contains the model configurations for the C1-C4 networks; the former trains the model imported from network_zoo
using the SB and LB methods and plots the parametric plot connecting the two minimizers. The network is chosen using a command-line argument -n
(or --network
) and the generated plot is saved in PDF form. For instance, to plot for the C1 network, one can simply run:
KERAS_BACKEND=theano python plot_parametric_plot.py -n C1
with the necessary Theano flags depending on the setup. The figure in the Figures/
folder should resemble: