"""Neural Shuffle-Exchange Network.

Implementation of
"Neural Shuffle-Exchange Networks - Sequence Processing in O(n log n) Time"
paper by K.Freivalds, E.Ozolins, A.Sostaks.

Paper: https://papers.nips.cc/paper/

Original code: https://github.com/LUMII-Syslab/shuffle-exchange
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
from tensor2tensor.layers import common_hparams
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow.compat.v1 as tf

def ror(x, n, p=1):
  """Bitwise right rotation.

    x: Input tensor
    n: Bit count to represent x
    p: Bit positions to shift

    tf.Tensor: x shifted by p positions in n bits

  a = tf.bitwise.right_shift(x, p)
  b = tf.bitwise.left_shift(1, p) - 1
  c = tf.bitwise.bitwise_and(x, b)
  d = tf.bitwise.left_shift(c, n - p)

  return a + d

def rol(x, n, p=1):
  """Bitwise left rotation.

    x: Input tensor
    n: Bit count to represent x
    p: Bit positions to shift

    tf.Tensor: x shifted by p positions in n bits
  a = tf.bitwise.left_shift(x, p)
  b = tf.bitwise.left_shift(1, n) - 1
  c = tf.bitwise.bitwise_and(a, b)
  d = tf.bitwise.right_shift(x, n - p)

  return tf.bitwise.bitwise_or(c, d)

def shuffle_layer(inputs, shuffle_fn=rol):
  """Shuffles the elements according to bitwise left or right rotation.

    inputs: Tensor input from previous layer
    shuffle_fn: Shift function rol or ror

    tf.Tensor: Inputs shifted according to shuffle_fn

  length = tf.shape(inputs)[1]
  n_bits = tf.log(tf.cast(length - 1, tf.float32)) / tf.log(2.0)
  n_bits = tf.cast(n_bits, tf.int32) + 1

  indices = tf.range(0, length)
  rev_indices = shuffle_fn(indices, n_bits)
  return tf.gather(inputs, rev_indices, axis=1)

def reverse_shuffle_layer(inputs):
  """Reverse shuffle of inputs.

  Used in the second half of Benes block.

    inputs: Inputs that should be shuffled

    tf.Tensor: Inputs shuffled according to bitwise right rotation

  return shuffle_layer(inputs, ror)

def conv_linear_map(inputs, nin, nout, bias_start, prefix):
  """Convolutional liner map.

  Maps 3D tensor by last dimension.

    inputs: Inputs that should be shuffled
    nin: Input feature map count
    nout: Output feature map count
    bias_start: Bias start value
    prefix: Name prefix

    tf.Tensor: Inputs with applied convolution

  with tf.variable_scope(prefix):
    inp_shape = tf.shape(inputs)

    initializer = tf.variance_scaling_initializer(
        scale=1.0, mode="fan_avg", distribution="uniform")
    kernel = tf.get_variable("CvK", [nin, nout], initializer=initializer)
    bias_term = tf.get_variable(
        "CvB", [nout], initializer=tf.constant_initializer(0.0))

    mul_shape = [inp_shape[0] * inp_shape[1], nin]
    res = tf.matmul(tf.reshape(inputs, mul_shape), kernel)
    res = tf.reshape(res, [inp_shape[0], inp_shape[1], nout])
    return res + bias_start + bias_term

