import matplotlib.pyplot as plt
import numpy as np
import os
from rdkit import Chem
from experiment.figure.confusion_plot import find_average_trial
from sklearn.metrics import roc_curve, precision_recall_curve, confusion_matrix
from sklearn.metrics import roc_auc_score, average_precision_score

plt.rcParams['font.size'] = 16
plt.rcParams['axes.axisbelow'] = True
red, orange, green, blue, weave = "#CC3311", "#ED7D0F", "#009988", "#0077BB", "#cccccc"


def draw_pr_curve(dataset, base_path):
    if dataset == "bace_cla":
        c = green
    else:
        c = blue
    for i in range(5, 6):
        path = base_path + "trial_{}/".format(i)
        # Load true, pred value
        true_y, pred_y, weave_y = [], [], []

        if os.path.isfile(path + "weave_trial_{}.sdf".format(i)):
            is_weave = True
            mols = Chem.SDMolSupplier(path + "weave_trial_{}.sdf".format(i))
            for mol in mols:
                if "true" not in mol.GetPropNames():
                    continue
                true_y.append(float(mol.GetProp("true")))
                pred_y.append(float(mol.GetProp("pred")))
                weave_y.append(1 - float(mol.GetProp("pred_weave")))
        else:
            is_weave = False
            mols = Chem.SDMolSupplier(path + "test.sdf")
            for mol in mols:
                true_y.append(float(mol.GetProp("true")))
                pred_y.append(float(mol.GetProp("pred")))

        true_y = np.array(true_y, dtype=float)
        pred_y = np.array(pred_y, dtype=float)
        weave_y = np.array(weave_y, dtype=float)

        # Get roc / precision and recall
        precision, recall, _ = precision_recall_curve(true_y, pred_y)
        fpr, tpr, _ = roc_curve(true_y, pred_y)

        if is_weave:
            precision_w, recall_w, _ = precision_recall_curve(true_y, weave_y)
            fpr_w, tpr_w, _ = roc_curve(true_y, weave_y)

        print("ROC 3DGCN: {}".format(roc_auc_score(true_y, pred_y)))
        if is_weave: print("ROC Weave: {}".format(roc_auc_score(true_y, weave_y)))
        print("PR 3DGCN: {}".format(average_precision_score(true_y, pred_y)))
        if is_weave: print("PR Weave: {}".format(average_precision_score(true_y, weave_y)))

        # Generate canvas
        w, h = plt.figaspect(1)
        plt.figure(figsize=(w, h))

        # Draw ROC curve
        if is_weave: plt.step(fpr_w, tpr_w, color=weave, alpha=1, where='post')
        plt.step(fpr, tpr, color=c, alpha=1, where='post')

        plt.ylabel('True Positive Rate')
        plt.xlabel('False Positive Rate')

        plt.ylim([0.0, 1.0])
        plt.xlim([0.0, 1.0])

        fig_name = path + "ROC_curve_trial" + str(i) + ".png"
        plt.savefig(fig_name, dpi=600)
        plt.clf()
        print("ROC curve figure saved on {}".format(fig_name))

        # Draw PR curve
        plt.figure(figsize=(w, h))
        if is_weave: plt.step(recall_w, precision_w, color=weave, alpha=1, where='post')
        plt.step(recall, precision, color=c, alpha=1, where='post')

        plt.ylabel('Precision')
        plt.xlabel('Recall')

        plt.ylim([0.0, 1.0])
        plt.xlim([0.0, 1.0])

        fig_name = path + "PR_curve_trial" + str(i) + ".png"
        plt.savefig(fig_name, dpi=600)
        plt.clf()
        print("PR curve figure saved on {}".format(fig_name))


def draw_confusion_matrix(dataset, model, set_trial=None, filename="test_results.sdf"):
    path = find_average_trial(dataset, model, metric="test_pr") if set_trial is None \
        else "../result/{}/{}/{}/".format(model, dataset, set_trial)

    # Load true, pred value
    true_y, pred_y = [], []
    mols = Chem.SDMolSupplier(path + filename)

    for mol in mols:
        true_y.append(float(mol.GetProp("true")))
        pred_y.append(float(mol.GetProp("pred")))

    true_y = np.array(true_y, dtype=float)
    pred_y = np.array(pred_y, dtype=float).round()

    # Get precision and recall
    confusion = confusion_matrix(true_y, pred_y)
    tn, fp, fn, tp = confusion.ravel()

    print("tn: {}, fp: {}, fn: {}, tp: {}".format(tn, fp, fn, tp))