import numpy as np
import random
import cv2
from util import rotate_and_crop, AsyncTaskManager


class DataProvider(object):

  def __init__(self,
               data,
               output_size=-1,
               limit=-1,
               synchronous=False,
               augmentation=0,
               bnw=False,
               blur=False,
               default_batch_size=64,
               train=True,
               seperation=None,
               image_scaling=1.0,
               *args,
               **kwargs):
    print((data.shape))
    self.blur = blur
    if limit == -1:
      limit = data.shape[0]
    elif isinstance(limit, float):
      limit = int(data.shape[0] * limit)
    else:
      limit = limit
    self.image_scaling = image_scaling
    self.data = data[:limit]
    if seperation is not None:
      seperator = int(round(len(self.data) * seperation))
      if train:
        self.data = self.data[:seperator]
      else:
        self.data = self.data[seperator:]
    self.bnw = bnw
    if self.bnw:
      self.data = 0.27 * self.data[:, :, :,
                                   0] + 0.67 * self.data[:, :, :,
                                                         1] + 0.06 * self.data[:, :, :,
                                                                               2]
      self.data = self.data[:, :, :, None]
    self.num_images = len(self.data)
    self.default_batch_size = default_batch_size
    self.image_size = data.shape[1:3]
    self.augmentation = augmentation
    self.indices = list(range(self.num_images))
    random.shuffle(self.indices)
    self.synchronous = synchronous
    self.async_task = None
    if output_size == -1:
      self.output_size = data.shape[1:3]
    else:
      self.output_size = (output_size, output_size)

  def augment(self, img, strength):
    s = self.output_size[0]
    start_x = random.randrange(0, img.shape[0] - s + 1)
    start_y = random.randrange(0, img.shape[1] - s + 1)
    img = img[start_x:start_x + s, start_y:start_y + s]
    ### No resizing and rotating....
    # img = rotate_and_crop(img, (random.random() - 0.5) * strength * 300)
    # img = cv2.resize(img, self.output_size)
    if random.random() < 0.5:
      # left-right flip
      img = img[:, ::-1]
    if len(img.shape) < 3:
      img = img[:, :, None]
    if self.blur:
      angle = random.uniform(-1, 1) * 10
      # img = cv2.GaussianBlur(img, (3, 3), 0)
      img = rotate_and_crop(img, angle)
      img = rotate_and_crop(img, -angle)
      img = cv2.resize(img, dsize=self.output_size)
    return img

  def get_next_batch_(self, batch_size):
    batch = []
    while len(batch) < batch_size:
      s = min(len(self.indices), batch_size - len(batch))
      batch += self.indices[:s]
      self.indices = self.indices[s:]
      if len(self.indices) == 0:
        self.indices = list(range(self.num_images))
        random.shuffle(self.indices)
    batch_images = np.empty(
        (batch_size,) + self.output_size + self.data.shape[3:],
        dtype=self.data.dtype)
    if self.augmentation > 0:
      for i in range(len(batch)):
        batch_images[i] = self.augment(self.data[batch[i]], self.augmentation)
    else:
      for i in range(len(batch)):
        batch_images[i] = cv2.resize(self.data[batch[i]], self.output_size)
    batch = np.array(batch)

    ## Hao
    return batch_images * self.image_scaling, np.zeros((batch_size,))
    # print(batch.shape)

    # return batch_images * self.image_scaling, batch # np.zeros((batch_size,))

  def get_next_batch(self, batch_size):
    if self.synchronous or (self.async_task and
                            batch_size != self.default_batch_size):
      return self.get_next_batch_(batch_size)
    else:
      if self.async_task is None:
        self.async_task = AsyncTaskManager(
            target=self.get_next_batch_, args=(self.default_batch_size,))
      if batch_size != self.default_batch_size:
        ret = self.get_next_batch_(batch_size)
      else:
        ret = self.async_task.get_next()
      return ret

  def get_random_batch(self, batch_size):
    indices = list(range(self.num_images))
    random.shuffle(indices)
    indices = indices[:batch_size]
    return self.data[indices], np.zeros((self.num_images,))

  # Returns a list of image batches
  # the last one may not be a full batch
  def get_test_batches(self, batch_size):
    batches = []
    for i in range((len(self.data) + batch_size - 1) // batch_size):
      batch = []
      for img in self.data[i * batch_size:(i + 1) * batch_size]:
        img *= self.image_scaling
        if self.augmentation > 0:
          batch.append(self.augment(img, self.augmentation))
        else:
          batch.append(cv2.resize(img, self.output_size))
      batch = np.stack(batch, axis=0)
      batches.append(batch)
    return batches, None