# Copyright 2020 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""LSTM-based encoders and decoders for MusicVAE."""
import abc

from magenta.common import flatten_maybe_padded_sequences
from magenta.common import Nade
import magenta.contrib.rnn as contrib_rnn
import magenta.contrib.seq2seq as contrib_seq2seq
import magenta.contrib.training as contrib_training
from magenta.models.music_vae import base_model
from magenta.models.music_vae import lstm_utils
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp


# ENCODERS


class LstmEncoder(base_model.BaseEncoder):
  """Unidirectional LSTM Encoder."""

  @property
  def output_depth(self):
    return self._cell.output_size

  def build(self, hparams, is_training=True, name_or_scope='encoder'):
    if hparams.use_cudnn:
      tf.logging.warning('cuDNN LSTM no longer supported. Using regular LSTM.')

    self._is_training = is_training
    self._name_or_scope = name_or_scope

    tf.logging.info('\nEncoder Cells (unidirectional):\n'
                    '  units: %s\n',
                    hparams.enc_rnn_size)
    self._cell = lstm_utils.rnn_cell(
        hparams.enc_rnn_size, hparams.dropout_keep_prob,
        hparams.residual_encoder, is_training)

  def encode(self, sequence, sequence_length):
    # Convert to time-major.
    sequence = tf.transpose(sequence, [1, 0, 2])
    outputs, _ = tf.nn.dynamic_rnn(
        self._cell, sequence, sequence_length, dtype=tf.float32,
        time_major=True, scope=self._name_or_scope)
    return outputs[-1]


class BidirectionalLstmEncoder(base_model.BaseEncoder):
  """Bidirectional LSTM Encoder."""

  @property
  def output_depth(self):
    return self._cells[0][-1].output_size + self._cells[1][-1].output_size

  def build(self, hparams, is_training=True, name_or_scope='encoder'):
    self._is_training = is_training
    self._name_or_scope = name_or_scope
    if hparams.use_cudnn:
      tf.logging.warning('cuDNN LSTM no longer supported. Using regular LSTM.')

    tf.logging.info('\nEncoder Cells (bidirectional):\n'
                    '  units: %s\n',
                    hparams.enc_rnn_size)

    self._cells = lstm_utils.build_bidirectional_lstm(
        layer_sizes=hparams.enc_rnn_size,
        dropout_keep_prob=hparams.dropout_keep_prob,
        residual=hparams.residual_encoder,
        is_training=is_training)

  def encode(self, sequence, sequence_length):
    cells_fw, cells_bw = self._cells

    _, states_fw, states_bw = contrib_rnn.stack_bidirectional_dynamic_rnn(
        cells_fw,
        cells_bw,
        sequence,
        sequence_length=sequence_length,
        time_major=False,
        dtype=tf.float32,
        scope=self._name_or_scope)
    # Note we access the outputs (h) from the states since the backward
    # ouputs are reversed to the input order in the returned outputs.
    last_h_fw = states_fw[-1][-1].h
    last_h_bw = states_bw[-1][-1].h

    return tf.concat([last_h_fw, last_h_bw], 1)


class HierarchicalLstmEncoder(base_model.BaseEncoder):
  """Hierarchical LSTM encoder wrapper.

  Input sequences will be split into segments based on the first value of
  `level_lengths` and encoded. At subsequent levels, the embeddings will be
  grouped based on `level_lengths` and encoded until a single embedding is
  produced.

  See the `encode` method for details on the expected arrangement the sequence
  tensors.

  Args:
    core_encoder_cls: A single BaseEncoder class to use for each level of the
      hierarchy.
    level_lengths: A list of the (maximum) lengths of the segments at each
      level of the hierarchy. The product must equal `hparams.max_seq_len`.
  """

  def __init__(self, core_encoder_cls, level_lengths):
    self._core_encoder_cls = core_encoder_cls
    self._level_lengths = level_lengths

  @property
  def output_depth(self):
    return self._hierarchical_encoders[-1][1].output_depth

  @property
  def level_lengths(self):
    return list(self._level_lengths)

  def level(self, l):
    """Returns the BaseEncoder at level `l`."""
    return self._hierarchical_encoders[l][1]

  def build(self, hparams, is_training=True):
    self._total_length = hparams.max_seq_len
    if self._total_length != np.prod(self._level_lengths):
      raise ValueError(
          'The product of the HierarchicalLstmEncoder level lengths (%d) must '
          'equal the padded input sequence length (%d).' % (
              np.prod(self._level_lengths), self._total_length))
    tf.logging.info('\nHierarchical Encoder:\n'
                    '  input length: %d\n'
                    '  level lengths: %s\n',
                    self._total_length,
                    self._level_lengths)
    self._hierarchical_encoders = []
    num_splits = int(np.prod(self._level_lengths))
    for i, l in enumerate(self._level_lengths):
      num_splits //= l
      tf.logging.info('Level %d splits: %d', i, num_splits)
      h_encoder = self._core_encoder_cls()
      h_encoder.build(
          hparams, is_training,
          name_or_scope=tf.VariableScope(
              tf.AUTO_REUSE, 'encoder/hierarchical_level_%d' % i))
      self._hierarchical_encoders.append((num_splits, h_encoder))

  def encode(self, sequence, sequence_length):
    """Hierarchically encodes the input sequences, returning a single embedding.

    Each sequence should be padded per-segment. For example, a sequence with
    three segments [1, 2, 3], [4, 5], [6, 7, 8 ,9] and a `max_seq_len` of 12
    should be input as `sequence = [1, 2, 3, 0, 4, 5, 0, 0, 6, 7, 8, 9]` with
    `sequence_length = [3, 2, 4]`.

    Args:
      sequence: A batch of (padded) sequences, sized
        `[batch_size, max_seq_len, input_depth]`.
      sequence_length: A batch of sequence lengths. May be sized
        `[batch_size, level_lengths[0]]` or `[batch_size]`. If the latter,
        each length must either equal `max_seq_len` or 0. In this case, the
        segment lengths are assumed to be constant and the total length will be
        evenly divided amongst the segments.

    Returns:
      embedding: A batch of embeddings, sized `[batch_size, N]`.
    """
    batch_size = int(sequence.shape[0])
    sequence_length = lstm_utils.maybe_split_sequence_lengths(
        sequence_length, np.prod(self._level_lengths[1:]),
        self._total_length)

    for level, (num_splits, h_encoder) in enumerate(
        self._hierarchical_encoders):
      split_seqs = tf.split(sequence, num_splits, axis=1)
      # In the first level, we use the input `sequence_length`. After that,
      # we use the full embedding sequences.
      if level:
        sequence_length = tf.fill(
            [batch_size, num_splits], split_seqs[0].shape[1])
      split_lengths = tf.unstack(sequence_length, axis=1)
      embeddings = [
          h_encoder.encode(s, l) for s, l in zip(split_seqs, split_lengths)]
      sequence = tf.stack(embeddings, axis=1)

    with tf.control_dependencies([tf.assert_equal(tf.shape(sequence)[1], 1)]):
      return sequence[:, 0]


