import torch
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os


# inception preprocessing

class ImageNet(datasets.ImageFolder):
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]

    def __init__(self, root, train=True, image_size=224, transform=None, target_transform=None):
        self.root = os.path.expanduser(root)
        traindir = os.path.join(self.root, 'train')
        valdir = os.path.join(self.root, 'val')
        self.train = train
        self.image_size = image_size
        transform = transform or self.preprocess()
        super(ImageNet, self).__init__(train and traindir or valdir,
                                       transform=transform, target_transform=target_transform)

    def preprocess(self):
        if self.train:
            return transforms.Compose([
                transforms.RandomResizedCrop(self.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ])
        else:
            return transforms.Compose([
                transforms.Resize((int(self.image_size / 0.875), int(self.image_size / 0.875))),
                transforms.CenterCrop(self.image_size),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ])