# # KTH Royal Institute of Technology # import torch.utils.data as data import torch from torchvision.transforms import CenterCrop, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation import numpy as np import random from PIL import Image import src.data_manager as data_manager import src.config as config def pil_to_numpy(x_pil): """ :param x_pil: PIL.Image object :return: Normalized numpy array of shape (channels, height, width) """ # Channels are the third dim of a PIL.Image, # but we want to be able to index it by channel first, # so we use np.rollaxis to get an array of shape (3, h, w) return np.rollaxis(np.asarray(x_pil) / 255.0, 2) def pil_to_tensor(x_pil): """ :param x_pil: PIL.Image object :return: Normalized torch tensor of shape (channels, height, width) """ x_np = pil_to_numpy(x_pil) return torch.from_numpy(x_np).float() def numpy_to_pil(x_np): """ :param x_np: Image as a numpy array of shape (channels, height, width) :return: PIL.Image object """ x_np = x_np.copy() x_np *= 255.0 x_np = x_np.clip(0, 255) # PIL.Image wants the channel as the last dimension x_np = np.rollaxis(x_np, 0, 3).astype(np.uint8) return Image.fromarray(x_np, mode='RGB') class PatchDataset(data.Dataset): def __init__(self, patches, use_cache, augment_data): super(PatchDataset, self).__init__() self.patches = patches self.crop = CenterCrop(config.CROP_SIZE) if augment_data: self.random_transforms = [RandomRotation((90, 90)), RandomVerticalFlip(1.0), RandomHorizontalFlip(1.0), (lambda x: x)] self.get_aug_transform = (lambda: random.sample(self.random_transforms, 1)[0]) else: # Transform does nothing. Not sure if horrible or very elegant... self.get_aug_transform = (lambda: (lambda x: x)) if use_cache: self.load_patch = data_manager.load_cached_patch else: self.load_patch = data_manager.load_patch print('Dataset ready with {} tuples.'.format(len(patches))) @staticmethod def random_temporal_order_swap(x1, x2): if random.random() <= config.RANDOM_TEMPORAL_ORDER_SWAP_PROB: return x2, x1 else: return x1, x2 def __getitem__(self, index): frames = self.load_patch(self.patches[index]) aug_transform = self.get_aug_transform() x1, target, x2 = (pil_to_tensor(self.crop(aug_transform(x))) for x in frames) x1, x2, = self.random_temporal_order_swap(x1, x2) input = torch.cat((x1, x2), dim=0) return input, target def __len__(self): return len(self.patches) class ValidationDataset(data.Dataset): def __init__(self, tuples): super(ValidationDataset, self).__init__() self.tuples = tuples self.crop = CenterCrop(config.CROP_SIZE) def __getitem__(self, index): frames = self.tuples[index] x1, target, x2 = (pil_to_tensor(self.crop(data_manager.load_img(x))) for x in frames) input = torch.cat((x1, x2), dim=0) return input, target def __len__(self): return len(self.tuples) def get_training_set(): patches = data_manager.prepare_dataset() if config.CACHE_PATCHES: patches = data_manager.get_cached_patches() patches = patches[:config.MAX_TRAINING_SAMPLES] return PatchDataset(patches, config.CACHE_PATCHES, config.AUGMENT_DATA) def get_validation_set(): davis_17_test = data_manager.get_davis_17_test(config.DATASET_DIR) tuples = data_manager.tuples_from_davis(davis_17_test, res='480p') n_samples = min(len(tuples), config.MAX_VALIDATION_SAMPLES) random_ = random.Random(42) tuples = random_.sample(tuples, n_samples) return ValidationDataset(tuples) def get_visual_test_set(): return data_manager.get_selected_davis(res='480p')