import random
import numpy as np
import cv2
import torch
import torch.utils.data as data
import data.util as util


class LRHRSeg_BG_Dataset(data.Dataset):
    '''
    Read HR image, segmentation probability map; generate LR image, category for SFTGAN
    also sample general scenes for background
    need to generate LR images on-the-fly
    '''

    def __init__(self, opt):
        super(LRHRSeg_BG_Dataset, self).__init__()
        self.opt = opt
        self.paths_LR = None
        self.paths_HR = None
        self.paths_HR_bg = None  # HR images for background scenes
        self.LR_env = None  # environment for lmdb
        self.HR_env = None
        self.HR_env_bg = None

        # read image list from lmdb or image files
        self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_GT'])
        self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
        self.HR_env_bg, self.paths_HR_bg = util.get_image_paths(opt['data_type'],
                                                                opt['dataroot_GT_bg'])

        assert self.paths_HR, 'Error: HR path is empty.'
        if self.paths_LR and self.paths_HR:
            assert len(self.paths_LR) == len(self.paths_HR), \
                'HR and LR datasets have different number of images - {}, {}.'.format(
                len(self.paths_LR), len(self.paths_HR))

        self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5]
        self.ratio = 10  # 10 OST data samples and 1 DIV2K general data samples(background)

    def __getitem__(self, index):
        HR_path, LR_path = None, None
        scale = self.opt['scale']
        HR_size = self.opt['HR_size']

        # get HR image
        if self.opt['phase'] == 'train' and \
                random.choice(list(range(self.ratio))) == 0:  # read background images
            bg_index = random.randint(0, len(self.paths_HR_bg) - 1)
            HR_path = self.paths_HR_bg[bg_index]
            img_HR = util.read_img(self.HR_env_bg, HR_path)
            seg = torch.FloatTensor(8, img_HR.shape[0], img_HR.shape[1]).fill_(0)
            seg[0, :, :] = 1  # background
        else:
            HR_path = self.paths_HR[index]
            img_HR = util.read_img(self.HR_env, HR_path)
            seg = torch.load(HR_path.replace('/img/', '/bicseg/').replace('.png', '.pth'))
            # read segmentatin files, you should change it to your settings.

        # modcrop in the validation / test phase
        if self.opt['phase'] != 'train':
            img_HR = util.modcrop(img_HR, 8)

        seg = np.transpose(seg.numpy(), (1, 2, 0))

        # get LR image
        if self.paths_LR:
            LR_path = self.paths_LR[index]
            img_LR = util.read_img(self.LR_env, LR_path)
        else:  # down-sampling on-the-fly
            # randomly scale during training
            if self.opt['phase'] == 'train':
                random_scale = random.choice(self.random_scale_list)
                H_s, W_s, _ = seg.shape

                def _mod(n, random_scale, scale, thres):
                    rlt = int(n * random_scale)
                    rlt = (rlt // scale) * scale
                    return thres if rlt < thres else rlt

                H_s = _mod(H_s, random_scale, scale, HR_size)
                W_s = _mod(W_s, random_scale, scale, HR_size)
                img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
                seg = cv2.resize(np.copy(seg), (W_s, H_s), interpolation=cv2.INTER_NEAREST)

            H, W, _ = img_HR.shape
            # using matlab imresize
            img_LR = util.imresize_np(img_HR, 1 / scale, True)
            if img_LR.ndim == 2:
                img_LR = np.expand_dims(img_LR, axis=2)

        H, W, C = img_LR.shape
        if self.opt['phase'] == 'train':
            LR_size = HR_size // scale

            # randomly crop
            rnd_h = random.randint(0, max(0, H - LR_size))
            rnd_w = random.randint(0, max(0, W - LR_size))
            img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]
            rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
            img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
            seg = seg[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]

            # augmentation - flip, rotate
            img_LR, img_HR, seg = util.augment([img_LR, img_HR, seg], self.opt['use_flip'],
                                               self.opt['use_rot'])

            # category
            if 'building' in HR_path:
                category = 1
            elif 'plant' in HR_path:
                category = 2
            elif 'mountain' in HR_path:
                category = 3
            elif 'water' in HR_path:
                category = 4
            elif 'sky' in HR_path:
                category = 5
            elif 'grass' in HR_path:
                category = 6
            elif 'animal' in HR_path:
                category = 7
            else:
                category = 0  # background
        else:
            category = -1  # during val, useless

        # BGR to RGB, HWC to CHW, numpy to tensor
        if img_HR.shape[2] == 3:
            img_HR = img_HR[:, :, [2, 1, 0]]
            img_LR = img_LR[:, :, [2, 1, 0]]
        img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()
        img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
        seg = torch.from_numpy(np.ascontiguousarray(np.transpose(seg, (2, 0, 1)))).float()

        if LR_path is None:
            LR_path = HR_path
        return {
            'LR': img_LR,
            'HR': img_HR,
            'seg': seg,
            'category': category,
            'LR_path': LR_path,
            'HR_path': HR_path
        }

    def __len__(self):
        return len(self.paths_HR)