from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os
import cv2
import h5py
import numpy as np
import tensorflow as tf

from glob import glob
from tqdm import tqdm
from util import split
from config import get_config
from multiprocessing import Pool


# Configuration
config, _ = get_config()

seed = config.seed


class DataSetLoader:

    @staticmethod
    def get_extension(ext):
        if ext in ['jpg', 'png']:
            return 'img'
        elif ext == 'tfr':
            return 'tfr'
        elif ext == 'h5':
            return 'h5'
        elif ext == 'npy':
            return 'npy'
        else:
            raise ValueError("[-] There's no supporting file... [%s] :(" % ext)

    @staticmethod
    def get_img(path, size=(64, 64), interp=cv2.INTER_CUBIC):
        img = cv2.imread(path, cv2.IMREAD_COLOR)[..., ::-1]  # BGR to RGB
        if img.shape[:1] == size:
            return img
        else:
            return cv2.resize(img, size, interp)

    @staticmethod
    def parse_tfr_tf(record):
        features = tf.parse_single_example(record, features={
            'shape': tf.FixedLenFeature([3], tf.int64),
            'data': tf.FixedLenFeature([], tf.string)})
        data = tf.decode_raw(features['data'], tf.uint8)
        return tf.reshape(data, features['shape'])

    @staticmethod
    def parse_tfr_np(record):
        ex = tf.train.Example()
        ex.ParseFromString(record)
        shape = ex.features.feature['shape'].int64_list.value
        data = ex.features.feature['data'].bytes_list.value[0]
        return np.fromstring(data, np.uint8).reshape(shape)

    @staticmethod
    def img_scaling(img, scale='0,1'):
        if scale == '0,1':
            try:
                img /= 255.
            except TypeError:  # ufunc 'true divide' output ~
                img = np.true_divide(img, 255.0, casting='unsafe')
        elif scale == '-1,1':
            try:
                img = (img / 127.5) - 1.
            except TypeError:
                img = np.true_divide(img, 127.5, casting='unsafe') - 1.
        else:
            raise ValueError("[-] Only '0,1' or '-1,1' please - (%s)" % scale)

        return img

    def __init__(self, path, size=None, name='to_tfr', use_save=False, save_file_name='',
                 buffer_size=4096, n_threads=8,
                 use_image_scaling=False, image_scale='0,1', img_save_method=cv2.INTER_LINEAR, debug=True):

        self.op = name.split('_')
        self.debug = debug

        try:
            assert len(self.op) == 2
        except AssertionError:
            raise AssertionError("[-] Invalid Target Types :(")

        self.size = size

        try:
            assert self.size
        except AssertionError:
            raise AssertionError("[-] Invalid Target Sizes :(")

        # To-DO
        # Supporting 4D Image
        self.height = size[0]
        self.width = size[1]
        self.channel = size[2]

        self.path = path

        try:
            assert os.path.exists(self.path)
        except AssertionError:
            raise AssertionError("[-] Path(%s) does not exist :(" % self.path)

        self.buffer_size = buffer_size
        self.n_threads = n_threads

        if os.path.isfile(self.path):
            self.file_list = [self.path]
            self.file_ext = self.path.split('.')[-1]
            self.file_names = [self.path]
        else:
            self.file_list = sorted(os.listdir(self.path))
            self.file_ext = self.file_list[0].split('.')[-1]
            self.file_names = glob(self.path + '/*')
        self.raw_data = np.ndarray([], dtype=np.uint8)  # (N, H * W * C)

        if self.debug:
            print("[*] Detected Path            is [%s]" % self.path)
            print("[*] Detected File Extension  is [%s]" % self.file_ext)
            print("[*] Detected First File Name is [%s] (%d File(s))" % (self.file_names[0], len(self.file_names)))

        self.types = ('img', 'tfr', 'h5', 'npy')  # Supporting Data Types
        self.op_src = self.get_extension(self.file_ext)
        self.op_dst = self.op[1]

        try:
            chk_src, chk_dst = False, False
            for t in self.types:
                if self.op_src == t:
                    chk_src = True
                if self.op_dst == t:
                    chk_dst = True
            assert chk_src and chk_dst
        except AssertionError:
            raise AssertionError("[-] Invalid Operation Types (%s, %s) :(" % (self.op_src, self.op_dst))

        self.img_save_method = img_save_method

        if self.op_src == self.types[0]:
            self.load_img()
        elif self.op_src == self.types[1]:
            self.load_tfr()
        elif self.op_src == self.types[2]:
            self.load_h5()
        elif self.op_src == self.types[3]:
            self.load_npy()
        else:
            raise NotImplementedError("[-] Not Supported Type :(")

        # Random Shuffle
        order = np.arange(self.raw_data.shape[0])
        np.random.RandomState(seed).shuffle(order)
        self.raw_data = self.raw_data[order]

        # Clip [0, 255]
        self.raw_data = np.rint(self.raw_data).clip(0, 255).astype(np.uint8)

        self.use_save = use_save
        self.save_file_name = save_file_name

        if self.use_save:
            try:
                assert self.save_file_name
            except AssertionError:
                raise AssertionError("[-] Empty save-file name :(")

            if self.op_dst == self.types[0]:
                self.convert_to_img()
            elif self.op_dst == self.types[1]:
                self.tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
                self.tfr_writer = tf.python_io.TFRecordWriter(self.save_file_name + ".tfrecords", self.tfr_opt)
                self.convert_to_tfr()
            elif self.op_dst == self.types[2]:
                self.convert_to_h5()
            elif self.op_dst == self.types[3]:
                self.convert_to_npy()
            else:
                raise NotImplementedError("[-] Not Supported Type :(")

        self.use_image_scaling = use_image_scaling
        self.img_scale = image_scale

        if self.use_image_scaling:
            self.raw_data = self.img_scaling(self.raw_data, self.img_scale)

    def load_img(self):
        self.raw_data = np.zeros((len(self.file_list), self.height * self.width * self.channel),
                                 dtype=np.uint8)

        for i, fn in tqdm(enumerate(self.file_names)):
            self.raw_data[i] = self.get_img(fn, (self.height, self.width), self.img_save_method).flatten()
            if self.debug:  # just once
                print("[*] Image Shape   : ", self.raw_data[i].shape)
                print("[*] Image Size    : ", self.raw_data[i].size)
                print("[*] Image MIN/MAX :  (%d, %d)" % (np.min(self.raw_data[i]), np.max(self.raw_data[i])))
                self.debug = False

    def load_tfr(self):
        self.raw_data = tf.data.TFRecordDataset(self.file_names, compression_type='', buffer_size=self.buffer_size)
        self.raw_data = self.raw_data.map(self.parse_tfr_tf, num_parallel_calls=self.n_threads)

    def load_h5(self, size=0, offset=0):
        init = True

        for fl in self.file_list:  # For multiple .h5 files
            with h5py.File(fl, 'r') as hf:
                data = hf['images']
                full_size = len(data)

                if size == 0:
                    size = full_size

                n_chunks = int(np.ceil(full_size / size))
                if offset >= n_chunks:
                    print("[*] Looping from back to start.")
                    offset %= n_chunks
                if offset == n_chunks - 1:
                    print("[-] Not enough data available, clipping to end.")
                    data = data[offset * size:]
                else:
                    data = data[offset * size:(offset + 1) * size]

                data = np.array(data, dtype=np.uint8)
                print("[+] ", fl, " => Image size : ", data.shape)

                if init:
                    self.raw_data = data
                    init = False

                    if self.debug:  # just once
                        print("[*] Image Shape   : ", self.raw_data[0].shape)
                        print("[*] Image Size    : ", self.raw_data[0].size)
                        print("[*] Image MIN/MAX :  (%d, %d)" % (np.min(self.raw_data[0]), np.max(self.raw_data[0])))
                        self.debug = False

                    continue
                else:
                    self.raw_data = np.concatenate((self.raw_data, data))

    def load_npy(self):
        self.raw_data = np.rollaxis(np.squeeze(np.load(self.file_names), axis=0), 0, 3)

        if self.debug:  # just once
            print("[*] Image Shape   : ", self.raw_data[0].shape)
            print("[*] Image Size    : ", self.raw_data[0].size)
            print("[*] Image MIN/MAX :  (%d, %d)" % (np.min(self.raw_data[0]), np.max(self.raw_data[0])))
            self.debug = False

    def convert_to_img(self):
        def to_img(i):
            cv2.imwrite('imgHQ%05d.png' % i, cv2.COLOR_BGR2RGB)
            return True

        raw_data_shape = self.raw_data.shape  # (N, H * W * C)

        try:
            assert os.path.exists(self.save_file_name)
        except AssertionError:
            print("[-] There's no %s :(" % self.save_file_name)
            print("[*] Make directory at %s... " % self.save_file_name)
            os.mkdir(self.save_file_name)

        ii = [i for i in range(raw_data_shape[0])]

        pool = Pool(self.n_threads)
        print(pool.map(to_img, ii))

    def convert_to_tfr(self):
        for data in self.raw_data:
            ex = tf.train.Example(features=tf.train.Features(feature={
                'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=data.shape)),
                'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data.tostring()]))
            }))
            self.tfr_writer.write(ex.SerializeToString())

    def convert_to_h5(self):
        with h5py.File(self.save_file_name, 'w') as f:
            f.create_dataset("images", data=self.raw_data)

    def convert_to_npy(self):
        np.save(self.save_file_name, self.raw_data)


