Python torchvision.transforms.ColorJitter() Examples

The following are 30 code examples of torchvision.transforms.ColorJitter(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchvision.transforms , or try the search function .
Example #1
Source File: data_loader.py    From self-supervised-da with MIT License 7 votes vote down vote up
def get_rot_train_transformers(args):
    size = args.img_transform.random_resize_crop.size
    scale = args.img_transform.random_resize_crop.scale
    img_tr = [transforms.RandomResizedCrop((int(size[0]), int(size[1])), (scale[0], scale[1]))]
    if args.img_transform.random_horiz_flip > 0.0:
        img_tr.append(transforms.RandomHorizontalFlip(args.img_transform.random_horiz_flip))
    if args.img_transform.jitter > 0.0:
        img_tr.append(transforms.ColorJitter(
            brightness=args.img_transform.jitter, contrast=args.img_transform.jitter,
            saturation=args.jitter, hue=min(0.5, args.jitter)))

    mean = args.normalize.mean
    std = args.normalize.std
    img_tr += [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]

    return transforms.Compose(img_tr) 
Example #2
Source File: datasets.py    From garbageClassifier with MIT License 6 votes vote down vote up
def __init__(self, data_dir, image_size, is_train=True, **kwargs):
		self.image_size = image_size
		self.image_paths = []
		self.image_labels = []
		self.classes = sorted(os.listdir(data_dir))
		for idx, cls_ in enumerate(self.classes):
			self.image_paths += glob.glob(os.path.join(data_dir, cls_, '*.*'))
			self.image_labels += [idx] * len(glob.glob(os.path.join(data_dir, cls_, '*.*')))
		self.indexes = list(range(len(self.image_paths)))
		if is_train:
			random.shuffle(self.indexes)
			self.transform = transforms.Compose([transforms.RandomResizedCrop(image_size),
												 transforms.RandomHorizontalFlip(),
												 transforms.ColorJitter(brightness=1, contrast=1, saturation=0.5, hue=0.5),
												 transforms.ToTensor(),
												 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
		else:
			self.transform = transforms.Compose([transforms.ToTensor(),
												 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 
Example #3
Source File: utils.py    From NAS-Benchmark with GNU General Public License v3.0 6 votes vote down vote up
def data_transforms_food101():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(128), # default bilinear for interpolation
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])

    valid_transform = transforms.Compose([
                transforms.Resize(128),
                transforms.CenterCrop(128),
                transforms.ToTensor(),
                normalize,
            ])

    return train_transform, valid_transform 
Example #4
Source File: data_loader_stargan.py    From adversarial-object-removal with MIT License 6 votes vote down vote up
def __init__(self, transform, mode, select_attrs=[], out_img_size=64, bbox_out_size=32, randomrotate=0, scaleRange=[0.1, 0.9], squareAspectRatio=False, use_celeb=False):
        self.image_path = os.path.join('data','mnist')
        self.mode = mode
        self.iouThresh = 0.5
        self.maxDigits= 1
        self.minDigits = 1
        self.use_celeb = use_celeb
        self.scaleRange = scaleRange
        self.squareAspectRatio = squareAspectRatio
        self.nc = 1 if not self.use_celeb else 3
        transList = [transforms.RandomHorizontalFlip(), transforms.RandomRotation(randomrotate,resample=Image.BICUBIC)]#, transforms.ColorJitter(0.5,0.5,0.5,0.3)
        self.digitTransforms = transforms.Compose(transList)
        self.dataset = MNIST(self.image_path,train=True, transform=self.digitTransforms) if not use_celeb else CelebDataset('./data/celebA/images', './data/celebA/list_attr_celeba.txt', self.digitTransforms, mode)
        self.num_data = len(self.dataset)
        self.metadata = {'images':[]}
        self.catid2attr = {}
        self.out_img_size = out_img_size
        self.bbox_out_size = bbox_out_size
        self.selected_attrs = select_attrs

        print ('Start preprocessing dataset..!')
        self.preprocess()
        print ('Finished preprocessing dataset..!') 
Example #5
Source File: datasets.py    From amdim-public with MIT License 6 votes vote down vote up
def __init__(self):
        # flipping image along vertical axis
        self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)
        # image augmentation functions
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        col_jitter = transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8)
        img_jitter = transforms.RandomApply([
            RandomTranslateWithReflect(4)], p=0.8)
        rnd_gray = transforms.RandomGrayscale(p=0.25)
        # main transform for self-supervised training
        self.train_transform = transforms.Compose([
            img_jitter,
            col_jitter,
            rnd_gray,
            transforms.ToTensor(),
            normalize
        ])
        # transform for testing
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ]) 
Example #6
Source File: utils.py    From NAS-Benchmark with GNU General Public License v3.0 6 votes vote down vote up
def data_transforms_imagenet():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,
            hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])

    valid_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(128),
                transforms.ToTensor(),
                normalize,
            ])

    return train_transform, valid_transform 
