# Copyright 2016 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Model architecture for predictive model, including CDNA, DNA, and STP.""" import itertools import numpy as np import tensorflow as tf import tensorflow.contrib.slim as slim from tensorflow.contrib.layers.python import layers as tf_layers from video_prediction.models import VideoPredictionModel from .sna_model import basic_conv_lstm_cell # Amount to use when lower bounding tensors RELU_SHIFT = 1e-12 def construct_model(images, actions=None, states=None, iter_num=-1.0, kernel_size=(5, 5), k=-1, use_state=True, num_masks=10, stp=False, cdna=True, dna=False, context_frames=2, pix_distributions=None): """Build convolutional lstm video predictor using STP, CDNA, or DNA. Args: images: tensor of ground truth image sequences actions: tensor of action sequences states: tensor of ground truth state sequences iter_num: tensor of the current training iteration (for sched. sampling) k: constant used for scheduled sampling. -1 to feed in own prediction. use_state: True to include state and action in prediction num_masks: the number of different pixel motion predictions (and the number of masks for each of those predictions) stp: True to use Spatial Transformer Predictor (STP) cdna: True to use Convoluational Dynamic Neural Advection (CDNA) dna: True to use Dynamic Neural Advection (DNA) context_frames: number of ground truth frames to pass in before feeding in own predictions Returns: gen_images: predicted future image frames gen_states: predicted future states Raises: ValueError: if more than one network option specified or more than 1 mask specified for DNA model. """ DNA_KERN_SIZE = kernel_size[0] if stp + cdna + dna != 1: raise ValueError('More than one, or no network option specified.') batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4] lstm_func = basic_conv_lstm_cell # Generated robot states and images. gen_states, gen_images = [], [] gen_pix_distrib = [] gen_masks = [] current_state = states[0] if k == -1: feedself = True else: # Scheduled sampling: # Calculate number of ground-truth frames to pass in. num_ground_truth = tf.to_int32( tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k))))) feedself = False # LSTM state sizes and states. lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32])) lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None lstm_state5, lstm_state6, lstm_state7 = None, None, None for t, action in enumerate(actions): # Reuse variables after the first timestep. reuse = bool(gen_images) done_warm_start = len(gen_images) > context_frames - 1 with slim.arg_scope( [lstm_func, slim.layers.conv2d, slim.layers.fully_connected, tf_layers.layer_norm, slim.layers.conv2d_transpose], reuse=reuse): if feedself and done_warm_start: # Feed in generated image. prev_image = gen_images[-1] if pix_distributions is not None: prev_pix_distrib = gen_pix_distrib[-1] elif done_warm_start: # Scheduled sampling prev_image = scheduled_sample(images[t], gen_images[-1], batch_size, num_ground_truth) else: # Always feed in ground_truth prev_image = images[t] if pix_distributions is not None: prev_pix_distrib = pix_distributions[t] # prev_pix_distrib = tf.expand_dims(prev_pix_distrib, -1) # Predicted state is always fed back in state_action = tf.concat(axis=1, values=[action, current_state]) enc0 = slim.layers.conv2d( prev_image, 32, [5, 5], stride=2, scope='scale1_conv1', normalizer_fn=tf_layers.layer_norm, normalizer_params={'scope': 'layer_norm1'}) hidden1, lstm_state1 = lstm_func( enc0, lstm_state1, lstm_size[0], scope='state1') hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2') hidden2, lstm_state2 = lstm_func( hidden1, lstm_state2, lstm_size[1], scope='state2') hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3') enc1 = slim.layers.conv2d( hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2') hidden3, lstm_state3 = lstm_func( enc1, lstm_state3, lstm_size[2], scope='state3') hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4') hidden4, lstm_state4 = lstm_func( hidden3, lstm_state4, lstm_size[3], scope='state4') hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5') enc2 = slim.layers.conv2d( hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3') # Pass in state and action. smear = tf.reshape( state_action, [int(batch_size), 1, 1, int(state_action.get_shape()[1])]) smear = tf.tile( smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1]) if use_state: enc2 = tf.concat(axis=3, values=[enc2, smear]) enc3 = slim.layers.conv2d( enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4') hidden5, lstm_state5 = lstm_func( enc3, lstm_state5, lstm_size[4], scope='state5') # last 8x8 hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6') enc4 = slim.layers.conv2d_transpose( hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1') hidden6, lstm_state6 = lstm_func( enc4, lstm_state6, lstm_size[5], scope='state6') # 16x16 hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7') # Skip connection. hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 enc5 = slim.layers.conv2d_transpose( hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2') hidden7, lstm_state7 = lstm_func( enc5, lstm_state7, lstm_size[6], scope='state7') # 32x32 hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8') # Skip connection. hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 enc6 = slim.layers.conv2d_transpose( hidden7, hidden7.get_shape()[3], 3, stride=2, scope='convt3', normalizer_fn=tf_layers.layer_norm, normalizer_params={'scope': 'layer_norm9'}) if dna: # Using largest hidden state for predicting untied conv kernels. enc7 = slim.layers.conv2d_transpose( enc6, DNA_KERN_SIZE ** 2, 1, stride=1, scope='convt4') else: # Using largest hidden state for predicting a new image layer. enc7 = slim.layers.conv2d_transpose( enc6, color_channels, 1, stride=1, scope='convt4') # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. transformed = [tf.nn.sigmoid(enc7)] if stp: stp_input0 = tf.reshape(hidden5, [int(batch_size), -1]) stp_input1 = slim.layers.fully_connected( stp_input0, 100, scope='fc_stp') # disabling capability to generete pixels reuse_stp = None if reuse: reuse_stp = reuse transformed = stp_transformation(prev_image, stp_input1, num_masks, reuse_stp) # transformed += stp_transformation(prev_image, stp_input1, num_masks) if pix_distributions is not None: transf_distrib = stp_transformation(prev_pix_distrib, stp_input1, num_masks, reuse=True) elif cdna: cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) new_transformed, cdna_kerns = cdna_transformation(prev_image, cdna_input, num_masks, int(color_channels), kernel_size, reuse_sc=reuse) transformed += new_transformed if pix_distributions is not None: if not dna: transf_distrib = [prev_pix_distrib] new_transf_distrib, _ = cdna_transformation(prev_pix_distrib, cdna_input, num_masks, prev_pix_distrib.shape[-1].value, kernel_size, reuse_sc=True) transf_distrib += new_transf_distrib elif dna: # Only one mask is supported (more should be unnecessary). if num_masks != 1: raise ValueError('Only one mask is supported for DNA model.') transformed = [dna_transformation(prev_image, enc7, DNA_KERN_SIZE)] masks = slim.layers.conv2d_transpose( enc6, num_masks + 1, 1, stride=1, scope='convt7') masks = tf.reshape( tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), [int(batch_size), int(img_height), int(img_width), num_masks + 1]) mask_list = tf.split(masks, num_masks + 1, axis=3) output = mask_list[0] * prev_image for layer, mask in zip(transformed, mask_list[1:]): output += layer * mask gen_images.append(output) gen_masks.append(mask_list) if dna and pix_distributions is not None: transf_distrib = [dna_transformation(prev_pix_distrib, enc7, DNA_KERN_SIZE)] if pix_distributions is not None: pix_distrib_output = mask_list[0] * prev_pix_distrib for layer, mask in zip(transf_distrib, mask_list[1:]): pix_distrib_output += layer * mask pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True) gen_pix_distrib.append(pix_distrib_output) if int(current_state.get_shape()[1]) == 0: current_state = tf.zeros_like(state_action) else: current_state = slim.layers.fully_connected( state_action, int(current_state.get_shape()[1]), scope='state_pred', activation_fn=None) gen_states.append(current_state) return gen_images, gen_states, gen_masks, gen_pix_distrib ## Utility functions def stp_transformation(prev_image, stp_input, num_masks): """Apply spatial transformer predictor (STP) to previous image. Args: prev_image: previous image to be transformed. stp_input: hidden layer to be used for computing STN parameters. num_masks: number of masks and hence the number of STP transformations. Returns: List of images transformed by the predicted STP parameters. """ # Only import spatial transformer if needed. from spatial_transformer import transformer identity_params = tf.convert_to_tensor( np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32)) transformed = [] for i in range(num_masks - 1): params = slim.layers.fully_connected( stp_input, 6, scope='stp_params' + str(i), activation_fn=None) + identity_params transformed.append(transformer(prev_image, params)) return transformed def cdna_transformation(prev_image, cdna_input, num_masks, color_channels, kernel_size, reuse_sc=None): """Apply convolutional dynamic neural advection to previous image. Args: prev_image: previous image to be transformed. cdna_input: hidden lyaer to be used for computing CDNA kernels. num_masks: the number of masks and hence the number of CDNA transformations. color_channels: the number of color channels in the images. Returns: List of images transformed by the predicted CDNA kernels. """ batch_size = int(cdna_input.get_shape()[0]) height = int(prev_image.get_shape()[1]) width = int(prev_image.get_shape()[2]) # Predict kernels using linear function of last hidden layer. cdna_kerns = slim.layers.fully_connected( cdna_input, kernel_size[0] * kernel_size[1] * num_masks, scope='cdna_params', activation_fn=None, reuse=reuse_sc) # Reshape and normalize. cdna_kerns = tf.reshape( cdna_kerns, [batch_size, kernel_size[0], kernel_size[1], 1, num_masks]) cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True) cdna_kerns /= norm_factor # Treat the color channel dimension as the batch dimension since the same # transformation is applied to each color channel. # Treat the batch dimension as the channel dimension so that # depthwise_conv2d can apply a different transformation to each sample. cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3]) cdna_kerns = tf.reshape(cdna_kerns, [kernel_size[0], kernel_size[1], batch_size, num_masks]) # Swap the batch and channel dimensions. prev_image = tf.transpose(prev_image, [3, 1, 2, 0]) # Transform image. transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME') # Transpose the dimensions to where they belong. transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks]) transformed = tf.transpose(transformed, [3, 1, 2, 0, 4]) transformed = tf.unstack(transformed, axis=-1) return transformed, cdna_kerns def dna_transformation(prev_image, dna_input, kernel_size): """Apply dynamic neural advection to previous image. Args: prev_image: previous image to be transformed. dna_input: hidden lyaer to be used for computing DNA transformation. Returns: List of images transformed by the predicted CDNA kernels. """ # Construct translated images. pad_along_height = (kernel_size[0] - 1) pad_along_width = (kernel_size[1] - 1) pad_top = pad_along_height // 2 pad_bottom = pad_along_height - pad_top pad_left = pad_along_width // 2 pad_right = pad_along_width - pad_left prev_image_pad = tf.pad(prev_image, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) image_height = int(prev_image.get_shape()[1]) image_width = int(prev_image.get_shape()[2]) inputs = [] for xkern in range(kernel_size[0]): for ykern in range(kernel_size[1]): inputs.append( tf.expand_dims( tf.slice(prev_image_pad, [0, xkern, ykern, 0], [-1, image_height, image_width, -1]), [3])) inputs = tf.concat(axis=3, values=inputs) # Normalize channels to 1. kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT kernel = tf.expand_dims( kernel / tf.reduce_sum( kernel, [3], keepdims=True), [4]) return tf.reduce_sum(kernel * inputs, [3], keepdims=False) def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): """Sample batch with specified mix of ground truth and generated data points. Args: ground_truth_x: tensor of ground-truth data points. generated_x: tensor of generated data points. batch_size: batch size num_ground_truth: number of ground-truth examples to include in batch. Returns: New batch with num_ground_truth sampled from ground_truth_x and the rest from generated_x. """ idx = tf.random_shuffle(tf.range(int(batch_size))) ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size))) ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) generated_examps = tf.gather(generated_x, generated_idx) return tf.dynamic_stitch([ground_truth_idx, generated_idx], [ground_truth_examps, generated_examps]) def generator_fn(inputs, hparams=None): images = tf.unstack(inputs['images'], axis=0) actions = tf.unstack(inputs['actions'], axis=0) states = tf.unstack(inputs['states'], axis=0) pix_distributions = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None iter_num = tf.to_float(tf.train.get_or_create_global_step()) gen_images, gen_states, gen_masks, gen_pix_distrib = \ construct_model(images, actions, states, iter_num=iter_num, kernel_size=hparams.kernel_size, k=hparams.schedule_sampling_k, num_masks=hparams.num_masks, cdna=hparams.transformation == 'cdna', dna=hparams.transformation == 'dna', stp=hparams.transformation == 'stp', context_frames=hparams.context_frames, pix_distributions=pix_distributions) outputs = { 'gen_images': tf.stack(gen_images, axis=0), 'gen_states': tf.stack(gen_states, axis=0), 'masks': tf.stack([tf.stack(gen_mask_list, axis=-1) for gen_mask_list in gen_masks], axis=0), } if 'pix_distribs' in inputs: outputs['gen_pix_distribs'] = tf.stack(gen_pix_distrib, axis=0) gen_images = outputs['gen_images'][hparams.context_frames - 1:] return gen_images, outputs class DNAVideoPredictionModel(VideoPredictionModel): def __init__(self, *args, **kwargs): super(DNAVideoPredictionModel, self).__init__( generator_fn, *args, **kwargs) def get_default_hparams_dict(self): default_hparams = super(DNAVideoPredictionModel, self).get_default_hparams_dict() hparams = dict( batch_size=32, l1_weight=0.0, l2_weight=1.0, transformation='cdna', kernel_size=(9, 9), num_masks=10, schedule_sampling_k=900.0, ) return dict(itertools.chain(default_hparams.items(), hparams.items())) def parse_hparams(self, hparams_dict, hparams): hparams = super(DNAVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) if self.mode == 'test': def override_hparams_maybe(name, value): orig_value = hparams.values()[name] if orig_value != value: print('Overriding hparams from %s=%r to %r for mode=%s.' % (name, orig_value, value, self.mode)) hparams.set_hparam(name, value) override_hparams_maybe('schedule_sampling_k', -1) return hparams