from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps
import numpy as np
import numbers
import types
import scipy.ndimage as ndimage
try:
    import accimage
except:
    accimage = None


'''Set of tranform random routines that takes both input and target as arguments,
in order to have random but coherent transformations.
img are PIL Image pairs and targets are ndarrays'''


class Compose(object):
    """ Composes several co_transforms together.
    For example:
    >>> co_transforms.Compose([
    >>>     co_transforms.CenterCrop(10),
    >>>     co_transforms.ToTensor(),
    >>>  ])
    """

    def __init__(self, co_transforms):
        self.co_transforms = co_transforms

    def __call__(self, img, target):
        for t in self.co_transforms:
            img, target = t(img, target)
        return img, target

class Lambda(object):
    """Applies a lambda as a transform"""

    def __init__(self, lambd):
        assert type(lambd) is types.LambdaType
        self.lambd = lambd

    def __call__(self, input, target):
        return self.lambd(input, target)


class CenterCrop(object):
    """Crops the given img and target arrays at the center to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    Careful, img1 and img2 may not be the same size
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img, target):
        h1, w1, _ = img[0].shape
        h2, w2, _ = img[1].shape
        th, tw = self.size
        x1 = int(round((w1 - tw) / 2.))
        y1 = int(round((h1 - th) / 2.))
        x2 = int(round((w2 - tw) / 2.))
        y2 = int(round((h2 - th) / 2.))

        img[0] = img[0][y1: y1 + th, x1: x1 + tw]
        img[1] = img[1][y2: y2 + th, x2: x2 + tw]
        target = target[y1: y1 + th, x1: x1 + tw]
        return img, target


class Scale(object):
    """ Rescales the img and target arrays to the given 'size'.
    'size' will be the size of the smaller edge.
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
    size: size of the smaller edge
    interpolation order: Default: 2 (bilinear)
    """

    def __init__(self, size, order=2):
        self.size = size
        self.order = order

    def __call__(self, img, target):
        h, w, _ = img[0].shape
        if (w <= h and w == self.size) or (h <= w and h == self.size):
            return img, target
        if w < h:
            ratio = self.size / w
        else:
            ratio = self.size / h

        img[0] = ndimage.interpolation.zoom(img[0], ratio, order=self.order)
        img[1] = ndimage.interpolation.zoom(img[1], ratio, order=self.order)

        target = ndimage.interpolation.zoom(target, ratio, order=self.order)
        target *= ratio
        return img, target


class RandomCrop(object):
    """Crops the given PIL.Image at a random location to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img, target):
        h, w, _ = img[0].shape
        th, tw = self.size
        if w == tw and h == th:
            return img, target

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        img[0] = img[0][y1: y1 + th, x1: x1 + tw]
        img[1] = img[1][y1: y1 + th, x1: x1 + tw]
        return img, target[y1: y1 + th, x1: x1 + tw]


class RandomHorizontalFlip(object):
    """Randomly horizontally flips the given PIL.Image with a probability of 0.5
    """

    def __call__(self, img, target):
        if random.random() < 0.5:
            img = np.fliplr(img)
            #img[1] = np.fliplr(img[1])
            target = np.fliplr(target)
            #target[:, :, 0] *= -1
        return img, target


class RandomVerticalFlip(object):
    """Randomly horizontally flips the given PIL.Image with a probability of 0.5
    """

    def __call__(self, img, target):
        if random.random() < 0.5:
            img = np.flipud(img)
            #img[1] = np.flipud(img[1])
            target = np.flipud(target)
            #target[:, :, 1] *= -1
        return img, target


class RandomRotate(object):
    """Random rotation of the image from -angle to angle (in degrees)
    This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation
    angle: max angle of the rotation
    interpolation order: Default: 2 (bilinear)
    reshape: Default: false. If set to true, image size will be set to keep every pixel in the image.
    diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off.
    """

    def __init__(self, angle, diff_angle=0, order=2, reshape=False):
        self.angle = angle
        self.reshape = reshape
        self.order = order
        self.diff_angle = diff_angle

    def __call__(self, img, target):
        applied_angle = random.uniform(-self.angle, self.angle)
        diff = random.uniform(-self.diff_angle, self.diff_angle)
        angle1 = applied_angle - diff / 2
        angle2 = applied_angle + diff / 2

        angle1_rad = angle1 * np.pi / 180
        angle2_rad = angle2 * np.pi / 180

        # [Saeed] Removed _
        if np.ndim(target) == 2:
            target = np.expand_dims(target, 2)


        h, w, _= target.shape

        def rotate_flow(i, j, k):
            return -k * (j - w / 2) * (diff * np.pi / 180) + (1 - k) * (i - h / 2) * (diff * np.pi / 180)


        rotate_flow_map = np.fromfunction(rotate_flow, target.shape)
        target = target.astype(np.float64)
        target += rotate_flow_map

        img = ndimage.interpolation.rotate(img, angle1, reshape=self.reshape, order=self.order)
        target = ndimage.interpolation.rotate(target, angle1, reshape=self.reshape, order=self.order)

        target_ = np.array(target, copy=True)
        target[:, :, 0] = np.cos(angle1_rad) * target_[:, :, 0] + np.sin(angle1_rad) * target_[:, :, 0]
        target = target[:,:,0].astype(np.uint8)
        target [target > 1] = 255
        target [target < 0] = 0

        return img, target


