import multiprocessing as mp
if mp.get_start_method(allow_none=True) != 'spawn':
    mp.set_start_method('spawn', force=True)

import argparse
import os
import time
import logging
from datetime import datetime
import numpy as np
import yaml
import pdb

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
torch.multiprocessing.set_sharing_strategy('file_system')

import models
from datasets import GivenSizeSampler, BinDataset, FileListLabeledDataset, FileListDataset
from utils import AverageMeter, load_state, save_state, log, normalize, bin_loader
from evaluation import evaluate, test_megaface

model_names = sorted(name for name in models.backbones.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.backbones.__dict__[name]))

class ArgObj(object):
    def __init__(self):
        pass

parser = argparse.ArgumentParser(description='Multi-Task Face Recognition Training')
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--load-path', default='', type=str)
parser.add_argument('--resume', action='store_true')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--extract', action='store_true')

def main():

    ## config
    global args
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for k,v in config.items():    
        if isinstance(v, dict):
            argobj = ArgObj()
            setattr(args, k, argobj)
            for kk,vv in v.items():
                setattr(argobj, kk, vv)
        else:
            setattr(args, k, v)
    args.ngpu = len(args.gpus.split(','))

    ## asserts
    assert args.model.backbone in model_names, "available backbone names: {}".format(model_names)
    num_tasks = len(args.train.data_root)
    assert(num_tasks == len(args.train.loss_weight))
    assert(num_tasks == len(args.train.batch_size))
    assert(num_tasks == len(args.train.data_list))
    #assert(num_tasks == len(args.train.data_meta))
    if args.val.flag:
        assert(num_tasks == len(args.val.batch_size))
        assert(num_tasks == len(args.val.data_root))
        assert(num_tasks == len(args.val.data_list))
        #assert(num_tasks == len(args.val.data_meta))

    ## mkdir
    if not hasattr(args, 'save_path'):
        args.save_path = os.path.dirname(args.config)
    if not os.path.isdir('{}/checkpoints'.format(args.save_path)):
        os.makedirs('{}/checkpoints'.format(args.save_path))
    if not os.path.isdir('{}/logs'.format(args.save_path)):
        os.makedirs('{}/logs'.format(args.save_path))
    if not os.path.isdir('{}/events'.format(args.save_path)):
        os.makedirs('{}/events'.format(args.save_path))

    ## create dataset
    if not (args.extract or args.evaluate): # train + val
        for i in range(num_tasks):
            args.train.batch_size[i] *= args.ngpu

        #train_dataset = [FaceDataset(args, idx, 'train') for idx in range(num_tasks)]
        train_dataset = [FileListLabeledDataset(
            args.train.data_list[i], args.train.data_root[i],
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.Resize(args.model.input_size),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),]),
            memcached=args.memcached,
            memcached_client=args.memcached_client) for i in range(num_tasks)]
        args.num_classes = [td.num_class for td in train_dataset]
        train_longest_size = max([int(np.ceil(len(td) / float(bs))) for td, bs in zip(train_dataset, args.train.batch_size)])
        train_sampler = [GivenSizeSampler(td, total_size=train_longest_size * bs, rand_seed=args.train.rand_seed) for td, bs in zip(train_dataset, args.train.batch_size)]
        train_loader = [DataLoader(
            train_dataset[k], batch_size=args.train.batch_size[k], shuffle=False,
            num_workers=args.workers, pin_memory=False, sampler=train_sampler[k]) for k in range(num_tasks)]
        assert(all([len(train_loader[k]) == len(train_loader[0]) for k in range(num_tasks)]))

        if args.val.flag:
            for i in range(num_tasks):
                args.val.batch_size[i] *= args.ngpu
    
            #val_dataset = [FaceDataset(args, idx, 'val') for idx in range(num_tasks)]
            val_dataset = [FileListLabeledDataset(
                args.val.data_list[i], args.val.data_root[i],
                transforms.Compose([
                    transforms.Resize(args.model.input_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),]),
                memcached=args.memcached,
                memcached_client=args.memcached_client) for idx in range(num_tasks)]
            
            val_longest_size = max([int(np.ceil(len(vd) / float(bs))) for vd, bs in zip(val_dataset, args.val.batch_size)])
            val_sampler = [GivenSizeSampler(vd, total_size=val_longest_size * bs, sequential=True) for vd, bs in zip(val_dataset, args.val.batch_size)]
            val_loader = [DataLoader(
                val_dataset[k], batch_size=args.val.batch_size[k], shuffle=False,
                num_workers=args.workers, pin_memory=False, sampler=val_sampler[k]) for k in range(num_tasks)]
            assert(all([len(val_loader[k]) == len(val_loader[0]) for k in range(num_tasks)]))

    if args.test.flag or args.evaluate: # online or offline evaluate
        args.test.batch_size *= args.ngpu
        test_dataset = []
        for tb in args.test.benchmark:
            if tb == 'megaface':
                test_dataset.append(FileListDataset(args.test.megaface_list,
                    args.test.megaface_root, transforms.Compose([
                    transforms.Resize(args.model.input_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),])))
            else:
                test_dataset.append(BinDataset("{}/{}.bin".format(args.test.test_root, tb),
                    transforms.Compose([
                    transforms.Resize(args.model.input_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                    ])))
        test_sampler = [GivenSizeSampler(td,
            total_size=int(np.ceil(len(td) / float(args.test.batch_size)) * args.test.batch_size),
            sequential=True, silent=True) for td in test_dataset]
        test_loader = [DataLoader(
            td, batch_size=args.test.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=False, sampler=ts)
            for td, ts in zip(test_dataset, test_sampler)]

    if args.extract: # feature extraction
        args.extract_info.batch_size *= args.ngpu
#        extract_dataset = FaceDataset(args, 0, 'extract')
        extract_dataset = FileListDataset(
            args.extract_info.data_list, args.extract_info.data_root,
            transforms.Compose([
                transforms.Resize(args.model.input_size),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),]),
            memcached=args.memcached,
            memcached_client=args.memcached_client)
        extract_sampler = GivenSizeSampler(
            extract_dataset, total_size=int(np.ceil(len(extract_dataset) / float(args.extract_info.batch_size)) * args.extract_info.batch_size), sequential=True)
        extract_loader = DataLoader(
            extract_dataset, batch_size=args.extract_info.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=False, sampler=extract_sampler)


    ## create model
    log("Creating model on [{}] gpus: {}".format(args.ngpu, args.gpus))
    if args.evaluate or args.extract:
        args.num_classes = None
    model = models.MultiTaskWithLoss(backbone=args.model.backbone, num_classes=args.num_classes, feature_dim=args.model.feature_dim, spatial_size=args.model.input_size, arc_fc=args.model.arc_fc, feat_bn=args.model.feat_bn)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    model = nn.DataParallel(model)
    model.cuda()
    cudnn.benchmark = True

    ## criterion and optimizer
    optimizer = torch.optim.SGD(model.parameters(), args.train.base_lr,
                                momentum=args.train.momentum,
                                weight_decay=args.train.weight_decay)

    ## resume / load model
    start_epoch = 0
    count = [0]
    if args.load_path:
        assert os.path.isfile(args.load_path), "File not exist: {}".format(args.load_path)
        if args.resume:
            checkpoint = load_state(args.load_path, model, optimizer)
            start_epoch = checkpoint['epoch']
            count[0] = checkpoint['count']
        else:
            load_state(args.load_path, model)

    ## offline evaluate
    if args.evaluate:
        for tb, tl, td in zip(args.test.benchmark, test_loader, test_dataset):
            evaluation(tl, model, num=len(td),
                       outfeat_fn="{}_{}.bin".format(args.load_path[:-8], tb),
                       benchmark=tb)
        return

    ## feature extraction
    if args.extract:
        extract(extract_loader, model, num=len(extract_dataset), output_file="{}_{}.bin".format(args.load_path[:-8], args.extract_info.data_name))
        return

    ######################## train #################
    ## lr scheduler
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.train.lr_decay_steps, gamma=args.train.lr_decay_scale, last_epoch=start_epoch-1)

    ## logger
    logging.basicConfig(filename=os.path.join('{}/logs'.format(args.save_path), 'log-{}-{:02d}-{:02d}_{:02d}:{:02d}:{:02d}.txt'.format(
        datetime.today().year, datetime.today().month, datetime.today().day,
        datetime.today().hour, datetime.today().minute, datetime.today().second)),
        level=logging.INFO)
    tb_logger = SummaryWriter('{}/events'.format(args.save_path))

    ## initial validate
    if args.val.flag:
        validate(val_loader, model, start_epoch, args.train.loss_weight, len(train_loader[0]), tb_logger)

    ## initial evaluate
    if args.test.flag and args.test.initial_test:
        log("*************** evaluation epoch [{}] ***************".format(start_epoch))
        for tb, tl, td in zip(args.test.benchmark, test_loader, test_dataset):
            res = evaluation(tl, model, num=len(td),
                             outfeat_fn="{}/checkpoints/ckpt_epoch_{}_{}.bin".format(
                             args.save_path, start_epoch, tb),
                             benchmark=tb)
            tb_logger.add_scalar(tb, res, start_epoch)

    ## training loop
    for epoch in range(start_epoch, args.train.max_epoch):
        lr_scheduler.step()
        for ts in train_sampler:
            ts.set_epoch(epoch)
        # train for one epoch
        train(train_loader, model, optimizer, epoch, args.train.loss_weight, tb_logger, count)
        # save checkpoint
        save_state({
            'epoch': epoch + 1,
            'arch': args.model.backbone,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'count': count[0]
        }, args.save_path + "/checkpoints/ckpt_epoch", epoch + 1, is_last=(epoch + 1 == args.train.max_epoch))

        # validate
        if args.val.flag:
            validate(val_loader, model, epoch, args.train.loss_weight, len(train_loader[0]), tb_logger, count)
        # online evaluate
        if args.test.flag and ((epoch + 1) % args.test.interval == 0 or epoch + 1 == args.train.max_epoch):
            log("*************** evaluation epoch [{}] ***************".format(epoch + 1))
            for tb, tl, td in zip(args.test.benchmark, test_loader, test_dataset):
                res = evaluation(tl, model, num=len(td),
                                 outfeat_fn="{}/checkpoints/ckpt_epoch_{}_{}.bin".format(
                                 args.save_path, epoch + 1, tb),
                                 benchmark=tb)
                tb_logger.add_scalar(tb, res, start_epoch)


