This repo contains a Pytorch (0.4.1) implementation of the EKFAC and K-FAC preconditioners. If you find this software useful, please check the references below and cite accordingly!
We implemented K-FAC and EKFAC as preconditioners
. Preconditioners are similar Pytorch's optimizer
class, with the exception that they do not perform the update of the parameters, but only change the gradient of those parameters. They can thus be used in combination with your favorite optimizer (we used SGD in our experiments). Note that we only implemented them for Linear
and Conv2d
modules, so they will silently skip all the other modules of your network.
Here is a simple example showing how to add K-FAC or EKFAC to your code:
# 1. Instantiate the preconditioner
preconditioner = EKFAC(network, 0.1, update_freq=100)
# 2. During the training loop, simply call preconditioner.step() before optimizer.step().
# The optimiser is usually SGD.
for i, (inputs, targets) in enumerate(train_loader):
optimizer.zero_grad()
outputs = network(inputs)
loss = criterion(outputs, targets)
loss.backward()
preconditioner.step() # Add a step of preconditioner before the optimizer step.
optimizer.step()