########################################################################################################################
# mnist_cnn1.py
# Author: Zach Harris @jzharris
#
# Dependencies:
#   Python 3.6.0
#   TensorFlow 1.4.x
#   Keras 2.x
#
# Python packages:
#   os
#   argparse
#   matplotlib
#   scipy
#
# Datasets in use:
#   MNIST
#
# Notes:
#   Determine version of your TF: python3 -c 'import tensorflow as tf; print(tf.__version__)'

import os
import os.path as path
import argparse

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Flatten, Dropout, Conv2D, MaxPooling2D
from keras import backend as K

import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib

import matplotlib.pyplot as plt
from scipy.misc import imsave

########################################################################################################################
# Set ArgumentParser so you can set all variables from terminal

parser = argparse.ArgumentParser(description='Test Arguments')

# model params
parser.add_argument('--model_name', default='mnist_cnn1')       # the name of the saved model

# dataset params
parser.add_argument('--export_images', default=False)           # instead of training a model, save MNIST to .png's
parser.add_argument('--export_number', default=10)              # number of MNIST images to export (if enabled)
parser.add_argument('--plot_images', default=False)             # instead of training a model, display MNIST images

# training params
parser.add_argument('--epochs', default=5)                      # number of epochs to train model for
parser.add_argument('--batch_size', default=128)                # batch size to use for training

args = parser.parse_args()

########################################################################################################################
# Global Vars

model_name = args.model_name
export_images = args.export_images
export_number = args.export_number
plot_images = args.plot_images
epochs = args.epochs
batch_size = args.batch_size

########################################################################################################################
# Load data: load the MNIST train/validation sets, and export/display MNIST images

def load_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    if export_images:

        # save images from MNIST to files
        for i in range(export_number):
            if not path.exists('export'):
                os.mkdir('export')

            imsave('export/mnist_train_{}.png'.format(i), x_train[i])
            imsave('export/mnist_test_{}.png'.format(i), x_test[i])

        exit(0)

    if plot_images:
        # plot 4 images as gray scale
        plt.subplot(221)
        plt.imshow(x_train[0], cmap=plt.get_cmap('gray'))
        plt.subplot(222)
        plt.imshow(x_train[1], cmap=plt.get_cmap('gray'))
        plt.subplot(223)
        plt.imshow(x_train[2], cmap=plt.get_cmap('gray'))
        plt.subplot(224)
        plt.imshow(x_train[3], cmap=plt.get_cmap('gray'))
        plt.show()

        exit(0)

    # if not exporting or plotting, return processed dataset
    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
    x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    y_train = keras.utils.to_categorical(y_train, 10)
    y_test = keras.utils.to_categorical(y_test, 10)

    return x_train, y_train, x_test, y_test

########################################################################################################################
# Build the basic CNN model:

def build_model():
    model = Sequential()
    model.add(Conv2D(filters=64, kernel_size=3, strides=1, padding='same', activation='relu', input_shape=[28, 28, 1]))
    # model is now 28*28*64
    model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
    # model is now 14*14*64

    model.add(Conv2D(filters=128, kernel_size=3, strides=1, padding='same', activation='relu'))
    # model is now 14*14*128
    model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
    # model is now 7*7*128

    model.add(Conv2D(filters=256, kernel_size=3, strides=1, padding='same', activation='relu'))
    # model is now 7*7*256
    model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
    # model is now 4*4*256

    model.add(Flatten())
    model.add(Dense(1024, activation='relu'))
    model.add(Dropout(0.5))

    model.add(Dense(10, activation='softmax'))
    # model is now 10

    return model

########################################################################################################################
# Train the model using Adadelta optimizer from Keras

def train(model, x_train, y_train, x_test, y_test):
    model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(),
                  metrics=['accuracy'])

    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=True, validation_data=(x_test, y_test))

########################################################################################################################
# Export the frozen graph for later use in Unity

def export_model(saver, model, input_node_names, output_node_name):
    if not path.exists('out'):
        os.mkdir('out')

    tf.train.write_graph(K.get_session().graph_def, 'out', model_name + '_graph.pbtxt')

    saver.save(K.get_session(), 'out/' + model_name + '.chkp')

    freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, False,
                              'out/' + model_name + '.chkp', output_node_name,
                              "save/restore_all", "save/Const:0",
                              'out/frozen_' + model_name + '.bytes', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + model_name + '.bytes', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/opt_' + model_name + '.bytes', "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("graph saved!")

########################################################################################################################
# Main program

def main():

    # 1. load dataset
    x_train, y_train, x_test, y_test = load_data()

    # 2. build model
    model = build_model()

    # 3. train model
    train(model, x_train, y_train, x_test, y_test)

    # 4. export model to file for Unity
    export_model(tf.train.Saver(), model, ["conv2d_1_input"], "dense_2/Softmax")


if __name__ == '__main__':
    main()