import torch from allennlp.common.checks import ConfigurationError from allennlp.nn.util import replace_masked_values from allennlp.nn.util import get_text_field_mask import allennlp from typing import Union, Dict import torch from allennlp.modules import MatrixAttention, Seq2SeqEncoder def masked_mean(tensor, dim, mask): """ ``Performs a mean on just the non-masked portions of the ``tensor`` in the ``dim`` dimension of the tensor. ===================================================================== From Decomposable Graph Entailment Model code replicated from SciTail repo https://github.com/allenai/scitail ===================================================================== """ if mask is None: return torch.mean(tensor, dim) if tensor.dim() != mask.dim(): raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim())) masked_tensor = replace_masked_values(tensor, mask, 0.0) # total value total_tensor = torch.sum(masked_tensor, dim) # count count_tensor = torch.sum((mask != 0), dim) # set zero count to 1 to avoid nans zero_count_mask = (count_tensor == 0) count_plus_zeros = (count_tensor + zero_count_mask).float() # average mean_tensor = total_tensor / count_plus_zeros return mean_tensor def seq2vec_seq_aggregate(seq_tensor, mask, aggregate, bidirectional, dim=1): """ Takes the aggregation of sequence tensor :param seq_tensor: Batched sequence requires [batch, seq, hs] :param mask: binary mask with shape batch, seq_len, 1 :param aggregate: max, avg, sum :param dim: The dimension to take the max. for batch, seq, hs it is 1 :return: """ seq_tensor_masked = seq_tensor * mask.unsqueeze(-1) aggr_func = None if aggregate == "last": raise NotImplemented("This is currently not supported with AllenNLP 0.2.") seq = allennlp.nn.util.get_final_encoder_states(seq_tensor, mask, bidirectional) elif aggregate == "max": aggr_func = torch.max seq, _ = aggr_func(seq_tensor_masked, dim=dim) elif aggregate == "min": aggr_func = torch.min seq, _ = aggr_func(seq_tensor_masked, dim=dim) elif aggregate == "sum": aggr_func = torch.sum seq = aggr_func(seq_tensor_masked, dim=dim) elif aggregate == "avg": aggr_func = torch.sum seq = aggr_func(seq_tensor_masked, dim=dim) seq_lens = torch.sum(mask, dim=dim) # this returns batch_size, 1 seq = seq / seq_lens.view([-1, 1]) return seq def embed_encode_and_aggregate_text_field(question: Dict[str, torch.LongTensor], text_field_embedder, embeddings_dropout, encoder, aggregation_type): """ Given a batched token ids (2D) runs embeddings lookup with dropout, context encoding and aggregation :param question: :param text_field_embedder: The embedder to be used for embedding lookup :param embeddings_dropout: Dropout :param encoder: Context encoder :param aggregation_type: The type of aggregation - max, sum, avg, last :return: """ embedded_question = text_field_embedder(question) question_mask = get_text_field_mask(question).float() embedded_question = embeddings_dropout(embedded_question) encoded_question = encoder(embedded_question, question_mask) # aggregate sequences to a single item encoded_question_aggregated = seq2vec_seq_aggregate(encoded_question, question_mask, aggregation_type, None, 1) # bs X d return encoded_question_aggregated def embed_encode_and_aggregate_list_text_field(texts_list: Dict[str, torch.LongTensor], text_field_embedder, embeddings_dropout, encoder: Seq2SeqEncoder, aggregation_type, init_hidden_states=None): """ Given a batched list of token ids (3D) runs embeddings lookup with dropout, context encoding and aggregation on :param texts_list: List of texts :param text_field_embedder: The embedder to be used for embedding lookup :param embeddings_dropout: Dropout :param encoder: Context encoder :param aggregation_type: The type of aggregation - max, sum, avg, last :param get_last_states: If it should return the last states. :param init_hidden_states: Hidden states initialization :return: """ embedded_texts = text_field_embedder(texts_list) embedded_texts = embeddings_dropout(embedded_texts) batch_size, choices_cnt, choice_tokens_cnt, d = tuple(embedded_texts.shape) embedded_texts_flattened = embedded_texts.view([batch_size * choices_cnt, choice_tokens_cnt, -1]) # masks texts_mask_dim_3 = get_text_field_mask(texts_list).float() texts_mask_flatened = texts_mask_dim_3.view([-1, choice_tokens_cnt]) # context encoding multiple_texts_init_states = None if init_hidden_states is not None: if init_hidden_states.shape[0] == batch_size and init_hidden_states.shape[1] != choices_cnt: if init_hidden_states.shape[1] != encoder.get_output_dim(): raise ValueError("The shape of init_hidden_states is {0} but is expected to be {1} or {2}".format(str(init_hidden_states.shape), str([batch_size, encoder.get_output_dim()]), str([batch_size, choices_cnt, encoder.get_output_dim()]))) # in this case we passed only 2D tensor which is the default output from question encoder multiple_texts_init_states = init_hidden_states.unsqueeze(1).expand([batch_size, choices_cnt, encoder.get_output_dim()]).contiguous() # reshape this to match the flattedned tokens multiple_texts_init_states = multiple_texts_init_states.view([batch_size * choices_cnt, encoder.get_output_dim()]) else: multiple_texts_init_states = init_hidden_states.view([batch_size * choices_cnt, encoder.get_output_dim()]) encoded_texts_flattened = encoder(embedded_texts_flattened, texts_mask_flatened, hidden_state=multiple_texts_init_states) aggregated_choice_flattened = seq2vec_seq_aggregate(encoded_texts_flattened, texts_mask_flatened, aggregation_type, encoder, 1) # bs*ch X d aggregated_choice_flattened_reshaped = aggregated_choice_flattened.view([batch_size, choices_cnt, -1]) return aggregated_choice_flattened_reshaped