#!/usr/bin/env python
# -*- coding: utf-8 -*-

# -----------------------------------------------------------------------------
#
# Prepare for Portrait Matting
#
#   Create alpha weight image
#
# -----------------------------------------------------------------------------

import argparse
import os
import cv2
import numpy as np
import scipy.sparse

# modules
import log_initializer
import config
from datasets import get_valid_names

# logging
from logging import getLogger, INFO
log_initializer.set_fmt()
log_initializer.set_root_level(INFO)
logger = getLogger(__name__)


def create_pseudo_alpha(name, src_dir, dst_dir):
    src_path = os.path.join(src_dir, name)
    if not os.path.exists(src_path):
        logger.error('"%s" dose not exist', src_path)
        return

    dst_path = os.path.join(dst_dir, name)
    if os.path.exists(dst_path):
        logger.info('"%s" exists, Skip...', dst_path)
        return

    logger.info('Pseudo alpha from "%s"', src_path)
    mask = cv2.imread(src_path)
    alpha = cv2.GaussianBlur(mask, (51, 51), 0)

    # Save
    cv2.imwrite(dst_path, alpha)


class AlphaWeightLut(object):
    def __init__(self, names, alpha_dir, n_data_use=300):
        sum_distrib = None

        n_data = len(names)
        for idx in np.random.permutation(n_data)[:n_data_use]:
            # Load alpha image
            path = os.path.join(alpha_dir, names[idx])
            if not os.path.exists(path):
                logger.error('"%s" dose not exist', path)
                continue
            alpha = cv2.imread(path, 0)

            # Histogram
            distrib, _ = np.histogram(alpha, bins=256)

            # Sum up
            if sum_distrib is None:
                sum_distrib = distrib
            else:
                sum_distrib += distrib

        # Convert to information content
        self.distrib_lut = -np.log(sum_distrib / np.sum(sum_distrib))

    def lookup(self, alpha):
        assert alpha.dtype == np.uint8
        return self.distrib_lut[alpha]


def compute_weights(name, src_dir, dst_dir, weight_lut):
    src_path = os.path.join(src_dir, name)
    if not os.path.exists(src_path):
        logger.error('"%s" dose not exist', src_path)
        return

    dst_path = os.path.join(dst_dir, name + '.npz')
    if os.path.exists(dst_path):
        logger.info('"%s" exists, Skip...', dst_path)
        return

    logger.info('Alpha weight for "%s"', src_path)
    alpha = cv2.imread(src_path, 0)
    assert alpha.ndim == 2
    weight = weight_lut.lookup(alpha)

    # Cast for saving storage
    weight = weight.astype(np.float32)

    # Save
    np.savez(dst_path, weight=weight)


def main():
    # Argument
    parser = argparse.ArgumentParser(description='Dataset Preparing Script')
    parser.add_argument('--config', '-c', default='config.json',
                        help='Load config from given json file')
    parser.add_argument('--pseudo_alpha', action='store_true',
                        help='Dummy alpha generation')
    args = parser.parse_args()

    # Load config
    config.load(args.config)

    if args.pseudo_alpha:
        logger.info('Compute pseudo alpha images')
        # Get valid names in 6 channel segmentation stage
        names = get_valid_names(config.img_crop_dir, config.img_mask_dir,
                                config.img_mean_mask_dir,
                                config.img_mean_grid_dir,
                                rm_exts=[False, False, False, True])
        # Create pseudo alpha images
        os.makedirs(config.img_alpha_dir, exist_ok=True)
        for name in names:
            create_pseudo_alpha(name, config.img_mask_dir,
                                config.img_alpha_dir)

    # Get valid names for alpha matting
    names = get_valid_names(config.img_crop_dir, config.img_mask_dir,
                            config.img_mean_mask_dir, config.img_mean_grid_dir,
                            config.img_alpha_dir,
                            rm_exts=[False, False, False, True, False])

    # Pre-compute look up table for weights
    logger.info('Compute look up table for weights')
    weight_lut = AlphaWeightLut(names, config.img_alpha_dir)

    # Compute weight matrix
    logger.info('Compute weight matrix for each image')
    os.makedirs(config.img_alpha_weight_dir, exist_ok=True)
    for name in names:
        compute_weights(name, config.img_alpha_dir,
                        config.img_alpha_weight_dir, weight_lut)


if __name__ == '__main__':
    main()