# coding=utf-8
from __future__ import absolute_import, division, print_function
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from matplotlib import pyplot as plt

import numpy as np
import os

from src.data import utils as data_utils, datasets
from src.models import select_model
from src.models.objectives import w_categorical_crossentropy


class Experiment(object):
    def __init__(self, data, experiment, model, **kwargs):
        # set dataset attributes
        self.data_config = data
        self.DatasetClass = getattr(datasets, data['dataset_name'])

        # set basic model attributes
        self.model_config = model
        self.model_name = model['name']

        # set basic experiment attributes/directories
        self.experiment_config = experiment
        experiment_dir = os.path.join(
            self.experiment_config['root_dir'],
            self.data_config['dataset_name'],
            self.model_name
        )
        log_dir = os.path.join(experiment_dir, 'logs')
        checkpoint_dir = os.path.join(experiment_dir, 'weights')
        data_utils.ensure_dir(log_dir)
        data_utils.ensure_dir(checkpoint_dir)
        self.experiment_dir = experiment_dir
        self.log_dir = log_dir
        self.checkpoint_dir = checkpoint_dir

        checkpoint_filename = 'weights.{}.{}.{}-{}.h5'.format(
            self.data_config['dataset_name'],
            self.model_name,
            '{epoch:02d}',
            '{val_loss:.2f}'
        )
        self.checkpoint_file = os.path.join(checkpoint_dir, checkpoint_filename)

    def callbacks(self):
        """
        :return:
        """
        # TODO: Add ReduceLROnPlateau callback
        cbs = []

        tb = TensorBoard(log_dir=self.log_dir,
                         write_graph=True,
                         write_images=True)
        cbs.append(tb)

        best_model_filename = self.model_name + '_best.h5'
        best_model = os.path.join(self.checkpoint_dir, best_model_filename)
        save_best = ModelCheckpoint(best_model, save_best_only=True)
        cbs.append(save_best)

        checkpoints = ModelCheckpoint(filepath=self.checkpoint_file, verbose=1)
        cbs.append(checkpoints)

        reduce_lr = ReduceLROnPlateau(patience=1, verbose=1)
        cbs.append(reduce_lr)
        return cbs

    def dataset(self, data_split='train'):
        """

        :param data_split: 'train', 'val' or 'test'
        :return:
        """
        valid_data_splits = ['train', 'val', 'test']
        if data_split not in valid_data_splits:
            errmsg = 'Invalid data split: {} instead of {}'.format(
                data_split,
                '/'.join(valid_data_splits)
            )
            raise ValueError(errmsg)

        kwargs = self.data_config
        supplementary_args = kwargs[data_split]
        kwargs.update(supplementary_args)

        return self.DatasetClass(**kwargs)

    def model(self):
        model = select_model(model_name=self.model_name)
        kwargs = self.model_config
        kwargs['nc'] = self.DatasetClass.num_classes()
        weights = self.DatasetClass.class_weights()
        kwargs['loss'] = [
            'categorical_crossentropy',
            # w_categorical_crossentropy(weights=weights)
        ]
        autoencoder, _ = model.build(**kwargs)
        if self.model_config['print_summary']:
            autoencoder.summary()
        try:
            h5file = self.model_config['h5file']
            print('Loading from file {}'.format(h5file))
            h5file, ext = os.path.splitext(h5file)
            autoencoder.load_weights(h5file + ext)
        except KeyError:
            autoencoder = model.transfer_weights(autoencoder)
        return autoencoder

    def run(self):
        dataset_name = self.data_config['dataset_name']
        print('Preparing to train on {} data...'.format(dataset_name))
        train_dataset = self.dataset(data_split='train')
        val_dataset = self.dataset(data_split='val')
        model = self.model()

        print('Preparing to start training')
        model.fit_generator(
            generator=train_dataset.flow(),
            steps_per_epoch=train_dataset.steps,
            epochs=self.experiment_config['epochs'],
            verbose=1,
            callbacks=self.callbacks(),
            validation_data=val_dataset.flow(),
            validation_steps=val_dataset.steps,
            initial_epoch=self.experiment_config['completed_epochs'],
        )


class InferenceExperiment(Experiment):
    def __init__(self, data, experiment, model, **kwargs):
        super(InferenceExperiment, self).__init__(data, experiment, model, **kwargs)

    def model(self):
        model_name = 'enet_unpooling'
        # model_name = 'enet_unpooling_weights_simple_setup'
        # model_name = 'enet_unpooling_no_weights'
        dataset_name = self.data_config['dataset_name']
        root_dir = 'experiments'
        pw = os.path.join(
            root_dir, dataset_name,
            model_name,
            'weights',
            # 'weights.enet_unpooling.02-2.59.h5'
            '{}_best.h5'.format(model_name)
        )

        # print(pw)

        nc = getattr(datasets, dataset_name).num_classes()
        self.model_config['nc'] = nc

        autoencoder = select_model(model_name=model_name)
        # segmenter, model_name = autoencoder.build(nc=nc, w=w, h=h)
        segmenter, model_name = autoencoder.build(**self.model_config)
        segmenter.load_weights(pw)
        return segmenter

    def run(self):
        model = self.model()
        dataset = self.dataset()
        for image_batch, target_batch in dataset.flow():
            image_batch = image_batch['image']
            target_batch = target_batch['output']
            for image, target in zip(image_batch, target_batch):
                output = model.predict(np.expand_dims(image, axis=0))[0]
                output = np.reshape(np.argmax(output, axis=-1), newshape=(512, 512))

                target = np.reshape(np.argmax(target, axis=-1), newshape=(512, 512))

                plt.rcParams["figure.figsize"] = [4 * 3, 4]

                fig = plt.figure()

                subplot1 = fig.add_subplot(131)
                subplot1.imshow(image.astype(np.uint8))
                subplot1.set_title('rgb image')
                subplot1.axis('off')

                subplot2 = fig.add_subplot(132)
                subplot2.imshow(output, cmap='gray')
                subplot2.set_title('Prediction')
                subplot2.axis('off')

                subplot3 = fig.add_subplot(133)
                masked = np.array(target)
                masked[target == 0] = 0
                subplot3.imshow(masked, cmap='gray')
                subplot3.set_title('Targets')
                subplot3.axis('off')

                fig.tight_layout()
                plt.show()


class CaptioningExperiment(Experiment):
    def __init__(self, data, model, experiment, **kwargs):
        super(CaptioningExperiment, self).__init__(data, model, experiment)
        max_seq_len = self.data_config['max_caption_length'] + 2
        self.model_config['max_token_length'] = max_seq_len

    def run(self):
        dataset_name = self.data_config['dataset_name']
        print('Preparing to train on {} data...'.format(dataset_name))
        train_dataset = self.dataset(data_split='train')
        val_dataset = self.dataset(data_split='val')
        self.model_config['vocab_size'] = train_dataset.vocab.size()
        model = self.model()
        model.fit_generator(
            generator=train_dataset.flow(),
            steps_per_epoch=train_dataset.steps,
            epochs=self.experiment_config['epochs'],
            verbose=1,
            callbacks=self.callbacks(),
            validation_data=val_dataset.flow(),
            validation_steps=val_dataset.steps,
            initial_epoch=self.experiment_config['completed_epochs'],
        )


