import os
import random
import torch
import numpy as np
import PIL.Image as Image

from torch.utils.data import Dataset
from torchvision import transforms, utils

class loadedDataset(Dataset):
	def __init__(self, root_dir, transform=None):
		self.root_dir = root_dir
		self.transform = transform
		self.classes = sorted(os.listdir(self.root_dir))
		self.count = [len(os.listdir(self.root_dir + '/' + c)) for c in self.classes]
		self.acc_count = [self.count[0]]
		for i in range(1, len(self.count)):
				self.acc_count.append(self.acc_count[i-1] + self.count[i])
		# self.acc_count = [self.count[i] + self.acc_count[i-1] for i in range(1, len(self.count))]

	def __len__(self):
		l = np.sum(np.array([len(os.listdir(self.root_dir + '/' + c)) for c in self.classes]))
		return l

	def __getitem__(self, idx):
		for i in range(len(self.acc_count)):
			if idx < self.acc_count[i]:
				label = i
				break

		class_path = self.root_dir + '/' + self.classes[label] 

		if label:
			file_path = class_path + '/' + sorted(os.listdir(class_path))[idx-self.acc_count[label]]
		else:
			file_path = class_path + '/' + sorted(os.listdir(class_path))[idx]

		_, file_name = os.path.split(file_path)

		frames = []

		# print os.listdir(file_path)
		file_list = sorted(os.listdir(file_path))
		# print file_list

		# v: maximum translation in every step
		v = 2
		offset = 0
		for i, f in enumerate(file_list):
			frame = Image.open(file_path + '/' + f)
			#translation
			offset += random.randrange(-v, v)
			offset = min(offset, 3 * v)
			offset = max(offset, -3 * v)
			frame = frame.transform(frame.size, Image.AFFINE, (1, 0, offset, 0, 1, 0))
			if self.transform is not None:
				frame = self.transform[0](frame)
			frames.append(frame)

		return frames, label, file_name