#----------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License. See License.txt in the project root for license information.
#----------------------------------------------------------------------------------------------

import os
from six import string_types as _string_types
import keras as _keras
from keras import backend as _K

from mmdnn.conversion.keras.keras2_graph import Keras2Graph
import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2
from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType
from mmdnn.conversion.common.DataStructure.parser import Parser
from mmdnn.conversion.common.utils import *


class Keras2Parser(Parser):

    dtype_map = {
        "float16" : graph_pb2.DT_FLOAT16,
        "float32" : graph_pb2.DT_FLOAT32,
        "float64" : graph_pb2.DT_FLOAT64,
        "int16"   : graph_pb2.DT_INT16,
        "int32"   : graph_pb2.DT_INT32,
        "int64"   : graph_pb2.DT_INT64,
        "uint8"   : graph_pb2.DT_UINT8,
        "uint16"  : graph_pb2.DT_UINT16
    }

    activation_map = {
        "relu"          : "Relu",
        'softmax'       : "Softmax",
        'sigmoid'       : "Sigmoid",
        "tanh"          : "Tanh",
        "elu"           : "Elu",
        "relu6"         : "Relu6",
        'softplus'      : 'Softplus',
        'softsign'      : 'Softsign',
        'hard_sigmoid'  : 'HardSigmoid'
    }


    def _load_model(self, model_network_path, model_weight_path):
        """Load a keras model from disk

        Parameters
        ----------
        model_network_path: str
            Path where the model network path is (json file)

        model_weight_path: str
            Path where the model network weights are (hd5 file)

        Returns
        -------
        model: A keras model
        """
        from keras.models import model_from_json

        # Load the model network
        json_file = open(model_network_path, 'r')
        loaded_model_json = json_file.read()
        json_file.close()

        # Load the model weights

        try:
            from keras.applications.mobilenet import relu6
            from keras.applications.mobilenet import DepthwiseConv2D
            loaded_model = model_from_json(loaded_model_json, custom_objects={
                'relu6': _keras.applications.mobilenet.relu6,
                'DepthwiseConv2D': _keras.applications.mobilenet.DepthwiseConv2D})
        except:
            import keras.layers as layers
            loaded_model = model_from_json(loaded_model_json, custom_objects={
                'relu6': layers.ReLU(6, name='relu6'),
                'DepthwiseConv2D': layers.DepthwiseConv2D})


        if model_weight_path:
            if os.path.isfile(model_weight_path):
                loaded_model.load_weights(model_weight_path)
                self.weight_loaded = True
                print("Network file [{}] and [{}] is loaded successfully.".format(model_network_path, model_weight_path))

            else:
                print("Warning: Weights File [%s] is not found." % (model_weight_path))

        return loaded_model

    @property
    def src_graph(self):
        return self.keras_graph


    def __init__(self, model):
        super(Keras2Parser, self).__init__()

        # load model files into Keras graph
        if isinstance(model, _string_types):
            try:
                # Keras 2.1.6
                from keras.applications.mobilenet import relu6
                from keras.applications.mobilenet import DepthwiseConv2D
                model = _keras.models.load_model(
                    model,
                    custom_objects={
                        'relu6': _keras.applications.mobilenet.relu6,
                        'DepthwiseConv2D': _keras.applications.mobilenet.DepthwiseConv2D
                    }
                )
            except:
                # Keras. 2.2.2
                import keras.layers as layers
                model = _keras.models.load_model(
                    model,
                    custom_objects={
                        'relu6': layers.ReLU(6, name='relu6'),
                        'DepthwiseConv2D': layers.DepthwiseConv2D
                    }
                )
            self.weight_loaded = True

        elif isinstance(model, tuple):
            model = self._load_model(model[0], model[1])

        else:
            assert False

        # _keras.utils.plot_model(model, "model.png", show_shapes = True)

        # Build network graph
        self.data_format = _keras.backend.image_data_format()
        self.keras_graph = Keras2Graph(model)
        self.keras_graph.build()
        self.lambda_layer_count = 0


    def gen_IR(self):
        for layer in self.keras_graph.topological_sort:
            current_node = self.keras_graph.get_node(layer)
            node_type = current_node.type
  
            if hasattr(self, "rename_" + node_type):
                func = getattr(self, "rename_" + node_type)
                func(current_node)
            else:
                print("KerasParser has not supported operator [%s]." % (node_type))
                self.rename_UNKNOWN(current_node)

        _K.clear_session()


    @staticmethod
    def _set_output_shape(source_node, IR_node):
        shape = graph_pb2.TensorShape()
        for dim in source_node.layer.output_shape:
            new_dim = shape.dim.add()
            new_dim.size = dim if dim else -1

        IR_node.attr["_output_shapes"].list.shape.extend([shape])


    @staticmethod
    def _copy_and_reop(source_node, IR_node, new_op = None):
        IR_node.name = source_node.name
        IR_node.op = source_node.type if new_op == None else new_op

        if hasattr(source_node.layer, "dtype"):
            IR_node.attr["dtype"].type = Keras2Parser.dtype_map[source_node.layer.dtype]

        Keras2Parser._set_output_shape(source_node, IR_node)


    @staticmethod
    def _copy_shape(source_node, target_node):
        if hasattr(source_node, "output_shape"):
            for dim in source_node.output_shape:
                new_dim = target_node.attr["shape"].shape.dim.add()
                new_dim.size = -1 if dim == None else dim

        else:
            target_node.attr["shape"].shape.unknown_rank = True


    @staticmethod
    def _convert_dataformat(source_node, target_node):
        if source_node.keras_layer.data_format == 'channels_last':
            target_node.attr["data_format"].s = "NHWC"
        elif source_node.keras_layer.data_format == 'channels_first':
            target_node.attr["data_format"].s = "NCHW"
        else:
            print("Warning: [%s] don't have data format info." % (source_node.keras_layer.name))


    @staticmethod
    def _convert_padding(source_node, IR_node):
        # TODO: Fused conv and pool with padding is different from defused operators
        dims = len(source_node.layer.input_shape)
        if source_node.layer.padding == 'valid':
            assign_IRnode_values(IR_node, {'auto_pad' : "VALID", 'pads' : [0, 0] * dims})

        elif source_node.layer.padding == 'same':
            kernel_shape = source_node.layer.kernel_size if hasattr(source_node.layer, 'kernel_size') else source_node.layer.pool_size
            padding = compute_tf_same_padding(
                source_node.layer.input_shape,
                kernel_shape,
                list(source_node.layer.strides))
            assign_IRnode_values(IR_node, {'auto_pad' : "SAME_LOWER", 'pads' : padding})

        else:
            assert False


    def _defuse_activation(self, source_node):
        if source_node.layer.activation is None or source_node.layer.activation.__name__ == "linear":
            return

        IR_node = self.IR_graph.node.add()
        IR_node.name = source_node.real_name + "_activation"
        IR_node.op = Keras2Parser.activation_map[source_node.layer.activation.__name__]
        IR_node.input.append(source_node.real_name)
        Keras2Parser._set_output_shape(source_node, IR_node)

        # TODO: More activation functions
        # for ELU
        if hasattr(source_node.layer, 'alpha'):
            assign_attr_value(IR_node['alpha'], source_node.layer.alpha)

        source_node.real_name = IR_node.name


    def _convert_convolution(self, source_node, dim):
        IR_node = self.IR_graph.node.add()

        # input edge
        self.convert_inedge(source_node, IR_node)

        # name, op
        if source_node.type.startswith('Separable'):
            Keras2Parser._copy_and_reop(source_node, IR_node, "SeparableConv")
            if self.weight_loaded:
                self.set_weight(source_node.name, 'depthwise_filter', source_node.layer.get_weights()[0])
                self.set_weight(source_node.name, 'pointwise_filter', source_node.layer.get_weights()[1])

        else:
            if source_node.type.startswith('Conv'):
                if source_node.type.endswith('Transpose'):
                    Keras2Parser._copy_and_reop(source_node, IR_node, "ConvTranspose")
                else:
                    Keras2Parser._copy_and_reop(source_node, IR_node, "Conv")
            elif source_node.type.startswith('Deconv'):
                Keras2Parser._copy_and_reop(source_node, IR_node, "ConvTranspose")

            elif source_node.type.startswith('Depthwise'):
                Keras2Parser._copy_and_reop(source_node, IR_node, "DepthwiseConv")

            else:
                raise NotImplementedError("Convolution layer [{}] is not supported.".format(source_node.type))

            # weights
            if self.weight_loaded:
                self.set_weight(source_node.name, "weights", source_node.layer.get_weights()[0])
                if source_node.layer.use_bias:
                    self.set_weight(source_node.name, "bias", source_node.layer.get_weights()[1])

        if isinstance(source_node.layer.kernel_size, int):
            source_node.layer.kernel_size = (source_node.layer.kernel_size) * dim

        if isinstance(source_node.layer.strides, int):
            source_node.layer.strides = (source_node.layer.strides) * dim

        if isinstance(source_node.layer.dilation_rate, int):
            source_node.layer.dilation_rate = (source_node.layer.dilation_rate) * dim

        kwargs = dict()

        # pads
        Keras2Parser._convert_padding(source_node, IR_node)

        # filter
        # [kd, kh, kw, channel_size, filter number]
        in_channel = source_node.layer.input_shape[-1] if self.data_format == "channels_last" else source_node.layer.input_shape[1]
        out_channel = source_node.layer.filters or source_node.layer.depth_multiplier

        if source_node.type.startswith("Deconv"):
            kwargs['kernel_shape'] = list(source_node.layer.kernel_size) + [out_channel, in_channel]
        else:
            kwargs['kernel_shape'] = list(source_node.layer.kernel_size) + [in_channel, out_channel]

        # use_bias
        kwargs['use_bias'] = source_node.keras_layer.use_bias

        # strides
        # [1, sd, sh, sw, 1]
        kwargs['strides'] = [1] + list(source_node.layer.strides) + [1]

        # dilations
        # [1, dd, dh, dw, 1]
        kwargs['dilations'] = [1] + list(source_node.layer.dilation_rate) + [1]

        assign_IRnode_values(IR_node, kwargs)

        # activation
        self._defuse_activation(source_node)


    def _convert_pooling(self, source_node, dim, pooling_type, is_global):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node, "Pool")

        # input edge
        self.convert_inedge(source_node, IR_node)

        kwargs = {}

        kwargs['pooling_type'] = pooling_type


        if is_global:
            kwargs['global_pooling'] = True
            kwargs['strides'] = [1] * (dim + 2)
        else:
            if isinstance(source_node.layer.pool_size, int):
                source_node.layer.pool_size = (source_node.layer.pool_size) * dim

            if isinstance(source_node.layer.strides, int):
                source_node.layer.strides = (source_node.layer.strides) * dim

            # padding
            self._convert_padding(source_node, IR_node)

            # strides
            # [1, sd, sh, sw, 1]
            kwargs['strides'] = [1] + list(source_node.layer.strides) + [1]

            # window_shape
            # [1, pd, ph, pw, 1]
            kwargs['kernel_shape'] = [1] + list(source_node.layer.pool_size) + [1]

        assign_IRnode_values(IR_node, kwargs)

        if is_global:
            flatten_node = self.IR_graph.node.add()
            flatten_node.name = source_node.name + '_flatten'
            flatten_node.op = 'Flatten'
            flatten_node.input.append(source_node.name)
            Keras2Parser._set_output_shape(source_node, flatten_node)
            source_node.real_name = flatten_node.name


    def _convert_merge(self, source_node, new_name = None):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node, new_name)

        # input edge
        self.convert_inedge(source_node, IR_node)

        # For concat axis
        if hasattr(source_node.layer, 'axis'):
            axis = source_node.layer.axis
            if int(axis) == -1:
                axis = 3 if self.data_format == "channels_last" else 2
            IR_node.attr['axis'].i = axis

        return IR_node


    def _convert_padding_api(self, source_node, IR_node, mode):
         # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node, "Pad")

        # input edge
        self.convert_inedge(source_node, IR_node)

        kwargs = dict()
        kwargs['mode'] = mode

        # padding
        kwargs['pads'] = [0, 0]
        for padding_pair in source_node.layer.padding:
            kwargs['pads'].extend(padding_pair)
        kwargs['pads'] += [0, 0]
        kwargs['pads'] = convert_tf_pad_to_onnx(kwargs['pads'])
        assign_IRnode_values(IR_node, kwargs)


    def rename_UNKNOWN(self, source_node):
        print (source_node.layer.get_config())

        # only for training
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node)

        # input edge
        self.convert_inedge(source_node, IR_node)


    def rename_Activation(self, keras_node):
        IR_node = self.IR_graph.node.add()
        # name, op
        try:
            Keras2Parser._copy_and_reop(keras_node, IR_node, self.activation_map[keras_node.keras_layer.activation.__name__])
        except:
            Keras2Parser._copy_and_reop(keras_node, IR_node, self.activation_map[keras_node.keras_layer.activation.name])

        # input edge
        self.convert_inedge(keras_node, IR_node)


    # Merge Layers
    def rename_Add(self, source_node):
        self._convert_merge(source_node)


    def rename_Conv1D(self, source_node):
        self._convert_convolution(source_node, 1)

    def rename_Conv1DTranspose(self, source_node):
        self._convert_convolution(source_node, 1)

    def rename_Conv2D(self, source_node):
        self._convert_convolution(source_node, 2)

    def rename_Conv2DTranspose(self, source_node):
        self._convert_convolution(source_node, 2)

    def rename_Conv3D(self, source_node):
        self._convert_convolution(source_node, 3)

    def rename_Conv3DTranspose(self, source_node):
        self._convert_convolution(source_node, 3)

    def rename_InputLayer(self, source_node):
        # only for training
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node, "DataInput")

        # input edge
        self.convert_inedge(source_node, IR_node)

        # shape
        Keras2Parser._copy_shape(source_node.keras_layer, IR_node)



    def rename_GlobalMaxPooling1D(self, source_node):
        self._convert_pooling(source_node, 1, "MAX", True)


    def rename_GlobalMaxPooling2D(self, source_node):
        self._convert_pooling(source_node, 2, "MAX", True)


    def rename_GlobalMaxPooling3D(self, source_node):
        self._convert_pooling(source_node, 3, "MAX", True)


    def rename_GlobalAveragePooling1D(self, source_node):
        self._convert_pooling(source_node, 1, "AVG", True)


    def rename_GlobalAveragePooling2D(self, source_node):
        self._convert_pooling(source_node, 2, "AVG", True)


    def rename_GlobalAveragePooling3D(self, source_node):
        self._convert_pooling(source_node, 3, "AVG", True)


    def rename_MaxPooling1D(self, source_node):
        self._convert_pooling(source_node, 1, "MAX", False)


    def rename_MaxPooling2D(self, source_node):
        self._convert_pooling(source_node, 2, "MAX", False)


    def rename_MaxPooling3D(self, source_node):
        self._convert_pooling(source_node, 3, "MAX", False)


    def rename_AveragePooling1D(self, source_node):
        self._convert_pooling(source_node, 1, "AVG", False)


    def rename_AveragePooling2D(self, source_node):
        self._convert_pooling(source_node, 2, "AVG", False)


    def rename_AveragePooling3D(self, source_node):
        self._convert_pooling(source_node, 3, "AVG", False)


    def rename_Dropout(self, source_node):
        # only for training
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node)

        # input edge
        self.convert_inedge(source_node, IR_node)

        IR_node.attr["keep_prob"].f = source_node.keras_layer.rate
        if source_node.keras_layer.seed != None:
            IR_node.attr["seed"].i = source_node.keras_layer.seed


    # Core Layers
    def rename_Dense(self, source_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node, "FullyConnected")

        # input edge
        self.convert_inedge(source_node, IR_node)

        # units
        IR_node.attr["units"].i = source_node.keras_layer.units

        # use_bias
        IR_node.attr["use_bias"].b = source_node.keras_layer.use_bias

        # weights
        if self.weight_loaded == True:
            self.set_weight(source_node.name, 'weights', source_node.layer.get_weights()[0])
            if IR_node.attr["use_bias"].b == True:
                self.set_weight(source_node.name, 'bias', source_node.layer.get_weights()[1])

        # activation
        self._defuse_activation(source_node)


    def rename_Flatten(self, source_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node)

        # input edge
        self.convert_inedge(source_node, IR_node)

    def rename_UpSampling2D(self, source_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node)

        # input edge
        self.convert_inedge(source_node, IR_node)

        # size
        IR_node.attr["scales"].list.i.extend(source_node.keras_layer.size)


    def rename_Embedding(self, source_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node)

        # input edge
        self.convert_inedge(source_node, IR_node)

        # input_dim
        IR_node.attr["input_dim"].i = source_node.keras_layer.input_dim

        # output_dim
        IR_node.attr["output_dim"].i = source_node.keras_layer.output_dim

        # mask_zero
        IR_node.attr["mask_zero"].b = source_node.keras_layer.mask_zero

        # weights
        if self.weight_loaded:
            self.set_weight(source_node.name, 'embedding_weights', source_node.layer.get_weights()[0])


    def rename_LSTM(self, keras_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(keras_node, IR_node)

        # input edge
        self.convert_inedge(keras_node, IR_node)

        # units
        IR_node.attr["units"].i = keras_node.keras_layer.units

        # use_bias
        IR_node.attr["use_bias"].b = keras_node.keras_layer.use_bias

        # for Keras, drop_out and recurrent_dropout
        IR_node.attr["dropout"].f = keras_node.keras_layer.dropout
        IR_node.attr["recurrent_dropout"].f = keras_node.keras_layer.recurrent_dropout

        # activation
        self._defuse_activation(keras_node)


    def rename_GRU(self, source_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node)

        # input edge
        self.convert_inedge(source_node, IR_node)

        # units
        IR_node.attr["units"].i = source_node.keras_layer.units

        # activation
        self._defuse_activation(source_node)

        # weights
        if self.weight_loaded:
            self.set_weight(source_node.name, 'gru_weights', source_node.layer.get_weights()[0])
            self.set_weight(source_node.name, 'gru_recurrent_weights', source_node.layer.get_weights()[1])
            if source_node.layer.use_bias:
                self.set_weight(source_node.name, "gru_bias", source_node.layer.get_weights()[2])


    def rename_Multiply(self, source_node):
        self._convert_merge(source_node, 'Mul')


    def rename_Average(self, source_node):
        # Kit TODO : need to search the tf
        self._convert_merge(source_node, 'Avg')


    def rename_Maximum(self, source_node):
        self._convert_merge(source_node)


    def rename_Concatenate(self, source_node):
        IR_node = self._convert_merge(source_node, 'Concat')


    def rename_Reshape(self, source_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node, 'Reshape')

        # input edge
        self.convert_inedge(source_node, IR_node)

        # for target shape
        IR_node.attr["shape"].list.i.append(-1)
        IR_node.attr["shape"].list.i.extend(source_node.layer.target_shape)


    def rename_Lambda(self, source_node):
        node_type = source_node.layer.name
        if hasattr(self, "rename_" + node_type):
            print ("Try to convert Lambda function [{}]".format(source_node.layer.name))
            func = getattr(self, "rename_" + node_type)
            func(source_node)
        else:
            raise NotImplementedError("Lambda layer [{}] in keras is not supported yet.".format(node_type))


    def rename_BatchNormalization(self, keras_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(keras_node, IR_node, 'BatchNorm')

        # input edge
        self.convert_inedge(keras_node, IR_node)

        # axis
        IR_node.attr['axis'].i = keras_node.keras_layer.axis

        IR_node.attr['scale'].b = keras_node.keras_layer.scale

        IR_node.attr['bias'].b = keras_node.keras_layer.center

        IR_node.attr['epsilon'].f = keras_node.layer.epsilon

        if self.weight_loaded:
            # Parameter arrangement in Keras: gamma, beta, mean, variance
            idx = 0

            # scale
            if IR_node.attr['scale'].b:
                self.set_weight(keras_node.name, "scale", keras_node.layer.get_weights()[idx])
                idx += 1

            # beta
            if IR_node.attr['bias'].b:
                self.set_weight(keras_node.name, "bias", keras_node.layer.get_weights()[idx])
                idx += 1

            # mean
            self.set_weight(keras_node.name, "mean", keras_node.layer.get_weights()[idx])

            # var
            self.set_weight(keras_node.name, "var", keras_node.layer.get_weights()[idx + 1])


    def rename_ZeroPadding2D(self, keras_node):
        IR_node = self.IR_graph.node.add()
        self._convert_padding_api(keras_node, IR_node, "constant")


    def rename_SeparableConv2D(self, source_node):
        self._convert_convolution(source_node, 2)


    def rename_DepthwiseConv2D(self, source_node):
        self._convert_convolution(source_node, 2)


    def custom_relu6(x):
        return _keras.relu(x, max_value=6)


    def _convert_crop(self, source_node):
        IR_node = self.IR_graph.node.add()

        Keras2Parser._copy_and_reop(source_node, IR_node, "Crop")

        self.convert_inedge(source_node, IR_node)

        border = []
        for i in source_node.layer.cropping:
            for j in i:
                border.append(j)

        assign_IRnode_values(IR_node, {'border' : border})



    def rename_Cropping1D(self, source_node):
        self._convert_crop(source_node)


    def rename_Cropping2D(self, source_node):
        self._convert_crop(source_node)


    def rename_Cropping3D(self, source_node):
        self._convert_crop(source_node)


    def rename_LeakyReLU(self, source_node):
        IR_node = self.IR_graph.node.add()
        Keras2Parser._copy_and_reop(source_node, IR_node, 'LeakyRelu')
        self.convert_inedge(source_node, IR_node)
        assign_IRnode_values(IR_node, {'alpha' : source_node.layer.alpha.tolist()})


    def rename_ReLU(self, source_node):
        IR_node = self.IR_graph.node.add()
        max_value = source_node.layer.max_value
        if max_value == 6.0:
            Keras2Parser._copy_and_reop(source_node, IR_node, 'Relu6')
        else:
            Keras2Parser._copy_and_reop(source_node, IR_node, 'Relu')

        assign_IRnode_values(IR_node, {'max_value' : max_value})
        self.convert_inedge(source_node, IR_node)


    def rename_space_to_depth_x2(self, source_node):
        IR_node = self.IR_graph.node.add()

        # name, op
        Keras2Parser._copy_and_reop(source_node, IR_node, 'SpaceToDepth')
        IR_node.name = "Lambda_{}".format(self.lambda_layer_count)

        # input edge
        self.convert_inedge(source_node, IR_node)

        # for target shape
        IR_node.attr["blocksize"].i = 2
        self.lambda_layer_count = self.lambda_layer_count + 1
        source_node.real_name = IR_node.name