class DryDatasetExperiment(Experiment):
    def __init__(self, data, model, experiment, **kwargs):
        super(DryDatasetExperiment, self).__init__(
            data=data,
            experiment=experiment,
            model=model,
            **kwargs
        )

    def split_label_channels(self, label):
        binary_masks = {}
        for i in range(label.shape[-1]):
            binary_mask = label[..., i]
            if not np.any(binary_mask > 0):
                continue
            binary_mask[binary_mask > 0] = 1
            binary_masks[i] = binary_mask.astype(np.uint8)
        return binary_masks

    def run(self):
        from matplotlib import pyplot as plt
        import sys

        np.random.seed(1337)  # for reproducibility

        dataset = self.dataset(data_split='val')

        for idx, item in enumerate(dataset.flow()):
            img, lbl = item[0]['image'].astype(np.uint8), item[1]['output']
            batch_size = img.shape[0]
            h = img.shape[1]
            w = img.shape[2]
            nc = lbl.shape[-1]
            lbl = np.reshape(lbl, (batch_size, h, w, nc))
            # batch_size = dataset.config.batch_size
            for batch_index in range(batch_size):
                binary_masks = self.split_label_channels(lbl[batch_index, ...])
                img_item = img[batch_index, ...]
                for class_idx, binary_mask in binary_masks.items():
                    # class_name = dataset.CATEGORIES[dataset.IDS[class_idx]]
                    class_name = dataset.CATEGORIES[class_idx]

                    plt.rcParams["figure.figsize"] = [4 * 3, 4]

                    fig = plt.figure()

                    subplot1 = fig.add_subplot(131)
                    subplot1.imshow(img_item)
                    subplot1.set_title('rgb image')
                    subplot1.axis('off')

                    subplot2 = fig.add_subplot(132)
                    subplot2.imshow(binary_mask, cmap='gray')
                    subplot2.set_title('{} binary mask'.format(class_name))
                    subplot2.axis('off')

                    subplot3 = fig.add_subplot(133)
                    masked = np.array(img_item)
                    masked[binary_mask == 0] = 0
                    subplot3.imshow(masked)
                    subplot3.set_title('{} label'.format(class_name))
                    subplot3.axis('off')

                    fig.tight_layout()
                    plt.show()
            # shapes.append(img.shape)
                item_idx = batch_size * idx + batch_index + 1
                print('Processed {} items: ({})'.format(item_idx, type(item)),
                      end='\r')
            sys.stdout.flush()


class DryDatasetCaptioningExperiment(CaptioningExperiment):
    def __init__(self, data, model, experiment, **kwargs):
        super(DryDatasetCaptioningExperiment, self).__init__(
            data=data,
            experiment=experiment,
            model=model,
            **kwargs
        )

    def run(self):
        dataset = self.dataset(data_split='val')
        for item in dataset.flow():
            for item_idx in range(dataset.config.batch_size):
                text = [dataset.vocab.decode(idx)
                        for idx in np.argmax(item[0]['text'][item_idx], axis=-1)]
                output = [dataset.vocab.decode(idx)
                          for idx in np.argmax(item[1]['output'][item_idx], axis=-1)]
                # print(' '.join(text))
                # print(' '.join(output))


class OverfittingExperiment(Experiment):
    def __init__(self, data, model, experiment, **kwargs):
        data['sample_size'] = 50
        experiment['epochs'] = 2
        super(OverfittingExperiment, self).__init__(
            data=data,
            experiment=experiment,
            model=model,
            **kwargs
        )

    def run(self):
        dataset = self.dataset(data_split='train')
        model = self.model()
        model.fit_generator(
            generator=dataset.flow(),
            steps_per_epoch=dataset.steps,
            epochs=self.experiment_config['epochs'],
            verbose=1,
            # validation_data=dataset.flow(),
            # validation_steps=dataset.steps,
            initial_epoch=self.experiment_config['completed_epochs'],
        )
        for inputs, outputs in dataset.flow(single_pass=True):
            predictions = model.predict(inputs)
            for prediction, output in zip(predictions, outputs['output']):
                pass
        print('End of overfitting experiment')


class SemanticSegmentationExperiment(Experiment):
    def __init__(self, **kwargs):
        super(SemanticSegmentationExperiment, self).__init__(**kwargs)

    def model(self):
        model = select_model(model_name=self.model_name)
        kwargs = self.model_config
        kwargs['nc'] = self.DatasetClass.num_classes()
        weights = self.DatasetClass.class_weights()
        kwargs['loss'] = [
            'categorical_crossentropy',
            # w_categorical_crossentropy(weights=weights)
        ]
        autoencoder, _ = model.build(**kwargs)
        if self.model_config['print_summary']:
            autoencoder.summary()
        try:
            h5file = self.model_config['h5file']
            print('Loading from file {}'.format(h5file))
            h5file, ext = os.path.splitext(h5file)
            autoencoder.load_weights(h5file + ext)
        except KeyError:
            autoencoder = model.transfer_weights(autoencoder)

        print('Done loading {} model!'.format(self.model_name))
        return autoencoder