import argparse
import os
import sys
import shutil
import time
import numpy as np
from PIL import Image
import json

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from nets.vgg_based_network import HPE_with_PIL_VGG_MSRAInit
from nets.hourglass_based_network import HPE_with_PIL_HG_MSRAInit

from utils.data_loader import LIPDataset
from utils.calc_pckh import calc_pck_lip_dataset 
import utils.eval_util as eval_util

parser = argparse.ArgumentParser(description='PyTorch Human Pose Estimation with Parsing Induced Learner on LIP dataset')
parser.add_argument('--train-data', default='dataset/lip/train_images/', metavar='DIR', help='path to training dataset')
parser.add_argument('--train-pose-anno', default='dataset/lip/jsons/LIP_SP_TRAIN_annotations.json', type=str, metavar='PATH', help='path to training pose annotations')
parser.add_argument('--train-parsing-anno', default='dataset/lip/train_segmentations', metavar='DIR', help='path to training parsing annotations')
parser.add_argument('--eval-data', default='dataset/lip/val_images', metavar='DIR', help='path to eval dataset')
parser.add_argument('--eval-pose-anno', default='dataset/lip/jsons/LIP_SP_VAL_annotations.json', type=str, metavar='PATH', help='path to eval pose annotations')
parser.add_argument('--eval-parsing-anno', default='dataset/lip/val_segmentations', metavar='DIR', help='path to eval parsing annotations')

parser.add_argument('--arch', default='HG', type=str, metavar='PATH', help='Network architecture (VGG or HG (Hourglass), default: HG)')

parser.add_argument('-b', '--batch_size', default=10, type=int, metavar='N', help='mini-batch size (default: 10)')
parser.add_argument('--lr', '--learning-rate', default=0.0015, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--epochs', default=250, type=int, metavar='N', help='number of total epochs to run (default: 250)')
parser.add_argument('--snapshot-fname-prefix', default='exps/snapshots/pil_lip', type=str, metavar='PATH', help='path to snapshot')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)')
parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')

parser.add_argument('--evaluate', default=False, type=bool, metavar='BOOL', help='evaluate or train')
parser.add_argument('--calc-pck', default=False, type=bool, metavar='BOOL', help='caculate PCK or not')
parser.add_argument('--pred-path', default='exps/preds/csv_results/pred_keypoints_lip.csv', type=str, metavar='PATH', help='path to save the prediction results in .csv format')
parser.add_argument('--visualization', default=False, type=bool, metavar='BOOL', help='visualizae prediction or not')
parser.add_argument('--vis-dir', default='exps/preds/vis_results', metavar='DIR', help='path to save visualization results')

best_pck = 0
pck_avg_list = []
pck_all_list = []

