"""
Visualization of Neural Turing Machines.
"""
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np


def show(w, w_title):
    """
    Show a weight matrix.
    :param w: the weight matrix.
    :param w_title: the title of the weight matrix
    :return: None.
    """
    # show w_z matrix of update gate.
    axes_w = plt.gca()
    plt.imshow(w)
    plt.colorbar()
    # plt.colorbar(orientation="horizontal")
    plt.xlabel("$w_{1}$")
    plt.ylabel("$w_{2}$")
    axes_w.set_xticks([])
    axes_w.set_yticks([])
    matrix_size = "$:\ %d \\times\ %d$" % (len(w[0]), len(w))
    w_title += matrix_size
    plt.title(w_title)

    # show the matrix.
    plt.show()


def make_tick_labels_invisible(fig):
    for i, ax in enumerate(fig.axes):
        # ax.text(0.5, 0.5, "ax%d" % (i+1), va="center", ha="center")
        for tl in ax.get_xticklabels() + ax.get_yticklabels():
            tl.set_visible(False)


def show_copy_data(input_sequence, output_sequence, input_name, output_name, image_file):
    # set figure size
    fig = plt.figure(figsize=(7, 3))

    # draw first line
    axes_input_10 = plt.subplot2grid((2, 1), (0, 0), colspan=1)
    axes_input_10.set_aspect('equal')
    plt.imshow(input_sequence, interpolation='none')
    axes_input_10.set_xticks([])
    axes_input_10.set_yticks([])

    # draw second line
    axes_output_10 = plt.subplot2grid((2, 1), (1, 0), colspan=1)
    plt.imshow(output_sequence, interpolation='none')
    axes_output_10.set_xticks([])
    axes_output_10.set_yticks([])

    # add text
    plt.text(-2, -4.5, input_name, ha='right')
    plt.text(-2, 4, output_name, ha='right')
    plt.text(6, 10, 'Time $\longrightarrow$', ha='right')

    # set tick labels invisible
    make_tick_labels_invisible(plt.gcf())
    # adjust spaces
    plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
    # add color bars
    # *rect* = [left, bottom, width, height]
    cax = plt.axes([0.85, 0.125, 0.015, 0.75])
    plt.colorbar(cax=cax)

    # show figure
    # plt.show()

    # save image
    # pp = PdfPages(image_file)
    # plt.savefig(pp, format='pdf')
    # pp.close()
    fig.savefig(image_file, dpi=75)

    # close plot GUI
    plt.close()


def show_repeat_copy_data(input_sequence, output_sequence, input_name, output_name, image_file, repeat_times):
    # set figure size
    fig = plt.figure(figsize=(16, 1.5))

    # draw first line
    axes_input_10 = plt.subplot2grid((4, 1), (1, 0), colspan=1)
    axes_input_10.set_aspect('equal')
    plt.imshow(input_sequence, interpolation='none')
    axes_input_10.set_xticks([])
    axes_input_10.set_yticks([])

    # draw second line
    axes_output_10 = plt.subplot2grid((4, 1), (2, 0), colspan=1)
    plt.imshow(output_sequence, interpolation='none')
    axes_output_10.set_xticks([])
    axes_output_10.set_yticks([])

    # add text
    plt.text(-2, -2.2, input_name, ha='right')
    plt.text(-2, 2.2, output_name, ha='right')
    plt.text(13, 6.5, 'Time $t$ $\longrightarrow$', ha='right')

    title = "Repeat times = %d" % repeat_times
    plt.text(55, -6.5, title, ha='center')

    # # set tick labels invisible
    # make_tick_labels_invisible(plt.gcf())
    # # adjust spaces
    # plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
    # # add color bars
    # # *rect* = [left, bottom, width, height]
    # cax = plt.axes([0.85, 0.125, 0.015, 0.75])
    # plt.colorbar(cax=cax)

    # show figure
    # plt.show()

    # save image
    # pp = PdfPages(image_file)
    # plt.savefig(pp, format='pdf')
    # pp.close()
    fig.savefig(image_file, dpi=75, format='pdf')

    # close plot GUI
    plt.close()


