import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import matplotlib.patches as patches
import mnist, math

image_width = 28
image_height = 28
image = np.zeros(shape=(image_height, image_width))
image_1d = np.zeros(shape=(1, image_height * image_width))
gray = 0.9
isMouseDown = False
sess = tf.InteractiveSession()

def UpdateImage(x, y):
    if(type(x) == None.__class__ or type(x) == None.__class__):
        return
    x = int(round(x))
    y = int(round(y))
    image_y = -(y + 1)
    if (image_y - 2) < 0 or image_y >= image_height or (x + 2) >= image_width:
        return

    axes.add_patch(patches.Rectangle((x , y), 3, 3))
    plt.draw()

    image[image_y][x] = image[image_y][x + 1] = image[image_y][x + 2] = gray
    image[image_y - 1][x] = image[image_y - 1][x + 1] = image[image_y - 1][x + 2] = gray    
    image[image_y - 2 ][x] = image[image_y - 2][x + 1] = image[image_y - 2][x + 2] = gray
    
def update_figure(result):
    if result == -1:
        axes.set_xlabel("")
    else:
        axes.set_xlabel("Predict: {0}".format(result))
    plt.draw()

def OnClick(event):
    global isMouseDown
    if event.button == 1: # left
        isMouseDown = True;
        UpdateImage(event.xdata, event.ydata)

def OnRelease(event):
    global image, image_1d, isMouseDown
    if event.button == 3: # right
        image = np.zeros(shape=(image_height, image_width))
        reset_axis(axes)
        update_figure(-1)
    if event.button == 1: # left
        isMouseDown = False;
        recognize()
    image_1d = image.ravel()

def OnMotion(event):
    global isMouseDown
    if (isMouseDown):
       UpdateImage(event.xdata, event.ydata)
       update_figure(-1)

def recognize():
    predict, conv1, conv2 = sess.run([predict_op, conv1_op, conv2_op], feed_dict={x: np.reshape(image_1d, (1, 784))})
    update_figure(np.argmax(predict, 1))
    plot_conv_cout(conv1, 2, "conv layer 1")
    plot_conv_cout(conv2, 3, "conv layer 2")

def reset_axis(axes):
    plt.cla()
    axes.set_xlim(0, image_width)
    axes.xaxis.set_major_locator(MultipleLocator(4))
    axes.xaxis.set_minor_locator(MultipleLocator(1))
    axes.xaxis.grid(True, which='minor')

    axes.set_ylim(-image_height, 0)
    axes.yaxis.set_major_locator(MultipleLocator(4))
    axes.yaxis.set_minor_locator(MultipleLocator(1))
    axes.yaxis.grid(True, which='minor')

def plot_conv_cout(values, index, title):
    num_filters = values.shape[3]
    num_grids = math.ceil(math.sqrt(num_filters))

    _, axes = plt.subplots(num_grids, num_grids)
    for i, ax in enumerate(axes.flat):
        if i < num_filters:
            img = values[0, :, :, i]
            ax.imshow(img, interpolation='nearest', cmap='gray')
        ax.set_xticks([]); ax.set_yticks([])
    plt.figure(index).canvas.set_window_title(title)
    plt.show()

saver = tf.train.import_meta_graph(mnist.LOGDIR + "model.ckpt.meta")
saver.restore(sess, mnist.LOGDIR + "model.ckpt")
graph = tf.get_default_graph()  

x = graph.get_tensor_by_name("x:0")
predict_op = graph.get_collection('predict_op')[0]
conv1_op = graph.get_collection('conv1_op')[0]
conv2_op = graph.get_collection('conv2_op')[0]

fig = plt.figure('Input')
fig.canvas.mpl_connect('button_press_event', OnClick)
fig.canvas.mpl_connect('button_release_event', OnRelease)
fig.canvas.mpl_connect('motion_notify_event', OnMotion)

plt.gca().set_aspect('equal', adjustable='box')
axes = plt.gca()
reset_axis(axes)
plt.show()