import os
import sys
import json
import scipy.io
import scipy.misc
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import argparse

sys.path.append('libs')

from config import Config
import model as modellib
from model import log
import visualize
from SketchDataset import SketchDataset
from edgelist_utils import refine_mask_with_edgelist

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from keras.backend.tensorflow_backend import set_session

tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
set_session(tf.Session(config=tf_config))


class SkeSegConfig(Config):
    # Give the configuration a recognizable name
    NAME = "sketchyscene"

    # Train on 1 GPU and 16 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 16 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

    # Number of classes (including background)
    NUM_CLASSES = 1 + 46  # background + 46 classes

    # If enabled, resizes instance masks to a smaller size to reduce
    # memory load. Recommended when using high-resolution images.
    USE_MINI_MASK = False

    # image shape.
    IMAGE_MIN_DIM = 768
    IMAGE_MAX_DIM = 768

    # anchor side in pixels
    RPN_ANCHOR_SCALES = (32, 64, 128, 256, 512)

    # use the binary input to filter the pred_mask if 'True'
    IGNORE_BG = True


def segment_data_generation(mode, data_base_dir, use_edgelist=False, debug=False):
    if mode == 'both':
        dataset_types = ['val', 'test']
    else:
        dataset_types = [mode]

    caption_base_dir = 'data'
    outputs_base_dir = 'outputs'
    trained_model_dir = os.path.join(outputs_base_dir, 'snapshot')
    edgelist_result_dir = os.path.join(outputs_base_dir, 'edgelist')
    seg_data_save_base_dir = os.path.join(outputs_base_dir, 'inst_segm_output_data')
    epochs = '0100'
    model_path = os.path.join(trained_model_dir, 'mask_rcnn_sketchyscene_' + epochs + '.h5')

    dataset_class_names = ['bg']
    color_map_mat_path = os.path.join(data_base_dir, 'colorMapC46.mat')
    colorMap = scipy.io.loadmat(color_map_mat_path)['colorMap']
    for i in range(46):
        cat_name = colorMap[i][0][0]
        dataset_class_names.append(cat_name)

    ROAD_LABEL = dataset_class_names.index('road')

    CLASS_ORDERS = [[dataset_class_names.index('sun'), dataset_class_names.index('moon'),
                     dataset_class_names.index('star'), dataset_class_names.index('road')],
                    [dataset_class_names.index('tree')],
                    [dataset_class_names.index('cloud')],
                    [dataset_class_names.index('house')],
                    [dataset_class_names.index('bus'), dataset_class_names.index('car'),
                     dataset_class_names.index('truck')]]

    config = SkeSegConfig()
    model = modellib.MaskRCNN(mode="inference", config=config, model_dir='', log_dir='')

    assert model_path != "", "Provide path to trained weights"
    print("Loading weights from ", model_path)
    model.load_weights(model_path, by_name=True)

    for dataset_type in dataset_types:
        caption_json_path = os.path.join(caption_base_dir, 'sentence_instance_' + dataset_type + '.json')

        fp = open(caption_json_path, "r")
        json_data = fp.read()
        json_data = json.loads(json_data)
        print('data_len', len(json_data))

        # val/test dataset
        dataset = SketchDataset(data_base_dir)
        dataset.load_sketches(dataset_type)
        dataset.prepare()

        split_seg_data_save_base_dir = os.path.join(seg_data_save_base_dir, dataset_type)
        os.makedirs(split_seg_data_save_base_dir, exist_ok=True)

        for data_idx in range(len(json_data)):
            img_idx = json_data[data_idx]['key']
            print('Processing', dataset_type, data_idx + 1, '/', len(json_data))

            original_image, _, gt_class_id, gt_bbox, gt_mask, _ = \
                modellib.load_image_gt(dataset, config, img_idx - 1, use_mini_mask=False)

            ## 1. inference
            results = model.detect([original_image])
            r = results[0]

            pred_boxes = r["rois"]  # (nRoIs, (y1, x1, y2, x2))
            pred_class_ids = r["class_ids"]  # (nRoIs)
            pred_scores = r["scores"]
            pred_masks = r["masks"]  # (768, 768, nRoIs)

            log("pred_boxes", pred_boxes)
            log("pred_class_ids", pred_class_ids)
            log("pred_masks", pred_masks)

            ## 2. Use original_image(768, 768, 3) {0, 255} to filter pred_masks
            if config.IGNORE_BG:
                pred_masks = np.transpose(pred_masks, (2, 0, 1))  # (nRoIs, 768, 768)
                bin_input = original_image[:, :, 0] == 255
                pred_masks[:, bin_input[:, :]] = 0  # (nRoIs, 768, 768)
                pred_masks = np.transpose(pred_masks, (1, 2, 0))  # (768, 768, nRoIs)

            if debug:
                visualize.display_instances(original_image, pred_boxes, pred_masks, pred_class_ids,
                                            dataset.class_names, pred_scores, figsize=(8, 8))

            ## 3. refine pred_masks(768, 768, nRoIs) with edge-list
            if use_edgelist:
                pred_masks = \
                    refine_mask_with_edgelist(img_idx, dataset_type, data_base_dir, edgelist_result_dir,
                                              pred_masks.copy(), pred_boxes)

            ## 4. TODO: remove road prediction
            # pred_boxes = pred_boxes.tolist()
            # pred_masks = np.transpose(pred_masks, (2, 0, 1)).tolist()
            # pred_scores = pred_scores.tolist()
            # pred_class_ids = pred_class_ids.tolist()
            #
            # while ROAD_LABEL in pred_class_ids:
            #     road_idx = pred_class_ids.index(ROAD_LABEL)
            #     pred_boxes.remove(pred_boxes[road_idx])
            #     pred_masks.remove(pred_masks[road_idx])
            #     pred_scores.remove(pred_scores[road_idx])
            #     pred_class_ids.remove(ROAD_LABEL)

            ## 5. TODO: add road from semantic prediction
            # sem_label_base_path = '../../../../Sketch-Segmentation-TF/Segment-Sketch-DeepLab-v2/edge-list/pred_semantic_label_edgelist/'
            # sem_label_base_path = os.path.join(sem_label_base_path, dataset_type, 'mat')
            # sem_label_path = os.path.join(sem_label_base_path, 'L0_sample' + str(img_idx) + '.mat')
            # sem_label = scipy.io.loadmat(sem_label_path)['pred_label_edgelist']  # (750, 750), [0, 46]
            #
            # if ROAD_LABEL in sem_label:
            #     road_mask_img = np.zeros([sem_label.shape[0], sem_label.shape[1], 3], dtype=np.uint8)
            #     road_mask_img[sem_label == ROAD_LABEL] = [255, 255, 255]  # (750, 750, 3), {0, 255}
            #     road_mask_img = scipy.misc.imresize(
            #         road_mask_img, (config.IMAGE_MAX_DIM, config.IMAGE_MAX_DIM), interp='nearest')  # (768, 768, 3)
            #     road_mask = np.zeros(road_mask_img[:, :, 0].shape, dtype=np.uint8)
            #     road_mask[road_mask_img[:, :, 0] == 255] = 1  # (768, 768), {0, 1}
            #     # plt.imshow(road_mask)
            #     # plt.show()
            #
            #     road_bbox = utils.extract_bboxes(np.expand_dims(road_mask, axis=2))  # [num_instances, (y1, x1, y2, x2)]
            #     road_bbox = road_bbox[0]
            #     pred_boxes.append(road_bbox)
            #     pred_masks.append(road_mask)
            #     pred_scores.append(1.)
            #     pred_class_ids.append(ROAD_LABEL)

            # pred_boxes = np.array(pred_boxes, dtype=np.int32)
            # pred_class_ids = np.array(pred_class_ids, dtype=np.int32)
            # pred_scores = np.array(pred_scores, dtype=np.float32)
            # pred_masks = np.array(pred_masks, dtype=np.uint8)
            # pred_masks = np.transpose(pred_masks, [1, 2, 0])  # (768, 768, nRoIs?)

            if debug:
                visualize.display_instances(original_image, pred_boxes, pred_masks, pred_class_ids,
                                            dataset.class_names, pred_scores, figsize=(8, 8))

            ## 8. sort instances
            instance_sorted_index = []

            for order_idx in range(len(CLASS_ORDERS)):
                order_ids = CLASS_ORDERS[order_idx]
                for cate_idx in range(pred_class_ids.shape[0]):
                    if pred_class_ids[cate_idx] in order_ids:
                        instance_sorted_index.append(cate_idx)

            for cate_idx in range(pred_class_ids.shape[0]):
                if cate_idx not in instance_sorted_index:
                    instance_sorted_index.append(cate_idx)

            # print('pred_class_ids', pred_class_ids)
            # print('instance_sorted_index', instance_sorted_index)
            assert len(instance_sorted_index) == pred_class_ids.shape[0]

            pred_class_ids_list = []
            pred_masks_list = []
            pred_boxes_list = []

            for cate_idx_i in range(len(instance_sorted_index)):
                pred_class_ids_list.append(pred_class_ids[instance_sorted_index[cate_idx_i]])

                pred_box = pred_boxes[instance_sorted_index[cate_idx_i]]
                y1, x1, y2, x2 = pred_box
                pred_mask_large = pred_masks[:, :, instance_sorted_index[cate_idx_i]]
                pred_mask = pred_mask_large[y1: y2 + 1, x1: x2 + 1]

                pred_masks_list.append(pred_mask)
                pred_boxes_list.append(pred_box)

            # print('pred_class_ids_list', pred_class_ids_list)
            assert len(pred_class_ids_list) == pred_class_ids.shape[0]

            ## 9. generate .npz data
            npz_name = os.path.join(split_seg_data_save_base_dir, str(img_idx) + '_datas.npz')
            np.savez(npz_name, pred_class_ids=pred_class_ids_list, pred_masks=pred_masks_list,
                     pred_boxes=pred_boxes_list)


