#!/usr/bin/env python # -*- coding: utf-8 -*- # ----------------------------------------------------------------------------- # # Prepare for Portrait FCN+. # # Create mean mask and warp it. # # ----------------------------------------------------------------------------- import argparse import os import cv2 import numpy as np # modules import log_initializer import config from datasets import PortraitSegDataset, split_dataset, get_valid_names from face_mask import FaceMasker # logging from logging import getLogger, INFO log_initializer.set_fmt() log_initializer.set_root_level(INFO) logger = getLogger(__name__) def align_mask(name, src_dir, dst_mask_dir, dst_grid_dir, face_masker): src_path = os.path.join(src_dir, name) if not os.path.exists(src_path): logger.error('"%s" dose not exist', src_path) return dst_mask_path = os.path.join(dst_mask_dir, name) dst_grid_path = os.path.join(dst_grid_dir, name + '.npz') if os.path.exists(dst_mask_path) and os.path.exists(dst_grid_path): logger.info('"%s" exists, Skip...', dst_mask_path) return logger.info('Align mean maks for "%s"', src_path) img = cv2.imread(src_path) ret_align = face_masker.align(img) if ret_align is None: logger.debug('Failed to detect a face') return mask, grid_x, grid_y = ret_align # Cast for saving storage grid_x = grid_x.astype(np.float32) grid_y = grid_y.astype(np.float32) # Save cv2.imwrite(dst_mask_path, mask) np.savez(dst_grid_path, grid_x=grid_x, grid_y=grid_y) def main(): # Argument parser = argparse.ArgumentParser(description='Dataset Preparing Script') parser.add_argument('--config', '-c', default='config.json', help='Load config from given json file') args = parser.parse_args() # Load config config.load(args.config) # Setup segmentation dataset dataset = PortraitSegDataset(config.img_crop_dir, config.img_mask_dir) # Split into train and test train_raw, _ = split_dataset(dataset) # Setup mean mask face_masker = FaceMasker(config.face_predictor_filepath, config.mean_mask_filepath, train_raw) # Get valid names in 3 channel segmentation stage names = get_valid_names(config.img_crop_dir, config.img_mask_dir) # Start alignment logger.info('Generate aligned mask and grids') os.makedirs(config.img_mean_mask_dir, exist_ok=True) os.makedirs(config.img_mean_grid_dir, exist_ok=True) for name in names: align_mask(name, config.img_crop_dir, config.img_mean_mask_dir, config.img_mean_grid_dir, face_masker) if __name__ == '__main__': main()