""" Tianwei Shen, HKUST, 2018 - 2019. DeepSlam class defines the training procedure and losses """ from __future__ import division import os import time import math import numpy as np import tensorflow as tf import tensorflow.contrib.slim as slim from data_loader import DataLoader from nets import * from geo_utils import get_relative_pose, projective_inverse_warp, pose_vec2mat, mat2euler, \ fundamental_matrix_from_rt, reprojection_error class DeepSlam(object): def __init__(self): pass def build_train_graph(self): '''[summary] build training graph Returns: data loader and batch sample for train() to initialize undefined placeholders ''' opt = self.opt is_read_pose = opt.with_pose or opt.pose_weight > 0 loader = DataLoader(opt.dataset_dir, opt.batch_size, opt.img_height, opt.img_width, opt.num_source, opt.num_scales, is_read_pose, opt.match_num) with tf.name_scope("data_loading"): batch_sample = loader.load_train_batch() # give additional batch_size info since the input is undetermined placeholder inputs = batch_sample.get_next() tgt_image = inputs[0] src_image_stack = inputs[1] intrinsics = inputs[2] #[bs, 128, 416, 3] tgt_image.set_shape([opt.batch_size, opt.img_height, opt.img_width, 3]) # [bs, 128, 416, 6] src_image_stack.set_shape([opt.batch_size, opt.img_height, opt.img_width, 3*opt.num_source]) # [bs, 4, 3, 3] intrinsics.set_shape([opt.batch_size, opt.num_scales, 3, 3]) if is_read_pose: poses = inputs[3] poses.set_shape([opt.batch_size, 3, 6]) if opt.match_num > 0: matches = inputs[3] matches.set_shape([opt.batch_size, opt.num_source, opt.match_num, 4]) tgt_image = self.preprocess_image(tgt_image) src_image_stack = self.preprocess_image(src_image_stack) with tf.name_scope("depth_prediction"): pred_disp, _ = disp_net_res50(tgt_image, is_training=True) if opt.with_pose: # cannot normalize pose here due to given scale pred_depth = [1. / d for d in pred_disp] else: pred_depth = [1. / self.spatial_normalize(d) for d in pred_disp] with tf.name_scope("pose_and_explainability_prediction"): pred_poses, _ = pose_net(tgt_image, src_image_stack, is_training=True) with tf.name_scope("compute_loss"): pixel_loss = 0 smooth_loss = 0 pose_loss = 0 ssim_loss = 0 match_loss = 0 tgt_image_all = [] src_image_stack_all = [] mask_stack_all = [] proj_image_stack_all = [] proj_error_stack_all = [] for s in range(opt.num_scales): # Scale the source and target images for computing loss at the according scale. curr_tgt_image = tf.image.resize_area(tgt_image, [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) curr_src_image_stack = tf.image.resize_area(src_image_stack, [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))]) if opt.smooth_weight > 0: smooth_loss += opt.smooth_weight/(2**s) * \ self.compute_smooth_loss(pred_disp[s], curr_tgt_image) for i in range(opt.num_source): # Inverse warp the source image to the target image frame if is_read_pose: relative_pose = get_relative_pose(poses[:,0,:], poses[:,i+1,:]) relative_rot = tf.slice(relative_pose, [0, 0, 0], [-1, 3, 3]) relative_rot_vec = mat2euler(relative_rot) relative_trans_vec = tf.slice(relative_pose, [0, 0, 3], [-1, 3, 1]) relative_pose_vec = tf.squeeze(tf.concat([relative_rot_vec, relative_trans_vec], axis=1)) if opt.with_pose: warp_pose = relative_pose pose_is_vec = False else: warp_pose = pred_poses[:,i,:] pose_is_vec = True curr_proj_image, mask = projective_inverse_warp( curr_src_image_stack[:,:,:,3*i:3*(i+1)], tf.squeeze(pred_depth[s], axis=3), warp_pose, intrinsics[:,s,:,:], is_vec=pose_is_vec) curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image) curr_proj_error = tf.multiply(curr_proj_error, mask) # below-threshold mask perct_thresh = tf.contrib.distributions.percentile(curr_proj_error, q=99, axis=[1,2]) perct_thresh = tf.expand_dims(tf.expand_dims(perct_thresh, 1), 1) curr_proj_error = tf.clip_by_value(curr_proj_error, 0, perct_thresh) above_perct_thresh_region = tf.reduce_max(tf.cast(tf.equal(curr_proj_error, perct_thresh), 'float32'), axis=3) above_perct_thresh_region = tf.greater_equal(above_perct_thresh_region, 1.0) suppresion_mask = tf.expand_dims(1.0 - tf.cast(above_perct_thresh_region, 'float32'), axis=3) curr_proj_error = tf.multiply(curr_proj_error, suppresion_mask) mask = tf.multiply(mask, suppresion_mask) pixel_loss += tf.reduce_mean(curr_proj_error) # SSIM loss if opt.ssim_weight > 0: ssim_mask = slim.avg_pool2d(mask, 3, 1, 'VALID') ssim_loss += tf.reduce_mean( ssim_mask * self.compute_ssim_loss(curr_proj_image, curr_tgt_image)) # Relative pose error if opt.pose_weight > 0 and s == 0: # only do it for highest resolution pose_loss += tf.reduce_mean(self.compute_pose_loss(relative_pose_vec, pred_poses[:, i, :])) # Matches loss (fundamental matrix) if opt.match_num > 0 and s == 0: # only do it for highest resolution match_loss += self.compute_match_loss(matches[:, i, :, :], tf.squeeze( pred_depth[s], axis=3), pred_poses[:, i, :], intrinsics[:, s, :, :]) # Prepare images for tensorboard summaries if i == 0: proj_image_stack = curr_proj_image mask_stack = mask proj_error_stack = curr_proj_error else: proj_image_stack = tf.concat([proj_image_stack, curr_proj_image], axis=3) mask_stack = tf.concat([mask_stack, mask], axis=3) proj_error_stack = tf.concat([proj_error_stack, curr_proj_error], axis=3) tgt_image_all.append(curr_tgt_image) src_image_stack_all.append(curr_src_image_stack) proj_image_stack_all.append(proj_image_stack) mask_stack_all.append(mask_stack) proj_error_stack_all.append(proj_error_stack) total_loss = opt.ssim_weight * ssim_loss + \ (1 - opt.ssim_weight) * pixel_loss + \ smooth_loss + opt.pose_weight * pose_loss + opt.match_weight * match_loss with tf.name_scope("train_op"): train_vars = [var for var in tf.trainable_variables()] optim = tf.train.AdamOptimizer(opt.learning_rate, opt.beta1) # self.grads_and_vars = optim.compute_gradients(total_loss, # var_list=train_vars) # self.train_op = optim.apply_gradients(self.grads_and_vars) self.train_op = slim.learning.create_train_op(total_loss, optim) self.global_step = tf.Variable(0, name='global_step', trainable=False) self.incr_global_step = tf.assign(self.global_step, self.global_step+1) # Collect tensors that are useful later (e.g. tf summary) self.pred_depth = pred_depth self.pred_poses = pred_poses self.steps_per_epoch = loader.steps_per_epoch self.total_loss = total_loss self.pixel_loss = pixel_loss self.pose_loss = pose_loss self.smooth_loss = smooth_loss self.ssim_loss = ssim_loss self.match_loss = match_loss self.tgt_image_all = tgt_image_all self.src_image_stack_all = src_image_stack_all self.proj_image_stack_all = proj_image_stack_all self.mask_stack_all = mask_stack_all self.proj_error_stack_all = proj_error_stack_all return loader, batch_sample def compute_smooth_loss(self, disp, img): def _gradient(pred): D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :] D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :] return D_dx, D_dy disp_gradients_x, disp_gradients_y = _gradient(disp) image_gradients_x, image_gradients_y = _gradient(img) weights_x = tf.exp(-tf.reduce_mean(tf.abs(image_gradients_x), 3, keep_dims=True)) weights_y = tf.exp(-tf.reduce_mean(tf.abs(image_gradients_y), 3, keep_dims=True)) smoothness_x = disp_gradients_x * weights_x smoothness_y = disp_gradients_y * weights_y return tf.reduce_mean(tf.abs(smoothness_x)) + tf.reduce_mean(tf.abs(smoothness_y)) def compute_pose_loss(self, prior_pose_vec, pred_pose_vec): rot_vec_err = tf.norm(prior_pose_vec[:,:3] - pred_pose_vec[:,:3], axis=1) trans_err = tf.norm(tf.nn.l2_normalize( prior_pose_vec[:, 3:], dim=1) - tf.nn.l2_normalize(pred_pose_vec[:, 3:], dim=1), axis=1) return rot_vec_err + trans_err # reference https://github.com/tensorflow/models/tree/master/research/vid2depth/model.py def compute_ssim_loss(self, x, y): """Computes a differentiable structured image similarity measure.""" c1 = 0.01**2 c2 = 0.03**2 mu_x = slim.avg_pool2d(x, 3, 1, 'VALID') mu_y = slim.avg_pool2d(y, 3, 1, 'VALID') sigma_x = slim.avg_pool2d(x**2, 3, 1, 'VALID') - mu_x**2 sigma_y = slim.avg_pool2d(y**2, 3, 1, 'VALID') - mu_y**2 sigma_xy = slim.avg_pool2d(x * y, 3, 1, 'VALID') - mu_x * mu_y ssim_n = (2 * mu_x * mu_y + c1) * (2 * sigma_xy + c2) ssim_d = (mu_x**2 + mu_y**2 + c1) * (sigma_x + sigma_y + c2) ssim = ssim_n / ssim_d return tf.clip_by_value((1 - ssim) / 2, 0, 1) # reference: https://github.com/yzcjtr/GeoNet/blob/master/geonet_model.py # and https://arxiv.org/abs/1712.00175 def spatial_normalize(self, disp): _, curr_h, curr_w, curr_c = disp.get_shape().as_list() disp_mean = tf.reduce_mean(disp, axis=[1,2,3], keep_dims=True) disp_mean = tf.tile(disp_mean, [1, curr_h, curr_w, curr_c]) return disp/disp_mean def normalize_for_show(self, disp, thresh=90): disp_max = tf.contrib.distributions.percentile(disp, q=thresh, axis=[1,2]) disp_max = tf.expand_dims(tf.expand_dims(disp_max, 1), 1) clip_disp = tf.clip_by_value(disp, 0, disp_max) return clip_disp def compute_match_loss(self, matches, pred_depth, pose, intrinsics): points1 = tf.slice(matches, [0, 0, 0], [-1, -1, 2]) points2 = tf.slice(matches, [0, 0, 2], [-1, -1, 2]) ones = tf.ones([self.opt.batch_size, self.opt.match_num, 1]) points1 = tf.concat([points1, ones], axis=2) points2 = tf.concat([points2, ones], axis=2) match_num = matches.get_shape().as_list()[1] # compute fundamental matrix loss fmat = fundamental_matrix_from_rt(pose, intrinsics) fmat = tf.expand_dims(fmat, axis=1) fmat_tiles = tf.tile(fmat, [1, match_num, 1, 1]) epi_lines = tf.matmul(fmat_tiles, tf.expand_dims(points1, axis=3)) dist_p2l = tf.abs(tf.matmul(tf.transpose(epi_lines, perm=[0, 1, 3, 2]), tf.expand_dims(points2, axis=3))) a = tf.slice(epi_lines, [0,0,0,0], [-1,-1,1,-1]) b = tf.slice(epi_lines, [0,0,1,0], [-1,-1,1,-1]) dist_div = tf.sqrt(a*a + b*b) + 1e-6 dist_p2l = tf.reduce_mean(dist_p2l / dist_div) return dist_p2l def collect_summaries(self): opt = self.opt tf.summary.scalar("total_loss", self.total_loss) tf.summary.scalar("pixel_loss", self.pixel_loss) if opt.smooth_weight > 0: tf.summary.scalar("smooth_loss", self.smooth_loss) if opt.pose_weight > 0: tf.summary.scalar("pose_loss", self.pose_loss) if opt.ssim_weight > 0: tf.summary.scalar("ssim_loss", self.ssim_loss) if opt.match_num > 0: tf.summary.scalar("match_loss", self.match_loss) #for s in range(opt.num_scales): s = 0 # only show the error images of the highest resolution (scale 0) tf.summary.histogram("scale%d_depth" % s, self.pred_depth[s]) shown_disparity_image = self.normalize_for_show(1./self.pred_depth[s]) tf.summary.image('scale%d_disparity_image' % s, shown_disparity_image) tf.summary.image('scale%d_target_image' % s, self.deprocess_image(self.tgt_image_all[s])) for i in range(opt.num_source): tf.summary.image( 'scale%d_source_image_%d' % (s, i), self.deprocess_image(self.src_image_stack_all[s][:, :, :, i*3:(i+1)*3])) proj_images = self.deprocess_image(self.proj_image_stack_all[s][:, :, :, i*3:(i+1)*3]) mask_images = self.mask_stack_all[s][:, :, :, i:i+1] proj_error_images = self.deprocess_image(tf.clip_by_value( self.proj_error_stack_all[s][:, :, :, i*3:(i+1)*3] - 1, -1, 1)) tf.summary.image('scale%d_projected_image_%d' % (s, i), proj_images) tf.summary.image('scale%d_proj_error_%d' % (s, i), proj_error_images) tf.summary.image('scale%d_mask_%d' % (s, i), mask_images) def train(self, opt): opt.num_source = opt.seq_length - 1 self.opt = opt if opt.match_num > 0: # don't use match and pose at the same time opt.with_pose = False data_loader, batch_sample = self.build_train_graph() self.collect_summaries() with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \ for v in tf.trainable_variables()]) self.saver = tf.train.Saver([var for var in tf.model_variables()] + \ [self.global_step], max_to_keep=None) sv = tf.train.Supervisor(logdir=opt.checkpoint_dir, save_summaries_secs=0, saver=None) config = tf.ConfigProto() config.gpu_options.allow_growth = True with sv.managed_session(config=config) as sess: data_loader.init_data_pipeline(sess, batch_sample) print('Trainable variables: ') for var in tf.trainable_variables(): print(var.name) print("parameter_count =", sess.run(parameter_count)) if opt.continue_train: if opt.init_ckpt_file is None: checkpoint = tf.train.latest_checkpoint(opt.checkpoint_dir) else: checkpoint = opt.init_ckpt_file print("Resume training from previous checkpoint: %s" % checkpoint) self.saver.restore(sess, checkpoint) start_time = time.time() for step in range(0, opt.max_steps): fetches = {"train": self.train_op, "global_step": self.global_step, "incr_global_step": self.incr_global_step} if step % opt.summary_freq == 0: fetches["total_loss"] = self.total_loss fetches["pixel_loss"] = self.pixel_loss fetches["smooth_loss"] = self.smooth_loss fetches["summary"] = sv.summary_op if opt.pose_weight > 0: fetches["pose_loss"] = self.pose_loss results = sess.run(fetches) gs = results["global_step"] if step % opt.summary_freq == 0: sv.summary_writer.add_summary(results["summary"], gs) train_epoch = math.ceil(gs / self.steps_per_epoch) train_step = gs - (train_epoch - 1) * self.steps_per_epoch print("Epoch: [%2d] [%5d/%5d] time: %4.4f" % (train_epoch, train_step, self.steps_per_epoch, (time.time() - start_time)/opt.summary_freq)) print("total/pixel/smooth loss: [%.3f/%.3f/%.3f]\n" % ( results["total_loss"], results["pixel_loss"], results["smooth_loss"])) start_time = time.time() # save model if step != 0 and step % opt.save_freq == 0: self.save(sess, opt.checkpoint_dir, gs-1) def select_tensor_or_placeholder_input(self, input_uint8): if input_uint8 == None: input_uint8 = tf.placeholder(tf.uint8, [self.batch_size, self.img_height, self.img_width, 3], name='raw_input') self.inputs = input_uint8 else: self.inputs = None input_mc = self.preprocess_image(input_uint8) return input_mc def build_depth_test_graph(self, input_uint8): input_mc = self.select_tensor_or_placeholder_input(input_uint8) with tf.name_scope("depth_prediction"): pred_disp, depth_net_endpoints = disp_net_res50(input_mc, is_training=False) pred_depth = [1./disp for disp in pred_disp] pred_depth = pred_depth[0] self.pred_depth = pred_depth self.depth_epts = depth_net_endpoints def build_pose_test_graph(self, input_uint8): input_mc = self.select_tensor_or_placeholder_input(input_uint8) loader = DataLoader() tgt_image, src_image_stack = \ loader.batch_unpack_image_sequence( input_mc, self.img_height, self.img_width, self.num_source) with tf.name_scope("pose_prediction"): pred_poses, _ = pose_net(tgt_image, src_image_stack, is_training=False) self.pred_poses = pred_poses def preprocess_image(self, image): # Assuming input image is uint8 image = tf.image.convert_image_dtype(image, dtype=tf.float32) return image * 2. -1. def deprocess_image(self, image): # Assuming input image is float32 image = (image + 1.)/2. return tf.image.convert_image_dtype(image, dtype=tf.uint8) def setup_inference(self, img_height, img_width, mode, seq_length=3, batch_size=1, input_img_uint8=None): self.img_height = img_height self.img_width = img_width self.mode = mode self.batch_size = batch_size if self.mode == 'depth': self.build_depth_test_graph(input_img_uint8) if self.mode == 'pose': self.seq_length = seq_length self.num_source = seq_length - 1 self.build_pose_test_graph(input_img_uint8) def inference(self, sess, mode, inputs=None): fetches = {} if mode == 'depth': fetches['depth'] = self.pred_depth if mode == 'pose': fetches['pose'] = self.pred_poses if inputs is None: results = sess.run(fetches) else: results = sess.run(fetches, feed_dict={self.inputs:inputs}) return results def save(self, sess, checkpoint_dir, step): model_name = 'model' print(" [*] Saving checkpoint step %d to %s..." % (step, checkpoint_dir)) self.saver.save(sess, os.path.join(checkpoint_dir, model_name), global_step=step)