# coding=utf-8
#
# created by kpe on 02.Sep.2019 at 11:57
#

from __future__ import absolute_import, division, print_function

import unittest

import os
import re
import tempfile

import bert

import numpy as np
import tensorflow as tf
from tensorflow import keras

from .test_common import AbstractBertTest, MiniBertFactory

#tf.enable_eager_execution()
#tf.disable_eager_execution()


class TestExtendSegmentVocab(AbstractBertTest):

    def setUp(self) -> None:
        tf.compat.v1.reset_default_graph()
        tf.compat.v1.enable_eager_execution()
        print("Eager Execution:", tf.executing_eagerly())

    def test_extend_pretrained_segments(self):

        model_dir = tempfile.TemporaryDirectory().name
        os.makedirs(model_dir)
        save_path = MiniBertFactory.create_mini_bert_weights(model_dir)
        tokenizer = bert.bert_tokenization.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True)

        ckpt_dir = os.path.dirname(save_path)
        bert_params = bert.params_from_pretrained_ckpt(ckpt_dir)

        self.assertEqual(bert_params.token_type_vocab_size, 2)
        bert_params.token_type_vocab_size = 4

        l_bert = bert.BertModelLayer.from_params(bert_params)

        # we dummy call the layer once in order to instantiate the weights
        l_bert([np.array([[1, 1, 0]]),
                np.array([[1, 0, 0]])])#, mask=[[True, True, False]])

        #
        # - load the weights from a pre-trained model,
        # - expect a mismatch for the token_type embeddings
        # - use the segment/token type id=0 embedding for the missing token types
        #
        mismatched = bert.load_stock_weights(l_bert, save_path)

        self.assertEqual(1, len(mismatched), "token_type embeddings should have mismatched shape")

        for weight, value in mismatched:
            if re.match("(.*)embeddings/token_type_embeddings/embeddings:0", weight.name):
                seg0_emb = value[:1, :]
                new_segment_embeddings = np.repeat(seg0_emb, (weight.shape[0]-value.shape[0]), axis=0)
                new_value = np.concatenate([value, new_segment_embeddings], axis=0)
                keras.backend.batch_set_value([(weight, new_value)])

        tte = l_bert.embeddings_layer.token_type_embeddings_layer.weights[0]

        if not tf.executing_eagerly():
            with tf.keras.backend.get_session() as sess:
                tte, = sess.run((tte, ))

        self.assertTrue(np.allclose(seg0_emb, tte[0], 1e-6))
        self.assertFalse(np.allclose(seg0_emb, tte[1], 1e-6))
        self.assertTrue(np.allclose(seg0_emb, tte[2], 1e-6))
        self.assertTrue(np.allclose(seg0_emb, tte[3], 1e-6))

        bert_params.token_type_vocab_size = 4
        print("token_type_vocab_size", bert_params.token_type_vocab_size)
        print(l_bert.embeddings_layer.trainable_weights[1])