import os import torch import torchvision import torchvision.transforms as transforms from . import lmdb_dataset from . import torchvision_extension as transforms_extension from .prefetch_data import fast_collate class ImageNet12(object): def __init__(self, trainFolder, testFolder, num_workers=8, pin_memory=True, size_images=224, scaled_size=256, type_of_data_augmentation='rand_scale', data_config=None): self.data_config = data_config self.trainFolder = trainFolder self.testFolder = testFolder self.num_workers = num_workers self.pin_memory = pin_memory self.patch_dataset = self.data_config.patch_dataset #images will be rescaled to match this size if not isinstance(size_images, int): raise ValueError('size_images must be an int. It will be scaled to a square image') self.size_images = size_images self.scaled_size = scaled_size type_of_data_augmentation = type_of_data_augmentation.lower() if type_of_data_augmentation not in ('rand_scale', 'random_sized'): raise ValueError('type_of_data_augmentation must be either rand-scale or random-sized') self.type_of_data_augmentation = type_of_data_augmentation def _getTransformList(self, aug_type): assert aug_type in ['rand_scale', 'random_sized', 'week_train', 'validation'] list_of_transforms = [] if aug_type == 'validation': list_of_transforms.append(transforms.Resize(self.scaled_size)) list_of_transforms.append(transforms.CenterCrop(self.size_images)) elif aug_type == 'week_train': list_of_transforms.append(transforms.Resize(256)) list_of_transforms.append(transforms.RandomCrop(self.size_images)) list_of_transforms.append(transforms.RandomHorizontalFlip()) else: if aug_type == 'rand_scale': list_of_transforms.append(transforms_extension.RandomScale(256, 480)) list_of_transforms.append(transforms.RandomCrop(self.size_images)) list_of_transforms.append(transforms.RandomHorizontalFlip()) elif aug_type == 'random_sized': list_of_transforms.append(transforms.RandomResizedCrop(self.size_images, scale=(self.data_config.random_sized.min_scale, 1.0))) list_of_transforms.append(transforms.RandomHorizontalFlip()) if self.data_config.color: list_of_transforms.append(transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)) return transforms.Compose(list_of_transforms) def _getTrainSet(self): train_transform = self._getTransformList(self.type_of_data_augmentation) if self.data_config.train_data_type == 'img': train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform) elif self.data_config.train_data_type == 'lmdb': train_set = lmdb_dataset.ImageFolder(self.trainFolder, os.path.join(self.trainFolder, '..', 'train_datalist'), train_transform, patch_dataset=self.patch_dataset) self.train_num_examples = train_set.__len__() return train_set def _getWeekTrainSet(self): train_transform = self._getTransformList('week_train') if self.data_config.train_data_type == 'img': train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform) elif self.data_config.train_data_type == 'lmdb': train_set = lmdb_dataset.ImageFolder(self.trainFolder, os.path.join(self.trainFolder, '..', 'train_datalist'), train_transform, patch_dataset=self.patch_dataset) self.train_num_examples = train_set.__len__() return train_set def _getTestSet(self): test_transform = self._getTransformList('validation') if self.data_config.val_data_type == 'img': test_set = torchvision.datasets.ImageFolder(self.testFolder, test_transform) elif self.data_config.val_data_type == 'lmdb': test_set = lmdb_dataset.ImageFolder(self.testFolder, os.path.join(self.testFolder, '..', 'val_datalist'), test_transform) self.test_num_examples = test_set.__len__() return test_set def getTrainLoader(self, batch_size, shuffle=True): train_set = self._getTrainSet() train_loader = torch.utils.data.DataLoader( train_set, batch_size=batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory, sampler=None, collate_fn=fast_collate) return train_loader def getWeekTrainLoader(self, batch_size, shuffle=True): train_set = self._getWeekTrainSet() train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=fast_collate) return train_loader def getTestLoader(self, batch_size, shuffle=False): test_set = self._getTestSet() test_loader = torch.utils.data.DataLoader( test_set, batch_size=batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory, sampler=None, collate_fn=fast_collate) return test_loader def getTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False): train_loader = self.getTrainLoader(batch_size, train_shuffle) test_loader = self.getTestLoader(batch_size, val_shuffle) return train_loader, test_loader def getSetTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False): train_loader = self.getTrainLoader(batch_size, train_shuffle) week_train_loader = self.getWeekTrainLoader(batch_size, train_shuffle) test_loader = self.getTestLoader(batch_size, val_shuffle) return (train_loader, week_train_loader), test_loader