# DECODERS


class BaseLstmDecoder(base_model.BaseDecoder):
  """Abstract LSTM Decoder class.

  Implementations must define the following abstract methods:
      -`_sample`
      -`_flat_reconstruction_loss`
  """

  def build(self, hparams, output_depth, is_training=True):
    if hparams.use_cudnn:
      tf.logging.warning('cuDNN LSTM no longer supported. Using regular LSTM.')

    self._is_training = is_training

    tf.logging.info('\nDecoder Cells:\n'
                    '  units: %s\n',
                    hparams.dec_rnn_size)

    self._sampling_probability = lstm_utils.get_sampling_probability(
        hparams, is_training)
    self._output_depth = output_depth
    self._output_layer = tf.layers.Dense(
        output_depth, name='output_projection')
    self._dec_cell = lstm_utils.rnn_cell(
        hparams.dec_rnn_size, hparams.dropout_keep_prob,
        hparams.residual_decoder, is_training)

  @property
  def state_size(self):
    return self._dec_cell.state_size

  @abc.abstractmethod
  def _sample(self, rnn_output, temperature):
    """Core sampling method for a single time step.

    Args:
      rnn_output: The output from a single timestep of the RNN, sized
          `[batch_size, rnn_output_size]`.
      temperature: A scalar float specifying a sampling temperature.
    Returns:
      A batch of samples from the model.
    """
    pass

  @abc.abstractmethod
  def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
    """Core loss calculation method for flattened outputs.

    Args:
      flat_x_target: The flattened ground truth vectors, sized
        `[sum(x_length), self._output_depth]`.
      flat_rnn_output: The flattened output from all timeputs of the RNN,
        sized `[sum(x_length), rnn_output_size]`.
    Returns:
      r_loss: The unreduced reconstruction losses, sized `[sum(x_length)]`.
      metric_map: A map of metric names to tuples, each of which contain the
        pair of (value_tensor, update_op) from a tf.metrics streaming metric.
    """
    pass

  def _decode(self, z, helper, input_shape, max_length=None):
    """Decodes the given batch of latent vectors vectors, which may be 0-length.

    Args:
      z: Batch of latent vectors, sized `[batch_size, z_size]`, where `z_size`
        may be 0 for unconditioned decoding.
      helper: A seq2seq.Helper to use.
      input_shape: The shape of each model input vector passed to the decoder.
      max_length: (Optional) The maximum iterations to decode.

    Returns:
      results: The LstmDecodeResults.
    """
    initial_state = lstm_utils.initial_cell_state_from_embedding(
        self._dec_cell, z, name='decoder/z_to_initial_state')

    decoder = lstm_utils.Seq2SeqLstmDecoder(
        self._dec_cell,
        helper,
        initial_state=initial_state,
        input_shape=input_shape,
        output_layer=self._output_layer)
    final_output, final_state, final_lengths = contrib_seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=max_length,
        swap_memory=True,
        scope='decoder')
    results = lstm_utils.LstmDecodeResults(
        rnn_input=final_output.rnn_input[:, :, :self._output_depth],
        rnn_output=final_output.rnn_output,
        samples=final_output.sample_id,
        final_state=final_state,
        final_sequence_lengths=final_lengths)

    return results

  def reconstruction_loss(self, x_input, x_target, x_length, z=None,
                          c_input=None):
    """Reconstruction loss calculation.

    Args:
      x_input: Batch of decoder input sequences for teacher forcing, sized
        `[batch_size, max(x_length), output_depth]`.
      x_target: Batch of expected output sequences to compute loss against,
        sized `[batch_size, max(x_length), output_depth]`.
      x_length: Length of input/output sequences, sized `[batch_size]`.
      z: (Optional) Latent vectors. Required if model is conditional. Sized
        `[n, z_size]`.
      c_input: (Optional) Batch of control sequences, sized
          `[batch_size, max(x_length), control_depth]`. Required if conditioning
          on control sequences.

    Returns:
      r_loss: The reconstruction loss for each sequence in the batch.
      metric_map: Map from metric name to tf.metrics return values for logging.
      decode_results: The LstmDecodeResults.
    """
    batch_size = int(x_input.shape[0])

    has_z = z is not None
    z = tf.zeros([batch_size, 0]) if z is None else z
    repeated_z = tf.tile(
        tf.expand_dims(z, axis=1), [1, tf.shape(x_input)[1], 1])

    has_control = c_input is not None
    if c_input is None:
      c_input = tf.zeros([batch_size, tf.shape(x_input)[1], 0])

    sampling_probability_static = tf.get_static_value(
        self._sampling_probability)
    if sampling_probability_static == 0.0:
      # Use teacher forcing.
      x_input = tf.concat([x_input, repeated_z, c_input], axis=2)
      helper = contrib_seq2seq.TrainingHelper(x_input, x_length)
    else:
      # Use scheduled sampling.
      if has_z or has_control:
        auxiliary_inputs = tf.zeros([batch_size, tf.shape(x_input)[1], 0])
        if has_z:
          auxiliary_inputs = tf.concat([auxiliary_inputs, repeated_z], axis=2)
        if has_control:
          auxiliary_inputs = tf.concat([auxiliary_inputs, c_input], axis=2)
      else:
        auxiliary_inputs = None
      helper = contrib_seq2seq.ScheduledOutputTrainingHelper(
          inputs=x_input,
          sequence_length=x_length,
          auxiliary_inputs=auxiliary_inputs,
          sampling_probability=self._sampling_probability,
          next_inputs_fn=self._sample)

    decode_results = self._decode(
        z, helper=helper, input_shape=helper.inputs.shape[2:])
    flat_x_target = flatten_maybe_padded_sequences(x_target, x_length)
    flat_rnn_output = flatten_maybe_padded_sequences(
        decode_results.rnn_output, x_length)
    r_loss, metric_map = self._flat_reconstruction_loss(
        flat_x_target, flat_rnn_output)

    # Sum loss over sequences.
    cum_x_len = tf.concat([(0,), tf.cumsum(x_length)], axis=0)
    r_losses = []
    for i in range(batch_size):
      b, e = cum_x_len[i], cum_x_len[i + 1]
      r_losses.append(tf.reduce_sum(r_loss[b:e]))
    r_loss = tf.stack(r_losses)

    return r_loss, metric_map, decode_results

  def sample(self, n, max_length=None, z=None, c_input=None, temperature=1.0,
             start_inputs=None, end_fn=None):
    """Sample from decoder with an optional conditional latent vector `z`.

    Args:
      n: Scalar number of samples to return.
      max_length: (Optional) Scalar maximum sample length to return. Required if
        data representation does not include end tokens.
      z: (Optional) Latent vectors to sample from. Required if model is
        conditional. Sized `[n, z_size]`.
      c_input: (Optional) Control sequence, sized `[max_length, control_depth]`.
      temperature: (Optional) The softmax temperature to use when sampling, if
        applicable.
      start_inputs: (Optional) Initial inputs to use for batch.
        Sized `[n, output_depth]`.
      end_fn: (Optional) A callable that takes a batch of samples (sized
        `[n, output_depth]` and emits a `bool` vector
        shaped `[batch_size]` indicating whether each sample is an end token.
    Returns:
      samples: Sampled sequences. Sized `[n, max_length, output_depth]`.
      final_state: The final states of the decoder.
    Raises:
      ValueError: If `z` is provided and its first dimension does not equal `n`.
    """
    if z is not None and int(z.shape[0]) != n:
      raise ValueError(
          '`z` must have a first dimension that equals `n` when given. '
          'Got: %d vs %d' % (z.shape[0], n))

    # Use a dummy Z in unconditional case.
    z = tf.zeros((n, 0), tf.float32) if z is None else z

    if c_input is not None:
      # Tile control sequence across samples.
      c_input = tf.tile(tf.expand_dims(c_input, 1), [1, n, 1])

    # If not given, start with zeros.
    if start_inputs is None:
      start_inputs = tf.zeros([n, self._output_depth], dtype=tf.float32)
    # In the conditional case, also concatenate the Z.
    start_inputs = tf.concat([start_inputs, z], axis=-1)
    if c_input is not None:
      start_inputs = tf.concat([start_inputs, c_input[0]], axis=-1)
    initialize_fn = lambda: (tf.zeros([n], tf.bool), start_inputs)

    sample_fn = lambda time, outputs, state: self._sample(outputs, temperature)
    end_fn = end_fn or (lambda x: False)

    def next_inputs_fn(time, outputs, state, sample_ids):
      del outputs
      finished = end_fn(sample_ids)
      next_inputs = tf.concat([sample_ids, z], axis=-1)
      if c_input is not None:
        # We need to stop if we've run out of control input.
        finished = tf.cond(tf.less(time, tf.shape(c_input)[0] - 1),
                           lambda: finished,
                           lambda: True)
        next_inputs = tf.concat([
            next_inputs,
            tf.cond(tf.less(time, tf.shape(c_input)[0] - 1),
                    lambda: c_input[time + 1],
                    lambda: tf.zeros_like(c_input[0]))  # should be unused
        ], axis=-1)
      return (finished, next_inputs, state)

    sampler = contrib_seq2seq.CustomHelper(
        initialize_fn=initialize_fn, sample_fn=sample_fn,
        next_inputs_fn=next_inputs_fn, sample_ids_shape=[self._output_depth],
        sample_ids_dtype=tf.float32)

    decode_results = self._decode(
        z, helper=sampler, input_shape=start_inputs.shape[1:],
        max_length=max_length)

    return decode_results.samples, decode_results


