import torch import numpy as np from functools import partial class Optimizer(): def __init__(self, parameters, optimizer, lr, eps, lr_scheduler, tf_start=1, tf_end=1, tf_step=1, **kwargs): # Setup teacher forcing scheduler self.tf_type = tf_end != 1 self.tf_rate = lambda step: max( tf_end, tf_start-(tf_start-tf_end)*step/tf_step) # Setup torch optimizer self.opt_type = optimizer self.init_lr = lr self.sch_type = lr_scheduler opt = getattr(torch.optim, optimizer) if lr_scheduler == 'warmup': warmup_step = 4000.0 init_lr = lr self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \ np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5) self.opt = opt(parameters, lr=1.0) elif lr_scheduler == 'spec-aug-basic': # Scheduler from https://arxiv.org/pdf/1904.08779.pdf self.lr_scheduler = partial(speech_aug_scheduler, s_r=500, s_i=20000, s_f=80000, peak_lr=lr) self.opt = opt(parameters, lr=lr, eps=eps) elif lr_scheduler == 'spec-aug-double': # Scheduler from https://arxiv.org/pdf/1904.08779.pdf self.lr_scheduler = partial(speech_aug_scheduler, s_r=1000, s_i=40000, s_f=160000, peak_lr=lr) self.opt = opt(parameters, lr=lr, eps=eps) else: self.lr_scheduler = None self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better? def get_opt_state_dict(self): return self.opt.state_dict() def load_opt_state_dict(self, state_dict): self.opt.load_state_dict(state_dict) def pre_step(self, step): if self.lr_scheduler is not None: cur_lr = self.lr_scheduler(step) for param_group in self.opt.param_groups: param_group['lr'] = cur_lr self.opt.zero_grad() return self.tf_rate(step) def step(self): self.opt.step() def create_msg(self): return ['Optim.spec.| Algo. = {}\t| Lr = {}\t (Scheduler = {})| Scheduled sampling = {}' .format(self.opt_type, self.init_lr, self.sch_type, self.tf_type)] def speech_aug_scheduler(step, s_r, s_i, s_f, peak_lr): # Starting from 0, ramp-up to set LR and converge to 0.01*LR, w/ exp. decay final_lr_ratio = 0.01 exp_decay_lambda = -np.log10(final_lr_ratio)/(s_f-s_i) # Approx. w/ 10-based cur_step = step+1 if cur_step<s_r: # Ramp-up return peak_lr*float(cur_step)/s_r elif cur_step<s_i: # Hold return peak_lr elif cur_step<=s_f: # Decay return peak_lr*np.power(10,-exp_decay_lambda*(cur_step-s_i)) else: # Converge return peak_lr*final_lr_ratio