import tensorflow as tf from keras import Input, Model, backend as K from keras.layers import Dense, Concatenate, Permute from keras.layers import LSTM from keras.layers import Lambda, Reshape from keras.optimizers import RMSprop from general_utils import pxy_dim, out_dim from load_model_config import ModelConfig from tf_normal_sampler import normal2d_log_pdf from tf_normal_sampler import normal2d_sample from grid import tf_grid_mask from general_utils import get_image_size class MySocialModel: def __init__(self, config: ModelConfig) -> None: self.x_input = Input((config.obs_len, config.max_n_peds, pxy_dim)) # y_input = Input((config.obs_len, config.max_n_peds, pxy_dim)) self.grid_input = Input( (config.obs_len, config.max_n_peds, config.max_n_peds, config.grid_side_squared)) self.zeros_input = Input( (config.obs_len, config.max_n_peds, config.lstm_state_dim)) # Social LSTM layers self.lstm_layer = LSTM(config.lstm_state_dim, return_state=True) self.W_e_relu = Dense(config.emb_dim, activation="relu") self.W_a_relu = Dense(config.emb_dim, activation="relu") self.W_p = Dense(out_dim) self._build_model(config) def _compute_loss(self, y_batch, o_batch): """ :param y_batch: (batch_size, pred_len, max_n_peds, pxy_dim) :param o_batch: (batch_size, pred_len, max_n_peds, out_dim) :return: loss """ not_exist_pid = 0 y = tf.reshape(y_batch, (-1, pxy_dim)) o = tf.reshape(o_batch, (-1, out_dim)) pids = y[:, 0] # remain only existing pedestrians data exist_rows = tf.not_equal(pids, not_exist_pid) y_exist = tf.boolean_mask(y, exist_rows) o_exist = tf.boolean_mask(o, exist_rows) pos_exist = y_exist[:, 1:] # compute 2D normal prob under output parameters log_prob_exist = normal2d_log_pdf(o_exist, pos_exist) # for numerical stability log_prob_exist = tf.minimum(log_prob_exist, 0.0) loss = -log_prob_exist return loss def _compute_social_tensor(self, grid_t, prev_h_t, config): """Compute $H_t_i(m, n, :)$. this function implementation is same as getSocialTensor() function. :param grid_t: (batch_size, max_n_peds, max_n_peds, grid_side ** 2) which is (batch_index, self_pid, other_pid, grid_index). :param prev_h_t: (batch_size, max_n_peds, lstm_state_dim) :return: H_t (batch_size, max_n_peds, (grid_side ** 2) * lstm_state_dim) """ H_t = [] for i in range(config.max_n_peds): # (batch_size, max_n_peds, max_n_peds, grid_side ** 2) # => (batch_size, max_n_peds, grid_side ** 2) grid_it = Lambda(lambda grid_t: grid_t[:, i, ...])(grid_t) # (batch_size, max_n_peds, grid_side **2) # => (batch_size, grid_side ** 2, max_n_peds) grid_it_T = Permute((2, 1))(grid_it) # (batch_size, grid_side ** 2, lstm_state_dim) H_it = Lambda(lambda x: K.batch_dot(x[0], x[1]))( [grid_it_T, prev_h_t]) # store to H_t H_t.append(H_it) # list of (batch_size, grid_side_squared, lstm_state_dim) # => (max_n_peds, batch_size, grid_side_squared, lstm_state_dim) H_t = Lambda(lambda H_t: K.stack(H_t, axis=0))(H_t) # (max_n_peds, batch_size, grid_side_squared, lstm_state_dim) # => (batch_size, max_n_peds, grid_side_squared, lstm_state_dim) H_t = Lambda(lambda H_t: K.permute_dimensions(H_t, (1, 0, 2, 3)))(H_t) # (batch_size, max_n_peds, grid_side_squared, lstm_state_dim) # => (batch_size, max_n_peds, grid_side_squared * lstm_state_dim) H_t = Reshape( (config.max_n_peds, config.grid_side_squared * config.lstm_state_dim))( H_t) return H_t def _build_model(self, config: ModelConfig): o_obs_batch = [] for t in range(config.obs_len): print("t: ", t) x_t = Lambda(lambda x: x[:, t, :, :])(self.x_input) grid_t = Lambda(lambda grid: grid[:, t, ...])(self.grid_input) h_t, c_t = [], [] o_t = [] if t == 0: prev_h_t = Lambda(lambda z: z[:, t, :, :])(self.zeros_input) prev_c_t = Lambda(lambda z: z[:, t, :, :])(self.zeros_input) # compute $H_t$ # (n_samples, max_n_peds, (grid_side ** 2) * lstm_state_dim) H_t = self._compute_social_tensor(grid_t, prev_h_t, config) for ped_index in range(config.max_n_peds): print("(t, li):", t, ped_index) # ---------------------------------------- # compute $e_i^t$ and $a_i^t$ # ---------------------------------------- x_pos_it = Lambda(lambda x_t: x_t[:, ped_index, 1:])(x_t) e_it = self.W_e_relu(x_pos_it) # compute a_it H_it = Lambda(lambda H_t: H_t[:, ped_index, ...])(H_t) a_it = self.W_a_relu(H_it) # build concatenated embedding states for LSTM input emb_it = Concatenate()([e_it, a_it]) emb_it = Reshape((1, 2 * config.emb_dim))(emb_it) # initial_state = h_i_tになっている # h_i_tを次のx_t_pに対してLSTMを適用するときのinitial_stateに使えば良い prev_states_it = [prev_h_t[:, ped_index], prev_c_t[:, ped_index]] lstm_output, h_it, c_it = self.lstm_layer(emb_it, prev_states_it) h_t.append(h_it) c_t.append(c_it) # compute output_it, which shape is (batch_size, 5) o_it = self.W_p(lstm_output) o_t.append(o_it) # convert lists of h_it/c_it/o_it to h_t/c_t/o_t respectively h_t = _stack_permute_axis_zero(h_t) c_t = _stack_permute_axis_zero(c_t) o_t = _stack_permute_axis_zero(o_t) o_obs_batch.append(o_t) # current => previous prev_h_t = h_t prev_c_t = c_t # convert list of output_t to output_batch o_obs_batch = _stack_permute_axis_zero(o_obs_batch) # ---------------------------------------------------------------------- # Prediction # ---------------------------------------------------------------------- # この時点でprev_h_t, prev_c_tにはobs_lenの最終的な状態が残っている x_obs_t_final = Lambda(lambda x: x[:, -1, :, :])(self.x_input) pid_obs_t_final = Lambda(lambda x_t: x_t[:, :, 0])(x_obs_t_final) pid_obs_t_final = Lambda(lambda p_t: K.expand_dims(p_t, 2))( pid_obs_t_final) x_pred_batch = [] o_pred_batch = [] for t in range(config.pred_len): if t == 0: prev_o_t = Lambda(lambda o_b: o_b[:, -1, :, :])(o_obs_batch) pred_pos_t = normal2d_sample(prev_o_t) # assume all the pedestrians in the final observation frame are # exist in the prediction frames. x_pred_t = Concatenate(axis=2)([pid_obs_t_final, pred_pos_t]) grid_t = tf_grid_mask(x_pred_t, get_image_size(config.test_dataset_kind), config.n_neighbor_pixels, config.grid_side) h_t, c_t, o_t = [], [], [] # compute $H_t$ # (n_samples, max_n_peds, (grid_side ** 2) * lstm_state_dim) H_t = self._compute_social_tensor(grid_t, prev_h_t, config) for i in range(config.max_n_peds): print("(t, li):", t, i) prev_o_it = Lambda(lambda o_t: o_t[:, i, :])(prev_o_t) H_it = Lambda(lambda H_t: H_t[:, i, ...])(H_t) # pred_pos_it: (batch_size, 2) pred_pos_it = normal2d_sample(prev_o_it) # compute e_it and a_it # e_it: (batch_size, emb_dim) # a_it: (batch_size, emb_dim) e_it = self.W_e_relu(pred_pos_it) a_it = self.W_a_relu(H_it) # build concatenated embedding states for LSTM input # emb_it: (batch_size, 1, 2 * emb_dim) emb_it = Concatenate()([e_it, a_it]) emb_it = Reshape((1, 2 * config.emb_dim))(emb_it) # initial_state = h_i_tになっている # h_i_tを次のx_t_pに対してLSTMを適用するときのinitial_stateに使えば良い prev_states_it = [prev_h_t[:, i], prev_c_t[:, i]] lstm_output, h_it, c_it = self.lstm_layer(emb_it, prev_states_it) h_t.append(h_it) c_t.append(c_it) # compute output_it, which shape is (batch_size, 5) o_it = self.W_p(lstm_output) o_t.append(o_it) # convert lists of h_it/c_it/o_it to h_t/c_t/o_t respectively h_t = _stack_permute_axis_zero(h_t) c_t = _stack_permute_axis_zero(c_t) o_t = _stack_permute_axis_zero(o_t) o_pred_batch.append(o_t) x_pred_batch.append(x_pred_t) # current => previous prev_h_t = h_t prev_c_t = c_t prev_o_t = o_t # convert list of output_t to output_batch o_pred_batch = _stack_permute_axis_zero(o_pred_batch) x_pred_batch = _stack_permute_axis_zero(x_pred_batch) # o_concat_batch = Lambda(lambda os: tf.concat(os, axis=1))( # [o_obs_batch, o_pred_batch]) # 本当に学習に必要なモデルはこっちのはず self.train_model = Model( [self.x_input, self.grid_input, self.zeros_input], o_pred_batch ) lr = 0.003 optimizer = RMSprop(lr=lr) self.train_model.compile(optimizer, self._compute_loss) self.sample_model = Model( [self.x_input, self.grid_input, self.zeros_input], x_pred_batch ) def _stack_permute_axis_zero(xs): xs = Lambda(lambda xs: K.stack(xs, axis=0))(xs) # axes (0, 1) are permuted perm = [1, 0] + list(range(2, xs.shape.ndims)) xs = Lambda(lambda xs: K.permute_dimensions(xs, perm))(xs) return xs