# @File  : piston_train_arcface.py
# @Author: X.Yang
# @Contact : pistonyang@gmail.com
# @Date  : 18-10-31

import argparse
import mxnet as mx
import sklearn
from mxnet import gluon, autograd
import os
import time
from mxnet.gluon.data import DataLoader
from mxnet.gluon.data.vision import transforms
import logging
from gluonfr.loss import ArcLoss, L2Softmax, get_loss
from gluonfr.model_zoo import get_model
from gluonfr.data import get_recognition_dataset
from gluonfr.metrics.verification import FaceVerification
from gluonfr.utils.lr_scheduler import IterLRScheduler

parser = argparse.ArgumentParser(description='Train a margin based model for face recognition.')
parser.add_argument('--dataset', type=str, default='emore',
                    help='Training dataset include emore, vgg, webface, default is emore.')
parser.add_argument('--batch-size', type=int, default=512,
                    help='Training batch size.')
parser.add_argument('--ctx', type=str, default="0, 1, 2, 3",
                    help='Use GPUs to train.')
parser.add_argument('--dtype', type=str, default='float32',
                    help='data type for training. default is float32')
parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int,
                    help='Number of preprocessing workers.')
parser.add_argument('--lr', type=float, default=0.1,
                    help='Start training learning rate, default is 0.1.')
parser.add_argument('--lr-mode', type=str, default='step',
                    help='Which type to decay learning rate(choose lr scheduler), default is step.')
parser.add_argument('--lr-decay', type=float, default=0.1,
                    help='decay rate of learning rate. default is 0.1.')
parser.add_argument('--lr-decay-iter', type=str, default="60000, 120000",
                    help='Iters at which learning rate decays. default is 60e3,120e3.')
parser.add_argument('--lr-warmup-iters', type=int, default=0,
                    help='Whether use learning rate warmup, default set 0 to disable.')
parser.add_argument('--wd', type=float, default=5e-4,
                    help='Weight decay of network, default is 0.0005 for l_se_resnet.')
parser.add_argument('--no-wd', action='store_true',
                    help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.')
parser.add_argument('-s', '--margin-s', type=float, default=64, help='scale')
parser.add_argument('-m', '--margin-m', type=float, default=0.5, help='margin')
parser.add_argument('-t', '--val-dateset', dest='target', type=str, default='lfw',
                    help='Val datasets, default is lfw.')
parser.add_argument('-n', '--model', type=str, default='l_se_resnet50v2',
                    help='Network to train.')
parser.add_argument('--logging-file', type=str, default='',
                    help='It will be generated by model name if not specific.')
parser.add_argument('--niters', type=int, default=int(180e3),
                    help='Training iterations.')
parser.add_argument('--loss', type=str, default='arcface',
                    help='Which loss used to train the model.')
parser.add_argument('--loss-warmup-iters', type=int, default=0,
                    help='Whether use loss warm up, default set 0 to disable.')
parser.add_argument('--cat-interval', type=int, default=int(1e3),
                    help='Watch training interval, default is 1000 iters.')
parser.add_argument('--save-dir', type=str, default='params',
                    help='directory of saved models')
parser.add_argument('--hybrid', action='store_true',
                    help='Whether to use hybrid.')
parser.add_argument('--auto-epochs', dest='epochs', type=int, default=25,
                    help='Auto train mode epochs.')
opt = parser.parse_args()

assert opt.batch_size % len(opt.ctx.split(",")) == 0, "Per batch on each GPU must be same."
assert opt.dtype in ('float32', 'float16'), "Data type only support FP16/FP32."
if not os.path.exists(opt.save_dir):
    os.mkdir(opt.save_dir)

logging_file = opt.logging_file
if opt.logging_file == '':
    logging_file = '%s_%s_%s.log' % (opt.dataset, opt.model.replace('_', ''), opt.loss)

filehandler = logging.FileHandler(logging_file)
streamhandler = logging.StreamHandler()

logger = logging.getLogger('')
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)

logger.info(opt)
if opt.dataset == 'emore' and opt.batch_size < 512:
    logger.info("Warning: If you train a model on emore with batch size < 512 may lead to not converge."
                "You may try a smaller dataset.")

transform_test = transforms.Compose([
    transforms.ToTensor()
])

_transform_train = transforms.Compose([
    transforms.RandomBrightness(0.3),
    transforms.RandomContrast(0.3),
    transforms.RandomSaturation(0.3),
    transforms.RandomFlipLeftRight(),
    transforms.ToTensor()
])


def transform_train(data, label):
    im = _transform_train(data)
    return im, label


def inf_train_gen(loader):
    while True:
        for batch in loader:
            yield batch


