"""
This script generates embedding visualization for features produced by the apply_model script.
"""
import argparse
import inspect
import os
import pickle
import sys
from multiprocessing import Pool

import matplotlib as mpl
import torch

# To facilitate plotting on a headless server
mpl.use('Agg')
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
import numpy as np
from sklearn.manifold import TSNE, Isomap, MDS
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
from util.misc import load_numpy_image, save_numpy_image
from PIL import Image
########################################################################################################################
def tsne(features, n_components=2):
    """
    Returns the embedded points for TSNE.
    Parameters
    ----------
    features: numpy.ndarray
        contains the input feature vectors.
    n_components: int
        number of components to transform the features into

    Returns
    -------
    embedding: numpy.ndarray
        x,y(z) points that the feature vectors have been transformed into
    """
    embedding = TSNE(n_components=n_components).fit_transform(features)
    return embedding


def isomap(features, n_components=2):
    """
    Returns the embedded points for Isomap.
    Parameters
    ----------
    features: numpy.ndarray
        contains the input feature vectors.
    n_components: int
        number of components to transform the features into

    Returns
    -------
    embedding: numpy.ndarray
        x,y(z) points that the feature vectors have been transformed into
    """
    embedding = Isomap(n_components=n_components, n_jobs=-1).fit_transform(features)
    return embedding

def mds(features, n_components=2):
    """
    Returns the embedded points for MDS.
    Parameters
    ----------
    features: numpy.ndarray
        contains the input feature vectors.
    n_components: int
        number of components to transform the features into

    Returns
    -------
    embedding: numpy.ndarray
        x,y(z) points that the feature vectors have been transformed into
    """
    embedding = MDS(n_components=n_components, n_jobs=-1).fit_transform(features)
    return embedding

def pca(features, n_components=2):
    """
    Returns the embedded points for PCA.
    Parameters
    ----------
    features: numpy.ndarray
        contains the input feature vectors.
    n_components: int
        number of components to transform the features into

    Returns
    -------
    embedding: numpy.ndarray
        x,y(z) points that the feature vectors have been transformed into
    """
    embedding = PCA(n_components=n_components).fit_transform(features)
    return embedding

########################################################################################################################
def _make_embedding(features, labels, embedding, three_d=False):
    """
    Generate an embedding image using features from a model.

    Adapted from https://indico.io/blog/visualizing-with-t-sne/
    Parameters
    ----------
    features: numpy.ndarray
        contains the feature array generated by the apply_model runner class
    labels:  numpy.ndarray
        contains labels for corresponding feature vectors
    embedding: str
        type of embedding to use
    three_d: bool
        specify whether to generate 2d or 3d visualization

    Returns
    -------
    data: numpy.ndarray
        contains an image of the plotted visualization

    """

    plt.style.use(['seaborn-white', 'seaborn-paper'])
    fig = plt.figure(figsize=(8, 8))
    plt.tight_layout()
    mpl.rc("font", family="Times New Roman")

    X = features

    le = LabelEncoder()

    labels = le.fit_transform(labels)

    cmap = plt.cm.get_cmap('jet', len(np.unique(labels)))

    if three_d:
        ax = plt.axes(projection='3d')
        X_embedded = getattr(sys.modules[__name__], embedding)(X, n_components=3)
        ax.scatter3D(X_embedded[:, 0], X_embedded[:, 1], X_embedded[:, 2], c=labels, cmap=cmap)
    else:
        X_embedded = getattr(sys.modules[__name__], embedding)(X)
        plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=labels, cmap=cmap)

    # plt.colorbar(ticks=range(len(np.unique(labels))))

    fig.canvas.draw()
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    fig.clf()
    plt.close()
    return data


def _load_thumbnail(path):
    """
    Return a thumbnail version of any image
    Parameters
    ----------
    path: str
        path to an image

    Returns
    -------
    img: numpy.ndarray
        resized image of size 16x16
    """
    img = Image.open(path).thumbnail((16, 16))
    img = np.array(img)
    return img

def _make_folder_if_not_exists(path):
    if not os.path.exists(path):
        os.makedirs(path)

def _main(args):
    """
    Main routine of script to generate embeddings.
    Parameters
    ----------
    args : argparse.Namespace
        contains all arguments parsed from input

    Returns
    -------
    None

    """
    with open(args.results_file, 'rb') as f:
        results = pickle.load(f)

    features, preds, labels, filenames = results

    _make_folder_if_not_exists(os.path.dirname(args.output_file))

    if args.tensorboard:
        if args.output.endwith('.png'):
            output_loc = os.path.dirname(args.output)
        else:
            output_loc = args.output
        writer = SummaryWriter(log_dir=output_loc)
        # with Pool(16) as pool:
        # images = pool.map(_load_thumbnail, filenames)
        writer.add_embedding(torch.from_numpy(features), metadata=labels,
                             # label_img=torch.from_numpy(np.array(images)).unsqueeze(1))
                             label_img=None)
        return
    else:
        viz_img = _make_embedding(features=features, labels=labels, embedding=args.embedding, three_d=args.three_d)
        save_numpy_image(args.output_file, viz_img)
        return


if __name__ == "__main__":
    # Embedding options:
    embedding_options = [name[0] for name in inspect.getmembers(sys.modules[__name__], inspect.isfunction)]

    parser = argparse.ArgumentParser()

    parser.add_argument('--results-file',
                        type=str,
                        help='path to a results pickle file')

    parser.add_argument('--embedding',
                        help='which embedding to use for the features',
                        choices=embedding_options,
                        type=str)
    parser.add_argument('--output',
                        type=str,
                        default='./output.png',
                        help='path to generate output image')
    parser.add_argument('--3d',
                        dest='three_d',
                        action='store_true',
                        default=False,
                        help='enable 3d plots')
    parser.add_argument('--tensorboard',
                        action='store_true',
                        default=False,
                        help='store embeddings to tensorboard')

    args = parser.parse_args()

    _main(args)