Snapshot Ensembles

This repository contains an implementation in Keras of the paper Snapshot Ensembles: Train 1, get M for free.

The authors use a modified version of cyclical learning rate to force the model to fall into local minima at the end of each cycle. Each local minima makes different mistakes. Thus the ensemble of every local minima helps to reach a better generalization.

Image snapshot

Image formula

Prototype

This is a callback:

Snapshot(folder_path, nb_epochs, nb_cycles=5, verbose=0)

With:

Usage

from snapshot import Snapshot

callback = Snapshot('snapshots', nb_epochs=6, verbose=1, nb_cycles=2)
model.fit(
    x=x_train, y=y_train,
    epochs=10,
    batch_size=32,
    callbacks=[callback]
)

The authors advise to use the mean of the models'outputs. The file example.py shows how one could do it.