from glob import glob
import os
import numpy as np
import random
import torch.utils.data as data
import json
import cv2

def make_dataset(root, is_train):
  if is_train:
    folder = 'balls_n4_t60_ex50000'
  else:
    folder = 'balls_n4_t60_ex2000'

  dataset = np.load(os.path.join(root, folder, 'dataset_info.npy'))
  return dataset

class BouncingBalls(data.Dataset):
  '''
  Bouncing balls dataset.
  '''
  def __init__(self, root, is_train, n_frames_input, n_frames_output, image_size,
               transform=None, return_positions=False):
    super(BouncingBalls, self).__init__()
    self.n_frames = n_frames_input + n_frames_output
    self.dataset = make_dataset(root, is_train)

    self.size = image_size
    self.scale = self.size / 800
    self.radius = int(60 * self.scale)

    self.root = root
    self.is_train = is_train
    self.n_frames_input = n_frames_input
    self.n_frames_output = n_frames_output
    self.transform = transform
    self.return_positions = return_positions

  def __getitem__(self, idx):
    # traj sizeL (n_frames, n_balls, 4)
    traj = self.dataset[idx]
    vid_len, n_balls = traj.shape[:2]
    if self.is_train:
      start = random.randint(0, vid_len - self.n_frames)
    else:
      start = 0

    n_channels = 1
    images = np.zeros([self.n_frames, self.size, self.size, n_channels], np.uint8)
    positions = []
    for fid in range(self.n_frames):
      xy = []
      for bid in range(n_balls):
        # each ball:
        ball = traj[start + fid, bid]
        x, y = int(round(self.scale * ball[0])), int(round(self.scale * ball[1]))
        images[fid] = cv2.circle(images[fid], (x, y), int(self.radius * ball[3]),
                                 255, -1)
        xy.append([x / self.size, y / self.size])
      positions.append(xy)

    if self.transform is not None:
      images = self.transform(images)

    input = images[:self.n_frames_input]
    if self.n_frames_output > 0:
      output = images[self.n_frames_input:]
    else:
      output = []

    if not self.return_positions:
      return input, output
    else:
      positions = np.array(positions)
      return input, output, positions

  def __len__(self):
    return len(self.dataset)