from pathlib import Path from PIL import Image from torch.utils.data import Dataset from torchvision import transforms class CUFED5Dataset(Dataset): """ Dataset class for CUFED5, which is a dataset provided the author of SRNTT. """ def __init__(self, dataroot: Path, scale_factor: int = 4): super(CUFED5Dataset, self).__init__() self.dataroot = Path(dataroot) self.filenames = list(set( [f.stem.split('_')[0] for f in self.dataroot.glob('*.png')] )) self.transforms = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) self.warp = transforms.RandomAffine( degrees=(10, 30), translate=(0.25, 0.5), scale=(1.2, 2.0), resample=Image.BICUBIC ) def __getitem__(self, index): def load_ref(f, i, warp=False): # ref image img_ref = Image.open(self.dataroot / f'{f}_{i}.png').convert('RGB') size = (x - (x % 4) for x in img_ref.size) img_ref = img_ref.resize(size, Image.BICUBIC) # adjustment to x4 if warp: img_ref = self.warp(img_ref) # down-upsampling ref image size = (x // 4 for x in img_ref.size) img_ref_blur = img_ref.resize(size, Image.BICUBIC)\ .resize(img_ref.size, Image.BICUBIC) return {'ref': self.transforms(img_ref), 'ref_blur': self.transforms(img_ref_blur)} filename = self.filenames[index] # HR image img_hr = Image.open(self.dataroot / f'{filename}_0.png').convert('RGB') size = (x - (x % 4) for x in img_hr.size) img_hr = img_hr.resize(size, Image.BICUBIC) # adjustment to x4 # LR image size = (x // 4 for x in img_hr.size) img_lr = img_hr.resize(size, Image.BICUBIC) # for feature swapping img_in_up = img_lr.resize(img_hr.size, Image.BICUBIC) ref_dict = {i: load_ref(filename, i) for i in range(6)} ref_dict.update({6: load_ref(filename, 0, warp=True)}) return {'img_hr': self.transforms(img_hr), 'img_lr': self.transforms(img_lr), 'img_in_up': self.transforms(img_in_up), 'ref': ref_dict, 'filename': filename} def __len__(self): return len(self.filenames) if __name__ == "__main__": from torch.utils.data import DataLoader dataset = CUFED5Dataset( dataroot='/home/ubuntu/srntt-pytorch/data/CUFED5' ) dataloader = DataLoader(dataset) for batch in dataloader: break