# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Skip-Thoughts model for learning sentence vectors.

The model is based on the paper:

  "Skip-Thought Vectors"
  Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
  Antonio Torralba, Raquel Urtasun, Sanja Fidler.
  https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf

Layer normalization is applied based on the paper:

  "Layer Normalization"
  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
  https://arxiv.org/abs/1607.06450
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import tensorflow as tf

import importlib
gru_cell = importlib.import_module('model-skipthoughts.ops.gru_cell')
input_ops = importlib.import_module('model-skipthoughts.ops.input_ops')


def random_orthonormal_initializer(shape, dtype=tf.float32,
                                   partition_info=None):  # pylint: disable=unused-argument
  """Variable initializer that produces a random orthonormal matrix."""
  if len(shape) != 2 or shape[0] != shape[1]:
    raise ValueError("Expecting square shape, got %s" % shape)
  _, u, _ = tf.svd(tf.random_normal(shape, dtype=dtype), full_matrices=True)
  return u


class SkipThoughtsModel(object):
  """Skip-thoughts model."""

  def __init__(self, config, mode="train", input_reader=None):
    """Basic setup. The actual TensorFlow graph is constructed in build().

    Args:
      config: Object containing configuration parameters.
      mode: "train", "eval" or "encode".
      input_reader: Subclass of tf.ReaderBase for reading the input serialized
        tf.Example protocol buffers. Defaults to TFRecordReader.

    Raises:
      ValueError: If mode is invalid.
    """
    if mode not in ["train", "eval", "encode"]:
      raise ValueError("Unrecognized mode: %s" % mode)

    self.config = config
    self.mode = mode
    self.reader = input_reader if input_reader else tf.TFRecordReader()

    # Initializer used for non-recurrent weights.
    self.uniform_initializer = tf.random_uniform_initializer(
        minval=-self.config.uniform_init_scale,
        maxval=self.config.uniform_init_scale)

    # Input sentences represented as sequences of word ids. "encode" is the
    # source sentence, "decode_pre" is the previous sentence and "decode_post"
    # is the next sentence.
    # Each is an int64 Tensor with  shape [batch_size, padded_length].
    self.encode_ids = None
    self.decode_pre_ids = None
    self.decode_post_ids = None

    # Boolean masks distinguishing real words (1) from padded words (0).
    # Each is an int32 Tensor with shape [batch_size, padded_length].
    self.encode_mask = None
    self.decode_pre_mask = None
    self.decode_post_mask = None

    # Input sentences represented as sequences of word embeddings.
    # Each is a float32 Tensor with shape [batch_size, padded_length, emb_dim].
    self.encode_emb = None
    self.decode_pre_emb = None
    self.decode_post_emb = None

    # The output from the sentence encoder.
    # A float32 Tensor with shape [batch_size, num_gru_units].
    self.thought_vectors = None

    # The cross entropy losses and corresponding weights of the decoders. Used
    # for evaluation.
    self.target_cross_entropy_losses = []
    self.target_cross_entropy_loss_weights = []

    # The total loss to optimize.
    self.total_loss = None

  def build_inputs(self):
    """Builds the ops for reading input data.

    Outputs:
      self.encode_ids
      self.decode_pre_ids
      self.decode_post_ids
      self.encode_mask
      self.decode_pre_mask
      self.decode_post_mask
    """
    if self.mode == "encode":
      # Word embeddings are fed from an external vocabulary which has possibly
      # been expanded (see vocabulary_expansion.py).
      encode_ids = None
      decode_pre_ids = None
      decode_post_ids = None
      encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask")
      decode_pre_mask = None
      decode_post_mask = None
    else:
      # Prefetch serialized tf.Example protos.
      input_queue = input_ops.prefetch_input_data(
          self.reader,
          self.config.input_file_pattern,
          shuffle=self.config.shuffle_input_data,
          capacity=self.config.input_queue_capacity,
          num_reader_threads=self.config.num_input_reader_threads)

      # Deserialize a batch.
      serialized = input_queue.dequeue_many(self.config.batch_size)
      encode, decode_pre, decode_post = input_ops.parse_example_batch(
          serialized)

      encode_ids = encode.ids
      decode_pre_ids = decode_pre.ids
      decode_post_ids = decode_post.ids

      encode_mask = encode.mask
      decode_pre_mask = decode_pre.mask
      decode_post_mask = decode_post.mask

    self.encode_ids = encode_ids
    self.decode_pre_ids = decode_pre_ids
    self.decode_post_ids = decode_post_ids

    self.encode_mask = encode_mask
    self.decode_pre_mask = decode_pre_mask
    self.decode_post_mask = decode_post_mask

  def build_word_embeddings(self):
    """Builds the word embeddings.

    Inputs:
      self.encode_ids
      self.decode_pre_ids
      self.decode_post_ids

    Outputs:
      self.encode_emb
      self.decode_pre_emb
      self.decode_post_emb
    """
    if self.mode == "encode":
      # Word embeddings are fed from an external vocabulary which has possibly
      # been expanded (see vocabulary_expansion.py).
      encode_emb = tf.placeholder(tf.float32, (
          None, None, self.config.word_embedding_dim), "encode_emb")
      # No sequences to decode.
      decode_pre_emb = None
      decode_post_emb = None
    else:
      word_emb = tf.get_variable(
          name="word_embedding",
          shape=[self.config.vocab_size, self.config.word_embedding_dim],
          initializer=self.uniform_initializer)

      encode_emb = tf.nn.embedding_lookup(word_emb, self.encode_ids)
      decode_pre_emb = tf.nn.embedding_lookup(word_emb, self.decode_pre_ids)
      decode_post_emb = tf.nn.embedding_lookup(word_emb, self.decode_post_ids)

    self.encode_emb = encode_emb
    self.decode_pre_emb = decode_pre_emb
    self.decode_post_emb = decode_post_emb

  def _initialize_gru_cell(self, num_units):
    """Initializes a GRU cell.

    The Variables of the GRU cell are initialized in a way that exactly matches
    the skip-thoughts paper: recurrent weights are initialized from random
    orthonormal matrices and non-recurrent weights are initialized from random
    uniform matrices.

    Args:
      num_units: Number of output units.

    Returns:
      cell: An instance of RNNCell with variable initializers that match the
        skip-thoughts paper.
    """
    return gru_cell.LayerNormGRUCell(
        num_units,
        w_initializer=self.uniform_initializer,
        u_initializer=random_orthonormal_initializer,
        b_initializer=tf.constant_initializer(0.0))

  def build_encoder(self):
    """Builds the sentence encoder.

    Inputs:
      self.encode_emb
      self.encode_mask

    Outputs:
      self.thought_vectors

    Raises:
      ValueError: if config.bidirectional_encoder is True and config.encoder_dim
        is odd.
    """
    with tf.variable_scope("encoder") as scope:
      length = tf.to_int32(tf.reduce_sum(self.encode_mask, 1), name="length")

      if self.config.bidirectional_encoder:
        if self.config.encoder_dim % 2:
          raise ValueError(
              "encoder_dim must be even when using a bidirectional encoder.")
        num_units = self.config.encoder_dim // 2
        cell_fw = self._initialize_gru_cell(num_units)  # Forward encoder
        cell_bw = self._initialize_gru_cell(num_units)  # Backward encoder
        _, states = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell_fw,
            cell_bw=cell_bw,
            inputs=self.encode_emb,
            sequence_length=length,
            dtype=tf.float32,
            scope=scope)
        thought_vectors = tf.concat(states, 1, name="thought_vectors")
      else:
        cell = self._initialize_gru_cell(self.config.encoder_dim)
        _, state = tf.nn.dynamic_rnn(
            cell=cell,
            inputs=self.encode_emb,
            sequence_length=length,
            dtype=tf.float32,
            scope=scope)
        # Use an identity operation to name the Tensor in the Graph.
        thought_vectors = tf.identity(state, name="thought_vectors")

    self.thought_vectors = thought_vectors

  def _build_decoder(self, name, embeddings, targets, mask, initial_state,
                     reuse_logits):
    """Builds a sentence decoder.

    Args:
      name: Decoder name.
      embeddings: Batch of sentences to decode; a float32 Tensor with shape
        [batch_size, padded_length, emb_dim].
      targets: Batch of target word ids; an int64 Tensor with shape
        [batch_size, padded_length].
      mask: A 0/1 Tensor with shape [batch_size, padded_length].
      initial_state: Initial state of the GRU. A float32 Tensor with shape
        [batch_size, num_gru_cells].
      reuse_logits: Whether to reuse the logits weights.
    """
    # Decoder RNN.
    cell = self._initialize_gru_cell(self.config.encoder_dim)
    with tf.variable_scope(name) as scope:
      # Add a padding word at the start of each sentence (to correspond to the
      # prediction of the first word) and remove the last word.
      decoder_input = tf.pad(
          embeddings[:, :-1, :], [[0, 0], [1, 0], [0, 0]], name="input")
      length = tf.reduce_sum(mask, 1, name="length")
      decoder_output, _ = tf.nn.dynamic_rnn(
          cell=cell,
          inputs=decoder_input,
          sequence_length=length,
          initial_state=initial_state,
          scope=scope)

    # Stack batch vertically.
    decoder_output = tf.reshape(decoder_output, [-1, self.config.encoder_dim])
    targets = tf.reshape(targets, [-1])
    weights = tf.to_float(tf.reshape(mask, [-1]))

    # Logits.
    with tf.variable_scope("logits", reuse=reuse_logits) as scope:
      logits = tf.contrib.layers.fully_connected(
          inputs=decoder_output,
          num_outputs=self.config.vocab_size,
          activation_fn=None,
          weights_initializer=self.uniform_initializer,
          scope=scope)

    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=targets, logits=logits)
    batch_loss = tf.reduce_sum(losses * weights)
    tf.losses.add_loss(batch_loss)

    tf.summary.scalar("losses/" + name, batch_loss)

    self.target_cross_entropy_losses.append(losses)
    self.target_cross_entropy_loss_weights.append(weights)

  def build_decoders(self):
    """Builds the sentence decoders.

    Inputs:
      self.decode_pre_emb
      self.decode_post_emb
      self.decode_pre_ids
      self.decode_post_ids
      self.decode_pre_mask
      self.decode_post_mask
      self.thought_vectors

    Outputs:
      self.target_cross_entropy_losses
      self.target_cross_entropy_loss_weights
    """
    if self.mode != "encode":
      # Pre-sentence decoder.
      self._build_decoder("decoder_pre", self.decode_pre_emb,
                          self.decode_pre_ids, self.decode_pre_mask,
                          self.thought_vectors, False)

      # Post-sentence decoder. Logits weights are reused.
      self._build_decoder("decoder_post", self.decode_post_emb,
                          self.decode_post_ids, self.decode_post_mask,
                          self.thought_vectors, True)

  def build_loss(self):
    """Builds the loss Tensor.

    Outputs:
      self.total_loss
    """
    if self.mode != "encode":
      total_loss = tf.losses.get_total_loss()
      tf.summary.scalar("losses/total", total_loss)

      self.total_loss = total_loss

  def build_global_step(self):
    """Builds the global step Tensor.

    Outputs:
      self.global_step
    """
    self.global_step = tf.contrib.framework.create_global_step()

  def build(self):
    """Creates all ops for training, evaluation or encoding."""
    self.build_inputs()
    self.build_word_embeddings()
    self.build_encoder()
    self.build_decoders()
    self.build_loss()
    self.build_global_step()