# Copyright 2016 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Model architecture for predictive model, including CDNA, DNA, and STP.""" import itertools import numpy as np import tensorflow as tf import tensorflow.contrib.slim as slim from tensorflow.contrib.layers.python import layers as tf_layers from tensorflow.contrib.slim import add_arg_scope from tensorflow.contrib.slim import layers from video_prediction.models import VideoPredictionModel # Amount to use when lower bounding tensors RELU_SHIFT = 1e-12 @add_arg_scope def basic_conv_lstm_cell(inputs, state, num_channels, filter_size=5, forget_bias=1.0, scope=None, reuse=None, ): """Basic LSTM recurrent network cell, with 2D convolution connctions. We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training. It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. Args: inputs: input Tensor, 4D, batch x height x width x channels. state: state Tensor, 4D, batch x height x width x channels. num_channels: the number of output channels in the layer. filter_size: the shape of the each convolution filter. forget_bias: the initial value of the forget biases. scope: Optional scope for variable_scope. reuse: whether or not the layer and the variables should be reused. Returns: a tuple of tensors representing output and the new state. """ if state is None: state = tf.zeros(inputs.get_shape().as_list()[:3] + [2 * num_channels], name='init_state') with tf.variable_scope(scope, 'BasicConvLstmCell', [inputs, state], reuse=reuse): inputs.get_shape().assert_has_rank(4) state.get_shape().assert_has_rank(4) c, h = tf.split(axis=3, num_or_size_splits=2, value=state) inputs_h = tf.concat(values=[inputs, h], axis=3) # Parameters of gates are concatenated into one conv for efficiency. i_j_f_o = layers.conv2d(inputs_h, 4 * num_channels, [filter_size, filter_size], stride=1, activation_fn=None, scope='Gates', ) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = tf.split(value=i_j_f_o, num_or_size_splits=4, axis=3) new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j) new_h = tf.tanh(new_c) * tf.sigmoid(o) return new_h, tf.concat(values=[new_c, new_h], axis=3) class Prediction_Model(object): def __init__(self, images, actions=None, states=None, iter_num=-1.0, pix_distributions1=None, pix_distributions2=None, conf=None): self.pix_distributions1 = pix_distributions1 self.pix_distributions2 = pix_distributions2 self.actions = actions self.iter_num = iter_num self.conf = conf self.images = images self.cdna, self.stp, self.dna = False, False, False if self.conf['model'] == 'CDNA': self.cdna = True elif self.conf['model'] == 'DNA': self.dna = True elif self.conf['model'] == 'STP': self.stp = True if self.stp + self.cdna + self.dna != 1: raise ValueError("More than one option selected!") self.k = conf['schedsamp_k'] self.use_state = conf['use_state'] self.num_masks = conf['num_masks'] self.context_frames = conf['context_frames'] self.batch_size, self.img_height, self.img_width, self.color_channels = [int(i) for i in images[0].get_shape()[0:4]] self.lstm_func = basic_conv_lstm_cell # Generated robot states and images. self.gen_states = [] self.gen_images = [] self.gen_masks = [] self.moved_images = [] self.moved_pix_distrib1 = [] self.moved_pix_distrib2 = [] self.states = states self.gen_distrib1 = [] self.gen_distrib2 = [] self.trafos = [] def build(self): if 'kern_size' in self.conf.keys(): KERN_SIZE = self.conf['kern_size'] else: KERN_SIZE = 5 batch_size, img_height, img_width, color_channels = self.images[0].get_shape()[0:4] lstm_func = basic_conv_lstm_cell if self.states != None: current_state = self.states[0] else: current_state = None if self.actions == None: self.actions = [None for _ in self.images] if self.k == -1: feedself = True else: # Scheduled sampling: # Calculate number of ground-truth frames to pass in. num_ground_truth = tf.to_int32( tf.round(tf.to_float(batch_size) * (self.k / (self.k + tf.exp(self.iter_num / self.k))))) feedself = False # LSTM state sizes and states. if 'lstm_size' in self.conf: lstm_size = self.conf['lstm_size'] print('using lstm size', lstm_size) else: ngf = self.conf['ngf'] lstm_size = np.int32(np.array([ngf, ngf * 2, ngf * 4, ngf * 2, ngf])) lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None lstm_state5, lstm_state6, lstm_state7 = None, None, None for t, action in enumerate(self.actions): print(t) # Reuse variables after the first timestep. reuse = bool(self.gen_images) done_warm_start = len(self.gen_images) > self.context_frames - 1 with slim.arg_scope( [lstm_func, slim.layers.conv2d, slim.layers.fully_connected, tf_layers.layer_norm, slim.layers.conv2d_transpose], reuse=reuse): if feedself and done_warm_start: # Feed in generated image. prev_image = self.gen_images[-1] # 64x64x6 if self.pix_distributions1 != None: prev_pix_distrib1 = self.gen_distrib1[-1] if 'ndesig' in self.conf: prev_pix_distrib2 = self.gen_distrib2[-1] elif done_warm_start: # Scheduled sampling prev_image = scheduled_sample(self.images[t], self.gen_images[-1], batch_size, num_ground_truth) else: # Always feed in ground_truth prev_image = self.images[t] if self.pix_distributions1 != None: prev_pix_distrib1 = self.pix_distributions1[t] if 'ndesig' in self.conf: prev_pix_distrib2 = self.pix_distributions2[t] if len(prev_pix_distrib1.get_shape()) == 3: prev_pix_distrib1 = tf.expand_dims(prev_pix_distrib1, -1) if 'ndesig' in self.conf: prev_pix_distrib2 = tf.expand_dims(prev_pix_distrib2, -1) if 'refeed_firstimage' in self.conf: assert self.conf['model']=='STP' if t > 1: input_image = self.images[1] print('refeed with image 1') else: input_image = prev_image else: input_image = prev_image # Predicted state is always fed back in if not 'ignore_state_action' in self.conf: state_action = tf.concat(axis=1, values=[action, current_state]) enc0 = slim.layers.conv2d( #32x32x32 input_image, 32, [5, 5], stride=2, scope='scale1_conv1', normalizer_fn=tf_layers.layer_norm, normalizer_params={'scope': 'layer_norm1'}) hidden1, lstm_state1 = lstm_func( # 32x32x16 enc0, lstm_state1, lstm_size[0], scope='state1') hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2') enc1 = slim.layers.conv2d( # 16x16x16 hidden1, hidden1.get_shape()[3], [3, 3], stride=2, scope='conv2') hidden3, lstm_state3 = lstm_func( #16x16x32 enc1, lstm_state3, lstm_size[1], scope='state3') hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4') enc2 = slim.layers.conv2d( # 8x8x32 hidden3, hidden3.get_shape()[3], [3, 3], stride=2, scope='conv3') if not 'ignore_state_action' in self.conf: # Pass in state and action. if 'ignore_state' in self.conf: lowdim = action print('ignoring state') else: lowdim = state_action smear = tf.reshape( lowdim, [int(batch_size), 1, 1, int(lowdim.get_shape()[1])]) smear = tf.tile( smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1]) enc2 = tf.concat(axis=3, values=[enc2, smear]) else: print('ignoring states and actions') enc3 = slim.layers.conv2d( #8x8x32 enc2, hidden3.get_shape()[3], [1, 1], stride=1, scope='conv4') hidden5, lstm_state5 = lstm_func( #8x8x64 enc3, lstm_state5, lstm_size[2], scope='state5') hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6') enc4 = slim.layers.conv2d_transpose( #16x16x64 hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1') hidden6, lstm_state6 = lstm_func( #16x16x32 enc4, lstm_state6, lstm_size[3], scope='state6') hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7') if 'noskip' not in self.conf: # Skip connection. hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 enc5 = slim.layers.conv2d_transpose( #32x32x32 hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2') hidden7, lstm_state7 = lstm_func( # 32x32x16 enc5, lstm_state7, lstm_size[4], scope='state7') hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8') if not 'noskip' in self.conf: # Skip connection. hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 enc6 = slim.layers.conv2d_transpose( # 64x64x16 hidden7, hidden7.get_shape()[3], 3, stride=2, scope='convt3', normalizer_fn=tf_layers.layer_norm, normalizer_params={'scope': 'layer_norm9'}) if 'transform_from_firstimage' in self.conf: prev_image = self.images[1] if self.pix_distributions1 != None: prev_pix_distrib1 = self.pix_distributions1[1] prev_pix_distrib1 = tf.expand_dims(prev_pix_distrib1, -1) print('transform from image 1') if self.conf['model'] == 'DNA': # Using largest hidden state for predicting untied conv kernels. trafo_input = slim.layers.conv2d_transpose( enc6, KERN_SIZE ** 2, 1, stride=1, scope='convt4_cam2') transformed_l = [self.dna_transformation(prev_image, trafo_input, self.conf['kern_size'])] if self.pix_distributions1 != None: transf_distrib_ndesig1 = [self.dna_transformation(prev_pix_distrib1, trafo_input, KERN_SIZE)] if 'ndesig' in self.conf: transf_distrib_ndesig2 = [ self.dna_transformation(prev_pix_distrib2, trafo_input, KERN_SIZE)] extra_masks = 1 ## extra_masks = 2 is needed for running singleview_shifted!! # print('using extra masks 2 because of single view shifted!!') # extra_masks = 2 if self.conf['model'] == 'CDNA': if 'gen_pix' in self.conf: # Using largest hidden state for predicting a new image layer. enc7 = slim.layers.conv2d_transpose( enc6, color_channels, 1, stride=1, scope='convt4', activation_fn=None) # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. transformed_l = [tf.nn.sigmoid(enc7)] extra_masks = 2 else: transformed_l = [] extra_masks = 1 cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) new_transformed, _ = self.cdna_transformation(prev_image, cdna_input, reuse_sc=reuse) transformed_l += new_transformed self.moved_images.append(transformed_l) if self.pix_distributions1 != None: transf_distrib_ndesig1, _ = self.cdna_transformation(prev_pix_distrib1, cdna_input, reuse_sc=True) self.moved_pix_distrib1.append(transf_distrib_ndesig1) if 'ndesig' in self.conf: transf_distrib_ndesig2, _ = self.cdna_transformation( prev_pix_distrib2, cdna_input, reuse_sc=True) self.moved_pix_distrib2.append(transf_distrib_ndesig2) if self.conf['model'] == 'STP': enc7 = slim.layers.conv2d_transpose(enc6, color_channels, 1, stride=1, scope='convt5', activation_fn= None) # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. if 'gen_pix' in self.conf: transformed_l = [tf.nn.sigmoid(enc7)] extra_masks = 2 else: transformed_l = [] extra_masks = 1 enc_stp = tf.reshape(hidden5, [int(batch_size), -1]) stp_input = slim.layers.fully_connected( enc_stp, 200, scope='fc_stp_cam2') # disabling capability to generete pixels reuse_stp = None if reuse: reuse_stp = reuse # enable the generation of pixels: transformed, trafo = self.stp_transformation(prev_image, stp_input, self.num_masks, reuse_stp, suffix='cam2') transformed_l += transformed self.trafos.append(trafo) self.moved_images.append(transformed_l) if self.pix_distributions1 != None: transf_distrib_ndesig1, _ = self.stp_transformation(prev_pix_distrib1, stp_input, suffix='cam2', reuse=True) self.moved_pix_distrib1.append(transf_distrib_ndesig1) if '1stimg_bckgd' in self.conf: background = self.images[0] print('using background from first image..') else: background = prev_image output, mask_list = self.fuse_trafos(enc6, background, transformed_l, scope='convt7_cam2', extra_masks= extra_masks) self.gen_images.append(output) self.gen_masks.append(mask_list) if self.pix_distributions1!=None: pix_distrib_output = self.fuse_pix_distrib(extra_masks, mask_list, self.pix_distributions1, prev_pix_distrib1, transf_distrib_ndesig1) self.gen_distrib1.append(pix_distrib_output) if 'ndesig' in self.conf: pix_distrib_output = self.fuse_pix_distrib(extra_masks, mask_list, self.pix_distributions2, prev_pix_distrib2, transf_distrib_ndesig2) self.gen_distrib2.append(pix_distrib_output) if int(current_state.get_shape()[1]) == 0: current_state = tf.zeros_like(state_action) else: current_state = slim.layers.fully_connected( state_action, int(current_state.get_shape()[1]), scope='state_pred', activation_fn=None) self.gen_states.append(current_state) def fuse_trafos(self, enc6, background_image, transformed, scope, extra_masks): masks = slim.layers.conv2d_transpose( enc6, (self.conf['num_masks']+ extra_masks), 1, stride=1, activation_fn=None, scope=scope) img_height = 64 img_width = 64 num_masks = self.conf['num_masks'] if self.conf['model']=='DNA': if num_masks != 1: raise ValueError('Only one mask is supported for DNA model.') # the total number of masks is num_masks +extra_masks because of background and generated pixels! masks = tf.reshape( tf.nn.softmax(tf.reshape(masks, [-1, num_masks +extra_masks])), [int(self.batch_size), int(img_height), int(img_width), num_masks +extra_masks]) mask_list = tf.split(axis=3, num_or_size_splits=num_masks +extra_masks, value=masks) output = mask_list[0] * background_image assert len(transformed) == len(mask_list[1:]) for layer, mask in zip(transformed, mask_list[1:]): output += layer * mask return output, mask_list def fuse_pix_distrib(self, extra_masks, mask_list, pix_distributions, prev_pix_distrib, transf_distrib): if '1stimg_bckgd' in self.conf: background_pix = pix_distributions[0] if len(background_pix.get_shape()) == 3: background_pix = tf.expand_dims(background_pix, -1) print('using pix_distrib-background from first image..') else: background_pix = prev_pix_distrib pix_distrib_output = mask_list[0] * background_pix if 'gen_pix' in self.conf: pix_distrib_output += mask_list[1] * prev_pix_distrib # assume pixels don't when image is generated from scratch for i in range(self.num_masks): pix_distrib_output += transf_distrib[i] * mask_list[i + extra_masks] pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True) return pix_distrib_output ## Utility functions def stp_transformation(self, prev_image, stp_input, num_masks, reuse= None, suffix = None): """Apply spatial transformer predictor (STP) to previous image. Args: prev_image: previous image to be transformed. stp_input: hidden layer to be used for computing STN parameters. num_masks: number of masks and hence the number of STP transformations. Returns: List of images transformed by the predicted STP parameters. """ # Only import spatial transformer if needed. from spatial_transformer import transformer identity_params = tf.convert_to_tensor( np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32)) transformed = [] trafos = [] for i in range(num_masks): params = slim.layers.fully_connected( stp_input, 6, scope='stp_params' + str(i) + suffix, activation_fn=None, reuse= reuse) + identity_params outsize = (prev_image.get_shape()[1], prev_image.get_shape()[2]) transformed.append(transformer(prev_image, params, outsize)) trafos.append(params) return transformed, trafos def dna_transformation(self, prev_image, dna_input, DNA_KERN_SIZE): """Apply dynamic neural advection to previous image. Args: prev_image: previous image to be transformed. dna_input: hidden lyaer to be used for computing DNA transformation. Returns: List of images transformed by the predicted CDNA kernels. """ # Construct translated images. pad_len = int(np.floor(DNA_KERN_SIZE / 2)) prev_image_pad = tf.pad(prev_image, [[0, 0], [pad_len, pad_len], [pad_len, pad_len], [0, 0]]) image_height = int(prev_image.get_shape()[1]) image_width = int(prev_image.get_shape()[2]) inputs = [] for xkern in range(DNA_KERN_SIZE): for ykern in range(DNA_KERN_SIZE): inputs.append( tf.expand_dims( tf.slice(prev_image_pad, [0, xkern, ykern, 0], [-1, image_height, image_width, -1]), [3])) inputs = tf.concat(axis=3, values=inputs) # Normalize channels to 1. kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT kernel = tf.expand_dims( kernel / tf.reduce_sum( kernel, [3], keepdims=True), [4]) return tf.reduce_sum(kernel * inputs, [3], keepdims=False) def cdna_transformation(self, prev_image, cdna_input, reuse_sc=None): """Apply convolutional dynamic neural advection to previous image. Args: prev_image: previous image to be transformed. cdna_input: hidden lyaer to be used for computing CDNA kernels. num_masks: the number of masks and hence the number of CDNA transformations. color_channels: the number of color channels in the images. Returns: List of images transformed by the predicted CDNA kernels. """ batch_size = int(cdna_input.get_shape()[0]) height = int(prev_image.get_shape()[1]) width = int(prev_image.get_shape()[2]) DNA_KERN_SIZE = self.conf['kern_size'] num_masks = self.conf['num_masks'] color_channels = int(prev_image.get_shape()[3]) # Predict kernels using linear function of last hidden layer. cdna_kerns = slim.layers.fully_connected( cdna_input, DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks, scope='cdna_params', activation_fn=None, reuse = reuse_sc) # Reshape and normalize. cdna_kerns = tf.reshape( cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks]) cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True) cdna_kerns /= norm_factor cdna_kerns_summary = cdna_kerns # Transpose and reshape. cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3]) cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks]) prev_image = tf.transpose(prev_image, [3, 1, 2, 0]) transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME') # Transpose and reshape. transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks]) transformed = tf.transpose(transformed, [3, 1, 2, 0, 4]) transformed = tf.unstack(value=transformed, axis=-1) return transformed, cdna_kerns_summary def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): """Sample batch with specified mix of ground truth and generated data_files points. Args: ground_truth_x: tensor of ground-truth data_files points. generated_x: tensor of generated data_files points. batch_size: batch size num_ground_truth: number of ground-truth examples to include in batch. Returns: New batch with num_ground_truth sampled from ground_truth_x and the rest from generated_x. """ idx = tf.random_shuffle(tf.range(int(batch_size))) ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size))) ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) generated_examps = tf.gather(generated_x, generated_idx) return tf.dynamic_stitch([ground_truth_idx, generated_idx], [ground_truth_examps, generated_examps]) def generator_fn(inputs, mode, hparams): images = tf.unstack(inputs['images'], axis=0) actions = tf.unstack(inputs['actions'], axis=0) states = tf.unstack(inputs['states'], axis=0) pix_distributions1 = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None iter_num = tf.to_float(tf.train.get_or_create_global_step()) if isinstance(hparams.kernel_size, (tuple, list)): kernel_height, kernel_width = hparams.kernel_size assert kernel_height == kernel_width kern_size = kernel_height else: kern_size = hparams.kernel_size schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1 conf = { 'context_frames': hparams.context_frames, # of frames before predictions.' , 'use_state': 1, # 'Whether or not to give the state+action to the model' , 'ngf': hparams.ngf, 'model': hparams.transformation.upper(), # 'model architecture to use - CDNA, DNA, or STP' , 'num_masks': hparams.num_masks, # 'number of masks, usually 1 for DNA, 10 for CDNA, STN.' , 'schedsamp_k': schedule_sampling_k, # 'The k hyperparameter for scheduled sampling -1 for no scheduled sampling.' , 'kern_size': kern_size, # size of DNA kerns } if hparams.first_image_background: conf['1stimg_bckgd'] = '' if hparams.generate_scratch_image: conf['gen_pix'] = '' m = Prediction_Model(images, actions, states, pix_distributions1=pix_distributions1, iter_num=iter_num, conf=conf) m.build() outputs = { 'gen_images': tf.stack(m.gen_images, axis=0), 'gen_states': tf.stack(m.gen_states, axis=0), } if 'pix_distribs' in inputs: outputs['gen_pix_distribs'] = tf.stack(m.gen_distrib1, axis=0) return outputs class SNAVideoPredictionModel(VideoPredictionModel): def __init__(self, *args, **kwargs): super(SNAVideoPredictionModel, self).__init__( generator_fn, *args, **kwargs) def get_default_hparams_dict(self): default_hparams = super(SNAVideoPredictionModel, self).get_default_hparams_dict() hparams = dict( batch_size=32, l1_weight=0.0, l2_weight=1.0, ngf=16, transformation='cdna', kernel_size=(5, 5), num_masks=10, first_image_background=True, generate_scratch_image=True, schedule_sampling_k=900.0, ) return dict(itertools.chain(default_hparams.items(), hparams.items()))