"""Conll Evaluation - Scoring"""

import os
import subprocess
import io
import pickle

import torch
from torch.utils.data import DataLoader

from neuralcoref.train.conllparser import FEATURES_NAMES
from neuralcoref.train.dataset import NCBatchSampler, padder_collate
from neuralcoref.train.compat import unicode_

PACKAGE_DIRECTORY = os.path.dirname(os.path.abspath(__file__))

OUT_PATH = os.path.join(PACKAGE_DIRECTORY, "test_corefs.txt")  # fernandes.txt")#
ALL_MENTIONS_PATH = os.path.join(PACKAGE_DIRECTORY, "test_mentions.txt")
# KEY_PATH = os.path.join(PACKAGE_DIRECTORY, "conll-2012-test-test-key.txt")
SCORING_SCRIPT = os.path.join(PACKAGE_DIRECTORY, "scorer_wrapper.pl")

METRICS = ["muc", "bcub", "ceafm", "ceafe", "blanc"]
CONLL_METRICS = ["muc", "bcub", "ceafe"]


class ConllEvaluator(object):
    def __init__(self, model, dataset, test_data_path, test_key_file, embed_path, args):
        """ Evaluate the pytorch model that is currently being build
            We take the embedding vocabulary currently being trained
        """
        self.test_key_file = test_key_file
        self.cuda = args.cuda
        self.model = model
        batch_sampler = NCBatchSampler(
            dataset.mentions_pair_length, batchsize=args.batchsize, shuffle=False
        )
        self.dataloader = DataLoader(
            dataset,
            collate_fn=padder_collate,
            batch_sampler=batch_sampler,
            num_workers=args.numworkers,
            pin_memory=args.cuda,
        )
        self.mentions_idx, self.n_pairs = batch_sampler.get_batch_info()
        self.load_meta(test_data_path)

    def load_meta(self, test_data_path):
        # Load meta files
        datas = {}
        if not os.listdir(test_data_path):
            raise ValueError("Empty test_data_path")
        bin_files_found = False
        print("Reading ", end="")
        for file_name in os.listdir(test_data_path):
            if ".bin" not in file_name:
                continue
            bin_files_found = True
            print(file_name, end=", ")
            with open(test_data_path + file_name, "rb") as f:
                datas[file_name.split(".")[0]] = pickle.load(f)
        if not bin_files_found:
            raise ValueError(f"Can't find bin files in {test_data_path}")
        print("Done")
        self.m_loc = datas[FEATURES_NAMES[9]]
        self.tokens = datas[FEATURES_NAMES[10]]
        self.lookup = datas[FEATURES_NAMES[11]]
        self.docs = datas[FEATURES_NAMES[12]]
        self.flat_m_idx = list(
            (doc_i, m_i) for doc_i, l in enumerate(self.m_loc) for m_i in range(len(l))
        )

    ###########################
    #### CLUSTER FUNCTIONS ####
    ###########################

    def _prepare_clusters(self):
        """
        Clean up and prepare one cluster for each mention
        """
        self.mention_to_cluster = list(
            list(range(len(doc_mentions))) for doc_mentions in self.m_loc
        )
        self.clusters = list(
            dict((i, [i]) for i in doc_mentions)
            for doc_mentions in self.mention_to_cluster
        )

    def _merge_coreference_clusters(self, ant_flat_idx, mention_flat_idx):
        """
        Merge two clusters together
        """
        doc_idx, ant_idx = self.flat_m_idx[ant_flat_idx]
        doc_idx2, mention_idx = self.flat_m_idx[mention_flat_idx]
        assert doc_idx2 == doc_idx
        if (
            self.mention_to_cluster[doc_idx][ant_idx]
            == self.mention_to_cluster[doc_idx][mention_idx]
        ):
            return
        remove_id = self.mention_to_cluster[doc_idx][ant_idx]
        keep_id = self.mention_to_cluster[doc_idx][mention_idx]
        for idx in self.clusters[doc_idx][remove_id]:
            self.mention_to_cluster[doc_idx][idx] = keep_id
            self.clusters[doc_idx][keep_id].append(idx)
        del self.clusters[doc_idx][remove_id]

    def remove_singletons_clusters(self, debug=False):
        for doc_idx in range(len(self.docs)):
            remove_id = []
            kept = False
            for key, mentions in self.clusters[doc_idx].items():
                if len(mentions) == 1:
                    remove_id.append(key)
                    self.mention_to_cluster[doc_idx][key] = None
                else:
                    kept = True
                    if debug:
                        l = list(self.m_loc[doc_idx][m][3] for m in mentions)
                        print("Cluster found", key)
                        print(
                            "Corefs:",
                            "|".join(
                                str(self.docs[doc_idx]["mentions"][m_idx])
                                + " ("
                                + str(m_idx)
                                + ")"
                                for m_idx in l
                            ),
                        )
            if not kept and debug:
                print("❄️ No coreference found")
            for rem in remove_id:
                del self.clusters[doc_idx][rem]

    def display_clusters(self, doc_idx=None):
        """
        Print clusters informations
        """
        doc_it = range(len(self.docs)) if doc_idx is None else [doc_idx]
        for d_i in doc_it:
            print(
                "Clusters in doc:",
                doc_it,
                self.docs[d_i]["name"],
                self.docs[d_i]["part"],
            )
            print(self.clusters[d_i])
            for key, mentions in self.clusters[d_i].items():
                l = list(self.m_loc[d_i][m][3] for m in mentions)
                print(
                    "cluster",
                    key,
                    "(",
                    ", ".join(self.docs[d_i]["mentions"][m_idx] for m_idx in l),
                    ")",
                )

    ########################
    #### MAIN FUNCTIONS ####
    ########################
    def get_max_score(self, batch, debug=False):
        inputs, mask = batch
        if self.cuda:
            inputs = tuple(i.cuda() for i in inputs)
            mask = mask.cuda()
        self.model.eval()
        with torch.no_grad():
            scores = self.model(inputs, concat_axis=1)
            scores.masked_fill_(mask, -float("Inf"))
            _, max_idx = scores.max(
                dim=1
            )  # We may want to weight the single score with coref.greedyness
        if debug:
            print("Max_idx", max_idx)
        return scores.cpu().numpy(), max_idx.cpu().numpy()

    def test_model(self):
        print("🌋 Test evaluator / print all mentions")
        self.build_test_file(out_path=ALL_MENTIONS_PATH, print_all_mentions=True)
        self.get_score(file_path=ALL_MENTIONS_PATH)

    def build_test_file(
        self,
        out_path=OUT_PATH,
        remove_singleton=True,
        print_all_mentions=False,
        debug=None,
    ):
        """ Build a test file to supply to the coreference scoring perl script
        """
        print("🌋 Building test file")
        self._prepare_clusters()
        self.dataloader.dataset.no_targets = True
        if not print_all_mentions:
            print("🌋 Build coreference clusters")
            for sample_batched, mentions_idx, n_pairs_l in zip(
                self.dataloader, self.mentions_idx, self.n_pairs
            ):
                scores, max_i = self.get_max_score(sample_batched)
                for m_idx, ind, n_pairs in zip(mentions_idx, max_i, n_pairs_l):
                    if (
                        ind < n_pairs
                    ):  # the single score is not the highest, we have a match !
                        prev_idx = m_idx - n_pairs + ind
                        if debug is not None and (
                            debug == -1 or debug == prev_idx or debug == m_idx
                        ):
                            m1_doc, m1_idx = self.flat_m_idx[m_idx]
                            m1 = self.docs[m1_doc]["mentions"][m1_idx]
                            m2_doc, m2_idx = self.flat_m_idx[prev_idx]
                            m2 = self.docs[m2_doc]["mentions"][m2_idx]
                            print(
                                "We have a match between:",
                                m1,
                                "(" + str(m1_idx) + ")",
                                "and:",
                                m2,
                                "(" + str(m2_idx) + ")",
                            )
                        self._merge_coreference_clusters(prev_idx, m_idx)
            if remove_singleton:
                self.remove_singletons_clusters()
        self.dataloader.dataset.no_targets = False

        print("🌋 Construct test file")
        out_str = ""
        for doc, d_tokens, d_lookup, d_m_loc, d_m_to_c in zip(
            self.docs, self.tokens, self.lookup, self.m_loc, self.mention_to_cluster
        ):
            out_str += (
                "#begin document (" + doc["name"] + "); part " + doc["part"] + "\n"
            )
            for utt_idx, (c_tokens, c_lookup) in enumerate(zip(d_tokens, d_lookup)):
                for i, (token, lookup) in enumerate(zip(c_tokens, c_lookup)):
                    out_coref = ""
                    for m_str, mention, mention_cluster in zip(
                        doc["mentions"], d_m_loc, d_m_to_c
                    ):
                        m_start, m_end, m_utt, m_idx, m_doc = mention
                        if mention_cluster is None:
                            pass
                        elif m_utt == utt_idx:
                            if m_start in lookup:
                                out_coref += "|" if out_coref else ""
                                out_coref += "(" + unicode_(mention_cluster)
                                if (m_end - 1) in lookup:
                                    out_coref += ")"
                                else:
                                    out_coref += ""
                            elif (m_end - 1) in lookup:
                                out_coref += "|" if out_coref else ""
                                out_coref += unicode_(mention_cluster) + ")"
                    out_line = (
                        doc["name"]
                        + " "
                        + doc["part"]
                        + " "
                        + unicode_(i)
                        + " "
                        + token
                        + " "
                    )
                    out_line += "-" if len(out_coref) == 0 else out_coref
                    out_str += out_line + "\n"
                out_str += "\n"
            out_str += "#end document\n"

        # Write test file
        print("Writing in", out_path)
        with io.open(out_path, "w", encoding="utf-8") as out_file:
            out_file.write(out_str)

    def get_score(self, file_path=OUT_PATH, debug=False):
        """ Call the coreference scoring perl script on the created test file
        """
        print("🌋 Computing score")
        score = {}
        ident = None
        for metric_name in CONLL_METRICS:
            if debug:
                print("Computing metric:", metric_name)
            try:
                scorer_out = subprocess.check_output(
                    [
                        "perl",
                        SCORING_SCRIPT,
                        metric_name,
                        self.test_key_file,
                        file_path,
                    ],
                    stderr=subprocess.STDOUT,
                    encoding="utf-8",
                )
            except subprocess.CalledProcessError as err:
                print("Error during the scoring")
                print(err)
                print(err.output)
                raise
            if debug:
                print("scorer_out", scorer_out)
            value, ident = scorer_out.split("\n")[-2], scorer_out.split("\n")[-1]
            if debug:
                print("value", value, "identification", ident)
            NR, DR, NP, DP = [float(x) for x in value.split(" ")]
            ident_NR, ident_DR, ident_NP, ident_DP = [
                float(x) for x in ident.split(" ")
            ]
            precision = NP / DP if DP else 0
            recall = NR / DR if DR else 0
            F1 = (
                2 * precision * recall / (precision + recall)
                if precision + recall > 0
                else 0
            )
            ident_precision = ident_NP / ident_DP if ident_DP else 0
            ident_recall = ident_NR / ident_DR if ident_DR else 0
            ident_F1 = (
                2 * ident_precision * ident_recall / (ident_precision + ident_recall)
                if ident_precision + ident_recall > 0
                else 0
            )
            score[metric_name] = (precision, recall, F1)
            ident = (
                ident_precision,
                ident_recall,
                ident_F1,
                ident_NR,
                ident_DR,
                ident_NP,
                ident_DP,
            )
        F1_conll = sum([score[metric][2] for metric in CONLL_METRICS]) / len(
            CONLL_METRICS
        )
        print(
            "Mention identification recall",
            ident[1],
            "<= Detected mentions",
            ident[3],
            "True mentions",
            ident[4],
        )
        print("Scores", score)
        print("F1_conll", F1_conll)
        return score, F1_conll, ident