def main():
    # Global variables
    global args, best_pck, pck_avg_list, pck_all_list
    args = parser.parse_args()

    # Welcome msg
    phase_str = '[Train and Val Phase]'
    if args.evaluate:
        phase_str = '[Testing Phase]'
    print('Human Pose Estimation with Parsing Induced Learner: {0}'.format(phase_str))

    # Create network
    if args.arch == 'VGG':
        hpe_with_pil_net = HPE_with_PIL_VGG_MSRAInit()
        pose_net_stride = 8
    elif args.arch == 'HG':
        hpe_with_pil_net = HPE_with_PIL_HG_MSRAInit()
        pose_net_stride = 4
    else:
        raise RuntimeError('Unknown network architecture!')
        
    # Multi-GPU setting
    hpe_with_pil_net = nn.DataParallel(hpe_with_pil_net).cuda()

    # CUDNN setting
    cudnn.benchmark = True
    cudnn.enabled = True

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> Loading checkpoints {0}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            hpe_with_pil_net.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch']
            best_pck = checkpoint['best_pck']
            pck_avg_list = checkpoint['pck_avg_list']
            pck_all_list = checkpoint['pck_all_list']
            hpe_with_pil_net_params = hpe_with_pil_net.parameters()
        else:
            print('=> No checkpoint found at {0}'.format(args.resume))

    hpe_with_pil_net_params = hpe_with_pil_net.parameters()

    # Snapshot file names
    snapshot_fname = '{0}.pth.tar'.format(args.snapshot_fname_prefix)
    snapshot_best_fname = '{0}_best.pth.tar'.format(args.snapshot_fname_prefix)

    # Image normalization
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[1, 1, 1])

    # Data transform
    data_transform = transforms.Compose([transforms.ToTensor(), normalize,])

    # LIP dataset
    lip_ds = LIPDataset(args.train_data, \
                        args.train_pose_anno, \
                        args.train_parsing_anno, \
                        transform=data_transform, \
                        pose_net_stride=pose_net_stride, \
                        parsing_net_stride=1, \
                        crop_size=256, \
                        target_dist=1.171, scale_min=0.8, scale_max=1.5, \
                        max_rotate_degree=40, \
                        max_center_trans=40, \
                        flip_prob=0.5, \
                        is_visualization=False)

    # Load training data 
    train_loader = torch.utils.data.DataLoader(lip_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)

    # Load validation data
    print('Loading evaluation json file: {0}...'.format(args.eval_pose_anno))
    eval_list = []
    with open(args.eval_pose_anno) as data_file:
        data_this = json.load(data_file)
        data_this = data_this['root']
        eval_list = eval_list + data_this
    eval_im_name_list = []
    for ii in range(0, len(eval_list)):
        eval_item = eval_list[ii]
        eval_im_name_list.append(eval_item['im_name'])
    print('Finished loading evaluation json file')

    # MSE Loss function for pose estimation and CrossEntropy Loss function for parsing estimation
    pose_criterion = nn.MSELoss().cuda()
    parsing_criterion = nn.NLLLoss2d().cuda()

    # RMSProp as the optimizer
    optimizer = torch.optim.RMSprop(hpe_with_pil_net_params, args.lr)
	
    # Testing 
    if args.evaluate == True:
        evaluate(hpe_with_pil_net, \
                 args.eval_data, \
                 eval_im_name_list, \
                 transform=data_transform, \
                 stride=pose_net_stride, \
                 crop_size=256, \
                 scale_multiplier=[1], \
                 visualization=args.visualization, \
                 vis_result_dir=args.vis_dir, \
                 pred_path=args.pred_path, \
                 is_calc_pck=args.calc_pck)

        return

    for epoch in range(args.start_epoch, args.epochs):

        # Training
        train(train_loader, hpe_with_pil_net, pose_criterion, parsing_criterion, optimizer, epoch)

        # Save snapshot
        torch.save({ 
            'epoch': epoch + 1,
            'state_dict': hpe_with_pil_net.state_dict(),
            'best_pck': best_pck,
            'pck_avg_list': pck_avg_list,
            'pck_all_list': pck_all_list,
        }, snapshot_fname)

        # Validation 
        if epoch < 100:
            val_freq = 10
        elif epoch < 150:
            val_freq = 2
        else:
            val_freq = 1		

        if (epoch + 1) % val_freq == 0:
            pck_avg = evaluate(hpe_with_pil_net, \
                               args.eval_data, \
                               eval_im_name_list, \
                               transform=data_transform, \
                               stride=pose_net_stride, \
                               crop_size=256, \
                               scale_multiplier=[1], \
                               visualization=args.visualization, \
                               vis_result_dir=args.vis_dir, \
                               pred_path=args.pred_path, \
                               is_calc_pck=True)
            
            is_best = pck_avg > best_pck
            best_pck = max(pck_avg, best_pck)

            torch.save({ 
                'epoch': epoch + 1,
                'state_dict': hpe_with_pil_net.state_dict(),
                'best_pck': best_pck,
                'pck_avg_list': pck_avg_list,
                'pck_all_list': pck_all_list,
            }, snapshot_fname)
            if is_best:
                shutil.copyfile(snapshot_fname,snapshot_best_fname)
		
def train(train_loader, model, pose_criterion, parsing_criterion, optimizer, epoch):

    cur_lr = adjust_learning_rate(optimizer, epoch)

    losses = AverageMeter()
    cost_time = AverageMeter()
    train_acc = AverageMeter()

    model.train()

    iter_start_time = time.time()
    for i, (im, pose_target, parsing_target) in enumerate(train_loader):

        # Prepare input and target variables
        im = im.cuda(async=True)
        pose_target = pose_target.float().cuda(async=True)
        parsing_target = parsing_target.long().cuda(async=True)
        input_var = torch.autograd.Variable(im)
        pose_target_var = torch.autograd.Variable(pose_target)
        parsing_target_var = torch.autograd.Variable(parsing_target)

        # Network forward
        pose_output, parsing_output = model(input_var)

        # Calculate parsing loss
        total_loss = 0.01 * parsing_criterion(parsing_output, parsing_target_var)

        # Calculate pose loss
        # Case 1: pose output is a list from Hourglass network
        # Case 2: pose output is a tensor from VGG network
        if isinstance(pose_output, list):
            avg_acc = cal_train_acc(pose_output[-1].data, pose_target)
            for s in range(0, len(pose_output)):
                pose_loss = pose_criterion(pose_output[s], pose_target_var)
                total_loss += pose_loss
        else:
            avg_acc = cal_train_acc(pose_output.data, pose_target)
            pose_loss = pose_criterion(pose_output, pose_target_var)
            total_loss += pose_loss

        train_acc.update(avg_acc, 1)
        losses.update(total_loss.data[0], im.size(0))

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        cost_time.update(time.time() - iter_start_time)
        iter_start_time = time.time()
        
        if i == 0 or (i + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}] \t'
                  'CurLR: {3} \t'
                  'Loss {loss.val:.6f} ({loss.avg:.6f}) \t'
                  'Acc {accuracy.val:.3f} ({accuracy.avg:.3f}) \t'
                  'BatchTime {cost_time.val:.3f} ({cost_time.avg:.3f}) \t'.format(
                  epoch + 1, i + 1, len(train_loader), 
                  cur_lr, 
                  loss=losses, 
                  accuracy=train_acc, 
                  cost_time=cost_time))

