import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
from PIL import Image
import os

class ImageData(Dataset):
    def __init__(self, root_path="CACD2000/", label_path="data/label.npy", name_path="data/name.npy", train_mode = "train"):
        """
        Initialize some variables
        Load labels & names
        define transform
        """
        self.root_path = root_path
        self.image_labels = np.load(label_path)
        self.image_names = np.load(name_path)
        self.train_mode = train_mode
        self.transform = {
            'train': transforms.Compose([                
                transforms.Resize(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
 #               transforms.Normalize([0.656,0.487,0.411], [1., 1., 1.])
            ]),
            'val': transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
 #               transforms.Normalize([0.656,0.487,0.411], [1., 1., 1.])
            ]),
        }

    def __len__(self):
        """
        Get the length of the entire dataset
        """
        print("Length of dataset is ", self.image_labels.shape[0])
        return self.image_labels.shape[0]

    def __getitem__(self, idx):
        """
        Get the image item by index
        """
        image_name = os.path.join(self.root_path, self.image_names[idx])
        image = Image.open(image_name)
        image_label = self.image_labels[idx].astype(int) - 1
        transformed_img = self.transform[self.train_mode](image)
        sample = {'image':transformed_img, 'label':torch.from_numpy(image_label)}
        return sample