from __future__ import print_function, division

from keras.models import load_model
from keras.applications.imagenet_utils import preprocess_input as pinput

import cv2 as cv
import numpy as np
import os
import argparse
from metric import *
import glob
from model.fast_scnn import resize_image
from segmentation_models.losses import *

import warnings

warnings.filterwarnings('ignore')

import tensorflow as tf

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

IMG_SIZE = None


def vis_parsing_maps(im, parsing_anno, data_name):
    part_colors = [[255, 255, 255], [0, 255, 0], [255, 0, 0]]

    if data_name == 'figaro1k':
        part_colors = [[255, 255, 255], [255, 0, 0]]

    im = np.array(im)
    vis_im = im.copy().astype(np.uint8)
    vis_parsing_anno_color = np.zeros(
        (parsing_anno.shape[0], parsing_anno.shape[1], 3))

    for pi in range(len(part_colors)):
        index = np.where(parsing_anno == pi)
        vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
    vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)

    # Guided filter
    # vis_parsing_anno_color = cv.ximgproc.guidedFilter(
    #     guide=vis_im, src=vis_parsing_anno_color, radius=4, eps=50, dDepth=-1)
    vis_im = cv.addWeighted(vis_im, 0.7, vis_parsing_anno_color, 0.3, 0)

    return vis_im


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--image_size",
                        help="size of image", type=int, default=256)
    parser.add_argument("--model_path",
                        help="the path of model", type=str,
                        default='./weights/celebhair/exper/fastscnn/model.h5')
    args = parser.parse_args()

    IMG_SIZE = args.image_size
    MODEL_PATH = args.model_path

    if MODEL_PATH.split('/')[-2] == 'lednet':
        from model.lednet import LEDNet

        model = LEDNet(2, 3, (256, 256, 3)).model()
        model.load_weights(MODEL_PATH)

    else:
        model = load_model(MODEL_PATH, custom_objects={'mean_accuracy': mean_accuracy,
                                                       'mean_iou': mean_iou,
                                                       'frequency_weighted_iou': frequency_weighted_iou,
                                                       'pixel_accuracy': pixel_accuracy,
                                                       'categorical_crossentropy_plus_dice_loss': cce_dice_loss,
                                                       'resize_image': resize_image})

    data_name = MODEL_PATH.split('/')[2]

    for img_path in glob.glob(os.path.join("./demo", data_name, "*.jpg")):
        img_basename = os.path.basename(img_path)
        name = os.path.splitext(img_basename)[0]

        org_img = cv.imread(img_path)
        try:
            h, w, _ = org_img.shape
        except:
            raise IOError("Reading image error...")

        img_resize = cv.resize(org_img, (IMG_SIZE, IMG_SIZE))
        img = img_resize[np.newaxis, :]
        # pre-processing
        img = pinput(img)

        result_map = np.argmax(model.predict(img)[0], axis=-1)
        out = vis_parsing_maps(img_resize, result_map, data_name)
        out = cv.resize(out, (w, h), interpolation=cv.INTER_NEAREST)

        cv.imwrite(os.path.join("./demo", data_name, "{}.png").format(name), out)