# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Model architecture for predictive model, including CDNA, DNA, and STP."""

import numpy as np
import tensorflow as tf

import tensorflow.contrib.slim as slim
from tensorflow.contrib.layers.python import layers as tf_layers
from lstm_ops import basic_conv_lstm_cell

# Amount to use when lower bounding tensors
RELU_SHIFT = 1e-12

# kernel size for DNA and CDNA.
DNA_KERN_SIZE = 5


def construct_model(images,
                    actions=None,
                    states=None,
                    iter_num=-1.0,
                    k=-1,
                    use_state=True,
                    num_masks=10,
                    stp=False,
                    cdna=True,
                    dna=False,
                    context_frames=2):
  """Build convolutional lstm video predictor using STP, CDNA, or DNA.

  Args:
    images: tensor of ground truth image sequences
    actions: tensor of action sequences
    states: tensor of ground truth state sequences
    iter_num: tensor of the current training iteration (for sched. sampling)
    k: constant used for scheduled sampling. -1 to feed in own prediction.
    use_state: True to include state and action in prediction
    num_masks: the number of different pixel motion predictions (and
               the number of masks for each of those predictions)
    stp: True to use Spatial Transformer Predictor (STP)
    cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
    dna: True to use Dynamic Neural Advection (DNA)
    context_frames: number of ground truth frames to pass in before
                    feeding in own predictions
  Returns:
    gen_images: predicted future image frames
    gen_states: predicted future states

  Raises:
    ValueError: if more than one network option specified or more than 1 mask
    specified for DNA model.
  """
  if stp + cdna + dna != 1:
    raise ValueError('More than one, or no network option specified.')
  batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4]
  lstm_func = basic_conv_lstm_cell

  # Generated robot states and images.
  gen_states, gen_images = [], []
  current_state = states[0]

  if k == -1:
    feedself = True
  else:
    # Scheduled sampling:
    # Calculate number of ground-truth frames to pass in.
    num_ground_truth = tf.to_int32(
        tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k)))))
    feedself = False

  # LSTM state sizes and states.
  lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32]))
  lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None
  lstm_state5, lstm_state6, lstm_state7 = None, None, None

  for image, action in zip(images[:-1], actions[:-1]):
    # Reuse variables after the first timestep.
    reuse = bool(gen_images)

    done_warm_start = len(gen_images) > context_frames - 1
    with slim.arg_scope(
        [lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
         tf_layers.layer_norm, slim.layers.conv2d_transpose],
        reuse=reuse):

      if feedself and done_warm_start:
        # Feed in generated image.
        prev_image = gen_images[-1]
      elif done_warm_start:
        # Scheduled sampling
        prev_image = scheduled_sample(image, gen_images[-1], batch_size,
                                      num_ground_truth)
      else:
        # Always feed in ground_truth
        prev_image = image

      # Predicted state is always fed back in
      state_action = tf.concat(1, [action, current_state])

      enc0 = slim.layers.conv2d(
          prev_image,
          32, [5, 5],
          stride=2,
          scope='scale1_conv1',
          normalizer_fn=tf_layers.layer_norm,
          normalizer_params={'scope': 'layer_norm1'})

      hidden1, lstm_state1 = lstm_func(
          enc0, lstm_state1, lstm_size[0], scope='state1')
      hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')
      hidden2, lstm_state2 = lstm_func(
          hidden1, lstm_state2, lstm_size[1], scope='state2')
      hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3')
      enc1 = slim.layers.conv2d(
          hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2')

      hidden3, lstm_state3 = lstm_func(
          enc1, lstm_state3, lstm_size[2], scope='state3')
      hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4')
      hidden4, lstm_state4 = lstm_func(
          hidden3, lstm_state4, lstm_size[3], scope='state4')
      hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5')
      enc2 = slim.layers.conv2d(
          hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3')

      # Pass in state and action.
      smear = tf.reshape(
          state_action,
          [int(batch_size), 1, 1, int(state_action.get_shape()[1])])
      smear = tf.tile(
          smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1])
      if use_state:
        enc2 = tf.concat(3, [enc2, smear])
      enc3 = slim.layers.conv2d(
          enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4')

      hidden5, lstm_state5 = lstm_func(
          enc3, lstm_state5, lstm_size[4], scope='state5')  # last 8x8
      hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6')
      enc4 = slim.layers.conv2d_transpose(
          hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')

      hidden6, lstm_state6 = lstm_func(
          enc4, lstm_state6, lstm_size[5], scope='state6')  # 16x16
      hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7')
      # Skip connection.
      hidden6 = tf.concat(3, [hidden6, enc1])  # both 16x16

      enc5 = slim.layers.conv2d_transpose(
          hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2')
      hidden7, lstm_state7 = lstm_func(
          enc5, lstm_state7, lstm_size[6], scope='state7')  # 32x32
      hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8')

      # Skip connection.
      hidden7 = tf.concat(3, [hidden7, enc0])  # both 32x32

      enc6 = slim.layers.conv2d_transpose(
          hidden7,
          hidden7.get_shape()[3], 3, stride=2, scope='convt3',
          normalizer_fn=tf_layers.layer_norm,
          normalizer_params={'scope': 'layer_norm9'})

      if dna:
        # Using largest hidden state for predicting untied conv kernels.
        enc7 = slim.layers.conv2d_transpose(
            enc6, DNA_KERN_SIZE**2, 1, stride=1, scope='convt4')
      else:
        # Using largest hidden state for predicting a new image layer.
        enc7 = slim.layers.conv2d_transpose(
            enc6, color_channels, 1, stride=1, scope='convt4')
        # This allows the network to also generate one image from scratch,
        # which is useful when regions of the image become unoccluded.
        transformed = [tf.nn.sigmoid(enc7)]

      if stp:
        stp_input0 = tf.reshape(hidden5, [int(batch_size), -1])
        stp_input1 = slim.layers.fully_connected(
            stp_input0, 100, scope='fc_stp')
        transformed += stp_transformation(prev_image, stp_input1, num_masks)
      elif cdna:
        cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
        transformed += cdna_transformation(prev_image, cdna_input, num_masks,
                                           int(color_channels))
      elif dna:
        # Only one mask is supported (more should be unnecessary).
        if num_masks != 1:
          raise ValueError('Only one mask is supported for DNA model.')
        transformed = [dna_transformation(prev_image, enc7)]

      masks = slim.layers.conv2d_transpose(
          enc6, num_masks + 1, 1, stride=1, scope='convt7')
      masks = tf.reshape(
          tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
          [int(batch_size), int(img_height), int(img_width), num_masks + 1])
      mask_list = tf.split(3, num_masks + 1, masks)
      output = mask_list[0] * prev_image
      for layer, mask in zip(transformed, mask_list[1:]):
        output += layer * mask
      gen_images.append(output)

      current_state = slim.layers.fully_connected(
          state_action,
          int(current_state.get_shape()[1]),
          scope='state_pred',
          activation_fn=None)
      gen_states.append(current_state)

  return gen_images, gen_states


## Utility functions
def stp_transformation(prev_image, stp_input, num_masks):
  """Apply spatial transformer predictor (STP) to previous image.

  Args:
    prev_image: previous image to be transformed.
    stp_input: hidden layer to be used for computing STN parameters.
    num_masks: number of masks and hence the number of STP transformations.
  Returns:
    List of images transformed by the predicted STP parameters.
  """
  # Only import spatial transformer if needed.
  from spatial_transformer import transformer

  identity_params = tf.convert_to_tensor(
      np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
  transformed = []
  for i in range(num_masks - 1):
    params = slim.layers.fully_connected(
        stp_input, 6, scope='stp_params' + str(i),
        activation_fn=None) + identity_params
    transformed.append(transformer(prev_image, params))

  return transformed


def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
  """Apply convolutional dynamic neural advection to previous image.

  Args:
    prev_image: previous image to be transformed.
    cdna_input: hidden lyaer to be used for computing CDNA kernels.
    num_masks: the number of masks and hence the number of CDNA transformations.
    color_channels: the number of color channels in the images.
  Returns:
    List of images transformed by the predicted CDNA kernels.
  """
  batch_size = int(cdna_input.get_shape()[0])

  # Predict kernels using linear function of last hidden layer.
  cdna_kerns = slim.layers.fully_connected(
      cdna_input,
      DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks,
      scope='cdna_params',
      activation_fn=None)

  # Reshape and normalize.
  cdna_kerns = tf.reshape(
      cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks])
  cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
  norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
  cdna_kerns /= norm_factor

  cdna_kerns = tf.tile(cdna_kerns, [1, 1, 1, color_channels, 1])
  cdna_kerns = tf.split(0, batch_size, cdna_kerns)
  prev_images = tf.split(0, batch_size, prev_image)

  # Transform image.
  transformed = []
  for kernel, preimg in zip(cdna_kerns, prev_images):
    kernel = tf.squeeze(kernel)
    if len(kernel.get_shape()) == 3:
      kernel = tf.expand_dims(kernel, -1)
    transformed.append(
        tf.nn.depthwise_conv2d(preimg, kernel, [1, 1, 1, 1], 'SAME'))
  transformed = tf.concat(0, transformed)
  transformed = tf.split(3, num_masks, transformed)
  return transformed


def dna_transformation(prev_image, dna_input):
  """Apply dynamic neural advection to previous image.

  Args:
    prev_image: previous image to be transformed.
    dna_input: hidden lyaer to be used for computing DNA transformation.
  Returns:
    List of images transformed by the predicted CDNA kernels.
  """
  # Construct translated images.
  prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]])
  image_height = int(prev_image.get_shape()[1])
  image_width = int(prev_image.get_shape()[2])

  inputs = []
  for xkern in range(DNA_KERN_SIZE):
    for ykern in range(DNA_KERN_SIZE):
      inputs.append(
          tf.expand_dims(
              tf.slice(prev_image_pad, [0, xkern, ykern, 0],
                       [-1, image_height, image_width, -1]), [3]))
  inputs = tf.concat(3, inputs)

  # Normalize channels to 1.
  kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
  kernel = tf.expand_dims(
      kernel / tf.reduce_sum(
          kernel, [3], keep_dims=True), [4])
  return tf.reduce_sum(kernel * inputs, [3], keep_dims=False)


def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
  """Sample batch with specified mix of ground truth and generated data points.

  Args:
    ground_truth_x: tensor of ground-truth data points.
    generated_x: tensor of generated data points.
    batch_size: batch size
    num_ground_truth: number of ground-truth examples to include in batch.
  Returns:
    New batch with num_ground_truth sampled from ground_truth_x and the rest
    from generated_x.
  """
  idx = tf.random_shuffle(tf.range(int(batch_size)))
  ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
  generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

  ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
  generated_examps = tf.gather(generated_x, generated_idx)
  return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                           [ground_truth_examps, generated_examps])