# Subscribe to my channel, share and like my videos at
# http://youtube.com/tkorting
#
# Feel free to use and share this code.
#
# Thales Sehn Körting

# import libraries
from osgeo import gdal
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

# gdal constants
from gdalconst import *

# inform to use GDAL exceptions
gdal.UseExceptions()

from matplotlib.colors import ListedColormap, BoundaryNorm

# compute the distance in the SLIC space given:
# - two pixels as 5-tuple (R, G, B, x, y)
# - the expected compactness, m [1, 20]
# - the size S (grid interval expected)
def distance_slic(pixel_k, pixel_i, m, S):
    # compute the euclidian distance in RGB space between pixels k and i
    # in the original algorithm, the CIE-L*a*b color space is used
    d_rgb = math.sqrt( (pixel_k[0] - pixel_i[0]) ** 2 + 
                       (pixel_k[1] - pixel_i[1]) ** 2 + 
                       (pixel_k[2] - pixel_i[2]) ** 2)
    # compute the euclidian distance in the row/column space between pixels k and i
    d_xy = math.sqrt( (pixel_k[3] - pixel_i[3]) ** 2 + 
                      (pixel_k[4] - pixel_i[4]) ** 2)
    final_distance = d_rgb + m * d_xy / S
    # print ('d_rgb:', d_rgb, 'd_xy:', d_xy, 'final distance:', final_distance)
    return final_distance

# return a default color in range 0-1000
def get_color(k, K):
    color_vector = matplotlib.cm.get_cmap('Spectral')
    return color_vector(k / K)

# print centers of clusters
def plot_clusters(array, clusters, K, S, output_figure = ''):
    fig = plt.figure(figsize = (8, 6))
    plt.imshow(array, vmin=0, vmax=255)
    for k in range(K):
        # plot bounding box
        center_x = clusters[k, 3]
        center_y = clusters[k, 4]
        x = [center_x - S, center_x + S, center_x + S, center_x - S, center_x - S]
        y = [center_y - S, center_y - S, center_y + S, center_y + S, center_y - S]
        ax = fig.add_subplot(111)
        cluster_color = get_color(k, K)
        # cluster_color = (clusters[k, 0], clusters[k, 1], clusters[k, 2], 1.0)
        ax.fill(x, y, color=cluster_color, fill = None, linewidth=1, alpha = 0.5)
        # plot center
        # plt.plot(center_x, center_y, color=cluster_color, marker='.')
        plt.scatter(center_x, center_y, s = 35, facecolors=cluster_color, edgecolors='white', linewidth=1, alpha = 0.75)
        # adjust image
        (rows, columns, bands) = array.shape
        plt.xlim([0 - S, columns + S])
        plt.ylim([0 - S, rows + S])

    if output_figure != '':
        plt.savefig(output_figure, format='png', dpi=1000)
    else:
        plt.show()

def plot_slic(array, clusters, K, S, output_figure = ''):
    fig = plt.figure(figsize=(8, 6))
    # create colormap based on cluster RGB centers
    slic_colormap = []
    for c in clusters:
        slic_colormap.append((c[0], c[1], c[2], 1.0))
    slic_listed_colormap = ListedColormap(slic_colormap)
    slic_norm = BoundaryNorm(range(K), K)
    plt.imshow(array, norm=slic_norm, cmap=slic_listed_colormap)
    # adjust image
    (rows, columns) = array.shape
    plt.xlim([0 - S, columns + S])
    plt.ylim([0 - S, rows + S])

    if output_figure != '':
        plt.savefig(output_figure, format='png', dpi=1000)
    else:
        plt.show()

# open dataset
filename = "slic_test.tif"
dataset = gdal.Open(filename, GA_ReadOnly)

# retrieve metadata from raster
rows = dataset.RasterYSize
columns = dataset.RasterXSize
N = rows * columns
bands = dataset.RasterCount

# define the number of regions to split the image into
set_of_K = (400, 1200)
# define compactness constant
m = 10

