'''Trains a simple convnet on the MNIST dataset.
Gets to 99.25% test accuracy after 12 epochs
(there is still a lot of margin for parameter tuning).
16 seconds per epoch on a GRID K520 GPU.
'''

#converted for ue4 use from
#https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py

import json
from pathlib import Path

from tensorflow.python import keras
from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras.models import Sequential, load_model
from tensorflow.python.keras.layers import Dense, Dropout, Flatten
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D
from tensorflow.python.keras.models import model_from_json
from tensorflow.python.keras import backend as K
import numpy as np
import operator
import sys
import random

import tensorflow as tf
import unreal_engine as ue
from TFPluginAPI import TFPluginAPI

class MnistKeras(TFPluginAPI):

	#keras stop callback
	class StopCallback(keras.callbacks.Callback):
		def __init__(self, outer):
			self.outer = outer

		def on_train_begin(self, logs={}):
			self.losses = []

		def on_batch_end(self, batch, logs={}):
			if(self.outer.shouldStop):
				#notify on first call
				if not (self.model.stop_training):
					ue.log('Early stop called!')
				self.model.stop_training = True

			else:
				if(batch % 5 == 0):
					#json convertible types are float64 not float32
					logs['acc'] = np.float64(logs['acc'])
					logs['loss'] = np.float64(logs['loss'])
					self.outer.callEvent('TrainingUpdateEvent', logs, True)

				#callback an example image from batch to see the actual data we're training on
				if((batch*self.outer.batch_size) % 10000 == 0):
					index = random.randint(0,self.outer.batch_size)*batch
					self.outer.jsonPixels['pixels'] = self.outer.x_train[index].ravel().tolist()
					self.outer.callEvent('PixelEvent', self.outer.jsonPixels, True)

	#expected api: setup your model for training
	def onSetup(self):
		#setup or load your model and pass it into stored
		
		#Usually store session, graph, and model if using keras
		#self.sess = tf.InteractiveSession()
		#self.graph = tf.get_default_graph()

		self.stopcallback = self.StopCallback(self)

	#expected api: storedModel and session, json inputs
	def onJsonInput(self, jsonInput):
		#build the result object
		result = {'prediction':-1}

		#prepare the input
		x_raw = [jsonInput['pixels']]
		x_raw = np.reshape(x_raw, (1, 28, 28))

		ue.log('image shape: ' + str(x_raw.shape))
		#ue.log(stored)

		#convert pixels to N_samples, height, width, N_channels input tensor
		x = np.reshape(x_raw, (len(x_raw), 28, 28, 1))

		ue.log('input shape: ' + str(x.shape))

		#run run the input through our network
		if self.model is None:
			ue.log("Warning! No 'model' found. Did training complete?")
			return result

		#restore our saved session and model
		K.set_session(self.session)

		with self.session.as_default():
			output = self.model.predict(x)

			ue.log(output)

			#convert output array to prediction
			index, value = max(enumerate(output[0]), key=operator.itemgetter(1))

			result['prediction'] = index
			result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round tripping

		return result

	#expected api: no params forwarded for training? TBC
	def onBeginTraining(self):
		ue.log("starting mnist keras cnn training")

		model_file_name = "mnistKerasCNN"
		model_directory = ue.get_content_dir() + "/Scripts/"
		model_sess_path =  model_directory + model_file_name + ".tfsess"
		model_json_path = model_directory + model_file_name + ".json"

		my_file = Path(model_json_path)

		#reset the session each time we get training calls
		K.clear_session()

		#let's train
		batch_size = 128
		num_classes = 10
		epochs = 5 					 # lower default for simple testing
		self.batch_size = batch_size # so that it can be obtained inside keras callbacks

		# input image dimensions
		img_rows, img_cols = 28, 28

		# the data, shuffled and split between train and test sets
		(x_train, y_train), (x_test, y_test) = mnist.load_data()

		if K.image_data_format() == 'channels_first':
			x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
			x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
			input_shape = (1, img_rows, img_cols)
		else:
			x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
			x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
			input_shape = (img_rows, img_cols, 1)

		x_train = x_train.astype('float32')
		x_test = x_test.astype('float32')
		x_train /= 255
		x_test /= 255
		ue.log('x_train shape:' + str(x_train.shape))
		ue.log(str(x_train.shape[0]) + 'train samples')
		ue.log(str(x_test.shape[0]) + 'test samples')

		#pre-fill our callEvent data to optimize callbacks
		jsonPixels = {}
		size = {'x':28, 'y':28}
		jsonPixels['size'] = size
		self.jsonPixels = jsonPixels
		self.x_train = x_train

		# convert class vectors to binary class matrices
		y_train = keras.utils.to_categorical(y_train, num_classes)
		y_test = keras.utils.to_categorical(y_test, num_classes)

		model = Sequential()
		model.add(Conv2D(64, kernel_size=(3, 3),
						  activation='relu',
						  input_shape=input_shape))
		
		# model.add(Dropout(0.2))
		# model.add(Flatten())
		# model.add(Dense(512, activation='relu'))
		# model.add(Dropout(0.2))
		# model.add(Dense(num_classes, activation='softmax'))

		#model.add(Conv2D(64, (3, 3), activation='relu'))
		model.add(MaxPooling2D(pool_size=(2, 2)))
		model.add(Dropout(0.25))
		model.add(Flatten())
		model.add(Dense(128, activation='relu'))
		model.add(Dropout(0.5))
		model.add(Dense(num_classes, activation='softmax'))

		model.compile(loss=keras.losses.categorical_crossentropy,
					  optimizer=keras.optimizers.Adadelta(),
					  metrics=['accuracy'])

		model.fit(x_train, y_train,
				  batch_size=batch_size,
				  epochs=epochs,
				  verbose=1,
				  validation_data=(x_test, y_test),
				  callbacks=[self.stopcallback])
		score = model.evaluate(x_test, y_test, verbose=0)
		ue.log("mnist keras cnn training complete.")
		ue.log('Test loss:' + str(score[0]))
		ue.log('Test accuracy:' + str(score[1]))

		self.session = K.get_session()
		self.model = model

		stored = {'model':model, 'session': self.session}

		#run a test evaluation
		ue.log(x_test.shape)
		result_test = model.predict(np.reshape(x_test[500],(1,28,28,1)))
		ue.log(result_test)

		#flush the architecture model data to disk
		#with open(model_json_path, "w") as json_file:
		#	json_file.write(model.to_json())

		#flush the whole model and weights to disk
		#saver = tf.train.Saver()
		#save_path = saver.save(K.get_session(), model_sess_path)
		#model.save(model_path)

		
		return stored

#required function to get our api
def getApi():
	#return CLASSNAME.getInstance()
	return MnistKeras.getInstance()