def show_multi_copy_data(target_sequence_10, output_sequence_10,
                   target_sequence_20, output_sequence_20,
                   target_sequence_30, output_sequence_30,
                   target_sequence_50, output_sequence_50,
                   target_sequence_120, output_sequence_120,
                   image_file):
    # set figure size
    fig = plt.figure(figsize=(12, 4))

    # draw first line
    axes_target_10 = plt.subplot2grid((4, 11), (0, 0), colspan=1)
    axes_target_10.set_aspect('equal')
    plt.imshow(target_sequence_10, interpolation='none')
    axes_target_10.set_xticks([])
    axes_target_10.set_yticks([])
    axes_target_20 = plt.subplot2grid((4, 11), (0, 1), colspan=2)
    plt.imshow(target_sequence_20, interpolation='none')
    axes_target_20.set_xticks([])
    axes_target_20.set_yticks([])
    axes_target_30 = plt.subplot2grid((4, 11), (0, 3), colspan=3)
    plt.imshow(target_sequence_30, interpolation='none')
    axes_target_30.set_xticks([])
    axes_target_30.set_yticks([])
    axes_target_50 = plt.subplot2grid((4, 11), (0, 6), colspan=5)
    plt.imshow(target_sequence_50, interpolation='none')
    axes_target_50.set_xticks([])
    axes_target_50.set_yticks([])

    # draw second line
    axes_output_10 = plt.subplot2grid((4, 11), (1, 0), colspan=1)
    plt.imshow(output_sequence_10, interpolation='none')
    axes_output_10.set_xticks([])
    axes_output_10.set_yticks([])
    axes_output_20 = plt.subplot2grid((4, 11), (1, 1), colspan=2)
    plt.imshow(output_sequence_20, interpolation='none')
    axes_output_20.set_xticks([])
    axes_output_20.set_yticks([])
    axes_output_30 = plt.subplot2grid((4, 11), (1, 3), colspan=3)
    plt.imshow(output_sequence_30, interpolation='none')
    axes_output_30.set_xticks([])
    axes_output_30.set_yticks([])
    axes_output_50 = plt.subplot2grid((4, 11), (1, 6), colspan=5)
    plt.imshow(output_sequence_50, interpolation='none')
    axes_output_50.set_xticks([])
    axes_output_50.set_yticks([])

    # draw last two lines
    axes_target_120 = plt.subplot2grid((4, 11), (2, 0), colspan=11)
    plt.imshow(target_sequence_120, interpolation='none')
    axes_target_120.set_xticks([])
    axes_target_120.set_yticks([])
    axes_output_120 = plt.subplot2grid((4, 11), (3, 0), colspan=11)
    plt.imshow(output_sequence_120, interpolation='none')
    axes_output_120.set_xticks([])
    axes_output_120.set_yticks([])

    # add text
    plt.text(-2, 5, 'Outputs', ha='right')
    plt.text(-2, -7.5, 'Targets', ha='right')
    plt.text(-2, -20, 'Outputs', ha='right')
    plt.text(-2, -32.5, 'Targets', ha='right')
    plt.text(10, 12, 'Time $\longrightarrow$', ha='right')

    # set tick labels invisible
    make_tick_labels_invisible(plt.gcf())
    # adjust spaces
    plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
    # add color bars
    # *rect* = [left, bottom, width, height]
    cax = plt.axes([0.85, 0.125, 0.015, 0.75])
    plt.colorbar(cax=cax)

    # show figure
    plt.show()

    # save image
    # pp = PdfPages(image_file)
    # plt.savefig(pp, format='pdf')
    # pp.close()
    fig.savefig(image_file, dpi=75)

    # close plot GUI
    plt.close()


def show_associative_recall_data(input_sequence, output_sequence, input_name, output_name, image_file):
    # set figure size
    fig = plt.figure(figsize=(16, 2))

    # draw first line
    axes_input_10 = plt.subplot2grid((3, 1), (0, 0), colspan=1)
    axes_input_10.set_aspect('equal')
    plt.imshow(input_sequence, interpolation='none')
    axes_input_10.set_xticks([])
    axes_input_10.set_yticks([])

    # draw second line
    axes_output_10 = plt.subplot2grid((3, 1), (1, 0), colspan=1)
    plt.imshow(output_sequence, interpolation='none')
    axes_output_10.set_xticks([])
    axes_output_10.set_yticks([])

    # add text
    plt.text(-2, -5, input_name, ha='right')
    plt.text(-2, 5, output_name, ha='right')
    plt.text(14.3, 11, 'Time $t$ $\longrightarrow$', ha='right')

    # title = "Repeat times = %d" % repeat_times
    # plt.text(55, -6.5, title, ha='center')

    # # set tick labels invisible
    # make_tick_labels_invisible(plt.gcf())
    # # adjust spaces
    # plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
    # # add color bars
    # # *rect* = [left, bottom, width, height]
    # cax = plt.axes([0.85, 0.125, 0.015, 0.75])
    # plt.colorbar(cax=cax)

    # show figure
    # plt.show()

    # save image
    # pp = PdfPages(image_file)
    # plt.savefig(pp, format='pdf')
    # pp.close()
    fig.savefig(image_file, dpi=75, format='pdf')

    # close plot GUI
    plt.close()


