import torch
from torchvision import transforms
from torch.autograd import Variable

class NormalizeImageDict(object):
    """
    
    Normalizes Tensor images in dictionary
    
    Args:
        image_keys (list): dict. keys of the images to be normalized
        normalizeRange (bool): if True the image is divided by 255.0s
    
    """
    
    def __init__(self,image_keys,normalizeRange=True):
        self.image_keys = image_keys
        self.normalizeRange=normalizeRange
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        
    def __call__(self, sample):
        for key in self.image_keys:
            if self.normalizeRange:
                sample[key] /= 255.0                
            sample[key] = self.normalize(sample[key])
        return  sample
    
def normalize_image(image, forward=True, mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]):
        im_size = image.size()
        mean=torch.FloatTensor(mean).unsqueeze(1).unsqueeze(2)
        std=torch.FloatTensor(std).unsqueeze(1).unsqueeze(2)
        if image.is_cuda:
            mean = mean.cuda()
            std = std.cuda()
        if isinstance(image,torch.autograd.variable.Variable):
            mean = Variable(mean,requires_grad=False)
            std = Variable(std,requires_grad=False)
        if forward:
            if len(im_size)==3:
                result = image.sub(mean.expand(im_size)).div(std.expand(im_size))
            elif len(im_size)==4:
                result = image.sub(mean.unsqueeze(0).expand(im_size)).div(std.unsqueeze(0).expand(im_size))
        else:
            if len(im_size)==3:
                result = image.mul(std.expand(im_size)).add(mean.expand(im_size))
            elif len(im_size)==4:
                result = image.mul(std.unsqueeze(0).expand(im_size)).add(mean.unsqueeze(0).expand(im_size))
                
        return  result