import math
import gc
import multiprocessing as mp
from sklearn import preprocessing

import base.batch as bat
from utils import *
import base.evaluation as eva
from data_model import DataModel
from MultiKE_model import MultiKE
from predicate_alignment import PredicateAlignModel


def valid(model, embed_choice='avg', w=(1, 1, 1)):
    if embed_choice == 'nv':
        ent_embeds = model.name_embeds.eval(session=model.session)
    elif embed_choice == 'rv':
        ent_embeds = model.rv_ent_embeds.eval(session=model.session)
    elif embed_choice == 'av':
        ent_embeds = model.av_ent_embeds.eval(session=model.session)
    elif embed_choice == 'final':
        ent_embeds = model.ent_embeds.eval(session=model.session)
    elif embed_choice == 'avg':
        ent_embeds = w[0] * model.name_embeds.eval(session=model.session) + \
                     w[1] * model.rv_ent_embeds.eval(session=model.session) + \
                     w[2] * model.av_ent_embeds.eval(session=model.session)
    else:  # 'final'
        ent_embeds = model.ent_embeds
    print(embed_choice, 'valid results:')
    embeds1 = ent_embeds[model.kgs.valid_entities1,]
    embeds2 = ent_embeds[model.kgs.valid_entities2 + model.kgs.test_entities2,]
    hits1_12, mrr_12 = eva.valid(embeds1, embeds2, None, model.args.top_k, model.args.test_threads_num,
                                 normalize=True)
    del embeds1, embeds2
    gc.collect()
    return mrr_12


def test(model, embed_choice='avg', w=(1, 1, 1)):
    if embed_choice == 'nv':
        ent_embeds = model.name_embeds.eval(session=model.session)
    elif embed_choice == 'rv':
        ent_embeds = model.rv_ent_embeds.eval(session=model.session)
    elif embed_choice == 'av':
        ent_embeds = model.av_ent_embeds.eval(session=model.session)
    elif embed_choice == 'final':
        ent_embeds = model.ent_embeds.eval(session=model.session)
    elif embed_choice == 'avg':
        ent_embeds = w[0] * model.name_embeds.eval(session=model.session) + \
                     w[1] * model.rv_ent_embeds.eval(session=model.session) + \
                     w[2] * model.av_ent_embeds.eval(session=model.session)
    else:  # wavg
        ent_embeds = model.ent_embeds
    print(embed_choice, 'test results:')
    embeds1 = ent_embeds[model.kgs.test_entities1,]
    embeds2 = ent_embeds[model.kgs.test_entities2,]
    hits1_12, mrr_12 = eva.valid(embeds1, embeds2, None, model.args.top_k, model.args.test_threads_num,
                                 normalize=True)
    del embeds1, embeds2
    gc.collect()
    return mrr_12


def _compute_weight(embeds1, embeds2, embeds3):
    def min_max_normalization(mat):
        min_ = np.min(mat)
        max_ = np.max(mat)
        return (mat - min_) / (max_ - min_)

    other_embeds = (embeds1 + embeds2 + embeds3) / 3
    # other_embeds = (embeds2 + embeds3) / 2
    other_embeds = preprocessing.normalize(other_embeds)
    embeds1 = preprocessing.normalize(embeds1)
    # sim_mat = sim(embeds1, other_embeds, metric='cosine')
    sim_mat = np.matmul(embeds1, other_embeds.T)
    # sim_mat = 1 - euclidean_distances(embeds1, other_embeds)
    weights = np.diag(sim_mat)
    # print(weights.shape, np.mean(weights))
    # weights = min_max_normalization(weights)
    print(weights.shape, np.mean(weights))
    return np.mean(weights)


def wva(embeds1, embeds2, embeds3):
    weight1 = _compute_weight(embeds1, embeds2, embeds3)
    weight2 = _compute_weight(embeds2, embeds1, embeds3)
    weight3 = _compute_weight(embeds3, embeds1, embeds2)
    return weight1, weight2, weight3
    all_weight = weight1 + weight2 + weight3
    weight1 /= all_weight
    weight2 /= all_weight
    weight3 /= all_weight
    print('final weights', weight1, weight2, weight3)
    ent_embeds = weight1 * embeds1 + \
                 weight2 * embeds2 + \
                 weight3 * embeds3
    return ent_embeds