def expand_small_segmentation_mask(pred_masks_small_list, pred_boxes):
    pred_masks = []

    for i in range(len(pred_masks_small_list)):
        pred_mask_small = pred_masks_small_list[i]
        y1, x1, y2, x2 = pred_boxes[i]

        pred_mask_exp = np.zeros((768, 768), dtype=np.uint8)
        pred_mask_exp[y1: y2 + 1, x1: x2 + 1] = pred_mask_small
        pred_masks.append(pred_mask_exp)

    pred_masks = np.stack(pred_masks, axis=0)  # (N, IMAGE_SIZE, IMAGE_SIZE)
    return pred_masks


def debug_saved_npz(dataset_type, img_idx, data_base_dir):
    outputs_base_dir = 'outputs'
    seg_data_save_base_dir = os.path.join(outputs_base_dir, 'inst_segm_output_data', dataset_type)

    npz_name = os.path.join(seg_data_save_base_dir, str(img_idx) + '_datas.npz')
    npz = np.load(npz_name)

    pred_class_ids = np.array(npz['pred_class_ids'], dtype=np.int32)
    pred_boxes = np.array(npz['pred_boxes'], dtype=np.int32)
    pred_masks_s = npz['pred_masks']
    pred_masks = expand_small_segmentation_mask(pred_masks_s, pred_boxes)  # [N, H, W]

    pred_masks = np.transpose(pred_masks, (1, 2, 0))
    print(pred_class_ids.shape)
    print(pred_masks.shape)
    print(pred_boxes.shape)

    image_name = 'L0_sample' + str(img_idx) + '.png'
    images_base_dir = os.path.join(data_base_dir, dataset_type, 'DRAWING_GT')
    image_path = os.path.join(images_base_dir, image_name)
    original_image = Image.open(image_path).convert("RGB")
    original_image = original_image.resize((768, 768), resample=Image.NEAREST)
    original_image = np.array(original_image, dtype=np.float32)  # shape = [H, W, 3]

    dataset_class_names = ['bg']
    color_map_mat_path = os.path.join(data_base_dir, 'colorMapC46.mat')
    colorMap = scipy.io.loadmat(color_map_mat_path)['colorMap']
    for i in range(46):
        cat_name = colorMap[i][0][0]
        dataset_class_names.append(cat_name)

    visualize.display_instances(original_image, pred_boxes, pred_masks, pred_class_ids,
                                dataset_class_names, figsize=(8, 8))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_basedir', '-db', type=str, default='../data', help="set the data base dir")
    parser.add_argument('--dataset', '-ds', type=str, choices=['val', 'test', 'both'],
                        default='both', help="choose a dataset")
    parser.add_argument('--use_edgelist', '-el', type=int, choices=[0, 1],
                        default=1, help="use edgelist or not")
    parser.add_argument('--image_id', '-id', type=int, default=-1, help="choose an image for debug")
    args = parser.parse_args()

    segment_data_generation(mode=args.dataset,
                            data_base_dir=args.data_basedir,
                            use_edgelist=args.use_edgelist)

    ## For debugging
    # assert args.dataset in ['val', 'test']
    # assert args.image_id != -1
    # debug_saved_npz(dataset_type=args.dataset,
    #                 img_idx=args.image_id,
    #                 data_base_dir=args.data_basedir)