import torch.nn as nn import math import torch from torch.autograd import Variable from torch.nn import Module from torch.nn.parameter import Parameter class BertPooler(nn.Module): def __init__(self, hidden_size): super(BertPooler, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertLayerNorm(nn.Module): def __init__(self, hidden_size, variance_epsilon=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(BertLayerNorm, self).__init__() self.gamma = nn.Parameter(torch.ones(hidden_size)) self.beta = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = variance_epsilon def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.gamma * x + self.beta class BertSelfAttention(nn.Module): """ Extracted from """ def __init__(self, hidden_size): super(BertSelfAttention, self).__init__() self.num_attention_heads = 16 self.attention_head_size = int(hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(hidden_size, self.all_head_size) self.key = nn.Linear(hidden_size, self.all_head_size) self.value = nn.Linear(hidden_size, self.all_head_size) self.dropout = nn.Dropout(0.2) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask=None): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) if attention_mask is not None: attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer, attention_probs class SelfAttentive(nn.Module): def __init__(self, hidden_size, att_hops=1, att_unit=200, dropout=0.2): super(SelfAttentive, self).__init__() self.drop = nn.Dropout(dropout) self.ws1 = nn.Linear(hidden_size, att_unit, bias=False) self.ws2 = nn.Linear(att_unit, att_hops, bias=False) self.tanh = nn.Tanh() self.softmax = nn.Softmax() # self.dictionary = config['dictionary'] # self.init_weights() self.attention_hops = att_hops def forward(self, rnn_out, mask=None): outp = rnn_out size = outp.size() # [bsz, len, nhid] compressed_embeddings = outp.reshape(-1, size[2]) # [bsz*len, nhid*2] hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit] alphas = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] alphas = torch.transpose(alphas, 1, 2).contiguous() # [bsz, hop, len] if mask is not None: mask = mask.squeeze(2) concatenated_mask = [mask for i in range(self.attention_hops)] concatenated_mask = torch.cat(concatenated_mask, 1) # [bsz, hop, len] penalized_alphas = alphas + concatenated_mask else: penalized_alphas = alphas alphas = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len] alphas = alphas.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len] return torch.bmm(alphas, outp), alphas class AttentionOneParaPerChan(Module): """ Computes a weighted average of the different channels across timesteps. Uses 1 parameter pr. channel to compute the attention value for a single timestep. """ def __init__(self, attention_size, IS_HALF=False): """ Initialize the attention layer # Arguments: attention_size: Size of the attention vector. return_attention: If true, output will include the weight for each input token used for the prediction """ super(AttentionOneParaPerChan, self).__init__() self.attention_size = attention_size self.attention_vector = Parameter(torch.FloatTensor(attention_size)) self.attention_vector.data.normal_(std=0.05) # Initialize attention vector self.is_half = IS_HALF def __repr__(self): s = '{name}({attention_size}, return attention={return_attention})' return s.format(name=self.__class__.__name__, **self.__dict__) def forward(self, inputs, input_lengths): """ Forward pass. # Arguments: inputs (Torch.Variable): Tensor of input sequences input_lengths (torch.LongTensor): Lengths of the sequences # Return: Tuple with (representations and attentions if self.return_attention else None). """ logits = inputs.matmul(self.attention_vector) unnorm_ai = (logits - logits.max()).exp() # Compute a mask for the attention on the padded sequences # See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5 max_len = unnorm_ai.size(1) idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0) if self.is_half: mask = Variable((idxes < input_lengths.unsqueeze(1)).half()).cuda() else: mask = Variable((idxes < input_lengths.unsqueeze(1)).float()).cuda() masked_weights = unnorm_ai * mask # apply mask and renormalize attention scores (weights) att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence attentions = masked_weights.div(att_sums) # apply attention weights weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs)) # get the final fixed vector representations of the sentences representations = weighted.sum(dim=1) return representations, attentions