#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Han Xiao <artex.xh@gmail.com> <https://hanxiao.github.io>

# NOTE: First install bert-as-service via
# $
# $ pip install bert-serving-server
# $ pip install bert-serving-client
# $

# visualizing a 12-layer BERT

import time
from collections import namedtuple

import numpy as np
import pandas as pd
# from MulticoreTSNE import MulticoreTSNE as TSNE
from bert_serving.client import BertClient
from bert_serving.server import BertServer
from bert_serving.server.helper import get_args_parser
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.decomposition import PCA


#=========================== dump bert vectors ===========================
data = pd.read_csv('/corpus/uci-news-aggregator.csv', usecols=['TITLE', 'CATEGORY'])

# just copy paste from some Kaggle kernel ->
num_of_categories = 5000
shuffled = data.reindex(np.random.permutation(data.index))
e = shuffled[shuffled['CATEGORY'] == 'e'][:num_of_categories]
b = shuffled[shuffled['CATEGORY'] == 'b'][:num_of_categories]
t = shuffled[shuffled['CATEGORY'] == 't'][:num_of_categories]
m = shuffled[shuffled['CATEGORY'] == 'm'][:num_of_categories]
concated = pd.concat([e, b, t, m], ignore_index=True)
# Shuffle the dataset
concated = concated.reindex(np.random.permutation(concated.index))
concated['LABEL'] = 0
# One-hot encode the lab
concated.loc[concated['CATEGORY'] == 'e', 'LABEL'] = 0
concated.loc[concated['CATEGORY'] == 'b', 'LABEL'] = 1
concated.loc[concated['CATEGORY'] == 't', 'LABEL'] = 2
concated.loc[concated['CATEGORY'] == 'm', 'LABEL'] = 3

subset_text = list(concated['TITLE'].values)
subset_label = list(concated['LABEL'].values)
num_label = len(set(subset_label))

# <- just copy paste from some Kaggle kernel

print('min_seq_len: %d' % min(len(v.split()) for v in subset_text))
print('max_seq_len: %d' % max(len(v.split()) for v in subset_text))
print('unique label: %d' % num_label)

pool_layer = 1
subset_vec_all_layers = []
port = 6006
port_out = 6007

common = [
    '-model_dir', '/bert_model/chinese_L-12_H-768_A-12/',
    '-num_worker', '2',
    '-port', str(port),
    '-port_out', str(port_out),
    '-max_seq_len', '20',
    # '-client_batch_size', '2048',
    '-max_batch_size', '256',
    # '-num_client', '1',
    '-pooling_strategy', 'REDUCE_MEAN',
    '-pooling_layer', '-2',
    '-gpu_memory_fraction', '0.2',
    '-device','3',
]
args = get_args_parser().parse_args(common)

for pool_layer in range(1, 13):
    setattr(args, 'pooling_layer', [-pool_layer])
    server = BertServer(args)
    server.start()
    print('wait until server is ready...')
    time.sleep(20)
    print('encoding...')
    bc = BertClient(port=port, port_out=port_out, show_server_config=True)
    subset_vec_all_layers.append(bc.encode(subset_text))
    bc.close()
    server.close()
    print('done at layer -%d' % pool_layer)

#save bert vectors and labels
stacked_subset_vec_all_layers = np.stack(subset_vec_all_layers)
np.save('example7_5k_2',stacked_subset_vec_all_layers)
np_subset_label = np.array(subset_label)
np.save('example7_5k_2_subset_label',np_subset_label)

#load bert vectors and labels
subset_vec_all_layers = np.load('example7_5k_mxnet.npy')
np_subset_label = np.load('example7_5k_mxnet_subset_label.npy')
subset_label = np_subset_label.tolist()
#=========================== visualize ===========================
def vis(embed, vis_alg='PCA', pool_alg='REDUCE_MEAN'):
    plt.close()
    fig = plt.figure()
    plt.rcParams['figure.figsize'] = [21, 7]
    for idx, ebd in enumerate(embed):
        ax = plt.subplot(2, 6, idx + 1)
        vis_x = ebd[:, 0]
        vis_y = ebd[:, 1]
        plt.scatter(vis_x, vis_y, c=subset_label, cmap=ListedColormap(["blue", "green", "yellow", "red"]), marker='.',
                    alpha=0.7, s=2)
        ax.set_title('pool_layer=-%d' % (idx + 1))
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.1, right=0.95, top=0.9)
    cax = plt.axes([0.96, 0.1, 0.01, 0.3])
    cbar = plt.colorbar(cax=cax, ticks=range(num_label))
    cbar.ax.get_yaxis().set_ticks([])
    for j, lab in enumerate(['ent.', 'bus.', 'sci.', 'heal.']):
        cbar.ax.text(.5, (2 * j + 1) / 8.0, lab, ha='center', va='center', rotation=270)
    fig.suptitle('%s visualization of BERT layers using "bert-as-service" (-pool_strategy=%s)' % (vis_alg, pool_alg),
                 fontsize=14)
    plt.show()


pca_embed = [PCA(n_components=2).fit_transform(v) for v in subset_vec_all_layers]
vis(pca_embed)

# if False:
#     tsne_embed = [TSNE(n_jobs=8).fit_transform(v) for v in subset_vec_all_layers]
#     vis(tsne_embed, 't-SNE')