"""Dataset class template

This module provides a template for users to implement custom datasets.
You can specify '--dataset_mode template' to use this dataset.
The class name should be consistent with both the filename and its dataset_mode option.
The filename should be <dataset_mode>_dataset.py
The class name should be <Dataset_mode>Dataset.py
You need to implement the following functions:
    -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
    -- <__init__>: Initialize this dataset class.
    -- <__getitem__>: Return a data point and its metadata information.
    -- <__len__>: Return the number of images.
"""
from data.base_dataset import BaseDataset, get_transform
import numpy as np 
import h5py as h5
import os
# from data.image_folder import make_dataset
from PIL import Image


class Hdf5Dataset(BaseDataset):
    """A template dataset class for you to implement custom datasets."""
    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        parser.add_argument('--hdf5_filename', type=str, default='img_align_celeba_128.hdf5', help='the name of hdf5 file')
        parser.add_argument('--load_in_mem', action='store_true', default=False, help='Load all data into memory? (default: %(default)s)')

        #parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0)  # specify dataset-specific default values
        return parser

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions

        A few things can be done here.
        - save the options (have been done in BaseDataset)
        - get image paths and meta information of the dataset.
        - define the image transformation.
        """
        # save the option and dataset root
        BaseDataset.__init__(self, opt)
        # get the image paths of your dataset;
        self.hdf5_path = os.path.join(opt.dataroot, opt.hdf5_filename) 
        self.load_in_mem = opt.load_in_mem
        self.imkey = None
        self.lkey = None
        
        with h5.File(self.hdf5_path,'r') as f:
            key_list = list(f.keys())
            for key in key_list:
                if key == 'data' or key == 'imgs':
                    self.imkey = key
                    self.num_imgs = len(f[self.imkey])
                elif key == 'label' or key == 'labels':
                    self.lkey = key
                else:    
                    raise ValueError('Unkown key in the HDF5 file.')

            # If loading into memory, do so now
            if self.load_in_mem:
                print('Loading %s into memory...' % self.hdf5_path)
                self.data = f[self.imkey][:]
                self.labels = f[self.lkey][:] if (self.lkey is not None) else None

        # define the default transform function. 
        self.transform = get_transform(opt)

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index -- a random integer for data indexing

        Returns:
            a dictionary of data with their names. It usually contains the data itself and its metadata information.

        """
        if self.load_in_mem:
            img = self.data[index]
            label = self.labels[index] if (self.lkey is not None) else -1 
        else:
            with h5.File(self.hdf5_path,'r') as f:
                img = f[self.imkey][index]
                label = f[self.lkey][index] if (self.lkey is not None) else -1 
        if img.shape[0] <= 3:
            img = img.transpose(1,2,0)
        img = Image.fromarray(img) 
        img = self.transform(img)

        return {'image': img, 'target': label}

    def __len__(self):
        """Return the total number of images."""
        return self.num_imgs