import torch
import json
import os
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

data_folder = "./dataset/images"
press_times = json.load(open("./dataset/dataset.json"))

image_roots = [os.path.join(data_folder,image_file) \
                for image_file in os.listdir(data_folder)]


class JumpDataset(Dataset):

    def __init__(self,transform = None):
        self.image_roots = image_roots
        self.press_times = press_times
        self.transform = transform

    def __len__(self):
        return len(self.image_roots)

    def __getitem__(self,idx):
        image_root = self.image_roots[idx]
        image_name = image_root.split("/")[-1]
        image = Image.open(image_root)
        image = image.convert('RGB')
        image = image.resize((224,224), resample=Image.LANCZOS)
        #image = np.array(image, dtype=np.float32)
        if self.transform is not None:
            image = self.transform(image)
        press_time = self.press_times[image_name]

        return image,press_time

def jump_data_loader():
    normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
    transform = transforms.Compose([transforms.ToTensor(),normalize])
    dataset =  JumpDataset(transform=transform)
    return DataLoader(dataset,batch_size = 32,shuffle = True)