def show_memory_of_copy_task(
        input_sequence, output_squence,
        adds, reads,
        write_weightings, read_weightings,
        image_file):
    # set figure size
    fig = plt.figure(figsize=(10, 8))

    # draw first line
    axes_input = plt.subplot2grid((15, 2), (0, 0), rowspan=2)
    plt.imshow(input_sequence, interpolation='none')
    axes_input.set_xticks([])
    axes_input.set_yticks([])
    plt.title("Inputs")
    axes_output = plt.subplot2grid((15, 2), (0, 1), rowspan=2)
    plt.imshow(output_squence, interpolation='none')
    axes_output.set_xticks([])
    axes_output.set_yticks([])
    plt.title("Outputs")

    # draw second line
    axes_adds = plt.subplot2grid((15, 2), (2, 0), rowspan=4)
    plt.imshow(adds)  # , interpolation='none'
    axes_adds.set_xticks([])
    axes_adds.set_yticks([])
    plt.title("Adds")
    axes_reads = plt.subplot2grid((15, 2), (2, 1), rowspan=4)
    plt.imshow(reads)  # , interpolation='none'
    axes_reads.set_xticks([])
    axes_reads.set_yticks([])
    plt.title("Reads")

    # draw last line
    axes_write = plt.subplot2grid((15, 2), (6, 0), rowspan=9)
    plt.imshow(write_weightings, interpolation='none')
    axes_write.set_xticks([])
    axes_write.set_yticks([])
    plt.title("Write Weightings")
    axes_read = plt.subplot2grid((15, 2), (6, 1), rowspan=9)
    plt.imshow(read_weightings, interpolation='none')
    axes_read.set_xticks([])
    axes_read.set_yticks([])
    plt.title("Read Weightings")

    # add text
    plt.text(-45, 20.5, 'Location $\longrightarrow$', fontsize=16, ha='center', rotation=90)
    plt.text(11.5, 39, 'Time $\longrightarrow$', fontsize=16, ha='right')
    plt.text(-30.5, 39, 'Time $\longrightarrow$', fontsize=16, ha='right')

    # set tick labels invisible
    make_tick_labels_invisible(plt.gcf())
    # adjust spaces
    plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
    # add color bars
    # *rect* = [left, bottom, width, height]
    cax = plt.axes([0.85, 0.15, 0.015, 0.75])
    plt.colorbar(cax=cax)

    # show figure
    plt.show()

    # save image
    fig.savefig(image_file, dpi=75)

    # close plot GUI
    plt.close()


