#
#     Copywrite 2017 Alan Steremberg and Arthur Conner
#

import argparse
from keras import backend as K
from keras.models import load_model
#from tensorflow_serving.session_bundle import exporter
from keras.models import model_from_config
from keras.models import Sequential,Model
import tensorflow as tf
import os


def convert(prevmodel,export_path,freeze_graph_binary):

   # open up a Tensorflow session
   sess = tf.Session()
   # tell Keras to use the session
   K.set_session(sess)

   # From this document: https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html
   
   # let's convert the model for inference
   K.set_learning_phase(0)  # all new operations will be in test mode from now on
   # serialize the model and get its weights, for quick re-building
   previous_model = load_model(prevmodel)
   previous_model.summary()

   config = previous_model.get_config()
   weights = previous_model.get_weights()

   # re-build a model where the learning phase is now hard-coded to 0
   try:
     model= Sequential.from_config(config) 
   except:
     model= Model.from_config(config) 
   #model= model_from_config(config)
   model.set_weights(weights)

   print("Input name:")
   print(model.input.name)
   print("Output name:")
   print(model.output.name)
   output_name=model.output.name.split(':')[0]

   #  not sure what this is for
   export_version = 1 # version number (integer)

   graph_file=export_path+"_graph.pb"
   ckpt_file=export_path+".ckpt"
   # create a saver 
   saver = tf.train.Saver(sharded=True)
   tf.train.write_graph(sess.graph_def, '', graph_file)
   save_path = saver.save(sess, ckpt_file)
#~/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=./graph.pb  --input_checkpoint=./model.ckpt --output_node_names=add_72 --output_graph=frozen.pb
   command = freeze_graph_binary +" --input_graph=./"+graph_file+" --input_checkpoint=./"+ckpt_file+" --output_node_names="+output_name+" --output_graph=./"+export_path+".pb"
   print(command)
   os.system(command)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Keras Tensorflow Converter')
    parser.add_argument(
        'model',
        type=str,
        help='Path to the keras model'
    )
    parser.add_argument(
        'frozen',
        type=str,
        help='Path to the frozen output'
    )
    parser.add_argument(
        'freezegraph',
        type=str,
        help='Path to the freeze_graph binary'
    )
    args = parser.parse_args()

    convert(args.model,args.frozen,args.freezegraph)