import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import math, copy, time import pdb from torchtext import data, datasets def subsequent_mask(size): "Mask out subsequent positions." attn_shape = (1, size, size) subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') return torch.from_numpy(subsequent_mask) == 0 def seq1toseq2_mask(seq1, seq2, pad): temp = (seq1 != pad).unsqueeze(-1).expand((seq1.shape[0], seq1.shape[1], seq2.shape[-1])) output = temp & (seq2 != pad).unsqueeze(-2).expand((seq2.shape[0], seq1.shape[1], seq2.shape[-1])) return output class Batch: "Object for holding a batch of data with mask during training." def __init__(self, query, his, his_st, fts=None, cap=None, trg=None, trg_y=None, pad=0): self.query = query self.his = his self.his_st = his_st if fts is not None: permuted_fts = [torch.from_numpy(ft).float().cuda().permute(1,0,2) for ft in fts] self.fts_mask = [(torch.sum(permuted_ft != 1, dim=2) != 0).unsqueeze(-2) for permuted_ft in permuted_fts] self.fts = [ ft * self.fts_mask[i].squeeze().unsqueeze(-1).expand_as(ft).float() for i, ft in enumerate(permuted_fts)] else: self.fts = None self.fts_mask = None self.query_mask = (query != pad).unsqueeze(-2) self.his_mask = (his != pad).unsqueeze(-2) if cap is not None: self.cap = cap self.cap_mask = (cap != pad).unsqueeze(-2) else: self.cap = None self.cap_mask = None if trg is not None: self.trg = trg self.trg_y = trg_y self.trg_mask = self.make_std_mask(self.trg, pad) self.ntokens = (self.trg_y != pad).data.sum() @staticmethod def make_std_mask(tgt, pad): "Create a mask to hide padding and future words." tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) return tgt_mask global max_src_in_batch, max_tgt_in_batch def batch_size_fn(new, count, sofar): "Keep augmenting batch and calculate total number of tokens + padding." global max_src_in_batch, max_tgt_in_batch if count == 1: max_src_in_batch = 0 max_tgt_in_batch = 0 max_src_in_batch = max(max_src_in_batch, len(new.src)) max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2) src_elements = count * max_src_in_batch tgt_elements = count * max_tgt_in_batch return max(src_elements, tgt_elements) class MyIterator(data.Iterator): def create_batches(self): if self.train: def pool(d, random_shuffler): for p in data.batch(d, self.batch_size * 100): p_batch = data.batch( sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn) for b in random_shuffler(list(p_batch)): yield b self.batches = pool(self.data(), self.random_shuffler) else: self.batches = [] for b in data.batch(self.data(), self.batch_size, self.batch_size_fn): self.batches.append(sorted(b, key=self.sort_key)) def rebatch(pad_idx, batch): "Fix order in torchtext to match ours" src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1) return Batch(src, trg, pad_idx) class NoamOpt: "Optim wrapper that implements rate." def __init__(self, model_size, factor, warmup, optimizer): self.optimizer = optimizer self._step = 0 self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 def step(self): "Update parameters and rate" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p['lr'] = rate self._rate = rate self.optimizer.step() def rate(self, step = None): "Implement `lrate` above" if step is None: step = self._step return self.factor * \ (self.model_size ** (-0.5) * \ min(step ** (-0.5), step * self.warmup ** (-1.5))) def get_std_opt(model): return NoamOpt(model.src_embed[0].d_model, 2, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) class SimpleLossCompute: "A simple loss compute and train function." def __init__(self, generator, ae_generator, criterion, opt=None, l=1.0): self.generator = generator self.ae_generator= ae_generator self.criterion = criterion self.opt = opt self.l = l def __call__(self, x, y, norm, ae_x=None, ae_y=None, ae_norm=None): out = self.generator(x) loss = self.criterion(out.contiguous().view(-1, out.size(-1)), y.contiguous().view(-1)) / norm.float() if ae_x is not None: if type(ae_x) == list: for i, ae_in in enumerate(ae_x): if self.ae_generator is not None: ae_out = self.ae_generator[i](ae_in) else: ae_out = self.generator(ae_in) loss += self.l * self.criterion(ae_out.contiguous().view(-1, ae_out.size(-1)), ae_y.contiguous().view(-1)) / ae_norm.float() else: if self.ae_generator is not None: ae_out = self.ae_generator(ae_x) else: ae_out = self.generator(ae_x) loss += self.l * self.criterion(ae_out.contiguous().view(-1, ae_out.size(-1)), ae_y.contiguous().view(-1)) / ae_norm.float() loss.backward() if self.opt is not None: self.opt.step() self.opt.optimizer.zero_grad() return loss.item() * norm.float() def encode(model, his, his_st, his_mask, cap, cap_mask, query, query_mask, video_features, video_features_mask): query_memory, encoded_vid_features, cap_memory, his_memory, ae_encoded_ft = model.encode(query, query_mask, his, his_mask, cap, cap_mask, video_features, video_features_mask) return his_memory, cap_memory, query_memory, encoded_vid_features, ae_encoded_ft def greedy_decode(model, batch, max_len, start_symbol, pad_symbol): video_features, video_features_mask, cap, cap_mask, his, his_st, his_mask, query, query_mask = batch.fts, batch.fts_mask, batch.cap, batch.cap_mask, batch.his, batch.his_st, batch.his_mask, batch.query, batch.query_mask his_memory, cap_memory, query_memory, encoded_vid_features, ae_encoded_ft = encode(model, his, his_st, his_mask, cap, cap_mask, query, query_mask, video_features, video_features_mask) ys = torch.ones(1, 1).fill_(start_symbol).type_as(query.data) for i in range(max_len-1): cap2res_mask = None out = model.decode(encoded_vid_features, his_memory, cap_memory, query_memory, video_features_mask, his_mask, cap_mask, query_mask, Variable(ys), Variable(subsequent_mask(ys.size(1)).type_as(query.data)), cap2res_mask, ae_encoded_ft) if type(out) == list: prob = 0 for idx, o in enumerate(out): prob += model.generator[idx](o[:,-1]) else: prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim = 1) next_word = next_word.data[0] ys = torch.cat([ys, torch.ones(1, 1).type_as(query.data).fill_(next_word)], dim=1) return ys def beam_search_decode(model, batch, max_len, start_symbol, unk_symbol, end_symbol, pad_symbol, beam=5, penalty=1.0, nbest=5, min_len=1): video_features, video_features_mask, cap, cap_mask, his, his_st, his_mask, query, query_mask = batch.fts, batch.fts_mask, batch.cap, batch.cap_mask, batch.his, batch.his_st, batch.his_mask, batch.query, batch.query_mask his_memory, cap_memory, query_memory, encoded_vid_features, ae_encoded_ft = encode(model, his, his_st, his_mask, cap, cap_mask, query, query_mask, video_features, video_features_mask) ds = torch.ones(1, 1).fill_(start_symbol).type_as(query.data) hyplist=[([], 0., ds)] best_state=None comp_hyplist=[] for l in range(max_len): new_hyplist = [] argmin = 0 for out, lp, st in hyplist: cap2res_mask = None output = model.decode(encoded_vid_features, his_memory, cap_memory, query_memory, video_features_mask, his_mask, cap_mask, query_mask, Variable(st), Variable(subsequent_mask(st.size(1)).type_as(query.data)), ae_encoded_ft) if type(output) == tuple or type(output) == list: logp = model.generator(output[0][:, -1]) else: logp = model.generator(output[:, -1]) lp_vec = logp.cpu().data.numpy() + lp lp_vec = np.squeeze(lp_vec) if l >= min_len: new_lp = lp_vec[end_symbol] + penalty * (len(out) + 1) comp_hyplist.append((out, new_lp)) if best_state is None or best_state < new_lp: best_state = new_lp count = 1 for o in np.argsort(lp_vec)[::-1]: if o == unk_symbol or o == end_symbol: continue new_lp = lp_vec[o] if len(new_hyplist) == beam: if new_hyplist[argmin][1] < new_lp: new_st = torch.cat([st, torch.ones(1,1).type_as(query.data).fill_(int(o))], dim=1) new_hyplist[argmin] = (out + [o], new_lp, new_st) argmin = min(enumerate(new_hyplist), key=lambda h:h[1][1])[0] else: break else: new_st = torch.cat([st, torch.ones(1,1).type_as(query.data).fill_(int(o))], dim=1) new_hyplist.append((out + [o], new_lp, new_st)) if len(new_hyplist) == beam: argmin = min(enumerate(new_hyplist), key=lambda h:h[1][1])[0] count += 1 hyplist = new_hyplist if len(comp_hyplist) > 0: maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:nbest] return maxhyps, best_state else: return [([], 0)], None