"""Script to train a model to spot cat faces in images.""" from __future__ import print_function, division from torch.autograd import Variable import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import imgaug as ia from imgaug import augmenters as iaa from scipy import misc import time import cPickle as pickle import numpy as np import cv2 from create_dataset import Example from plotting import History, LossPlotter from common import draw_heatmap from model import Model, Model2 import multiprocessing import threading import math import random random.seed(42) torch.manual_seed(42) torch.cuda.manual_seed(42) torch.backends.cudnn.benchmark = True GPU = 0 # id of the gpu to use BATCH_SIZE_TRAIN = 32 # training batch size BATCH_SIZE_VAL = 32 # validation batch size BATCHES_PER_VAL = 40 # average validation loss over N batches per validation GRID_SIZE = 28 # output heatmap size VAL_EVERY = 250 # validate every N batches SAVE_EVERY = 250 # save model every N batches PLOT_EVERY = 250 # plot loss curve every N batches TRAIN_WINDOW_NAME = "trainwin" # window/file name of training batch debug output VAL_WINDOW_NAME = "valwin" # window/file name of validation batch debug output NB_BATCHES = 30000 # train for N batches SHOW_DEBUG_WINDOWS = False # whether to show example outputs in windows (True) # or write to files instead (False) def main(): # load datsets, create it via create_dataset.py with open("cats-dataset.pkl", "r") as f: examples = pickle.load(f) examples_val = examples[0:1024] examples_train = examples[1024:] # history of loss values gathered during the experiment history = History() history.add_group("loss", ["train", "val"], increasing=False) # object to generate loss plots loss_plotter = LossPlotter( history.get_group_names(), history.get_groups_increasing(), save_to_fp="plot.jpg" ) loss_plotter.start_batch_idx = 100 # load model, loss and stochastic optimizer model = Model2() if GPU >= 0: model.cuda(GPU) criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters()) # initialize image augmentation cascade rarely = lambda aug: iaa.Sometimes(0.1, aug) sometimes = lambda aug: iaa.Sometimes(0.25, aug) often = lambda aug: iaa.Sometimes(0.5, aug) seq = iaa.Sequential([ iaa.Fliplr(0.5), # horizontally flip 50% of all images iaa.Flipud(0.5), # vertically flip 50% of all images rarely(iaa.Superpixels(p_replace=(0, 1.0), n_segments=(20, 200))), # convert images into their superpixel representation often(iaa.Crop(percent=(0, 0.1))), # crop images by 0-10% of their height/width sometimes(iaa.GaussianBlur((0, 3.0))), # blur images with a sigma between 0 and 3.0 sometimes(iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5))), # sharpen images sometimes(iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0))), # emboss images # search either for all edges or for directed edges rarely(iaa.Sometimes(0.5, iaa.EdgeDetect(alpha=(0, 0.7)), iaa.DirectedEdgeDetect(alpha=(0, 0.7), direction=(0.0, 1.0)), )), often(iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.2), per_channel=0.5)), # add gaussian noise to images often(iaa.Dropout((0.0, 0.1), per_channel=0.5)), # randomly remove up to 10% of the pixels rarely(iaa.Invert(0.25, per_channel=True)), # invert color channels often(iaa.Add((-10, 10), per_channel=0.5)), # change brightness of images (by -10 to 10 of original value) often(iaa.Multiply((0.5, 1.5), per_channel=0.25)), # change brightness of images (50-150% of original value) often(iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5)), # improve or worsen the contrast sometimes(iaa.Grayscale(alpha=(0.0, 1.0))), often(iaa.Affine( scale={"x": (0.6, 1.4), "y": (0.6, 1.4)}, # scale images to 60-140% of their size, individually per axis translate_percent={"x": (-0.3, 0.3), "y": (-0.3, 0.3)}, # translate by -30 to +30% percent (per axis) rotate=(-45, 45), # rotate by -45 to +45 degrees shear=(-16, 16), # shear by -16 to +16 degrees order=[0, 1], # use any of scikit-image's interpolation methods cval=(0, 255), # if mode is constant, use a cval between 0 and 255 mode=["constant", "edge"] # use any of scikit-image's warping modes (see 2nd image from the top for examples) )), sometimes(iaa.ElasticTransformation(alpha=(0.5, 3.5), sigma=0.25)) # apply elastic transformations with random strengths ], random_order=True # do all of the above in random order ) # method to generate batches def load_training_batch(): examples_batch = random.sample(examples_train, BATCH_SIZE_TRAIN) images = [ex.image for ex in examples_batch] bb_coords = [ia.KeypointsOnImage(ex.get_bb_coords_keypoints(), shape=image.shape) for ex, image in zip(examples_batch, images)] return Batch(identifiers=None, images=images, keypoints=bb_coords) img_loader = ImageLoader(load_training_batch, nb_workers=1) bg_augmenter = BackgroundAugmenter(seq, img_loader.queue, nb_workers=8) # training loop for batch_idx in range(NB_BATCHES): # load training batch time_cbatch = time.time() batch = bg_augmenter.get_batch() inputs, outputs_gt = images_coords_to_batch(batch.images_aug, batch.keypoints_aug) time_cbatch = time.time() - time_cbatch # train on batch time_fwbw = time.time() loss = run_batch(inputs, outputs_gt, model, criterion, optimizer, train=True) time_fwbw = time.time() - time_fwbw print("[T] %06d | loss=%.4f | %.4fs cbatch | %.4fs fwbw" % (batch_idx, loss, time_cbatch, time_fwbw)) history.add_value("loss", "train", batch_idx, loss) # every N batches, show the true and generated outputs for the training # batch if batch_idx % 50 == 0: update_window(TRAIN_WINDOW_NAME, inputs[0:8], outputs_gt[0:8], model) # Code to generate a video of the training progress #time_vid_start = time.time() #grid = generate_video_image(batch_idx, examples_val[0:66], model) #misc.imsave("training-video/%05d.jpg" % (batch_idx,), grid) # every N batches, validate if (batch_idx+1) % VAL_EVERY == 0: # the validation computes an average over N randomly picked batches time_cbatch_total = 0 time_fwbw_total = 0 loss_total = 0 for i in range(BATCHES_PER_VAL): # load batch time_cbatch = time.time() examples_batch = random.sample(examples_val, BATCH_SIZE_VAL) inputs, outputs_gt = examples_to_batch(examples_batch, iaa.Noop()) time_cbatch = time.time() - time_cbatch time_cbatch_total += time_cbatch # validate on batch (forward + loss) time_fwbw = time.time() loss = run_batch(inputs, outputs_gt, model, criterion, optimizer, train=False) time_fwbw = time.time() - time_fwbw time_fwbw_total += time_fwbw loss_total += loss loss_total_avg = loss_total / BATCHES_PER_VAL # check if average loss of val batches was best value so far if len(history.line_groups["loss"].lines["val"].ys) == 0: is_new_best = True else: minval = np.min(history.line_groups["loss"].lines["val"].ys) is_new_best = (loss_total_avg < minval) print("[V] %06d | loss=%.4f | %.4fs cbatch | %.4fs fwbw" % (batch_idx, loss_total_avg, time_cbatch_total, time_fwbw_total)) history.add_value("loss", "val", batch_idx, loss_total_avg) # show true and generated outputs for the first 8 validation # examples inputs, outputs_gt = examples_to_batch(examples_val[0:8], iaa.Noop()) update_window(VAL_WINDOW_NAME, inputs, outputs_gt, model) # save a checkpoint if the model was the best one so far if is_new_best: torch.save({ "batch_idx": batch_idx, "history": history.to_string(), "state_dict": model.state_dict() }, "model.best.tar") # every N batches, save a checkpoint if (batch_idx+1) % SAVE_EVERY == 0: torch.save({ "batch_idx": batch_idx, "history": history.to_string(), "state_dict": model.state_dict() }, "model.tar") # every N batches, plot the current loss curve if (batch_idx+1) % PLOT_EVERY == 0: loss_plotter.plot(history) def generate_video_image(batch_idx, examples, model): """Generate frames for a video of the training progress. Each frame contains N examples shown in a grid. Each example shows the input image and the main heatmap predicted by the model.""" start_time = time.time() #print("A", time.time() - start_time) model.eval() # fw through network inputs, outputs_gt = examples_to_batch(examples, iaa.Noop()) inputs_torch = torch.from_numpy(inputs) inputs_torch = Variable(inputs_torch, volatile=True) if GPU >= 0: inputs_torch = inputs_torch.cuda(GPU) outputs_pred_torch = model(inputs_torch) #print("B", time.time() - start_time) outputs_pred = outputs_pred_torch.cpu().data.numpy() inputs = (inputs * 255).astype(np.uint8).transpose(0, 2, 3, 1) #print("C", time.time() - start_time) heatmaps = [] for i in range(inputs.shape[0]): hm_drawn = draw_heatmap(inputs[i], np.squeeze(outputs_pred[i][0]), alpha=0.5) heatmaps.append(hm_drawn) #print("D", time.time() - start_time) grid = ia.draw_grid(heatmaps, cols=11, rows=6).astype(np.uint8) #grid_rs = misc.imresize(grid, (720-32, 1280-32)) # pad by 42 for the text and to get the image to 720p aspect ratio grid_pad = np.pad(grid, ((0, 42), (0, 0), (0, 0)), mode="constant") grid_pad_text = ia.draw_text( grid_pad, x=grid_pad.shape[1]-220, y=grid_pad.shape[0]-35, text="Batch %05d" % (batch_idx,), color=[255, 255, 255] ) #print("E", time.time() - start_time) return grid_pad_text def update_window(win, inputs, outputs_gt, model): """Show true and generated outputs/heatmaps for example images.""" model.eval() # prepare inputs and forward through network inputs, outputs_gt = torch.from_numpy(inputs), torch.from_numpy(outputs_gt) inputs, outputs_gt = Variable(inputs), Variable(outputs_gt) if GPU >= 0: inputs = inputs.cuda(GPU) outputs_gt = outputs_gt.cuda(GPU) outputs_pred = model(inputs) # draw rows of resulting image rows = [] for i in range(inputs.size()[0]): # image, ground truth outputs, predicted outputs img_np = (inputs[i].cpu().data.numpy() * 255).astype(np.uint8).transpose(1, 2, 0) hm_gt_np = outputs_gt[i].cpu().data.numpy() hm_pred_np = outputs_pred[i].cpu().data.numpy() # per image # first row: ground truth outputs, # second row: predicted outputs # each row starts with the input image, followed by heatmap images row_truth = [img_np] + [draw_heatmap(img_np, np.squeeze(hm_gt_np[hm_idx]), alpha=0.5) for hm_idx in range(hm_gt_np.shape[0])] row_pred = [img_np] + [draw_heatmap(img_np, np.squeeze(hm_pred_np[hm_idx]), alpha=0.5) for hm_idx in range(hm_pred_np.shape[0])] rows.append(np.hstack(row_truth)) rows.append(np.hstack(row_pred)) grid = np.vstack(rows) if SHOW_DEBUG_WINDOWS: # show grid in opencv window if cv2.getWindowProperty(win, 0) == -1: cv2.namedWindow(win, cv2.WINDOW_NORMAL) cv2.resizeWindow(win, 1200, 600) time.sleep(1) cv2.imshow(win, grid.astype(np.uint8)[:, :, ::-1]) cv2.waitKey(10) else: # save grid to file misc.imsave("window_%s.jpg" % (win,), grid.astype(np.uint8)) def examples_to_batch(examples, seq=None): """Convert examples from the dataset to inputs and ground truth outputs for the model. """ if seq is None: seq = iaa.Noop() seq_det = seq.to_deterministic() inputs = [ex.image for ex in examples] inputs_aug = seq_det.augment_images(inputs) bb_coords = [ex.get_bb_coords_keypoints(seq_det) for ex in examples] return images_coords_to_batch(inputs_aug, bb_coords) def images_coords_to_batch(images, bb_coords): """Convert input images and bounding box coordinates to expected inputs and outputs for the model.""" bb_grids = [bb_coords_to_grid(bb_coords_one, img.shape, GRID_SIZE) for img, bb_coords_one in zip(images, bb_coords)] outputs_gt = bb_grids inputs = (np.array(images)/255.0).astype(np.float32).transpose(0, 3, 1, 2) outputs_gt = np.array(outputs_gt).astype(np.float32).transpose(0, 3, 1, 2) return inputs, outputs_gt def bb_coords_to_grid(bb_coords_one, img_shape, grid_size): """Convert bounding box coordinates (corners) to ground truth heatmaps.""" if isinstance(bb_coords_one, ia.KeypointsOnImage): bb_coords_one = bb_coords_one.keypoints # bb edges after augmentation x1b = min([kp.x for kp in bb_coords_one]) x2b = max([kp.x for kp in bb_coords_one]) y1b = min([kp.y for kp in bb_coords_one]) y2b = max([kp.y for kp in bb_coords_one]) # clip x1c = np.clip(x1b, 0, img_shape[1]-1) y1c = np.clip(y1b, 0, img_shape[0]-1) x2c = np.clip(x2b, 0, img_shape[1]-1) y2c = np.clip(y2b, 0, img_shape[0]-1) # project x1d = int((x1c / img_shape[1]) * grid_size) y1d = int((y1c / img_shape[0]) * grid_size) x2d = int((x2c / img_shape[1]) * grid_size) y2d = int((y2c / img_shape[0]) * grid_size) assert 0 <= x1d < grid_size assert 0 <= y1d < grid_size assert 0 <= x2d < grid_size assert 0 <= y2d < grid_size # output ground truth: # - 1 heatmap that is 1 everywhere where there is a bounding box # - 9 position sensitive heatmaps, # e.g. the first one is 1 everywhere where there is the _top left corner_ # of a bounding box, # the second one is 1 for the top center cell, # the third one is 1 for the top right corner, # ... grids = np.zeros((grid_size, grid_size, 1+9), dtype=np.float32) # first heatmap grids[y1d:y2d+1, x1d:x2d+1, 0] = 1 # position sensitive heatmaps nb_cells_x = 3 nb_cells_y = 3 cell_width = (x2d - x1d) / nb_cells_x cell_height = (y2d - y1d) / nb_cells_y cell_counter = 0 for j in range(nb_cells_y): cell_y1 = y1d + cell_height * j cell_y2 = cell_y1 + cell_height cell_y1_int = np.clip(int(math.floor(cell_y1)), 0, img_shape[0]-1) cell_y2_int = np.clip(int(math.floor(cell_y2)), 0, img_shape[0]-1) for i in range(nb_cells_x): cell_x1 = x1d + cell_width * i cell_x2 = cell_x1 + cell_width cell_x1_int = np.clip(int(math.floor(cell_x1)), 0, img_shape[1]-1) cell_x2_int = np.clip(int(math.floor(cell_x2)), 0, img_shape[1]-1) grids[cell_y1_int:cell_y2_int+1, cell_x1_int:cell_x2_int+1, 1+cell_counter] = 1 cell_counter += 1 return grids def run_batch(inputs, outputs_gt, model, criterion, optimizer, train): """Train or validate on a batch (inputs + outputs).""" if train: model.train() else: model.eval() val = not train inputs, outputs_gt = torch.from_numpy(inputs), torch.from_numpy(outputs_gt) inputs, outputs_gt = Variable(inputs, volatile=val), Variable(outputs_gt) if GPU >= 0: inputs = inputs.cuda(GPU) outputs_gt = outputs_gt.cuda(GPU) if train: optimizer.zero_grad() outputs_pred = model(inputs) loss = criterion(outputs_pred, outputs_gt) if train: loss.backward() optimizer.step() return loss.data[0] class Batch(object): """Class encapsulating a batch before and after augmentation.""" def __init__(self, identifiers, images, keypoints): self.identifiers = identifiers self.images = images self.images_aug = None # keypoints here are the corners of the bounding box self.keypoints = keypoints self.keypoints_aug = None class ImageLoader(object): """Class to load batches in the background.""" def __init__(self, load_batch_func, nb_workers=1, queue_size=50, threaded=True): self.queue = multiprocessing.Queue(queue_size) self.workers = [] for i in range(nb_workers): if threaded: worker = threading.Thread(target=self._load_batches, args=(load_batch_func, self.queue)) else: worker = multiprocessing.Process(target=self._load_batches, args=(load_batch_func, self.queue)) worker.daemon = True worker.start() self.workers.append(worker) def _load_batches(self, load_batch_func, queue): while True: queue.put(pickle.dumps(load_batch_func(), protocol=-1)) class BackgroundAugmenter(object): """Class to augment batches in the background (while training on the GPU).""" def __init__(self, augseq, queue_source, nb_workers, queue_size=50, threaded=False): assert 0 < queue_size <= 10000 self.augseq = augseq self.queue_source = queue_source self.queue_result = multiprocessing.Queue(queue_size) self.workers = [] for i in range(nb_workers): augseq.reseed() if threaded: worker = threading.Thread(target=self._augment_images_worker, args=(self.augseq, self.queue_source, self.queue_result)) else: worker = multiprocessing.Process(target=self._augment_images_worker, args=(self.augseq, self.queue_source, self.queue_result)) worker.daemon = True worker.start() self.workers.append(worker) def get_batch(self): """Returns a batch from the queue of augmented batches.""" batch_str = self.queue_result.get() batch = pickle.loads(batch_str) return batch def _augment_images_worker(self, augseq, queue_source, queue_result): """Worker function that endlessly queries the source queue (input batches), augments batches in it and sends the result to the output queue.""" while True: # wait for a new batch in the source queue and load it batch_str = queue_source.get() batch = pickle.loads(batch_str) # augment the batch if batch.images is not None and batch.keypoints is not None: augseq_det = augseq.to_deterministic() batch.images_aug = augseq_det.augment_images(batch.images) batch.keypoints_aug = augseq_det.augment_keypoints(batch.keypoints) elif batch.images is not None: batch.images_aug = augseq.augment_images(batch.images) elif batch.keypoints is not None: batch.keypoints_aug = augseq.augment_keypoints(batch.keypoints) # send augmented batch to output queue queue_result.put(pickle.dumps(batch, protocol=-1)) if __name__ == "__main__": main()