"""Functions for plotting EEG processing results.
"""

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sklearn.metrics as skm
import seaborn as sns
import pandas as pd
import pywt
import scipy.signal
import sklearn.decomposition
from matplotlib.gridspec import GridSpec
from pylab import rcParams
import itertools


def plot_confusion_matrix(path, cm, target_names, title='Confusion matrix ', cmap=None, normalize=True):
    """Produces a plot for a confusion matrix and saves it to file.

    Args:
        path (str): Filename of produced plot
        cm (ndarray): confusion matrix from sklearn.metrics.confusion_matrix
        target_names ([str]): given classification classes such as [0, 1, 2] the
            class names, for example: ['high', 'medium', 'low']
        title (str): the text to display at the top of the matrix
        cmap: the gradient of the values displayed from matplotlib.pyplot.cm see
            http://matplotlib.org/examples/color/colormaps_reference.html
            plt.get_cmap('jet') or plt.cm.Blues
        normalize (bool): if False, plot the raw numbers. If True, plot the
            proportions

    Example:
        plot_confusion_matrix(cm           = cm,              # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    References:
        http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """

    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    fig = plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    #plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()
    fig.savefig(path)


# TODO: check formatting (whitespaces, etc)
# TODO: check all variable names and improve them
def ROC_curve(Y_pred, Y_test, fig=None):
    Y_score = np.array(Y_pred)
    # The following were moved inside the function call (roc_curve) to avoid
    # potential side effects of this functin
    # Y_score -=1
    # Y_test -=1

    # print (roc_auc_score(y_test, y_score))

    fpr, tpr, _ = sklearn.metrics.roc_curve(Y_test - 1, Y_score - 1)

    # plotting
    if fig is None:
        fig = plt.figure()
    plt.plot(fpr, tpr, color= 'red', lw = 2)
    plt.plot([0, 1], [0, 1], color='navy', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Roc curve')
    plt.legend(loc="lower right")
    plt.show()


def confusion_matrix(true_labels, predicted_labels, cmap=plt.cm.Blues):
    cm = skm.confusion_matrix(true_labels, predicted_labels)
    # TODO:
    # print(cm)
    # Show confusion matrix in a separate window ?
    plt.matshow(cm,cmap=cmap)
    plt.title('Confusion matrix')
    plt.colorbar()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()


# TODO: permit the user to specify the figure where this plot shall appear
def accuracy_results_plot(data_path):
    data = pd.read_csv(data_path,index_col=0)
    sns.boxplot(data=data)
    sns.set(rc={"figure.figsize": (9, 6)})
    ax = sns.boxplot( data=data)
    ax.set_xlabel(x_label,fontsize=15)
    ax.set_ylabel(y_label,fontsize=15)
    plt.show()


def reconstruct_without_approx(xs, labels, level, fig=None):
    # reconstruct
    rs = [pywt.upcoef('d', x, 'db4', level=level) for x in xs]

    # generate plot
    if fig is None:
        fig = plt.figure()
    for i, x in enumerate(xs):
        plt.plot((np.abs(x))**2, label="Power of reconstructed signal ({})".format(labels[i]))

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    return rs, fig


def reconstruct_with_approx(cDs, labels, wavelet, fig=None):
    rs = [pywt.idwt(cA=None, cD=cD, wavelet=wavelet) for cD in cDs]

    if fig is None:
        fig = plt.figure()

    for i, r in enumerate(rs):
        plt.plot((np.abs(r))**2, label="Power of reconstructed signal ({})".format(labels[i]))
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    return rs, fig


def fft(x, fs, fig_fft=None, fig_psd=None):
    t = np.arange(fs)
    signal_fft = np.fft.fft(x)
    signal_psd = np.abs(signal_fft)**2
    freq = np.linspace(0, fs, len(signal_fft))
    freq1 = np.linspace(0, fs, len(signal_psd))

    if fig_fft is None:
        fig_fft = plt.figure()
    plt.plot(freq, signal_fft, label="fft")

    if fig_psd is None:
        fig_psd = plt.figure()
    plt.plot(freq, signal_psd, label="PSD")

    return signal_fft, signal_psd, fig_fft, fig_psd


def dwt(approx, details, labels, level, sampling_freq, class_str=None):
    """
    Plot the results of a DWT transform.
    """

    fig, axis = plt.subplots(level+1, 1, figsize=(8, 8))
    fig.tight_layout()

    # plot the approximation
    for i, l in enumerate(labels):
        axis[0].plot(approx[i], label=l)
    axis[0].legend()
    if class_str is None:
        axis[0].set_title('DWT approximations (level={}, sampling-freq={}Hz)'.format(level, sampling_freq))
    else:
        axis[0].set_title('DWT approximations, {} (level={}, sampling-freq={}Hz)'.format(class_str, level, sampling_freq))
    axis[0].set_ylabel('(A={})'.format(level))

    # build the rows of detail coefficients
    for j in range (1,level+1):
        for i, l in enumerate(labels):
            axis[j].plot(details[i][j-1], label=l)
        if class_str is None:
            axis[j].set_title('DWT Coeffs (level{}, sampling-freq={}Hz)'.format(level, sampling_freq))
        else:
            axis[j].set_title('DWT Coeffs, {} (level={}, sampling-freq={}Hz)'.format(class_str, level, sampling_freq))
        axis[j].legend()
        axis[j].set_ylabel('(D={})'.format(j))

    return axis


def welch_psd(xs, labels, sampling_freq, fig=None):
    """Compute and plot the power spectrum density (PSD) using Welch's method.
    """

    fs = []
    ps = []
    for i, x in enumerate(xs):
        f, p = scipy.signal.welch(x, sampling_freq, 'flattop', scaling='spectrum')
        fs.append(f)
        ps.append(p)

    if fig is None:
        fig = plt.figure()

    plt.subplots_adjust(hspace=0.4)
    for i, p in enumerate(ps):
        plt.semilogy(f/8, p.T, label=labels[i])

    plt.xlabel('frequency [Hz]')
    plt.ylabel('PSD')

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.grid()
    plt.show()

    return ps, fig



def artifact_removal(X, S, S_reconst, fig=None):
    """Plot the results of an artifact removal.

    This function displays the results after artifact removal, for instance
    performed via :func:`gumpy.signal.artifact_removal`.

    Parameters
    ----------
    X:
        Observations
    S:
        True sources
    S_reconst:
        The reconstructed signal
    """

    if fig is None:
        fig = plt.figure()

    models = [X, S, S_reconst]
    names = ['Observations (mixed signal)',
         'True Sources',
         'ICA recovered signals']
    for ii, (model, name) in enumerate(zip(models, names), 1):
        plt.subplot(3, 1, ii)
        plt.title(name)
    plt.subplots_adjust(0.09, 0.04, 0.94, 0.94, 0.26, 0.46)
    plt.show()


def PCA_2D(X, X_train, Y_train, colors=None):
    # computation
    pca_2comp = PCA(n_components=2)
    X_2comp = pca_2comp.fit_transform(X)

    # color and figure initialization
    if colors is None:
        colors = ['red','cyan']
    if fig is None:
        fig = plt.figure()

    # plotting
    fig.suptitle('2D - Data')
    ax = fig.add_subplot(1,1,1)
    ax.scatter(X_train.T[0], X_train.T[1], alpha=0.5,
               c=Y_train, cmap=mpl.colors.ListedColormap(colors))
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')


def PCA_3D(X, X_train, Y_train, fig=None, colors=None):
    # computation
    pca_3comp = sklearn.decomposition.PCA(n_components=3)
    X_3comp = pca_3comp.fit_transform(X)

    # color and figure initialization
    if colors is None:
        colors = ['red','cyan']
    if fig is None:
        fig = plt.figure()

    # plotting
    fig.suptitle('3D - Data')
    ax = fig.add_subplot(1,1,1, projection='3d')
    ax.scatter(X_train.T[0], X_train.T[1], X_train.T[2], alpha=0.5,
               c=Y_train, cmap=mpl.colors.ListedColormap(colors))
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.set_zlabel('x3')


# TODO: allow user to pass formatting control, e.g. colors, cmap, etc
def PCA(ttype, X, X_train, Y_train, fig=None, colors=None):
    plot_fns = {'2D': PCA_2D, '3D': PCA_3D}
    if not ttype in plot_fns:
        raise Exception("Transformation type '{ttype}' unknown".format(ttype=ttype))
    plot_fns[ttype](X, X_train, Y_train, fig, colors)



def EEG_bandwave_visualizer(data, band_wave, n_trial, lo, hi, fig=None):
    if not fig:
        fig = plt.figure()

    plt.clf()
    plt.plot(band_wave[data.trials[n_trial]-data.mi_interval[0]*data.sampling_freq : data.trials[n_trial]+data.mi_interval[0]*data.sampling_freq, 0],
            alpha=0.7, label='C3')
    plt.plot(band_wave[data.trials[n_trial]-data.mi_interval[0]*data.sampling_freq : data.trials[n_trial]+data.mi_interval[0]*data.sampling_freq, 1],
            alpha=0.7, label='C4')
    plt.plot(band_wave[data.trials[n_trial]-data.mi_interval[0]*data.sampling_freq : data.trials[n_trial]+data.mi_interval[0]*data.sampling_freq, 2],
            alpha=0.7, label='Cz')

    plt.legend()
    plt.title("Filtered data  (Band wave {}-{})".format(lo, hi))


# TODO: check if this is too specific
# TODO: documentation
# TODO: units missing
def average_power(data_class1, lowcut, highcut, interval, sampling_freq, logarithmic_power):
    fs = sampling_freq
    if logarithmic_power:
        power_c3_c1_a  = np.log(np.power(data_class1[0], 2).mean(axis=0))
        power_c4_c1_a  = np.log(np.power(data_class1[1], 2).mean(axis=0))
        power_cz_c1_a  = np.log(np.power(data_class1[2], 2).mean(axis=0))
        power_c3_c2_a  = np.log(np.power(data_class1[3], 2).mean(axis=0))
        power_c4_c2_a  = np.log(np.power(data_class1[4], 2).mean(axis=0))
        power_cz_c2_a  = np.log(np.power(data_class1[5], 2).mean(axis=0))
    else:
        power_c3_c1_a  = np.power(data_class1[0], 2).mean(axis=0)
        power_c4_c1_a  = np.power(data_class1[1], 2).mean(axis=0)
        power_cz_c1_a  = np.power(data_class1[2], 2).mean(axis=0)
        power_c3_c2_a  = np.power(data_class1[3], 2).mean(axis=0)
        power_c4_c2_a  = np.power(data_class1[4], 2).mean(axis=0)
        power_cz_c2_a  = np.power(data_class1[5], 2).mean(axis=0)

    # time indices
    t = np.linspace(interval[0],interval[1],len(power_c3_c1_a[fs*interval[0]:fs*interval[1]]))

    # first figure, left motor imagery
    plt.figure()
    plt.plot(t, power_c3_c1_a[fs*interval[0]:fs*interval[1]], c='blue',
                    label='C3', alpha=0.7)
    plt.plot(t,power_c4_c1_a [fs*interval[0]:fs*interval[1]],c='red',
                    label='C4', alpha=0.7)
    plt.legend()
    plt.xlabel('Time')
    if logarithmic_power:
        plt.ylabel('Logarithmic Power')
    else:
        plt.ylabel('Power')
    plt.title("Left motor imagery movements ".format(lowcut, highcut))
    plt.show()

    # second figure, right motor imagery
    plt.figure()
    plt.clf()
    plt.plot(t, power_c3_c2_a[fs*interval[0] : fs*interval[1]], c='blue', label='C3', alpha=0.7)
    plt.plot(t, power_c4_c2_a[fs*interval[0] : fs*interval[1]], c='red', label='C4', alpha=0.7)
    plt.legend()
    plt.xlabel('Time')
    if logarithmic_power:
        plt.ylabel('Logarithmic Power')
    else:
        plt.ylabel('Power')
    plt.title("Right motor imagery movements".format(lowcut, highcut))