Example #7
Source File: datasets.py    From amdim-public with MIT License 6 votes vote down vote up
def __init__(self):
        # image augmentation functions
        self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)
        rand_crop = \
            transforms.RandomResizedCrop(128, scale=(0.3, 1.0), ratio=(0.7, 1.4),
                                         interpolation=INTERP)
        col_jitter = transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8)
        rnd_gray = transforms.RandomGrayscale(p=0.25)
        post_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.test_transform = transforms.Compose([
            transforms.Resize(146, interpolation=INTERP),
            transforms.CenterCrop(128),
            post_transform
        ])
        self.train_transform = transforms.Compose([
            rand_crop,
            col_jitter,
            rnd_gray,
            post_transform
        ]) 
Example #8
Source File: preprocess.py    From pytorch_quantization with MIT License 6 votes vote down vote up
def imgnet_transform(is_training=True):
    if is_training:
        transform_list = transforms.Compose([transforms.RandomResizedCrop(224),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.ColorJitter(brightness=0.5,
                                                                    contrast=0.5,
                                                                    saturation=0.3),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                  std=[0.229, 0.224, 0.225])])
    else:
        transform_list = transforms.Compose([transforms.Resize(256),
                                             transforms.CenterCrop(224),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                  std=[0.229, 0.224, 0.225])])
    return transform_list 
Example #9
Source File: imagenet.py    From nasnet-pytorch with MIT License 6 votes vote down vote up
def preprocess(self):
        if self.train:
            return transforms.Compose([
                transforms.RandomResizedCrop(self.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ])
        else:
            return transforms.Compose([
                transforms.Resize((int(self.image_size / 0.875), int(self.image_size / 0.875))),
                transforms.CenterCrop(self.image_size),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ]) 
Example #10
Source File: data_loader.py    From self-supervised-da with MIT License 6 votes vote down vote up
def get_jig_train_transformers(args):
    size = args.img_transform.random_resize_crop.size
    scale = args.img_transform.random_resize_crop.scale
    img_tr = [transforms.RandomResizedCrop((int(size[0]), int(size[1])), (scale[0], scale[1]))]
    if args.img_transform.random_horiz_flip > 0.0:
        img_tr.append(transforms.RandomHorizontalFlip(args.img_transform.random_horiz_flip))
    if args.img_transform.jitter > 0.0:
        img_tr.append(transforms.ColorJitter(
            brightness=args.img_transform.jitter, contrast=args.img_transform.jitter,
            saturation=args.jitter, hue=min(0.5, args.jitter)))

    tile_tr = []
    if args.jig_transform.tile_random_grayscale:
        tile_tr.append(transforms.RandomGrayscale(args.jig_transform.tile_random_grayscale))
    mean = args.normalize.mean
    std = args.normalize.std
    tile_tr = tile_tr + [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]

    return transforms.Compose(img_tr), transforms.Compose(tile_tr) 
Example #11
Source File: deep_globe.py    From GLNet with MIT License 6 votes vote down vote up
def __init__(self, root, ids, label=False, transform=False):
        super(DeepGlobe, self).__init__()
        """
        Args:

        fileDir(string):  directory with all the input images.
        transform(callable, optional): Optional transform to be applied on a sample
        """
        self.root = root
        self.label = label
        self.transform = transform
        self.ids = ids
        self.classdict = {1: "urban", 2: "agriculture", 3: "rangeland", 4: "forest", 5: "water", 6: "barren", 0: "unknown"}
        
        self.color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.04)
        self.resizer = transforms.Resize((2448, 2448)) 
