#-----------------------------------------------------#
#                   Library imports                   #
#-----------------------------------------------------#
#External libraries
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os

#-----------------------------------------------------#
#             Functions for Visualization             #
#-----------------------------------------------------#
# Visualize loss and metric plot for training
def visualize_training(history, prefix, eva_path):
    # Set up the evaluation directory
    if not os.path.exists(eva_path):
        os.mkdir(eva_path)
    # Plot the generalized dice coefficient
    plt.plot(history.history['dice_classwise'])
    plt.plot(history.history['val_dice_classwise'])
    plt.title('Generalized Dice coefficient')
    plt.ylabel('Dice coefficient')
    plt.xlabel('Epoch')
    plt.legend(['Train Set', 'Test Set'], loc='upper left')
    out_path = os.path.join(eva_path,
                            "dice_classwise." + str(prefix) + ".png")
    plt.savefig(out_path)
    plt.close()
    # Plot the tversky loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Tvsersky Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train Set', 'Test Set'], loc='upper left')
    out_path = os.path.join(eva_path,
                            "tversky_loss." + str(prefix) + ".png")
    plt.savefig(out_path)
    plt.close()

def visualize_evaluation(case_id, vol, truth, pred, eva_path):
    # Color volumes according to truth and pred segmentation
    vol_truth = overlay_segmentation(vol, truth)
    vol_pred = overlay_segmentation(vol, pred)
    # Create a figure and two axes objects from matplot
    fig, (ax1, ax2) = plt.subplots(1, 2)
    # Initialize the two subplots (axes) with an empty 512x512 image
    data = np.zeros(vol.shape[1:3])
    ax1.set_title("Ground Truth")
    ax2.set_title("Prediction")
    img1 = ax1.imshow(data)
    img2 = ax2.imshow(data)
    # Update function for both images to show the slice for the current frame
    def update(i):
        plt.suptitle("Case ID: " + str(case_id) + " - " + "Slice: " + str(i))
        img1.set_data(vol_truth[i])
        img2.set_data(vol_pred[i])
        return [img1, img2]
    # Compute the animation (gif)
    ani = animation.FuncAnimation(fig, update, frames=len(truth), interval=10,
                                  repeat_delay=0, blit=False)
    # Set up the output path for the gif
    if not os.path.exists(eva_path):
        os.mkdir(eva_path)
    file_name = "visualization.case_" + str(case_id).zfill(5) + ".gif"
    out_path = os.path.join(eva_path, file_name)
    # Save the animation (gif)
    ani.save(out_path, writer='imagemagick', fps=30)
    # Close the matplot
    plt.close()


#-----------------------------------------------------#
#                     Subroutines                     #
#-----------------------------------------------------#
# Based on: https://github.com/neheller/kits19/blob/master/starter_code/visualize.py
def overlay_segmentation(vol, seg):
    # Scale volume to greyscale range
    vol_greyscale = (255*(vol - np.min(vol))/np.ptp(vol)).astype(int)
    # Convert volume to RGB
    vol_rgb = np.stack([vol_greyscale, vol_greyscale, vol_greyscale], axis=-1)
    # Initialize segmentation in RGB
    shp = seg.shape
    seg_rgb = np.zeros((shp[0], shp[1], shp[2], 3), dtype=np.int)
    # Set class to appropriate color
    seg_rgb[np.equal(seg, 1)] = [255, 0,   0]
    seg_rgb[np.equal(seg, 2)] = [0,   0, 255]
    # Get binary array for places where an ROI lives
    segbin = np.greater(seg, 0)
    repeated_segbin = np.stack((segbin, segbin, segbin), axis=-1)
    # Weighted sum where there's a value to overlay
    alpha = 0.3
    vol_overlayed = np.where(
        repeated_segbin,
        np.round(alpha*seg_rgb+(1-alpha)*vol_rgb).astype(np.uint8),
        np.round(vol_rgb).astype(np.uint8)
    )
    # Return final volume with segmentation overlay
    return vol_overlayed