"""Trains a CNN to detect the current steering wheel angle from images.""" from __future__ import print_function, division import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), '..')) import models from lib import replay_memory from lib import util from lib.util import to_variable, to_cuda, to_numpy from lib import plotting from config import Config from scipy import misc import imgaug as ia from imgaug import augmenters as iaa from imgaug import parameters as iap import time import numpy as np import random import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import multiprocessing import threading import argparse np.random.seed(42) random.seed(42) torch.manual_seed(42) if sys.version_info[0] == 2: import cPickle as pickle elif sys.version_info[0] == 3: import pickle xrange = range # train for N batches NB_BATCHES = 50000 # size of each batch BATCH_SIZE = 128 # save/val/plot every N batches SAVE_EVERY = 500 VAL_EVERY = 500 PLOT_EVERY = 500 # use N batches for validation (loss will be averaged) NB_VAL_BATCHES = 128 # input image height/width MODEL_HEIGHT = 32 MODEL_WIDTH = 64 # size of each bin in degrees ANGLE_BIN_SIZE = 5 def main(): """Function that initializes the training (e.g. models) and runs the batches.""" parser = argparse.ArgumentParser(description="Train steering wheel tracker") parser.add_argument('--nocontinue', default=False, action="store_true", help="Whether to NOT continue the previous experiment", required=False) args = parser.parse_args() if os.path.isfile("steering_wheel.tar") and not args.nocontinue: checkpoint = torch.load("steering_wheel.tar") else: checkpoint = None if checkpoint is not None: history = plotting.History.from_string(checkpoint["history"]) else: history = plotting.History() history.add_group("loss", ["train", "val"], increasing=False) history.add_group("acc", ["train", "val"], increasing=True) loss_plotter = plotting.LossPlotter( history.get_group_names(), history.get_groups_increasing(), save_to_fp="train_plot.jpg" ) loss_plotter.start_batch_idx = 100 tracker_cnn = models.SteeringWheelTrackerCNNModel() tracker_cnn.train() optimizer = optim.Adam(tracker_cnn.parameters()) criterion = nn.CrossEntropyLoss() #criterion = nn.BCELoss() if checkpoint is not None: tracker_cnn.load_state_dict(checkpoint["tracker_cnn_state_dict"]) if Config.GPU >= 0: tracker_cnn.cuda(Config.GPU) criterion.cuda(Config.GPU) # initialize image augmentation cascade rarely = lambda aug: iaa.Sometimes(0.1, aug) sometimes = lambda aug: iaa.Sometimes(0.2, aug) often = lambda aug: iaa.Sometimes(0.4, aug) augseq = iaa.Sequential([ sometimes(iaa.Crop(percent=(0, 0.025))), rarely(iaa.GaussianBlur((0, 1.0))), # blur images with a sigma between 0 and 3.0 rarely(iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.02*255), per_channel=0.5)), # add gaussian noise to images often(iaa.Dropout( iap.FromLowerResolution( other_param=iap.Binomial(1 - 0.2), size_px=(2, 16) ), per_channel=0.2 )), often(iaa.Add((-20, 20), per_channel=0.5)), # change brightness of images (by -10 to 10 of original value) often(iaa.Multiply((0.8, 1.2), per_channel=0.25)), # change brightness of images (50-150% of original value) often(iaa.ContrastNormalization((0.8, 1.2), per_channel=0.5)), # improve or worsen the contrast often(iaa.Affine( scale={"x": (0.8, 1.3), "y": (0.8, 1.3)}, translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, rotate=(-0, 0), shear=(-0, 0), order=[0, 1], cval=(0, 255), mode=["constant", "edge"] )), rarely(iaa.Grayscale(alpha=(0.0, 1.0))) ], random_order=True # do all of the above in random order ) #memory = replay_memory.ReplayMemory.get_instance_supervised() batch_loader_train = BatchLoader(val=False, augseq=augseq, queue_size=15, nb_workers=4) batch_loader_val = BatchLoader(val=True, augseq=iaa.Noop(), queue_size=NB_VAL_BATCHES, nb_workers=2) start_batch_idx = 0 if checkpoint is None else checkpoint["batch_idx"] + 1 for batch_idx in xrange(start_batch_idx, NB_BATCHES): run_batch(batch_idx, False, batch_loader_train, tracker_cnn, criterion, optimizer, history, (batch_idx % 20) == 0) if (batch_idx+1) % VAL_EVERY == 0: for i in xrange(NB_VAL_BATCHES): run_batch(batch_idx, True, batch_loader_val, tracker_cnn, criterion, optimizer, history, i == 0) if (batch_idx+1) % PLOT_EVERY == 0: loss_plotter.plot(history) # every N batches, save a checkpoint if (batch_idx+1) % SAVE_EVERY == 0: torch.save({ "batch_idx": batch_idx, "history": history.to_string(), "tracker_cnn_state_dict": tracker_cnn.state_dict() }, "steering_wheel.tar") def run_batch(batch_idx, val, batch_loader, tracker_cnn, criterion, optimizer, history, save_debug_image): """Train or validate on a single batch.""" train = not val time_cbatch_start = time.time() inputs, outputs_gt = batch_loader.get_batch() if Config.GPU >= 0: inputs = to_cuda(to_variable(inputs, volatile=val), Config.GPU) outputs_gt_bins = to_cuda(to_variable(np.argmax(outputs_gt, axis=1), volatile=val, requires_grad=False), Config.GPU) outputs_gt = to_cuda(to_variable(outputs_gt, volatile=val, requires_grad=False), Config.GPU) time_cbatch_end = time.time() time_fwbw_start = time.time() if train: optimizer.zero_grad() outputs_pred = tracker_cnn(inputs) outputs_pred_sm = F.softmax(outputs_pred) loss = criterion(outputs_pred, outputs_gt_bins) if train: loss.backward() optimizer.step() time_fwbw_end = time.time() loss = loss.data.cpu().numpy()[0] outputs_pred_np = to_numpy(outputs_pred_sm) outputs_gt_np = to_numpy(outputs_gt) acc = np.sum(np.equal(np.argmax(outputs_pred_np, axis=1), np.argmax(outputs_gt_np, axis=1))) / BATCH_SIZE history.add_value("loss", "train" if train else "val", batch_idx, loss, average=val) history.add_value("acc", "train" if train else "val", batch_idx, acc, average=val) print("[%s] Batch %05d | loss %.8f | acc %.2f | cbatch %.04fs | fwbw %.04fs" % ("T" if train else "V", batch_idx, loss, acc, time_cbatch_end - time_cbatch_start, time_fwbw_end - time_fwbw_start)) if save_debug_image: debug_img = generate_debug_image(inputs, outputs_gt, outputs_pred_sm) misc.imsave("debug_img_%s.jpg" % ("train" if train else "val"), debug_img) def generate_debug_image(inputs, outputs_gt, outputs_pred): """Draw an image with current ground truth and predictions for debug purposes.""" current_image = inputs.data[0].cpu().numpy() current_image = np.clip(current_image * 255, 0, 255).astype(np.uint8).transpose((1, 2, 0)) current_image = ia.imresize_single_image(current_image, (32*4, 64*4)) h, w = current_image.shape[0:2] outputs_gt = to_numpy(outputs_gt)[0] outputs_pred = to_numpy(outputs_pred)[0] binwidth = 6 outputs_grid = np.zeros((20+2, outputs_gt.shape[0]*binwidth, 3), dtype=np.uint8) for angle_bin_idx in xrange(outputs_gt.shape[0]): val = outputs_pred[angle_bin_idx] x_start = angle_bin_idx*binwidth x_end = (angle_bin_idx+1)*binwidth fill_start = 1 fill_end = 1 + int(20*val) #print(angle_bin_idx, x_start, x_end, fill_start, fill_end, outputs_grid.shape, outputs_grid[fill_start:fill_end, x_start+1:x_end].shape) if fill_start < fill_end: outputs_grid[fill_start:fill_end, x_start+1:x_end] = [255, 255, 255] bordercol = [128, 128, 128] if outputs_gt[angle_bin_idx] < 1 else [0, 0, 255] outputs_grid[0:22, x_start:x_start+1] = bordercol outputs_grid[0:22, x_end:x_end+1] = bordercol outputs_grid[0, x_start:x_end+1] = bordercol outputs_grid[21, x_start:x_end+1] = bordercol outputs_grid = outputs_grid[::-1, :, :] bin_gt = np.argmax(outputs_gt) bin_pred = np.argmax(outputs_pred) angles = [(binidx*ANGLE_BIN_SIZE) - 180 for binidx in [bin_gt, bin_pred]] #print(outputs_grid.shape) current_image = np.pad(current_image, ((0, 128), (0, 400), (0, 0)), mode="constant") current_image[h+4:h+4+22, 4:4+outputs_grid.shape[1], :] = outputs_grid current_image = util.draw_text(current_image, x=4, y=h+4+22+4, text="GT: %03.2fdeg\nPR: %03.2fdeg" % (angles[0], angles[1]), size=10) return current_image def extract_steering_wheel_image(screenshot_rs): """Extract the part of a screenshot (resized to 180x320 HxW) that usually contains the steering wheel.""" h, w = screenshot_rs.shape[0:2] x1 = int(w * (470/1280)) x2 = int(w * (840/1280)) y1 = int(h * (500/720)) y2 = int(h * (720/720)) return screenshot_rs[y1:y2+1, x1:x2+1, :] def downscale_image(steering_wheel_image): """Downscale an image to the model's input sizes (height, width).""" return ia.imresize_single_image( steering_wheel_image, (MODEL_HEIGHT, MODEL_WIDTH), interpolation="linear" ) def load_random_state(memory, depth=0): """Load a single random state from the replay memory which has a steering wheel position attached (estimated via classical means).""" rndidx = random.randint(memory.id_min, memory.id_max) state = memory.get_state_by_id(rndidx) if state.steering_wheel_classical is None: if depth+1 >= 200: raise Exception("Maximum depth reached in load_random_state(), \ too many states with None in column steering_wheel_classical. \ Use scripts/add_steering_wheel.py to recalculate missing values.") return load_random_state(memory, depth=depth+1) else: return state def load_random_batch(memory, augseq, batch_size): """Load a random batch from the replay memory for training. augseq contains the image augmentation sequence to use.""" inputs = np.zeros((batch_size, MODEL_HEIGHT, MODEL_WIDTH, 3), dtype=np.uint8) outputs = np.zeros((batch_size, 360//ANGLE_BIN_SIZE), dtype=np.float32) for b_idx in xrange(batch_size): state = load_random_state(memory) subimg = extract_steering_wheel_image(state.screenshot_rs) subimg = augseq.augment_image(subimg) subimg = downscale_image(subimg) inputs[b_idx] = subimg deg = state.steering_wheel_classical % 360 if -360 <= deg < -180: deg = 360 - deg elif -180 <= deg < 0: pass elif 0 <= deg < 180: pass elif 180 <= deg < 360: deg = -360 + deg deg = 180 + deg bin_idx = int(deg / ANGLE_BIN_SIZE) outputs[b_idx, bin_idx] = 1 inputs = (inputs / 255).astype(np.float32).transpose((0, 3, 1, 2)) return inputs, outputs class BatchLoader(object): """Class to load batches in multiple background processes.""" def __init__(self, val, queue_size, augseq, nb_workers, threaded=False): self.queue = multiprocessing.Queue(queue_size) self.workers = [] for i in range(nb_workers): seed = random.randint(0, 10**6) augseq_worker = augseq.deepcopy() if threaded: worker = threading.Thread(target=self._load_batches, args=(val, self.queue, augseq_worker, None)) else: worker = multiprocessing.Process(target=self._load_batches, args=(val, self.queue, augseq_worker, seed)) worker.daemon = True worker.start() self.workers.append(worker) def get_batch(self): return pickle.loads(self.queue.get()) def _load_batches(self, val, queue, augseq_worker, seed): if seed is None: random.seed(seed) np.random.seed(seed) augseq_worker.reseed(seed) ia.seed(seed) memory = replay_memory.ReplayMemory.create_instance_reinforced(val=val) while True: batch = load_random_batch(memory, augseq_worker, BATCH_SIZE) queue.put(pickle.dumps(batch, protocol=-1)) if __name__ == "__main__": main()