import tensorflow as tf from config import Config import numpy as np import random from tensorflow.python.training import moving_averages from tensorflow.python.ops import control_flow_ops from main import HEIGHT, WIDTH MOVING_AVERAGE_DECAY = 0.99 BN_DECAY = MOVING_AVERAGE_DECAY BN_EPSILON = 0.001 CONV_WEIGHT_DECAY = 0.00001 CONV_WEIGHT_STDDEV = 0.05 GC_VARIABLES = 'gc_variables' UPDATE_OPS_COLLECTION = 'gc_update_ops' # training ops # HEIGHT = 256 # WIDTH = 512 DISPARITY = 192 # wrapper for 2d convolution op def conv(x, c): ksize = c['ksize'] stride = c['stride'] filters_out = c['conv_filters_out'] filters_in = x.get_shape()[-1] shape = [ksize, ksize, filters_in, filters_out] # initializer = tf.truncated_normal_initializer(stddev=CONV_WEIGHT_STDDEV) initializer = tf.contrib.layers.xavier_initializer() weights = _get_variable('weights', shape=shape, #dtype='float', initializer=initializer, weight_decay=CONV_WEIGHT_DECAY) bias = tf.get_variable('bias', [filters_out], 'float', tf.constant_initializer(0.05, dtype='float')) x = tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding='SAME') return tf.nn.bias_add(x, bias) def conv_3d(x, c): ksize = c['ksize'] stride = c['stride'] filters_out = c['conv_filters_out'] filters_in = x.get_shape()[-1] shape = [ksize, ksize, ksize, filters_in, filters_out] # initializer = tf.truncated_normal_initializer(stddev=CONV_WEIGHT_STDDEV) initializer = tf.contrib.layers.xavier_initializer() weights = _get_variable('weights', shape=shape, #dtype='float', initializer=initializer, weight_decay=CONV_WEIGHT_DECAY) bias = tf.get_variable('bias', [filters_out], 'float', tf.constant_initializer(0.05, dtype='float')) x = tf.nn.conv3d(x, weights, [1, stride, stride, stride, 1], padding='SAME') return tf.nn.bias_add(x, bias) def deconv_3d(x, c): ksize = c['ksize'] stride = c['stride'] filters_out = c['conv_filters_out'] filters_in = x.get_shape()[-1] # must have as_list to get a python list!!!!!!!!!!!!!! x_shape = x.get_shape().as_list() d = x_shape[1] * stride height = x_shape[2] * stride width = x_shape[3] * stride output_shape = [1, d, height, width, filters_out] strides = [1, stride, stride, stride, 1] shape = [ksize, ksize, ksize, filters_out, filters_in] # initializer = tf.truncated_normal_initializer(stddev=CONV_WEIGHT_STDDEV) initializer = tf.contrib.layers.xavier_initializer() weights = _get_variable('weights', shape=shape, dtype='float32', initializer=initializer, weight_decay=CONV_WEIGHT_DECAY) bias = tf.get_variable('bias', [filters_out], 'float32', tf.constant_initializer(0.05, dtype='float32')) x = tf.nn.conv3d_transpose(x, weights, output_shape=output_shape, strides=strides, padding='SAME') return tf.nn.bias_add(x, bias) # wrapper for batch-norm op def bn(x, c): x_shape = x.get_shape() params_shape = x_shape[-1:] axis = list(range(len(x_shape) - 1)) beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer()) #tf.constant_initializer(0.00, dtype='float') gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer()) moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer(), trainable=False) moving_variance = _get_variable('moving_variance', params_shape, initializer=tf.ones_initializer(), trainable=False) # These ops will only be performed when training. mean, variance = tf.nn.moments(x, axis) update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, BN_DECAY) tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean) tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance) mean, variance = control_flow_ops.cond( c['is_training'], lambda: (mean, variance), lambda: (moving_mean, moving_variance)) x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON) return x # wrapper for get_variable op def _get_variable(name, shape, initializer, weight_decay=0.0, dtype='float32', trainable=True): "A little wrapper around tf.get_variable to do weight decay and add to" "resnet collection" if weight_decay > 0: regularizer = tf.contrib.layers.l2_regularizer(weight_decay) else: regularizer = None collections = [tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES] return tf.get_variable(name, shape=shape, initializer=initializer, dtype=dtype, regularizer=regularizer, collections=collections, trainable=trainable) # resnet block def stack(x, c): shortcut = x with tf.variable_scope('block_A'): x = conv(x, c) x = bn(x, c) x = tf.nn.relu(x) with tf.variable_scope('block_B'): x = conv(x, c) x = bn(x, c) x = shortcut + x x = tf.nn.relu(x) return x # siamese structure def _build_resnet(x, c): # imageL = tf.placeholder(tf.float32, shape=([1, HEIGHT, WIDTH]), name='L') # imageR = tf.placeholder(tf.float32, shape=([1, HEIGHT, WIDTH]), name='R') with tf.variable_scope('downsample'): c['conv_filters_out'] = 24 # 32 c['ksize'] = 5 c['stride'] = 2 x = conv(x, c) x = bn(x, c) x = tf.nn.relu(x) c['ksize'] = 3 c['stride'] = 1 with tf.variable_scope('resnet'): for i in xrange(c['num_resblock']): with tf.variable_scope('res' + str(i+1)): x = stack(x, c) return x def _build_3d_conv(x, c): c['ksize'] = 3 c['stride'] = 1 c['conv_filters_out'] = 32 for i in xrange(c['num_3d']): with tf.variable_scope(str(i)+'3d'): x = conv_3d(cost_vol, c) x = bn(x, c) x = tf.nn.relu(x) x = conv_3d(cost_vol, c) x = bn(x, c) x = tf.nn.relu(x) c['stride'] = 2 if i is 3: c['conv_filters_out'] = 128 else: c['conv_filters_out'] = 64 x = conv_3d(cost_vol, c) x = bn(x, c) x = tf.nn.relu(x) c['stride'] = 1 c['conv_filters_out'] = 64 return x def inference(left_x, right_x, is_training): # imageL = tf.placeholder(tf.float32, shape=([1, HEIGHT, WIDTH]), name='L') # imageR = tf.placeholder(tf.float32, shape=([1, HEIGHT, WIDTH]), name='R') c = Config() c['is_training'] = tf.convert_to_tensor(is_training, dtype = 'bool', name = 'is_training') c['num_resblock'] = 8 # totally 8 resnet blocks c['num_3d'] = 2 # totally 4 blocks of 3d conv c['conv_filters_out'] = 24 # 32 with tf.variable_scope("siamese") as scope: left_features = _build_resnet(left_x, c) scope.reuse_variables() right_features = _build_resnet(right_x, c) # create cost volume # cost_vol = tf.zeros([DISPARITY/2, HEIGHT/2, WIDTH/2, 2*c['conv_filters_out']], tf.float32) # left_feature = tf.slice(left_features, [0, 0, 0, 0], [1, HEIGHT/2, WIDTH/2 ,1]) # right_feature = tf.slice(right_features, [0, 0, 0, 0], [1, HEIGHT/2, WIDTH/2, 1]) # cost_vol = tf.concat([left_feature, right_feature], 3) # # cost_vol = tf.expand_dims(features_d, -1) # for k in xrange(1, c['conv_filters_out']): # left = tf.slice(left_features, [0, 0, 0, k], [1, HEIGHT/2, WIDTH/2, 1]) # right = tf.slice(right_features, [0, 0, 0, k], [1, HEIGHT/2, WIDTH/2, 1]) # pair = tf.concat([left, right], 3) # cost_vol = tf.concat([cost_vol, pair], 3) # cost_vol = tf.expand_dims(cost_vol, -1) # # for d in xrange(1, DISPARITY/2): # paddings = [[0, 0], [0, 0], [d, 0], [0, 0]] # right_feature_first_d = tf.slice(right_feature, [0, 0, d, 0], [1, HEIGHT/2, WIDTH/2-d, 1]) # right_feature_first_d = tf.pad(right_feature_first_d, paddings, "CONSTANT") # # right_feature_first_d = tf.expand_dims(right_feature_first_d, -1) # feature_pairs = tf.concat([left_feature, right_feature_first_d], 3) # for k in xrange(1, c['conv_filters_out']): # left_feature_ = tf.slice(left_features, [0, 0, 0, k], [1, HEIGHT/2, WIDTH/2, 1]) # right_feature_ = tf.slice(right_features, [0, 0, 0, k], [1, HEIGHT/2, WIDTH/2, 1]) # right_feature_d = tf.slice(right_feature, [0, 0, d, 0], [1, HEIGHT/2, WIDTH/2-d, 1]) # right_feature_d = tf.pad(right_feature_d, paddings, "CONSTANT") # right_feature_d = tf.expand_dims(right_feature_d, -1) # pair = tf.concat([left_feature_, right_feature], 3) # feature_pairs = tf.concat([feature_pairs, pair], 3) # feature_pairs = tf.expand_dims(feature_pairs, 4) # cost_vol = tf.concat([cost_vol, feature_pairs], 4) # # cost_vol = tf.reshape(tf.stack(cost_vol), shape=(DISPARITY/2, HEIGHT/2, WIDTH/2, c['conv_filters_out']*2)) # # cost_vol = tf.expand_dims(cost_vol, 0) # print "------------------------------cost_vol", cost_vol.get_shape().as_list() # create cost volume # cost_vol = [] # left_feature = tf.slice(left_features, [0, 0, 0, 0], [1, HEIGHT/2, WIDTH/2 ,1]) # right_feature = tf.slice(right_features, [0, 0, 0, 0], [1, HEIGHT/2, WIDTH/2, 1]) # for d in xrange(DISPARITY/2): # paddings = [[0, 0], [0, 0], [0, d], [0, 0]] # right_feature_first_d = tf.slice(right_feature, [0, 0, d, 0], [1, HEIGHT/2, WIDTH/2-d, 1]) # right_feature_first_d = tf.pad(right_feature_first_d, paddings, "CONSTANT") # feature_pairs = tf.concat([left_feature, right_feature_first_d], 3) # feature_pairs = tf.squeeze(feature_pairs, 0) # for k in xrange(1, c['conv_filters_out']): # left_feature = tf.slice(left_features, [0, 0, 0, k], [1, HEIGHT/2, WIDTH/2, 1]) # right_feature = tf.slice(right_features, [0, 0, 0, k], [1, HEIGHT/2, WIDTH/2, 1]) # right_feature_d = tf.slice(right_feature, [0, 0, d, 0], [1, HEIGHT/2, WIDTH/2-d, 1]) # right_features_d = tf.pad(right_feature_d, paddings, "CONSTANT") # feature_pair = tf.concat([left_feature, right_features_d], 3) # feature_pair = tf.squeeze(feature_pair) # feature_pairs = tf.concat([feature_pairs, feature_pair], 2) # cost_vol.append(feature_pairs) # cost_vol = tf.stack(cost_vol) # print "cost vol dimension is: ", cost_vol.get_shape().as_list() # cost_vol = tf.expand_dims(cost_vol, 0) cost_vol = [] left_features = tf.squeeze(left_features) right_features = tf.squeeze(right_features) for d in xrange(1, DISPARITY/2+1): paddings = [[0,0], [d,0], [0,0]] for k in xrange(c['conv_filters_out']): left_feature = tf.slice(left_features, [0, 0, k], [HEIGHT/2, WIDTH/2, 1]) right_feature = tf.slice(right_features, [0, 0, k], [HEIGHT/2, WIDTH/2, 1]) right_feature = tf.slice(right_feature, [0, d, 0], [HEIGHT/2, WIDTH/2-d, 1]) right_feature = tf.pad(right_feature, paddings, "CONSTANT") # feature_pair = tf.concat([left_feature, right_feature], 3) cost_vol.append(left_feature) cost_vol.append(right_feature) cost_vol = tf.stack(cost_vol) cost_vol = tf.reshape(cost_vol, shape=(1, DISPARITY/2, 2*c['conv_filters_out'], HEIGHT/2, WIDTH/2)) cost_vol = tf.transpose(cost_vol, [0, 1, 3, 4, 2]) # 3d convolution with tf.variable_scope("3dconv"): c['ksize'] = 3 c['stride'] = 1 c['conv_filters_out'] = 32 with tf.variable_scope(str(0) + '3d'): with tf.variable_scope('A'): x = conv_3d(cost_vol, c) x = bn(x, c) x = tf.nn.relu(x) with tf.variable_scope('B'): x = conv_3d(x, c) x = bn(x, c) x20 = tf.nn.relu(x) c['stride'] = 2 c['conv_filters_out'] = 64 with tf.variable_scope('C'): x = conv_3d(x20, c) x = bn(x, c) x = tf.nn.relu(x) c['conv_filters_out'] = 64 with tf.variable_scope(str(1) + '3d'): c['stride'] = 1 with tf.variable_scope('A'): x = conv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) with tf.variable_scope('B'): x = conv_3d(x, c) x = bn(x, c) x23 = tf.nn.relu(x) c['stride'] = 2 with tf.variable_scope('C'): x = conv_3d(x23, c) x = bn(x, c) x = tf.nn.relu(x) with tf.variable_scope(str(2) + '3d'): c['stride'] = 1 with tf.variable_scope('A'): x = conv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) with tf.variable_scope('B'): x = conv_3d(x, c) x = bn(x, c) x26 = tf.nn.relu(x) c['stride'] = 2 with tf.variable_scope('C'): x = conv_3d(x26, c) x = bn(x, c) x = tf.nn.relu(x) with tf.variable_scope(str(3) + '3d'): c['stride'] = 1 with tf.variable_scope('A'): x = conv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) with tf.variable_scope('B'): x = conv_3d(x, c) x = bn(x, c) x29 = tf.nn.relu(x) c['stride'] = 2 c['conv_filters_out'] = 128 with tf.variable_scope('C'): x = conv_3d(x29, c) x = bn(x, c) x = tf.nn.relu(x) c['stride'] = 1 with tf.variable_scope('D'): c['stride'] = 1 x = conv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) with tf.variable_scope('E'): x = conv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) # 3d deconvolution with tf.variable_scope("deconv"): c['stride'] = 2 c['conv_filters_out'] = 64 c['ksize'] = 3 with tf.variable_scope('A'): x = deconv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) x = x + x29 with tf.variable_scope('B'): x = deconv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) x = x + x26 with tf.variable_scope('C'): x = deconv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) x = x + x23 c['conv_filters_out'] = 32 with tf.variable_scope('D'): x = deconv_3d(x, c) x = bn(x, c) x = tf.nn.relu(x) x = x + x20 c['conv_filters_out'] = 1 with tf.variable_scope('E'): x = deconv_3d(x, c) x = tf.squeeze(x) x = -x # with tf.name_scope('softmax'): # max_axis = tf.reduce_max(x, 2, keep_dims=True) # x = tf.exp(x-max_axis) # normalize = tf.reduce_sum(x, 2, keep_dims=True) # x = x/normalize return x