def valid_WVA(model):
    nv_ent_embeds1 = tf.nn.embedding_lookup(model.name_embeds, model.kgs.valid_entities1).eval(session=model.session)
    rv_ent_embeds1 = tf.nn.embedding_lookup(model.rv_ent_embeds, model.kgs.valid_entities1).eval(session=model.session)
    av_ent_embeds1 = tf.nn.embedding_lookup(model.av_ent_embeds, model.kgs.valid_entities1).eval(session=model.session)
    weight11, weight21, weight31 = wva(nv_ent_embeds1, rv_ent_embeds1, av_ent_embeds1)

    test_list = model.kgs.valid_entities2 + model.kgs.test_entities2
    nv_ent_embeds2 = tf.nn.embedding_lookup(model.name_embeds, test_list).eval(session=model.session)
    rv_ent_embeds2 = tf.nn.embedding_lookup(model.rv_ent_embeds, test_list).eval(session=model.session)
    av_ent_embeds2 = tf.nn.embedding_lookup(model.av_ent_embeds, test_list).eval(session=model.session)
    weight12, weight22, weight32 = wva(nv_ent_embeds2, rv_ent_embeds2, av_ent_embeds2)

    weight1 = weight11 + weight12
    weight2 = weight21 + weight22
    weight3 = weight31 + weight32
    all_weight = weight1 + weight2 + weight3
    weight1 /= all_weight
    weight2 /= all_weight
    weight3 /= all_weight

    print('weights', weight1, weight2, weight3)

    embeds1 = weight1 * nv_ent_embeds1 + \
              weight2 * rv_ent_embeds1 + \
              weight3 * av_ent_embeds1
    embeds2 = weight1 * nv_ent_embeds2 + \
              weight2 * rv_ent_embeds2 + \
              weight3 * av_ent_embeds2
    print('wvag valid results:')
    hits1_12, mrr_12 = eva.valid(embeds1, embeds2, None, model.args.top_k, model.args.test_threads_num,
                                 normalize=True)

    del nv_ent_embeds1, rv_ent_embeds1, av_ent_embeds1
    del nv_ent_embeds2, rv_ent_embeds2, av_ent_embeds2
    del embeds1, embeds2
    gc.collect()

    return mrr_12


def test_WVA(model):
    nv_ent_embeds1 = tf.nn.embedding_lookup(model.name_embeds, model.kgs.test_entities1).eval(session=model.session)
    rv_ent_embeds1 = tf.nn.embedding_lookup(model.rv_ent_embeds, model.kgs.test_entities1).eval(session=model.session)
    av_ent_embeds1 = tf.nn.embedding_lookup(model.av_ent_embeds, model.kgs.test_entities1).eval(session=model.session)
    weight11, weight21, weight31 = wva(nv_ent_embeds1, rv_ent_embeds1, av_ent_embeds1)

    test_list = model.kgs.test_entities2
    nv_ent_embeds2 = tf.nn.embedding_lookup(model.name_embeds, test_list).eval(session=model.session)
    rv_ent_embeds2 = tf.nn.embedding_lookup(model.rv_ent_embeds, test_list).eval(session=model.session)
    av_ent_embeds2 = tf.nn.embedding_lookup(model.av_ent_embeds, test_list).eval(session=model.session)
    weight12, weight22, weight32 = wva(nv_ent_embeds2, rv_ent_embeds2, av_ent_embeds2)

    weight1 = weight11 + weight12
    weight2 = weight21 + weight22
    weight3 = weight31 + weight32
    all_weight = weight1 + weight2 + weight3
    weight1 /= all_weight
    weight2 /= all_weight
    weight3 /= all_weight

    print('weights', weight1, weight2, weight3)

    embeds1 = weight1 * nv_ent_embeds1 + \
              weight2 * rv_ent_embeds1 + \
              weight3 * av_ent_embeds1
    embeds2 = weight1 * nv_ent_embeds2 + \
              weight2 * rv_ent_embeds2 + \
              weight3 * av_ent_embeds2
    print('wvag test results:')
    hits1_12, mrr_12 = eva.valid(embeds1, embeds2, None, model.args.top_k, model.args.test_threads_num,
                                 normalize=True)
    del embeds1, embeds2
    gc.collect()
    return mrr_12


