import torch from torch import nn import torch.nn.functional as F from torch.autograd import Variable import numpy as np import math from config import global_config as cfg import copy, random, time, logging def cuda_(var): return var.cuda() if cfg.cuda else var def toss_(p): return random.randint(0, 99) <= p def nan(v): return np.isnan(np.sum(v.data.cpu().numpy())) def get_sparse_input(x_input): """ get a sparse matrix of x_input: [T,B,V] where x_sparse[i][j][k]=1, and others = 1e-8 :param x_input: *Tensor* of [T,B] :return: *Tensor* in shape [B,T,V] """ # indexes that will make no effect in copying sw = time.time() print('sparse input start: %s' % sw) ignore_index = [0] result = torch.normal(mean=0, std=torch.zeros(x_input.size(0), x_input.size(1), cfg.vocab_size)) for t in range(x_input.size(0)): for b in range(x_input.size(1)): if x_input[t][b] not in ignore_index: result[t][b][x_input[t][b]] = 1.0 print('sparse input end %s' % time.time()) return result.transpose(0, 1) def get_sparse_input_efficient(x_input_np): ignore_index = [0] result = np.zeros((x_input_np.shape[0], x_input_np.shape[1], cfg.vocab_size), dtype=np.float32) result.fill(1e-10) for t in range(x_input_np.shape[0]): for b in range(x_input_np.shape[1]): if x_input_np[t][b] not in ignore_index: result[t,b,x_input_np[t][b]] = 1.0 result_np = result.transpose((1, 0, 2)) result = torch.from_numpy(result_np).float() return result def shift(pz_proba): first_input = np.zeros((pz_proba.size(1), pz_proba.size(2))) first_input.fill(1e-12) first_input = cuda_(Variable(torch.from_numpy(first_input)).float()) pz_proba = list(pz_proba)[:-1] pz_proba.insert(0, first_input) pz_proba = torch.stack(pz_proba, 0) return pz_proba.contiguous() class Encoder(nn.Module): def __init__(self, input_size, embed_size, hidden_size, n_layers, dropout): super(Encoder, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.embed_size = embed_size self.n_layers = n_layers self.dropout = dropout self.embedding = nn.Embedding(input_size, embed_size) self.gru = nn.GRU(embed_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True) def forward(self, input_seqs, hidden=None): embedded = self.embedding(input_seqs) outputs, hidden = self.gru(embedded, hidden) outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # Sum bidirectional outputs return outputs, hidden class DynamicEncoder(nn.Module): def __init__(self, input_size, embed_size, hidden_size, n_layers, dropout): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.embed_size = embed_size self.n_layers = n_layers self.dropout = dropout self.embedding = nn.Embedding(input_size, embed_size) self.gru = nn.GRU(embed_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True) def forward(self, input_seqs, input_lens, hidden=None): """ forward procedure. No need for inputs to be sorted :param input_seqs: Variable of [T,B] :param hidden: :param input_lens: *numpy array* of len for each input sequence :return: """ batch_size = input_seqs.size(1) embedded = self.embedding(input_seqs) embedded = embedded.transpose(0, 1) # [B,T,E] sort_idx = np.argsort(-input_lens) unsort_idx = cuda_(torch.LongTensor(np.argsort(sort_idx))) input_lens = input_lens[sort_idx] sort_idx = cuda_(torch.LongTensor(sort_idx)) embedded = embedded[sort_idx].transpose(0, 1) # [T,B,E] packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lens) outputs, hidden = self.gru(packed, hidden) outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) outputs = outputs[:,:,:self.hidden_size] + outputs[:,:,self.hidden_size:] outputs = outputs.transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() hidden = hidden.transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() return outputs, hidden class Attn(nn.Module): def __init__(self, hidden_size): super(Attn, self).__init__() self.hidden_size = hidden_size self.attn = nn.Linear(self.hidden_size * 2, hidden_size) self.v = nn.Linear(self.hidden_size, 1) def forward(self, hidden, encoder_outputs, normalize=True): encoder_outputs = encoder_outputs.transpose(0, 1) # [B,T,H] attn_energies = self.score(hidden, encoder_outputs) normalized_energy = F.softmax(attn_energies, dim=2) # [B,1,T] context = torch.bmm(normalized_energy, encoder_outputs) # [B,1,H] return context.transpose(0, 1) # [1,B,H] def score(self, hidden, encoder_outputs): max_len = encoder_outputs.size(1) H = hidden.repeat(max_len, 1, 1).transpose(0, 1) energy = self.attn(torch.cat([H, encoder_outputs], 2)) # [B,T,2H]->[B,T,H] energy = self.v(F.tanh(energy)).transpose(1,2) # [B,1,T] return energy class MultiTurnInferenceDecoder_Z(nn.Module): """ Inference network: copying version of Q_phi(z_t|s_t,m_t) <- Q_phi(z_ti|s_t,m_t,z_t[1..i-1]) """ def __init__(self, embed_size, hidden_size, vocab_size, dropout_rate): super().__init__() self.gru = nn.GRU(embed_size, hidden_size, dropout=dropout_rate) self.w1 = nn.Linear(hidden_size, vocab_size) self.mu = nn.Linear(vocab_size, embed_size, bias=False) self.log_sigma = nn.Linear(vocab_size, embed_size) self.dropout_rate = dropout_rate self.vocab_size = vocab_size self.proj_copy1 = nn.Linear(hidden_size, hidden_size) self.proj_copy2 = nn.Linear(hidden_size, hidden_size) self.proj_copy3 = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(self.dropout_rate) def forward(self, u_input, u_enc_out, pv_pz_proba, pv_z_dec_out, m_input, m_enc_out, embed_z, last_hidden, rand_eps, u_input_np, m_input_np): """ Similar to base class method :param m_input: :param u_input: :param u_enc_out: :param m_enc_out: :param embed_z: :param last_hidden: :param rand_eps: :return: """ sparse_u_input = Variable(get_sparse_input_efficient(u_input_np), requires_grad=False) # [B,T,V] sparse_m_input = Variable(get_sparse_input_efficient(m_input_np), requires_grad=False) # [B,T,V] # if cfg.cuda: sparse_m_input = sparse_m_input.cuda() # if cfg.cuda: sparse_u_input = sparse_u_input.cuda() embed_z = self.dropout(embed_z) gru_out, last_hidden = self.gru(embed_z, last_hidden) gen_score = self.w1(gru_out).squeeze(0) # [B,V] u_copy_score = F.tanh(self.proj_copy1(u_enc_out.transpose(0, 1))) # [B,T,H] m_copy_score = F.tanh(self.proj_copy2(m_enc_out.transpose(0, 1))) if not cfg.force_stable: # unstable version of copynet for small dataset u_copy_score = torch.exp(torch.matmul(u_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2)) # [B,T] m_copy_score = torch.exp(torch.matmul(m_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2)) # [B,T] u_copy_score, m_copy_score = u_copy_score.cpu(), m_copy_score.cpu() u_copy_score = torch.log(torch.bmm(u_copy_score.unsqueeze(1), sparse_u_input)).squeeze(1) # [B,V] m_copy_score = torch.log(torch.bmm(m_copy_score.unsqueeze(1), sparse_m_input)).squeeze(1) # [B,V] else: # stable version of copynet u_copy_score = torch.matmul(u_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) m_copy_score = torch.matmul(m_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) u_copy_score, m_copy_score = u_copy_score.cpu(), m_copy_score.cpu() u_copy_score_max, m_copy_score_max = torch.max(u_copy_score, dim=1, keepdim=True)[0], \ torch.max(m_copy_score, dim=1, keepdim=True)[0] u_copy_score = torch.exp(u_copy_score - u_copy_score_max) # [B,T] m_copy_score = torch.exp(m_copy_score - m_copy_score_max) # [B,T] # u_copy_score, m_copy_score = u_copy_score.cpu(), m_copy_score.cpu() u_copy_score = torch.log(torch.bmm(u_copy_score.unsqueeze(1), sparse_u_input)).squeeze( 1) + u_copy_score_max # [B,V] m_copy_score = torch.log(torch.bmm(m_copy_score.unsqueeze(1), sparse_m_input)).squeeze( 1) + m_copy_score_max # [B,V] u_copy_score, m_copy_score = cuda_(u_copy_score), cuda_(m_copy_score) if pv_pz_proba is not None: pv_pz_proba = shift(pv_pz_proba) pv_z_copy_score = F.tanh(self.proj_copy3(pv_z_dec_out.transpose(0, 1))) # [B,T,H] if cfg.force_stable: pv_z_copy_score = torch.exp( torch.matmul(pv_z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2)) # [B,T] pv_z_copy_score = torch.log( torch.bmm(pv_z_copy_score.unsqueeze(1), pv_pz_proba.transpose(0, 1))).squeeze( 1) # [B,V] else: pv_z_copy_score = torch.matmul(pv_z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) pv_z_copy_score_max = torch.max(pv_z_copy_score, dim=1, keepdim=True)[0] pv_z_copy_score = torch.exp(pv_z_copy_score - pv_z_copy_score_max) pv_z_copy_score = torch.log( torch.bmm(pv_z_copy_score.unsqueeze(1), pv_pz_proba.transpose(0, 1))).squeeze( 1) + pv_z_copy_score_max # [B,V] scores = F.softmax(torch.cat([gen_score, u_copy_score, m_copy_score, pv_z_copy_score], dim=1), dim=1) gen_score, u_copy_score, m_copy_score, pv_z_copy_score = tuple( torch.split(scores, gen_score.size(1), dim=1)) proba = gen_score + u_copy_score + m_copy_score + pv_z_copy_score else: scores = F.softmax(torch.cat([gen_score, u_copy_score, m_copy_score], dim=1), dim=1) gen_score, u_copy_score, m_copy_score = tuple( torch.split(scores, gen_score.size(1), dim=1)) proba = gen_score + u_copy_score + m_copy_score appr_emb = self.mu(proba).unsqueeze(0) # log_sigma_ae = self.log_sigma(proba) # sigma_ae = torch.exp(log_sigma_ae) # sampled_ae = appr_emb + torch.mul(sigma_ae, rand_eps) return appr_emb, gru_out, last_hidden, proba, appr_emb, None class MultiTurnPriorDecoder_Z(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, dropout_rate): super().__init__() self.gru = nn.GRU(embed_size, hidden_size, dropout=dropout_rate) self.w1 = nn.Linear(hidden_size, vocab_size) self.proj_copy1 = nn.Linear(hidden_size, hidden_size) self.proj_copy2 = nn.Linear(hidden_size, hidden_size) self.mu = nn.Linear(vocab_size, embed_size, bias=False) self.log_sigma = nn.Linear(vocab_size, embed_size) self.dropout_rate = dropout_rate self.dropout = nn.Dropout(dropout_rate) def forward(self, u_input, u_enc_out, pv_pz_proba, pv_z_dec_out, embed_z, last_hidden, rand_eps, u_input_np, m_input_np): sparse_u_input = Variable(get_sparse_input_efficient(u_input_np), requires_grad=False) embed_z = self.dropout(embed_z) gru_out, last_hidden = self.gru(embed_z, last_hidden) gen_score = self.w1(gru_out).squeeze(0) u_copy_score = F.tanh(self.proj_copy1(u_enc_out.transpose(0, 1))) # [B,T,H] if not cfg.force_stable: u_copy_score = torch.exp(torch.matmul(u_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2)) # [B,T] u_copy_score = u_copy_score.cpu() u_copy_score = torch.log(torch.bmm(u_copy_score.unsqueeze(1), sparse_u_input)).squeeze(1) # [B,V] else: # stable version of copynet u_copy_score = torch.matmul(u_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) u_copy_score = u_copy_score.cpu() u_copy_score_max = torch.max(u_copy_score, dim=1, keepdim=True)[0] u_copy_score = torch.exp(u_copy_score - u_copy_score_max) # [B,T] u_copy_score = torch.log(torch.bmm(u_copy_score.unsqueeze(1), sparse_u_input)).squeeze( 1) + u_copy_score_max # [B,V] u_copy_score = cuda_(u_copy_score) if pv_pz_proba is not None: pv_pz_proba = shift(pv_pz_proba) pv_z_copy_score = F.tanh(self.proj_copy2(pv_z_dec_out.transpose(0, 1))) # [B,T,H] if cfg.force_stable: pv_z_copy_score = torch.exp( torch.matmul(pv_z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2)) # [B,T] pv_z_copy_score = torch.log( torch.bmm(pv_z_copy_score.unsqueeze(1), pv_pz_proba.transpose(0, 1))).squeeze( 1) # [B,V] else: pv_z_copy_score = torch.matmul(pv_z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) pv_z_copy_score_max = torch.max(pv_z_copy_score, dim=1, keepdim=True)[0] pv_z_copy_score = torch.exp(pv_z_copy_score - pv_z_copy_score_max) pv_z_copy_score = torch.log( torch.bmm(pv_z_copy_score.unsqueeze(1), pv_pz_proba.transpose(0, 1))).squeeze( 1) + pv_z_copy_score_max # [B,V] scores = F.softmax(torch.cat([gen_score, u_copy_score, pv_z_copy_score], dim=1), dim=1) gen_score, u_copy_score, pv_z_copy_score = tuple(torch.split(scores, gen_score.size(1), dim=1)) proba = gen_score + u_copy_score + pv_z_copy_score # [B,V] else: scores = F.softmax(torch.cat([gen_score, u_copy_score], dim=1), dim=1) gen_score, u_copy_score = tuple(torch.split(scores, gen_score.size(1), dim=1)) proba = gen_score + u_copy_score # [B,V] appr_emb = self.mu(proba).unsqueeze(0) return appr_emb, gru_out, last_hidden, proba, appr_emb, None class ResponseDecoder(nn.Module): """ Response decoder: P_theta(m_t|s_t, z_t) <- P_theta(m_ti|s_t, z_t, m_t[1..i-1]) This is a deterministic decoder. """ def __init__(self, embed_size, hidden_size, vocab_size, degree_size, dropout_rate): super().__init__() self.emb = nn.Embedding(vocab_size, embed_size) self.attn_z = Attn(hidden_size) self.attn_u = Attn(hidden_size) self.w4 = nn.Linear(hidden_size, hidden_size) self.gate_z = nn.Linear(hidden_size, hidden_size) self.w5 = nn.Linear(hidden_size, hidden_size) self.gru = nn.GRU(embed_size + hidden_size + degree_size, hidden_size, dropout=dropout_rate) self.proj = nn.Linear(hidden_size * 3, vocab_size) self.proj_copy1 = nn.Linear(hidden_size, hidden_size) self.dropout_rate = dropout_rate self.dropout = nn.Dropout(dropout_rate) def forward(self, z_enc_out, pz_proba, u_enc_out, m_t_input, degree_input, last_hidden): """ decode the response: P(m|u,z) :param degree_input: [B,D] :param pz_proba: [Tz,B,V], output of the prior decoder :param z_enc_out: [Tz,B,H] :param u_enc_out: [T,B,H] :param m_t_input: [1,B] :param last_hidden: :return: proba: [1,B,V] """ m_embed = self.emb(m_t_input) pz_proba = shift(pz_proba) z_context = self.attn_z(last_hidden, z_enc_out) u_context = self.attn_u(last_hidden, u_enc_out) d_control = z_context + torch.mul(F.sigmoid(self.gate_z(z_context)), u_context) embed = torch.cat([d_control, m_embed, degree_input.unsqueeze(0)], dim=2) embed = self.dropout(embed) gru_out, last_hidden = self.gru(embed, last_hidden) gen_score = self.proj(torch.cat([z_context, u_context, gru_out], 2)).squeeze(0) z_copy_score = F.tanh(self.proj_copy1(z_enc_out.transpose(0, 1))) # [B,T,H] if not cfg.force_stable: z_copy_score = torch.exp(torch.matmul(z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2)) # [B,T] z_copy_score = torch.log(torch.bmm(z_copy_score.unsqueeze(1), pz_proba.transpose(0, 1))).squeeze(1) # [B,V] else: z_copy_score = torch.matmul(z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) z_copy_score_max = torch.max(z_copy_score, dim=1, keepdim=True)[0] z_copy_score = torch.exp(z_copy_score - z_copy_score_max) z_copy_score = torch.log(torch.bmm(z_copy_score.unsqueeze(1), pz_proba.transpose(0, 1))) z_copy_score = z_copy_score.squeeze(1) + z_copy_score_max scores = F.softmax(torch.cat([gen_score, z_copy_score], dim=1), dim=1) gen_score, z_copy_score = tuple(torch.split(scores, gen_score.size(1), dim=1)) proba = gen_score + z_copy_score # [B,V] return proba, last_hidden, gru_out class MultinomialKLDivergenceLoss(nn.Module): def __init__(self, special_tokens=[]): super().__init__() self.special_tokens = special_tokens def forward(self, p_proba, q_proba): # [B, T, V] mask = torch.ones(p_proba.size(0), p_proba.size(1)) cnt = 0 for i in range(q_proba.size(0)): flg = False for j in range(q_proba.size(1)): topv, topi = torch.max(q_proba[i,j], -1) if flg: mask[i,j] = 0 else: mask[i,j] = 1 cnt += 1 if topi.item() in self.special_tokens: flg = True mask = cuda_(Variable(mask)) loss = q_proba * (torch.log(q_proba) - torch.log(p_proba)) masked_loss = torch.sum(mask.unsqueeze(-1) * loss) return masked_loss / (cnt + 1e-10) class SemiSupervisedSEDST(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, degree_size, layer_num, dropout_rate, z_length, alpha, max_ts, beam_search=False, teacher_force=100, **kwargs): super().__init__() self.u_encoder = DynamicEncoder(vocab_size, embed_size, hidden_size, layer_num, dropout_rate) self.m_encoder = DynamicEncoder(vocab_size, embed_size, hidden_size, layer_num, dropout_rate) self.m_decoder = ResponseDecoder(embed_size, hidden_size, vocab_size, degree_size, dropout_rate) self.qz_decoder = MultiTurnInferenceDecoder_Z(embed_size, hidden_size, vocab_size, dropout_rate) # posterior self.pz_decoder = MultiTurnPriorDecoder_Z(embed_size, hidden_size, vocab_size, dropout_rate) # prior self.embed_size = embed_size self.vocab = kwargs['vocab'] self.pr_loss = nn.NLLLoss(ignore_index=0) self.q_loss = nn.NLLLoss(ignore_index=0) self.dec_loss = nn.NLLLoss(ignore_index=0) self.kl_loss = MultinomialKLDivergenceLoss(special_tokens=[self.vocab.encode(x) for x in ['EOS_Z1','EOS_Z2', '</s>', '<pad>']]) self.z_length = z_length self.alpha = alpha self.max_ts = max_ts self.beam_search = beam_search self.teacher_force = teacher_force if self.beam_search: self.beam_size = kwargs['beam_size'] self.eos_token_idx = kwargs['eos_token_idx'] def forward(self, u_input, u_input_np, m_input, m_input_np, z_input, u_len, m_len, turn_states, z_supervised, p_input, p_input_np, p_len, degree_input, mode): if mode == 'train' or mode == 'valid': if not z_supervised: z_input = None pz_proba, qz_proba, pm_dec_proba, pz_mu, pz_log_sigma, qz_mu, qz_log_sigma, turn_states = \ self.forward_turn(u_input, u_len, m_input=m_input, m_len=m_len, z_input=z_input, is_train=True, turn_states=turn_states, degree_input=degree_input, u_input_np=u_input_np, m_input_np=m_input_np,p_input=p_input,p_input_np=p_input_np,p_len=p_len) if z_supervised: loss, pr_loss, m_loss, q_loss = self.supervised_loss(torch.log(pz_proba), torch.log(qz_proba), torch.log(pm_dec_proba), z_input, m_input) return loss, pr_loss, m_loss, q_loss, turn_states else: loss, m_loss, kl_div_loss = self.unsupervised_loss(qz_mu, qz_log_sigma, pz_mu, pz_log_sigma, torch.log(pm_dec_proba), m_input, pz_proba, qz_proba) return loss, m_loss, kl_div_loss, turn_states elif mode == 'test': m_output_index, pz_index, turn_states = self.forward_turn(u_input, u_len=u_len, is_train=False, turn_states=turn_states, degree_input=degree_input, u_input_np=u_input_np, m_input_np=m_input_np, p_input=p_input, p_input_np=p_input_np, p_len=p_len ) return m_output_index, pz_index, turn_states def forward_turn(self, u_input, u_len, turn_states, is_train, degree_input, u_input_np, m_input_np=None, m_input=None, m_len=None, z_input=None, p_input=None, p_input_np=None, p_len=None,test_type='pr'): """ compute required outputs for a single dialogue turn. Turn state{Dict} will be updated in each call. :param u_input_np: :param m_input_np: :param u_len: :param turn_states: :param is_train: :param u_input: [T,B] :param m_input: [T,B] :param z_input: [T,B] :return: """ pv_pz_proba = turn_states.get('pv_pz_proba', None) pv_z_outs = turn_states.get('pv_z_dec_outs', None) pv_qz_proba = turn_states.get('pv_qz_proba', None) pv_qz_outs = turn_states.get('pv_qz_dec_outs', None) batch_size = u_input.size(1) u_enc_out, u_enc_hidden = self.u_encoder(u_input, u_len) last_hidden = u_enc_hidden[:-1] # initial approximate embedding: SOS token initialized with all zero # Pi(z|u) pz_ae = cuda_(Variable(torch.zeros(1, batch_size, self.embed_size))) pz_proba, pz_mu, pz_log_sigma = [], [], [] pz_dec_outs = [] z_length = z_input.size(0) if z_input is not None else self.z_length for t in range(z_length): if cfg.sampling: rand_eps = Variable(torch.normal(means=torch.zeros(1, batch_size, cfg.embedding_size), std=1)) else: rand_eps = Variable(torch.zeros(1, batch_size, cfg.embedding_size)) if cfg.cuda: rand_eps = rand_eps.cuda() pz_ae, last_hidden, pz_dec_out, proba, appr_emb, log_sigma_ae = \ self.pz_decoder(u_input=u_input, u_enc_out=u_enc_out, pv_pz_proba=pv_pz_proba, pv_z_dec_out=pv_z_outs, embed_z=pz_ae, last_hidden=last_hidden, rand_eps=rand_eps, u_input_np=u_input_np, m_input_np=m_input_np) pz_proba.append(proba) pz_mu.append(appr_emb) pz_log_sigma.append(log_sigma_ae) pz_dec_outs.append(pz_dec_out) pz_dec_outs = torch.cat(pz_dec_outs, dim=0) # [Tz,B,H] pz_proba, pz_mu = torch.stack(pz_proba, dim=0), torch.stack(pz_mu, dim=0) # P(m|z,u) m_tm1 = cuda_(Variable(torch.ones(1, batch_size).long())) # GO token pm_dec_proba, m_dec_outs = [],[] turn_states['pv_z_dec_outs'], turn_states['pv_pz_proba'] = pz_dec_outs, pz_proba if is_train or test_type=='post': m_length = m_input.size(0) # Tm for t in range(m_length): teacher_forcing = toss_(self.teacher_force) proba, last_hidden, dec_out = self.m_decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, degree_input, last_hidden) if teacher_forcing: m_tm1 = m_input[t].view(1, -1) else: _, m_tm1 = torch.topk(proba, 1) m_tm1 = m_tm1.view(1, -1) pm_dec_proba.append(proba) m_dec_outs.append(dec_out) pm_dec_proba = torch.stack(pm_dec_proba, dim=0) # [T,B,V] # Q(z|u,m) u_enc_out, u_enc_hidden = self.m_encoder(u_input, u_len) m_enc_out, m_enc_hidden = self.m_encoder(m_input, m_len) last_hidden = u_enc_hidden[:-1] qz_ae = cuda_(Variable(torch.zeros(1, batch_size, self.embed_size))) qz_proba, qz_mu, qz_log_sigma, qz_dec_outs = [], [], [], [] for t in range(z_length): if cfg.sampling: rand_eps = self.alpha * Variable(torch.normal(means=torch.zeros(1, batch_size, cfg.embedding_size), std=1)) else: rand_eps = Variable(torch.zeros(1, batch_size, cfg.embedding_size)) if cfg.cuda: rand_eps = rand_eps.cuda() qz_ae, gru_out, last_hidden, proba, appr_emb, log_sigma_ae = \ self.qz_decoder(u_input=u_input, u_enc_out=u_enc_out, pv_pz_proba=pv_qz_proba, pv_z_dec_out=pv_qz_outs, m_input=m_input, m_enc_out=m_enc_out, u_input_np=u_input_np, m_input_np=m_input_np, embed_z=qz_ae, last_hidden=last_hidden, rand_eps=rand_eps) qz_proba.append(proba) qz_mu.append(appr_emb) qz_log_sigma.append(log_sigma_ae) qz_dec_outs.append(gru_out) qz_proba, qz_mu = torch.stack(qz_proba, dim=0), torch.stack(qz_mu, dim=0) qz_dec_outs = torch.cat(qz_dec_outs, dim=0) turn_states['pv_qz_dec_outs'], turn_states['pv_qz_proba'] = qz_dec_outs, qz_proba if is_train: return pz_proba, qz_proba, pm_dec_proba, pz_mu, pz_log_sigma, qz_mu, qz_log_sigma, turn_states else: qz_index = self.pz_max_sampling(qz_proba) return None, qz_index, turn_states else: if not self.beam_search: m_output_index = self.greedy_decode(pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input) else: m_output_index = self.beam_search_decode(pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input, self.eos_token_idx) pz_index = self.pz_max_sampling(pz_proba) return m_output_index, pz_index, turn_states def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input): """ greedy decoding of the response :param pz_dec_outs: :param u_enc_out: :param m_tm1: :param last_hidden: :return: nested-list """ decoded = [] for t in range(self.max_ts): proba, last_hidden, _ = self.m_decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, degree_input, last_hidden) mt_proba, mt_index = torch.topk(proba, 1) # [B,1] mt_index = mt_index.data.view(-1) decoded.append(mt_index) m_tm1 = cuda_(Variable(mt_index).view(1, -1)) decoded = torch.stack(decoded, dim=0).transpose(0, 1) decoded = list(decoded) return [list(_) for _ in decoded] def pz_max_sampling(self, pz_proba): """ Max-sampling procedure of pz during testing. :param pz_proba: # [Tz, B, Vz] :return: nested-list: B * [T] """ pz_proba = pz_proba.data z_proba, z_token = torch.topk(pz_proba, 1, dim=2) # [Tz, B, 1] z_token = list(z_token.squeeze(2).transpose(0, 1)) return [list(_) for _ in z_token] def pz_selective_sampling(self, pz_proba): """ Selective sampling of pz """ if cfg.spv_proportion == 0: return self.pz_max_sampling(pz_proba) pz_proba = pz_proba.data z_proba, z_token = torch.topk(pz_proba, pz_proba.size(0), dim=2) z_token = z_token.transpose(0, 1) # [B,Tz,top_Tz] all_sampled_z = [] for b in range(z_token.size(0)): sampled_z = [] for t in range(z_token.size(1)): for i in range(z_token.size(2)): if z_token[b][t][i] not in sampled_z: sampled_z.append(z_token[b][t][i]) break all_sampled_z.append(sampled_z) return all_sampled_z def beam_search_decode_single(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input, eos_token_id): """ Single beam search decoding. Batch size have to be 1. :param eos_token_id: :param degree_input: :param last_hidden: :param m_tm1: :param pz_dec_outs: [T,1,H] :param pz_proba: [T,1,V] :param u_enc_out: [T,1,H] :return: """ eos_token_id = self.vocab.encode(cfg.eos_m_token) batch_size = pz_dec_outs.size(1) if batch_size != 1: raise ValueError('"Beam search single" requires batch size to be 1') class BeamState: def __init__(self, score, last_hidden, decoded, length): """ Beam state in beam decoding :param score: sum of log-probabilities :param last_hidden: last hidden :param decoded: list of *Variable[1*1]* of all decoded words :param length: current decoded sentence length """ self.score = score self.last_hidden = last_hidden self.decoded = decoded self.length = length def update_clone(self, score_incre, last_hidden, decoded_t): decoded = copy.copy(self.decoded) decoded.append(decoded_t) clone = BeamState(self.score + score_incre, last_hidden, decoded, self.length + 1) return clone def beam_result_valid(decoded_t): pz_max_samples = self.pz_selective_sampling(pz_proba) requested, start = [], False t = 0 while t < len(pz_max_samples[0]) and pz_max_samples[0][t] != self.vocab.encode('EOS_Z1'): t += 1 t += 1 while t < len(pz_max_samples[0]) and pz_max_samples[0][t] != self.vocab.encode('EOS_Z2'): requested.append(self.vocab.decode(pz_max_samples[0][t])) t += 1 decoded_t = [_.view(-1).data[0] for _ in decoded_t] decoded_sentence = self.vocab.sentence_decode(decoded_t, cfg.eos_m_token) requested = set(requested).intersection(['address', 'food', 'pricerange', 'phone', 'postcode']) # return True for rq in requested: if '%s SLOT' % rq not in decoded_sentence: #print('Fail %s' % decoded_sentence) return False #print('Success %s' % decoded_sentence) return True def score_bonus(state, decoded): """ bonus scheme: bonus per token, or per new decoded slot. :param state: :return: """ bonus = cfg.beam_len_bonus decoded = self.vocab.decode(decoded) decoded_t = [_.view(-1).data[0] for _ in state.decoded] decoded_sentence = self.vocab.sentence_decode(decoded_t, cfg.eos_m_token) decoded_sentence = decoded_sentence.split() if len(decoded_sentence) >= 1 and decoded_sentence[-1] == decoded: # repeated words bonus -= 10000 if decoded == '**unknown**': bonus -= 3.0 return bonus def soft_score_incre(score, turn): return score finished, failed = [], [] states = [] # sorted by score decreasingly dead_k = 0 states.append(BeamState(0, last_hidden, [m_tm1], 0)) for t in range(self.max_ts): new_states = [] k = 0 while k < len(states) and k < self.beam_size - dead_k: state = states[k] last_hidden, m_tm1 = state.last_hidden, state.decoded[-1] proba, last_hidden = self.m_decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, degree_input, last_hidden) proba = torch.log(proba) mt_proba, mt_index = torch.topk(proba, self.beam_size - dead_k) # [1,K] for new_k in range(self.beam_size - dead_k): score_incre = soft_score_incre(mt_proba[0][new_k].data[0], t) + score_bonus(state, mt_index[0][new_k].data[0]) if len(new_states) >= self.beam_size - dead_k and state.score + score_incre < new_states[-1].score: break decoded_t = mt_index[0][new_k] if self.vocab.decode(decoded_t.data[0]) == cfg.eos_m_token: if beam_result_valid(state.decoded): finished.append(state) dead_k += 1 else: failed.append(state) else: decoded_t = decoded_t.view(1, -1) new_state = state.update_clone(score_incre, last_hidden, decoded_t) new_states.append(new_state) #beam_result_valid(new_state.decoded) #print(self.vocab.decode(decoded_t.view(-1).data[0]), t, new_k) k += 1 if self.beam_size - dead_k < 0: break new_states = new_states[:self.beam_size - dead_k] new_states.sort(key=lambda x: -x.score) states = new_states if t == self.max_ts - 1 and not finished: finished = failed if not finished: finished.append(states[0]) finished.sort(key=lambda x: -x.score) decoded_t = finished[0].decoded decoded_t = [_.view(-1).data[0] for _ in decoded_t] decoded_sentence = self.vocab.sentence_decode(decoded_t, cfg.eos_m_token) print(decoded_sentence) generated = torch.cat(finished[0].decoded, dim=1).data # [B=1, T] return generated def beam_search_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input, eos_token_id): vars = torch.split(pz_dec_outs, 1, dim=1), torch.split(pz_proba, 1, dim=1), torch.split(u_enc_out, 1, dim=1), torch.split( m_tm1, 1, dim=1), torch.split(last_hidden, 1, dim=1), torch.split(degree_input, 1, dim=0) decoded = [] for pz_dec_out_s, pz_proba_s, u_enc_out_s, m_tm1_s, last_hidden_s, degree_input_s in zip(*vars): decoded_s = self.beam_search_decode_single(pz_dec_out_s, pz_proba_s, u_enc_out_s, m_tm1_s, last_hidden_s, degree_input_s, eos_token_id) decoded.append(decoded_s) return [list(_.view(-1)) for _ in decoded] def supervised_loss(self, pz_proba, qz_proba, pm_dec_proba, z_input, m_input): pr_loss = self.pr_loss(pz_proba.view(-1, pz_proba.size(2)), z_input.view(-1)) m_loss = self.dec_loss(pm_dec_proba.view(-1, pm_dec_proba.size(2)), m_input.view(-1)) q_loss = self.q_loss(qz_proba.view(-1, pz_proba.size(2)), z_input.view(-1)) pr_loss, m_loss, q_loss = pr_loss, m_loss, q_loss if cfg.pretrain: loss = q_loss else: loss = pr_loss + m_loss + q_loss return loss, pr_loss, m_loss, q_loss def unsupervised_loss(self, mu_q, log_sigma_q, mu_p, log_sigma_p, pm_dec_proba, m_input, pz_proba, qz_proba): m_loss = self.dec_loss(pm_dec_proba.view(-1, pm_dec_proba.size(2)), m_input.view(-1)) kl_div_loss = self.kl_loss(pz_proba, qz_proba.detach()) m_loss, kl_div_loss = m_loss, kl_div_loss * self.alpha loss = m_loss + kl_div_loss return loss, m_loss, kl_div_loss def self_adjust(self, epoch): pass