import os

import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.tools.graph_transforms import TransformGraph

import config
import download
import loss


class MSINET:
    """The class representing the MSI-Net based on the VGG16 model. It
       implements a definition of the computational graph, as well as
       functions related to network training.
    """

    def __init__(self):
        self._output = None
        self._mapping = {}

        if config.PARAMS["device"] == "gpu":
            self._data_format = "channels_first"
            self._channel_axis = 1
            self._dims_axis = (2, 3)
        elif config.PARAMS["device"] == "cpu":
            self._data_format = "channels_last"
            self._channel_axis = 3
            self._dims_axis = (1, 2)

    def _encoder(self, images):
        """The encoder of the model consists of a pretrained VGG16 architecture
           with 13 convolutional layers. All dense layers are discarded and the
           last 3 layers are dilated at a rate of 2 to account for the omitted
           downsampling. Finally, the activations from 3 layers are combined.

        Args:
            images (tensor, float32): A 4D tensor that holds the RGB image
                                      batches used as input to the network.
        """

        imagenet_mean = tf.constant([103.939, 116.779, 123.68])
        imagenet_mean = tf.reshape(imagenet_mean, [1, 1, 1, 3])

        images -= imagenet_mean

        if self._data_format == "channels_first":
            images = tf.transpose(images, (0, 3, 1, 2))

        layer01 = tf.layers.conv2d(images, 64, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv1/conv1_1")

        layer02 = tf.layers.conv2d(layer01, 64, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv1/conv1_2")

        layer03 = tf.layers.max_pooling2d(layer02, 2, 2,
                                          data_format=self._data_format)

        layer04 = tf.layers.conv2d(layer03, 128, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv2/conv2_1")

        layer05 = tf.layers.conv2d(layer04, 128, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv2/conv2_2")

        layer06 = tf.layers.max_pooling2d(layer05, 2, 2,
                                          data_format=self._data_format)

        layer07 = tf.layers.conv2d(layer06, 256, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv3/conv3_1")

        layer08 = tf.layers.conv2d(layer07, 256, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv3/conv3_2")

        layer09 = tf.layers.conv2d(layer08, 256, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv3/conv3_3")

        layer10 = tf.layers.max_pooling2d(layer09, 2, 2,
                                          data_format=self._data_format)

        layer11 = tf.layers.conv2d(layer10, 512, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv4/conv4_1")

        layer12 = tf.layers.conv2d(layer11, 512, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv4/conv4_2")

        layer13 = tf.layers.conv2d(layer12, 512, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="conv4/conv4_3")

        layer14 = tf.layers.max_pooling2d(layer13, 2, 1,
                                          padding="same",
                                          data_format=self._data_format)

        layer15 = tf.layers.conv2d(layer14, 512, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   dilation_rate=2,
                                   data_format=self._data_format,
                                   name="conv5/conv5_1")

        layer16 = tf.layers.conv2d(layer15, 512, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   dilation_rate=2,
                                   data_format=self._data_format,
                                   name="conv5/conv5_2")

        layer17 = tf.layers.conv2d(layer16, 512, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   dilation_rate=2,
                                   data_format=self._data_format,
                                   name="conv5/conv5_3")

        layer18 = tf.layers.max_pooling2d(layer17, 2, 1,
                                          padding="same",
                                          data_format=self._data_format)

        encoder_output = tf.concat([layer10, layer14, layer18],
                                   axis=self._channel_axis)

        self._output = encoder_output

    def _aspp(self, features):
        """The ASPP module samples information at multiple spatial scales in
           parallel via convolutional layers with different dilation factors.
           The activations are then combined with global scene context and
           represented as a common tensor.

        Args:
            features (tensor, float32): A 4D tensor that holds the features
                                        produced by the encoder module.
        """

        branch1 = tf.layers.conv2d(features, 256, 1,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="aspp/conv1_1")

        branch2 = tf.layers.conv2d(features, 256, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   dilation_rate=4,
                                   data_format=self._data_format,
                                   name="aspp/conv1_2")

        branch3 = tf.layers.conv2d(features, 256, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   dilation_rate=8,
                                   data_format=self._data_format,
                                   name="aspp/conv1_3")

        branch4 = tf.layers.conv2d(features, 256, 3,
                                   padding="same",
                                   activation=tf.nn.relu,
                                   dilation_rate=12,
                                   data_format=self._data_format,
                                   name="aspp/conv1_4")

        branch5 = tf.reduce_mean(features,
                                 axis=self._dims_axis,
                                 keepdims=True)

        branch5 = tf.layers.conv2d(branch5, 256, 1,
                                   padding="valid",
                                   activation=tf.nn.relu,
                                   data_format=self._data_format,
                                   name="aspp/conv1_5")

        shape = tf.shape(features)

        branch5 = self._upsample(branch5, shape, 1)

        context = tf.concat([branch1, branch2, branch3, branch4, branch5],
                            axis=self._channel_axis)

        aspp_output = tf.layers.conv2d(context, 256, 1,
                                       padding="same",
                                       activation=tf.nn.relu,
                                       data_format=self._data_format,
                                       name="aspp/conv2")
        self._output = aspp_output

    def _decoder(self, features):
        """The decoder model applies a series of 3 upsampling blocks that each
           performs bilinear upsampling followed by a 3x3 convolution to avoid
           checkerboard artifacts in the image space. Unlike all other layers,
           the output of the model is not modified by a ReLU.

        Args:
            features (tensor, float32): A 4D tensor that holds the features
                                        produced by the ASPP module.
        """

        shape = tf.shape(features)

        layer1 = self._upsample(features, shape, 2)

        layer2 = tf.layers.conv2d(layer1, 128, 3,
                                  padding="same",
                                  activation=tf.nn.relu,
                                  data_format=self._data_format,
                                  name="decoder/conv1")

        shape = tf.shape(layer2)

        layer3 = self._upsample(layer2, shape, 2)

        layer4 = tf.layers.conv2d(layer3, 64, 3,
                                  padding="same",
                                  activation=tf.nn.relu,
                                  data_format=self._data_format,
                                  name="decoder/conv2")

        shape = tf.shape(layer4)

        layer5 = self._upsample(layer4, shape, 2)

        layer6 = tf.layers.conv2d(layer5, 32, 3,
                                  padding="same",
                                  activation=tf.nn.relu,
                                  data_format=self._data_format,
                                  name="decoder/conv3")

        decoder_output = tf.layers.conv2d(layer6, 1, 3,
                                          padding="same",
                                          data_format=self._data_format,
                                          name="decoder/conv4")

        if self._data_format == "channels_first":
            decoder_output = tf.transpose(decoder_output, (0, 2, 3, 1))

        self._output = decoder_output

    def _upsample(self, stack, shape, factor):
        """This function resizes the input to a desired shape via the
           bilinear upsampling method.

        Args:
            stack (tensor, float32): A 4D tensor with the function input.
            shape (tensor, int32): A 1D tensor with the reference shape.
            factor (scalar, int): An integer denoting the upsampling factor.

        Returns:
            tensor, float32: A 4D tensor that holds the activations after
                             bilinear upsampling of the input.
        """

        if self._data_format == "channels_first":
            stack = tf.transpose(stack, (0, 2, 3, 1))

        stack = tf.image.resize_bilinear(stack, (shape[self._dims_axis[0]] * factor,
                                                 shape[self._dims_axis[1]] * factor))

        if self._data_format == "channels_first":
            stack = tf.transpose(stack, (0, 3, 1, 2))

        return stack

    def _normalize(self, maps, eps=1e-7):
        """This function normalizes the output values to a range
           between 0 and 1 per saliency map.

        Args:
            maps (tensor, float32): A 4D tensor that holds the model output.
            eps (scalar, float, optional): A small factor to avoid numerical
                                           instabilities. Defaults to 1e-7.
        """

        min_per_image = tf.reduce_min(maps, axis=(1, 2, 3), keep_dims=True)
        maps -= min_per_image

        max_per_image = tf.reduce_max(maps, axis=(1, 2, 3), keep_dims=True)
        maps = tf.divide(maps, eps + max_per_image, name="output")

        self._output = maps

    def _pretraining(self):
        """The first 26 variables of the model here are based on the VGG16
           network. Therefore, their names are matched to the ones of the
           pretrained VGG16 checkpoint for correct initialization.
        """

        for var in tf.global_variables()[:26]:
            key = var.name.split("/", 1)[1]
            key = key.replace("kernel:0", "weights")
            key = key.replace("bias:0", "biases")
            self._mapping[key] = var

    def forward(self, images):
        """Public method to forward RGB images through the whole network
           architecture and retrieve the resulting output.

        Args:
            images (tensor, float32): A 4D tensor that holds the values of the
                                      raw input images.

        Returns:
            tensor, float32: A 4D tensor that holds the values of the
                             predicted saliency maps.
        """

        self._encoder(images)
        self._aspp(self._output)
        self._decoder(self._output)
        self._normalize(self._output)

        return self._output

    def train(self, ground_truth, predicted_maps, learning_rate):
        """Public method to define the loss function and optimization
           algorithm for training the model.

        Args:
            ground_truth (tensor, float32): A 4D tensor with the ground truth.
            predicted_maps (tensor, float32): A 4D tensor with the predictions.
            learning_rate (scalar, float): Defines the learning rate.

        Returns:
            object: The optimizer element used to train the model.
            tensor, float32: A 0D tensor that holds the averaged error.
        """

        error = loss.kld(ground_truth, predicted_maps)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        optimizer = optimizer.minimize(error)

        return optimizer, error

    def save(self, saver, sess, dataset, path, device):
        """This saves a model checkpoint to disk and creates
           the folder if it doesn't exist yet.

        Args:
            saver (object): An object for saving the model.
            sess (object): The current TF training session.
            path (str): The path used for saving the model.
            device (str): Represents either "cpu" or "gpu".
        """

        os.makedirs(path, exist_ok=True)

        saver.save(sess, path + "model_%s_%s.ckpt" % (dataset, device),
                   write_meta_graph=False, write_state=False)

    def restore(self, sess, dataset, paths, device):
        """This function allows continued training from a prior checkpoint and
           training from scratch with the pretrained VGG16 weights. In case the
           dataset is either CAT2000 or MIT1003, a prior checkpoint based on
           the SALICON dataset is required.

        Args:
            sess (object): The current TF training session.
            dataset ([type]): The dataset used for training.
            paths (dict, str): A dictionary with all path elements.
            device (str): Represents either "cpu" or "gpu".

        Returns:
            object: A saver object for saving the model.
        """

        model_name = "model_%s_%s" % (dataset, device)
        salicon_name = "model_salicon_%s" % device
        vgg16_name = "vgg16_hybrid"

        ext1 = ".ckpt.data-00000-of-00001"
        ext2 = ".ckpt.index"

        saver = tf.train.Saver()

        if os.path.isfile(paths["latest"] + model_name + ext1) and \
           os.path.isfile(paths["latest"] + model_name + ext2):
            saver.restore(sess, paths["latest"] + model_name + ".ckpt")
        elif dataset in ("mit1003", "cat2000", "dutomron",
                         "pascals", "osie", "fiwi"):
            if os.path.isfile(paths["best"] + salicon_name + ext1) and \
               os.path.isfile(paths["best"] + salicon_name + ext2):
                saver.restore(sess, paths["best"] + salicon_name + ".ckpt")
            else:
                raise FileNotFoundError("Train model on SALICON first")
        else:
            if not (os.path.isfile(paths["weights"] + vgg16_name + ext1) or
                    os.path.isfile(paths["weights"] + vgg16_name + ext2)):
                download.download_pretrained_weights(paths["weights"],
                                                     "vgg16_hybrid")
            self._pretraining()

            loader = tf.train.Saver(self._mapping)
            loader.restore(sess, paths["weights"] + vgg16_name + ".ckpt")

        return saver

    def optimize(self, sess, dataset, path, device):
        """The best performing model is frozen, optimized for inference
           by removing unneeded training operations, and written to disk.

        Args:
            sess (object): The current TF training session.
            path (str): The path used for saving the model.
            device (str): Represents either "cpu" or "gpu".

        .. seealso:: https://bit.ly/2VBBdqQ and https://bit.ly/2W7YqBa
        """

        model_name = "model_%s_%s" % (dataset, device)
        model_path = path + model_name

        tf.train.write_graph(sess.graph.as_graph_def(),
                             path, model_name + ".pbtxt")

        freeze_graph.freeze_graph(model_path + ".pbtxt", "", False,
                                  model_path + ".ckpt", "output",
                                  "save/restore_all", "save/Const:0",
                                  model_path + ".pb", True, "")

        os.remove(model_path + ".pbtxt")

        graph_def = tf.GraphDef()

        with tf.gfile.Open(model_path + ".pb", "rb") as file:
            graph_def.ParseFromString(file.read())

        transforms = ["remove_nodes(op=Identity)",
                      "merge_duplicate_nodes",
                      "strip_unused_nodes",
                      "fold_constants(ignore_errors=true)"]

        optimized_graph_def = TransformGraph(graph_def,
                                             ["input"],
                                             ["output"],
                                             transforms)

        tf.train.write_graph(optimized_graph_def,
                             logdir=path,
                             as_text=False,
                             name=model_name + ".pb")