from torch import optim

from utils import ceildiv


def create_optimizer(parameters, opt):
    lr = opt.learning_rate
    # default learning rates:
    # sgd - 0.5, adagrad - 0.01, adadelta - 1, adam - 0.001, adamax - 0.002, asgd - 0.01, rmsprop - 0.01, rprop - 0.01
    optim_method = opt.optim_method.casefold()
    if optim_method == 'sgd':
        optimizer = optim.SGD(parameters, lr=lr if lr else 0.5, weight_decay=opt.weight_decay)
    elif optim_method == 'adagrad':
        optimizer = optim.Adagrad(parameters, lr=lr if lr else 0.01, weight_decay=opt.weight_decay)
    elif optim_method == 'adadelta':
        optimizer = optim.Adadelta(parameters, lr=lr if lr else 1, weight_decay=opt.weight_decay)
    elif optim_method == 'adam':
        optimizer = optim.Adam(parameters, lr=lr if lr else 0.001, weight_decay=opt.weight_decay)
    elif optim_method == 'adamax':
        optimizer = optim.Adamax(parameters, lr=lr if lr else 0.002, weight_decay=opt.weight_decay)
    elif optim_method == 'asgd':
        optimizer = optim.ASGD(parameters, lr=lr if lr else 0.01, t0=5000, weight_decay=opt.weight_decay)
    elif optim_method == 'rmsprop':
        optimizer = optim.RMSprop(parameters, lr=lr if lr else 0.01, weight_decay=opt.weight_decay)
    elif optim_method == 'rprop':
        optimizer = optim.Rprop(parameters, lr=lr if lr else 0.01)
    else:
        raise RuntimeError("Invalid optim method: " + opt.optim_method)
    return optimizer


def get_learning_rate(optimizer):
    for p in optimizer.param_groups:
        if 'lr' in p:
            return p['lr']


def setup_lr(optimizer, full_log, opt):
    # annealing learning rate
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, factor=opt.lr_reduce_factor, min_lr=opt.lr_min_value,
        threshold=opt.lr_quantity_epsilon, threshold_mode='rel', mode=opt.lr_quantity_mode,
        patience=ceildiv(opt.lr_patience, opt.eval_iter), cooldown=ceildiv(opt.lr_cooldown, opt.eval_iter))

    # create a function and a closure
    averaging_buffer_max_length = ceildiv(opt.lr_quantity_smoothness, opt.eval_iter)
    if averaging_buffer_max_length <= 1:
        averaging_buffer_max_length = 1
    averaging_buffer = []

    def anneal_lr_func(anneal_now=True):
        value_to_monitor = full_log[opt.lr_quantity_to_monitor][-1]
        averaging_buffer.append(value_to_monitor)
        if len(averaging_buffer) > averaging_buffer_max_length:
            averaging_buffer.pop(0)
        averaged_value = sum(averaging_buffer) / float(len(averaging_buffer))
        counter = len(full_log[opt.lr_quantity_to_monitor])
        if opt.anneal_learning_rate and anneal_now:
            lr_scheduler.step(averaged_value, counter)
        return get_learning_rate(optimizer)

    return lr_scheduler, anneal_lr_func