Example #12
Source File: cifar10_cls_dataset.py    From imgclsmob with MIT License 6 votes vote down vote up
def cifar10_train_transform(ds_metainfo,
                            mean_rgb=(0.4914, 0.4822, 0.4465),
                            std_rgb=(0.2023, 0.1994, 0.2010),
                            jitter_param=0.4):
    assert (ds_metainfo is not None)
    assert (ds_metainfo.input_image_size[0] == 32)
    return transforms.Compose([
        transforms.RandomCrop(
            size=32,
            padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=jitter_param,
            contrast=jitter_param,
            saturation=jitter_param),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=mean_rgb,
            std=std_rgb)
    ]) 
Example #13
Source File: utils.py    From WS-DAN.PyTorch with MIT License 6 votes vote down vote up
def get_transform(resize, phase='train'):
    if phase == 'train':
        return transforms.Compose([
            transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
            transforms.RandomCrop(resize),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(brightness=0.126, saturation=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]) 
Example #14
Source File: utils_model.py    From HistoGAN with GNU General Public License v3.0 6 votes vote down vote up
def get_data_transforms():
	
	data_transforms = {
	    'train': transforms.Compose([
	        transforms.CenterCrop(config.patch_size),
	        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2),
	        transforms.RandomHorizontalFlip(),
	        transforms.RandomVerticalFlip(),
	        Random90Rotation(),
	        transforms.ToTensor(),
	        transforms.Normalize([0.7, 0.6, 0.7], [0.15, 0.15, 0.15]) #mean and standard deviations for lung adenocarcinoma resection slides
	    ]),
	    'val': transforms.Compose([
	        transforms.CenterCrop(config.patch_size),
	        transforms.ToTensor(),
	        transforms.Normalize([0.7, 0.6, 0.7], [0.15, 0.15, 0.15])
	    ]),
	    'unnormalize': transforms.Compose([
	        transforms.Normalize([1/0.15, 1/0.15, 1/0.15], [1/0.15, 1/0.15, 1/0.15])
	    ]),
	}

	return data_transforms

#printing the model 
Example #15
Source File: datasets.py    From nni with MIT License 6 votes vote down vote up
def build_train_transform(self, distort_color, resize_scale):
        print('Color jitter: %s' % distort_color)
        if distort_color == 'strong':
            color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
        elif distort_color == 'normal':
            color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
        else:
            color_transform = None
        if color_transform is None:
            train_transforms = transforms.Compose([
                transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                self.normalize,
            ])
        else:
            train_transforms = transforms.Compose([
                transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
                transforms.RandomHorizontalFlip(),
                color_transform,
                transforms.ToTensor(),
                self.normalize,
            ])
        return train_transforms 
Example #16
Source File: preprocessing.py    From pytorch_DoReFaNet with MIT License 6 votes vote down vote up
def imgnet_transform(is_training=True):
  if is_training:
    transform_list = transforms.Compose([transforms.RandomResizedCrop(224),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ColorJitter(brightness=0.5,
                                                                contrast=0.5,
                                                                saturation=0.3),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                              std=[0.229, 0.224, 0.225])])
  else:
    transform_list = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                              std=[0.229, 0.224, 0.225])])
  return transform_list 
Example #17
Source File: utils.py    From NAS-Benchmark with GNU General Public License v3.0 5 votes vote down vote up
def data_transforms_large(dataset, cutout_length):
    dataset = dataset.lower()
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]
    transf_train = [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,
            hue=0.2)
    ]
    transf_val = [
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ]
    train_transform = transforms.Compose(transf_train + normalize)
    valid_transform = transforms.Compose(
        transf_val + normalize)  # FIXME validation is not set to square proportions, is this an issue?

    if cutout_length > 0:
        train_transform.transforms.append(Cutout(cutout_length))

    return train_transform, valid_transform 
Example #18
Source File: dataset.py    From EAST with MIT License 5 votes vote down vote up
def __getitem__(self, index):
		with open(self.gt_files[index], 'r') as f:
			lines = f.readlines()
		vertices, labels = extract_vertices(lines)
		
		img = Image.open(self.img_files[index])
		img, vertices = adjust_height(img, vertices) 
		img, vertices = rotate_img(img, vertices)
		img, vertices = crop_img(img, vertices, labels, self.length) 
		transform = transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.25), \
                                        transforms.ToTensor(), \
                                        transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
		
		score_map, geo_map, ignored_map = get_score_geo(img, vertices, labels, self.scale, self.length)
		return transform(img), score_map, geo_map, ignored_map 
