############################################################################### # BSD 3-Clause License # # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Copyright (c) 2017, Facebook, inc. All rights reserved. ############################################################################### ''' Code adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/transformer.py Introduced optimal gradient checkpointing for intermediate layers in ./transformer.py ''' import math import torch import torch.nn as nn import torch.nn.functional as F from collections import defaultdict class MultiheadAttention(nn.Module): """Multi-headed attention. See "Attention Is All You Need" for more details. """ def __init__(self, embed_dim, num_heads, dropout=0., bias=True): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self._mask = None self.in_proj_weight = nn.Parameter(torch.Tensor(3*embed_dim, embed_dim)) if bias: self.in_proj_bias = nn.Parameter(torch.Tensor(3*embed_dim)) else: self.register_parameter('in_proj_bias', None) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.in_proj_weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.in_proj_bias is not None: nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.out_proj.bias, 0.) def forward(self, query, key, value, mask_future_timesteps=False, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for query, key and value. Future timesteps can be masked with the `mask_future_timesteps` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() kv_same = key.data_ptr() == value.data_ptr() tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] assert key.size() == value.size() if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if 'prev_key' in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert kv_same and not qkv_same key = value = None else: saved_state = None if qkv_same: # self-attention q, k, v = self.in_proj_qkv(query) elif kv_same: # encoder-decoder attention q = self.in_proj_q(query) if key is None: assert value is None # this will allow us to concat it with previous value and get # just get the previous value k = v = q.new(0) else: k, v = self.in_proj_kv(key) else: q = self.in_proj_q(query) k = self.in_proj_k(key) v = self.in_proj_v(value) q *= self.scaling if saved_state is not None: if 'prev_key' in saved_state: k = torch.cat((saved_state['prev_key'], k), dim=0) if 'prev_value' in saved_state: v = torch.cat((saved_state['prev_value'], v), dim=0) saved_state['prev_key'] = k saved_state['prev_value'] = v self._set_input_buffer(incremental_state, saved_state) src_len = k.size(0) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] # only apply masking at training time (when incremental state is None) if mask_future_timesteps and incremental_state is None: assert query.size() == key.size(), \ 'mask_future_timesteps only applies to self-attention' attn_weights += self.buffered_mask(attn_weights).unsqueeze(0) if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.float().masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ).type_as(attn_weights) # FP16 support: cast to float and back attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn = torch.bmm(attn_weights, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) if need_weights: # average attention weights over heads attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.sum(dim=1) / self.num_heads else: attn_weights = None return attn, attn_weights def in_proj_k(self, key): return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim) def in_proj_v(self, value): return self._in_proj(value, start=2*self.embed_dim) def _in_proj(self, input, start=None, end=None): weight = self.in_proj_weight bias = self.in_proj_bias if end is not None: weight = weight[:end, :] if bias is not None: bias = bias[:end] if start is not None: weight = weight[start:, :] if bias is not None: bias = bias[start:] return F.linear(input.type_as(weight), weight, bias) def buffered_mask(self, tensor): attn = self.out_proj(attn) if need_weights: # average attention weights over heads attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.sum(dim=1) / self.num_heads else: attn_weights = None return attn, attn_weights def in_proj_qkv(self, query): return self._in_proj(query).chunk(3, dim=-1) def in_proj_kv(self, key): return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) def in_proj_q(self, query): return self._in_proj(query, end=self.embed_dim) def in_proj_k(self, key): return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim) def in_proj_v(self, value): return self._in_proj(value, start=2*self.embed_dim) def _in_proj(self, input, start=None, end=None): weight = self.in_proj_weight bias = self.in_proj_bias if end is not None: weight = weight[:end, :] if bias is not None: bias = bias[:end] if start is not None: weight = weight[start:, :] if bias is not None: bias = bias[start:] return F.linear(input.type_as(weight), weight, bias) def buffered_mask(self, tensor): dim = tensor.size(-1) if self._mask is None: self._mask = torch.triu(fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._mask.size(0) < dim: self._mask = torch.triu(fill_with_neg_inf(self._mask.resize_(dim, dim)), 1) return self._mask[:dim, :dim] def reorder_incremental_state(self, incremental_state, new_order): """Reorder buffered internal state (for incremental generation).""" input_buffer = self._get_input_buffer(incremental_state) if input_buffer is not None: for k in input_buffer.keys(): input_buffer[k] = input_buffer[k].index_select(1, new_order) self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): return get_incremental_state( self, incremental_state, 'attn_state', ) or {} def _set_input_buffer(self, incremental_state, buffer): set_incremental_state( self, incremental_state, 'attn_state', buffer, ) class LearnedPositionalEmbedding(nn.Embedding): """This module learns positional embeddings up to a fixed maximum size. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). """ def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad): super().__init__(num_embeddings, embedding_dim, padding_idx) self.left_pad = left_pad def forward(self, input, incremental_state=None): """Input is expected to be of size [bsz x seqlen].""" if incremental_state is not None: # positions is the same for every token when decoding a single step positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) else: positions = make_positions(input.data, self.padding_idx, self.left_pad) return super().forward(positions) def max_positions(self): """Maximum number of supported positions.""" return self.num_embeddings - self.padding_idx - 1 class SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). """ def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024): super().__init__() self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.left_pad = left_pad self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim, padding_idx, ) self.register_buffer('_float_tensor', torch.FloatTensor()) @staticmethod def get_embedding(num_embeddings, embedding_dim, padding_idx=None): """Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb def forward(self, input, incremental_state=None): """Input is expected to be of size [bsz x seqlen].""" # recompute/expand embeddings if needed bsz, seq_len = input.size() max_pos = self.padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): self.weights = SinusoidalPositionalEmbedding.get_embedding( max_pos, self.embedding_dim, self.padding_idx, ) self.weights = self.weights.type_as(self._float_tensor) if incremental_state is not None: # positions is the same for every token when decoding a single step return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1) positions = make_positions(input.data, self.padding_idx, self.left_pad) return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() def max_positions(self): """Maximum number of supported positions.""" return int(1e5) # an arbitrary large number class GeLU(nn.Module): def __init__(self): super().__init__() def forward(self, x): return 0.5 * x * (1 + torch.tanh(0.79788456080 * (x + 0.044715 * x * x * x))) def Embedding(num_embeddings, embedding_dim, padding_idx): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) return m def LayerNorm(embedding_dim): m = nn.LayerNorm(embedding_dim) return m def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0.) return m def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False): if learned: m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.constant_(m.weight[padding_idx], 0) else: m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings) return m INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) def _get_full_incremental_state_key(module_instance, key): module_name = module_instance.__class__.__name__ # assign a unique ID to each module instance, so that incremental state is # not shared across module instances if not hasattr(module_instance, '_fairseq_instance_id'): INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key) def get_incremental_state(module, incremental_state, key): """Helper for getting incremental state for an nn.Module.""" full_key = _get_full_incremental_state_key(module, key) if incremental_state is None or full_key not in incremental_state: return None return incremental_state[full_key] def set_incremental_state(module, incremental_state, key, value): """Helper for setting incremental state for an nn.Module.""" if incremental_state is not None: full_key = _get_full_incremental_state_key(module, key) incremental_state[full_key] = value def fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float('-inf')).type_as(t) def make_positions(tensor, padding_idx, left_pad): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). """ max_pos = padding_idx + 1 + tensor.size(1) if not hasattr(make_positions, 'range_buf'): make_positions.range_buf = tensor.new() make_positions.range_buf = make_positions.range_buf.type_as(tensor) if make_positions.range_buf.numel() < max_pos: torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf) mask = tensor.ne(padding_idx) positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor) if left_pad: positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) return tensor.clone().masked_scatter_(mask, positions[mask])