import six import collections from os.path import join as pjoin import warnings warnings.filterwarnings("ignore") # skimage warnings from PIL import Image from skimage.transform import resize as imresize import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms import numpy as np import jacinle.io as io from jacinle.cli.argument import JacArgumentParser from jacinle.utils.enum import JacEnum from jactorch.utils.meta import as_variable, as_cuda from jactorch.graph.variable import var_with from vocab import Vocabulary Record = collections.namedtuple('Record', [ 'raw_image', 'raw_caption', 'raw_caption_ext', 'image', 'image_embedding', 'image_embedding_precomp', 'captions', 'caption_embedding', 'caption_ext_embedding' ]) parser = JacArgumentParser() parser.add_argument('--load', required=True) parser.add_argument('--encoder', required=True) parser.add_argument('--load-encoder', required=True, type='checked_file') parser.add_argument('--images', required=True, type='checked_dir') parser.add_argument('--image-list', required=True, type='checked_file') parser.add_argument('--image-embeddings', required=True, type='checked_file') parser.add_argument('--captions', required=True, type='checked_file') parser.add_argument('--captions-ext', required=True, type='checked_dir') args = parser.parse_args() args.grad_power = 0.5 def main(): encoder = ImageEncoder(args.encoder, args.load_encoder) dataset = Dataset(args) extractor = FeatureExtractor(args.load, encoder, dataset) def e(ind): a = extractor(ind) pic = plot_saliency(a.raw_image, a.image, a.image_embedding, a.caption_embedding) pic.save('/tmp/origin{}.png'.format(ind)) print('Image saliency saved:', '/tmp/origin{}.png'.format(ind)) print_txt_saliency(a.captions, 0, a.raw_caption, a.image_embedding, a.caption_embedding) return a from IPython import embed; embed() def normalize_grad(grad, stat=False): grad = np.abs(grad) if stat: print('Grad min={}, max={}'.format(grad.min(), grad.max())) grad -= grad.min() grad /= grad.max() return grad.astype('float32') def print_txt_saliency(txt, ind, content, img_embedding_var, caption_embedding_var, backward=False): if backward: dis = (caption_embedding_var.squeeze() * img_embedding_var.squeeze()).sum() dis.backward(retain_graph=True) content = content.split() assert txt.grad.size(1) == 2 + len(content), (txt.grad.size(1), len(content)) grad_txt = txt.grad[ind,1:1+len(content)].data.cpu().squeeze().abs().numpy() grad_txt = grad_txt.mean(-1) grad_txt /= grad_txt.sum() print('Text saliency:', ' '.join(['{}({:.3f})'.format(c, float(g)) for c, g in zip(content, grad_txt)])) def plot_saliency(raw_img, image_var, img_embedding_var, caption_var): dis = (caption_var.squeeze() * img_embedding_var.squeeze()).sum() dis.backward(retain_graph=True) grad = image_var.grad.data.cpu().squeeze().numpy().transpose((1, 2, 0)) grad = normalize_grad(grad, stat=True) grad = imresize((grad * 255).astype('uint8'), (raw_img.height, raw_img.width)) / 255 grad = normalize_grad(grad.mean(axis=-1, keepdims=True).repeat(3, axis=-1)) grad = np.float_power(grad, args.grad_power) np_img = np.array(raw_img) masked_img = np_img * grad final = np.hstack([np_img, masked_img.astype('uint8'), (grad * 255).astype('uint8')]) return Image.fromarray(final.astype('uint8')) def plot(record, ind, pic_ind=None): p = plot_saliency(record.raw_image, record.image, record.image_embedding, record.caption_ext_embedding[ind]) if pic_ind is not None: name = '/tmp/{}_{}.png'.format(pic_ind, ind) else: name = '/tmp/{}.png'.format(ind) p.save(name) print('Image saliency saved:', name) print_txt_saliency(record.captions, ind + 1, record.raw_caption_ext[ind], record.image_embedding, record.caption_ext_embedding[ind]) class ImageEncoderType(JacEnum): RESNET152 = 'resnet152' class Identity(nn.Module): def forward(self, x): return x def ImageEncoder(encoder, load): encoder = ImageEncoderType.from_string(encoder) if encoder is ImageEncoderType.RESNET152: encoder = models.resnet152() encoder.load_state_dict(torch.load(load)) encoder.fc = Identity() encoder.cuda() encoder.eval() return encoder class Dataset(object): def __init__(self, args): self.images = args.images self.image_list = args.image_list self.image_embeddings = args.image_embeddings self.captions = args.captions self.captions_ext = args.captions_ext self.load_files() self.build_image_transforms() def get_image(self, ind): img_file = self.image_list[ind].strip() if 'train2014' in img_file: img_file = pjoin('train2014', img_file) elif 'val2014' in img_file: img_file = pjoin('val2014', img_file) img = Image.open(pjoin(self.images, img_file)).convert('RGB') return img def get_image_embedding(self, ind): return self.image_embeddings[ind] def get_caption(self, ind): return self.captions[ind].strip() def get_caption_ext(self, ind): with open('{}/{}.txt'.format(self.captions_ext, str(ind))) as f: return [line.strip() for line in f.readlines()] def load_files(self): print('Loading captions and precomputed image embeddings') self.image_embeddings = io.load(self.image_embeddings) self.captions = list(open(self.captions)) image_list_dup = list() image_list = list(open(self.image_list)) assert len(image_list) * 5 == len(self.captions) for img in image_list: for i in range(5): image_list_dup.append(img) self.image_list = image_list_dup def build_image_transforms(self): self.image_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def __getitem__(self, ind): img = self.get_image(ind) img_embedding = self.get_image_embedding(ind) cap = self.get_caption(ind) cap_ext = self.get_caption_ext(ind) return img, self.image_transform(img), img_embedding, cap, cap_ext class FeatureExtractor(object): def __init__(self, checkpoint, image_encoder, dataset): self.load_checkpoint(checkpoint) self.image_encoder = image_encoder self.dataset = dataset def load_checkpoint(self, checkpoint): checkpoint = torch.load(checkpoint) opt = checkpoint['opt'] opt.use_external_captions = False vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name)) opt.vocab_size = len(vocab) from model import VSE self.model = VSE(opt) self.model.load_state_dict(checkpoint['model']) self.projector = vocab self.model.img_enc.eval() self.model.txt_enc.eval() for p in self.model.img_enc.parameters(): p.requires_grad = False for p in self.model.txt_enc.parameters(): p.requires_grad = False def __call__(self, ind): raw_img, img, img_embedding, cap, cap_ext = self.dataset[ind] img_embedding_precomp = self.model.img_enc(as_cuda(as_variable(img_embedding).unsqueeze(0))) img = as_variable(img) img.requires_grad = True img_embedding_a = img_embedding = self.image_encoder(as_cuda(img.unsqueeze(0))) img_embedding = self.model.img_enc(img_embedding) txt = [cap] txt.extend(cap_ext) txt_embeddings, txt_var = self.enc_txt(txt) return Record( raw_img, cap, cap_ext, img, img_embedding, img_embedding_precomp, txt_var, txt_embeddings[0], txt_embeddings[1:] ) def enc_txt(self, caps): sents, lengths, _, inv = _prepare_batch(caps, self.projector) inv = var_with(as_variable(inv), sents) out, x = self.model.txt_enc.forward(sents, lengths, True) return out[inv], x class Vocab(object): def __init__(self, idx2word=None, options=None, sync=None): assert options is None if sync is not None: self.idx2word = sync.idx2word self.word2idx = sync.word2idx else: self.idx2word = idx2word self.word2idx = dict([(w, i) for i, w in enumerate(self.idx2word)]) self.sent_trunc_length = None @classmethod def from_pickle(cls, path): vocab = io.load(path) return cls(sync=vocab) def project(self, sentence): sentence = sentence.strip().lower().split() sentence = ['<start>'] + sentence + ['<end>'] if self.sent_trunc_length is not None: if len(sentence) > self.sent_trunc_length: sentence = sentence[:self.sent_trunc_length] return list(map(lambda word: self.word2idx.get(word, 3), sentence)) def __len__(self): return len(self.idx2word) def __call__(self, sent): return self.project(sent) def _prepare_batch(sents, projector): if isinstance(sents, six.string_types): sents = [sents] sents = [np.array(projector(s)) for s in sents] lengths = [len(s) for s in sents] sents = _pad_sequences(sents, 0, max(lengths)) idx = np.array(sorted(range(len(lengths)), key=lambda x: lengths[x], reverse=True)) inv = np.array(sorted(range(len(lengths)), key=lambda x: idx[x])) sents = sents[idx] lengths = np.array(lengths)[idx].tolist() sents = as_variable(sents) if torch.cuda.is_available(): sents = sents.cuda() return sents, lengths, idx, inv def _pad_sequences(sequences, dim, length): seq_shape = list(sequences[0].shape) seq_shape[dim] = length output = np.zeros((len(sequences), ) + tuple(seq_shape), dtype=sequences[0].dtype) output = output.reshape((len(sequences), -1) + tuple(seq_shape[dim:])) for i, seq in enumerate(sequences): output[i, :, :seq.shape[dim], ...] = seq return output.reshape((len(sequences), ) + tuple(seq_shape)) if __name__ == '__main__': main()