from sklearn.datasets import load_digits
from sklearn.preprocessing import scale
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from sklearn import metrics
import numpy as np


# ---------------------------------------------------------------------------------------------------------- #
#                  A script using K-Means Clustering to classify handwritten digits.                         #
#                                           Written by @tobinatore                                           #
#                                                                                                            #
#                                 This script uses the following dataset:                                    #
#                                  Sklearn's own written digits dataset                                      #
# ---------------------------------------------------------------------------------------------------------- #


def bench_k_means(estimator, name, data):
    estimator.fit(data)
    # A short explanation for every score:
    # homogeneity:          each cluster contains only members of a single class (range 0 - 1)
    # completeness:         all members of a given class are assigned to the same cluster (range 0 - 1)
    # v_measure:            harmonic mean of homogeneity and completeness
    # adjusted_rand:        similarity of the actual values and their predictions,
    #                       ignoring permutations and with chance normalization
    #                       (range -1 to 1, -1 being bad, 1 being perfect and 0 being random)
    # adjusted_mutual_info: agreement of the actual values and predictions, ignoring permutations
    #                       (range 0 - 1, with 0 being random agreement and 1 being perfect agreement)
    # silhouette:           uses the mean distance between a sample and all other points in the same class,
    #                       as well as the mean distance between a sample and all other points in the nearest cluster
    #                       to calculate a score (range: -1 to 1, with the former being incorrect,
    #                       and the latter standing for highly dense clustering.
    #                       0 indicates overlapping clusters.
    print('%-9s \t%i \thomogeneity: %.3f \tcompleteness: %.3f \tv-measure: %.3f \tadjusted-rand: %.3f \t'
          'adjusted-mutual-info: %.3f \tsilhouette: %.3f'
          % (name, estimator.inertia_,
             metrics.homogeneity_score(y, estimator.labels_),
             metrics.completeness_score(y, estimator.labels_),
             metrics.v_measure_score(y, estimator.labels_),
             metrics.adjusted_rand_score(y, estimator.labels_),
             metrics.adjusted_mutual_info_score(y,  estimator.labels_),
             metrics.silhouette_score(data, estimator.labels_,
                                      metric='euclidean')))


def plot(kmeans, data):
    reduced_data = PCA(n_components=2).fit_transform(data)
    kmeans.fit(reduced_data)

    # Step size of the mesh. Decrease to increase the quality of the VQ.
    h = .01  # point in the mesh [x_min, x_max]x[y_min, y_max].

    # Plot the decision boundary. For that, we will assign a color to each
    x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
    y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

    # Obtain labels for each point in mesh. Use last trained model.
    Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.figure(1)
    plt.clf()
    plt.imshow(Z, interpolation='nearest',
               extent=(xx.min(), xx.max(), yy.min(), yy.max()),
               cmap=plt.cm.Paired,
               aspect='auto', origin='lower')

    plt.plot(reduced_data[:, 0], reduced_data[:, 1], 'k.', markersize=2)

    # Plot the centroids as a white X
    centroids = kmeans.cluster_centers_
    plt.scatter(centroids[:, 0], centroids[:, 1],
                marker='x', s=169, linewidths=3,
                color='w', zorder=10)
    plt.title('K-means clustering on the digits dataset (PCA-reduced data)\n'
              'Centroids are marked with white cross')
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.xticks(())
    plt.yticks(())
    plt.show()


# Loading and preparing the data
digits = load_digits()
data = scale(digits.data)
y = digits.target

# Number of clusters
k = len(np.unique(y))

samples, features = data.shape

# Defining the classifier
classifier = KMeans(n_clusters=k, init='k-means++', n_init=10, max_iter=300)

# Computing the score of this classifier
bench_k_means(classifier, "kmeans++", data)

plot(classifier, data)