"""
"""
import os

from keras import backend as K
from keras.layers import Embedding

from keras_exp._mixin_common import mixedomatic

if K.backend() == 'tensorflow':
    import tensorflow as tf
    from tensorflow.contrib.tensorboard.plugins import projector
    from keras.callbacks import TensorBoard


__all__ = ('TensorBoardEmbedding', 'find_embedding_layers', )


def find_embedding_layers(layers):
    '''Recursively find embedding layers.

    :param layers: The keras model layers. Typically obtained via model.layers
    :type layers: list
    '''
    elayers = []
    for layer in layers:
        if isinstance(layer, Embedding):
            elayers.append(layer)

        slayers = getattr(layer, 'layers', [])
        elayers += find_embedding_layers(slayers)

    return elayers


class TensorBoardEmbeddingMixin(object):
    """Tensorboard mixin for Embeddings.

    This has to be mixed in with TensorBoard class or a derived TensoBoard cls.
    Must specify arguments as keywords.

    # Mixin Arguments
        embeddings_freq: frequency (in epochs) at which selected embedding
            layers will be saved.
        embeddings_layer_names: a list of names of layers to keep eye on. If
            None or empty list all the embedding layer will be watched.
        embeddings_metadata: a dictionary which maps layer name to a file name
            in which metadata for this embedding layer is saved. See the
            [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
            about metadata files format. In case if the same metadata file is
            used for all embedding layers, string can be passed.
    """
    def __init__(self,
                 embeddings_freq=1,
                 embeddings_layer_names=None,
                 embeddings_metadata={}):
        self.embeddings_freq = embeddings_freq
        self.embeddings_layer_names = embeddings_layer_names
        self.embeddings_metadata = embeddings_metadata

    def set_model(self, model):
        if self.embeddings_freq:
            self.saver = tf.train.Saver()

            embeddings_layer_names = self.embeddings_layer_names

            elayers = find_embedding_layers(model.layers)
            if not embeddings_layer_names:
                embeddings_layer_names = [layer.name for layer in elayers]

            embeddings = {layer.name: layer.weights[0] for layer in elayers
                          if layer.name in embeddings_layer_names}

            embeddings_metadata = {}

            if not isinstance(self.embeddings_metadata, str):
                embeddings_metadata = self.embeddings_metadata
            else:
                embeddings_metadata = {layer_name: self.embeddings_metadata
                                       for layer_name in embeddings.keys()}

            config = projector.ProjectorConfig()
            self.embeddings_logs = []

            for layer_name, tensor in embeddings.items():
                embedding = config.embeddings.add()
                embedding.tensor_name = tensor.name

                self.embeddings_logs.append(os.path.join(self.log_dir,
                                                         layer_name + '.ckpt'))

                if layer_name in embeddings_metadata:
                    embedding.metadata_path = embeddings_metadata[layer_name]

            projector.visualize_embeddings(self.writer, config)

    def on_epoch_end(self, epoch, logs=None):
        if self.embeddings_freq and self.embeddings_logs:
            if epoch % self.embeddings_freq == 0:
                for log in self.embeddings_logs:
                    self.saver.save(self.sess, log, epoch)


@mixedomatic()
class TensorBoardEmbedding(TensorBoardEmbeddingMixin, TensorBoard):
    """Tensorboard for Embeddings.

    Refer to classes TensorBoardEmbedding and TensorBoardEmbeddingMixin for
    arguments. Must specify arguments as keywords i.e. kwargs to __init__.
    """

    def set_model(self, model):
        TensorBoard.set_model(self, model)
        TensorBoardEmbeddingMixin.set_model(self, model)

    def on_epoch_end(self, epoch, logs=None):
        TensorBoardEmbeddingMixin.on_epoch_end(self, epoch)
        TensorBoard.on_epoch_end(self, epoch, logs)