import torch
import torch.nn.parallel
import torch.optim
from pose.utils.osutils import mkdir_p, isfile, isdir, join
import pose.models as models
from scipy.ndimage import gaussian_filter, maximum_filter
import cv2
import numpy as np

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

def load_image(imgfile, w, h ):
    image = cv2.imread(imgfile)
    image = cv2.resize(image, (w, h))
    image = image[:, :, ::-1]  # BGR -> RGB
    image = image / 255.0
    image = image - np.array([[[0.4404, 0.4440, 0.4327]]])  # Extract mean RGB
    image = image.transpose((2, 0, 1))  # Change data layout from HWC to CHW
    image = image[np.newaxis, :, :, :]
    return image

def load_model(arch, stacks, blocks, num_classes, mobile, checkpoint_resume):
    # create model
    model = models.__dict__[arch](num_stacks=stacks, num_blocks=blocks, num_classes=num_classes, mobile=mobile)

    # optionally resume from a checkpoint
    if isfile(checkpoint_resume):
        print("=> loading checkpoint '{}'".format(checkpoint_resume))
        checkpoint =  torch.load(checkpoint_resume, map_location=lambda storage, loc: storage)
        # create new OrderedDict that does not contain `module.`
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in checkpoint['state_dict'].items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        # load params
        model.load_state_dict(new_state_dict)
        print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_resume))

    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
    model.eval()
    return model

def inference(model, image, device):
    input_tensor = torch.from_numpy(image).float().to(device)
    model = model.to(device)
    output = model(input_tensor)
    output = output[-1]
    output = output.data.cpu()
    kps = post_process_heatmap(output[0,:,:,:])
    return kps


def post_process_heatmap(heatMap, kpConfidenceTh=0.2):
    kplst = list()
    for i in range(heatMap.shape[0]):
        _map = heatMap[i, :, :]
        _map = gaussian_filter(_map, sigma=1)
        _nmsPeaks = non_max_supression(_map, windowSize=3, threshold=1e-6)

        y, x = np.where(_nmsPeaks == _nmsPeaks.max())
        if len(x) > 0 and len(y) > 0:
            kplst.append((int(x[0]), int(y[0]), _nmsPeaks[y[0], x[0]]))
        else:
            kplst.append((0, 0, 0))

    kp = np.array(kplst)
    return kp


def non_max_supression(plain, windowSize=3, threshold=1e-6):
    # clear value less than threshold
    under_th_indices = plain < threshold
    plain[under_th_indices] = 0
    return plain * (plain == maximum_filter(plain, footprint=np.ones((windowSize, windowSize))))

def render_kps(cvmat, kps, scale_x, scale_y):
    for _kp in kps:
        _x, _y, _conf = _kp
        if _conf > 0.2:
            cv2.circle(cvmat, center=(int(_x*4*scale_x), int(_y*4*scale_y)), color=(0,0,255), radius=5)

    return cvmat


def main(args):
    # load checkpoint
    model = load_model(args.arch, args.stacks, args.blocks, args.num_classes, args.mobile, args.checkpoint)
    in_res_h , in_res_w = args.in_res, args.in_res

    # load image from file and do preprocess
    image = load_image(args.image, in_res_w, in_res_h)

    # do inference
    kps = inference(model, image, args.device)

    # render the detected keypoints
    cvmat = cv2.imread(args.image)
    scale_x = cvmat.shape[1]*1.0/in_res_w
    scale_y = cvmat.shape[0]*1.0/in_res_h
    render_kps(cvmat, kps, scale_x, scale_y)

    cv2.imshow('x', cvmat)
    cv2.waitKey(0)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    # Model structure
    parser.add_argument('--arch', '-a', metavar='ARCH', default='hg',
                        choices=model_names,
                        help='model architecture: ' +
                             ' | '.join(model_names) +
                             ' (default: resnet18)')
    parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N',
                        help='Number of hourglasses to stack')
    parser.add_argument('-b', '--blocks', default=1, type=int, metavar='N',
                        help='Number of residual modules at each location in the hourglass')
    parser.add_argument('--num-classes', default=16, type=int, metavar='N',
                        help='Number of keypoints')
    parser.add_argument('--mobile', default=False, type=bool, metavar='N',
                        help='use depthwise convolution in bottneck-block')
    parser.add_argument('--checkpoint', required=True, type=str, metavar='N',
                        help='pre-trained model checkpoint')
    parser.add_argument('--in_res', required=True, type=int, metavar='N',
                        help='input shape 128 or 256')
    parser.add_argument('--image', default='data/sample.jpg', type=str, metavar='N',
                        help='input image')
    parser.add_argument('--device', default='cuda', type=str, metavar='N',
                        help='device')
    main(parser.parse_args())