import tensorflow as tf import numpy as np from config import load_args from lip_model.losses import cer from lip_model.modules import embedding, sinusoid_encoding, multihead_attention, \ feedforward, label_smoothing from lip_model.visual_frontend import VisualFrontend from util.tf_util import shape_list config = load_args() class TransformerTrainGraph(): def __init__(self, x, y, is_training=True, reuse=None, embed_input=False, go_token_index=2, chars=None): self.is_training = is_training self.x = x if config.featurizer: vid_inp = x[0] if type(x) is tuple or type(x) is list else x istarget = tf.not_equal(vid_inp, 0) self.padding_mask = tf.to_float(tf.reduce_any(istarget, axis=[2,3,4])) with tf.variable_scope('visual_frontend', reuse=reuse): self.visual_frontend = VisualFrontend(vid_inp) vid_inp = self.visual_frontend.output vid_inp = vid_inp * tf.expand_dims(self.padding_mask,-1) # pad = 30 # x = tf.keras.layers.ZeroPadding1D(padding=(pad, pad))(x) if type(x) is tuple or type(x) is list: x = [vid_inp] + list(x[1:]) else: x = vid_inp if is_training: self.prev = y self.y = y else: # This is the partial prediction used for the autoregression - # augmented by one more element on every step when autoregression is on self.prev = y[0] self.y = y[1] # This is the whole ground truth transcription self.alignment_history = {} # to be filled in by decoder self.go_token_idx = go_token_index # define decoder inputs self.decoder_inputs = tf.concat( (tf.ones_like(self.prev[:, :1]) * go_token_index, self.prev[:, :-1]), -1) # 2:<S> # Encoder self.enc = x with tf.variable_scope("encoder", reuse=reuse) as scope: self.enc = self.encoder_body(self.enc, is_training) # import ipdb; ipdb.set_trace() # Decoder self.dec = self.decoder_inputs top_scope = tf.get_variable_scope() # this is a hack to be able to use same model self.chars = chars # needed for decoding with external LM # --------------- index to char dict for summaries -------------------------------- if chars is not None: keys = tf.constant( np.arange(len(chars)) , dtype=tf.int64) values = tf.constant(chars , dtype=tf.string) self.char_table = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer(keys, values), '') with tf.variable_scope("decoder", reuse=reuse) as scope: self.dec = self.decoder_body(self.enc, self.dec, is_training, top_scope=top_scope) if type(self.dec) == tuple: self.preds, self.scores, self.dec = self.dec # Inference graph output self.add_loss_and_metrics(reuse, is_training) if config.tb_eval: self.add_tb_summaries() self.tb_sum = tf.summary.merge_all() def project_output(self): return True def decoder_body(self, enc, dec, is_training, top_scope=None): # Initialize the masks for the pads from here, # because after positional embeddings are added, nothing will be 0 if config.mask_pads: # Guard this for backwards compatibility key_masks_enc = tf.sign(tf.abs(tf.reduce_sum(enc, axis=-1))) # (N, T_k) key_masks_dec = tf.cast( tf.sign(tf.abs(dec)), 'float32' ) # (N, T_k) query_masks_dec = tf.cast( tf.sign(tf.abs(dec)), 'float32' ) # (N, T_k) else: key_masks_enc = key_masks_dec = query_masks_dec = None ## Embedding dec = self.decoder_embeddings(dec, is_training) for i in range(config.num_blocks): with tf.variable_scope("num_blocks_{}".format(i)): ## self-attention dec, alignmets = multihead_attention(queries=dec, query_masks=query_masks_dec, keys=dec, key_masks=key_masks_dec, num_units=config.hidden_units, num_heads=config.num_heads, dropout_rate=config.dropout_rate, is_training=is_training, causality=True, scope="self_attention") # self.alignment_history["dec_self_att_{}".format(i)] = alignmets # save for tb ## vanilla attention dec, alignmets = multihead_attention(queries=dec, query_masks=query_masks_dec, keys=enc, key_masks=key_masks_enc, num_units=config.hidden_units, num_heads=config.num_heads, dropout_rate=config.dropout_rate, is_training=is_training, causality=False, scope="vanilla_attention") self.alignment_history["enc_dec_attention_{}".format(i)] = alignmets # save for tb ## Feed Forward dec = feedforward(dec, num_units=[4 * config.hidden_units, config.hidden_units]) return dec def decoder_embeddings(self, decoder_inputs, is_training): dec = embedding(decoder_inputs, vocab_size=config.n_labels, num_units=config.hidden_units, scale=True, scope="dec_embed") # if self.is_training: # dec = dec[:,:self.out_last_non_pad_idx] ## Positional Encoding pos = self.positional_encoding(decoder_inputs, scope='dec_pe') # if self.is_training: # pos = pos[:,:self.out_last_non_pad_idx] dec += pos ## Dropout dec = tf.layers.dropout(dec, rate=config.dropout_rate, training=tf.convert_to_tensor(is_training)) return dec def positional_encoding(self, inp, scope): if config.sinusoid: return sinusoid_encoding(inp, num_units=config.hidden_units, zero_pad=False, scale=False, scope=scope, T = config.maxlen ) else: return embedding( tf.tile(tf.expand_dims(tf.range(tf.shape(inp)[1]), 0), [tf.shape(inp)[0], 1]), vocab_size=config.maxlen, num_units=config.hidden_units, zero_pad=False, scale=False, scope="dec_pe") def encoder_body(self, enc, is_training): num_blocks = config.num_blocks if config.mask_pads: # Guard this for backwards compatibility # Initialize the masks for the pads from here, # because after positional embeddings are added, nothing will be 0 key_masks = tf.sign(tf.abs(tf.reduce_sum(enc, axis=-1))) # (N, T_k) query_masks = tf.sign(tf.abs(tf.reduce_sum(enc, axis=-1))) # (N, T_k) else: key_masks = query_masks = None enc = self.encoder_embeddings(enc, is_training) for i in range(num_blocks): with tf.variable_scope("num_blocks_{}".format(i)): ### Multihead Attention enc, alignmets = multihead_attention(queries=enc, query_masks=query_masks, keys=enc, key_masks=key_masks, num_units=config.hidden_units, num_heads=config.num_heads, dropout_rate=config.dropout_rate, is_training=is_training, causality=False) # key_masks = query_masks = None # ### Feed Forward enc = feedforward(enc, num_units=[4 * config.hidden_units, config.hidden_units]) # self.alignment_history["enc_self_att_{}".format(i)] = alignmets # save for tb return enc def encoder_embeddings(self, x, is_training, embed_input=0): # Embedding if embed_input: enc = embedding(x, vocab_size=config.n_input_vocab, num_units=config.hidden_units, scale=True, scope="enc_embed") else: enc = x ## Positional Encoding feat_dim = shape_list(enc)[-1] # if input features are not same size as transformer units, make a linear projection if not feat_dim == config.hidden_units: enc = tf.layers.dense(enc, config.hidden_units) enc += self.positional_encoding(enc, scope='enc_pe') if embed_input: ## Dropout enc = tf.layers.dropout(enc, rate=config.dropout_rate, training=tf.convert_to_tensor(is_training)) return enc def add_loss_and_metrics(self, reuse, is_training): # Final linear projection if self.project_output(): # with tf.variable_scope("ctc_conv1d_net/ctc_probs", reuse=reuse) as scope: self.logits = tf.layers.dense(self.dec, config.n_labels, reuse=reuse) else: assert(self.dec.get_shape()[-1].value == config.n_labels) self.logits = self.dec if config.test_aug_times: self.logits_aug = self.logits self.logits = tf.reduce_mean(self.logits, 0, keep_dims=True) self.istarget = tf.to_float(tf.not_equal(self.y, 0)) # Loss self.y_one_hot = tf.one_hot(self.y, depth=config.n_labels) self.y_smoothed = label_smoothing(self.y_one_hot) self.loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y_smoothed) # we want to know when to stop so learn padding as well self.mean_loss = tf.reduce_sum(self.loss) / (tf.reduce_sum(self.istarget)) self.logprobs = tf.log(tf.nn.softmax(self.logits)) if not 'infer' in config.graph_type: self.preds = tf.to_int32(tf.argmax(self.logits, axis=-1)) self.cer, self.cer_per_sample = cer(self.y_one_hot, self.logits, return_all=True) else: self.preds = tf.to_int32(self.preds) one_hot_from_preds = tf.one_hot(self.preds, depth=config.n_labels) self.cer, self.cer_per_sample = cer(self.y_one_hot, one_hot_from_preds, return_all=True) def add_tb_summaries(self): from util.tb_util import add_gif_summary, colorize_image fps = 10 timeline = False # ---------------- Add video summaries ------------------------------- bs = int(self.visual_frontend.output.shape[0]) b_id = 0 non_pad_inds = tf.cast(tf.where(self.padding_mask[b_id] > 0)[:, 0], tf.int64) fr_in, to_in = non_pad_inds[0], non_pad_inds[-1] + 1 # For masking out input paddings add_gif_summary('1-video_input', self.visual_frontend.input[b_id][fr_in:to_in], fps=fps, timeline=timeline) if not config.test_aug_times: add_gif_summary('2-input_to_resnet', self.visual_frontend.aug_out[b_id][fr_in:to_in], fps=fps, timeline=timeline) else: # Viz the different test augmentations add_gif_summary('2-input_to_resnet', tf.concat([ self.visual_frontend.aug_out[b_id][fr_in:to_in] for b_id in xrange(bs) ], axis=2), fps=fps, timeline=timeline) # ---------------- Add text summaries ------------------------------- pred_strings_tf = self.char_table.lookup(tf.cast(self.preds, tf.int64)) joined_pred = tf.string_join( tf.split(pred_strings_tf, pred_strings_tf.shape[1], 1))[ :, 0] gt_strings_tf = self.char_table.lookup(tf.cast(self.y, tf.int64)) joined_gt = tf.string_join( tf.split(gt_strings_tf, pred_strings_tf.shape[1], 1))[:, 0] joined_all = tf.string_join([joined_gt, joined_pred], ' --> ') tf.summary.text('Predictions', joined_all) # ---------------- Add image summaries ------------------------------- all_atts = [] for layer_name, alignment_history in self.alignment_history.items(): for att_head_idx, attention_images in enumerate(alignment_history): all_atts.append(attention_images) avg_att = tf.exp(tf.reduce_mean(tf.log(all_atts), axis=0)) # Permute and reshape (batch, t_dec, t_enc) --> (batch, t_enc, t_dec, 1) attention_img = tf.expand_dims(tf.transpose(avg_att, [0, 2, 1]), -1) attention_img *= 255 # Scale to range [0, 255] b_id = 0 # visualize only the first sample of the batch to_out = tf.where( self.preds[b_id]> 0 )[-1][0] + 1 # To mask output paddings |~ color_img = tf.map_fn( colorize_image, (attention_img[:, fr_in:to_in, :to_out]) ) tf.summary.image("3-enc_dec_attention", color_img) # ---------------- Add image with subs summaries ------------------------------- # import ipdb; ipdb.set_trace() add_gif_summary('4-subs', self.visual_frontend.input[b_id][fr_in:to_in], fps=fps, timeline=timeline, attention=attention_img[b_id][fr_in:to_in, :to_out,0], pred=joined_pred[b_id]) @classmethod def get_input_shapes_and_types(cls, batch=0): input_types = [] input_shape = [] if config.featurizer: input_shape += [ (config.time_dim, config.img_width, config.img_height, config.img_channels) ] input_types +=['float32'] else: input_shape += [ (config.time_dim, config.feat_dim) ] input_types +=['float32'] if batch: input_shape = [ (config.batch_size,) + shape for shape in input_shape ] return input_shape, input_types @classmethod def get_target_shapes_and_types(cls, batch=0): target_shape = [ ( config.time_dim, ) ] if batch: target_shape = [ (config.batch_size,) + shape for shape in target_shape ] target_types = ['int64'] return target_shape, target_types @classmethod def get_model_input_target_shapes_and_types(cls, batch_dims=1): return cls.get_input_shapes_and_types(batch=batch_dims),\ cls.get_target_shapes_and_types(batch=batch_dims)