import collections
import math
import torch
import random
import numpy as np
import numbers
import cv2
from PIL import Image
import torchvision.transforms.functional as F


def resize(video, size, interpolation):
  if interpolation == 'bilinear':
    inter = cv2.INTER_LINEAR
  elif interpolation == 'nearest':
    inter = cv2.INTER_NEAREST
  else:
    raise NotImplementedError

  shape = video.shape[:-3]
  video = video.reshape((-1, *video.shape[-3:]))
  resized_video = np.zeros((video.shape[0], size[1], size[0], video.shape[-1]))
  for i in range(video.shape[0]):
    img = cv2.resize(video[i], size, inter)
    if len(img.shape) == 2:
      img = img[:, :, np.newaxis]
    resized_video[i] = img
  return resized_video.reshape((*shape, size[1], size[0], video.shape[-1]))


class ToTensor(object):
  """Converts a numpy.ndarray (... x H x W x C) in the range
  [0, 255] to a torch.FloatTensor of shape (... x C x H x W) in the range [0.0, 1.0].
  """
  def __init__(self, scale=True):
    self.scale = scale

  def __call__(self, arr):
    if isinstance(arr, np.ndarray):
      video = torch.from_numpy(np.rollaxis(arr, axis=-1, start=-3))

      if self.scale:
        return video.float().div(255)
      else:
        return video.float()
    else:
      raise NotImplementedError


class Normalize(object):
  """Given mean: (R, G, B) and std: (R, G, B),
  will normalize each channel of the torch.*Tensor, i.e.
  channel = (channel - mean) / std
  """

  def __init__(self, mean, std):
    if not isinstance(mean, list):
      mean = [mean]
    if not isinstance(std, list):
      std = [std]

    self.mean = torch.FloatTensor(mean).unsqueeze(1).unsqueeze(2)
    self.std = torch.FloatTensor(std).unsqueeze(1).unsqueeze(2)

  def __call__(self, tensor):
    return tensor.sub_(self.mean).div_(self.std)


class Scale(object):
  """Rescale the input numpy.ndarray to the given size.
  Args:
      size (sequence or int): Desired output size. If size is a sequence like
          (w, h), output size will be matched to this. If size is an int,
          smaller edge of the image will be matched to this number.
          i.e, if height > width, then image will be rescaled to
          (size * height / width, size)
      interpolation (int, optional): Desired interpolation. Default is
          ``bilinear``
  """

  def __init__(self, size, interpolation='bilinear'):
    assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
    self.size = size
    self.interpolation = interpolation

  def __call__(self, video):
    """
    Args:
        video (numpy.ndarray): Video to be scaled.
    Returns:
        numpy.ndarray: Rescaled video.
    """
    if isinstance(self.size, int):
      w, h = video.shape[-2], video.shape[-3]
      if (w <= h and w == self.size) or (h <= w and h == self.size):
        return video
      if w < h:
        ow = self.size
        oh = int(self.size*h/w)
        return resize(video, (ow, oh), self.interpolation)
      else:
        oh = self.size
        ow = int(self.size*w/h)
        return resize(video, (ow, oh), self.interpolation)
    else:
      return resize(video, self.size, self.interpolation)


class CenterCrop(object):
  """Crops the given numpy.ndarray at the center to have a region of
  the given size. size can be a tuple (target_height, target_width)
  or an integer, in which case the target will be of a square shape (size, size)
  """

  def __init__(self, size):
    if isinstance(size, numbers.Number):
      self.size = (int(size), int(size))
    else:
      self.size = size

  def __call__(self, video):
    h, w = video.shape[-3:-1]
    th, tw = self.size
    x1 = int(round((w-tw)/2.))
    y1 = int(round((h-th)/2.))

    return video[..., y1:y1+th, x1:x1+tw, :]


class Pad(object):
  """Pad the given np.ndarray on all sides with the given "pad" value.
  Args:
      padding (int or sequence): Padding on each border. If a sequence of
          length 4, it is used to pad left, top, right and bottom borders respectively.
      fill: Pixel fill value. Default is 0.
  """

  def __init__(self, padding, fill=0):
    assert isinstance(padding, numbers.Number)
    assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
    self.padding = padding
    self.fill = fill

  def __call__(self, video):
    """
    Args:
        video (np.ndarray): Video to be padded.
    Returns:
        np.ndarray: Padded video.
    """
    pad_width = ((0, 0), (self.padding, self.padding), (self.padding, self.padding), (0, 0))
    return np.pad(video, pad_width=pad_width, mode='constant', constant_values=self.fill)


