# coding=utf-8 # Copyright 2020 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Neural Shuffle-Exchange Network. Implementation of "Neural Shuffle-Exchange Networks - Sequence Processing in O(n log n) Time" paper by K.Freivalds, E.Ozolins, A.Sostaks. Paper: https://papers.nips.cc/paper/ 8889-neural-shuffle-exchange-networks-sequence-processing-in-on-log-n-time.pdf Original code: https://github.com/LUMII-Syslab/shuffle-exchange """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import math from tensor2tensor.layers import common_hparams from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model import tensorflow.compat.v1 as tf def ror(x, n, p=1): """Bitwise right rotation. Args: x: Input tensor n: Bit count to represent x p: Bit positions to shift Returns: tf.Tensor: x shifted by p positions in n bits """ a = tf.bitwise.right_shift(x, p) b = tf.bitwise.left_shift(1, p) - 1 c = tf.bitwise.bitwise_and(x, b) d = tf.bitwise.left_shift(c, n - p) return a + d def rol(x, n, p=1): """Bitwise left rotation. Args: x: Input tensor n: Bit count to represent x p: Bit positions to shift Returns: tf.Tensor: x shifted by p positions in n bits """ a = tf.bitwise.left_shift(x, p) b = tf.bitwise.left_shift(1, n) - 1 c = tf.bitwise.bitwise_and(a, b) d = tf.bitwise.right_shift(x, n - p) return tf.bitwise.bitwise_or(c, d) def shuffle_layer(inputs, shuffle_fn=rol): """Shuffles the elements according to bitwise left or right rotation. Args: inputs: Tensor input from previous layer shuffle_fn: Shift function rol or ror Returns: tf.Tensor: Inputs shifted according to shuffle_fn """ length = tf.shape(inputs)[1] n_bits = tf.log(tf.cast(length - 1, tf.float32)) / tf.log(2.0) n_bits = tf.cast(n_bits, tf.int32) + 1 indices = tf.range(0, length) rev_indices = shuffle_fn(indices, n_bits) return tf.gather(inputs, rev_indices, axis=1) def reverse_shuffle_layer(inputs): """Reverse shuffle of inputs. Used in the second half of Benes block. Args: inputs: Inputs that should be shuffled Returns: tf.Tensor: Inputs shuffled according to bitwise right rotation """ return shuffle_layer(inputs, ror) def conv_linear_map(inputs, nin, nout, bias_start, prefix): """Convolutional liner map. Maps 3D tensor by last dimension. Args: inputs: Inputs that should be shuffled nin: Input feature map count nout: Output feature map count bias_start: Bias start value prefix: Name prefix Returns: tf.Tensor: Inputs with applied convolution """ with tf.variable_scope(prefix): inp_shape = tf.shape(inputs) initializer = tf.variance_scaling_initializer( scale=1.0, mode="fan_avg", distribution="uniform") kernel = tf.get_variable("CvK", [nin, nout], initializer=initializer) bias_term = tf.get_variable( "CvB", [nout], initializer=tf.constant_initializer(0.0)) mul_shape = [inp_shape[0] * inp_shape[1], nin] res = tf.matmul(tf.reshape(inputs, mul_shape), kernel) res = tf.reshape(res, [inp_shape[0], inp_shape[1], nout]) return res + bias_start + bias_term # pylint: disable=useless-object-inheritance class SwitchLayer(object): """Switch layer of Neural Shuffle-Exchange network.""" def __init__(self, prefix, dropout, mode): """Initialize switch layer. Args: prefix: Name prefix for switch layer dropout: Dropout rate mode: Training mode """ self.prefix = prefix self.dropout = dropout self.mode = mode self.batch_size = None self.length = None self.num_units = None self.n_bits = None def linear_map(self, inputs, suffix, bias_start, in_units, out_units): """2 input to 2 output linear map. Args: inputs: Input tensor suffix: Linear map name suffix bias_start: Bias start value in_units: Size of input tensor feature map count out_units: Size of output tensor feature map count Return: tf.Tensor: Convolution apply to input tensor """ in_shape = [self.batch_size, self.length // 2, in_units * 2] inputs = tf.reshape(inputs, in_shape) res = conv_linear_map(inputs, in_units * 2, out_units * 2, bias_start, self.prefix + "/" + suffix) return tf.reshape(res, [self.batch_size, self.length, out_units]) def gated_linear_map(self, inputs, suffix, bias_start_reset, in_units, out_units): """Linear mapping with two reset gates. Args: inputs: Input tensor suffix: Linear map name suffix bias_start_reset: Bias start value for reset gate in_units: Size of input tensor feature map count out_units: Size of output tensor feature map count Return: tf.Tensor: Convolution apply to input tensor """ def reset_gate(name): prefix = self.prefix + name + suffix reset = conv_linear_map(inputs, in_units * 2, in_units * 2, bias_start_reset, prefix) return tf.nn.sigmoid(reset) in_shape = [self.batch_size, self.length // 2, in_units * 2] inputs = tf.reshape(inputs, in_shape) reset1 = reset_gate("/reset1/") reset2 = reset_gate("/reset2/") res1 = conv_linear_map(inputs * reset1, in_units * 2, out_units, 0.0, self.prefix + "/cand1/" + suffix) res2 = conv_linear_map(inputs * reset2, in_units * 2, out_units, 0.0, self.prefix + "/cand2/" + suffix) res = tf.concat([res1, res2], axis=2) res = tf.reshape(res, [self.batch_size, self.length, out_units]) return tf.nn.tanh(res) def __call__(self, inputs, residual_inputs): """Apply SwitchLayer to inputs. Args: inputs: Input tensor residual_inputs: Residual connections from previous block Returns: tf.Tensor: New candidate value """ input_shape = tf.shape(inputs) self.batch_size = input_shape[0] self.length = input_shape[1] self.num_units = inputs.shape.as_list()[2] self.n_bits = tf.log(tf.cast(self.length - 1, tf.float32)) / tf.log(2.0) self.n_bits = tf.floor(self.n_bits) + 1 initializer = tf.constant_initializer(0.5) residual_scale = tf.get_variable( self.prefix + "/residual_scale", [self.num_units], initializer=initializer) shuffled_input = self.swap_halves(inputs) mem_all = inputs + residual_inputs * residual_scale # calculate the new value candidate = self.gated_linear_map(mem_all, "c", 0.5, self.num_units, self.num_units) gate = tf.nn.sigmoid( self.linear_map(mem_all, "g", 0.5, self.num_units, self.num_units)) candidate = gate * shuffled_input + (1 - gate) * candidate if self.dropout > 0: candidate = tf.nn.dropout(candidate, rate=self.dropout / self.n_bits) if self.dropout != 0.0 and self.mode == tf.estimator.ModeKeys.TRAIN: noise = tf.random_normal(tf.shape(candidate), mean=1.0, stddev=0.001) candidate = candidate * noise return candidate def swap_halves(self, inputs): """Split inputs in half and then shuffle them as described in paper. Args: inputs: ShuffleLayer inputs Return: tf.Tensor: Inputs with swapped halves """ x = tf.range(0, self.length) xor_indices = tf.bitwise.bitwise_xor(x, 1) input_xor = tf.gather( inputs[:, :, :self.num_units // 2], xor_indices, axis=1) return tf.concat([input_xor, inputs[:, :, self.num_units // 2:]], axis=2) def shuffle_network(inputs, hparams): """Neural Shuffle-Network with skip connections between blocks. Args: inputs: inputs to the Shuffle-Exchange network. Should be in length of power of 2. hparams: Model configuration Returns: tf.Tensor: Outputs of the Shuffle-Exchange last layer """ def forward_step(state, layer_nr): with tf.variable_scope("forward"): last_state, residuals = state prev = residuals[layer_nr, :, :, :] switch = SwitchLayer("switch", hparams.dropout, hparams.mode) cur = switch(last_state, prev) return shuffle_layer(cur), residuals def reverse_step(state, layer_nr): with tf.variable_scope("reverse"): last_state, residuals = state prev = residuals[layer_nr, :, :, :] switch = SwitchLayer("reverse_switch", hparams.dropout, hparams.mode) cur = switch(last_state, prev) return reverse_shuffle_layer(cur), residuals input_shape = tf.shape(inputs) n_bits = tf.log(tf.cast(input_shape[1] - 1, tf.float32)) / tf.log(2.0) n_bits = tf.cast(n_bits, tf.int32) + 1 queue_shape = [n_bits * 2, input_shape[0], input_shape[1], input_shape[2]] residuals_queue = tf.zeros(queue_shape) block_out = tf.tanh(inputs) for k in range(hparams.num_hidden_layers): with tf.variable_scope("benes_block_" + str(k), reuse=tf.AUTO_REUSE): forward_outputs, _ = tf.scan( forward_step, tf.range(0, n_bits), initializer=(block_out, residuals_queue), parallel_iterations=1, swap_memory=True) forward_tensors = [tf.expand_dims(block_out, axis=0), forward_outputs] forward_outputs = tf.concat(forward_tensors, axis=0) forward_last = forward_outputs[-1, :, :, :] reverse_outputs, _ = tf.scan( reverse_step, tf.range(n_bits, n_bits * 2), initializer=(forward_last, residuals_queue), parallel_iterations=1, swap_memory=True) block_out = reverse_outputs[-1, :, :, :] residuals_queue = tf.concat([forward_outputs, reverse_outputs], axis=0) last_layer = SwitchLayer("last_layer", hparams.dropout, hparams.mode) return last_layer(block_out, residuals_queue[n_bits * 2, :, :, :]) @registry.register_model class ShuffleNetwork(t2t_model.T2TModel): """Seq2Seq model for sequence processing in O(n log n) time.""" def bottom(self, features): """We add padding to the input and output so they are the same. Length of input and output should be power of 2. Args: features: Dictionary of inputs and targets Returns: dictionary: Inputs and targets padded with 0 to the length of power of 2. Both are same length. """ pad_len = self.max_pad_length(features) features["inputs"] = self.pad(features["inputs"], pad_len) if features.get("targets") is not None: features["targets"] = self.pad(features["targets"], pad_len) return super(ShuffleNetwork, self).bottom(features) @staticmethod def pad(tensor, pad_len): """Pad tensor on first dimension to pad_len. Args: tensor: input tensor of shape length >= 2 pad_len: pad length Returns: tf.Tensor: Padded input tensor. """ assert len(tensor.shape) >= 2 # tensor of shape [batch, length, ...] length = tf.shape(tensor)[1] padding = [[0, 0], [0, pad_len - length]] padding += [[0, 0]] * (len(tensor.shape) - 2) return tf.pad(tensor, padding) def max_pad_length(self, features): """Finds max padding length. If target length not specified use fixed padding length from hparams.max_length. Args: features: Dictionary with input and target tensors Returns: tf.Tensor: Length of input and output sequence. Length is power of 2. """ if self.hparams.force_max_length or features.get("targets") is None: assert math.log(self.hparams.max_length, 2).is_integer(), \ "hparams.max_length should be power of w" return self.hparams.max_length length = tf.shape(features["inputs"])[1] targets_length = tf.shape(features["targets"])[1] length = tf.maximum(length, targets_length) p = tf.log(tf.cast(length, tf.float32)) / tf.log(2.0) p = tf.cast(tf.ceil(p), tf.int32) return tf.pow(2, p) def infer(self, features=None, **kwargs): """Custom infer method for Shuffle-Exchange network. Args: features: Dictionary of inputs and targets **kwargs: SE network currently doesn't support auto-regressive output Returns: dict: Dictionary of outputs. """ del kwargs targets = features.get("targets") infer_targets = features.get("infer_targets") if targets is None and infer_targets is not None: features["targets"] = infer_targets # Run the model self.hparams.force_full_predict = True with tf.variable_scope(self.name): logits, _ = self.model_fn(features) assert len(logits.shape) == 5 # [batch, time, 1, 1, vocab] logits = tf.squeeze(logits, [2, 3]) outputs = tf.argmax(logits, axis=2) return {"outputs": outputs, "logits": logits, "scores": None} def loss(self, logits, features): """Loss function for Neural Shuffle-Exchange network. We use custom loss function as default loss function doesn't use padding for calculating loss. We assume that output string is same length as the input. If you need other type of output please feel free to modify this. Args: logits: Logits from model features: Features, not in one-hot format Returns: tf.Tensor: Loss value """ onehot_labels = tf.one_hot(features["targets"], self._problem_hparams.vocab_size["targets"]) cost_vector = tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits, labels=onehot_labels) return tf.reduce_mean(cost_vector) def body(self, features): """Body of Neural Shuffle-Exchange network. Args: features: dictionary of inputs and targets """ inputs = tf.squeeze(features["inputs"], axis=2) logits = shuffle_network(inputs, self._hparams) return tf.expand_dims(logits, axis=2) @registry.register_hparams def shuffle_network_baseline(): """Large Shuffle-Exchange configuration. Returns: dict: Neural Shuffle-Exchange configuration """ hparams = common_hparams.basic_params1() hparams.hidden_size = 48 * 8 # feature maps hparams.num_hidden_layers = 2 # block count hparams.clip_grad_norm = 0. # no gradient clipping hparams.optimizer = "adam" hparams.optimizer_adam_epsilon = 1e-5 hparams.learning_rate_schedule = "legacy" hparams.learning_rate_decay_scheme = "noam" hparams.learning_rate = 0.1 hparams.initializer_gain = 1.0 hparams.initializer = "uniform_unit_scaling" hparams.optimizer_adam_beta1 = 0.9 hparams.optimizer_adam_beta2 = 0.999 hparams.add_hparam("force_max_length", False) # use fixed max length hparams.max_length = 256 # use when targets are not known hparams.dropout = 0.1 hparams.label_smoothing = 0. hparams.weight_decay = 0. return hparams