import random
import numpy as np
import lmdb
import torch
import torch.utils.data as data
import data.util as util


class LQkerDataset(data.Dataset):
    '''Read LR images to Predictor.'''

    def __init__(self, opt, ker_map_list):
        super(LQkerDataset, self).__init__()
        self.opt = opt
        self.opt_P = opt
        self.opt_F = opt
        self.LR_paths = None
        self.LR_sizes = None  # environment for lmdb
        self.LR_env = None
        self.LR_size = opt['LR_size']
        self.ker_maps = ker_map_list

        # read image list from lmdb or image files
        if opt['data_type'] == 'lmdb':
            self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ'])
        elif opt['data_type'] == 'img':
            self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) #LR_list
        else:
            print('Error: data_type is not matched in Dataset')
        assert self.LR_paths, 'Error: LR paths are empty.'



    def _init_lmdb(self):
        # https://github.com/chainer/chainermn/issues/129
        self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
                                meminit=False)

    def __getitem__(self, index):
        if self.opt['data_type'] == 'lmdb':
            if self.LR_env is None:
                self._init_lmdb()

        LR_size = self.LR_size

        # get LR image, kernel map
        LR_path = self.LR_paths[index]
        ker_map = self.ker_maps[index]
        if self.opt['data_type'] == 'lmdb':
            resolution = [int(s) for s in self.LR_sizes[index].split('_')]
        else:
            resolution = None
        img_LR = util.read_img(self.LR_env, LR_path, resolution)
        H, W, C = img_LR.shape

        if self.opt['phase'] == 'train':
            #randomly crop
            rnd_h = random.randint(0, max(0, H - LR_size))
            rnd_w = random.randint(0, max(0, W - LR_size))
            img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]

            # augmentation - flip, rotate
            img_LR = util.augment(img_LR, self.opt['use_flip'], self.opt['use_rot'], self.opt['mode'])

        # change color space if necessary
        if self.opt['color']:
            img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0]

        # BGR to RGB, HWC to CHW, numpy to tensor
        if img_LR.shape[2] == 3:
            img_LR = img_LR[:, :, [2, 1, 0]]
        img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()

        return {'LQ': img_LR, 'ker': ker_map, 'LQ_path': LR_path}

    def __len__(self):
        return len(self.LR_paths)