from os import path as osp
from torch.utils import data
from torchvision import transforms
from PIL import Image
import numpy as np


class LIP(data.Dataset):

    def __init__(self, root, train=True, transform=None, gt_transform=None ):
        self.root = root
        self.transform = transform
        self.gt_transform = gt_transform
        self.train = train  # trainval set or val set

        if self.train:
            self.train_image_path, self.train_gt_path = self.read_labeled_image_list(osp.join(root, 'train'))
        else:
            self.val_image_path, self.val_gt_path = self.read_labeled_image_list(osp.join(root, 'val'))
            # self.test_image_path = self.read_labeled_image_list(osp.join(root, 'test'))

    def __getitem__(self, index):
        if self.train:
            img, gt = self.get_a_sample(self.train_image_path, self.train_gt_path, index)
        else:
            img, gt = self.get_a_sample(self.val_image_path, self.val_gt_path, index)
        return img, gt

    def __len__(self):
        if self.train:
            return len(self.train_image_path)
        else:
            return len(self.val_image_path)

    def get_a_sample(self, image_path, gt_path, index):
        # get PIL Image
        img = Image.open(image_path[index])  # .resize((512,512),resample=Image.BICUBIC)
        if len(img.getbands()) != 3:
            img = img.convert('RGB')
        gt = Image.open(gt_path[index])  # .resize((30,30),resample=Image.NEAREST)
        if len(gt.getbands()) != 1:
            gt = gt.convert('L')

        if self.transform is not None:
            img = self.transform(img)
        if self.gt_transform is not None:
            gt = self.gt_transform(gt)
        return img, gt

    def read_labeled_image_list(self, data_dir):
        # return img path list and groundtruth path list
        f = open(osp.join(data_dir, 'id.txt' ), 'r')
        image_path = []
        gt_path = []
        for line in f:
            image = line.strip("\n")
            if self.train:
                image_path.append(osp.join(data_dir, 'image', image + ".jpg"))
                gt_path.append(osp.join(data_dir, 'gt', image + ".png"))
            else:
                image_path.append(osp.join(data_dir, 'image', image + ".jpg"))
                gt_path.append(osp.join(data_dir, 'gt', image + ".png"))
        return image_path, gt_path


class LIPWithClass(LIP):

    def __init__(self, root, num_cls=20, train=True, transform=None, gt_transform=None):
        LIP.__init__(self, root, train, transform, gt_transform)
        self.num_cls = num_cls

    def __getitem__(self, index):
        if self.train:
            img, gt, gt_cls = self.get_a_sample(self.train_image_path, self.train_gt_path, index)
        else:
            img, gt, gt_cls = self.get_a_sample(self.val_image_path, self.val_gt_path, index)
        return img, gt, gt_cls

    def get_a_sample(self, image_path, gt_path, index):
        # get PIL Image
        # gt_cls - batch of 1D tensors of dimensionality N: N total number of classes,
        # gt_cls[i, T] = 1 if class T is present in image i, 0 otherwise
        img = Image.open(image_path[index])
        if len(img.getbands()) != 3:
            img = img.convert('RGB')
        gt = Image.open(gt_path[index])
        if len(gt.getbands()) != 1:
            gt = gt.convert('L')
        # compute gt_cls
        gt_np = np.asarray(gt, dtype=np.uint8)
        gt_cls, _ = np.histogram(gt_np, bins=self.num_cls, range=(-0.5, self.num_cls-0.5), )
        gt_cls = np.asarray(np.asarray(gt_cls, dtype=np.bool), dtype=np.uint8)
        if self.transform is not None:
            img = self.transform(img)
        if self.gt_transform is not None:
            gt = self.gt_transform(gt)

        return img, gt, gt_cls


if __name__ == "__main__":

    import matplotlib.pyplot as plt

    path = 'K:\Dataset\LIP\single'

    transform_image_list = [
        transforms.Resize((512, 512), Image.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]

    transform_gt_list = [
        transforms.Resize((30, 30), Image.NEAREST),
        transforms.Lambda(lambda image: Image.fromarray(np.uint8(np.asarray(image)*(255.0/19.0)))),
        transforms.ToTensor(),
    ]

    data_transforms = {
        'image': transforms.Compose(transform_image_list),
        'gt': transforms.Compose(transform_gt_list),
    }

    loader = data.DataLoader(LIP(path, transform=data_transforms['image'], gt_transform=data_transforms['gt']),
                             batch_size=2, shuffle=False)

    for count, (src, lab) in enumerate(loader):
        src = src[0, :, :, :].numpy()
        lab = lab[0, :, :, :].numpy().transpose(1, 2, 0)


        def denormalize(image, mean, std):
            c, _, _ = image.shape
            for idx in range(c):
                image[idx, :, :] = image[idx, :, :] * std[idx] + mean[idx]
            return image

        src = denormalize(src, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).transpose(1, 2, 0)

        plt.subplot(121)
        plt.imshow(src)
        plt.subplot(122)
        plt.imshow(np.concatenate([lab, lab, lab], axis=2), cmap='gray')
        plt.show()
        print(src.shape)
        if count+1 == 4:
            break