class RandomCrop(object):
  """Crop the given numpy.ndarray at a random location.
  Args:
      size (sequence or int): Desired output size of the crop. If size is an
          int instead of sequence like (h, w), a square crop (size, size) is
          made.
      padding (int or sequence, optional): Optional padding on each border
          of the image. Default is 0, i.e no padding. If a sequence of length
          4 is provided, it is used to pad left, top, right, bottom borders
          respectively.
  """

  def __init__(self, size, padding=0):
    if isinstance(size, numbers.Number):
      self.size = (int(size), int(size))
    else:
      self.size = size
    self.padding = padding

  def __call__(self, video):
    """
    Args:
        video (np.ndarray): Video to be cropped.
    Returns:
        np.ndarray: Cropped video.
    """
    if self.padding > 0:
      pad = Pad(self.padding, 0)
      video = pad(video)

    w, h = video.shape[-2], video.shape[-3]
    th, tw = self.size
    if w == tw and h == th:
      return video

    x1 = random.randint(0, w-tw)
    y1 = random.randint(0, h-th)
    return video[..., y1:y1+th, x1:x1+tw, :]


class RandomHorizontalFlip(object):
  """Randomly horizontally flips the given numpy.ndarray with a probability of 0.5
  """

  def __call__(self, video):
    if random.random() < 0.5:
      return video[..., ::-1, :].copy()
    return video


class RandomSizedCrop(object):
  """Crop the given np.ndarray to random size and aspect ratio.
  A crop of random size of (0.08 to 1.0) of the original size and a random
  aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
  is finally resized to given size.
  This is popularly used to train the Inception networks.
  Args:
      size: size of the smaller edge
      interpolation: Default: 'bilinear'
  """

  def __init__(self, size, interpolation='bilinear'):
    self.size = size
    self.interpolation = interpolation

  def __call__(self, video):
    for attempt in range(10):
      area = video.shape[-3]*video.shape[-2]
      target_area = random.uniform(0.08, 1.0)*area
      aspect_ratio = random.uniform(3./4, 4./3)

      w = int(round(math.sqrt(target_area*aspect_ratio)))
      h = int(round(math.sqrt(target_area/aspect_ratio)))

      if random.random() < 0.5:
        w, h = h, w

      if w <= video.shape[-2] and h <= video.shape[-3]:
        x1 = random.randint(0, video.shape[-2]-w)
        y1 = random.randint(0, video.shape[-3]-h)

        video = video[..., y1:y1+h, x1:x1+w, :]

        return resize(video, (self.size, self.size), self.interpolation)

    # Fallback
    scale = Scale(self.size, interpolation=self.interpolation)
    crop = CenterCrop(self.size)
    return crop(scale(video))

class ColorJitter(object):
  """Randomly change the brightness, contrast and saturation of an image.
  Args:
      brightness (float): How much to jitter brightness. brightness_factor
          is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
      contrast (float): How much to jitter contrast. contrast_factor
          is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
      saturation (float): How much to jitter saturation. saturation_factor
          is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
      hue(float): How much to jitter hue. hue_factor is chosen uniformly from
          [-hue, hue]. Should be >=0 and <= 0.5.
  """
  def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
    self.brightness = brightness
    self.contrast = contrast
    self.saturation = saturation
    self.hue = hue

  @staticmethod
  def get_params(brightness, contrast, saturation, hue):
    """Get a randomized transform to be applied on image.
    Arguments are same as that of __init__.
    Returns:
        Transform which randomly adjusts brightness, contrast and
        saturation in a random order.
    """
    transforms = []
    if brightness > 0:
      brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
      transforms.append(lambda img: F.adjust_brightness(img, brightness_factor))

    if contrast > 0:
      contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
      transforms.append(lambda img: F.adjust_contrast(img, contrast_factor))

    if saturation > 0:
      saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
      transforms.append(lambda img: F.adjust_saturation(img, saturation_factor))

    if hue > 0:
      hue_factor = random.uniform(-hue, hue)
      transforms.append(lambda img: F.adjust_hue(img, hue_factor))

    random.shuffle(transforms)

    return transforms

  def __call__(self, video):
    """
    Args:
        img (numpy array): Input image, shape (... x H x W x C), dtype uint8.
    Returns:
        PIL Image: Color jittered image.
    """
    transforms = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
    reshaped_video = video.reshape((-1, *video.shape[-3:]))
    n_channels = video.shape[-1]
    for i in range(reshaped_video.shape[0]):
      img = reshaped_video[i]
      if n_channels == 1:
        img = img.squeeze(axis=2)
      img = Image.fromarray(img)
      for t in transforms:
        img = t(img)
      img = np.array(img)
      if n_channels == 1:
        img = img[..., np.newaxis]
      reshaped_video[i] = img
    video = reshaped_video.reshape(video.shape)
    return video