class PlotDynamicalMatrix:
    def __init__(self, matrix_list, name_list):
        """
        Initialize the value of matrix.
        :param matrix_list: a goup of matrix.
        :return: non.
        """
        self.matrix_list = matrix_list
        # set figure size
        self.fig = plt.figure(figsize=(7, 5))

        plt.ion()
        self.update(matrix_list, name_list)

    def update(self, matrix_list, name_list):
        # draw first line
        axes_input = plt.subplot2grid((3, 1), (0, 0), colspan=1)
        axes_input.set_aspect('equal')
        plt.imshow(matrix_list[0], interpolation='none')
        axes_input.set_xticks([])
        axes_input.set_yticks([])

        # draw second line
        axes_output = plt.subplot2grid((3, 1), (1, 0), colspan=1)
        plt.imshow(matrix_list[1], interpolation='none')
        axes_output.set_xticks([])
        axes_output.set_yticks([])

        # draw third line
        axes_predict = plt.subplot2grid((3, 1), (2, 0), colspan=1)
        plt.imshow(matrix_list[2], interpolation='none')
        axes_predict.set_xticks([])
        axes_predict.set_yticks([])

        # # add text
        # plt.text(-2, -19.5, name_list[0], ha='right')
        # plt.text(-2, -7.5, name_list[1], ha='right')
        # plt.text(-2, 4.5, name_list[2], ha='right')
        # plt.text(6, 10, 'Time $\longrightarrow$', ha='right')

        # set tick labels invisible
        make_tick_labels_invisible(plt.gcf())
        # adjust spaces
        plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
        # add color bars
        # *rect* = [left, bottom, width, height]
        cax = plt.axes([0.85, 0.125, 0.015, 0.75])
        plt.colorbar(cax=cax)

        # show figure
        # plt.show()
        plt.draw()
        plt.pause(0.025)
        # plt.pause(15)

    def save(self, image_file):
        # save image
        # pp = PdfPages(image_file)
        # plt.savefig(pp, format='pdf')
        # pp.close()
        self.fig.savefig(image_file, dpi=75)

    def close(self):
        # close plot GUI
        plt.close()


class PlotDynamicalMatrix4Repeat:
    def __init__(self, matrix_list, name_list, repeat_times):
        """
        Initialize the value of matrix.
        :param matrix_list: a goup of matrix.
        :return: non.
        """
        self.matrix_list = matrix_list
        # set figure size
        self.fig = plt.figure(figsize=(11, 3))

        plt.ion()
        self.update(matrix_list, name_list, repeat_times)

    def update(self, matrix_list, name_list, repeat_times):
        # draw first line
        axes_input = plt.subplot2grid((3, 1), (0, 0), colspan=1)
        axes_input.set_aspect('equal')
        plt.imshow(matrix_list[0], interpolation='none')
        axes_input.set_xticks([])
        axes_input.set_yticks([])

        # draw second line
        axes_output = plt.subplot2grid((3, 1), (1, 0), colspan=1)
        plt.imshow(matrix_list[1], interpolation='none')
        axes_output.set_xticks([])
        axes_output.set_yticks([])

        # draw third line
        axes_predict = plt.subplot2grid((3, 1), (2, 0), colspan=1)
        plt.imshow(matrix_list[2], interpolation='none')
        axes_predict.set_xticks([])
        axes_predict.set_yticks([])
        # for 8bits 20length
        # # add text
        # plt.text(-2, -22, name_list[0], ha='right')
        # plt.text(-2, -9, name_list[1], ha='right')
        # plt.text(-2, 4.5, name_list[2], ha='right')
        # plt.text(12, 12, 'Time $\longrightarrow$', ha='right')
        #
        # title = "Repeat Times = %d"%repeat_times
        # plt.text(60, -30, title, ha='center')
        # # plt.title(title)
        #
        # comment 20170307  ---------------------------------------------------
        # # add text
        # plt.text(-2, -11.3, name_list[0], ha='right')
        # plt.text(-2, -4.8, name_list[1], ha='right')
        # plt.text(-2, 2, name_list[2], ha='right')
        # plt.text(5.5, 6, 'Time $\longrightarrow$', ha='right')
        #
        # title = "Repeat Times = %d"%repeat_times
        # plt.text(30, -15, title, ha='center')
        # # plt.title(title)
        # comment 20170307  ---------------------------------------------------

        # set tick labels invisible
        make_tick_labels_invisible(plt.gcf())
        # adjust spaces
        plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
        # add color bars
        # *rect* = [left, bottom, width, height]
        cax = plt.axes([0.85, 0.125, 0.015, 0.75])
        plt.colorbar(cax=cax)

        # show figure
        # plt.show()
        plt.draw()
        plt.pause(0.025)
        # plt.pause(15)

    def save(self, image_file):
        # save image
        # pp = PdfPages(image_file)
        # plt.savefig(pp, format='pdf')
        # pp.close()
        self.fig.savefig(image_file, dpi=75)

    def close(self):
        # close plot GUI
        plt.close()


