import gzip
import math
import numpy as np
import os
from PIL import Image
import random
import torch
import torch.utils.data as data

def load_mnist(root):
  # Load MNIST dataset for generating training data.
  path = os.path.join(root, 'train-images-idx3-ubyte.gz')
  with gzip.open(path, 'rb') as f:
    mnist = np.frombuffer(f.read(), np.uint8, offset=16)
    mnist = mnist.reshape(-1, 28, 28)
  return mnist

def load_fixed_set(root, is_train):
  # Load the fixed dataset
  filename = 'mnist_test_seq.npy'
  path = os.path.join(root, filename)
  dataset = np.load(path)
  dataset = dataset[..., np.newaxis]
  return dataset

class MovingMNIST(data.Dataset):
  def __init__(self, root, is_train, n_frames_input, n_frames_output, num_objects,
               transform=None):
    '''
    param num_objects: a list of number of possible objects.
    '''
    super(MovingMNIST, self).__init__()

    self.dataset = None
    if is_train:
      self.mnist = load_mnist(root)
    else:
      if num_objects[0] != 2:
        self.mnist = load_mnist(root)
      else:
        self.dataset = load_fixed_set(root, False)
    self.length = int(1e4) if self.dataset is None else self.dataset.shape[1]

    self.is_train = is_train
    self.num_objects = num_objects
    self.n_frames_input = n_frames_input
    self.n_frames_output = n_frames_output
    self.n_frames_total = self.n_frames_input + self.n_frames_output
    self.transform = transform
    # For generating data
    self.image_size_ = 64
    self.digit_size_ = 28
    self.step_length_ = 0.1

  def get_random_trajectory(self, seq_length):
    ''' Generate a random sequence of a MNIST digit '''
    canvas_size = self.image_size_ - self.digit_size_
    x = random.random()
    y = random.random()
    theta = random.random() * 2 * np.pi
    v_y = np.sin(theta)
    v_x = np.cos(theta)

    start_y = np.zeros(seq_length)
    start_x = np.zeros(seq_length)
    for i in range(seq_length):
      # Take a step along velocity.
      y += v_y * self.step_length_
      x += v_x * self.step_length_

      # Bounce off edges.
      if x <= 0:
        x = 0
        v_x = -v_x
      if x >= 1.0:
        x = 1.0
        v_x = -v_x
      if y <= 0:
        y = 0
        v_y = -v_y
      if y >= 1.0:
        y = 1.0
        v_y = -v_y
      start_y[i] = y
      start_x[i] = x

    # Scale to the size of the canvas.
    start_y = (canvas_size * start_y).astype(np.int32)
    start_x = (canvas_size * start_x).astype(np.int32)
    return start_y, start_x

  def generate_moving_mnist(self, num_digits=2):
    '''
    Get random trajectories for the digits and generate a video.
    '''
    data = np.zeros((self.n_frames_total, self.image_size_, self.image_size_), dtype=np.float32)
    for n in range(num_digits):
      # Trajectory
      start_y, start_x = self.get_random_trajectory(self.n_frames_total)
      ind = random.randint(0, self.mnist.shape[0] - 1)
      digit_image = self.mnist[ind]
      for i in range(self.n_frames_total):
        top    = start_y[i]
        left   = start_x[i]
        bottom = top + self.digit_size_
        right  = left + self.digit_size_
        # Draw digit
        data[i, top:bottom, left:right] = np.maximum(data[i, top:bottom, left:right], digit_image)

    data = data[..., np.newaxis]
    return data

  def __getitem__(self, idx):
    length = self.n_frames_input + self.n_frames_output
    if self.is_train or self.num_objects[0] != 2:
      # Sample number of objects
      num_digits = random.choice(self.num_objects)
      # Generate data on the fly
      images = self.generate_moving_mnist(num_digits)
    else:
      images = self.dataset[:, idx, ...]

    if self.transform is not None:
      images = self.transform(images)
    input = images[:self.n_frames_input]
    if self.n_frames_output > 0:
      output = images[self.n_frames_input:length]
    else:
      output = []

    return input, output

  def __len__(self):
    return self.length