for K in set_of_K:
    # compute other constants
    size = N / K
    S = math.sqrt(size)

    # get RGB numpy arrays
    array_R = dataset.GetRasterBand(1).ReadAsArray() / 255
    array_G = dataset.GetRasterBand(2).ReadAsArray() / 255
    array_B = dataset.GetRasterBand(3).ReadAsArray() / 255
    array_RGB = np.zeros([array_R.shape[0], array_R.shape[1], 3])
    array_RGB[:,:,0] = array_R
    array_RGB[:,:,1] = array_G
    array_RGB[:,:,2] = array_B

    # print some metadata
    print ("SLIC test")
    print ("image metadata:")
    print (rows, "rows x", columns, "columns x", bands, "bands")
    print (K, "expected divisions")
    print (S, "is the value of S")

    # compute initial clusters
    # the 5 positions are:
    # 1-R, 2-G, 3-B, 4-x, 5-y
    # in the original algorithm 1-L, 2-a, 3-b (the CIE-L*a*b color space)
    C = np.zeros((K, 5))

    # define center of K clusters (x, y)
    k = 0
    for y in range(math.floor(S / 2), rows, math.floor(rows / math.sqrt(K))):
        for x in range(math.floor(S / 2), columns, math.floor(columns / math.sqrt(K))):
            if k >= K:
                continue
            C[k, 0] = array_R[y, x]
            C[k, 1] = array_G[y, x]
            C[k, 2] = array_B[y, x]
            C[k, 3] = x
            C[k, 4] = y
            k = k + 1

    # set SLIC matrix
    array_SLIC = np.ones_like(array_R)

    t = 0
    plot_clusters(array_RGB, C, K, S, 'animation/K' + str(K) + '_cluster_limits_t' + str(t) + '.png')

    # compute superpixels matrix
    array_superpixels = np.zeros((N, 5))
    i = 0
    for y in range(0, rows):
        for x in range(0, columns):
            array_superpixels[i, 0] = array_R[y, x]
            array_superpixels[i, 1] = array_G[y, x]
            array_superpixels[i, 2] = array_B[y, x]
            array_superpixels[i, 3] = x
            array_superpixels[i, 4] = y
            i = i + 1


    # run SLIC k-means
    error_threshold = 5
    residual_error = error_threshold + 1
    i = 0
    max_i = 20
    while residual_error > error_threshold:
        # set SLIC matrix
        array_SLIC = np.ones_like(array_R)
        # define no data as K + 100 value
        array_SLIC *= (K * 100)

        print ("iteration", i)
        residual_error = 0.0
        # assign the best matching pixels from a 2S × 2S square neighborhood
        # around the cluster center according to the distance measure
        for k in range(K):
            center_x = C[k, 3]
            center_y = C[k, 4]
            left_y_limit = max(0, math.floor(center_y - S))
            right_y_limit = min(rows, math.floor(center_y + S))
            left_x_limit = max(0, math.floor(center_x - S))
            right_x_limit = min(columns, math.floor(center_x + S))
            print ("cluster", k, "limits y", left_y_limit, "to", right_y_limit, "and x", left_x_limit, "to", right_x_limit)
            for y in range(left_y_limit, right_y_limit):
                for x in range(left_x_limit, right_x_limit):
                    # print("checking around cluster", k, "y (row)", y, "x (column)", x)
                    distances_to_clusters = np.zeros(K)
                    for k1 in range(K):
                        pixel_k = np.array([C[k1, 0], C[k1, 1], C[k1, 2], C[k1, 3], C[k1, 4]])
                        pixel_i = np.array([array_R[y, x], array_G[y, x], array_B[y, x], x, y])
                        distances_to_clusters[k1] = distance_slic(pixel_k, pixel_i, m, S)
                    if np.argmin(distances_to_clusters) == k:
                        array_SLIC[y, x] = k
                    # print("  distances_to_clusters", distances_to_clusters)
                    # print("  array_SLIC", array_SLIC)
        # compute new cluster centers and residual error E
        for k in range(K):
            center_R = C[k, 0]
            center_G = C[k, 1]
            center_B = C[k, 2]
            center_x = C[k, 3]
            center_y = C[k, 4]
            new_center = np.zeros(5)
            total_in_k = 0
            for j in range(N):
                x = int(array_superpixels[j, 3])
                y = int(array_superpixels[j, 4])
                if array_SLIC[y, x] == k:
                    new_center = new_center + array_superpixels[j, :]
                    total_in_k = total_in_k + 1
            if total_in_k > 0:
                print("updating k", k, "old center", C[k, :], "new center", new_center / total_in_k, "total_in_k", total_in_k)
                new_center = new_center / total_in_k
                partial_error = C[k, :] - new_center
                residual_error = residual_error + math.sqrt(partial_error.dot(partial_error.transpose()))
                C[k, :] = new_center

        t = t + 1
        # plot intermediate clusters and slic
        plot_clusters(array_RGB, C, K, S, 'animation/K' + str(K) + '_cluster_limits_t' + str(t) + '.png')
        plot_slic(array_SLIC, C, K, S, 'animation/K' + str(K) + '_partial_slic_t' + str(t) + '.png')

        print ("residual error is", residual_error, "at iteration", i)
        i = i + 1
        # to avoid infinite loop
        if i > max_i:
            residual_error = error_threshold

    # make final segmentation
    # set SLIC matrix
    array_SLIC = np.ones_like(array_R)
    # define no data as K + 100 value
    array_SLIC *= (K * 100)
    # iterate again
    for j in range(N):
        x = int(array_superpixels[j, 3])
        y = int(array_superpixels[j, 4])
        distances_to_clusters = np.zeros(K)
        for k in range(K):
            pixel_k = np.array([C[k, 0], C[k, 1], C[k, 2], C[k, 3], C[k, 4]])
            pixel_i = array_superpixels[j, :]
            distances_to_clusters[k] = distance_slic(pixel_k, pixel_i, m, S)
        array_SLIC[y, x] = np.argmin(distances_to_clusters)

    # print("  array_SLIC", array_SLIC)
    plot_slic(array_SLIC, C, K, S, 'animation/K' + str(K) + '_final_slic_t.png')

# close dataset
dataset = None