import time import torch.utils.data import nvidia.dali.ops as ops import nvidia.dali.types as types import torchvision.datasets as datasets from nvidia.dali.pipeline import Pipeline import torchvision.transforms as transforms from nvidia.dali.plugin.pytorch import DALIClassificationIterator, DALIGenericIterator import torchvision.transforms as transforms def cifar_transform(is_training=True): if is_training: transform_list = [transforms.RandomHorizontalFlip(), transforms.Pad(padding=4, padding_mode='reflect'), transforms.RandomCrop(32, padding=0), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] else: transform_list = [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] transform_list = transforms.Compose(transform_list) return transform_list def imgnet_transform(is_training=True): if is_training: transform_list = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.3), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) else: transform_list = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) return transform_list class HybridTrainPipe(Pipeline): def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False, local_rank=0, world_size=1): super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id) dali_device = "gpu" self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True) self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) self.res = ops.RandomResizedCrop(device="gpu", size=crop, random_area=[0.08, 1.25]) self.cmnp = ops.CropMirrorNormalize(device="gpu", output_dtype=types.FLOAT, output_layout=types.NCHW, image_type=types.RGB, mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) self.coin = ops.CoinFlip(probability=0.5) print('DALI "{0}" variant'.format(dali_device)) def define_graph(self): rng = self.coin() self.jpegs, self.labels = self.input(name="Reader") images = self.decode(self.jpegs) images = self.res(images) output = self.cmnp(images, mirror=rng) return [output, self.labels] class HybridValPipe(Pipeline): def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, local_rank=0, world_size=1): super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id) self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=False) self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR) self.cmnp = ops.CropMirrorNormalize(device="gpu", output_dtype=types.FLOAT, output_layout=types.NCHW, crop=(crop, crop), image_type=types.RGB, mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) def define_graph(self): self.jpegs, self.labels = self.input(name="Reader") images = self.decode(self.jpegs) images = self.res(images) output = self.cmnp(images) return [output, self.labels] def get_imagenet_iter_dali(type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, val_size=256, world_size=1, local_rank=0): if type == 'train': pip_train = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank, data_dir=image_dir + '/train', crop=crop, world_size=world_size, local_rank=local_rank) pip_train.build() dali_iter_train = DALIClassificationIterator(pip_train, size=pip_train.epoch_size("Reader") // world_size, auto_reset=True) return dali_iter_train elif type == 'val': pip_val = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank, data_dir=image_dir + '/val', crop=crop, size=val_size, world_size=world_size, local_rank=local_rank) pip_val.build() dali_iter_val = DALIClassificationIterator(pip_val, size=pip_val.epoch_size("Reader") // world_size, auto_reset=True) return dali_iter_val def get_imagenet_iter_torch(type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, val_size=256, world_size=1, local_rank=0): if type == 'train': transform = transforms.Compose([ transforms.RandomResizedCrop(crop, scale=(0.08, 1.25)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) dataset = datasets.ImageFolder(image_dir + '/train', transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_threads, pin_memory=True) else: transform = transforms.Compose([ transforms.Resize(val_size), transforms.CenterCrop(crop), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) dataset = datasets.ImageFolder(image_dir + '/val', transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_threads, pin_memory=True) return dataloader