#!/usr/bin/python
# -*- coding: utf-8 -*-
# 
# Developed by Shangchen Zhou <shangchenzhou@gmail.com>
'''ref: http://pytorch.org/docs/master/torchvision/transforms.html'''


import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
from config import cfg
from PIL import Image
import random
import numbers
class Compose(object):
    """ Composes several co_transforms together.
    For example:
    >>> transforms.Compose([
    >>>     transforms.CenterCrop(10),
    >>>     transforms.ToTensor(),
    >>>  ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, seq_blur, seq_clear):
        for t in self.transforms:
            seq_blur, seq_clear = t(seq_blur, seq_clear)
        return seq_blur, seq_clear


class ColorJitter(object):
    def __init__(self, color_adjust_para):
        """brightness [max(0, 1 - brightness), 1 + brightness] or the given [min, max]"""
        """contrast [max(0, 1 - contrast), 1 + contrast] or the given [min, max]"""
        """saturation [max(0, 1 - saturation), 1 + saturation] or the given [min, max]"""
        """hue [-hue, hue] 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5"""
        '''Ajust brightness, contrast, saturation, hue'''
        '''Input: PIL Image, Output: PIL Image'''
        self.brightness, self.contrast, self.saturation, self.hue = color_adjust_para

    def __call__(self, seq_blur, seq_clear):
        seq_blur  = [Image.fromarray(np.uint8(img)) for img in seq_blur]
        seq_clear = [Image.fromarray(np.uint8(img)) for img in seq_clear]
        if self.brightness > 0:
            brightness_factor = np.random.uniform(max(0, 1 - self.brightness), 1 + self.brightness)
            seq_blur  = [F.adjust_brightness(img, brightness_factor) for img in seq_blur]
            seq_clear = [F.adjust_brightness(img, brightness_factor) for img in seq_clear]

        if self.contrast > 0:
            contrast_factor = np.random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
            seq_blur  = [F.adjust_contrast(img, contrast_factor) for img in seq_blur]
            seq_clear = [F.adjust_contrast(img, contrast_factor) for img in seq_clear]

        if self.saturation > 0:
            saturation_factor = np.random.uniform(max(0, 1 - self.saturation), 1 + self.saturation)
            seq_blur  = [F.adjust_saturation(img, saturation_factor) for img in seq_blur]
            seq_clear = [F.adjust_saturation(img, saturation_factor) for img in seq_clear]

        if self.hue > 0:
            hue_factor = np.random.uniform(-self.hue, self.hue)
            seq_blur  = [F.adjust_hue(img, hue_factor) for img in seq_blur]
            seq_clear = [F.adjust_hue(img, hue_factor) for img in seq_clear]

        seq_blur  = [np.asarray(img) for img in seq_blur]
        seq_clear = [np.asarray(img) for img in seq_clear]

        seq_blur  = [img.clip(0,255) for img in seq_blur]
        seq_clear = [img.clip(0,255) for img in seq_clear]

        return seq_blur, seq_clear

class RandomColorChannel(object):
    def __call__(self, seq_blur, seq_clear):
        random_order = np.random.permutation(3)

        seq_blur  = [img[:,:,random_order] for img in seq_blur]
        seq_clear = [img[:,:,random_order] for img in seq_clear]

        return seq_blur, seq_clear

class RandomGaussianNoise(object):
    def __init__(self, gaussian_para):
        self.mu = gaussian_para[0]
        self.std_var = gaussian_para[1]

    def __call__(self, seq_blur, seq_clear):

        shape = seq_blur[0].shape
        gaussian_noise = np.random.normal(self.mu, self.std_var, shape)
        # only apply to blurry images
        seq_blur = [img + gaussian_noise for img in seq_blur]
        seq_blur = [img.clip(0, 1) for img in seq_blur]

        return seq_blur, seq_clear

class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std  = std
    def __call__(self, seq_blur, seq_clear):
        seq_blur  = [img/self.std -self.mean for img in seq_blur]
        seq_clear = [img/self.std -self.mean for img in seq_clear]

        return seq_blur, seq_clear

class CenterCrop(object):

    def __init__(self, crop_size):
        """Set the height and weight before and after cropping"""

        self.crop_size_h  = crop_size[0]
        self.crop_size_w  = crop_size[1]

    def __call__(self, seq_blur, seq_clear):
        input_size_h, input_size_w, _ = seq_blur[0].shape
        x_start = int(round((input_size_w - self.crop_size_w) / 2.))
        y_start = int(round((input_size_h - self.crop_size_h) / 2.))

        seq_blur  = [img[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w] for img in seq_blur]
        seq_clear = [img[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w] for img in seq_clear]

        return seq_blur, seq_clear

class RandomCrop(object):

    def __init__(self, crop_size):
        """Set the height and weight before and after cropping"""
        self.crop_size_h  = crop_size[0]
        self.crop_size_w  = crop_size[1]

    def __call__(self, seq_blur, seq_clear):
        input_size_h, input_size_w, _ = seq_blur[0].shape
        x_start = random.randint(0, input_size_w - self.crop_size_w)
        y_start = random.randint(0, input_size_h - self.crop_size_h)

        seq_blur  = [img[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w] for img in seq_blur]
        seq_clear = [img[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w] for img in seq_clear]

        return seq_blur, seq_clear

class RandomHorizontalFlip(object):
    """Randomly horizontally flips the given PIL.Image with a probability of 0.5 left-right"""

    def __call__(self, seq_blur, seq_clear):
        if random.random() < 0.5:
            '''Change the order of 0 and 1, for keeping the net search direction'''
            seq_blur  = [np.copy(np.fliplr(img)) for img in seq_blur]
            seq_clear = [np.copy(np.fliplr(img)) for img in seq_clear]

        return seq_blur, seq_clear


class RandomVerticalFlip(object):
    """Randomly vertically flips the given PIL.Image with a probability of 0.5  up-down"""

    def __call__(self, seq_blur, seq_clear):
        if random.random() < 0.5:
            seq_blur  = [np.copy(np.flipud(img)) for img in seq_blur]
            seq_clear = [np.copy(np.flipud(img)) for img in seq_clear]

        return seq_blur, seq_clear


class ToTensor(object):
    """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""

    def __call__(self, seq_blur, seq_clear):
        seq_blur  = [np.transpose(img, (2, 0, 1)) for img in seq_blur]
        seq_clear = [np.transpose(img, (2, 0, 1)) for img in seq_clear]
        # handle numpy array
        seq_blur_tensor  = [torch.from_numpy(img).float() for img in seq_blur]
        seq_clear_tensor = [torch.from_numpy(img).float() for img in seq_clear]

        return seq_blur_tensor, seq_clear_tensor