class BidirectionalLstmControlPreprocessingDecoder(base_model.BaseDecoder):
  """Decoder that preprocesses control input with a bidirectional LSTM."""

  def __init__(self, core_decoder):
    super(BidirectionalLstmControlPreprocessingDecoder, self).__init__()
    self._core_decoder = core_decoder

  def build(self, hparams, output_depth, is_training=True):
    self._is_training = is_training

    if hparams.use_cudnn:
      tf.logging.warning('cuDNN LSTM no longer supported. Using regular LSTM.')

    tf.logging.info('\nControl Preprocessing Cells (bidirectional):\n'
                    '  units: %s\n',
                    hparams.control_preprocessing_rnn_size)

    self._control_preprocessing_cells = lstm_utils.build_bidirectional_lstm(
        layer_sizes=hparams.control_preprocessing_rnn_size,
        dropout_keep_prob=hparams.dropout_keep_prob,
        residual=hparams.residual_decoder,
        is_training=is_training)

    self._core_decoder.build(hparams, output_depth, is_training)

  def _preprocess_controls(self, c_input, length):
    cells_fw, cells_bw = self._control_preprocessing_cells

    outputs, _, _ = contrib_rnn.stack_bidirectional_dynamic_rnn(
        cells_fw,
        cells_bw,
        c_input,
        sequence_length=length,
        time_major=False,
        dtype=tf.float32,
        scope='control_preprocessing')

    return outputs

  def reconstruction_loss(self, x_input, x_target, x_length, z=None,
                          c_input=None):
    if c_input is None:
      raise ValueError('Must provide control input to preprocess.')
    preprocessed_c_input = self._preprocess_controls(c_input, x_length)
    return self._core_decoder.reconstruction_loss(
        x_input, x_target, x_length, z, preprocessed_c_input)

  def sample(self, n, max_length=None, z=None, c_input=None, **kwargs):
    if c_input is None:
      raise ValueError('Must provide control input to preprocess.')
    preprocessed_c_input = tf.squeeze(
        self._preprocess_controls(tf.expand_dims(c_input, axis=0),
                                  tf.reshape(max_length, [1])),
        axis=0)
    return self._core_decoder.sample(
        n, max_length, z, preprocessed_c_input, **kwargs)


