from copy import deepcopy
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pdb

def binarize_labels(true_labels, pred_labels):
    srcids = list(pred_labels.keys())
    tot_labels = [list(labels) for labels in
               list(pred_labels.values()) + list(true_labels.values())]
    mlb = MultiLabelBinarizer().fit(tot_labels)
    pred_mat = mlb.transform(pred_labels.values())
    true_mat = mlb.transform(true_labels.values())
    return true_mat, pred_mat

def get_micro_f1(true_labels, pred_labels):
    true_mat, pred_mat = binarize_labels(true_labels, pred_labels)
    return get_micro_f1_mat(true_mat, pred_mat)

def get_macro_f1(true_labels, pred_labels):
    true_mat, pred_mat = binarize_labels(true_labels, pred_labels)
    return get_macro_f1_mat(true_mat, pred_mat)

def get_macro_f1_mat(true_mat, pred_mat):
    assert true_mat.shape == pred_mat.shape
    f1s = []
    for i in range(0, true_mat.shape[1]):
        if 1 not in true_mat[:,i]:
            continue
        f1 = f1_score(true_mat[:,i], pred_mat[:,i])
        f1s.append(f1)
    return np.mean(f1s)

def get_multiclass_micro_f1(true_labels, pred_labels):
    le = LabelEncoder()
    #pred_mat, true_mat = binarize_labels(true_labels, pred_labels)
    #f1_custom = get_micro_f1_mat(true_mat, pred_mat)
    srcids = list(true_labels.keys())
    true_label_list = [true_labels[srcid] for srcid in srcids]
    pred_label_list = [pred_labels[srcid] for srcid in srcids]
    le = LabelEncoder()
    le.fit(true_label_list + pred_label_list)
    true_encoded = le.transform(true_label_list)
    pred_encoded = le.transform(pred_label_list)
    f1_micro = f1_score(true_encoded, pred_encoded, average='micro')
    #f1_weighted = f1_score(true_encoded, pred_encoded, average='weighted')
    #pdb.set_trace()
    return f1_micro

def get_multiclass_macro_f1(true_labels, pred_labels):
    le = LabelEncoder()
    #pred_mat, true_mat = binarize_labels(true_labels, pred_labels)
    #f1_custom = get_micro_f1_mat(true_mat, pred_mat)
    srcids = list(true_labels.keys())
    true_label_list = [true_labels[srcid] for srcid in srcids]
    pred_label_list = [pred_labels[srcid] for srcid in srcids]
    le = LabelEncoder()
    le.fit(true_label_list + pred_label_list)
    true_encoded = le.transform(true_label_list)
    pred_encoded = le.transform(pred_label_list)
    f1_micro = f1_score(true_encoded, pred_encoded, average='macro')
    #f1_weighted = f1_score(true_encoded, pred_encoded, average='weighted')
    #pdb.set_trace()
    return f1_micro



def get_micro_f1_mat(true_mat, pred_mat):
    TP = np.sum(np.bitwise_and(true_mat==1, pred_mat==1))
    TN = np.sum(np.bitwise_and(true_mat==0, pred_mat==0))
    FN = np.sum(np.bitwise_and(true_mat==1, pred_mat==0))
    FP = np.sum(np.bitwise_and(true_mat==0, pred_mat==1))
    micro_prec = TP / (TP + FP)
    micro_rec = TP / (TP + FN)
    return 2 * micro_prec * micro_rec / (micro_prec + micro_rec)

def get_point_accuracy(true_tagsets, pred_tagsets):
    target_srcids = pred_tagsets.keys()
    return sum([true_tagsets[srcid].lower() == pred_tagsets[srcid].lower()
                for srcid in target_srcids]) / len(target_srcids)

def get_accuracy(true_tagsets_sets, pred_tagsets_sets):
    acc = 0
    for srcid, pred_tagsets in pred_tagsets_sets.items():
        pred = set(pred_tagsets)
        true = set(true_tagsets_sets[srcid])
        jaccard = len(pred.intersection(true)) / len(pred.union(true))
        acc += jaccard
    return acc / len(pred_tagsets_sets)

def exclude_common_tagsets(tagsets):
    return [tagset for tagset in tagsets
            if tagset.split('-')[0] != 'networkadapter' and
            tagset.split('-')[0] != 'building'
            ]

def get_accuracy_conservative(true_tagsets_sets, pred_tagsets_sets):
    acc = 0
    for srcid, pred_tagsets in pred_tagsets_sets.items():
        pred = set(exclude_common_tagsets(pred_tagsets))
        true = set(exclude_common_tagsets(true_tagsets_sets[srcid]))
        if len(true) == 0:
            jaccard = 1
        else:
            jaccard = len(pred.intersection(true)) / len(pred.union(true))
        acc += jaccard
    return acc / len(pred_tagsets_sets)


def get_set_accuracy(true_label_sets, pred_tagset_sets):
    # Accuracy per sample = #intersection / #union
    # Accuracy over set = average of the accuracy per sample
    # Input params dictionary based on the srcids
    for srcid, pred_tagset_set in pred_tagset_sets.items():
        pass #TODO