This repository contains the implementation of the paper Deep Frank-Wolfe For Neural Network Optimization in pytorch. If you use this work for your research, please cite the paper:
@Article{berrada2019deep,
author = {Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan},
title = {Deep Frank-Wolfe For Neural Network Optimization},
journal = {International Conference on Learning Representations},
year = {2019},
}
This code should work for pytorch >= 1.0 in python3. Detailed requirements are available in requirements.txt
.
git clone --recursive https://github.com/oval-group/dfw
(note that the option recursive
is necessary to have clone the submodules, these are needed to reproduce the experiments but not for the DFW implementation itself).cd dfw && pip install -r requirements.txt
python setup.py install
from dfw import DFW
from dfw.losses import MultiClassHingeLoss
model
is a nn.Modulex
is an input sample, y
is a labelsvm = MultiClassHingeLoss()
optimizer = DFW(model.parameters(), eta=0.1)
optimizer.zero_grad() loss = svm(model(x), y) loss.backward()
optimizer.step(lambda: float(loss))
* Technical requirement: the DFW uses a custom step-size at each step. For this update to make sense, the loss function must be piecewise linear convex.
For instance, one can use a multi-class SVM loss or an l1 regression.
* Smoothing: sometimes the multi-class SVM loss does not fare well with a large number of classes.
This issue can be alleviated by using dual smoothing, which is easy to plug in the code:
```python
from dfw.losses import set_smoothing_enabled
...
with set_smoothing_enabled(True):
loss = svm(model(x), y)
VISION_DATA=[path/to/your/cifar/data] python reproduce/cifar.py
python reproduce/snli.py
DFW largely outperforms all baselines that do not use a manual schedule for the learning rate. The tables below show the performance on the CIFAR data sets when using data augmentation (AMSGrad, a variant of Adam, is the strongest baseline in our experiments), and on the SNLI data set.
Wide Residual Networks | Densely Connected Networks |
---|---|
| Optimizer | Test Accuracy (%) | | --------- | :--------------: | | AMSGrad | 90.1 | | **DFW** | **94.7** | | SGD (with schedule) | 95.4 | | | Optimizer | Test Accuracy (%) | | --------- | :--------------: | | AMSGrad | 91.8 | | **DFW** | **94.9** | | SGD (with schedule) | 95.3 | |
Wide Residual Networks | Densely Connected Networks |
---|---|
| Optimizer | Test Accuracy (%) | | --------- | :--------------: | | AMSGrad | 67.8 | | **DFW** | **74.7** | | SGD (with schedule) | 77.8 | | | Optimizer | Test Accuracy (%) | | --------- | :--------------: | | AMSGrad | 69.6 | | **DFW** | **73.2** | | SGD (with schedule) | 76.3 | |
We use the following third-part implementations: