import seaborn as sb
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np


def plot_confusion(title, true_labels, predicted_labels, normalized=True):
    labels = list(set(true_labels) | set(predicted_labels))

    if normalized:
        cm = confusion_matrix(true_labels, predicted_labels, labels=labels)
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        cm = confusion_matrix(true_labels, predicted_labels, labels=labels)

    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.set_title(title)
    # plt.colorbar()
    tick_marks = np.arange(len(labels))
    ax.set_xticks(tick_marks)
    ax.set_xticklabels(labels, rotation=90)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(labels)
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    ax.grid(False)
    return fig, ax