#!/usr/bin/python3 # Author: GMFTBY # Time: 2019.9.14 import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init import random import numpy as np import ipdb from .layers import * ''' WSeq is a HRED-based model which uses cosine similarity as the attention weight ''' class Utterance_encoder(nn.Module): ''' Bidirectional GRU ''' def __init__(self, input_size, embedding_size, hidden_size, dropout=0.5, n_layer=1, pretrained=None): super(Utterance_encoder, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.input_size = input_size self.n_layer = n_layer self.embed = nn.Embedding(input_size, self.embedding_size) self.gru = nn.GRU(self.embedding_size, self.hidden_size, num_layers=n_layer, dropout=dropout, bidirectional=True) # hidden_project # self.hidden_proj = nn.Linear(n_layer * 2 * self.hidden_size, hidden_size) # self.bn = nn.BatchNorm1d(num_features=hidden_size) self.init_weight() def init_weight(self): init.xavier_normal_(self.gru.weight_hh_l0) init.xavier_normal_(self.gru.weight_ih_l0) self.gru.bias_ih_l0.data.fill_(0.0) self.gru.bias_hh_l0.data.fill_(0.0) def forward(self, inpt, lengths, hidden=None): # use pack_padded # inpt: [seq_len, batch], lengths: [batch_size] embedded = self.embed(inpt) # [seq_len, batch, input_size] if not hidden: hidden = torch.randn(self.n_layer * 2, len(lengths), self.hidden_size) if torch.cuda.is_available(): hidden = hidden.cuda() embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths, enforce_sorted=False) output, hidden = self.gru(embedded, hidden) output, _ = nn.utils.rnn.pad_packed_sequence(output) output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] hidden = hidden.sum(axis=0) hidden = torch.tanh(hidden) output = torch.tanh(output) # [seq, batch, hidden] # [n_layer * bidirection, batch, hidden_size] # hidden = hidden.reshape(hidden.shape[1], -1) # ipdb.set_trace() # hidden = hidden.permute(1, 0, 2) # [batch, n_layer * bidirectional, hidden_size] # hidden = hidden.reshape(hidden.size(0), -1) # [batch, *] # hidden = self.bn(self.hidden_proj(hidden)) # hidden = torch.tanh(hidden) # [batch, hidden] return output, hidden class Context_encoder(nn.Module): ''' input_size is 2 * utterance_hidden_size ''' def __init__(self, input_size, hidden_size, dropout=0.5): super(Context_encoder, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.gru = nn.GRU(self.input_size, self.hidden_size, bidirectional=True) # self.drop = nn.Dropout(p=dropout) self.init_weight() def init_weight(self): init.xavier_normal_(self.gru.weight_hh_l0) init.xavier_normal_(self.gru.weight_ih_l0) self.gru.bias_ih_l0.data.fill_(0.0) self.gru.bias_hh_l0.data.fill_(0.0) def forward(self, inpt, hidden=None): # inpt: [turn_len, batch, input_size] # hidden if not hidden: hidden = torch.randn(2, inpt.shape[1], self.hidden_size) if torch.cuda.is_available(): hidden = hidden.cuda() # inpt = self.drop(inpt) output, hidden = self.gru(inpt, hidden) output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] # hidden: [2, batch, hidden_size] hidden = torch.tanh(hidden) # [batch, hidden_size] return output, hidden class Decoder(nn.Module): ''' Max likelyhood for decoding the utterance input_size is the size of the input vocabulary Attention module should satisfy that the decoder_hidden size is the same as the Context encoder hidden size WSeq attention in the decoder part smaller model size than HRED-attn but has the attention part ''' def __init__(self, utter_hidden, context_hidden, output_size, embed_size, hidden_size, n_layer=2, dropout=0.5, pretrained=None): super(Decoder, self).__init__() self.output_size = output_size self.hidden_size = hidden_size self.embed_size = embed_size self.embed = nn.Embedding(self.output_size, self.embed_size) self.gru = nn.GRU(self.embed_size + self.hidden_size, self.hidden_size, num_layers=n_layer, dropout=(0 if n_layer == 1 else dropout)) self.out = nn.Linear(hidden_size, output_size) # attention on context encoder self.attn = WSeq_attention(hidden_size) self.context_encoder = Context_encoder(utter_hidden, context_hidden, dropout=dropout) self.word_level_attn = Attention(hidden_size) self.init_weight() def init_weight(self): init.xavier_normal_(self.gru.weight_hh_l0) init.xavier_normal_(self.gru.weight_ih_l0) self.gru.bias_ih_l0.data.fill_(0.0) self.gru.bias_hh_l0.data.fill_(0.0) def forward(self, inpt, last_hidden, encoder_outputs): # inpt: [batch_size], last_hidden: [2, batch, hidden_size] # encoder_outputs: [turn_len, seq, batch, hidden_size] embedded = self.embed(inpt).unsqueeze(0) # [1, batch_size, embed_size] key = last_hidden.sum(axis=0) # word level attention context_output = [] for turn in encoder_outputs: # ipdb.set_trace() word_attn_weights = self.word_level_attn(key, turn) context = word_attn_weights.bmm(turn.transpose(0, 1)) context = context.transpose(0, 1).squeeze(0) # [batch, hidden] context_output.append(context) context_output = torch.stack(context_output) # [turn, batch, hidden] # output: [seq, batch, hidden], [2, batch, hidden] context_output, hidden = self.context_encoder(context_output) # [batch, hidden_size] if len(encoder_outputs) == 1: context = self.attn(context_output[0], context_output) else: context = self.attn(context_output[-1], context_output[:-1]) context = context.unsqueeze(0) # [1, batch, hidden] rnn_input = torch.cat([embedded, context], 2) # [1, batch, 2 * hidden] # output: [1, batch, hidden_size], hidden: [1, batch, hidden_size] output, hidden = self.gru(rnn_input, last_hidden) output = output.squeeze(0) # [batch, hidden_size] # context = context.squeeze(0) # [batch, hidden] # output = torch.cat([output, context], 1) # [batch, 2 * hidden] output = self.out(output) # [batch, output_size] output = F.log_softmax(output, dim=1) return output, hidden class WSeq_RA(nn.Module): def __init__(self, embed_size, input_size, output_size, utter_hidden, context_hidden, decoder_hidden, teach_force=0.5, pad=24745, sos=24742, dropout=0.5, utter_n_layer=1, pretrained=None): super(WSeq_RA, self).__init__() self.teach_force = teach_force self.output_size = output_size self.pad, self.sos = pad, sos self.utter_n_layer = utter_n_layer self.hidden_size = decoder_hidden self.utter_encoder = Utterance_encoder(input_size, embed_size, utter_hidden, dropout=dropout, n_layer=utter_n_layer, pretrained=pretrained) self.decoder = Decoder(utter_hidden, context_hidden, output_size, embed_size, decoder_hidden, n_layer=utter_n_layer, dropout=dropout, pretrained=pretrained) def forward(self, src, tgt, lengths): # src: [turns, lengths, batch], tgt: [lengths, batch] # lengths: [turns, batch] turn_size, batch_size, maxlen = len(src), tgt.size(1), tgt.size(0) outputs = torch.zeros(maxlen, batch_size, self.output_size) if torch.cuda.is_available(): outputs = outputs.cuda() # utterance encoding turns = [] turns_output = [] for i in range(turn_size): # sbatch = src[i].transpose(0, 1) # [seq_len, batch] output, hidden = self.utter_encoder(src[i], lengths[i]) # utter_hidden turns.append(hidden) turns_output.append(output) turns = torch.stack(turns) # [turn_len, batch, utter_hidden] # context encoding # output: [seq, batch, hidden], [batch, hidden] # context_output, hidden = self.context_encoder(turns) # decoding # tgt = tgt.transpose(0, 1) # [seq_len, batch] # hidden = hidden.unsqueeze(0) # [1, batch, hidden_size] hidden = torch.randn(self.utter_n_layer, batch_size, self.hidden_size) if torch.cuda.is_available(): hidden = hidden.cuda() output = tgt[0, :] # [batch] use_teacher = random.random() < self.teach_force if use_teacher: for t in range(1, maxlen): output, hidden = self.decoder(output, hidden, turns_output) outputs[t] = output output = tgt[t] else: for t in range(1, maxlen): output, hidden = self.decoder(output, hidden, turns_output) outputs[t] = output output = output.topk(1)[1].squeeze().detach() return outputs # [maxlen, batch, vocab_size] def predict(self, src, maxlen, lengths, loss=False): # predict for test dataset, return outputs: [maxlen, batch_size] # src: [turn, max_len, batch_size], lengths: [turn, batch_size] with torch.no_grad(): turn_size, batch_size = len(src), src[0].size(1) outputs = torch.zeros(maxlen, batch_size) floss = torch.zeros(maxlen, batch_size, self.output_size) if torch.cuda.is_available(): outputs = outputs.cuda() floss = floss.cuda() # utterance encoding turns = [] turns_output = [] for i in range(turn_size): # sbatch = src[i].transpose(0, 1) # [seq_len, batch] output, hidden = self.utter_encoder(src[i], lengths[i]) # utter_hidden turns.append(hidden) turns_output.append(output) turns = torch.stack(turns) # [turn_len, batch, utter_hidden] # context_output, hidden = self.context_encoder(turns) # hidden = hidden.unsqueeze(0) hidden = torch.randn(self.utter_n_layer, batch_size, self.hidden_size) if torch.cuda.is_available(): hidden = hidden.cuda() output = torch.zeros(batch_size, dtype=torch.long).fill_(self.sos) if torch.cuda.is_available(): output = output.cuda() for i in range(1, maxlen): output, hidden = self.decoder(output, hidden, turns_output) floss[i] = output output = output.max(1)[1] outputs[i] = output if loss: return outputs, floss else: return outputs if __name__ == "__main__": pass