def evaluate(model, \
             eval_im_root_dir, \
             eval_im_name_list, \
             transform=None, \
             stride=4, \
             crop_size=256, \
             scale_multiplier=[1], \
             num_of_joints=16, \
             visualization=False, \
             vis_result_dir='exps/preds/vis_results', \
             gt_path='dataset/lip/val_gt/lip_val_groundtruth.csv', \
             pred_path='exps/preds/csv_results/pred_keypoints_lip.csv', \
             is_calc_pck=True):

    model.eval()
    pose_list = eval_util.multi_image_testing_on_lip_dataset(model, \
                                                             eval_im_root_dir, \
                                                             eval_im_name_list, \
                                                             transform=transform, \
                                                             stride=stride, \
                                                             crop_size=crop_size, \
                                                             scale_multiplier=scale_multiplier, \
                                                             num_of_joints=num_of_joints,  \
                                                             visualization=visualization, \
                                                             vis_result_dir=vis_result_dir)

    eval_util.save_hpe_results_to_lip_format(eval_im_name_list, pose_list, save_path=pred_path)

    pck_avg = 0.0
    if is_calc_pck:
        pck_all = calc_pck_lip_dataset(gt_path, pred_path, method_name='hpe_with_pil', eval_num=len(eval_im_name_list))
        pck_avg = pck_all[-1][-1]
        pck_all_list.append(pck_all)
        pck_avg_list.append(pck_avg)   

    return pck_avg

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(optimizer, epoch):
    decay = 0
    if epoch + 1 >= 230:
        decay = 0.05  
    elif epoch + 1 >= 200:
        decay = 0.1
    elif epoch + 1 >= 170:
        decay = 0.25
    elif epoch + 1 >= 150:
        decay = 0.5
    else:
        decay = 1

    lr = args.lr * decay

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

# Get predictions
def get_preds(heatmaps):
    if heatmaps.dim() != 4:
        raise ValueError('Input must be 4-D tensor')
    max_val, max_idx = torch.max(heatmaps.view(heatmaps.size(0), heatmaps.size(1), heatmaps.size(2) * heatmaps.size(3)), 2)
    preds = torch.Tensor(max_idx.size(0), max_idx.size(1), 2)
    preds[:, :, 0] = max_idx[:, :] % heatmaps.size(3)
    preds[:, :, 1] = max_idx[:, :] / heatmaps.size(3)
    preds[:, :, 1] = preds[:, :, 1].floor()
    return preds

def calc_dists(preds, labels, normalize):
    dists = torch.Tensor(preds.size(1), preds.size(0))
    for i in range(preds.size(0)):
        for j in range(preds.size(1)):
            if labels[i, j, 0] == 0 and labels[i, j, 1] == 0:
                dists[j, i] = -1
            else:
                dists[j, i] = torch.dist(labels[i, j, :], preds[i, j, :]) / normalize
    return dists

def dist_accuracy(dists, th=0.5):
    if torch.ne(dists, -1).sum() > 0:
        return (dists.le(th).eq(dists.ne(-1)).sum()) * 1.0 / dists.ne(-1).sum()
    else:
        return -1

def cal_train_acc(output, target):

    num_of_joints = target.size(1) - 1 

    preds = get_preds(output)
    gt = get_preds(target)
    dists = calc_dists(preds, gt, output.size(3) / 10.0)

    avg_acc = 0.0
    bad_idx_count = 0
    for ji in range(num_of_joints):
        acc = dist_accuracy(dists[ji, :])
        if acc > 0:
            avg_acc += acc
        else:
            bad_idx_count += 1
    if bad_idx_count != num_of_joints:
        avg_acc = avg_acc / (num_of_joints - bad_idx_count)
    return avg_acc

if __name__ == '__main__':
    main()