def train(train_loader, model, optimizer, epoch, loss_weight, tb_logger, count):
    num_tasks = len(train_loader)
    batch_time = AverageMeter(args.train.average_stats)
    data_time = AverageMeter(args.train.average_stats)
    losses = [AverageMeter(args.train.average_stats) for k in range(num_tasks)]

    # switch to train mode
    model.train()

    end = time.time()
    for i, all_in in enumerate(zip(*tuple(train_loader))):
        input, target = zip(*[all_in[k] for k in range(num_tasks)])
        slice_pt = 0
        slice_idx = [0]
        for l in [p.size(0) for p in input]:
            slice_pt += l // args.ngpu
            slice_idx.append(slice_pt)

        organized_input = []
        organized_target = []
        for ng in range(args.ngpu):
            for t in range(len(input)):
                bs = args.train.batch_size[t] // args.ngpu
                organized_input.append(input[t][ng * bs : (ng + 1) * bs, ...])
                organized_target.append(target[t][ng * bs : (ng + 1) * bs, ...])
        input = torch.cat(organized_input, dim=0)
        target = torch.cat(organized_target, dim=0)

        # measure data loading time
        data_time.update(time.time() - end)

        input_var = torch.autograd.Variable(input.cuda())
        target_var = torch.autograd.Variable(target.cuda())

        # measure accuracy and record loss
        loss = model(input_var, target_var, slice_idx)

        for k in range(num_tasks):
            if torch.__version__ >= '1.1.0':
                losses[k].update(loss[k].mean().item()) 
            else:
                losses[k].update(loss[k].mean().data[0])

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss_total = 0.
        for k in range(num_tasks):
            loss_total = loss_total + loss[k].mean() * loss_weight[k]
        loss_total.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # info
        if i % args.train.print_freq == 0:
            log('Progress: [{0}][{1}/{2}][{3}]    '
                  'Lr: {4:.2g}    '
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})    '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                   epoch, i, len(train_loader[0]), count[0], 
                   optimizer.param_groups[0]['lr'],
                   batch_time=batch_time,
                   data_time=data_time))
            for k in range(num_tasks):
                log('Task: #{0}\t'
                      'LW: {1:.2g}\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                       k, loss_weight[k], loss=losses[k]))

        # tensorboard logger
        for k in range(num_tasks):
            tb_logger.add_scalar('train_loss_{}'.format(k), losses[k].val, count[0])
        tb_logger.add_scalar('lr', optimizer.param_groups[0]['lr'], count[0])

        count[0] += 1

def validate(val_loader, model, criterion, epoch, loss_weight, train_len, tb_logger, count):
    raise NotImplemented
    num_tasks = len(val_loader)
    losses = [AverageMeter(args.val.average_stats) for k in range(num_tasks)]

    # switch to evaluate mode
    model.eval()

    start = time.time()
    for i, all_in in enumerate(zip(*tuple(val_loader))):
        input, target = zip(*[all_in[k] for k in range(num_tasks)])

        slice_pt = 0
        slice_idx = [0]
        for l in [p.size(0) for p in input]:
            slice_pt += l
            slice_idx.append(slice_pt)

        input = torch.cat(tuple(input), dim=0)

        target = [tg.cuda() for tg in target]
        input_var = torch.autograd.Variable(input.cuda(), volatile=True)
        target_var = [torch.autograd.Variable(tg, volatile=True) for tg in target]

        # measure accuracy and record loss
        loss = model(input_var, target_var, slice_idx)

        for k in range(num_tasks):
            if torch.__version__ >= '1.1.0':
                losses[k].update(loss[k].item())
            else:
                losses[k].update(loss[k].data[0])

    log('Test epoch #{}    Time {}'.format(epoch, time.time() - start))
    for k in range(num_tasks):
        log(' * Task: #{0}    Loss {loss.avg:.4f}'.format(k, loss=losses[k]))

    for k in range(num_tasks):
        tb_logger.add_scalar('val_loss_{}'.format(k), losses[k].val, count[0])

def extract(ext_loader, model, num, output_file, silent=False):
    batch_time = AverageMeter(9999999)
    data_time = AverageMeter(9999999)
    model.eval()
    features = []

    start = time.time()
    end = time.time()
    for i, input in enumerate(ext_loader):
        data_time.update(time.time() - end)
        input_var = torch.autograd.Variable(input.cuda(), volatile=True)
        output = model(input_var, extract_mode=True)
        features.append(output.data.cpu().numpy())
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.train.print_freq == 0 and not silent:
            log("Extracting: {0}/{1}\t"
                    "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    "Data {data_time.val:.3f} ({data_time.avg:.3f})".format(
                    i, len(ext_loader), batch_time=batch_time, data_time=data_time))

    features = np.concatenate(features, axis=0)[:num, :]
    features.tofile(output_file)
    if not silent:
        log("Extracting Done. Total time: {}".format(time.time() - start))
    return features