class BooleanLstmDecoder(BaseLstmDecoder):
  """LSTM decoder with single Boolean output per time step."""

  def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
    flat_logits = flat_rnn_output
    flat_truth = tf.squeeze(flat_x_target, axis=1)
    flat_predictions = tf.squeeze(flat_logits >= 0, axis=1)
    r_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=flat_x_target, logits=flat_logits)

    metric_map = {
        'metrics/accuracy':
            tf.metrics.accuracy(flat_truth, flat_predictions),
    }
    return r_loss, metric_map

  def _sample(self, rnn_output, temperature=1.0):
    sampler = tfp.distributions.Bernoulli(
        logits=rnn_output / temperature, dtype=tf.float32)
    return sampler.sample()


class CategoricalLstmDecoder(BaseLstmDecoder):
  """LSTM decoder with single categorical output."""

  def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
    flat_logits = flat_rnn_output
    flat_truth = tf.argmax(flat_x_target, axis=1)
    flat_predictions = tf.argmax(flat_logits, axis=1)
    r_loss = tf.nn.softmax_cross_entropy_with_logits(
        labels=flat_x_target, logits=flat_logits)

    metric_map = {
        'metrics/accuracy':
            tf.metrics.accuracy(flat_truth, flat_predictions),
        'metrics/mean_per_class_accuracy':
            tf.metrics.mean_per_class_accuracy(
                flat_truth, flat_predictions, int(flat_x_target.shape[-1])),
    }
    return r_loss, metric_map

  def _sample(self, rnn_output, temperature=1.0):
    sampler = tfp.distributions.OneHotCategorical(
        logits=rnn_output / temperature, dtype=tf.float32)
    return sampler.sample()

  def sample(self, n, max_length=None, z=None, c_input=None, temperature=None,
             start_inputs=None, end_token=None):
    """Overrides BaseLstmDecoder `sample` method to add optional end token.

    Args:
      n: Scalar number of samples to return.
      max_length: (Optional) Scalar maximum sample length to return. Required if
        data representation does not include end tokens.
      z: (Optional) Latent vectors to sample from. Required if model is
        conditional. Sized `[n, z_size]`.
      c_input: (Optional) Control sequence, sized `[max_length, control_depth]`.
      temperature: (Optional) The softmax temperature to use  Defaults to 1.0.
      start_inputs: (Optional) Initial inputs to use for batch.
        Sized `[n, output_depth]`.
      end_token: (Optional) Scalar token signaling the end of the sequence to
        use for early stopping.
    Returns:
      samples: Sampled sequences. Sized `[n, max_length, output_depth]`.
      final_state: The final states of the decoder.
    Raises:
      ValueError: If `z` is provided and its first dimension does not equal `n`.
    """
    if end_token is None:
      end_fn = None
    else:
      end_fn = lambda x: tf.equal(tf.argmax(x, axis=-1), end_token)
    return super(CategoricalLstmDecoder, self).sample(
        n, max_length, z, c_input, temperature, start_inputs, end_fn)


class MultiOutCategoricalLstmDecoder(CategoricalLstmDecoder):
  """LSTM decoder with multiple categorical outputs.

  The final sequence dimension is split before computing the loss or sampling,
  based on the `output_depths`. Reconstruction losses are summed across the
  split and samples are concatenated in the same order as the input.

  Args:
    output_depths: A list of output depths for the in the same order as the are
      concatenated in the final sequence dimension.
  """

  def __init__(self, output_depths):
    super(MultiOutCategoricalLstmDecoder, self).__init__()
    self._output_depths = output_depths

  def build(self, hparams, output_depth, is_training=True):
    if sum(self._output_depths) != output_depth:
      raise ValueError(
          'Decoder output depth does not match sum of sub-decoders: %s vs %d' %
          (self._output_depths, output_depth))
    super(MultiOutCategoricalLstmDecoder, self).build(
        hparams, output_depth, is_training)

  def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
    split_x_target = tf.split(flat_x_target, self._output_depths, axis=-1)
    split_rnn_output = tf.split(
        flat_rnn_output, self._output_depths, axis=-1)

    losses = []
    metric_map = {}
    for i in range(len(self._output_depths)):
      l, m = (
          super(MultiOutCategoricalLstmDecoder, self)._flat_reconstruction_loss(
              split_x_target[i], split_rnn_output[i]))
      losses.append(l)
      for k, v in m.items():
        metric_map['%s/output_%d' % (k, i)] = v

    return tf.reduce_sum(losses, axis=0), metric_map

  def _sample(self, rnn_output, temperature=1.0):
    split_logits = tf.split(rnn_output, self._output_depths, axis=-1)
    samples = []
    for logits, output_depth in zip(split_logits, self._output_depths):
      sampler = tfp.distributions.Categorical(
          logits=logits / temperature)
      sample_label = sampler.sample()
      samples.append(tf.one_hot(sample_label, output_depth, dtype=tf.float32))
    return tf.concat(samples, axis=-1)


