Keras AdamW

Build Status Coverage Status Codacy Badge PyPI version License: MIT

Keras implementation of AdamW, SGDW, NadamW, and Warm Restarts, based on paper Decoupled Weight Decay Regularization - plus Learning Rate Multipliers

Features

Installation

pip install keras-adamw or clone repository

Usage

If using tensorflow.keras imports, set import os; os.environ["TF_KERAS"]='1'.

Weight decay

AdamW(model=model)
Three methods to set weight_decays = {<weight matrix name>:<weight decay value>,}:

# 1. Automatically
Just pass in `model` (`AdamW(model=model)`), and decays will be automatically extracted.
Loss-based penalties (l1, l2, l1_l2) will be zeroed by default, but can be kept via
`zero_penalties=False` (NOT recommended, see Use guidelines).
# 2. Use keras_adamw.utils.py
Dense(.., kernel_regularizer=l2(0)) # set weight decays in layers as usual, but to ZERO
wd_dict = get_weight_decays(model)
# print(wd_dict) to see returned matrix names, note their order
# specify values as (l1, l2) tuples, both for l1_l2 decay
ordered_values = [(0, 1e-3), (1e-4, 2e-4), ..]
weight_decays = fill_dict_in_order(wd_dict, ordered_values)
# 3. Fill manually
model.layers[1].kernel.name # get name of kernel weight matrix of layer indexed 1
weight_decays.update({'conv1d_0/kernel:0': (1e-4, 0)}) # example

Warm restarts

AdamW(.., use_cosine_annealing=True, total_iterations=200) - refer to Use guidelines below

LR multipliers

AdamW(.., lr_multipliers=lr_multipliers) - to get, {<layer name>:<multiplier value>,}:

  1. (a) Name every layer to be modified (recommended), e.g. Dense(.., name='dense_1') - OR
    (b) Get every layer name, note which to modify: [print(idx,layer.name) for idx,layer in enumerate(model.layers)]
  2. (a) lr_multipliers = {'conv1d_0':0.1} # target layer by full name - OR
    (b) lr_multipliers = {'conv1d':0.1} # target all layers w/ name substring 'conv1d'

    Example

    
    import numpy as np
    from keras.layers import Input, Dense, LSTM
    from keras.models import Model
    from keras.regularizers import l1, l2, l1_l2
    from keras_adamw import AdamW

ipt = Input(shape=(120, 4)) x = LSTM(60, activation='relu', name='lstm_1', kernel_regularizer=l1(1e-4), recurrent_regularizer=l2(2e-4))(ipt) out = Dense(1, activation='sigmoid', kernel_regularizer=l1_l2(1e-4, 2e-4))(x) model = Model(ipt, out)

```python
lr_multipliers = {'lstm_1': 0.5}

optimizer = AdamW(lr=1e-4, model=model, lr_multipliers=lr_multipliers,
                  use_cosine_annealing=True, total_iterations=24)
model.compile(optimizer, loss='binary_crossentropy')
for epoch in range(3):
    for iteration in range(24):
        x = np.random.rand(10, 120, 4) # dummy data
        y = np.random.randint(0, 2, (10, 1)) # dummy labels
        loss = model.train_on_batch(x, y)
        print("Iter {} loss: {}".format(iteration + 1, "%.3f" % loss))
    if iteration == (24 - 2): 
        K.set_value(model.optimizer.t_cur, -1) # WARM RESTART: reset cosine annealing argument
    print("EPOCH {} COMPLETED\n".format(epoch + 1))

(Full example + plot code, and explanation of lr_t vs. lr: example.py)

Use guidelines

Weight decay

Warm restarts