class PlotDynamicalMatrix4NGram:
    def __init__(self, matrix_input, matrix_output, matrix_predict):
        """
        Initialize the value of matrix.
        :param matrix_list: a goup of matrix.
        :return: non.
        """
        # set figure size
        self.fig = plt.figure(figsize=(20.5, 1.5))
        # self.fig = plt.figure()

        plt.ion()
        self.update(matrix_input, matrix_output, matrix_predict)

    def update(self, matrix_input, matrix_output, matrix_predict):
        # print(matrix_input[0])
        # print(matrix_output[0])
        # print(matrix_predict[0])
        # matrix = np.zeros((3, len(matrix_input[0])), dtype=np.uint8)
        # matrix[0] = matrix_input[0]
        # matrix[1] = matrix_output[0]
        # matrix[2] = matrix_predict[0]
        matrix = np.zeros((9, len(matrix_input[0])), dtype=np.uint8)
        matrix[0] = matrix_input[0]
        matrix[1] = matrix_input[1]
        matrix[2] = matrix_input[2]
        matrix[3] = matrix_output[0]
        matrix[4] = matrix_output[1]
        matrix[5] = matrix_output[2]
        matrix[6] = matrix_predict[0]
        matrix[7] = matrix_predict[1]
        matrix[8] = matrix_predict[2]

        # print(matrix)

        axes_w = plt.gca()
        plt.imshow(matrix, interpolation='none')
        plt.xlabel("$Time \longrightarrow$")
        # plt.ylabel("$w_{2}$")
        # axes_w.set_xticks([])
        axes_w.set_yticks([])
        # plt.title("N Gram")

        # show figure
        # plt.show()
        plt.draw()
        plt.pause(0.025)
        # plt.pause(15)

    def save(self, image_file):
        # save image
        # pp = PdfPages(image_file)
        # plt.savefig(pp, format='pdf')
        # pp.close()
        self.fig.savefig(image_file, dpi=75)

    def close(self):
        # close plot GUI
        plt.close()



class PlotDynamicalMatrix4PrioritySort:
    def __init__(self, matrix_input, matrix_output, matrix_predict):
        """
        Initialize the value of matrix.
        :param matrix_list: a goup of matrix.
        :return: non.
        """
        # set figure size
        self.fig = plt.figure(figsize=(6, 5))
        # self.fig = plt.figure()

        plt.ion()
        self.update(matrix_input, matrix_output, matrix_predict)

    def update(self, matrix_input, matrix_output, matrix_predict):

        # draw first line
        axes_input = plt.subplot2grid((3, 1), (0, 0), colspan=1)
        axes_input.set_aspect('equal')
        plt.imshow(matrix_input, interpolation='none')
        axes_input.set_xticks([])
        axes_input.set_yticks([])

        # draw second line
        axes_output = plt.subplot2grid((3, 1), (1, 0), colspan=1)
        plt.imshow(matrix_output, interpolation='none')
        axes_output.set_xticks([])
        axes_output.set_yticks([])

        # draw third line
        axes_predict = plt.subplot2grid((3, 1), (2, 0), colspan=1)
        plt.imshow(matrix_predict, interpolation='none')
        axes_predict.set_xticks([])
        axes_predict.set_yticks([])
        # for 8bits 20length
        # # add text
        # plt.text(-2, -22, name_list[0], ha='right')
        # plt.text(-2, -9, name_list[1], ha='right')
        # plt.text(-2, 4.5, name_list[2], ha='right')
        # plt.text(12, 12, 'Time $\longrightarrow$', ha='right')
        #
        # title = "Repeat Times = %d"%repeat_times
        # plt.text(60, -30, title, ha='center')
        # # plt.title(title)
        #

        # # add text
        # plt.text(-2, -11.3, "Input", ha='right')
        # plt.text(-2, -4.8, "Output", ha='right')
        # plt.text(-2, 2, "Predict", ha='right')
        # plt.text(5.5, 6, 'Time $\longrightarrow$', ha='right')
        # # plt.title(title)

        # set tick labels invisible
        make_tick_labels_invisible(plt.gcf())
        # adjust spaces
        plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
        # add color bars
        # *rect* = [left, bottom, width, height]
        cax = plt.axes([0.85, 0.125, 0.015, 0.75])
        plt.colorbar(cax=cax)

        # show figure
        # plt.show()
        plt.draw()
        plt.pause(0.025)
        # plt.pause(1)


    def save(self, image_file):
        # save image
        # pp = PdfPages(image_file)
        # plt.savefig(pp, format='pdf')
        # pp.close()
        self.fig.savefig(image_file, dpi=75)

    def close(self):
        # close plot GUI
        plt.close()