class SplitMultiOutLstmDecoder(base_model.BaseDecoder):
  """Wrapper that splits multiple outputs to different LSTM decoders.

  The final sequence dimension is split and passed to the `core_decoders` based
  on the `output_depths`. `z` is passed directly to all core decoders without
  modification. Reconstruction losses are summed across the split and samples
  are concatenated in the same order as the input.

  Args:
    core_decoders: The BaseDecoder implementation class(es) to use at the
      output layer. Size and order must match `output_depths`.
    output_depths: A list of output depths for the core decoders in the same
      order as the are concatenated in the input. Size and order must match
      `core_decoders`.
  Raises:
    ValueError: If the size of `core_decoders` and `output_depths` are not
      equal.
  """

  def __init__(self, core_decoders, output_depths):
    if len(core_decoders) != len(output_depths):
      raise ValueError(
          'The number of `core_decoders` and `output_depths` provided to a '
          'SplitMultiOutLstmDecoder must be equal. Got: %d != %d' %
          (len(core_decoders), len(output_depths)))
    self._core_decoders = core_decoders
    self._output_depths = output_depths

  @property
  def state_size(self):
    return tf.nest.map_structure(
        lambda *x: sum(x), *(cd.state_size for cd in self._core_decoders))

  def build(self, hparams, output_depth, is_training=True):
    if sum(self._output_depths) != output_depth:
      raise ValueError(
          'Decoder output depth does not match sum of sub-decoders: %s vs %d' %
          (self._output_depths, output_depth))
    self.hparams = hparams
    self._is_training = is_training

    for i, (cd, od) in enumerate(zip(self._core_decoders, self._output_depths)):
      with tf.variable_scope('core_decoder_%d' % i):
        cd.build(hparams, od, is_training)

  def _merge_decode_results(self, decode_results):
    """Merge in the output dimension."""
    output_axis = -1
    assert decode_results
    zipped_results = lstm_utils.LstmDecodeResults(*list(zip(*decode_results)))
    with tf.control_dependencies([
        tf.assert_equal(
            zipped_results.final_sequence_lengths, self.hparams.max_seq_len,
            message='Variable length not supported by '
                    'MultiOutCategoricalLstmDecoder.')]):
      if zipped_results.final_state[0] is None:
        final_state = None
      else:
        final_state = tf.nest.map_structure(
            lambda x: tf.concat(x, axis=output_axis),
            zipped_results.final_state)

      return lstm_utils.LstmDecodeResults(
          rnn_output=tf.concat(zipped_results.rnn_output, axis=output_axis),
          rnn_input=tf.concat(zipped_results.rnn_input, axis=output_axis),
          samples=tf.concat(zipped_results.samples, axis=output_axis),
          final_state=final_state,
          final_sequence_lengths=zipped_results.final_sequence_lengths[0])

  def reconstruction_loss(self, x_input, x_target, x_length, z=None,
                          c_input=None):
    # Split output for each core model.
    split_x_input = tf.split(x_input, self._output_depths, axis=-1)
    split_x_target = tf.split(x_target, self._output_depths, axis=-1)
    loss_outputs = []

    # Compute reconstruction losses for the split output.
    for i, cd in enumerate(self._core_decoders):
      with tf.variable_scope('core_decoder_%d' % i):
        # TODO(adarob): Sample initial inputs when using scheduled sampling.
        loss_outputs.append(
            cd.reconstruction_loss(
                split_x_input[i], split_x_target[i], x_length, z, c_input))

    r_losses, metric_maps, decode_results = list(zip(*loss_outputs))

    # Merge the metric maps by passing through renamed values and taking the
    # mean across the splits.
    merged_metric_map = {}
    for metric_name in metric_maps[0]:
      metric_values = []
      for i, m in enumerate(metric_maps):
        merged_metric_map['%s/output_%d' % (metric_name, i)] = m[metric_name]
        metric_values.append(m[metric_name][0])
      merged_metric_map[metric_name] = (
          tf.reduce_mean(metric_values), tf.no_op())

    return (tf.reduce_sum(r_losses, axis=0),
            merged_metric_map,
            self._merge_decode_results(decode_results))

  def sample(self, n, max_length=None, z=None, c_input=None, temperature=1.0,
             start_inputs=None, **core_sampler_kwargs):
    if z is not None and int(z.shape[0]) != n:
      raise ValueError(
          '`z` must have a first dimension that equals `n` when given. '
          'Got: %d vs %d' % (z.shape[0], n))

    if max_length is None:
      # TODO(adarob): Support variable length outputs.
      raise ValueError(
          'SplitMultiOutLstmDecoder requires `max_length` be provided during '
          'sampling.')

    if start_inputs is None:
      split_start_inputs = [None] * len(self._output_depths)
    else:
      split_start_inputs = tf.split(start_inputs, self._output_depths, axis=-1)

    sample_results = []
    for i, cd in enumerate(self._core_decoders):
      with tf.variable_scope('core_decoder_%d' % i):
        sample_results.append(cd.sample(
            n,
            max_length,
            z=z,
            c_input=c_input,
            temperature=temperature,
            start_inputs=split_start_inputs[i],
            **core_sampler_kwargs))

    sample_ids, decode_results = list(zip(*sample_results))
    return (tf.concat(sample_ids, axis=-1),
            self._merge_decode_results(decode_results))


class MultiLabelRnnNadeDecoder(BaseLstmDecoder):
  """LSTM decoder with multi-label output provided by a NADE."""

  def build(self, hparams, output_depth, is_training=False):
    self._nade = Nade(
        output_depth, hparams.nade_num_hidden, name='decoder/nade')
    super(MultiLabelRnnNadeDecoder, self).build(
        hparams, output_depth, is_training)
    # Overwrite output layer for NADE parameterization.
    self._output_layer = tf.layers.Dense(
        self._nade.num_hidden + output_depth, name='output_projection')

  def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
    b_enc, b_dec = tf.split(
        flat_rnn_output,
        [self._nade.num_hidden, self._output_depth], axis=1)
    ll, cond_probs = self._nade.log_prob(
        flat_x_target, b_enc=b_enc, b_dec=b_dec)
    r_loss = -ll
    flat_truth = tf.cast(flat_x_target, tf.bool)
    flat_predictions = tf.greater_equal(cond_probs, 0.5)

    metric_map = {
        'metrics/accuracy':
            tf.metrics.mean(
                tf.reduce_all(tf.equal(flat_truth, flat_predictions), axis=-1)),
        'metrics/recall':
            tf.metrics.recall(flat_truth, flat_predictions),
        'metrics/precision':
            tf.metrics.precision(flat_truth, flat_predictions),
    }

    return r_loss, metric_map

  def _sample(self, rnn_output, temperature=None):
    """Sample from NADE, returning the argmax if no temperature is provided."""
    b_enc, b_dec = tf.split(
        rnn_output, [self._nade.num_hidden, self._output_depth], axis=1)
    sample, _ = self._nade.sample(
        b_enc=b_enc, b_dec=b_dec, temperature=temperature)
    return sample


