# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains a collection of models which operate on variable-length sequences.
"""
import math

import models
import video_level_models
import tensorflow as tf
import model_utils as utils

import tensorflow.contrib.slim as slim
from tensorflow import flags
from tensorflow import logging

FLAGS = flags.FLAGS


class RangeLogisticModel(models.BaseModel):

  def create_model(self, model_input, vocab_size, num_frames, **unused_params):
    """Creates a model which uses a logistic classifier over the average of the
    frame-level features.

    This class is intended to be an example for implementors of frame level
    models. If you want to train a model over averaged features it is more
    efficient to average them beforehand rather than on the fly.

    Args:
      model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
                   input features.
      vocab_size: The number of classes in the dataset.
      num_frames: A vector of length 'batch' which indicates the number of
           frames for each video (before padding).

    Returns:
      A dictionary with a tensor containing the probability predictions of the
      model in the 'predictions' key. The dimensions of the tensor are
      'batch_size' x 'num_classes'.
    """
#    num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
#    feature_size = model_input.get_shape().as_list()[2]

#    denominators = tf.reshape(
#        tf.tile(num_frames, [1, feature_size]), [-1, feature_size])
#    avg_pooled = tf.reduce_sum(model_input,
#                               axis=[1]) / denominators
    range_pooled = tf.reduce_max(model_input, axis=[1]) - \
                    tf.reduce_min(model_input, axis=[1])
    output = slim.fully_connected(
        range_pooled, vocab_size, activation_fn=tf.nn.sigmoid,
        weights_regularizer=slim.l2_regularizer(1e-4))
    return {"predictions": output}

class FNN_mvt_Model(models.BaseModel):

  def create_model(self, model_input, vocab_size, num_frames,
                   l2_penalty=1e-4, is_training=True, **unused_params):
    """Creates a model which uses a logistic classifier over the average of the
    frame-level features.

    This class is intended to be an example for implementors of frame level
    models. If you want to train a model over averaged features it is more
    efficient to average them beforehand rather than on the fly.

    Args:
      model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
                   input features.
      vocab_size: The number of classes in the dataset.
      num_frames: A vector of length 'batch' which indicates the number of
           frames for each video (before padding).

    Returns:
      A dictionary with a tensor containing the probability predictions of the
      model in the 'predictions' key. The dimensions of the tensor are
      'batch_size' x 'num_classes'.
    """
    
    inter_f_mean, inter_f_var = tf.nn.moments(model_input, [1])
    inter_f_std = tf.sqrt(inter_f_var)
    
    kk = 3
    xt = tf.transpose(model_input, perm=[0,2,1])
    tk = tf.nn.top_k(xt, kk).values    

    logging.info( 'xt:   {}'.format(xt.get_shape().as_list() ))
    logging.info( 'tk:   {}'.format(tk.get_shape().as_list() )) 
    
    topk = tf.reshape(tk, [-1, kk * tk.get_shape().as_list()[1]])
    logging.info( 'topk: {}'.format(topk.get_shape().as_list() )) 
 
#    inter_f_feats = tf.concat([inter_f_mean, inter_f_std], 1)
    inter_f_feats = tf.concat([inter_f_mean, inter_f_std, topk], 1)
    
    logging.info('inter_f_mean: {}'.format(inter_f_mean.get_shape().as_list()))
    logging.info( 'feats: {}'.format(inter_f_feats.get_shape().as_list() )) 
    
    tf.summary.histogram("inter_f_mean", inter_f_mean)
    tf.summary.histogram("inter_f_std", inter_f_std)
    
    with tf.name_scope('FNN_mvt_Model'):
        A0 = slim.batch_norm(
          inter_f_feats,
          center=True,
          scale=True,
          is_training=is_training,
          scope="BN")
        
        h1Units = 3600
        A1 = slim.fully_connected(
                A0, h1Units, activation_fn=tf.nn.relu,
                weights_regularizer=slim.l2_regularizer(l2_penalty),
                scope='FC_H1')
        output = slim.fully_connected(
                A1, vocab_size, activation_fn=tf.nn.sigmoid,
                weights_regularizer=slim.l2_regularizer(l2_penalty),
                scope='FC_P')
    return {"predictions": output}

class DbofModel2(models.BaseModel):
  """Creates a Deep Bag of Frames model.

  The model projects the features for each frame into a higher dimensional
  'clustering' space, pools across frames in that space, and then
  uses a configurable video-level model to classify the now aggregated features.

  The model will randomly sample either frames or sequences of frames during
  training to speed up convergence.

  Args:
    model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
                 input features.
    vocab_size: The number of classes in the dataset.
    num_frames: A vector of length 'batch' which indicates the number of
         frames for each video (before padding).

  Returns:
    A dictionary with a tensor containing the probability predictions of the
    model in the 'predictions' key. The dimensions of the tensor are
    'batch_size' x 'num_classes'.
  """

  def create_model(self,
                   model_input,
                   vocab_size,
                   num_frames,
                   iterations=None,
                   add_batch_norm=None,
                   sample_random_frames=None,
                   cluster_size=None,
                   hidden_size=None,
                   is_training=True,
                   **unused_params):
    iterations = iterations or FLAGS.iterations
    add_batch_norm = add_batch_norm or FLAGS.dbof_add_batch_norm
    random_frames = sample_random_frames or FLAGS.sample_random_frames
    cluster_size = cluster_size or FLAGS.dbof_cluster_size
    hidden1_size = hidden_size or FLAGS.dbof_hidden_size

    num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
    if random_frames:
      model_input = utils.SampleRandomFrames(model_input, num_frames,
                                             iterations)
    else:
      model_input = utils.SampleRandomSequence(model_input, num_frames,
                                               iterations)
    max_frames = model_input.get_shape().as_list()[1]
    feature_size = model_input.get_shape().as_list()[2]
    reshaped_input = tf.reshape(model_input, [-1, feature_size])
    tf.summary.histogram("input_hist", reshaped_input)

    if add_batch_norm:
      reshaped_input = slim.batch_norm(
          reshaped_input,
          center=True,
          scale=True,
          is_training=is_training,
          scope="input_bn")

    cluster_weights = tf.get_variable("cluster_weights",
      [feature_size, cluster_size],
      initializer = tf.random_normal_initializer(stddev=1 / math.sqrt(feature_size)))
    tf.summary.histogram("cluster_weights", cluster_weights)
    activation = tf.matmul(reshaped_input, cluster_weights)
    if add_batch_norm:
      activation = slim.batch_norm(
          activation,
          center=True,
          scale=True,
          is_training=is_training,
          scope="cluster_bn")
    else:
      cluster_biases = tf.get_variable("cluster_biases",
        [cluster_size],
        initializer = tf.random_normal(stddev=1 / math.sqrt(feature_size)))
      tf.summary.histogram("cluster_biases", cluster_biases)
      activation += cluster_biases
    activation = tf.nn.relu6(activation)
    tf.summary.histogram("cluster_output", activation)

    activation = tf.reshape(activation, [-1, max_frames, cluster_size])
    activation = utils.FramePooling(activation, FLAGS.dbof_pooling_method)

    hidden1_weights = tf.get_variable("hidden1_weights",
      [cluster_size, hidden1_size],
      initializer=tf.random_normal_initializer(stddev=1 / math.sqrt(cluster_size)))
    tf.summary.histogram("hidden1_weights", hidden1_weights)
    activation = tf.matmul(activation, hidden1_weights)
    if add_batch_norm:
      activation = slim.batch_norm(
          activation,
          center=True,
          scale=True,
          is_training=is_training,
          scope="hidden1_bn")
    else:
      hidden1_biases = tf.get_variable("hidden1_biases",
        [hidden1_size],
        initializer = tf.random_normal_initializer(stddev=0.01))
      tf.summary.histogram("hidden1_biases", hidden1_biases)
      activation += hidden1_biases
    activation = tf.nn.relu6(activation)
    tf.summary.histogram("hidden1_output", activation)

    aggregated_model = getattr(video_level_models,
                               FLAGS.video_level_classifier_model)
    return aggregated_model().create_model(
        model_input=activation,
        vocab_size=vocab_size,
        **unused_params)

class LstmModel2(models.BaseModel):

  def create_model(self, model_input, vocab_size, num_frames, **unused_params):
    """Creates a model which uses a stack of LSTMs to represent the video.

    Args:
      model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
                   input features.
      vocab_size: The number of classes in the dataset.
      num_frames: A vector of length 'batch' which indicates the number of
           frames for each video (before padding).

    Returns:
      A dictionary with a tensor containing the probability predictions of the
      model in the 'predictions' key. The dimensions of the tensor are
      'batch_size' x 'num_classes'.
    """
    lstm_size = FLAGS.lstm_cells
    number_of_layers = FLAGS.lstm_layers

    ## Batch normalize the input
    stacked_lstm = tf.contrib.rnn.MultiRNNCell(
            [
                tf.contrib.rnn.BasicLSTMCell(
                    lstm_size, forget_bias=1.0, state_is_tuple=False)
                for _ in range(number_of_layers)
                ], state_is_tuple=False)

    #loss = 0.0
    with tf.variable_scope("RNN"):
      outputs, state = tf.nn.dynamic_rnn(stacked_lstm, model_input,
                                         sequence_length=num_frames,
                                         dtype=tf.float32)

    aggregated_model = getattr(video_level_models,
                               FLAGS.video_level_classifier_model)
    return aggregated_model().create_model(
        model_input=state,
        vocab_size=vocab_size,
        num_mixtures=2,
        **unused_params)

class FMoeModel1(models.BaseModel):

  def create_model(self, model_input, vocab_size, num_frames,
                   l2_penalty=1e-4, is_training=True, **unused_params):
    """Creates a model which uses a logistic classifier over the average of the
    frame-level features.

    This class is intended to be an example for implementors of frame level
    models. If you want to train a model over averaged features it is more
    efficient to average them beforehand rather than on the fly.

    Args:
      model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
                   input features.
      vocab_size: The number of classes in the dataset.
      num_frames: A vector of length 'batch' which indicates the number of
           frames for each video (before padding).

    Returns:
      A dictionary with a tensor containing the probability predictions of the
      model in the 'predictions' key. The dimensions of the tensor are
      'batch_size' x 'num_classes'.
    """

                          
    inter_f_mean, inter_f_var = tf.nn.moments(model_input, [1])
    inter_f_std = tf.sqrt(inter_f_var)
    
    kk = 5
    xt = tf.transpose(model_input, perm=[0,2,1])
    tk = tf.nn.top_k(xt, kk).values    

    logging.info( 'xt:   {}'.format(xt.get_shape().as_list() ))
    logging.info( 'tk:   {}'.format(tk.get_shape().as_list() )) 
    
    topk = tf.reshape(tk, [-1, kk * tk.get_shape().as_list()[1]])
    logging.info( 'topk: {}'.format(topk.get_shape().as_list() )) 
 
#    inter_f_feats = tf.concat([inter_f_mean, inter_f_std], 1)
    inter_f_feats = tf.concat([inter_f_mean, inter_f_std, topk], 1)
    
    logging.info('inter_f_mean: {}'.format(inter_f_mean.get_shape().as_list()))
    logging.info( 'feats: {}'.format(inter_f_feats.get_shape().as_list() )) 
    
    tf.summary.histogram("inter_f_mean", inter_f_mean)
    tf.summary.histogram("inter_f_std", inter_f_std)
        
    A0 = slim.batch_norm(
          inter_f_feats,
          center=True,
          scale=True,
          is_training=is_training,
          scope="BN")
    
    aggregated_model = getattr(video_level_models,
                               FLAGS.video_level_classifier_model)
    return aggregated_model().create_model(
        model_input=A0,
        vocab_size=vocab_size,
        num_mixtures=2,
        **unused_params)
    
class FMoeModel2(models.BaseModel):

  def create_model(self, model_input, vocab_size, num_frames,
                   l2_penalty=1e-4, **unused_params):
    """Creates a model which uses a logistic classifier over the average of the
    frame-level features.

    This class is intended to be an example for implementors of frame level
    models. If you want to train a model over averaged features it is more
    efficient to average them beforehand rather than on the fly.

    Args:
      model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
                   input features.
      vocab_size: The number of classes in the dataset.
      num_frames: A vector of length 'batch' which indicates the number of
           frames for each video (before padding).

    Returns:
      A dictionary with a tensor containing the probability predictions of the
      model in the 'predictions' key. The dimensions of the tensor are
      'batch_size' x 'num_classes'.
    """
#    num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
#    feature_size = model_input.get_shape().as_list()[2]
#        
#    logging.info('model_input shape: {}'.format(
#            model_input.get_shape().as_list()))
#
#    denominators = tf.reshape(
#        tf.tile(num_frames, [1, feature_size]), [-1, feature_size])
#    avg_pooled = tf.reduce_sum(model_input, axis=[1]) / denominators
    
    avg_pooled = utils.FramePooling(model_input, 'average')
    
    logging.info( 'avg_pooled shape: {}'.format(
            avg_pooled.get_shape().as_list() )) 
    
    aggregated_model = getattr(video_level_models,
                               FLAGS.video_level_classifier_model)
    return aggregated_model().create_model(
        model_input=avg_pooled,
        vocab_size=vocab_size,
        num_mixtures=2,
        **unused_params)