""" CelebA Dataloader implementation, used in DCGAN """ import numpy as np import imageio import torch import torchvision.transforms as v_transforms import torchvision.utils as v_utils import torchvision.datasets as v_datasets from torch.utils.data import DataLoader, TensorDataset, Dataset class CelebADataLoader: def __init__(self, config): self.config = config if config.data_mode == "imgs": transform = v_transforms.Compose( [v_transforms.ToTensor(), v_transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) dataset = v_datasets.ImageFolder(self.config.data_folder, transform=transform) self.dataset_len = len(dataset) self.num_iterations = (self.dataset_len + config.batch_size - 1) // config.batch_size self.loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.data_loader_workers, pin_memory=config.pin_memory) elif config.data_mode == "numpy": raise NotImplementedError("This mode is not implemented YET") else: raise Exception("Please specify in the json a specified mode in data_mode") def plot_samples_per_epoch(self, fake_batch, epoch): """ Plotting the fake batch :param fake_batch: Tensor of shape (B,C,H,W) :param epoch: the number of current epoch :return: img_epoch: which will contain the image of this epoch """ img_epoch = '{}samples_epoch_{:d}.png'.format(self.config.out_dir, epoch) v_utils.save_image(fake_batch, img_epoch, nrow=4, padding=2, normalize=True) return imageio.imread(img_epoch) def make_gif(self, epochs): """ Make a gif from a multiple images of epochs :param epochs: num_epochs till now :return: """ gen_image_plots = [] for epoch in range(epochs + 1): img_epoch = '{}samples_epoch_{:d}.png'.format(self.config.out_dir, epoch) try: gen_image_plots.append(imageio.imread(img_epoch)) except OSError as e: pass imageio.mimsave(self.config.out_dir + 'animation_epochs_{:d}.gif'.format(epochs), gen_image_plots, fps=2) def finalize(self): pass