class HierarchicalLstmDecoder(base_model.BaseDecoder):
  """Hierarchical LSTM decoder."""

  def __init__(self,
               core_decoder,
               level_lengths,
               disable_autoregression=False,
               hierarchical_encoder=None):
    """Initializer for HierarchicalLstmDecoder.

    Hierarchicaly decodes a sequence across time.

    Each sequence is padded per-segment. For example, a sequence with
    three segments [1, 2, 3], [4, 5], [6, 7, 8 ,9] and a `max_seq_len` of 12
    is represented as `sequence = [1, 2, 3, 0, 4, 5, 0, 0, 6, 7, 8, 9]` with
    `sequence_length = [3, 2, 4]`.

    `z` initializes the first level LSTM to produce embeddings used to
    initialize the states of LSTMs at subsequent levels. The lowest-level
    embeddings are then passed to the given `core_decoder` to generate the
    final outputs.

    This decoder has 3 modes for what is used as the inputs to the LSTMs
    (excluding those in the core decoder):
      Autoregressive: (default) The inputs to the level `l` decoder are the
        final states of the level `l+1` decoder.
      Non-autoregressive: (`disable_autoregression=True`) The inputs to the
        hierarchical decoders are 0's.
      Re-encoder: (`hierarchical_encoder` provided) The inputs to the level `l`
        decoder are re-encoded outputs of level `l+1`, using the given encoder's
        matching level.

    Args:
      core_decoder: The BaseDecoder implementation to use at the output level.
      level_lengths: A list of the number of outputs of each level of the
        hierarchy. The final level is the (padded) maximum length. The product
        of the lengths must equal `hparams.max_seq_len`.
      disable_autoregression: Whether to disable the autoregression within the
        hierarchy. May also be a collection of levels on which to disable.
      hierarchical_encoder: (Optional) A HierarchicalLstmEncoder instance to use
        for re-encoding the decoder outputs at each level for use as inputs to
        the next level up in the hierarchy, instead of the final decoder state.
        The encoder level output lengths (except for the final single-output
        level) should be the reverse of `level_output_lengths`.

    Raises:
      ValueError: If `hierarchical_encoder` is given but has incompatible level
        lengths.
    """
    # Check for explicit True/False since lists may be given.
    if disable_autoregression is True:  # pylint:disable=g-bool-id-comparison
      disable_autoregression = list(range(len(level_lengths)))
    elif disable_autoregression is False:  # pylint:disable=g-bool-id-comparison
      disable_autoregression = []
    if (hierarchical_encoder and
        (tuple(hierarchical_encoder.level_lengths[-1::-1]) !=
         tuple(level_lengths))):
      raise ValueError(
          'Incompatible hierarchical encoder level output lengths: ',
          hierarchical_encoder.level_lengths, level_lengths)

    self._core_decoder = core_decoder
    self._level_lengths = level_lengths
    self._disable_autoregression = disable_autoregression
    self._hierarchical_encoder = hierarchical_encoder

  def build(self, hparams, output_depth, is_training=True):
    self.hparams = hparams
    self._output_depth = output_depth
    self._total_length = hparams.max_seq_len
    if self._total_length != np.prod(self._level_lengths):
      raise ValueError(
          'The product of the HierarchicalLstmDecoder level lengths (%d) must '
          'equal the padded input sequence length (%d).' % (
              np.prod(self._level_lengths), self._total_length))
    tf.logging.info('\nHierarchical Decoder:\n'
                    '  input length: %d\n'
                    '  level output lengths: %s\n',
                    self._total_length,
                    self._level_lengths)

    self._hier_cells = [
        lstm_utils.rnn_cell(
            hparams.dec_rnn_size,
            dropout_keep_prob=hparams.dropout_keep_prob,
            residual=hparams.residual_decoder)
        # Subtract 1 for the core decoder level
        for _ in range(len(self._level_lengths) - 1)]

    with tf.variable_scope('core_decoder', reuse=tf.AUTO_REUSE):
      self._core_decoder.build(hparams, output_depth, is_training)

  @property
  def state_size(self):
    return self._core_decoder.state_size

  def _merge_decode_results(self, decode_results):
    """Merge across time."""
    assert decode_results
    time_axis = 1
    zipped_results = lstm_utils.LstmDecodeResults(*list(zip(*decode_results)))
    if zipped_results.rnn_output[0] is None:
      rnn_output = None
      rnn_input = None
    else:
      rnn_output = tf.concat(zipped_results.rnn_output, axis=time_axis)
      rnn_input = tf.concat(zipped_results.rnn_input, axis=time_axis)
    return lstm_utils.LstmDecodeResults(
        rnn_output=rnn_output,
        rnn_input=rnn_input,
        samples=tf.concat(zipped_results.samples, axis=time_axis),
        final_state=zipped_results.final_state[-1],
        final_sequence_lengths=tf.stack(
            zipped_results.final_sequence_lengths, axis=time_axis))

  def _hierarchical_decode(self, z, base_decode_fn):
    """Depth first decoding from `z`, passing final embeddings to base fn."""
    batch_size = z.shape[0]
    # Subtract 1 for the core decoder level.
    num_levels = len(self._level_lengths) - 1

    hparams = self.hparams
    batch_size = hparams.batch_size

    def recursive_decode(initial_input, path=None):
      """Recursive hierarchical decode function."""
      path = path or []
      level = len(path)

      if level == num_levels:
        with tf.variable_scope('core_decoder', reuse=tf.AUTO_REUSE):
          return base_decode_fn(initial_input, path)

      scope = tf.VariableScope(
          tf.AUTO_REUSE, 'decoder/hierarchical_level_%d' % level)
      num_steps = self._level_lengths[level]
      with tf.variable_scope(scope):
        state = lstm_utils.initial_cell_state_from_embedding(
            self._hier_cells[level], initial_input, name='initial_state')
      if level not in self._disable_autoregression:
        # The initial input should be the same size as the tensors returned by
        # next level.
        if self._hierarchical_encoder:
          input_size = self._hierarchical_encoder.level(0).output_depth
        elif level == num_levels - 1:
          input_size = sum(tf.nest.flatten(self._core_decoder.state_size))
        else:
          input_size = sum(
              tf.nest.flatten(self._hier_cells[level + 1].state_size))
        next_input = tf.zeros([batch_size, input_size])
      lower_level_embeddings = []
      for i in range(num_steps):
        if level in self._disable_autoregression:
          next_input = tf.zeros([batch_size, 1])
        else:
          next_input = tf.concat([next_input, initial_input], axis=1)
        with tf.variable_scope(scope):
          output, state = self._hier_cells[level](next_input, state, scope)
        next_input = recursive_decode(output, path + [i])
        lower_level_embeddings.append(next_input)
      if self._hierarchical_encoder:
        # Return the encoding of the outputs using the appropriate level of the
        # hierarchical encoder.
        enc_level = num_levels - level
        return self._hierarchical_encoder.level(enc_level).encode(
            sequence=tf.stack(lower_level_embeddings, axis=1),
            sequence_length=tf.fill([batch_size], num_steps))
      else:
        # Return the final state.
        return tf.concat(tf.nest.flatten(state), axis=-1)

    return recursive_decode(z)

  def _reshape_to_hierarchy(self, t):
    """Reshapes `t` so that its initial dimensions match the hierarchy."""
    # Exclude the final, core decoder length.
    level_lengths = self._level_lengths[:-1]
    t_shape = t.shape.as_list()
    t_rank = len(t_shape)
    batch_size = t_shape[0]
    hier_shape = [batch_size] + level_lengths
    if t_rank == 3:
      hier_shape += [-1] + t_shape[2:]
    elif t_rank != 2:
      # We only expect rank-2 for lengths and rank-3 for sequences.
      raise ValueError('Unexpected shape for tensor: %s' % t)
    hier_t = tf.reshape(t, hier_shape)
    # Move the batch dimension to after the hierarchical dimensions.
    num_levels = len(level_lengths)
    perm = list(range(len(hier_shape)))
    perm.insert(num_levels, perm.pop(0))
    return tf.transpose(hier_t, perm)

  def reconstruction_loss(self, x_input, x_target, x_length, z=None,
                          c_input=None):
    """Reconstruction loss calculation.

    Args:
      x_input: Batch of decoder input sequences of concatenated segmeents for
        teacher forcing, sized `[batch_size, max_seq_len, output_depth]`.
      x_target: Batch of expected output sequences to compute loss against,
        sized `[batch_size, max_seq_len, output_depth]`.
      x_length: Length of input/output sequences, sized
        `[batch_size, level_lengths[0]]` or `[batch_size]`. If the latter,
        each length must either equal `max_seq_len` or 0. In this case, the
        segment lengths are assumed to be constant and the total length will be
        evenly divided amongst the segments.
      z: (Optional) Latent vectors. Required if model is conditional. Sized
        `[n, z_size]`.
      c_input: (Optional) Batch of control sequences, sized
        `[batch_size, max_seq_len, control_depth]`. Required if conditioning on
        control sequences.

    Returns:
      r_loss: The reconstruction loss for each sequence in the batch.
      metric_map: Map from metric name to tf.metrics return values for logging.
      decode_results: The LstmDecodeResults.

    Raises:
      ValueError: If `c_input` is provided in re-encoder mode.
    """
    if self._hierarchical_encoder and c_input is not None:
      raise ValueError(
          'Re-encoder mode unsupported when conditioning on controls.')

    batch_size = int(x_input.shape[0])

    x_length = lstm_utils.maybe_split_sequence_lengths(
        x_length, np.prod(self._level_lengths[:-1]), self._total_length)

    hier_input = self._reshape_to_hierarchy(x_input)
    hier_target = self._reshape_to_hierarchy(x_target)
    hier_length = self._reshape_to_hierarchy(x_length)
    hier_control = (
        self._reshape_to_hierarchy(c_input) if c_input is not None else None)

    loss_outputs = []

    def base_train_fn(embedding, hier_index):
      """Base function for training hierarchical decoder."""
      split_size = self._level_lengths[-1]
      split_input = hier_input[hier_index]
      split_target = hier_target[hier_index]
      split_length = hier_length[hier_index]
      split_control = (
          hier_control[hier_index] if hier_control is not None else None)

      res = self._core_decoder.reconstruction_loss(
          split_input, split_target, split_length, embedding, split_control)
      loss_outputs.append(res)
      decode_results = res[-1]

      if self._hierarchical_encoder:
        # Get the approximate "sample" from the model.
        # Start with the inputs the RNN saw (excluding the start token).
        samples = decode_results.rnn_input[:, 1:]
        # Pad to be the max length.
        samples = tf.pad(
            samples,
            [(0, 0), (0, split_size - tf.shape(samples)[1]), (0, 0)])
        samples.set_shape([batch_size, split_size, self._output_depth])
        # Set the final value based on the target, since the scheduled sampling
        # helper does not sample the final value.
        samples = lstm_utils.set_final(
            samples,
            split_length,
            lstm_utils.get_final(split_target, split_length, time_major=False),
            time_major=False)
        # Return the re-encoded sample.
        return self._hierarchical_encoder.level(0).encode(
            sequence=samples,
            sequence_length=split_length)
      elif self._disable_autoregression:
        return None
      else:
        return tf.concat(tf.nest.flatten(decode_results.final_state), axis=-1)

    z = tf.zeros([batch_size, 0]) if z is None else z
    self._hierarchical_decode(z, base_train_fn)

    # Accumulate the split sequence losses.
    r_losses, metric_maps, decode_results = list(zip(*loss_outputs))

    # Merge the metric maps by passing through renamed values and taking the
    # mean across the splits.
    merged_metric_map = {}
    for metric_name in metric_maps[0]:
      metric_values = []
      for i, m in enumerate(metric_maps):
        merged_metric_map['segment/%03d/%s' % (i, metric_name)] = m[metric_name]
        metric_values.append(m[metric_name][0])
      merged_metric_map[metric_name] = (
          tf.reduce_mean(metric_values), tf.no_op())

    return (tf.reduce_sum(r_losses, axis=0),
            merged_metric_map,
            self._merge_decode_results(decode_results))

  def sample(self, n, max_length=None, z=None, c_input=None,
             **core_sampler_kwargs):
    """Sample from decoder with an optional conditional latent vector `z`.

    Args:
      n: Scalar number of samples to return.
      max_length: (Optional) maximum total length of samples. If given, must
        match `hparams.max_seq_len`.
      z: (Optional) Latent vectors to sample from. Required if model is
        conditional. Sized `[n, z_size]`.
      c_input: (Optional) Control sequence, sized `[max_length, control_depth]`.
      **core_sampler_kwargs: (Optional) Additional keyword arguments to pass to
        core sampler.
    Returns:
      samples: Sampled sequences with concenated, possibly padded segments.
         Sized `[n, max_length, output_depth]`.
      decoder_results: The merged LstmDecodeResults from sampling.
    Raises:
      ValueError: If `z` is provided and its first dimension does not equal `n`,
        or if `c_input` is provided in re-encoder mode.
    """
    if z is not None and int(z.shape[0]) != n:
      raise ValueError(
          '`z` must have a first dimension that equals `n` when given. '
          'Got: %d vs %d' % (z.shape[0], n))
    z = tf.zeros([n, 0]) if z is None else z

    if self._hierarchical_encoder and c_input is not None:
      raise ValueError(
          'Re-encoder mode unsupported when conditioning on controls.')

    if max_length is not None:
      with tf.control_dependencies([
          tf.assert_equal(
              max_length, self._total_length,
              message='`max_length` must equal `hparams.max_seq_len` if given.')
      ]):
        max_length = tf.identity(max_length)

    if c_input is not None:
      # Reshape control sequence to hierarchy.
      c_input = tf.squeeze(
          self._reshape_to_hierarchy(tf.expand_dims(c_input, 0)),
          axis=len(self._level_lengths) - 1)

    core_max_length = self._level_lengths[-1]
    all_samples = []
    all_decode_results = []

    def base_sample_fn(embedding, hier_index):
      """Base function for sampling hierarchical decoder."""
      samples, decode_results = self._core_decoder.sample(
          n,
          max_length=core_max_length,
          z=embedding,
          c_input=c_input[hier_index] if c_input is not None else None,
          start_inputs=all_samples[-1][:, -1] if all_samples else None,
          **core_sampler_kwargs)
      all_samples.append(samples)
      all_decode_results.append(decode_results)
      if self._hierarchical_encoder:
        return self._hierarchical_encoder.level(0).encode(
            samples,
            decode_results.final_sequence_lengths)
      else:
        return tf.concat(tf.nest.flatten(decode_results.final_state), axis=-1)

    # Populate `all_sample_ids`.
    self._hierarchical_decode(z, base_sample_fn)

    all_samples = tf.concat(
        [tf.pad(s, [(0, 0), (0, core_max_length - tf.shape(s)[1]), (0, 0)])
         for s in all_samples],
        axis=1)
    return all_samples, self._merge_decode_results(all_decode_results)