def auto_train_setting(dataset, epochs=25, lr_rate=0.05, loss_rate=0.35):
    epochs = epochs + 1
    num_train_samples = len(dataset)
    num_iterations = round(int(num_train_samples // batch_size) * epochs, -3)
    lr_warmup_iters = round(int(num_iterations * lr_rate), -2)
    loss_warmup_iters = round(int(num_iterations * loss_rate), -2)

    logger.info('Enable Auto train mode. Following params have been reset. '
                'num_iterations={}, lr_warmup_iters={}, loss_warmup_iters={}.'.format(
        num_iterations, lr_warmup_iters, loss_warmup_iters))
    return num_iterations, lr_warmup_iters, loss_warmup_iters


ctx = [mx.gpu(int(i)) for i in opt.ctx.split(",")]

batch_size = opt.batch_size
num_iterations = opt.niters
lr_warmup_iters = opt.lr_warmup_iters
loss_warmup_iters = opt.loss_warmup_iters

margin_s = opt.margin_s
margin_m = opt.margin_m

train_set = get_recognition_dataset(opt.dataset, transform=transform_train)
train_data = DataLoader(train_set, batch_size, shuffle=True, num_workers=opt.num_workers, last_batch='discard')
batch_generator = inf_train_gen(train_data)

if num_iterations == 0:
    # Auto setting. You should have a large batch size to enable this(512 or larger is recommend).
    # Epochs 25, loss warm up 35%, lr warm up 5% mixup iters 90%.
    num_iterations, lr_warmup_iters, loss_warmup_iters = auto_train_setting(train_data._dataset,
                                                                            epochs=opt.epochs)

targets = opt.target
val_sets = [get_recognition_dataset(name, transform=transform_test) for name in targets.split(",")]
val_datas = [DataLoader(dataset, batch_size, last_batch='keep') for dataset in val_sets]

dtype = opt.dtype
train_net = get_model(opt.model, classes=train_set.num_classes, weight_norm=True, feature_norm=True)
train_net.initialize(init=mx.init.MSRAPrelu(), ctx=ctx)

lr_period = [int(iter) for iter in opt.lr_decay_iter.split(",")]
lr_scheduler = IterLRScheduler(mode=opt.lr_mode, baselr=opt.lr, step=lr_period,
                               step_factor=opt.lr_decay, power=2,
                               niters=num_iterations, warmup_iters=lr_warmup_iters)
optimizer = 'nag'
optimizer_params = {'wd': opt.wd, 'momentum': 0.9, 'lr_scheduler': lr_scheduler}
if opt.dtype != 'float32':
    train_net.cast(dtype)
    optimizer_params['multi_precision'] = True

# TODO(PistonYang): We will support more losses as we train them. Now only ArcFace support FP16.
Loss = None
AFL = ArcLoss(train_set.num_classes, margin_m, margin_s, easy_margin=False, dtype=dtype)
SML = L2Softmax(train_set.num_classes, alpha=margin_s, from_normx=True)


def train():
    train_net.collect_params().reset_ctx(ctx)
    trainer = gluon.Trainer(train_net.collect_params(), optimizer, optimizer_params)

    metric = mx.metric.Accuracy()
    train_loss = mx.metric.Loss()

    metric.reset()
    train_loss.reset()
    sample_time = time.time()
    for iteration in range(1, int(num_iterations + 1)):
        Loss = SML if iteration < loss_warmup_iters else AFL
        batch = next(batch_generator)
        trans = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
        labels = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)

        with autograd.record():
            outputs = [train_net(X.astype(dtype, copy=False))[1] for X in trans]
            losses = [Loss(yhat, y.astype(dtype, copy=False)) for yhat, y in zip(outputs, labels)]
        for loss in losses:
            loss.backward()
        trainer.step(batch_size)

        train_loss.update(0, losses)
        metric.update(labels, outputs)
        if iteration % opt.cat_interval == 0:
            num_samples = (opt.cat_interval * batch_size) // (time.time() - sample_time)
            _, train_acc = metric.get()
            _, epoch_loss = train_loss.get()
            metric.reset()
            train_loss.reset()
            epoch_str = ("Iter %d. Loss: %.5f, Train acc %f, %d samples/s."
                         % (iteration, epoch_loss, train_acc, num_samples))
            logger.info(epoch_str + 'lr ' + str(trainer.learning_rate))
            train_net.save_parameters("%s/%s-it-%d.params" % (opt.save_dir, opt.model, iteration))
            trainer.save_states("%s/%s-it-%d.states" % (opt.save_dir, opt.model, iteration))
            results = validate()
            for result in results:
                logger.info('{}'.format(result))
            sample_time = time.time()


def validate(nfolds=10, norm=True):
    metric = FaceVerification(nfolds)
    results = []
    for loader, name in zip(val_datas, targets.split(",")):
        metric.reset()
        for i, batch in enumerate(loader):
            data0s = gluon.utils.split_and_load(batch[0][0], ctx, even_split=False)
            data1s = gluon.utils.split_and_load(batch[0][1], ctx, even_split=False)
            issame_list = gluon.utils.split_and_load(batch[1], ctx, even_split=False)

            embedding0s = [train_net(X.astype(dtype, copy=False))[0] for X in data0s]
            embedding1s = [train_net(X.astype(dtype, copy=False))[0] for X in data1s]
            if norm:
                embedding0s = [sklearn.preprocessing.normalize(e.asnumpy()) for e in embedding0s]
                embedding1s = [sklearn.preprocessing.normalize(e.asnumpy()) for e in embedding1s]

            for embedding0, embedding1, issame in zip(embedding0s, embedding1s, issame_list):
                metric.update(issame, embedding0, embedding1)

        tpr, fpr, accuracy, val, val_std, far, accuracy_std = metric.get()
        results.append("{}: {:.6f}+-{:.6f}".format(name, accuracy, accuracy_std))
    return results


if __name__ == '__main__':
    if opt.hybrid:
        train_net.hybridize()
    if opt.model == 'mobilefacenet':
        for k, v in train_net.collect_params('.*output_').items():
            v.wd_mult = 10.0
    if opt.no_wd:
        for k, v in train_net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0
    train()