class Div2KDataSet:

    def __init__(self, hr_height=768, hr_width=768, lr_height=192, lr_width=192, channel=3,
                 n_patch=16, use_split=False, split_rate=0.1, random_state=42, n_threads=8,
                 ds_path=None, ds_name=None, use_img_scale=True,
                 ds_hr_path=None, ds_lr_path=None,
                 use_save=False, save_type='to_h5', save_file_name=None, debug=False):

        """
        # General Settings
        :param hr_height: input HR image height, default 768
        :param hr_width: input HR image width, default 768
        :param lr_height: input LR image height, default 192
        :param lr_width: input LR image width, default 192
        :param channel: input image channel, default 3 (RGB)
        - in case of Div2K - ds x4, image size is 768 x 768 x 3 (HWC).

        # Pre-Processing Option
        :param n_patch: patch size to crop, default 16
        :param split_rate: image split rate (into train & test), default 0.1
        :param random_state: random seed for shuffling, default 42
        :param n_threads: the number of threads for multi-threading, default 8

        # DataSet Option
        :param ds_path: DataSet's Path, default None
        :param ds_name: DataSet's Name, default None
        :param use_img_scale: using img scaling?, default False
        :param ds_hr_path: DataSet High Resolution path
        :param ds_lr_path: DataSet Low Resolution path
        :param use_save: saving into another file format
        :param save_type: file format to save
        :param save_file_name: file name to save
        :param debug: debugging messages, default False
        """

        self.hr_height = hr_height
        self.hr_width = hr_width
        self.lr_height = lr_height
        self.lr_width = lr_width
        self.channel = channel
        self.hr_shape = (self.hr_height, self.hr_width, self.channel)
        self.lr_shape = (self.lr_height, self.lr_width, self.channel)

        self.n_patch = n_patch
        self.use_split = use_split
        self.split_rate = split_rate
        self.random_state = random_state
        self.num_threads = n_threads  # change this value to the fitted value for ur system

        """
        Expected ds_path : div2k/...
        Expected ds_name : X4
        """
        self.ds_path = ds_path
        self.ds_name = ds_name
        self.ds_hr_path = ds_hr_path
        self.ds_lr_path = ds_lr_path

        try:
            assert self.ds_path
        except AssertionError:
            try:
                assert self.ds_hr_path and self.ds_lr_path
            except AssertionError:
                raise AssertionError("[-] DataSet's path is required!")

        self.use_save = use_save
        self.save_type = save_type
        self.save_file_name = save_file_name
        self.debug = debug

        try:
            if self.use_save:
                assert self.save_file_name
            else:
                self.save_file_name = ""
        except AssertionError:
            raise AssertionError("[-] save-file/folder-name is required!")

        self.n_images = 800
        self.n_images_val = 100

        self.use_img_scaling = use_img_scale

        if self.ds_path:  # like .h5 or .tfr # will be in the same folder
            self.ds_hr_path = self.ds_path + "/DIV2K_train_HR/"
            self.ds_lr_path = self.ds_hr_path

        self.hr_images = DataSetLoader(path=self.ds_hr_path,
                                       size=self.hr_shape,
                                       use_save=self.use_save,
                                       name=self.save_type,
                                       save_file_name=self.save_file_name + "-hr.h5",
                                       use_image_scaling=self.use_img_scaling,
                                       image_scale='0,1',
                                       img_save_method=cv2.INTER_LINEAR).raw_data  # numpy arrays
        self.patch_hr_images = None

        self.lr_images = DataSetLoader(path=self.ds_lr_path,
                                       size=self.lr_shape,
                                       use_save=self.use_save,
                                       name=self.save_type,
                                       save_file_name=self.save_file_name + "-lr.h5",
                                       use_image_scaling=self.use_img_scaling,
                                       image_scale='0,1',
                                       img_save_method=cv2.INTER_CUBIC).raw_data  # numpy arrays
        self.patch_lr_images = None

        if self.n_patch > 0:
            patch_size = int(np.sqrt(self.n_patch))

            self.patch_hr_images = np.zeros((self.n_images * self.n_patch,
                                             self.hr_height // patch_size, self.hr_width // patch_size, self.channel),
                                            dtype=np.uint8)

            self.patch_lr_images = np.zeros((self.n_images * self.n_patch,
                                             self.lr_height // patch_size, self.lr_width // patch_size, self.channel),
                                            dtype=np.uint8)

            for i in tqdm(range(self.n_images)):
                hr_patches = split(np.reshape(self.hr_images[i, :], self.hr_shape), self.n_patch)
                lr_patches = split(np.reshape(self.lr_images[i, :], self.lr_shape), self.n_patch)

                for n_ps in range(self.n_patch):
                    self.patch_hr_images[i * self.n_patch + n_ps] = hr_patches[n_ps]
                    self.patch_lr_images[i * self.n_patch + n_ps] = lr_patches[n_ps]

                if self.debug:
                    import matplotlib.pyplot as plt

                    fig = plt.figure()
                    for j in range(self.n_patch):
                        fig.add_subplot(4, 4, j + 1)
                        plt.imshow(self.patch_hr_images[j, :, :, :])
                    plt.show()

                    fig = plt.figure()
                    for j in range(self.n_patch):
                        fig.add_subplot(4, 4, j + 1)
                        plt.imshow(self.patch_lr_images[j, :, :, :])
                    plt.show()

                    self.debug = False


class DataIterator:

    def __init__(self, x, y, batch_size):
        self.x = x
        self.y = y
        self.batch_size = batch_size
        self.num_examples = num_examples = x.shape[0]
        self.num_batches = num_examples // batch_size
        self.pointer = 0

        assert (self.batch_size <= self.num_examples)

    def next_batch(self):
        start = self.pointer
        self.pointer += self.batch_size

        if self.pointer > self.num_examples:
            perm = np.arange(self.num_examples)
            np.random.shuffle(perm)

            self.x = self.x[perm, :, :, :]
            self.y = self.y[perm, :, :, :]

            start = 0
            self.pointer = self.batch_size

        end = self.pointer
        return self.x[start:end, :, :, :], self.y[start:end, :, :, :]

    def iterate(self):
        for step in range(self.num_batches):
            yield self.next_batch()