#!/usr/bin/env python
# -*- coding: utf-8 -*-

# -----------------------------------------------------------------------------
#
# Prepare for Trimap Segmentation
#
#   Create trimap from alpha
#
# -----------------------------------------------------------------------------

import argparse
import os
import cv2
import numpy as np
import scipy.sparse

# modules
import log_initializer
import config
from datasets import get_valid_names

# logging
from logging import getLogger, INFO
log_initializer.set_fmt()
log_initializer.set_root_level(INFO)
logger = getLogger(__name__)


def compute_trimap_from_alpha(name, src_dir, dst_dir, open_size=10,
                              alpha_margin=10):
    src_path = os.path.join(src_dir, name)
    if not os.path.exists(src_path):
        logger.error('"%s" dose not exist', src_path)
        return

    dst_path = os.path.join(dst_dir, name)
    if os.path.exists(dst_path):
        logger.info('"%s" exists, Skip...', dst_path)
        return

    logger.info('Trimap for "%s"', src_path)
    alpha = cv2.imread(src_path, 0)
    assert alpha.ndim == 2

    # Compute each region
    fore = ((255 - alpha_margin) < alpha)
    back = (alpha < alpha_margin)
    unknown = ~(fore + back)
    unknown = cv2.dilate(
        unknown.astype(np.uint8),
        cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_size, open_size))
    ).astype(np.bool)

    # Draw
    trimap = np.zeros_like(alpha)
    trimap[fore] = 255
    trimap[unknown] = 127

    cv2.imwrite(dst_path, trimap)


def main():
    # Argument
    parser = argparse.ArgumentParser(description='Dataset Preparing Script')
    parser.add_argument('--config', '-c', default='config.json',
                        help='Load config from given json file')
    args = parser.parse_args()

    # Load config
    config.load(args.config)

    # Get valid names for alpha matting
    names = get_valid_names(config.img_crop_dir, config.img_mask_dir,
                            config.img_mean_mask_dir, config.img_mean_grid_dir,
                            config.img_alpha_dir,
                            rm_exts=[False, False, False, True, False])

    # Compute trimap
    logger.info('Compute weight matrix for each image')
    os.makedirs(config.img_trimap_dir, exist_ok=True)
    for name in names:
        compute_trimap_from_alpha(name, config.img_alpha_dir,
                                  config.img_trimap_dir)


if __name__ == '__main__':
    main()