# imports import json import time import pickle import scipy.misc import skimage.io import cv2 import caffe import numpy as np import os.path as osp from random import shuffle from PIL import Image import random, copy from human_parse import load_human_annotation import multiprocessing class ImageSegDataLayer(caffe.Layer): """ This is a simple syncronous datalayer for training a Detection model on PASCAL. """ def setup(self, bottom, top): self.top_names = ['data', 'label_1s', 'label_2s', 'label_3s', 'label', 'attention'] # === Read input parameters === # params is a python dictionary with layer parameters. params = eval(self.param_str) SimpleTransformer.check_params(params) # store input as class variables self.batch_size = params['batch_size'] self.input_shape = params['crop_size'] # Create a batch loader to load the images. self.batch_loader = BatchLoader(params) # === reshape tops === # since we use a fixed input image size, we can shape the data layer # once. Else, we'd have to do it in the reshape call. top[0].reshape( self.batch_size, 3, self.input_shape[0], self.input_shape[1]) # Note the 20 channels (because PASCAL has 20 classes.) top[1].reshape( self.batch_size, 1, self.input_shape[0], self.input_shape[1]) top[2].reshape( self.batch_size, 1, self.input_shape[0], self.input_shape[1]) top[3].reshape( self.batch_size, 1, self.input_shape[0], self.input_shape[1]) top[4].reshape( self.batch_size, 1, self.input_shape[0], self.input_shape[1]) top[5].reshape( self.batch_size, 1, self.input_shape[0], self.input_shape[1]) print_info("ImageSegDataLayer", params) def forward(self, bottom, top): """ Load data. """ for itt in range(self.batch_size): # Use the batch loader to load the next image. im, label_1s, label_2s, label_3s, label, label_at = self.batch_loader.perpare_next_data() # Add directly to the caffe data layer top[0].data[itt, ...] = im top[1].data[itt, ...] = label_1s top[2].data[itt, ...] = label_2s top[3].data[itt, ...] = label_3s top[4].data[itt, ...] = label top[5].data[itt, ...] = label_at def reshape(self, bottom, top): """ There is no need to reshape the data, since the input is of fixed size (rows and columns) """ pass def backward(self, top, propagate_down, bottom): """ These layers does not back propagate """ pass class BatchLoader(object): """ This class abstracts away the loading of images. Images can either be loaded singly, or in a batch. The latter is used for the asyncronous data layer to preload batches while other processing is performed. """ def __init__(self, params): self.batch_size = params['batch_size'] self.root_folder = params['root_folder'] self.source = params['source'] self.voc_dir = params['voc_dir'] # get list of image indexes. self.indexlist = [line.strip().split() for line in open(self.source)] self._cur = 0 # current image # this class does some simple data-manipulations self.transformer = SimpleTransformer(params) print "BatchLoader initialized with {} images".format( len(self.indexlist)) def load_next_image(self): """ Load the next image in a batch. """ return self.q.get() def overlap(self, boxes): def iou(box1, box2): x1, y1, x2, y2 = box1 m1, n1, m2, n2 = box2 box1_sq = max(0.1, (x2-x1) * (y2-y1)) box2_sq = max(0.1, (m2-m1) * (n2-n1)) q1 = max(x1, m1) q2 = min(x2, m2) l1 = max(y1, n1) l2 = min(y2, n2) overlap_sq = max(0, (q2-q1)*(l2-l1)) if float(overlap_sq) / min(box1_sq, box2_sq) < 0.1: return True return False for i in range(len(boxes)): for j in range(i+1, len(boxes)): if not iou(boxes[i], boxes[j]): return True return False def is_crowed(self, boxes): if len(boxes) == 0: return False elif len(boxes) == 1: return False elif len(boxes) >= 2: return self.overlap(boxes) def perpare_next_data(self): # Did we finish an epoch? if self._cur == len(self.indexlist): self._cur = 0 shuffle(self.indexlist) # Load an image index = self.indexlist[self._cur] # Get the image index image_file_path, label_file_path = index image = cv2.imread(osp.join(self.root_folder, image_file_path), cv2.IMREAD_COLOR) label = cv2.imread(osp.join(self.root_folder, label_file_path), cv2.IMREAD_GRAYSCALE) img_id = osp.splitext(osp.basename(label_file_path))[0] annotation = load_human_annotation(img_id, self.voc_dir) self._cur += 1 return self.transformer.preprocess(image, label, annotation) def start_batch(self): thread = multiprocessing.Process(target=self.data_generator_task) thread.daemon = True thread.start() def data_generator_task(self): while True: output = self.perpare_next_data() self.q.put(output) class SimpleTransformer: """ SimpleTransformer is a simple class for preprocessing and deprocessing images for caffe. """ def __init__(self, params): SimpleTransformer.check_params(params) self.mean = params['mean'] self.is_mirror = params['mirror'] self.crop_h, self.crop_w = params['crop_size'] self.scale = params['scale'] self.phase = params['phase'] self.ignore_label = params['ignore_label'] def set_mean(self, mean): """ Set the mean to subtract for centering the data. """ self.mean = mean def set_scale(self, scale): """ Set the data scaling. """ self.scale = scale def generate_scale_label(self, image, label, annotation): boxes = annotation['boxes'] gt_ins = annotation['instances'] annos = zip(boxes, gt_ins) # base, ran = self.generate_scale_range(boxes) f_scale = 0.5 + random.randint(0, 10) / 10.0 label_1s, label_2s, label_3s = np.full(label.shape, 254, dtype=np.uint8), np.full(label.shape, 254, dtype=np.uint8), np.full(label.shape, 254, dtype=np.uint8) # label_1s, label_2s, label_3s, label_at= np.zeros_like(label), np.zeros_like(label), np.zeros_like(label), np.zeros_like(label) label_at= np.zeros_like(label) def fitness(annos_item): x1, y1, x2, y2 = annos_item[0] sq = (x2 - x1) * (y2 - y1) return sq annos = sorted(annos, key=fitness, reverse=True) for box, ins in annos: box = np.array(box) x1, y1, x2, y2 = box sq = (x2 - x1) * (y2 - y1) s1 = label_1s[y1:y2, x1:x2] s2 = label_2s[y1:y2, x1:x2] s3 = label_3s[y1:y2, x1:x2] at = label_at[y1:y2, x1:x2] si = ins[y1:y2, x1:x2] index = (ins == 1) if sq < 12544: bg_index = (s1==254) + (si == 1) s1[bg_index] = si[bg_index] label_2s[index] = 255 label_3s[index] = 255 label_at[index] = 1 elif sq >= 12544 and sq <= 50176: bg_index = (s2==254) + (si == 1) label_1s[index] = 255 s2[bg_index] = si[bg_index] label_3s[index] = 255 label_at[index] = 2 elif sq > 50176: bg_index = (s3==254) + (si == 1) label_1s[index] = 255 label_2s[index] = 255 s3[bg_index] = si[bg_index] label_at[index] = 3 label_1s[label_1s==254] = 255; label_2s[label_2s==254] = 255; label_3s[label_3s==254] = 255 image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_LINEAR) label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST) label_1s = cv2.resize(label_1s, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST) label_2s = cv2.resize(label_2s, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST) label_3s = cv2.resize(label_3s, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST) label_at = cv2.resize(label_at, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST) return image, label_1s, label_2s, label_3s, label, label_at def show(self, image, label_1s, label_2s, label_3s, label, label_at): import matplotlib.pyplot as plt from matplotlib import colors # make a color map of fixed colors cmap = colors.ListedColormap([(0,0,0), (0.5,0,0), (0,0.5,0), (0.5,0.5,0), (0,0,0.5), (0.5,0,0.5), (0,0.5,0.5)]) bounds=[0,1,2,3,4,5,6,7] norm = colors.BoundaryNorm(bounds, cmap.N) fig, axes = plt.subplots(2,3) (ax1, ax2, ax3), (ax4, ax5, ax6) = axes ax1.set_title('image'); ax1.imshow(image) ax3.set_title('label'); ax2.imshow(label, cmap=cmap, norm=norm) ax3.set_title('label 1s'); ax3.imshow(label_1s, cmap=cmap, norm=norm) ax4.set_title('label 2s'); ax4.imshow(label_2s, cmap=cmap, norm=norm) ax5.set_title('label 3s'); ax5.imshow(label_3s, cmap=cmap, norm=norm) ax6.set_title('label at'); ax6.imshow(label_at, cmap=cmap, norm=norm) plt.show() def preprocess(self, image, label, annos): """ preprocess() emulate the pre-processing occuring in the vgg16 caffe prototxt. """ # image = cv2.convertTo(image, cv2.CV_64F) image, label_1s, label_2s, label_3s, label, label_at = self.generate_scale_label(image, label, annos) # self.show(image, label_1s, label_2s, label_3s, label, label_at) image = np.asarray(image, np.float32) image -= self.mean image *= self.scale img_h, img_w = label_1s.shape pad_h = max(self.crop_h - img_h, 0) pad_w = max(self.crop_w - img_w, 0) if pad_h > 0 or pad_w > 0: img_pad = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=(0.0, 0.0, 0.0)) label_1s_pad = cv2.copyMakeBorder(label_1s, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=(self.ignore_label,)) label_2s_pad = cv2.copyMakeBorder(label_2s, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=(self.ignore_label,)) label_3s_pad = cv2.copyMakeBorder(label_3s, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=(self.ignore_label,)) label_pad = cv2.copyMakeBorder(label, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=(self.ignore_label,)) label_at_pad = cv2.copyMakeBorder(label_at, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=(self.ignore_label,)) else: img_pad, label_1s_pad, label_2s_pad, label_3s_pad, label_pad, label_at_pad = image, label_1s, label_2s, label_3s, label, label_at img_h, img_w = label_1s_pad.shape if self.phase == 'Train': h_off = random.randint(0, img_h - self.crop_h) w_off = random.randint(0, img_w - self.crop_w) else: h_off = (img_h - self.crop_h) / 2 w_off = (img_w - self.crop_w) / 2 # roi = cv2.Rect(w_off, h_off, self.crop_w, self.crop_h); image = np.asarray(img_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w].copy(), np.float32) label_1s = np.asarray(label_1s_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w].copy(), np.float32) label_2s = np.asarray(label_2s_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w].copy(), np.float32) label_3s = np.asarray(label_3s_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w].copy(), np.float32) label = np.asarray(label_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w].copy(), np.float32) label_at = np.asarray(label_at_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w].copy(), np.float32) #image = image[:, :, ::-1] # change to BGR image = image.transpose((2, 0, 1)) if self.is_mirror: flip = np.random.choice(2) * 2 - 1 image = image[:, :, ::flip] label_1s = label_1s[:, ::flip] label_2s = label_2s[:, ::flip] label_3s = label_3s[:, ::flip] label = label[:, ::flip] label_at = label_at[:, ::flip] return image, label_1s, label_2s, label_3s, label, label_at @classmethod def check_params(cls, params): if 'crop_size' not in params: params['crop_size'] = (505, 505) if 'mean' not in params: params['mean'] = [128, 128, 128] if 'scale' not in params: params['scale'] = 1.0 if 'mirror' not in params: params['mirror'] = False if 'phase' not in params: params['phase'] = 'Train' if 'ignore_label' not in params: params['ignore_label'] = 255 def print_info(name, params): """ Ouput some info regarding the class """ print "{} initialized for split: {}, with bs: {}, im_shape: {}.".format( name, params['source'], params['batch_size'], params['crop_size']) if __name__ == '__main__': params = {'batch_size': 2, 'mean': (104.008, 116.669, 122.675), 'root_folder': 'D:/v-zihuan/segmentation_with_scale/experiment/voc_part/data/', 'source': 'D:/v-zihuan/segmentation_with_scale/experiment/voc_part/list/train_3s.txt', 'mirror': True, 'crop_size': (505, 505)} t = SimpleTransformer(params) image = Image.open(r'D:/v-zihuan/segmentation_with_scale/experiment/voc_part/data/images/2008_000003.jpg') label = Image.open(r'D:/v-zihuan/segmentation_with_scale/experiment/voc_part/data/part_mask_scale_3/2008_000003.png') t.preprocess(image, label)