import cv2
import torch.utils.data

from src.utils import *
import src.config as cfg
from src.dataset.mpi import MPI
from src.model.model import which_model

valid_dset = MPI(cfg.IMG_DIR, cfg.ANNOTATION_PATH, is_train=False)
valid_loader = torch.utils.data.DataLoader(valid_dset, batch_size=1,
                                           shuffle=False, num_workers=cfg.NUM_WORKS)

checkpoint = torch.load(cfg.CHECKPOINT_PATH)
net = which_model(is_shallow=cfg.IS_SHALLOW, net_state_dict=checkpoint['net_state_dict'])
net.to(device)


def test():
    net.eval()
    for l, (input, target, mask) in enumerate(valid_loader):
        input = input.to(device)
        target, mask = target.to(device), mask.to(device)
        with torch.no_grad():
            output = net(input)

        # Transform relative coordinates to absolute coordinates
        output[:, :, :, 2:96:6] = (output[:, :, :, 2:96:6] + x_offset) * cfg.CELL_SIZE
        output[:, :, :, 3:96:6] = (output[:, :, :, 3:96:6] + y_offset) * cfg.CELL_SIZE
        output[:, :, :, 4:96:6] = output[:, :, :, 4:96:6].pow(2) * cfg.IMG_SIZE
        output[:, :, :, 5:96:6] = output[:, :, :, 5:96:6].pow(2) * cfg.IMG_SIZE

        input = input.cpu().numpy().transpose(0, 2, 3, 1)
        input = input * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        input = np.clip(input, 0, 1)
        output = output.cpu().numpy()
        output_limb = np.reshape(output[0, :, :, 96:], (12, 12, 9, 9, 15))
        target_limb = target.cpu().numpy()[0, :, :, 96:].reshape((12, 12, 9, 9, 15))
        mask_limb = mask.cpu().numpy()[0, :, :, 96:].reshape((12, 12, 9, 9, 15))

        # Extract candidate joints from output tensor and none maximum suppression
        all_joints_before = [[] for _ in range(16)]
        all_joints = [[] for _ in range(16)]
        for m in range(16):
            is_exist = output[:, :, :, 6*m] >= cfg.thres1
            joints = output[:, :, :, 6*m:6*m+6][is_exist, :]
            if joints.size > 0:
                # The seventh element is the person id which is used to group joints.
                # -1 denotes that it does not belong to any person.
                all_joints_before[m].append(np.concatenate((joints, np.full((joints.shape[0], 1), -1)), axis=1))
                joints = nms(np.concatenate((joints, np.full((joints.shape[0], 1), -1)), axis=1))
                all_joints[m].append(joints)
        draw_box(input, all_joints_before)
        draw_box(input, all_joints)

        # Get connection between joints that constitute limbs
        connection = [[] for _ in range(15)]
        for limb_id in range(15):
            if not all_joints[limbs_start[limb_id]] or not all_joints[limbs_end[limb_id]]:
                continue
            l_start = all_joints[limbs_start[limb_id]][0]
            start_x = (l_start[:, 2] // cfg.CELL_SIZE).astype(int)
            start_y = (l_start[:, 3] // cfg.CELL_SIZE).astype(int)
            l_end = all_joints[limbs_end[limb_id]][0]
            end_x = (l_end[:, 2] // cfg.CELL_SIZE).astype(int)
            end_y = (l_end[:, 3] // cfg.CELL_SIZE).astype(int)

            edges = np.zeros((len(l_end), len(l_start)))
            for i in range(len(l_end)):
                for j in range(len(l_start)):
                    s_y, s_x = start_y[j], start_x[j]
                    if 12 > s_y >= 0 and 12 > s_x >= 0:
                        e_y, e_x = 4 + end_y[i] - s_y, 4 + end_x[i] - s_x
                        if 9 > e_y >= 0 and 9 > e_x >= 0:
                            limb_score = output_limb[s_y, s_x, e_y, e_x, limb_id]
                            if limb_score > 0.1:
                                edges[i, j] = l_end[i, 0] * limb_score * l_start[j, 0]

            n = min(len(l_end), len(l_start))
            for _ in range(n):
                max_score = np.max(edges)
                index_end, index_start = np.nonzero(edges == max_score)
                if max_score != 0:
                    connection[limb_id].append((index_start[0], index_end[0]))
                edges[index_end[0], :] = 0
                edges[:, index_start[0]] = 0

        # Group joints into theirs person
        num_person = len(connection[0])
        persons = []
        for p_id in range(num_person):
            new_person = np.full((16, 2), -1.0)
            person_instance_id = connection[0][p_id][1]
            new_person[0] = all_joints[0][0][person_instance_id][2:4]
            all_joints[0][0][person_instance_id][-1] = p_id
            persons.append(new_person)

        # In order to decide which person the right/left hip joint belong to,
        # we must first know which person the center joint belongs to.
        order = [0, 1, 14, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
        for limb_type in order:
            start_id, end_id = limbs_start[limb_type], limbs_end[limb_type]
            for i, j in connection[limb_type]:
                l_start = all_joints[start_id][0]
                l_end = all_joints[end_id][0]
                p = int(l_end[j, -1])
                if p == -1:
                    continue
                persons[p][start_id, :] = l_start[i, 2:4]
                l_start[i, -1] = l_end[j, -1]
        draw_limb(input, persons)


def nms(joints):
    iou = joints[:, 0]
    order = np.argsort(iou)[::-1]
    joints = joints[order, :]
    for j in range(len(joints)-1):
        if joints[j][1] == 0:
            continue
        tl_x = np.maximum(joints[j, 2] - 0.5 * joints[j, 4], joints[j+1:, 2] - 0.5 * joints[j+1:, 4])
        tl_y = np.maximum(joints[j, 3] - 0.5 * joints[j, 5], joints[j+1:, 3] - 0.5 * joints[j+1:, 5])
        br_x = np.minimum(joints[j, 2] + 0.5 * joints[j, 4], joints[j+1:, 2] + 0.5 * joints[j+1:, 4])
        br_y = np.minimum(joints[j, 3] + 0.5 * joints[j, 5], joints[j+1:, 3] + 0.5 * joints[j+1:, 5])
        delta_x, delta_y = br_x - tl_x, br_y - tl_y
        condition = np.logical_or(delta_x < 0, delta_y < 0)
        intersection = np.where(condition, 0, delta_x * delta_y)
        union = joints[j, 4] * joints[j, 5] + joints[j+1:, 4] * joints[j+1:, 5] - intersection
        joints[j+1:, 1][intersection / union >= cfg.thres2] = 0
    joints = joints[np.nonzero(joints[:, 1] > 0)]
    return joints


def draw_box(img, joints):
    overlay = img[0].copy()
    for i in range(1, 16):
        if joints[i]:
            for j in range(len(joints[i][0])):
                box = joints[i][0][j]
                tl_x, tl_y, br_x, br_y = int(box[2] - 0.5 * box[4]), int(box[3] - 0.5 * box[5]), \
                                         int(box[2] + 0.5 * box[4]), int(box[3] + 0.5 * box[5])
                cv2.rectangle(overlay, (tl_x, tl_y), (br_x, br_y), colors[i-1], -1)

    img_transparent = cv2.addWeighted(overlay, alpha, img[0], 1 - alpha, 0)[:, :, ::-1]
    img_transparent[:, ::cfg.CELL_SIZE, :] = np.array([1., 1, 1])
    img_transparent[::cfg.CELL_SIZE, :, :] = np.array([1., 1, 1])
    cv2.namedWindow('box', cv2.WINDOW_NORMAL)
    cv2.resizeWindow('box', 600, 600)
    cv2.imshow('box', img_transparent)
    key = cv2.waitKey(0)
    if key == ord('s'):
        cv2.imwrite('box.png', img_transparent * 255)


def draw_limb(img, persons):
    overlay = img[0].copy()
    for p in persons:
        for j in range(1, 16):
            if p[j][0] == -1 or p[j][1] == -1:
                continue
            cv2.circle(overlay, (int(p[j][0]), int(p[j][1])), 3, colors[j-1],
                       -1, cv2.LINE_AA)

    for p in persons:
        for j in range(14):
            j1, j2 = p[limbs1[j]], p[limbs2[j]]
            if (j1 == -1).any() or (j2 == -1).any():
                continue
            cv2.line(overlay, (int(j1[0]), int(j1[1])), (int(j2[0]), int(j2[1])),
                     colors[j], 2, cv2.LINE_AA)
    img_dst = cv2.addWeighted(overlay, alpha, img[0], 1-alpha, 0)[:, :, ::-1]
    cv2.namedWindow('persons', cv2.WINDOW_NORMAL)
    cv2.resizeWindow('persons', 600, 600)
    cv2.imshow('persons', img_dst)
    key = cv2.waitKey(0)
    if key == ord('s'):
        cv2.imwrite('persons.png', img_dst * 255)


if __name__ == '__main__':
    test()