#!/usr/env/python python3
# -*- coding: utf-8 -*-
# @File     : plot_roc.py
# @Time     : 2018/8/22 20:58 
# @Software : PyCharm
import os
import numpy as np
import matplotlib.pylab as plt
from absl import app, flags


def cal_rate(score_dict, thres):
    all_number = len(score_dict)
    # print all_number
    TP = 0
    FP = 0
    FN = 0
    TN = 0
    for score, label in score_dict:
        if score >= thres:
            score = 1
        if score == 1:
            if label == "tp":
                TP += 1
            else:
                FP += 1
        else:
            if label == "n":
                TN += 1
            else:
                FN += 1
    # print TP+FP+TN+FN
    accracy = float(TP + FP) / float(all_number)
    if TP + FP == 0:
        precision = 0
    else:
        precision = float(TP) / float(TP + FP)
    TPR = float(TP) / float(TP + FN)
    TNR = float(TN) / float(FP + TN)
    FNR = float(FN) / float(TP + FN)
    FPR = float(FP) / float(FP + TN)

    return accracy, precision, TPR, TNR, FNR, FPR


def plot_roc(score_list, save_dir, plot_name):

    save_path = os.path.join(save_dir, plot_name + ".jpg")
    # 按照 score 排序
    threshold_value = sorted([score for score, _ in score_list])

    threshold_num = len(threshold_value)
    accracy_array = np.zeros(threshold_num)
    precision_array = np.zeros(threshold_num)
    TPR_array = np.zeros(threshold_num)
    TNR_array = np.zeros(threshold_num)
    FNR_array = np.zeros(threshold_num)
    FPR_array = np.zeros(threshold_num)

    # calculate all the rates
    for thres in range(threshold_num):
        accracy, precision, TPR, TNR, FNR, FPR = cal_rate(score_list, threshold_value[thres])
        accracy_array[thres] = accracy
        precision_array[thres] = precision
        TPR_array[thres] = TPR
        TNR_array[thres] = TNR
        FNR_array[thres] = FNR
        FPR_array[thres] = FPR

    AUC = np.trapz(TPR_array, FPR_array)
    threshold = np.argmin(abs(FNR_array - FPR_array))
    EER = (FNR_array[threshold] + FPR_array[threshold]) / 2
    # print('EER : %f AUC : %f' % (EER, -AUC))
    plt.plot(FPR_array, TPR_array)

    plt.title('ROC')
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.text(0.2, 0, s="EER :{} AUC :{} Threshold:{}".format(round(EER, 4), round(-AUC, 4),
                                                             round(threshold_value[threshold], 4)), fontsize=10)
    plt.legend()
    plt.savefig(save_path)
    plt.show()


FLAGS = flags.FLAGS
root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))

flags.DEFINE_string(
    "save_plot_dir", os.path.join(root_dir, "results/plots"),
    "the generate plots image dir")

flags.DEFINE_string(
    "plot_name", "plt_roc_spk-01000-0.99",
    "the roc image's name")

flags.DEFINE_string(
    "score_dir", os.path.join(root_dir, "results/scores"),
    "the score txt dir")


def main(argv):

    if not os.path.exists(FLAGS.save_plot_dir):
        os.makedirs(FLAGS.save_plot_dir)
    if not os.path.exists(FLAGS.score_dir):
        os.makedirs(FLAGS.score_dir)
    score_list = []
    # 读取 score 数据文件
    with open(os.path.join(FLAGS.score_dir, "score.txt"), "r") as f:
        for line in f:
            score, label = line.split(" ")
            score_list.append([float(score), label.rstrip("\n")])
    # 绘制 ROC 曲线
    plot_roc(score_list, FLAGS.save_plot_dir, FLAGS.plot_name)


if __name__ == "__main__":
    app.run(main)