import sys
sys.path.append("..")
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
import matplotlib.cm as cm
plt.switch_backend('agg')
import torch
from dataset import MixamoDatasetForFull
from common import config
from model import get_autoencoder
import cv2
import time
import argparse
import os


def tsne_on_pca(arr, is_PCA=True):
    """
    visualize through t-sne on pca reduced data
    :param arr: (nr_examples, nr_features)
    :return:
    """
    if is_PCA:
        pca_50 = PCA(n_components=50)
        arr = pca_50.fit_transform(arr)
    tsne_2 = TSNE(n_components=2)
    res = tsne_2.fit_transform(arr)
    return res


def cluster_body(net, cluster_data, device, save_path):
    data, characters = cluster_data[0], cluster_data[2]
    data = data[:, :, 0, :, :]
    # data = data.reshape(-1, data.shape[2], data.shape[3], data.shape[4])

    nr_mv, nr_char = data.shape[0], data.shape[1]
    labels = np.arange(0, nr_char).reshape(1, -1)
    labels = np.tile(labels, (nr_mv, 1)).reshape(-1)
    
    if hasattr(net, 'static_encoder'):
        features = net.static_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    else:
        features = net.body_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features, is_PCA=False)
    features_2d = features_2d.reshape(nr_mv, nr_char, -1)

    plt.figure(figsize=(7, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_char))
    for i in range(nr_char):
        x = features_2d[:, i, 0]
        y = features_2d[:, i, 1]
        plt.scatter(x, y, c=colors[i], label=characters[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0,0,0.75,1])
    plt.savefig(save_path)


def cluster_view(net, cluster_data, device, save_path):
    data, views = cluster_data[0], cluster_data[3]
    idx = np.random.randint(data.shape[1] - 1)  # np.linspace(0, data.shape[1] - 1, 4, dtype=int).tolist()
    data = data[:, idx, :, :, :]

    nr_mc, nr_view = data.shape[0], data.shape[1]
    labels = np.arange(0, nr_view).reshape(1, -1)
    labels = np.tile(labels, (nr_mc, 1)).reshape(-1)
    
    if hasattr(net, 'static_encoder'):
        features = net.static_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    else:
        features = net.view_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features, is_PCA=False)
    features_2d = features_2d.reshape(nr_mc, nr_view, -1)

    plt.figure(figsize=(7, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_view))
    for i in range(nr_view):
        x = features_2d[:, i, 0]
        y = features_2d[:, i, 1]
        plt.scatter(x, y, c=colors[i], label=views[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0, 0, 0.75, 1])
    plt.savefig(save_path)


def cluster_motion(net, cluster_data, device, save_path, nr_anims=15, mode='both'):
    data, animations = cluster_data[0], cluster_data[1]
    idx = np.linspace(0, data.shape[0] - 1, nr_anims, dtype=int).tolist()
    data = data[idx]
    animations = animations[idx]
    if mode == 'body':
        data = data[:, :, 0, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])
    elif mode == 'view':
        data = data[:, 3, :, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])
    else:
        data = data[:, :4, ::2, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4])

    nr_anims, nr_cv = data.shape[:2]
    labels = np.arange(0, nr_anims).reshape(-1, 1)
    labels = np.tile(labels, (1, nr_cv)).reshape(-1)
    
    features = net.mot_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3]).to(device))
    features = features.detach().cpu().numpy().reshape(features.shape[0], -1)

    features_2d = tsne_on_pca(features)
    features_2d = features_2d.reshape(nr_anims, nr_cv, -1)
    if features_2d.shape[1] < 5:
        features_2d = np.tile(features_2d, (1, 2, 1))

    plt.figure(figsize=(8, 4))
    colors = cm.rainbow(np.linspace(0, 1, nr_anims))
    for i in range(nr_anims):
        x = features_2d[i, :, 0]
        y = features_2d[i, :, 1]
        plt.scatter(x, y, c=colors[i], label=animations[i])

    plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
    plt.tight_layout(rect=[0,0,0.8,1])
    plt.savefig(save_path)


def test():
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--name', type=str, choices=['skeleton', 'view', 'full'], required=True,
                        help='which structure to use.')
    parser.add_argument('-p', '--model_path', type=str, default="model/pretrained_view.pth")
    parser.add_argument('--phase', type=str, default="test", choices=['train', 'test'])
    parser.add_argument('-g', '--gpu_ids', type=int, default=0, required=False, help="specify gpu ids")
    args = parser.parse_args()

    # set config
    config.initialize(args)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load trained model
    net = get_autoencoder(config)
    net.load_state_dict(torch.load(args.model_path))
    net.to(config.device)
    net.eval()
    
    # get dataset
    train_ds = MixamoDatasetForFull(args.phase, config)
    cluster_data = train_ds.get_cluster_data()

    # score, img = cluster_body(net, cluster_data, device, './cluster_body.png')
    if args.name == 'view':
        cluster_view(net, cluster_data, device, './cluster_view.png')
        cluster_motion(net, cluster_data, device, './cluster_motion.png')
    elif args.name == 'skeleton':
        cluster_body(net, cluster_data, device, './cluster_body.png')
        cluster_motion(net, cluster_data, device, './cluster_motion.png', mode='body')
    else: 
        cluster_motion(net, cluster_data, device, './cluster_motion.png')


if __name__ == '__main__':
    test()