#!/usr/bin/env python # -*- coding: utf-8 -*- # ----------------------------------------------------------------------------- # # Prepare for Portrait FCN. # # Paper's dataset will be downloaded and converted for training. # # ----------------------------------------------------------------------------- import argparse import os import cv2 import glob import urllib.request import PIL.Image import scipy.io # modules import log_initializer import config # logging from logging import getLogger, INFO log_initializer.set_fmt() log_initializer.set_root_level(INFO) logger = getLogger(__name__) def load_img_urls(filepath): # Load file logger.info('Load image urls from "%s"', filepath) with open(filepath, 'r') as f: lines = f.read().strip().split('\n') # Parse for each line url_pairs = list() for line in lines: items = line.split() if len(items) != 2: logger.error('Invalid line. (%s)', line) continue if items[1] != 'None': # Register url_pairs.append(items) return url_pairs def download_img(url, img_name, base_dir): img_path = os.path.join(base_dir, img_name) if os.path.exists(img_path): logger.info('"%s" exists. Skip...', img_path) return logger.info('Download to "%s"', img_path) try: urllib.request.urlretrieve(url, img_path) except urllib.error.HTTPError: logger.warin('Failed to download') def load_crop_rects(filepath): # Load file logger.info('Load crop rectangles from "%s"', filepath) with open(filepath, 'r') as f: lines = f.read().strip().split('\n') # Parse for each line rect_pairs = list() for line in lines: items = line.split() if len(items) != 5: logger.error('Invalid line. (%s)', line) continue # Register rect_pairs.append((items[0], items[1:5])) return rect_pairs def crop_img(img_name, src_dir, dst_dir, crop_rect, img_size): src_path = os.path.join(src_dir, img_name) if not os.path.exists(src_path): logger.error('"%s" dose not exist', src_path) return dst_path = os.path.join(dst_dir, img_name) if os.path.exists(dst_path): logger.info('"%s" exists, Skip...', dst_path) return logger.info('Crop "%s" to "%s"', src_path, dst_path) img = cv2.imread(src_path) x0, y0 = int(crop_rect[2]), int(crop_rect[0]) x1, y1 = int(crop_rect[3]), int(crop_rect[1]) img = img[y0:y1, x0:x1, :] img = cv2.resize(img, img_size) cv2.imwrite(dst_path, img) def parse_mask(mask_name, src_dir, img_name, dst_dir): src_path = os.path.join(src_dir, mask_name) if not os.path.exists(src_path): logger.error('"%s" dose not exist', src_path) return dst_path = os.path.join(dst_dir, img_name) if os.path.exists(dst_path): logger.info('"%s" exists, Skip...', dst_path) return logger.info('Parse mask "%s" to "%s"', src_path, dst_path) img = scipy.io.loadmat(src_path)['mask'] img *= 255 # [0:1] -> [0:255] cv2.imwrite(dst_path, img) def main(): # Argument parser = argparse.ArgumentParser(description='Dataset Preparing Script') parser.add_argument('--config', '-c', default='config.json', help='Load config from given json file') args = parser.parse_args() # Load config config.load(args.config) # Load image urls url_pairs = load_img_urls(config.org_imgurl_filepath) # Download os.makedirs(config.img_raw_dir, exist_ok=True) for name, url in url_pairs: download_img(url, name, config.img_raw_dir) # Load crop rectangles rect_pairs = load_crop_rects(config.org_crop_filepath) # Crop img_size = (600, 800) # Decided by mask size os.makedirs(config.img_crop_dir, exist_ok=True) for name, rect in rect_pairs: crop_img(name, config.img_raw_dir, config.img_crop_dir, rect, img_size) # Parse masks os.makedirs(config.img_mask_dir, exist_ok=True) for name, _ in rect_pairs: mask_name = '{}_mask.mat'.format(os.path.splitext(name)[0]) parse_mask(mask_name, config.org_mask_dir, name, config.img_mask_dir) if __name__ == '__main__': main()