# pylint: disable=useless-object-inheritance
class SwitchLayer(object):
  """Switch layer of Neural Shuffle-Exchange network."""

  def __init__(self, prefix, dropout, mode):
    """Initialize switch layer.

      prefix: Name prefix for switch layer
      dropout: Dropout rate
      mode: Training mode

    self.prefix = prefix
    self.dropout = dropout
    self.mode = mode
    self.batch_size = None
    self.length = None
    self.num_units = None
    self.n_bits = None

  def linear_map(self, inputs, suffix, bias_start, in_units, out_units):
    """2 input to 2 output linear map.

      inputs: Input tensor
      suffix: Linear map name suffix
      bias_start: Bias start value
      in_units: Size of input tensor feature map count
      out_units: Size of output tensor feature map count
      tf.Tensor: Convolution apply to input tensor
    in_shape = [self.batch_size, self.length // 2, in_units * 2]
    inputs = tf.reshape(inputs, in_shape)
    res = conv_linear_map(inputs, in_units * 2, out_units * 2, bias_start,
                          self.prefix + "/" + suffix)
    return tf.reshape(res, [self.batch_size, self.length, out_units])

  def gated_linear_map(self, inputs, suffix, bias_start_reset, in_units,
    """Linear mapping with two reset gates.

      inputs: Input tensor
      suffix: Linear map name suffix
      bias_start_reset: Bias start value for reset gate
      in_units: Size of input tensor feature map count
      out_units: Size of output tensor feature map count
      tf.Tensor: Convolution apply to input tensor

    def reset_gate(name):
      prefix = self.prefix + name + suffix
      reset = conv_linear_map(inputs, in_units * 2, in_units * 2,
                              bias_start_reset, prefix)
      return tf.nn.sigmoid(reset)

    in_shape = [self.batch_size, self.length // 2, in_units * 2]
    inputs = tf.reshape(inputs, in_shape)

    reset1 = reset_gate("/reset1/")
    reset2 = reset_gate("/reset2/")
    res1 = conv_linear_map(inputs * reset1, in_units * 2, out_units, 0.0,
                           self.prefix + "/cand1/" + suffix)
    res2 = conv_linear_map(inputs * reset2, in_units * 2, out_units, 0.0,
                           self.prefix + "/cand2/" + suffix)

    res = tf.concat([res1, res2], axis=2)
    res = tf.reshape(res, [self.batch_size, self.length, out_units])
    return tf.nn.tanh(res)

  def __call__(self, inputs, residual_inputs):
    """Apply SwitchLayer to inputs.

      inputs: Input tensor
      residual_inputs: Residual connections from previous block

      tf.Tensor: New candidate value
    input_shape = tf.shape(inputs)
    self.batch_size = input_shape[0]
    self.length = input_shape[1]
    self.num_units = inputs.shape.as_list()[2]

    self.n_bits = tf.log(tf.cast(self.length - 1, tf.float32)) / tf.log(2.0)
    self.n_bits = tf.floor(self.n_bits) + 1

    initializer = tf.constant_initializer(0.5)
    residual_scale = tf.get_variable(
        self.prefix + "/residual_scale", [self.num_units],

    shuffled_input = self.swap_halves(inputs)
    mem_all = inputs + residual_inputs * residual_scale

    # calculate the new value
    candidate = self.gated_linear_map(mem_all, "c", 0.5, self.num_units,
    gate = tf.nn.sigmoid(
        self.linear_map(mem_all, "g", 0.5, self.num_units, self.num_units))

    candidate = gate * shuffled_input + (1 - gate) * candidate

    if self.dropout > 0:
      candidate = tf.nn.dropout(candidate, rate=self.dropout / self.n_bits)
    if self.dropout != 0.0 and self.mode == tf.estimator.ModeKeys.TRAIN:
      noise = tf.random_normal(tf.shape(candidate), mean=1.0, stddev=0.001)
      candidate = candidate * noise

    return candidate

  def swap_halves(self, inputs):
    """Split inputs in half and then shuffle them as described in paper.

      inputs: ShuffleLayer inputs
      tf.Tensor: Inputs with swapped halves
    x = tf.range(0, self.length)
    xor_indices = tf.bitwise.bitwise_xor(x, 1)
    input_xor = tf.gather(
        inputs[:, :, :self.num_units // 2], xor_indices, axis=1)
    return tf.concat([input_xor, inputs[:, :, self.num_units // 2:]], axis=2)

def shuffle_network(inputs, hparams):
  """Neural Shuffle-Network with skip connections between blocks.

    inputs: inputs to the Shuffle-Exchange network. Should be in length of power
      of 2.
    hparams: Model configuration

    tf.Tensor: Outputs of the Shuffle-Exchange last layer

  def forward_step(state, layer_nr):
    with tf.variable_scope("forward"):
      last_state, residuals = state
      prev = residuals[layer_nr, :, :, :]
      switch = SwitchLayer("switch", hparams.dropout, hparams.mode)
      cur = switch(last_state, prev)
      return shuffle_layer(cur), residuals

  def reverse_step(state, layer_nr):
    with tf.variable_scope("reverse"):
      last_state, residuals = state
      prev = residuals[layer_nr, :, :, :]
      switch = SwitchLayer("reverse_switch", hparams.dropout, hparams.mode)
      cur = switch(last_state, prev)
      return reverse_shuffle_layer(cur), residuals

  input_shape = tf.shape(inputs)
  n_bits = tf.log(tf.cast(input_shape[1] - 1, tf.float32)) / tf.log(2.0)
  n_bits = tf.cast(n_bits, tf.int32) + 1

  queue_shape = [n_bits * 2, input_shape[0], input_shape[1], input_shape[2]]
  residuals_queue = tf.zeros(queue_shape)
  block_out = tf.tanh(inputs)

  for k in range(hparams.num_hidden_layers):
    with tf.variable_scope("benes_block_" + str(k), reuse=tf.AUTO_REUSE):
      forward_outputs, _ = tf.scan(
          tf.range(0, n_bits),
          initializer=(block_out, residuals_queue),

      forward_tensors = [tf.expand_dims(block_out, axis=0), forward_outputs]
      forward_outputs = tf.concat(forward_tensors, axis=0)
      forward_last = forward_outputs[-1, :, :, :]

      reverse_outputs, _ = tf.scan(
          tf.range(n_bits, n_bits * 2),
          initializer=(forward_last, residuals_queue),

      block_out = reverse_outputs[-1, :, :, :]
      residuals_queue = tf.concat([forward_outputs, reverse_outputs], axis=0)

  last_layer = SwitchLayer("last_layer", hparams.dropout, hparams.mode)
  return last_layer(block_out, residuals_queue[n_bits * 2, :, :, :])

class ShuffleNetwork(t2t_model.T2TModel):
  """Seq2Seq model for sequence processing in O(n log n) time."""

  def bottom(self, features):
    """We add padding to the input and output so they are the same.

    Length of input and output should be power of 2.

      features: Dictionary of inputs and targets

      dictionary: Inputs and targets padded with 0 to the length of power of 2.
      Both are same length.
    pad_len = self.max_pad_length(features)
    features["inputs"] = self.pad(features["inputs"], pad_len)

    if features.get("targets") is not None:
      features["targets"] = self.pad(features["targets"], pad_len)

    return super(ShuffleNetwork, self).bottom(features)

  def pad(tensor, pad_len):
    """Pad tensor on first dimension to pad_len.

      tensor: input tensor of shape length >= 2
      pad_len: pad length

      tf.Tensor: Padded input tensor.

    assert len(tensor.shape) >= 2  # tensor of shape [batch, length, ...]
    length = tf.shape(tensor)[1]

    padding = [[0, 0], [0, pad_len - length]]
    padding += [[0, 0]] * (len(tensor.shape) - 2)
    return tf.pad(tensor, padding)

  def max_pad_length(self, features):
    """Finds max padding length.

    If target length not specified use fixed padding
    length from hparams.max_length.

      features: Dictionary with input and target tensors

      tf.Tensor:  Length of input and output sequence. Length is power of 2.

    if self.hparams.force_max_length or features.get("targets") is None:
      assert math.log(self.hparams.max_length, 2).is_integer(), \
        "hparams.max_length should be power of w"

      return self.hparams.max_length

    length = tf.shape(features["inputs"])[1]
    targets_length = tf.shape(features["targets"])[1]
    length = tf.maximum(length, targets_length)

    p = tf.log(tf.cast(length, tf.float32)) / tf.log(2.0)
    p = tf.cast(tf.ceil(p), tf.int32)
    return tf.pow(2, p)

  def infer(self, features=None, **kwargs):
    """Custom infer method for Shuffle-Exchange network.

      features: Dictionary of inputs and targets
      **kwargs: SE network currently doesn't support auto-regressive output

      dict: Dictionary of outputs.

    del kwargs
    targets = features.get("targets")
    infer_targets = features.get("infer_targets")

    if targets is None and infer_targets is not None:
      features["targets"] = infer_targets

    # Run the model
    self.hparams.force_full_predict = True
    with tf.variable_scope(self.name):
      logits, _ = self.model_fn(features)

    assert len(logits.shape) == 5  # [batch, time, 1, 1, vocab]
    logits = tf.squeeze(logits, [2, 3])
    outputs = tf.argmax(logits, axis=2)

    return {"outputs": outputs, "logits": logits, "scores": None}

  def loss(self, logits, features):
    """Loss function for Neural Shuffle-Exchange network.

    We use custom loss function as default loss function doesn't
    use padding for calculating loss. We assume that output string is same
    length as the input. If you need other type of output please feel
    free to modify this.

      logits: Logits from model
      features: Features, not in one-hot format

       tf.Tensor: Loss value

    onehot_labels = tf.one_hot(features["targets"],
    cost_vector = tf.nn.softmax_cross_entropy_with_logits_v2(
        logits=logits, labels=onehot_labels)
    return tf.reduce_mean(cost_vector)

  def body(self, features):
    """Body of Neural Shuffle-Exchange network.

      features: dictionary of inputs and targets

    inputs = tf.squeeze(features["inputs"], axis=2)
    logits = shuffle_network(inputs, self._hparams)
    return tf.expand_dims(logits, axis=2)

def shuffle_network_baseline():
  """Large Shuffle-Exchange configuration.

    dict: Neural Shuffle-Exchange configuration

  hparams = common_hparams.basic_params1()
  hparams.hidden_size = 48 * 8  # feature maps
  hparams.num_hidden_layers = 2  # block count

  hparams.clip_grad_norm = 0.  # no gradient clipping

  hparams.optimizer = "adam"
  hparams.optimizer_adam_epsilon = 1e-5
  hparams.learning_rate_schedule = "legacy"
  hparams.learning_rate_decay_scheme = "noam"
  hparams.learning_rate = 0.1
  hparams.initializer_gain = 1.0
  hparams.initializer = "uniform_unit_scaling"
  hparams.optimizer_adam_beta1 = 0.9
  hparams.optimizer_adam_beta2 = 0.999
  hparams.add_hparam("force_max_length", False)  # use fixed max length
  hparams.max_length = 256  # use when targets are not known

  hparams.dropout = 0.1
  hparams.label_smoothing = 0.
  hparams.weight_decay = 0.

  return hparams