#class RandomCropRotate(object):
#    """Random rotation of the image from -angle to angle (in degrees)
#    A crop is done to keep same image ratio, and no black pixels
#    angle: max angle of the rotation, cannot be more than 180 degrees
#    interpolation order: Default: 2 (bilinear)
#    """
#
#    def __init__(self, angle, size, diff_angle=0, order=2):
#        self.angle = angle
#        self.order = order
#        self.diff_angle = diff_angle
#        self.size = size
#
#    def __call__(self, img, target):
#        applied_angle = random.uniform(-self.angle, self.angle)
#        diff = random.uniform(-self.diff_angle, self.diff_angle)
#        angle1 = applied_angle - diff / 2
#        angle2 = applied_angle + diff / 2
#
#        angle1_rad = angle1 * np.pi / 180
#        angle2_rad = angle2 * np.pi / 180
#
#        h, w, _ = img[0].shape
#
#        def rotate_flow(i, j, k):
#            return -k * (j - w / 2) * (diff * np.pi / 180) + (1 - k) * (i - h / 2) * (diff * np.pi / 180)
#
#        rotate_flow_map = np.fromfunction(rotate_flow, target.shape)
#        target += rotate_flow_map
#
#        img[0] = ndimage.interpolation.rotate(img[0], angle1, reshape=True, order=self.order)
#        img[1] = ndimage.interpolation.rotate(img[1], angle2, reshape=True, order=self.order)
#        target = ndimage.interpolation.rotate(target, angle1, reshape=True, order=self.order)
#        # flow vectors must be rotated too!
#        target_ = np.array(target, copy=True)
#        target[:, :, 0] = np.cos(angle1_rad) * target_[:, :, 0] - np.sin(angle1_rad) * target_[:, :, 1]
#        target[:, :, 1] = np.sin(angle1_rad) * target_[:, :, 0] + np.cos(angle1_rad) * target_[:, :, 1]
#
#        # keep angle1 and angle2 within [0,pi/2] with a reflection at pi/2: -1rad is 1rad, 2rad is pi - 2 rad
#        angle1_rad = np.pi / 2 - np.abs(angle1_rad % np.pi - np.pi / 2)
#        angle2_rad = np.pi / 2 - np.abs(angle2_rad % np.pi - np.pi / 2)
#
#        c1 = np.cos(angle1_rad)
#        s1 = np.sin(angle1_rad)
#        c2 = np.cos(angle2_rad)
#        s2 = np.sin(angle2_rad)
#        c_diag = h / np.sqrt(h * h + w * w)
#        s_diag = w / np.sqrt(h * h + w * w)
#
#        ratio = c_diag / max(c1 * c_diag + s1 * s_diag, c2 * c_diag + s2 * s_diag)
#
#        crop = CenterCrop((int(h * ratio), int(w * ratio)))
#        scale = Scale(self.size)
#        img, target = crop(img, target)
#        return scale(img, target)
#
#
#class RandomTranslate(object):
#    def __init__(self, translation):
#        if isinstance(translation, numbers.Number):
#            self.translation = (int(translation), int(translation))
#        else:
#            self.translation = translation
#
#    def __call__(self, img, target):
#        h, w, _ = img[0].shape
#        th, tw = self.translation
#        tw = random.randint(-tw, tw)
#        th = random.randint(-th, th)
#        if tw == 0 and th == 0:
#            return img, target
#        # compute x1,x2,y1,y2 for img1 and target, and x3,x4,y3,y4 for img2
#        x1, x2, x3, x4 = max(0, tw), min(w + tw, w), max(0, -tw), min(w - tw, w)
#        y1, y2, y3, y4 = max(0, th), min(h + th, h), max(0, -th), min(h - th, h)
#
#        img[0] = img[0][y1:y2, x1:x2]
#        img[1] = img[1][y3:y4, x3:x4]
#        target = target[y1:y2, x1:x2]
#        target[:, :, 0] += tw
#        target[:, :, 1] += th
#
#        return img, target