import keras.backend as K import numpy as np from keras.callbacks import Callback, ModelCheckpoint class HistoryCache: def __init__(self, his_len=10): self.history = [0] * his_len self.history_len = his_len self.cursor = 0 self.len = 0 def put(self, value): self.history[self.cursor] = value self.cursor += 1 if self.cursor >= self.history_len: self.cursor = 0 if self.len + 1 <= self.history_len: self.len += 1 def mean(self): return np.array(self.history[0: self.len]).mean() class SGDRScheduler(Callback): '''Cosine annealing learning rate scheduler with periodic restarts. # Usage ```python schedule = SGDRScheduler(min_lr=1e-5, max_lr=1e-2, steps_per_epoch=np.ceil(epoch_size/batch_size), lr_decay=0.9, cycle_length=5, mult_factor=1.5) model.fit(X_train, Y_train, epochs=100, callbacks=[schedule]) ``` # Arguments min_lr: The lower bound of the learning rate range for the experiment. max_lr: The upper bound of the learning rate range for the experiment. steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`. lr_decay: Reduce the max_lr after the completion of each cycle. Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8. cycle_length: Initial number of epochs in a cycle. mult_factor: Scale epochs_to_restart after each full cycle completion. initial_epoch: Used to resume training, **note**: Other args must be same as last training. # References Blog post: jeremyjordan.me/nn-learning-rate Original paper: http://arxiv.org/abs/1608.03983 ''' def __init__(self, min_lr, max_lr, steps_per_epoch, lr_decay=1, cycle_length=10, mult_factor=2, initial_epoch=0): self.min_lr = min_lr self.max_lr = max_lr self.lr_decay = lr_decay self.batch_since_restart = 0 self.next_restart = cycle_length self.steps_per_epoch = steps_per_epoch self.cycle_length = cycle_length self.mult_factor = mult_factor self.history = {} self.recovery_status(initial_epoch) def recovery_status(self, initial_epoch): # Return to the last state when it was stopped. if initial_epoch < self.cycle_length: num_cycles = 0 else: ratio = initial_epoch / self.cycle_length num_cycles = 0 while ratio > 0: ratio -= self.mult_factor ** num_cycles num_cycles += 1 # If haven't done if ratio < 0: num_cycles -= 1 done_epochs = 0 for _ in range(num_cycles): self.max_lr *= self.lr_decay done_epochs += self.cycle_length self.cycle_length = np.ceil(self.cycle_length * self.mult_factor) self.batch_since_restart = (initial_epoch - done_epochs) * self.steps_per_epoch def clr(self): '''Calculate the learning rate.''' fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length) lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi)) return lr def on_train_begin(self, logs=None): '''Initialize the learning rate to the minimum value at the start of training.''' logs = logs or {} K.set_value(self.model.optimizer.lr, self.max_lr) def on_batch_end(self, batch, logs=None): '''Record previous batch statistics and update the learning rate.''' logs = logs or {} self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr)) for k, v in logs.items(): self.history.setdefault(k, []).append(v) self.batch_since_restart += 1 K.set_value(self.model.optimizer.lr, self.clr()) def on_epoch_end(self, epoch, logs=None): '''Check for end of current cycle, apply restarts when necessary.''' if epoch + 1 == self.next_restart: self.batch_since_restart = 0 self.cycle_length = np.ceil(self.cycle_length * self.mult_factor) self.next_restart += self.cycle_length self.max_lr *= self.lr_decay self.best_weights = self.model.get_weights() def on_train_end(self, logs=None): '''Set weights to the values from the end of the most recent cycle for best performance.''' self.model.set_weights(self.best_weights) class LRScheduler(Callback): def __init__(self, schedule, watch, watch_his_len=10): super().__init__() self.schedule = schedule self.watch = watch self.history_cache = HistoryCache(watch_his_len) def on_epoch_begin(self, epoch, logs=None): logs = logs or {} logs['lr'] = K.get_value(self.model.optimizer.lr) def on_epoch_end(self, epoch, logs=None): lr = float(K.get_value(self.model.optimizer.lr)) watch_value = logs.get(self.watch) if watch_value is None: raise ValueError("Watched value '" + self.watch + "' don't exist") self.history_cache.put(watch_value) if watch_value > self.history_cache.mean(): lr = self.schedule(epoch, lr) print("Update learning rate: ", lr) K.set_value(self.model.optimizer.lr, lr) class SingleModelCK(ModelCheckpoint): """ 用于解决在多gpu下训练保存的权重无法应用于单gpu的情况 """ def __init__(self, filepath, model, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1): super().__init__(filepath=filepath, monitor=monitor, verbose=verbose, save_weights_only=save_weights_only, save_best_only=save_best_only, mode=mode, period=period) self.model = model def set_model(self, model): pass