#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:13:27 2018

@author: shirhe-lyh
"""

"""Functions to export inference graph.

Modified from: TensorFlow models/research/object_detection/export.py
"""

import logging
import os
import tempfile
import tensorflow as tf

from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib

slim = tf.contrib.slim


# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
    variable_names_blacklist=''):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.
    
    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError(
            "Input checkpoint ' + input_checkpoint + ' does not exist!")
        
    if not output_node_names:
        raise ValueError(
            'You must supply the name of a node to --output_node_names.')
        
    # Remove all the explicit device specifications for this node. This helps
    # to make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ''
    
    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name='')
        config = tf.ConfigProto(graph_options=tf.GraphOptions())
        with session.Session(config=config) as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def)
                saver.restore(sess, input_checkpoint)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(
                    input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ':0')
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example
                        # it's 'global_step' or a similar housekeeping element)
                        # so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes)
            
            variable_names_blacklist = (variable_names_blacklist.split(',') if
                                        variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(','),
                variable_names_blacklist=variable_names_blacklist)
    return output_graph_def


def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
                                                 new_checkpoint_file):
    """Replaces variable values in the checkpoint with their moving averages.
    
    If the current checkpoint has shadow variables maintaining moving averages
    of the variables defined in the graph, this function generates a new
    checkpoint where the variables contain the values of their moving averages.
    
    Args:
        graph: A tf.Graph object.
        current_checkpoint_file: A checkpoint both original variables and
            their moving averages.
        new_checkpoint_file: File path to write a new checkpoint.
    """
    with graph.as_default():
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        ema_variables_to_restore = variable_averages.variables_to_restore()
        with tf.Session() as sess:
            read_saver = tf.train.Saver(ema_variables_to_restore)
            read_saver.restore(sess, current_checkpoint_file)
            write_saver = tf.train.Saver()
            write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
    """Returns input placeholder and a 4-D uint8 image tensor."""
    if input_shape is None:
        input_shape = (None, None, None, 3)
    input_tensor = tf.placeholder(
        dtype=tf.uint8, shape=input_shape, name='image_tensor')
    return input_tensor, input_tensor


def _encoded_image_string_tensor_input_placeholder():
    """Returns input that accepts a batch of PNG or JPEG strings.
    
    Returns:
        A tuple of input placeholder and the output decoded images.
    """
    batch_image_str_placeholder = tf.placeholder(
        dtype=tf.string,
        shape=[None],
        name='encoded_image_string_tensor')
    def decode(encoded_image_string_tensor):
        image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                             channels=3)
        image_tensor.set_shape((None, None, 3))
        return image_tensor
    return (batch_image_str_placeholder,
            tf.map_fn(
                decode,
                elems=batch_image_str_placeholder,
                dtype=tf.uint8,
                parallel_iterations=32,
                back_prop=False))


input_placeholder_fn_map = {
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
        _encoded_image_string_tensor_input_placeholder,
#    'tf_example': _tf_example_input_placeholder,
    }


def _add_output_tensor_nodes(postprocessed_tensors,
                             output_collection_name='inference_op'):
    """Adds output nodes.
    
    Adjust according to specified implementations.
    
    Adds the following nodes for output tensors:
        * classes: A float32 tensor of shape [batch_size] containing class
            predictions.
    
    Args:
        postprocessed_tensors: A dictionary containing the following fields:
            'classes': [batch_size].
        output_collection_name: Name of collection to add output tensors to.
        
    Returns:
        A tensor dict containing the added output tensor nodes.
    """
    outputs = {}
    for class_name, logits in postprocessed_tensors.items():
        outputs[class_name] = tf.identity(logits, name=class_name)
    for output_key in outputs:
        tf.add_to_collection(output_collection_name, outputs[output_key])
    return outputs


def write_frozen_graph(frozen_graph_path, frozen_graph_def):
    """Writes frozen graph to disk.
    
    Args:
        frozen_graph_path: Path to write inference graph.
        frozen_graph_def: tf.GraphDef holding frozen graph.
    """
    with gfile.GFile(frozen_graph_path, 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())
    logging.info('%d ops in the final graph.', len(frozen_graph_def.node))
    
    
def write_saved_model(saved_model_path,
                      frozen_graph_def,
                      inputs,
                      outputs):
    """Writes SavedModel to disk.
    
    If checkpoint_path is not None bakes the weights into the graph thereby
    eliminating the need of checkpoint files during inference. If the model
    was trained with moving averages, setting use_moving_averages to True
    restores the moving averages, otherwise the original set of variables
    is restored.
    
    Args:
        saved_model_path: Path to write SavedModel.
        frozen_graph_def: tf.GraphDef holding frozen graph.
        inputs: The input image tensor.
        outputs: A tensor dictionary containing the outputs of a slim model.
    """
    with tf.Graph().as_default():
        with session.Session() as sess:
            tf.import_graph_def(frozen_graph_def, name='')
            
            builder = tf.saved_model.builder.SavedModelBuilder(
                saved_model_path)
            
            tensor_info_inputs = {
                'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
            tensor_info_outputs = {}
            for k, v in outputs.items():
                tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(
                    v)
                
            detection_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs=tensor_info_inputs,
                    outputs=tensor_info_outputs,
                    method_name=signature_constants.PREDICT_METHOD_NAME))
            
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                        detection_signature,
                        },
            )
            builder.save()


def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
    """Writes the graph and the checkpoint into disk."""
    for node in inference_graph_def.node:
        node.device = ''
    with tf.Graph().as_default():
        tf.import_graph_def(inference_graph_def, name='')
        with session.Session() as sess:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    save_relative_paths=True)
            saver.restore(sess, trained_checkpoint_prefix)
            saver.save(sess, model_path)


def _get_outputs_from_inputs(input_tensors, model, 
                             output_collection_name):
    inputs = tf.to_float(input_tensors)
    preprocessed_inputs = model.preprocess(inputs)
    output_tensors = model.predict(preprocessed_inputs)
    postprocessed_tensors = model.postprocess(output_tensors)
    return _add_output_tensor_nodes(postprocessed_tensors,
                                    output_collection_name)
    
    
def _build_model_graph(input_type, model, input_shape, 
                           output_collection_name, graph_hook_fn):
    """Build the desired graph."""
    if input_type not in input_placeholder_fn_map:
        raise ValueError('Unknown input type: {}'.format(input_type))
    placeholder_args = {}
    if input_shape is not None:
        if input_type != 'image_tensor':
            raise ValueError("Can only specify input shape for 'image_tensor' "
                             'inputs.')
        placeholder_args['input_shape'] = input_shape
    placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
        **placeholder_args)
    outputs = _get_outputs_from_inputs(
        input_tensors=input_tensors,
        model=model,
        output_collection_name=output_collection_name)
    
    # Add global step to the graph
    slim.get_or_create_global_step()
    
    if graph_hook_fn: graph_hook_fn()
    
    return outputs, placeholder_tensor


def export_inference_graph(input_type,
                           model,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           use_moving_averages=None,
                           output_collection_name='inference_op',
                           additional_output_tensor_names=None,
                           graph_hook_fn=None):
    """Exports inference graph for the desired graph.
    
    Args:
        input_type: Type of input for the graph. Can be one of ['image_tensor',
            'encoded_image_string_tensor', 'tf_example']. In this file, 
            input_type must be 'image_tensor'.
        model: A model defined by model.py.
        trained_checkpoint_prefix: Path to the trained checkpoint file.
        output_directory: Path to write outputs.
        input_shape: Sets a fixed shape for an 'image_tensor' input. If not
            specified, will default to [None, None, None, 3].
        use_moving_averages: A boolean indicating whether the 
            tf.train.ExponentialMovingAverage should be used or not.
        output_collection_name: Name of collection to add output tensors to.
            If None, does not add output tensors to a collection.
        additional_output_tensor_names: List of additional output tensors to
            include in the frozen graph.
    """
    tf.gfile.MakeDirs(output_directory)
    frozen_graph_path = os.path.join(output_directory,
                                     'frozen_inference_graph.pb')
    saved_model_path = os.path.join(output_directory, 'saved_model')
    model_path = os.path.join(output_directory, 'model.ckpt')
    
    outputs, placeholder_tensor = _build_model_graph(
        input_type=input_type,
        model=model,
        input_shape=input_shape,
        output_collection_name=output_collection_name,
        graph_hook_fn=graph_hook_fn)
    
    saver_kwargs = {}
    if use_moving_averages:
        # This check is to be compatible with both version of SaverDef.
        if os.path.isfile(trained_checkpoint_prefix):
            saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
            temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
        else:
            temp_checkpoint_prefix = tempfile.mkdtemp()
        replace_variable_values_with_moving_averages(
            tf.get_default_graph(), trained_checkpoint_prefix,
            temp_checkpoint_prefix)
        checkpoint_to_use = temp_checkpoint_prefix
    else:
        checkpoint_to_use = trained_checkpoint_prefix
    
    saver = tf.train.Saver(**saver_kwargs)
    input_saver_def = saver.as_saver_def()
    
    write_graph_and_checkpoint(
        inference_graph_def=tf.get_default_graph().as_graph_def(),
        model_path=model_path,
        input_saver_def=input_saver_def,
        trained_checkpoint_prefix=checkpoint_to_use)
    
    if additional_output_tensor_names is not None:
        output_node_names = ','.join(outputs.keys()+
                                     additional_output_tensor_names)
    else:
        output_node_names = ','.join(outputs.keys())
        
    frozen_graph_def = freeze_graph_with_def_protos(
        input_graph_def=tf.get_default_graph().as_graph_def(),
        input_saver_def=input_saver_def,
        input_checkpoint=checkpoint_to_use,
        output_node_names=output_node_names,
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        clear_devices=True,
        initializer_nodes='')
    write_frozen_graph(frozen_graph_path, frozen_graph_def)
    write_saved_model(saved_model_path, frozen_graph_def,
                      placeholder_tensor, outputs)