import numpy as np
from scipy.ndimage import rotate as scp_rotate
import cv2
import torch
from PIL import Image
from torchvision.transforms import functional as TF
from skimage.morphology import skeletonize


class Representation(object):
    """
    Intermediate representation object
    """

    def __init__(self, data=None, name=None):
        self.data = data
        self.name = name

    def set_data(self, data):
        self.data = data

    def shape(self):
        return (self.data).shape

    def rotate(self, angle, cval=0):
        self.data = scp_rotate(self.data, angle, reshape=False, order=0, mode='wrap', prefilter=False)

    def scale(self, ratio, interpolation='NEAREST'):
        h, w = self.data.shape[:2]
        tw = int(ratio * w)
        th = int(ratio * h)

        if interpolation == 'NEAREST':
            interpolation = cv2.INTER_NEAREST
        else:
            if ratio < 1:
                interpolation = cv2.INTER_LINEAR
            else:
                interpolation = cv2.INTER_CUBIC

        self.data = cv2.resize(self.data, dsize=(tw, th), interpolation=interpolation)

    def crop(self, x1, y1, tw, th):
        self.data = self.data[y1:y1 + th, x1:x1 + tw]

    def fliplr(self):
        self.data = np.fliplr(self.data)

    def to_tensor(self):
        self.data = torch.LongTensor(np.array(self.data, dtype=np.int))

    def normalize(self):
        return 1


class InputImage(Representation):
    """
    Image class
    """

    def __init__(self, data):
        super(InputImage, self).__init__(data=data, name='Image')
        # self.norm_mean = mean
        # self.norm_std = std

    def to_tensor(self):
        if isinstance(self.data, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(self.data)
        else:
            # handle PIL Image
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(self.data.tobytes()))
            # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
            if self.data.mode == 'YCbCr':
                nchannel = 3
            else:
                nchannel = len(self.data.mode)
            img = img.view(self.data.size[1], self.data.size[0], nchannel)
            img = img.transpose(0, 1).transpose(0, 2).contiguous()
        img = img.float().div(255)

        self.set_data(img)

    def shape(self):
        return (self.data).size

    def rotate(self, angle, cval=0):
        tmp = self.data.copy()
        tmp = np.array(tmp)
        tmp = scp_rotate(tmp, angle, reshape=False, order=0, mode='constant', cval=cval, prefilter=False)
        self.data = Image.fromarray(tmp)

    def scale(self, ratio):
        w, h = self.shape()
        tw = int(ratio * w)
        th = int(ratio * h)

        if ratio < 1:
            interpolation = Image.ANTIALIAS
        else:
            interpolation = Image.CUBIC

        self.data = (self.data).resize((tw, th), interpolation)

    def fliplr(self):
        self.data = (self.data).transpose(Image.FLIP_LEFT_RIGHT)

    def crop(self, x1, y1, tw, th):
        self.data = self.data.crop((x1, y1, x1 + tw, y1 + th))

    def gamma(self, gamma_ratio):
        self.data = TF.adjust_gamma(self.data, gamma_ratio, gain=1)

    def normalize(self, mean, std):
        mean = torch.FloatTensor(mean)
        std = torch.FloatTensor(std)

        image = self.data
        
        if image.device.type != 'cpu':
            means = [mean] * image.size()[0]
            stds = [std] * image.size()[0]
            for t, m, s in zip(image, means, stds):
                t.sub_(m[:, None, None].cuda()).div_(s[:, None, None].cuda())
        else:
            for t, m, s in zip(image, mean, std):
                t.sub_(m).div_(s)

        self.set_data(image)

        return 1


class Normals(Representation):
    """
    Normals: overwrite transforms to handle specificity of normals transforms
    """

    def __init__(self, data):
        super(Normals, self).__init__(data=data, name='normals')
        # normalize normals
        n = np.linalg.norm(self.data, 2, axis=2)
        self.data = self.data / (np.expand_dims(n, axis=2).clip(1e-4))

    def scale(self, ratio):
        # transform normals
        super(Normals, self).scale(ratio, interpolation='NEAREST')
        self.data[..., 2] *= ratio
        norm = np.linalg.norm(self.data, 2, axis=2)
        self.data = self.data / (np.expand_dims(norm, axis=2).clip(1e-4))

    def rotate(self, angle, cval=0):
        # rotating around Z axis does not affect Z normal
        rad_angle = np.deg2rad(angle)
        cos_angle = np.cos(rad_angle)
        sin_angle = np.sin(rad_angle)
        self.data[..., 0] = self.data[..., 0] * cos_angle - self.data[..., 1] * sin_angle
        self.data[..., 1] = self.data[..., 0] * sin_angle + self.data[..., 1] * cos_angle

        # normals
        # self.data = scp_rotate(self.data, angle, reshape=False, order=0, mode='constant', cval=cval, prefilter=False)
        self.data = scp_rotate(self.data, angle, reshape=False, order=0, mode='wrap', prefilter=False)

    def crop(self, x1, y1, tw, th):
        self.data = self.data[y1:y1 + th, x1:x1 + tw, :]

    def fliplr(self):
        self.data = np.fliplr(self.data)
        self.data[..., 0] = -1.0 * self.data[..., 0]

    def to_tensor(self):
        self.data = torch.FloatTensor(np.array((self.data).swapaxes(1, 2).swapaxes(0, 1), dtype=np.float32))


class Depth(Representation):
    """
    Depth: overwrite scale
    """

    def __init__(self, data):
        super(Depth, self).__init__(data=data, name='depth')

    def scale(self, ratio):
        super(Depth, self).scale(ratio, interpolation='NEAREST')
        self.data = self.data / ratio

    def to_tensor(self):
        self.data = torch.FloatTensor(np.array(self.data, dtype=np.float32))


class Contours(Representation):
    """
    Contours: overwrite scale to always have contours with 1 pixel width
    """

    def __init__(self, data):
        super(Contours, self).__init__(data=data, name='contours')

    def scale(self, ratio, interpolation='LINEAR'):
        h, w = self.data.shape[:2]
        tw = int(ratio * w)
        th = int(ratio * h)

        # solve the missed edges
        if ratio > 1:
            im = cv2.resize(self.data, dsize=(tw, th), interpolation=cv2.INTER_LINEAR_EXACT)
            im[im > 0.2] = 1
            im = skeletonize(im)
        else:
            im = cv2.resize(self.data, dsize=(tw, th), interpolation=cv2.INTER_LINEAR_EXACT)
            im[im > 0.4] = 1
            im = skeletonize(im)
        self.data = im.copy()


class Mask(Representation):
    """
    Mask:
    """

    def __init__(self, data):
        super(Mask, self).__init__(data=data, name='mask')

    def rotate(self, angle, cval=0):
        self.data = scp_rotate(self.data, angle, reshape=False, order=0, mode='constant', cval=cval, prefilter=False)