def evaluation(test_loader, model, num, outfeat_fn, benchmark):
    load_feat = True
    if not os.path.isfile(outfeat_fn) or not load_feat:
        features = extract(test_loader, model, num, outfeat_fn, silent=True)
    else:
        print("loading from: {}".format(outfeat_fn))
        features = np.fromfile(outfeat_fn, dtype=np.float32).reshape(-1, args.model.feature_dim)

    if benchmark == "megaface":
        r = test_megaface(features)
        log(' * Megaface: 1e-6 [{}], 1e-5 [{}], 1e-4 [{}]'.format(r[-1], r[-2], r[-3]))
        return r[-1]
    else:
        features = normalize(features)
        _, lbs = bin_loader("{}/{}.bin".format(args.test.test_root, benchmark))
        _, _, acc, val, val_std, far = evaluate(
            features, lbs, nrof_folds=args.test.nfolds, distance_metric=0)
    
        log(" * {}: accuracy: {:.4f}({:.4f})".format(benchmark, acc.mean(), acc.std()))
        return acc.mean()


#def evaluation_old(test_loader, model, num, outfeat_fn, benchmark):
#    load_feat = False
#    if not os.path.isfile(outfeat_fn) or not load_feat:
#        features = extract(test_loader, model, num, outfeat_fn)
#    else:
#        log("Loading features: {}".format(outfeat_fn))
#        features = np.fromfile(outfeat_fn, dtype=np.float32).reshape(-1, args.model.feature_dim)
#
#    if benchmark == "megaface":
#        r = test.test_megaface(features)
#        log(' * Megaface: 1e-6 [{}], 1e-5 [{}], 1e-4 [{}]'.format(r[-1], r[-2], r[-3]))
#        with open(outfeat_fn[:-4] + ".txt", 'w') as f:
#            f.write(' * Megaface: 1e-6 [{}], 1e-5 [{}], 1e-4 [{}]'.format(r[-1], r[-2], r[-3]))
#        return r[-1]
#    elif benchmark == "ijba":
#        r = test.test_ijba(features)
#        log(' * IJB-A: {} [{}], {} [{}], {} [{}]'.format(r[0][0], r[0][1], r[1][0], r[1][1], r[2][0], r[2][1]))
#        with open(outfeat_fn[:-4] + ".txt", 'w') as f:
#            f.write(' * IJB-A: {} [{}], {} [{}], {} [{}]'.format(r[0][0], r[0][1], r[1][0], r[1][1], r[2][0], r[2][1]))
#        return r[2][1]
#    elif benchmark == "lfw":
#        r = test.test_lfw(features)
#        log(' * LFW: mean: {} std: {}'.format(r[0], r[1]))
#        with open(outfeat_fn[:-4] + ".txt", 'w') as f:
#            f.write(' * LFW: mean: {} std: {}'.format(r[0], r[1]))
#        return r[0]


if __name__ == '__main__':
    main()