""" Dataset and dataloader for imsitu experiments. This allows us to: 1) Finetune on Imsitu 2) Finetune on a zero shot setting """ import spacy import torch import os from config import IMSITU_TRAIN_LIST, IMSITU_VAL_LIST, IMSITU_TEST_LIST, IMSITU_IMGS from torchvision.transforms import Scale, RandomCrop, CenterCrop, ToTensor, Normalize, Compose, RandomHorizontalFlip from PIL import Image from data.attribute_loader import Attributes from collections import namedtuple from torch.autograd import Variable LISTS = { 'train': IMSITU_TRAIN_LIST, 'val': IMSITU_VAL_LIST, 'test': IMSITU_TEST_LIST, } def _load_imsitu_file(mode): """ Helper fn that loads imsitu file :param fn: :return: """ if mode not in LISTS: raise ValueError("Invalid mode {}, must be train val or test".format(mode)) imsitu_ind_to_label = {} dps = [] with open(LISTS[mode], 'r') as f: for row in f.read().splitlines(): fn_ext = row.split(' ')[0] label = fn_ext.split('_')[0] # This has "ing" on it, so we can't use it for the word # label. But needed to construct the filename ind = int(row.split(' ')[1]) fn = os.path.join(IMSITU_IMGS, label, fn_ext) imsitu_ind_to_label[ind] = label dps.append((fn, ind)) return dps class ImSitu(torch.utils.data.Dataset): def __init__(self, use_train_verbs=False, use_val_verbs=False, use_test_verbs=False, use_train_images=False, use_val_images=False, use_test_images=False, vector_type='glove', word_type='lemma', ): self.vector_type = vector_type self.word_type = word_type self.use_train_verbs = use_train_verbs self.use_val_verbs = use_val_verbs self.use_test_verbs = use_test_verbs if not (self.use_train_verbs or self.use_val_verbs or self.use_test_verbs): raise ValueError("No verbs selected!") self.use_train_images = use_train_images self.use_val_images = use_val_images self.use_test_images = use_test_images if not (self.use_train_verbs or self.use_val_verbs or self.use_test_verbs): raise ValueError("No images selected!") self.attributes = Attributes( vector_type=vector_type, word_type=word_type, use_train=self.use_train_verbs, use_val=self.use_val_verbs, use_test=self.use_test_verbs, imsitu_only=True) self.examples = [] for mode, to_use in zip( ['train', 'val', 'test'], [self.use_train_images, self.use_val_images, self.use_test_images], ): if to_use: self.examples += [(fn, self.attributes.ind_perm[ind]) for fn, ind in _load_imsitu_file(mode) if ind in self.attributes.ind_perm] self.transform = transform(is_train=not self.use_test_verbs) def __getitem__(self, index): fn, ind = self.examples[index] img = self.transform(Image.open(fn).convert('RGB')) return img, ind @classmethod def splits(cls, zeroshot=False, **kwargs): """ Gets splits :param zeroshot: True if we're transferring to zeroshot classes :return: train, val, test datasets """ if zeroshot: train_cls = cls(use_train_verbs=True, use_train_images=True, use_val_images=True, **kwargs) val_cls = cls(use_val_verbs=True, use_train_images=True, use_val_images=True, **kwargs) test_cls = cls(use_test_verbs=True, use_test_images=True, **kwargs) else: train_cls = cls(use_train_verbs=True, use_train_images=True, **kwargs) val_cls = cls(use_train_verbs=True, use_val_images=True, **kwargs) test_cls = cls(use_train_verbs=True, use_test_images=True, **kwargs) return train_cls, val_cls, test_cls def __len__(self): return len(self.examples) Batch = namedtuple('Batch', ['img', 'label']) class CudaDataLoader(torch.utils.data.DataLoader): """ Iterates through the data, but also loads everything as a (cuda) variable """ def __init__(self, *args, volatile=False, **kwargs): super(CudaDataLoader, self).__init__(*args, **kwargs) self.volatile = volatile def _load(self, item): img = Variable(item[0], volatile=self.volatile) label = Variable(item[1], volatile=self.volatile) if torch.cuda.is_available(): img = img.cuda() label = label.cuda() return Batch(img, label) def __iter__(self): return (self._load(x) for x in super(CudaDataLoader, self).__iter__()) @classmethod def splits(cls, train, val, test, batch_size, num_workers=0, **kwargs): """ gets dataloaders given datasets :param train: :param val: :param test: :param batch_size: :param num_workers: :return: """ train_dl = cls( dataset=train, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, **kwargs, ) val_dl = cls( dataset=val, batch_size=batch_size*16, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, volatile=True, **kwargs, ) test_dl = cls( dataset=test, batch_size=batch_size*16, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, volatile=True, **kwargs, ) return train_dl, val_dl, test_dl def transform(is_train=True, normalize=True): """ Returns a transform object """ filters = [] filters.append(Scale(256)) if is_train: filters.append(RandomCrop(224)) else: filters.append(CenterCrop(224)) if is_train: filters.append(RandomHorizontalFlip()) filters.append(ToTensor()) if normalize: filters.append(Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) return Compose(filters) def collate_fn(data): imgs, labels = zip(*data) imgs = torch.stack(imgs, 0) labels = torch.LongTensor(labels) return imgs, labels if __name__ == '__main__': train, val, test = ImSitu.splits() train_dl = CudaDataLoader( dataset=train, batch_size=32, shuffle=True, num_workers=0, collate_fn=collate_fn )