# -*- coding: utf-8 -*-
# Created by li huayong on 2019/10/7
import re
import math
import torch
# from model.optimization import *
import pytorch_transformers.optimization as huggingfaceOptim  # 避免和torch.optim重名
from PyToolkit.PyToolkit import get_logger, debug_print


def get_optimizer_old(name, parameters, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
    if name == 'sgd':
        return torch.optim.SGD(parameters, lr=lr)
    elif name == 'adagrad':
        return torch.optim.Adagrad(parameters, lr=lr)
    elif name == 'adam':
        return torch.optim.Adam(parameters, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
    elif name == 'adamax':
        return torch.optim.Adamax(parameters)  # use default lr
    else:
        raise Exception("Unsupported optimizer: {}".format(name))


def _get_bertology_optimizer_grouped_parameters(args, model):
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]
    return optimizer_grouped_parameters


def _get_bertology_different_lr_grouped_parameters(args, model):
    no_decay = ['bias', 'LayerNorm.weight']
    params_bert_no_decay = []
    params_bert_decay = []
    params_other_no_decay = []
    params_other_decay = []
    for n, p in model.named_parameters():
        if 'encoder.bertology.' in n:
            # 是BERTology的参数
            if any(nd in n for nd in no_decay):
                # no decay
                params_bert_no_decay.append(p)
            else:
                params_bert_decay.append(p)
        else:
            # 不是bertology的参数
            if any(nd in n for nd in no_decay):
                # no decay
                params_other_no_decay.append(p)
            else:
                params_other_decay.append(p)
    optimizer_grouped_parameters = [
        {
            'params': params_bert_decay,
            'weight_decay': args.weight_decay,
            'lr': args.bertology_lr,
        },
        {
            'params': params_bert_no_decay,
            'weight_decay': 0.0,
            'lr': args.bertology_lr,
        },
        {
            'params': params_other_decay,
            'weight_decay': args.weight_decay,
            'lr': args.other_lr,
        },
        {
            'params': params_other_no_decay,
            'weight_decay': 0.0,
            'lr': args.other_lr,
        },
    ]
    return optimizer_grouped_parameters


def get_optimizer(args, model):
    logger = get_logger(args.log_name)
    args.warmup_steps = math.ceil(args.warmup_prop * args.max_train_steps)
    if args.optimizer == 'adamw-bertology':
        if args.different_lr:
            optimizer_grouped_parameters = _get_bertology_different_lr_grouped_parameters(args, model)
        else:
            optimizer_grouped_parameters = _get_bertology_optimizer_grouped_parameters(args, model)
        optimizer = huggingfaceOptim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon,
                                           betas=(args.beta1, args.beta2))
        scheduler = huggingfaceOptim.WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps,
                                                          t_total=args.max_train_steps)
        if args.local_rank in [-1, 0]:
            logger.info('Use Huggingface\'s AdamW Optimizer')
    elif args.optimizer == 'adamw-torch':
        try:
            from torch.optim import AdamW
        except ImportError as e:
            debug_print(f'torch version: {torch.__version__}')
            raise e
        if args.different_lr:
            optimizer_grouped_parameters = _get_bertology_different_lr_grouped_parameters(args, model)
        else:
            optimizer_grouped_parameters = _get_bertology_optimizer_grouped_parameters(args, model)
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon,
                          betas=(args.beta1, args.beta2))
        scheduler = huggingfaceOptim.WarmupLinearSchedule(optimizer,
                                                          warmup_steps=args.warmup_steps,
                                                          t_total=args.max_train_steps)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
        scheduler = None
    elif args.optimizer == 'adagrad':
        optimizer = torch.optim.Adagrad(model.parameters(), lr=args.learning_rate)
        scheduler = None
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=args.betas, eps=args.eps,
                                     weight_decay=args.weight_decay)
        scheduler = None
    elif args.rnn_optimizer == 'adamax':
        optimizer = torch.optim.Adamax(model.parameters())  # use default lr
        scheduler = None
    else:
        raise Exception("Unsupported optimizer: {}".format(args.optimizer))
    return optimizer, scheduler


if __name__ == '__main__':
    pass