import numpy as np
import cv2
import pickle
import os

class LoadData:
    '''
    Class to laod the data
    '''
    def __init__(self, data_dir, classes, cached_data_file, normVal=1.10, additional=None):
        '''
        :param data_dir: directory where the dataset is kept
        :param classes: number of classes in the dataset
        :param cached_data_file: location where cached file has to be stored
        :param normVal: normalization value, as defined in ERFNet paper
        '''
        self.data_dir = data_dir
        self.classes = classes
        self.classWeights = np.ones(self.classes, dtype=np.float32)
        self.normVal = normVal
        self.mean = np.zeros(3, dtype=np.float32)
        self.std = np.zeros(3, dtype=np.float32)
        self.trainImList = list()
        self.valImList = list()
        self.addvalImList = list()
        self.trainAnnotList = list()
        self.valAnnotList = list()
        self.addvalAnnotList = list()
        self.cached_data_file = cached_data_file
        self.train_txt=list()
        self.val_txt=list()
        self.additional = additional

    def compute_class_weights(self, histogram):
        '''
        Helper function to compute the class weights
        :param histogram: distribution of class samples
        :return: None, but updates the classWeights variable
        '''
        normHist = histogram / np.sum(histogram)
        for i in range(self.classes):
            self.classWeights[i] = 1 / (np.log(self.normVal + normHist[i]))

    def readFile(self, fileName, trainStg=False, addtional=None):
        '''
        Function to read the data
        :param fileName: file that stores the image locations
        :param trainStg: if processing training or validation data
        :return: 0 if successful
        '''
        if trainStg == True:
            global_hist = np.zeros(self.classes, dtype=np.float32)

        no_files = 0
        min_val_al = 0
        max_val_al = 0
        with open(self.data_dir + '/Portrait/' + fileName, 'r') as textFile:
            for line in textFile:
                # we expect the text file to contain the data in following format
                # <RGB Image>, <Label Image>
                # line_arr = line.split(',')
                name_num = int(line)
                img_file = ((self.data_dir).strip() + '/Portrait/images_data_crop/' + str(name_num).zfill(5) +'.jpg')
                label_file = ((self.data_dir).strip() + '/Portrait/GT_png/' + str(name_num).zfill(5) + '_mask.png')

                # print(label_file)
                label_img = cv2.imread(label_file, 0)
                label_img = label_img/255
                unique_values = np.unique(label_img)
                max_val = max(unique_values)
                min_val = min(unique_values)

                max_val_al = max(max_val, max_val_al)
                min_val_al = min(min_val, min_val_al)

                if trainStg == True:
                    hist = np.histogram(label_img, self.classes)
                    global_hist += hist[0]
                    try:
                        rgb_img = cv2.imread(img_file)
                        self.mean[0] += np.mean(rgb_img[:, :, 0])
                        self.mean[1] += np.mean(rgb_img[:, :, 1])
                        self.mean[2] += np.mean(rgb_img[:, :, 2])

                        self.std[0] += np.std(rgb_img[:, :, 0])
                        self.std[1] += np.std(rgb_img[:, :, 1])
                        self.std[2] += np.std(rgb_img[:, :, 2])

                        self.trainImList.append(img_file)
                        self.trainAnnotList.append(label_file)
                        no_files += 1
                        self.train_txt.append(str(name_num).zfill(5))
                    except:
                        print("Train has problem" + img_file)
                else:

                    rgb_img = cv2.imread(img_file)
                    try:
                        if len(rgb_img.shape) >2:
                            self.valImList.append(img_file)
                            self.valAnnotList.append(label_file)
                            self.val_txt.append(str(name_num).zfill(5))

                        else:
                            print("Val has problem" + img_file)
                    except:
                        print("Val has problem" + img_file)

                if max_val > (self.classes - 1) or min_val < 0:
                    print('Labels can take value between 0 and number of classes.')
                    print('Some problem with labels. Please check.')
                    print('Label Image ID: ' + label_file)

     ############ add additional dataset with ##################################

        if addtional !=None:
            for i in range(len(addtional)):
                this_additoinal = addtional[i]
                print(this_additoinal)
                with open(self.data_dir + this_additoinal + fileName, 'r') as textFile:
                    for line in textFile:
                        # we expect the text file to contain the data in following format
                        # <RGB Image>, <Label Image>
                        # line_arr = line.split(',')
                        img_file = ((self.data_dir).strip() + this_additoinal+'input/' + line.strip())
                        label_file = ((self.data_dir).strip() + this_additoinal+'target/' + line.strip())
                        label_img = cv2.imread(label_file, 0)
                        if os.path.isfile(label_file) == True:

                            label_bool = 255*((label_img >200).astype(np.uint8))
                            label_img = label_bool / 255
                            unique_values = np.unique(label_img)
                            max_val = max(unique_values)
                            min_val = min(unique_values)

                            max_val_al = max(max_val, max_val_al)
                            min_val_al = min(min_val, min_val_al)

                            if trainStg == True:
                                hist = np.histogram(label_img, self.classes)
                                global_hist += hist[0]
                                try:
                                    rgb_img = cv2.imread(img_file)
                                    self.mean[0] += np.mean(rgb_img[:, :, 0])
                                    self.mean[1] += np.mean(rgb_img[:, :, 1])
                                    self.mean[2] += np.mean(rgb_img[:, :, 2])

                                    self.std[0] += np.std(rgb_img[:, :, 0])
                                    self.std[1] += np.std(rgb_img[:, :, 1])
                                    self.std[2] += np.std(rgb_img[:, :, 2])

                                    self.trainImList.append(img_file)
                                    self.trainAnnotList.append(label_file)
                                    no_files += 1

                                except:
                                    print("Train has problem" + img_file)
                            else:

                                rgb_img = cv2.imread(img_file)
                                try:
                                    if len(rgb_img.shape) > 2:
                                        self.addvalImList.append(img_file)
                                        self.addvalAnnotList.append(label_file)
                                    else:
                                        print("add Val has problem" + img_file)
                                except:
                                    print("add Val has problem" + img_file)

                            if max_val > (self.classes - 1) or min_val < 0:
                                print('Labels can take value between 0 and number of classes.')
                                print('Some problem with labels. Please check.')
                                print('Label Image ID: ' + label_file)
                        else:
                            print(label_file)

        if trainStg == True:
            # divide the mean and std values by the sample space size
            self.mean /= no_files
            self.std /= no_files

            #compute the class imbalance information
            self.compute_class_weights(global_hist)
        return 0

    def processDataAug(self):
        '''
        main.py calls this function
        We expect train.txt and val.txt files to be inside the data directory.
        :return:
        '''
        print('Processing training data')

        return_val1 = self.readFile('train.txt', True, addtional=self.additional)

        print('Processing validation data')
        return_val2 = self.readFile('val.txt', addtional=self.additional )

        print('Pickling data')
        if (return_val1 ==0 and return_val2 ==0 ):
            data_dict = dict()
            data_dict['trainIm'] = self.trainImList
            data_dict['trainAnnot'] = self.trainAnnotList
            data_dict['valIm'] = self.valImList
            data_dict['valAnnot'] = self.valAnnotList
            data_dict['addvalIm'] = self.addvalImList
            data_dict['addvalAnnot'] = self.addvalAnnotList
            data_dict['mean'] = self.mean
            data_dict['std'] = self.std
            data_dict['classWeights'] = self.classWeights
            if not os.path.isdir("./pickle_file"):
                os.mkdir("./pickle_file")
            pickle.dump(data_dict, open(self.cached_data_file, "wb"))
            return data_dict
        else:
            print("There is problem")
            exit(0)
            return None

    def processData(self):
        '''
        main.py calls this function
        We expect train.txt and val.txt files to be inside the data directory.
        :return:
        '''
        print('Processing training data')

        return_val1 = self.readFile('train.txt', True)

        print('Processing validation data')
        return_val2 = self.readFile('val.txt')

        print('Pickling data')
        if (return_val1 == 0 and return_val2 == 0):
            data_dict = dict()
            data_dict['trainIm'] = self.trainImList
            data_dict['trainAnnot'] = self.trainAnnotList
            data_dict['valIm'] = self.valImList
            data_dict['valAnnot'] = self.valAnnotList

            data_dict['mean'] = self.mean
            data_dict['std'] = self.std
            data_dict['classWeights'] = self.classWeights
            if not os.path.isdir("./pickle_file"):
                os.mkdir("./pickle_file")
            pickle.dump(data_dict, open(self.cached_data_file, "wb"))
            # with open('EG1800_train.txt', 'w') as f:
            #     for item in self.train_txt:
            #         f.write("%s\n" % item)
            # with open('EG1800_val.txt', 'w') as f:
            #     for item in self.val_txt:
            #         f.write("%s\n" % item)
            return data_dict
        else:
            print("There is problem")
            exit(0)
            return None