import torch import models.modules.network as network from util.util import tensor2im from models.base_model import BaseModel TEST_SEQ = ['this small bird has a blue crown and white belly', 'this small yellow bird has gray wings and a black bill', 'a small brown bird with a brown crown has a white belly', 'this black bird has no other colors with a short bill', 'an orange bird with green wings and blue head', 'a black bird with a red head', 'this particular bird with a red head and breast and features grey wings'] ################## SemanticImageSynthesis ############################# class SemanticImageSynthesisModel(BaseModel): def __init__(self, opt): BaseModel.__init__(self, opt) self.enc_attribute = network.get_attribute_encoder('cub_text',opt) self.enc_attribute.to(self.device) self.enc_attribute.eval() def prepare_data(self,data): img, captions, captions_lens = data batch_size = img.size(0) captions_lens, sorted_cap_indices = torch.sort(captions_lens, 0, True) img = img[sorted_cap_indices].to(self.device) captions = captions[sorted_cap_indices].squeeze(dim=2).to(self.device) captions_lens = captions_lens.to(self.device) hidden = self.enc_attribute.init_hidden(batch_size) _, sent_emb = self.enc_attribute(captions, captions_lens, hidden) index_target = torch.tensor(range(-1,batch_size-1)).to(self.device) weight_source = torch.ones([batch_size,1]).to(self.device) self.current_data = [img, sent_emb, index_target,weight_source] return self.current_data def translation(self, data): with torch.no_grad(): img, cap_ori, cap_len_ori = data assert img.size(0) == 1 img = img.repeat(len(TEST_SEQ)+1,1,1,1) cap_tar, cap_len_tar = [cap_ori], [cap_len_ori] for seq in TEST_SEQ: cap, cap_len = self.opt.txt_dataset.cap2ix(seq) cap = torch.LongTensor(cap).unsqueeze(0) cap_len = torch.LongTensor([cap_len]) cap_tar.append(cap) cap_len_tar.append(cap_len) cap_tar = torch.cat(cap_tar,dim=0) cap_len_tar = torch.cat(cap_len_tar,dim=0) img, sent_emb, _, _ = self.prepare_data([img,cap_tar,cap_len_tar]) style_enc, _, _ = self.enc_style(img) content = self.enc_content(img) fakes = self.dec(content,torch.cat([sent_emb,style_enc],dim=1)) results = [('input',tensor2im(img[0].data)), ('rec',tensor2im(fakes[0].data))] for i in range(len(TEST_SEQ)): results.append(('seq_{}'.format(i+1),tensor2im(fakes[i+1].data))) return results