""" Copyright 2018 Lambda Labs. All Rights Reserved. Licensed under ========================================================================== """ from __future__ import print_function import os import numpy as np import math import tensorflow as tf from .inputter import Inputter from pycocotools.coco import COCO from source.augmenter.external import vgg_preprocessing JSON_TO_IMAGE = { "train2017": "train2017", "val2017": "val2017", "train2014": "train2014", "val2014": "val2014", "valminusminival2014": "val2014", "minival2014": "val2014", "test2014": "test2014", "test-dev2015": "val2017" } class ObjectDetectionMSCOCOInputter(Inputter): def __init__(self, config, augmenter): super(ObjectDetectionMSCOCOInputter, self).__init__(config, augmenter) self.category_id_to_class_id = None self.class_id_to_category_id = None self.cat_names = None # Has to be more than num_gpu * batch_size_per_gpu # Otherwise no valid batch will be produced self.TRAIN_NUM_SAMPLES = 117266 # train2014 + valminusminival2014 self.EVAL_NUM_SAMPLES = 4952 # val2017 (same as test-dev2015) if self.config.mode == "infer": self.test_samples = self.config.test_samples elif self.config.mode == "export": pass else: self.parse_coco() self.num_samples = self.get_num_samples() def parse_coco(self): samples = [] for name_meta in self.config.dataset_meta: annotation_file = os.path.join( self.config.dataset_dir, "annotations", "instances_" + name_meta + ".json") coco = COCO(annotation_file) cat_ids = coco.getCatIds() self.cat_names = [c["name"] for c in coco.loadCats(cat_ids)] # background has class id of 0 self.category_id_to_class_id = { v: i + 1 for i, v in enumerate(cat_ids)} self.class_id_to_category_id = { v: k for k, v in self.category_id_to_class_id.items()} img_ids = coco.getImgIds() img_ids.sort() # list of dict, each has keys: height,width,id,file_name imgs = coco.loadImgs(img_ids) for img in imgs: img["file_name"] = os.path.join( self.config.dataset_dir, JSON_TO_IMAGE[name_meta], img["file_name"]) if self.config.mode == "train": for img in imgs: self.parse_gt(coco, self.category_id_to_class_id, img) samples.extend(imgs) # Filter out images that has no object. if self.config.mode == "train": samples = list(filter( lambda sample: len( sample['boxes'][sample['is_crowd'] == 0]) > 0, samples)) self.samples = samples def get_num_samples(self): if not hasattr(self, 'num_samples'): if self.config.mode == "infer": self.num_samples = len(self.test_samples) elif self.config.mode == "export": self.num_samples = 1 elif self.config.mode == "eval": self.num_samples = self.EVAL_NUM_SAMPLES elif self.config.mode == "train": self.num_samples = self.TRAIN_NUM_SAMPLES return self.num_samples def get_samples_fn(self): # Args: # Returns: # sample["id"]: int64, image id # sample["file_name"]: , string, path to image # sample["class"]: (...,), int64 # sample["boxes"]: (..., 4), float32 # Read image if self.config.mode == "infer": for file_name in self.test_samples: yield (0, file_name, np.empty([1], dtype=np.int32), np.empty([1, 4])) elif self.config.mode == "eval": for sample in self.samples[0:self.num_samples]: yield(sample["id"], sample["file_name"], np.empty([1], dtype=np.int32), np.empty([1, 4])) else: for sample in self.samples[0:self.num_samples]: # remove crowd objects mask = sample['is_crowd'] == 0 sample["class"] = sample["class"][mask] sample["boxes"] = sample["boxes"][mask, :] sample["is_crowd"] = sample["is_crowd"][mask] yield (sample["id"], sample["file_name"], sample["class"], sample["boxes"]) def parse_gt(self, coco, category_id_to_class_id, img): ann_ids = coco.getAnnIds(imgIds=img["id"], iscrowd=None) objs = coco.loadAnns(ann_ids) # clean-up boxes valid_objs = [] width = img["width"] height = img["height"] for obj in objs: if obj.get("ignore", 0) == 1: continue x1, y1, w, h = obj["bbox"] x1 = float(x1) y1 = float(y1) x2 = float(x1 + w) y2 = float(y1 + h) x1 = max(0, min(float(x1), width - 1)) y1 = max(0, min(float(y1), height - 1)) x2 = max(0, min(float(x2), width - 1)) y2 = max(0, min(float(y2), height - 1)) w = x2 - x1 h = y2 - y1 if obj['area'] > 1 and w > 0 and h > 0 and w * h >= 4: # normalize box to [0, 1] obj['bbox'] = [x1 / float(width), y1 / float(height), x2 / float(width), y2 / float(height)] valid_objs.append(obj) boxes = np.asarray([obj['bbox'] for obj in valid_objs], dtype='float32') # (n, 4) cls = np.asarray([ category_id_to_class_id[obj['category_id']] for obj in valid_objs], dtype='int32') # (n,) is_crowd = np.asarray([obj['iscrowd'] for obj in valid_objs], dtype='int8') img['boxes'] = boxes # nx4 img['class'] = cls # n, always >0 img['is_crowd'] = is_crowd # n, def create_nonreplicated_fn(self): batch_size = (self.config.batch_size_per_gpu * self.config.gpu_count) max_step = (self.get_num_samples() * self.config.epochs // batch_size) tf.constant(max_step, name="max_step") def parse_fn(self, image_id, file_name, classes, boxes): """Parse a single input sample """ image = tf.read_file(file_name) image = tf.image.decode_png(image, channels=3) image = tf.to_float(image) scale = [0, 0] translation = [0, 0] if self.augmenter: is_training = (self.config.mode == "train") image, classes, boxes, scale, translation = self.augmenter.augment( image, classes, boxes, self.config.resolution, is_training=is_training, speed_mode=False) return ([image_id], image, classes, boxes, scale, translation, [file_name]) def input_fn(self, test_samples=[]): if self.config.mode == "export": image = tf.placeholder(tf.float32, shape=(self.config.resolution, self.config.resolution, 3), name="input_image") image = tf.to_float(image) image = self.augmenter.preprocess_for_export(image, self.config.resolution) image = tf.expand_dims(image, 0) return ([None], image, None, None, None, None, [None]) else: batch_size = (self.config.batch_size_per_gpu * self.config.gpu_count) dataset = tf.data.Dataset.from_generator( generator=lambda: self.get_samples_fn(), output_types=(tf.int64, tf.string, tf.int64, tf.float32)) if self.config.mode == "train": dataset = dataset.shuffle(self.get_num_samples()) dataset = dataset.repeat(self.config.epochs) dataset = dataset.map( lambda image_id, file_name, classes, boxes: self.parse_fn( image_id, file_name, classes, boxes), num_parallel_calls=12) dataset = dataset.padded_batch( batch_size, padded_shapes=([None], [None, None, 3], [None], [None, 4], [None], [None], [None])) dataset = dataset.prefetch(2) iterator = dataset.make_one_shot_iterator() return iterator.get_next() def build(config, augmenter): return ObjectDetectionMSCOCOInputter(config, augmenter)