import functools import numpy as np import tensorflow as tf import config from rfl_net.network import UPDATE_OPS_COLLECTION from rfl_net.rnn import rnn, DropoutWrapper from data_input.prepare_targets import create_labels_overlap from rfl_net.utils import activation_summary from rfl_net.xz_net import XZNet, ConvXZNet, FilterNet from rfl_net.conv_lstm import BasicConvLSTMCell, InitLSTMSate def lazy_property(function): attribute = '_' + function.__name__ @property @functools.wraps(function) def wrapper(self): if not hasattr(self, attribute): setattr(self, attribute, function(self)) return getattr(self, attribute) return wrapper class RFLNet(): def __init__(self, is_train, z_examplar=None, x_crops=None, y_crops=None, init_z_exemplar=None): self._is_train = is_train input_shape = z_examplar.get_shape().as_list() self._batch_size = input_shape[0] self._time_steps = input_shape[1] x_shape = x_crops.get_shape().as_list() self._z_examplar = tf.reshape(z_examplar, [-1, config.z_exemplar_size, config.z_exemplar_size, 3]) self._x_crops = tf.reshape(x_crops, [-1]+ x_shape[2:]) self._y_crops = y_crops self._response_size = config.response_size-int(2*8/config.stride) if config.is_augment and is_train else config.response_size self._gt_pos = tf.convert_to_tensor(np.floor([self._response_size/2, self._response_size/2]), tf.float32) if init_z_exemplar is not None: self.init_z_exemplar = tf.reshape(init_z_exemplar, [-1, config.z_exemplar_size, config.z_exemplar_size, 3]) self.filter self.response if y_crops is not None: self.loss self.dist_error else: self.init_state_filter if is_train: self._global_step = tf.get_variable('global_step', [], tf.int64, initializer=tf.constant_initializer(0), trainable=False) self._lr = tf.train.exponential_decay(config.learning_rate, self._global_step, config.decay_circles, config.lr_decay, staircase=True) tf.summary.scalar('learning_rate', self._lr) self.optimize self._summary = tf.summary.merge_all() self._saver = tf.train.Saver(tf.global_variables()) @lazy_property def init_state_filter(self): with tf.variable_scope('z_net', reuse=True): z_net = XZNet({'input': self.init_z_exemplar}, self._is_train) with tf.variable_scope('init_state'): init_state_net = InitLSTMSate({'input': z_net.get_output(),'state_size':self._state_size}, self._is_train) init_state = init_state_net.get_output() with tf.variable_scope('z_filter', reuse=True): first_output = init_state[1] init_filter_net = FilterNet({'output': first_output}, self._is_train) init_filter = init_filter_net.get_output() return tuple([init_state, init_filter]) @lazy_property def filter(self): # build z_net for reference image with tf.variable_scope('z_net'): z_net = XZNet({'input': self._z_examplar}, self._is_train) # build rnn for filter generation z_output = z_net.get_output() gf_shape = z_output.get_shape().as_list() # building rnn cell rnn_cell = BasicConvLSTMCell(gf_shape[1:3], [config.conv_filter_size, config.conv_filter_size], config.hidden_size, self._is_train, forget_bias=1.0, activation=tf.nn.tanh) if self._is_train and config.keep_prob < 1: rnn_cell = DropoutWrapper(rnn_cell, output_keep_prob=config.keep_prob) # cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell] * config.num_rnn_layers) cell = rnn_cell self._state_size = cell.state_size # reorganize rnn input rnn_inputs = tf.reshape(z_output, [self._batch_size, self._time_steps] + gf_shape[1:4]) if self._is_train and config.keep_prob < 1: rnn_inputs = tf.nn.dropout(rnn_inputs, config.keep_prob) rnn_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(axis=1, num_or_size_splits=self._time_steps, value=rnn_inputs)] if self._is_train or self._y_crops is not None: with tf.variable_scope('init_state'): init_state_net = InitLSTMSate({'input':rnn_inputs[0],'state_size':self._state_size}, self._is_train) initial_state = init_state_net.get_output() rnn_inputs_new = rnn_inputs[1:self._time_steps] outputs, final_state, input_gates, forget_gates, output_gates \ = rnn(cell, rnn_inputs_new, initial_state=initial_state) first_output = initial_state[1] outputs = [first_output] + outputs else: self._initial_state = cell.zero_state(self._batch_size, tf.float32) outputs, final_state, input_gates, forget_gates, output_gates \ = rnn(cell, rnn_inputs, initial_state=self._initial_state) outputs = tf.reshape(tf.concat(axis=1, values=outputs), [-1]+ gf_shape[1:3]+[config.hidden_size]) self._final_state = final_state with tf.variable_scope('z_filter'): f_net = FilterNet({'output': outputs}, self._is_train) z_gf = f_net.get_output() activation_summary(z_output, 'activation/z_output') activation_summary(z_gf, 'activation/z_gf') activation_summary(final_state, 'activation/cell_state') activation_summary(input_gates, 'gates/input') activation_summary(forget_gates, 'gates/forget') activation_summary(output_gates, 'gates/output') return z_gf @lazy_property def response(self): if self._is_train or self._y_crops is not None: self._z_gf = self.filter else: self._z_gf = tf.placeholder(tf.float32, [config.num_scale, 6, 6, 256]) # build x_net for test image if config.share_param: with tf.variable_scope('z_net', reuse=True): x_net = XZNet({'input': self._x_crops}, self._is_train) else: with tf.variable_scope('x_net'): x_net = XZNet({'input': self._x_crops}, self._is_train) # convolve filter with test image x_output = x_net.get_output() conv_xz = ConvXZNet({'z_gf': self._z_gf,'x_output': x_output}, self._is_train) activation_summary(x_output, 'activation/x_output') return conv_xz.get_output() @lazy_property def loss(self): response = self.response labels, weights = create_labels_overlap(np.array([self._response_size, self._response_size]), self._y_crops) labels = tf.reshape(labels, [-1]) weights = tf.reshape(weights, [-1]) response = tf.reshape(response,[-1]) keep = tf.where(tf.not_equal(labels, -1))[:, 0] logits = tf.gather(response, keep) labels = tf.gather(labels, keep) weights = tf.gather(weights, keep) loss = tf.reduce_sum(weights*tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))/(self._batch_size*self._time_steps) if self._is_train: tf.summary.scalar('loss/cross_entropy', loss) return loss @lazy_property def dist_error(self): response = self.response response_shape = response.get_shape().as_list() max_idx = tf.argmax(tf.reshape(response, [response_shape[0], -1]), 1) esti_pos = tf.cast(tf.stack([max_idx % self._response_size, max_idx // self._response_size], 1), tf.float32) dist_error = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(esti_pos - tf.expand_dims(self._gt_pos, 0)), 1))) if self._is_train: tf.summary.scalar('loss/dist_error', dist_error) return dist_error @lazy_property def optimize(self): tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), config.max_grad_norm) # optimizer = tf.train.GradientDescentOptimizer(self.lr) optimizer = tf.train.AdamOptimizer(self._lr) apply_gradient_op = optimizer.apply_gradients(zip(grads, tvars), self._global_step) batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION) batchnorm_updates_op = tf.group(*batchnorm_updates) train_op = tf.group(apply_gradient_op, batchnorm_updates_op) return train_op @property def saver(self): return self._saver @property def global_step(self): return self._global_step @property def lr(self): return self._lr @property def summary(self): return self._summary @property def initial_state(self): return self._initial_state @property def final_state(self): return self._final_state @property def z_gf(self): return self._z_gf