Example #19
Source File: utils.py    From instance-segmentation-pytorch with GNU General Public License v3.0 5 votes vote down vote up
def image_random_color_jitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2):
        return transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 
Example #20
Source File: train.py    From Holocron with MIT License 5 votes vote down vote up
def load_data(datadir):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("Loading training data")
    st = time.time()
    dataset = VOCDetection(datadir, image_set='train', download=True,
                           transforms=Compose([VOCTargetTransform(classes),
                                              Resize(512), RandomResizedCrop(416), RandomHorizontalFlip(),
                                              convert_to_relative,
                                              ImageTransform(transforms.ColorJitter(brightness=0.3, contrast=0.3,
                                                                                    saturation=0.1, hue=0.02)),
                                              ImageTransform(transforms.ToTensor()), ImageTransform(normalize)]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    dataset_test = VOCDetection(datadir, image_set='val', download=True,
                                transforms=Compose([VOCTargetTransform(classes),
                                                    Resize(416), CenterCrop(416),
                                                    convert_to_relative,
                                                    ImageTransform(transforms.ToTensor()), ImageTransform(normalize)]))

    print("Took", time.time() - st)
    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler 
Example #21
Source File: train.py    From Holocron with MIT License 5 votes vote down vote up
def load_data(datadir):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    base_size = 320
    crop_size = 256

    min_size = int(0.5 * base_size)
    max_size = int(2.0 * base_size)

    print("Loading training data")
    st = time.time()
    dataset = VOCSegmentation(datadir, image_set='train', download=True,
                              transforms=Compose([RandomResize(min_size, max_size),
                                                  RandomHorizontalFlip(0.5),
                                                  RandomCrop(crop_size),
                                                  SampleTransform(transforms.ColorJitter(brightness=0.3,
                                                                                         contrast=0.3,
                                                                                         saturation=0.1,
                                                                                         hue=0.02)),
                                                  ToTensor(),
                                                  SampleTransform(normalize)]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    dataset_test = VOCSegmentation(datadir, image_set='val', download=True,
                                   transforms=Compose([RandomResize(base_size, base_size),
                                                       ToTensor(),
                                                       SampleTransform(normalize)]))

    print("Took", time.time() - st)
    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler 
Example #22
Source File: charadesrgb.py    From actor-observer with GNU General Public License v3.0 5 votes vote down vote up
def get(cls, args, scale=(0.08, 1.0)):
        """ Entry point. Call this function to get all Charades dataloaders """
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_file = args.train_file
        val_file = args.val_file
        train_dataset = cls(
            args.data, 'train', train_file, args.cache, args.cache_buster,
            transform=transforms.Compose([
                transforms.RandomResizedCrop(args.inputsize, scale),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),  # missing PCA lighting jitter
                normalize,
            ]))
        val_dataset = cls(
            args.data, 'val', val_file, args.cache, args.cache_buster,
            transform=transforms.Compose([
                transforms.Resize(int(256. / 224 * args.inputsize)),
                transforms.CenterCrop(args.inputsize),
                transforms.ToTensor(),
                normalize,
            ]))
        valvideo_dataset = cls(
            args.data, 'val_video', val_file, args.cache, args.cache_buster,
            transform=transforms.Compose([
                transforms.Resize(int(256. / 224 * args.inputsize)),
                transforms.CenterCrop(args.inputsize),
                transforms.ToTensor(),
                normalize,
            ]))
        return train_dataset, val_dataset, valvideo_dataset 
Example #23
Source File: train.py    From Holocron with MIT License 5 votes vote down vote up
def load_data(traindir, valdir, half=False):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("Loading training data")
    st = time.time()
    dataset = torchvision.datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.1, hue=0.02),
            transforms.ToTensor(),
            normalize,
            # transforms.RandomErasing(p=0.9, value='random')
        ]))
    print("Took", time.time() - st)

    print("Loading validation data")
    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ]))

    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler 
Example #24
Source File: datasets.py    From amdim-public with MIT License 5 votes vote down vote up
def __init__(self):
        # flipping image along vertical axis
        self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)
        normalize = transforms.Normalize(mean=(0.43, 0.42, 0.39), std=(0.27, 0.26, 0.27))
        # image augmentation functions
        col_jitter = transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8)
        rnd_gray = transforms.RandomGrayscale(p=0.25)
        rand_crop = \
            transforms.RandomResizedCrop(64, scale=(0.3, 1.0), ratio=(0.7, 1.4),
                                         interpolation=INTERP)

        self.test_transform = transforms.Compose([
            transforms.Resize(70, interpolation=INTERP),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            normalize
        ])

        self.train_transform = transforms.Compose([
            rand_crop,
            col_jitter,
            rnd_gray,
            transforms.ToTensor(),
            normalize
        ]) 
