import json
import os
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np
import random
import pdb
import skimage.io as io
import h5py
import matplotlib.pyplot as plt

from PIL import Image


def parse_file(dataset_adr, categories):
    dataset = []
    with open(dataset_adr) as f:
        for line in f:
            line = line[:-1].split('/')
            category = '/'.join(line[2:-1])
            file_name = '/'.join(line[2:])
            if not category in categories:
                continue
            dataset.append([file_name, category])
    return dataset


def get_class_names(path):
    classes = []
    with open(path) as f:
        for line in f:
            categ = '/'.join(line[:-1].split('/')[2:])
            classes.append(categ)
    class_dic = {classes[i]: i for i in range(len(classes))}
    return class_dic


class SunDataset(data.Dataset):

    CLASS_WEIGHTS = None

    def __init__(self, args, train=True):
        self.root_dir = args.data
        root_dir = self.root_dir
        if train:
            self.data_set_list = os.path.join(root_dir,
                                              args.trainset_image_list)
        else:
            self.data_set_list = os.path.join(root_dir, args.testset_image_list)

        self.categ_dict = get_class_names(
            os.path.join(root_dir, 'ClassName.txt'))

        self.data_set_list = parse_file(self.data_set_list, self.categ_dict)

        self.args = args
        self.read_features = args.read_features

        self.features_dir = args.features_dir
        if train:
            self.transform = transforms.Compose([
                transforms.RandomSizedCrop(args.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.Scale((args.image_size, args.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Scale((args.image_size, args.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

    def get_relative_centroids(self):
        return None

    def __len__(self):
        return len(self.data_set_list)

    def load_and_resize(self, img_name):
        with open(img_name, 'rb') as fp:
            image = Image.open(fp).convert('RGB')
        return self.transform(image)

    def __getitem__(self, idx):
        file_name, categ = self.data_set_list[idx]
        try:
            image = self.load_and_resize(
                os.path.join(self.root_dir, 'all_data', file_name + '~'))
        except Exception:
            image = self.load_and_resize(
                os.path.join(self.root_dir, 'all_data', file_name))
        if not categ in self.categ_dict:
            pdb.set_trace()
        label = self.categ_dict[categ]
        label = torch.Tensor([label]).long()
        return (image, label, 0, 0, [file_name])