import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.python.layers.core import Dense
from common import *
from data_generator_ctc_joint_attention import *


class CtcPlusAttModel(object):
    """
    Class CtcPlusAttModel
    """
    def __init__(self):
        """
        Initialize global variables and compute graph
        """
        # vocabulary parameters
        self.start_token = START_TOKEN
        self.end_token = END_TOKEN
        self.unk_token = UNK_TOKEN
        self.vocab_att = VOCAB_ATT
        self.vocab_att_size = VOCAB_ATT_SIZE
        self.vocab_ctc = VOCAB_CTC
        self.vocab_ctc_size = VOCAB_CTC_SIZE

        # training parameters
        self.batch_size = BATCH_SIZE
        self.rnn_units = RNN_UNITS
        self.max_train_steps = TRAIN_STEP
        self.image_height = IMAGE_HEIGHT
        self.att_embed_dim = ATT_EMBED_DIM
        self.max_dec_iteration = MAXIMUM__DECODE_ITERATIONS
        # loss weights refrencehttps://arxiv.org/pdf/1609.06773v1.pdf
        self.ctc_loss_weights = 0.8
        self.att_loss_weights = 1 - self.ctc_loss_weights
        # choose attention mode 0 is "Bahdanau" Attention, 1 is "Luong" Attention
        self.attention_mode = 1

        # visualization path and model saved path
        self.logs_path = LOGS_PATH
        self.save_model_dir = CKPT_DIR

        # input image
        self.input_image = tf.placeholder(tf.float32, shape=(None, self.image_height, None, 1), name='img_data')

        # attention part placeholder
        self.att_train_output = tf.placeholder(tf.int64, shape=[None, None], name='att_train_output')
        self.att_train_length = tf.placeholder(tf.int32, shape=[None], name='att_train_length')
        self.att_target_output = tf.placeholder(tf.int64, shape=[None, None], name='att_target_output')

        # ctc part placeholder
        self.ctc_label = tf.sparse_placeholder(tf.int32, name='ctc_label')
        self.ctc_feature_length = tf.placeholder(tf.int32, shape=[None], name='ctc_feature_length')

        #
        self.sess = tf.Session()

    def __shared_encoder(self):
        """
        Image features encoded by CNN and bidirectional GRU
        :return: encoded features
        """
        convolution1 = layers.conv2d(inputs=self.input_image,
                                     num_outputs=64,
                                     kernel_size=[3, 3],
                                     padding='SAME',
                                     activation_fn=tf.nn.relu)
        pool1 = layers.max_pool2d(inputs=convolution1, kernel_size=[2, 2], stride=[2, 2])

        convolution2 = layers.conv2d(inputs=pool1,
                                     num_outputs=128,
                                     kernel_size=[3, 3],
                                     padding='SAME',
                                     activation_fn=tf.nn.relu)
        pool2 = layers.max_pool2d(inputs=convolution2, kernel_size=[2, 2], stride=[2, 2])

        convolution3 = layers.conv2d(inputs=pool2,
                                     num_outputs=256,
                                     kernel_size=[3, 3],
                                     padding='SAME',
                                     activation_fn=tf.nn.relu)

        convolution4 = layers.conv2d(inputs=convolution3,
                                     num_outputs=256,
                                     kernel_size=[3, 3],
                                     padding='SAME',
                                     activation_fn=tf.nn.relu)
        pool3 = layers.max_pool2d(inputs=convolution4, kernel_size=[2, 1], stride=[2, 1])

        convolution5 = layers.conv2d(inputs=pool3,
                                     num_outputs=512,
                                     kernel_size=[3, 3],
                                     padding='SAME',
                                     activation_fn=tf.nn.relu)
        n1 = layers.batch_norm(convolution5)

        convolution6 = layers.conv2d(inputs=n1,
                                     num_outputs=512,
                                     kernel_size=[3, 3],
                                     padding='SAME',
                                     activation_fn=tf.nn.relu)
        n2 = layers.batch_norm(convolution6)
        pool4 = layers.max_pool2d(inputs=n2, kernel_size=[2, 1], stride=[2, 1])

        convolution7 = layers.conv2d(inputs=pool4,
                                     num_outputs=512,
                                     kernel_size=[2, 2],
                                     padding='VALID',
                                     activation_fn=tf.nn.relu)
        cnn_out = tf.squeeze(convolution7, axis=1)

        cell = tf.contrib.rnn.GRUCell(num_units=self.rnn_units)
        enc_outputs, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell,
                                                                 cell_bw=cell,
                                                                 inputs=cnn_out,
                                                                 dtype=tf.float32)
        encoder_outputs = tf.concat(enc_outputs, -1)
        return encoder_outputs

    def __ctc_loss_branch(self, rnn_features):
        """
        Ctc loss compute graph
        :param rnn_features: encoded features and self.ctc_feature_length、self.ctc_label
        :return: loss matrix
        """
        project_output = layers.fully_connected(inputs=rnn_features,
                                                num_outputs=self.vocab_ctc_size + 1,
                                                activation_fn=None)
        # if time_major=True(default) the inputs must be the shape of [max_time x batch_size x num_classes].
        ctc_loss = tf.nn.ctc_loss(labels=self.ctc_label,
                                  inputs=project_output,
                                  sequence_length=self.ctc_feature_length,
                                  time_major=False)
        return ctc_loss

    def __attention_loss_branch(self, rnn_features):
        output_embed = layers.embed_sequence(self.att_train_output,
                                             vocab_size=self.vocab_att_size,
                                             embed_dim=self.att_embed_dim, scope='embed')
        #  with tf.device('/cpu:0'):
        embeddings = tf.Variable(tf.truncated_normal(shape=[self.vocab_att_size, self.att_embed_dim],
                                                     stddev=0.1), name='decoder_embedding')
        start_tokens = tf.zeros([self.batch_size], dtype=tf.int64)

        train_helper = tf.contrib.seq2seq.TrainingHelper(output_embed, self.att_train_length)
        pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddings,
                                                               start_tokens=tf.to_int32(start_tokens),
                                                               end_token=1)

        train_outputs = self.__att_decode(train_helper, rnn_features, 'decode')
        pred_outputs = self.__att_decode(pred_helper, rnn_features, 'decode', reuse=True)

        # train_decode_result = train_outputs[0].rnn_output[0, :-1, :]
        # pred_decode_result = pred_outputs[0].rnn_output[0, :, :]

        mask = tf.cast(tf.sequence_mask(self.batch_size * [self.att_train_length[0]-1], self.att_train_length[0]), tf.float32)
        att_loss = tf.contrib.seq2seq.sequence_loss(train_outputs[0].rnn_output, self.att_target_output,
                                                weights=mask)

        return att_loss

    def __att_decode(self, helper, rnn_features, scope, reuse=None):
        """
        Attention decode part
        :param helper: train or inference
        :param rnn_features: encoded features
        :param scope: name scope
        :param reuse: reuse or not
        :return: attention decode output
        """
        with tf.variable_scope(scope, reuse=reuse):
            if self.attention_mode == 1:
                attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=self.rnn_units,
                                                                        memory=rnn_features)
            else:
                attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=self.rnn_units,
                                                                           memory=rnn_features)

            cell = tf.contrib.rnn.GRUCell(num_units=self.rnn_units)
            attn_cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechanism,
                                                            attention_layer_size=self.rnn_units,
                                                            output_attention=True)
            output_layer = Dense(units=self.vocab_att_size)

            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=attn_cell, helper=helper,
                initial_state=attn_cell.zero_state(dtype=tf.float32, batch_size=self.batch_size),
                output_layer=output_layer)

            att_outputs = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder, output_time_major=False,
                impute_finished=True, maximum_iterations=self.max_dec_iteration)

            return att_outputs

    def build_model(self):
        """
        build compute graph
        :return: model
        """
        # share part
        encode_features = self.__shared_encoder()

        # attention part
        attention_loss = tf.reduce_mean(self.__attention_loss_branch(encode_features))

        # ctc part
        ctc_loss = tf.reduce_mean(self.__ctc_loss_branch(encode_features))

        # merge part
        t_loss = attention_loss*self.att_loss_weights + ctc_loss*self.ctc_loss_weights
        train_step = tf.train.AdadeltaOptimizer().minimize(t_loss)
        return train_step, t_loss

    def load_data(self):
        data_gen = gen_training_data(self.batch_size)
        return data_gen

    def train_process(self):
        train_step, loss = self.build_model()
        with self.sess.as_default():
            self.sess.run(tf.global_variables_initializer())
            data_gen = self.load_data()
            for step in range(self.max_train_steps):
                input_data = data_gen.__next__()
                self.sess.run(train_step, feed_dict={self.input_image: input_data['input_image'],
                                                     self.ctc_label: input_data['ctc_label'],
                                                     self.ctc_feature_length: input_data['ctc_feature_length'],
                                                     self.att_train_output: input_data['att_train_output'],
                                                     self.att_train_length: input_data['att_train_length'],
                                                     self.att_target_output: input_data['att_target_output']})
                if step % DISPLAY_STEPS == 0:
                    loss_print = self.sess.run(loss,
                                               feed_dict={self.input_image: input_data['input_image'],
                                                          self.ctc_label: input_data['ctc_label'],
                                                          self.ctc_feature_length: input_data['ctc_feature_length'],
                                                          self.att_train_output: input_data['att_train_output'],
                                                          self.att_train_length: input_data['att_train_length'],
                                                          self.att_target_output: input_data['att_target_output']})
                    print("step: {}\t loss:\t{}".format(step, loss_print))

    def visualize_log(self):
        pass


if __name__ == '__main__':
    ctc_att_model = CtcPlusAttModel()
    ctc_att_model.train_process()

    # loss = ctc_att_model.build_model()
    # print(loss.get_shape())