Example #25
Source File: model.py    From derplearning with MIT License 5 votes vote down vote up
def compose_transforms(transform_config):
    """ Apply all image transforms """
    transform_list = []
    for perturb_config in transform_config:
        if perturb_config["name"] == "colorjitter":
            transform = transforms.ColorJitter(
                brightness=perturb_config["brightness"],
                contrast=perturb_config["contrast"],
                saturation=perturb_config["saturation"],
                hue=perturb_config["hue"],
            )
            transform_list.append(transform)
    transform_list.append(transforms.ToTensor())
    return transforms.Compose(transform_list) 
Example #26
Source File: transforms.py    From RCRNet-Pytorch with MIT License 5 votes vote down vote up
def __init__(self, image_mode, **kwargs):
        super(ColorJitter, self).__init__(**kwargs)
        self.transform = None
        self.image_mode = image_mode 
Example #27
Source File: imagenet.py    From nni with MIT License 5 votes vote down vote up
def _imagenet_dataset(config):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dir = os.path.join(config.data_dir, "train")
    test_dir = os.path.join(config.data_dir, "val")
    if hasattr(config, "use_aa") and config.use_aa:
        train_data = dset.ImageFolder(
            train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                ImageNetPolicy(),
                transforms.ToTensor(),
                normalize,
            ]))
    else:
        train_data = dset.ImageFolder(
            train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4,
                    contrast=0.4,
                    saturation=0.4,
                    hue=0.2),
                transforms.ToTensor(),
                normalize,
            ]))

    test_data = dset.ImageFolder(
        test_dir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    return train_data, test_data 
Example #28
Source File: augmentation.py    From nni with MIT License 5 votes vote down vote up
def test_color_trans():
    img_id = '00abc623a.jpg'
    img = Image.open(os.path.join(settings.TRAIN_IMG_DIR, img_id)).convert('RGB')
    trans = ColorJitter(0.1, 0.1, 0.1, 0.1)

    img2 = trans(img)
    img.show()
    img2.show() 
Example #29
Source File: loader.py    From nni with MIT License 5 votes vote down vote up
def get_train_loaders(ifold, batch_size=8, dev_mode=False, pad_mode='edge', meta_version=1, pseudo_label=False, depths=False):
    train_shuffle = True
    train_meta, val_meta = get_nfold_split(ifold, nfold=10, meta_version=meta_version)

    if pseudo_label:
        test_meta = get_test_meta()
        train_meta = train_meta.append(test_meta, sort=True)

    if dev_mode:
        train_shuffle = False
        train_meta = train_meta.iloc[:10]
        val_meta = val_meta.iloc[:10]
    #print(val_meta[X_COLUMN].values[:5])
    #print(val_meta[Y_COLUMN].values[:5])
    print(train_meta.shape, val_meta.shape)
    img_mask_aug_train, img_mask_aug_val = get_img_mask_augments(pad_mode, depths)

    train_set = ImageDataset(True, train_meta,
                            augment_with_target=img_mask_aug_train,
                            image_augment=transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
                            image_transform=get_image_transform(pad_mode),
                            mask_transform=get_mask_transform(pad_mode))

    train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=train_shuffle, num_workers=4, collate_fn=train_set.collate_fn, drop_last=True)
    train_loader.num = len(train_set)

    val_set = ImageDataset(True, val_meta,
                            augment_with_target=img_mask_aug_val,
                            image_augment=None,
                            image_transform=get_image_transform(pad_mode),
                            mask_transform=get_mask_transform(pad_mode))
    val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=val_set.collate_fn)
    val_loader.num = len(val_set)
    val_loader.y_true = read_masks(val_meta[ID_COLUMN].values)

    return train_loader, val_loader 
Example #30
Source File: face_sketch_data.py    From Face-Sketch-Wild with MIT License 5 votes vote down vote up
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharp=0.0):
        super(ColorJitter, self).__init__(brightness, contrast, saturation, hue)
        self.sharp = sharp