class MultiKE_Late(MultiKE):
    def __init__(self, data, args, attr_align_model):
        super().__init__(data, args, attr_align_model)

        self.flag1 = -1
        self.flag2 = -1
        self.early_stop = False

        self._define_variables()

        self._define_name_view_graph()
        self._define_relation_view_graph()
        self._define_attribute_view_graph()

        self._define_cross_kg_entity_reference_relation_view_graph()
        self._define_cross_kg_entity_reference_attribute_view_graph()
        self._define_cross_kg_relation_reference_graph()
        self._define_cross_kg_attribute_reference_graph()

        self._define_common_space_learning_graph()
        self._define_space_mapping_graph()

        self.session = load_session()
        tf.global_variables_initializer().run(session=self.session)

    def run(self):
        t = time.time()
        relation_triples_num = self.kgs.kg1.local_relation_triples_num + self.kgs.kg2.local_relation_triples_num
        attribute_triples_num = self.kgs.kg1.local_attribute_triples_num + self.kgs.kg2.local_attribute_triples_num
        relation_triple_steps = int(math.ceil(relation_triples_num / self.args.batch_size))
        attribute_triple_steps = int(math.ceil(attribute_triples_num / self.args.batch_size))
        relation_step_tasks = task_divide(list(range(relation_triple_steps)), self.args.batch_threads_num)
        attribute_step_tasks = task_divide(list(range(attribute_triple_steps)), self.args.batch_threads_num)
        manager = mp.Manager()
        relation_batch_queue = manager.Queue()
        attribute_batch_queue = manager.Queue()
        cross_kg_relation_triples = self.kgs.kg1.sup_relation_triples_list + self.kgs.kg2.sup_relation_triples_list
        cross_kg_entity_inference_in_attribute_triples = self.kgs.kg1.sup_attribute_triples_list + \
                                                         self.kgs.kg2.sup_attribute_triples_list
        cross_kg_relation_inference = self.predicate_align_model.sup_relation_alignment_triples1 + \
                                      self.predicate_align_model.sup_relation_alignment_triples2
        cross_kg_attribute_inference = self.predicate_align_model.sup_attribute_alignment_triples1 + \
                                       self.predicate_align_model.sup_attribute_alignment_triples2
        neighbors1, neighbors2 = None, None

        entity_list = self.kgs.kg1.entities_list + self.kgs.kg2.entities_list

        valid(self, embed_choice='nv')
        valid(self, embed_choice='avg')
        for i in range(1, self.args.max_epoch + 1):
            print('epoch {}:'.format(i))
            self.train_relation_view_1epo(i, relation_triple_steps, relation_step_tasks,
                                          relation_batch_queue, neighbors1, neighbors2)
            self.train_cross_kg_entity_inference_relation_view_1epo(i, cross_kg_relation_triples)
            if i > self.args.start_predicate_soft_alignment:
                self.train_cross_kg_relation_inference_1epo(i, cross_kg_relation_inference)

            self.train_attribute_view_1epo(i, attribute_triple_steps, attribute_step_tasks, attribute_batch_queue,
                                           neighbors1, neighbors2)
            self.train_cross_kg_entity_inference_attribute_view_1epo(i, cross_kg_entity_inference_in_attribute_triples)
            if i > self.args.start_predicate_soft_alignment:
                self.train_cross_kg_attribute_inference_1epo(i, cross_kg_attribute_inference)

            if i >= self.args.start_valid and i % self.args.eval_freq == 0:
                valid(self, embed_choice='rv')
                valid(self, embed_choice='av')
                valid(self, embed_choice='avg')
                valid_WVA(self)
                if i >= self.args.start_predicate_soft_alignment:
                    self.predicate_align_model.update_predicate_alignment(self.rel_embeds.eval(session=self.session))
                    self.predicate_align_model.update_predicate_alignment(self.attr_embeds.eval(session=self.session),
                                                                          predicate_type='attribute')
                    cross_kg_relation_inference = self.predicate_align_model.sup_relation_alignment_triples1 + \
                                                  self.predicate_align_model.sup_relation_alignment_triples2
                    cross_kg_attribute_inference = self.predicate_align_model.sup_attribute_alignment_triples1 + \
                                                   self.predicate_align_model.sup_attribute_alignment_triples2

            if self.early_stop or i == self.args.max_epoch:
                break

            if self.args.neg_sampling == 'truncated' and i % self.args.truncated_freq == 0:
                t1 = time.time()
                assert 0.0 < self.args.truncated_epsilon < 1.0
                neighbors_num1 = int((1 - self.args.truncated_epsilon) * self.kgs.kg1.entities_num)
                neighbors_num2 = int((1 - self.args.truncated_epsilon) * self.kgs.kg2.entities_num)
                neighbors1 = bat.generate_neighbours(self.eval_kg1_useful_ent_embeddings(),
                                                     self.kgs.useful_entities_list1,
                                                     neighbors_num1, self.args.batch_threads_num)
                neighbors2 = bat.generate_neighbours(self.eval_kg2_useful_ent_embeddings(),
                                                     self.kgs.useful_entities_list2,
                                                     neighbors_num2, self.args.batch_threads_num)
                ent_num = len(self.kgs.kg1.entities_list) + len(self.kgs.kg2.entities_list)
                print('neighbor dict:', len(neighbors1), type(neighbors2))
                print("generating neighbors of {} entities costs {:.3f} s.".format(ent_num, time.time() - t1))
        for i in range(1, self.args.shared_learning_max_epoch + 1):
            self.train_shared_space_mapping_1epo(i, entity_list)
            if i >= self.args.start_valid and i % self.args.eval_freq == 0:
                valid(self, embed_choice='final')
        self.save()
        test(self, embed_choice='nv')
        test(self, embed_choice='rv')
        test(self, embed_choice='av')
        test(self, embed_choice='avg')
        test_WVA(self)
        test(self, embed_choice='final')


if __name__ == '__main__':
    args = load_args('args.json')
    data = DataModel(args)
    attr_align_model = PredicateAlignModel(data.kgs, args)
    model = MultiKE_Late(data, args, attr_align_model)
    model.run()