######################################################################################## # # Hierarchical Attentive Recurrent Tracking # Copyright (C) 2017 Adam R. Kosiorek, Oxford Robotics Institute, University of Oxford # email: adamk@robots.ox.ac.uk # webpage: http://ori.ox.ac.uk # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # ######################################################################################## import numpy as np import tensorflow as tf from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import LSTMCell from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell from tensorflow.python.util import nest from hart.model import tensor_ops from hart.model.nn import DynamicFilterModel from hart.model.rnn import ZoneoutWrapper, IdentityLSTMCell from neurocity.component.layer import AffineLayer, ConvLayer from neurocity.tensor_ops import convert_shape def gaussian_mask(params, R, C): """Define a mask of size RxC given by one 1-D Gaussian per row. u, s and d must be 1-dimensional vectors""" u, s, d = (params[..., i] for i in xrange(3)) for i in (u, s, d): assert len(u.get_shape()) == 1, i batch_size = tf.to_int32(tf.shape(u)[0]) R = tf.range(tf.to_int32(R)) C = tf.range(tf.to_int32(C)) R = tf.to_float(R)[tf.newaxis, tf.newaxis, :] C = tf.to_float(C)[tf.newaxis, :, tf.newaxis] C = tf.tile(C, (batch_size, 1, 1)) u, d = u[:, tf.newaxis, tf.newaxis], d[:, tf.newaxis, tf.newaxis] s = s[:, tf.newaxis, tf.newaxis] ur = u + (R - 0.) * d sr = tf.ones_like(ur) * s mask = C - ur mask = tf.exp(-.5 * (mask / sr) ** 2) mask /= tf.reduce_sum(mask, 1, keep_dims=True) + 1e-8 return mask def extract_glimpse(inpt, attention_params, glimpse_size): """Extracts an attention glimpse :param inpt: tensor of shape == (batch_size, img_height, img_width) :param attention_params: tensor of shape = (batch_size, 6) as [uy, sy, dy, ux, sx, dx] with u - mean, s - std, d - stride" :param glimpse_size: 2-tuple of ints as (height, width), size of the extracted glimpse :return: tensor """ ap = attention_params shape = inpt.get_shape() rank = len(shape) assert rank in (3, 4), "Input must be 3 or 4 dimensional tensor" inpt_H, inpt_W = shape[1:3] if rank == 3: inpt = inpt[..., tf.newaxis] rank += 1 Fy = gaussian_mask(ap[..., 0::2], glimpse_size[0], inpt_H) Fx = gaussian_mask(ap[..., 1::2], glimpse_size[1], inpt_W) gs = [] for channel in tf.unstack(inpt, axis=rank - 1): g = tf.matmul(tf.matmul(Fy, channel, adjoint_a=True), Fx) gs.append(g) g = tf.stack(gs, axis=rank - 1) g.set_shape([shape[0]] + list(glimpse_size)) return g class Attention(object): n_params = None def __init__(self, inpt_size, glimpse_size): self.inpt_size = np.asarray(inpt_size) self.glimpse_size = np.asarray(glimpse_size) def extract_glimpse(self, inpt, raw_att, return_all=False): raw_att_flat = tf.reshape(raw_att, (-1, self.n_params), 'flat_raw_att') att_flat = self._to_attention(raw_att_flat) shape = raw_att.get_shape().as_list() n_glimpses = int(raw_att.get_shape()[-1]) // self.n_params att = tf.reshape(att_flat, shape[:-1] + [n_glimpses, int(att_flat.get_shape()[-1])]) glimpse = [] for a in tf.unstack(att, axis=1): glimpse.append(self._extract_glimpse(inpt, a)) glimpse = tf.stack(glimpse, 1) glimpse = tf.reshape(glimpse, (-1,) + tuple(self.glimpse_size)) if return_all: return raw_att_flat, att_flat, glimpse else: return glimpse def attention_to_bbox(self, att): with tf.variable_scope('attention_to_bbox'): yx = att[..., :2] * self.inpt_size[np.newaxis, :2] hw = att[..., 2:4] * (self.inpt_size[np.newaxis, :2] - 1) bbox = tf.concat(axis=tf.rank(att) - 1, values=(yx, hw)) bbox.set_shape(att.get_shape()[:-1].concatenate((4,))) return bbox def attention_region(self, att): return self.attention_to_bbox(att) def _extract_glimpse(self, inpt, att_flat): return extract_glimpse(inpt, att_flat, self.glimpse_size) class RATMAttention(Attention): """Implemented after https://arxiv.org/abs/1510.08660""" n_params = 6 def bbox_to_attention(self, bbox): with tf.variable_scope('ratm_bbox_to_attention'): us = bbox[..., :2] / self.inpt_size[np.newaxis, :2] ss = 0.5 * bbox[..., 2:] / self.inpt_size[np.newaxis, :2] ds = bbox[..., 2:] / (self.inpt_size[np.newaxis, :2] - 1.) att = tf.concat(axis=tf.rank(bbox) - 1, values=(us, ss, ds)) return att @staticmethod def _to_axis_attention(params, glimpse_dim, inpt_dim): u, s, d = (params[..., i] for i in xrange(RATMAttention.n_params // 2)) u = u * inpt_dim s = (s + 1e-5) * float(inpt_dim) / glimpse_dim d = d * float(inpt_dim - 1) / (glimpse_dim - 1) return u, s, d def _to_attention(self, params): (y, x), (u, v) = self.inpt_size[:2], self.glimpse_size[:2] uy, sy, dy = self._to_axis_attention(params[..., ::2], u, y) ux, sx, dx = self._to_axis_attention(params[..., 1::2], v, x) ap = (uy, ux, sy, sx, dy, dx) ap = tf.transpose(tf.stack(ap), name='attention') assert ap.get_shape()[-1] == self.n_params, 'Invalid attention shape={}!'.format(ap.get_shape()) return ap class FixedStdAttention(Attention): """Like RATM but std for the gaussian mask depends directly and exclusively on the stride between gaussians. I used a small neural net to compute std to approximate bicubic interpolation and then fitted a 4th order surface to predictions of the neural net. There's also an additive (learnt) bias in pixels to the upper left corner of the attention window.""" n_params = 4 offset_bias = np.asarray([0.00809737, 0.50086582], dtype=np.float32).reshape(1, 2) weights = np.asarray([[6.12598441e-01, 9.25613308e-01], [-1.05801568e-02, -2.18224973e-03], [1.32131897e-04, -6.09307166e-06], [-2.87635530e-07, 9.08051012e-08], [1.94529164e-10, -9.47235313e-11], [1.44468477e-04, -1.19733592e-02], [-4.30590720e-06, 7.71485474e-05], [1.05376852e-08, -1.05474865e-07], [-6.49625282e-12, 3.43567810e-11], [9.85685680e-06, 6.57580098e-05], [-1.41991381e-08, -9.46024867e-08], [-9.81123812e-09, -1.56167932e-07], [-3.61024557e-12, 7.68954027e-11], [2.65848501e-11, 4.39530612e-11], [-1.01850187e-11, 9.85289183e-11]], dtype=np.float32) def bbox_to_attention(self, bbox): with tf.variable_scope('fixed_std_bbox_to_attention'): us = bbox[..., :2] / self.inpt_size[np.newaxis, :2] ds = bbox[..., 2:] / (self.inpt_size[np.newaxis, :2] - 1.) att = tf.concat(axis=tf.rank(bbox) - 1, values=(us, ds)) att.set_shape(bbox.get_shape()[:-1].concatenate([4])) return att def _stride_to_std(self, stride): shape = convert_shape(stride.get_shape()) stride_flat = tf.reshape(stride, (-1, shape[-1])) y, x = stride_flat[..., 0], stride_flat[..., 1] features = [ tf.ones_like(y), y, y ** 2, y ** 3, y ** 4, x, x ** 2, x ** 3, x ** 4, y * x, y * x ** 2, y ** 2 * x, y * x ** 3, y ** 2 * x ** 2, y ** 3 * x ] features = tf.concat(axis=1, values=[f[..., tf.newaxis] for f in features]) sigma_flat = tf.matmul(features, self.weights) return tf.reshape(sigma_flat, shape) def _to_attention(self, raw_att, with_bias=True): bbox = FixedStdAttention.attention_to_bbox(self, raw_att) us = bbox[..., :2] if with_bias: us += self.offset_bias ds = bbox[..., 2:4] / (self.glimpse_size[np.newaxis, :2] - 1) ss = self._stride_to_std(ds) ap = tf.concat(axis=tf.rank(raw_att) - 1, values=(us, ss, ds), name='attention') ap.set_shape(raw_att.get_shape()[:-1].concatenate((6,))) return ap class AttentionCell(RNNCell): def __init__(self, feature_extractor, n_units, att_gain, glimpse_size, input_size=None, batch_size=None, zoneout_prob=0., attention_module=RATMAttention, normalize_glimpse=False, identity_init=True, debug=False, predict_appearance=False, feature_shape=None, is_training=True): assert len(glimpse_size) in (2, 3), 'Invalid size' assert input_size is None or len(input_size) == len(glimpse_size), 'Invalid size' self.feature_extractor = feature_extractor self.n_units = n_units self.att_gain = att_gain self.glimpse_size = glimpse_size self.input_size = input_size self.batch_size = batch_size self.normalize_glimpse = normalize_glimpse self.identity_init = identity_init self.debug = debug self.predict_appearance = predict_appearance self.attention = attention_module(self.input_size, self.glimpse_size) if not isinstance(zoneout_prob, (tuple, list)): zoneout_prob = (zoneout_prob, 0.) self.zoneout_prob = zoneout_prob self.cell = self._make_cell(is_training) self._rec_init = tf.random_uniform_initializer(-1e-3, 1e-3) self._att_size = self.att_size self._state_size = (self._att_size, 1, self.cell.state_size) self._output_size = (self.cell.output_size, self._att_size, 1) if self.debug: self._output_size += (np.prod(self.glimpse_size),) if self.predict_appearance: self._state_size += (self.n_units,) self._output_size += (np.prod(feature_shape), 10, 1) @property def att_size(self): return self.attention.n_params @property def state_size(self): return self._state_size @property def output_size(self): return self._output_size def zero_state(self, batch_size, dtype, bbox0=None, presence=None, img0=None, transform_featuers=False, transform_state=False): if bbox0 is None: return super(AttentionCell, self).zero_state(batch_size, dtype) self.att0 = self.attention.bbox_to_attention(bbox0) att0 = tf.reshape(self.att0, (-1, self.att_size), 'shape_first_att') presence = tf.to_float(tf.reshape(presence, (-1, 1)), 'shape_first_presence') * 1e3 state = att0, presence zero_state = self.cell.zero_state(batch_size, dtype) if img0 is not None: zero_state, rnn_outputs = self._zero_state(img0, att0, presence, zero_state, transform_featuers, transform_state) att_bias = tf.get_variable('att_bias', (1, 1, 1, self.att_size), initializer=self._rec_init) self.att_bias = .25 * tf.nn.tanh(att_bias) att0 += tf.reshape(tf.tile(self.att_bias, (1, 1, 1, 1)), (1, -1)) self.att0 += self.att_bias[0] state += (zero_state,) if self.predict_appearance: rnn_outputs = tf.reshape(rnn_outputs, (-1, self.n_units)) state += (rnn_outputs,) else: state += (zero_state,) return state def _zero_state(self, img, att, presence, state, transform_features, transform_state=False): with tf.variable_scope(self.__class__.__name__) as vs: features = self.extract_features(img, att)[1] if transform_features: features_flat = tf.reshape(features, (-1, self.n_units)) features_flat = AffineLayer(features_flat, self.n_units, name='init_feature_transform').output features = tf.reshape(features_flat, tf.shape(features)) rnn_outputs, hidden_state = self._propagate(features, state) hidden_state = nest.flatten(hidden_state) if transform_state: for i, hs in enumerate(hidden_state): name = 'init_state_transform_{}'.format(i) hidden_state[i] = AffineLayer(hs, self.n_units, name=name).output state = nest.pack_sequence_as(structure=state, flat_sequence=hidden_state) self.rnn_vs = vs return state, rnn_outputs def _make_cell(self, is_training): raw_cell = IdentityLSTMCell if self.identity_init else LSTMCell if self.zoneout_prob[0] > 0.: cell = lambda: ZoneoutWrapper(raw_cell(self.n_units), self.zoneout_prob, is_training) else: cell = lambda: raw_cell(self.n_units) return cell() def _propagate(self, inpt, state): features = tf.reshape(inpt, (self.batch_size, self.n_units)) outputs, hidden_state = self.cell(features, state) return tf.reshape(outputs, (self.batch_size, 1, self.n_units)), hidden_state def extract_features(self, inpt, raw_att, apperance_vec=None, reuse=False): raw_att_flat, att_flat, glimpse_flat = self.attention.extract_glimpse(inpt, raw_att, return_all=True) if self.normalize_glimpse: # do not normalize depth colour = tensor_ops.normalize_contrast(glimpse_flat[..., :3]) if glimpse_flat.get_shape()[-1] == 4: ax = len(glimpse_flat.get_shape()) - 1 glimpse_flat = tf.concat(axis=ax, values=(colour, glimpse_flat[..., 3:])) else: glimpse_flat = colour features = self.feature_extractor(glimpse_flat, reuse=reuse) def flatten_features(f, name='', more_feats=None): with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): f_flat = tf.reshape(f, (self.batch_size * 1, -1), 'reshape' + name) to_concat = (f_flat, att_flat) if more_feats is not None: to_concat += more_feats f_flat = tf.concat(axis=1, values=to_concat, name='concat_att' + name) return AffineLayer(f_flat, self.n_units, transfer=tf.nn.elu, name='fc_after_conv' + name) if apperance_vec is not None: before_features = self.feature_extractor.orig_output self.dfn_inpt = DynamicFilterModel(before_features, apperance_vec, ksize=(1, 1), n_channels=5, name='pre_DFN') self.dfn = DynamicFilterModel(self.dfn_inpt.features, apperance_vec, ksize=(3, 3), n_channels=5) dfn = self.dfn.features obj_mask_logit = ConvLayer(dfn, (1, 1), 1, transfer=None, name='obj_mask').output obj_mask = tf.nn.sigmoid(obj_mask_logit) features *= obj_mask flat_mask = tf.reshape(obj_mask, (self.batch_size * 1, -1)) features_flat = flatten_features(features) features = tf.reshape(features_flat, (self.batch_size, 1, self.n_units), 'to_features') output = raw_att_flat, features, glimpse_flat if apperance_vec is not None: output += (obj_mask_logit, flat_mask) return output def __call__(self, inpt, state, scope=None): raw_att, presence, hidden_state = state[:3] if self.predict_appearance: apperance_vec = tf.reshape(state[3], (self.batch_size * 1, self.n_units)) if self.batch_size is None: self.batch_size = int(inpt.get_shape()[0]) if self.input_size is None: self.input_size = [tf.to_float(i) for i in inpt.get_shape()[-2:]] with tf.variable_scope(self.__class__.__name__): all_features = self.extract_features(inpt, raw_att, apperance_vec=apperance_vec, reuse=True) raw_att_flat, features, glimpses = all_features[:3] with tf.variable_scope(self.rnn_vs, initializer=self._rec_init, reuse=True): rnn_outputs, hidden_state = self._propagate(features, hidden_state) # delta-update of the raw attention params zero_init = tf.constant_initializer() outputs_flat = tf.reshape(rnn_outputs, (-1, self.n_units), 'outputs_flat') att_inpt = outputs_flat if self.predict_appearance: flat_mask = all_features[-1] mask_features = AffineLayer(flat_mask, 10, transfer=tf.nn.elu, name='mask_features') att_inpt = tf.concat(axis=1, values=(att_inpt, mask_features)) att_readout = AffineLayer(att_inpt, self.n_units, transfer=tf.nn.elu, name='att_readout_1') att_diff_flat = AffineLayer(att_readout, self.att_size, transfer=tf.nn.tanh, weight_init=self._rec_init, bias_init=zero_init, name='att_readout') att_delta_scale = tf.Variable(self.att_gain, name='att_delta_scale') new_att_flat = raw_att_flat + tf.nn.sigmoid(att_delta_scale) * att_diff_flat.output new_att = tf.reshape(new_att_flat, (-1, self.att_size), 'new_att_shape') rnn_outputs = tf.reshape(rnn_outputs, (-1, self.n_units), 'outputs_shape') outputs, state = (rnn_outputs, new_att, presence), (new_att, presence, hidden_state) if self.debug: glimpse_flat = tf.reshape(glimpses, (self.batch_size, -1)) outputs += (glimpse_flat,) if self.predict_appearance: rnn_outputs = tf.reshape(rnn_outputs, (self.batch_size, self.n_units)) state += (rnn_outputs,) # concat flat obj_mask to outputs flat_obj_mask = tf.reshape(all_features[-2], (self.batch_size, -1)) flat_mask_features = tf.reshape(mask_features, (self.batch_size, -1)) # weight decay def weight_decay(w): l = tf.reshape(w, (self.batch_size, 1, -1)) return tf.reduce_sum(l ** 2, (1, 2))[..., tf.newaxis] / (2 * 1) dynamic_weights = ( self.dfn.dynamic_weights, self.dfn.dynamic_bias, self.dfn_inpt.dynamic_weights, self.dfn_inpt.dynamic_bias) dfn_weight_decay = sum((weight_decay(i) for i in dynamic_weights)) outputs += (flat_obj_mask, flat_mask_features, dfn_weight_decay) return outputs, state