def get_default_hparams():
  """Returns copy of default HParams for LSTM models."""
  hparams_map = base_model.get_default_hparams().values()
  hparams_map.update({
      'conditional': True,
      'dec_rnn_size': [512],  # Decoder RNN: number of units per layer.
      'enc_rnn_size': [256],  # Encoder RNN: number of units per layer per dir.
      'dropout_keep_prob': 1.0,  # Probability all dropout keep.
      'sampling_schedule': 'constant',  # constant, exponential, inverse_sigmoid
      'sampling_rate': 0.0,  # Interpretation is based on `sampling_schedule`.
      'use_cudnn': False,  # DEPRECATED
      'residual_encoder': False,  # Use residual connections in encoder.
      'residual_decoder': False,  # Use residual connections in decoder.
      'control_preprocessing_rnn_size': [256],  # Decoder control preprocessing.
  })
  return contrib_training.HParams(**hparams_map)


class GrooveLstmDecoder(BaseLstmDecoder):
  """Groove LSTM decoder with MSE loss for continuous values.

  At each timestep, this decoder outputs a vector of length (N_INSTRUMENTS*3).
  The default number of drum instruments is 9, with drum categories defined in
  drums_encoder_decoder.py

  For each instrument, the model outputs a triple of (on/off, velocity, offset),
  with a binary representation for on/off, continuous values between 0 and 1
  for velocity, and continuous values between -0.5 and 0.5 for offset.
  """

  def _activate_outputs(self, flat_rnn_output):
    output_hits, output_velocities, output_offsets = tf.split(
        flat_rnn_output, 3, axis=1)

    output_hits = tf.nn.sigmoid(output_hits)
    output_velocities = tf.nn.sigmoid(output_velocities)
    output_offsets = tf.nn.tanh(output_offsets)

    return output_hits, output_velocities, output_offsets

  def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
    # flat_x_target is by default shape (1,27), [on/offs... vels...offsets...]
    # split into 3 equal length vectors
    target_hits, target_velocities, target_offsets = tf.split(
        flat_x_target, 3, axis=1)

    output_hits, output_velocities, output_offsets = self._activate_outputs(
        flat_rnn_output)

    hits_loss = tf.reduce_sum(tf.losses.log_loss(
        labels=target_hits, predictions=output_hits,
        reduction=tf.losses.Reduction.NONE), axis=1)

    velocities_loss = tf.reduce_sum(tf.losses.mean_squared_error(
        target_velocities, output_velocities,
        reduction=tf.losses.Reduction.NONE), axis=1)

    offsets_loss = tf.reduce_sum(tf.losses.mean_squared_error(
        target_offsets, output_offsets,
        reduction=tf.losses.Reduction.NONE), axis=1)

    loss = hits_loss + velocities_loss + offsets_loss

    metric_map = {
        'metrics/hits_loss':
            tf.metrics.mean(hits_loss),
        'metrics/velocities_loss':
            tf.metrics.mean(velocities_loss),
        'metrics/offsets_loss':
            tf.metrics.mean(offsets_loss)
    }

    return loss, metric_map

  def _sample(self, rnn_output, temperature=1.0):
    output_hits, output_velocities, output_offsets = tf.split(
        rnn_output, 3, axis=1)

    output_velocities = tf.nn.sigmoid(output_velocities)
    output_offsets = tf.nn.tanh(output_offsets)

    hits_sampler = tfp.distributions.Bernoulli(
        logits=output_hits / temperature, dtype=tf.float32)

    output_hits = hits_sampler.sample()
    return tf.concat([output_hits, output_velocities, output_offsets], axis=1)