#  Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# metrics[taxonomy name] is a list of normalized evidences for taxonomy name
from collections import Counter

from axcell.models.linking.acronym_extractor import AcronymExtractor
from axcell.models.linking.probs import get_probs, reverse_probs
from axcell.models.linking.utils import normalize_dataset, normalize_dataset_ws, normalize_cell, normalize_cell_ws
from scipy.special import softmax
import re
import pandas as pd
import numpy as np
import json
import ahocorasick
from numba import njit, typed, types
from pathlib import Path

from axcell.pipeline_logger import pipeline_logger

from axcell.models.linking import manual_dicts
from collections import Counter


def dummy_item(reason):
    return pd.DataFrame(dict(dataset=[reason], task=[reason], metric=[reason], evidence=[""], confidence=[0.0]))


class EvidenceFinder:
    single_letter_re = re.compile(r"\b\w\b")
    init_letter_re = re.compile(r"\b\w")
    end_letter_re = re.compile(r"\w\b")
    letter_re = re.compile(r"\w")

    def __init__(self, taxonomy, abbreviations_path=None, use_manual_dicts=False):
        self.abbreviations_path = abbreviations_path
        self.use_manual_dicts = use_manual_dicts
        self._init_structs(taxonomy)

    @staticmethod
    def evidences_from_name(key):
        x = normalize_dataset_ws(key)
        y = [w for w in x.split() if w not in manual_dicts.stop_words]
        return [x] + y if len(y) > 1 else [x]

    @staticmethod
    def get_basic_dicts(taxonomy):
        tasks = {ts: [normalize_dataset_ws(ts)] for ts in taxonomy.tasks}
        datasets = {ds: EvidenceFinder.evidences_from_name(ds) for ds in taxonomy.datasets}
        metrics = {ms: EvidenceFinder.evidences_from_name(ms) for ms in taxonomy.metrics}
        return tasks, datasets, metrics

    @staticmethod
    def merge_evidences(target, source):
        for name, evs in source.items():
            target.setdefault(name, []).extend(evs)

    @staticmethod
    def make_trie(names):
        trie = ahocorasick.Automaton()
        for name in names:
            norm = name.replace(" ", "")
            trie.add_word(norm, (len(norm), name))
        trie.make_automaton()
        return trie

    @staticmethod
    def get_auto_evidences(name, abbreviations, abbrvs_trie):
        frags = EvidenceFinder.find_names(normalize_dataset_ws(name), abbrvs_trie)
        evidences = []
        for f in frags:
            evidences.extend(abbreviations[f])
        return list(set(evidences))

    @staticmethod
    def find_names(text, names_trie):
        text = text.lower()
        profile = EvidenceFinder.letter_re.sub("i", text)
        profile = EvidenceFinder.init_letter_re.sub("b", profile)
        profile = EvidenceFinder.end_letter_re.sub("e", profile)
        profile = EvidenceFinder.single_letter_re.sub("x", profile)
        text = text.replace(" ", "")
        profile = profile.replace(" ", "")
        s = Counter()
        for (end, (l, word)) in names_trie.iter(text):
            if profile[end] in ['e', 'x'] and profile[end - l + 1] in ['b', 'x']:
                s[word] += 1
        return s

    def find_datasets(self, text):
        return EvidenceFinder.find_names(text, self.all_datasets_trie)

    def find_metrics(self, text):
        return EvidenceFinder.find_names(text, self.all_metrics_trie)

    def find_tasks(self, text):
        return EvidenceFinder.find_names(text, self.all_tasks_trie)

    def init_evidence_dicts(self, taxonomy):
        self.tasks, self.datasets, self.metrics = EvidenceFinder.get_basic_dicts(taxonomy)

        if self.use_manual_dicts:
            EvidenceFinder.merge_evidences(self.tasks, manual_dicts.tasks)
            EvidenceFinder.merge_evidences(self.datasets, manual_dicts.datasets)
            EvidenceFinder.merge_evidences(self.metrics, manual_dicts.metrics)

        if self.abbreviations_path is not None:
            with Path(self.abbreviations_path).open('rt') as f:
                abbreviations = json.load(f)
            abbrvs_trie = EvidenceFinder.make_trie(list(abbreviations.keys()))

            ds_auto = {x: EvidenceFinder.get_auto_evidences(x, abbreviations, abbrvs_trie) for x in taxonomy.datasets}
            ms_auto = {x: EvidenceFinder.get_auto_evidences(x, abbreviations, abbrvs_trie) for x in taxonomy.metrics}

            EvidenceFinder.merge_evidences(self.datasets, ds_auto)
            EvidenceFinder.merge_evidences(self.metrics, ms_auto)

        self.datasets = {k: (v + ['test'] if 'val' not in k else v + ['validation', 'dev', 'development']) for k, v in
                    self.datasets.items()}
        if self.use_manual_dicts:
            self.datasets.update({
                'LibriSpeech dev-clean': ['libri speech dev clean', 'libri speech', 'dev', 'clean', 'dev clean', 'development'],
                'LibriSpeech dev-other': ['libri speech dev other', 'libri speech', 'dev', 'other', 'dev other', 'development', 'noisy'],
            })

    def _init_structs(self, taxonomy):
        self.init_evidence_dicts(taxonomy)

        self.datasets = {k: set(v) for k, v in self.datasets.items()}
        self.metrics = {k: set(v) for k, v in self.metrics.items()}
        self.tasks = {k: set(v) for k, v in self.tasks.items()}

        self.all_datasets = set(normalize_cell_ws(normalize_dataset(y)) for x in self.datasets.values() for y in x)
        self.all_metrics = set(normalize_cell_ws(y) for x in self.metrics.values() for y in x)
        self.all_tasks = set(normalize_cell_ws(normalize_dataset(y)) for x in self.tasks.values() for y in x)

        self.all_datasets_trie = EvidenceFinder.make_trie(self.all_datasets)
        self.all_metrics_trie = EvidenceFinder.make_trie(self.all_metrics)
        self.all_tasks_trie = EvidenceFinder.make_trie(self.all_tasks)


@njit
def axis_logprobs(evidences_for, reverse_probs, found_evidences, noise, pb, max_repetitions):
    logprob = 0.0
    empty = typed.Dict.empty(types.unicode_type, types.float64)
    short_probs = reverse_probs.get(evidences_for, empty)
    for evidence, count in found_evidences.items():
        logprob += min(count, max_repetitions) * np.log(noise * pb + (1 - noise) * short_probs.get(evidence, 0.0))
    return logprob


# compute log-probabilities in a given context and add them to logprobs
@njit
def compute_logprobs(taxonomy, tasks, datasets, metrics,
                     reverse_merged_p, reverse_metrics_p, reverse_task_p,
                     dss, mss, tss, noise, ms_noise, ts_noise, ds_pb, ms_pb, ts_pb,
                     max_repetitions):
    task_cache = typed.Dict.empty(types.unicode_type, types.float64)
    dataset_cache = typed.Dict.empty(types.unicode_type, types.float64)
    metric_cache = typed.Dict.empty(types.unicode_type, types.float64)
    logprobs = np.zeros(len(taxonomy))
    axes_logprobs = (
        np.zeros(len(tasks)),
        np.zeros(len(datasets)),
        np.zeros(len(metrics))
    )
    for i, (task, dataset, metric) in enumerate(taxonomy):
        if dataset not in dataset_cache:
            dataset_cache[dataset] = axis_logprobs(dataset, reverse_merged_p, dss, noise, ds_pb, 1)
        if metric not in metric_cache:
            metric_cache[metric] = axis_logprobs(metric, reverse_metrics_p, mss, ms_noise, ms_pb, 1)
        if task not in task_cache:
            task_cache[task] = axis_logprobs(task, reverse_task_p, tss, ts_noise, ts_pb, max_repetitions)

        logprobs[i] += dataset_cache[dataset] + metric_cache[metric] + task_cache[task]
    for i, task in enumerate(tasks):
        axes_logprobs[0][i] += task_cache[task]

    for i, dataset in enumerate(datasets):
        axes_logprobs[1][i] += dataset_cache[dataset]

    for i, metric in enumerate(metrics):
        axes_logprobs[2][i] += metric_cache[metric]
    return logprobs, axes_logprobs


def _to_typed_list(iterable):
    l = typed.List()
    for i in iterable:
        l.append(i)
    return l


class ContextSearch:
    def __init__(self, taxonomy, evidence_finder,
                 context_noise=(0.99, 1.0, 1.0, 0.25, 0.01),
                 metric_noise=(0.99, 1.0, 1.0, 0.25, 0.01),
                 task_noise=(0.1, 1.0, 1.0, 0.1, 0.1),
                 ds_pb=0.001, ms_pb=0.01, ts_pb=0.01,
                 include_independent=True, debug_gold_df=None):
        merged_p = \
        get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in evidence_finder.datasets.items()})[1]
        metrics_p = \
        get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in evidence_finder.metrics.items()})[1]
        tasks_p = \
        get_probs({k: Counter([normalize_cell(normalize_dataset(x)) for x in v]) for k, v in evidence_finder.tasks.items()})[1]

        # todo: use LRU cache to avoid OOM
        self.queries = {}
        self.logprobs_cache = {}
        self.taxonomy = taxonomy
        self.evidence_finder = evidence_finder

        self._taxonomy = _to_typed_list(self.taxonomy.taxonomy)
        self._taxonomy_tasks = _to_typed_list(self.taxonomy.tasks)
        self._taxonomy_datasets = _to_typed_list(self.taxonomy.datasets)
        self._taxonomy_metrics = _to_typed_list(self.taxonomy.metrics)

        self.extract_acronyms = AcronymExtractor()
        self.context_noise = context_noise
        self.metrics_noise = metric_noise if metric_noise else context_noise
        self.task_noise = task_noise if task_noise else context_noise
        self.ds_pb = ds_pb
        self.ms_pb = ms_pb
        self.ts_pb = ts_pb
        self.reverse_merged_p = self._numba_update_nested_dict(reverse_probs(merged_p))
        self.reverse_metrics_p = self._numba_update_nested_dict(reverse_probs(metrics_p))
        self.reverse_tasks_p = self._numba_update_nested_dict(reverse_probs(tasks_p))
        self.debug_gold_df = debug_gold_df
        self.max_repetitions = 3
        self.include_independent = include_independent

    def _numba_update_nested_dict(self, nested):
        d = typed.Dict()
        for key, dct in nested.items():
            d2 = typed.Dict()
            d2.update(dct)
            d[key] = d2
        return d

    def _numba_extend_list(self, lst):
        l = typed.List.empty_list((types.unicode_type, types.int32))
        for x in lst:
            l.append(x)
        return l

    def _numba_extend_dict(self, dct):
        d = typed.Dict.empty(types.unicode_type, types.int64)
        d.update(dct)
        return d

    def _hash_counter(self, d):
        items = list(d.items())
        items = sorted(items)
        return ";".join([x[0]+":"+str(x[1]) for x in items])

    def compute_context_logprobs(self, context, noise, ms_noise, ts_noise, logprobs, axes_logprobs):
        if isinstance(context, str) or context is None:
            context = context or ""
            #abbrvs = self.extract_acronyms(context)
            context = normalize_cell_ws(normalize_dataset_ws(context))
            #dss = set(self.evidence_finder.find_datasets(context)) | set(abbrvs.keys())
            dss = self.evidence_finder.find_datasets(context)
            mss = self.evidence_finder.find_metrics(context)
            tss = self.evidence_finder.find_tasks(context)

            dss -= mss
            dss -= tss
        else:
            tss, dss, mss = context

        dss = {normalize_cell(ds): count for ds, count in dss.items()}
        mss = {normalize_cell(ms): count for ms, count in mss.items()}
        tss = {normalize_cell(ts): count for ts, count in tss.items()}
        ###print("dss", dss)
        ###print("mss", mss)
        dss = self._numba_extend_dict(dss)
        mss = self._numba_extend_dict(mss)
        tss = self._numba_extend_dict(tss)

        key = (self._hash_counter(tss), self._hash_counter(dss), self._hash_counter(mss), noise, ms_noise, ts_noise)
        if key not in self.logprobs_cache:
            lp, alp = compute_logprobs(self._taxonomy, self._taxonomy_tasks, self._taxonomy_datasets, self._taxonomy_metrics,
                             self.reverse_merged_p, self.reverse_metrics_p, self.reverse_tasks_p,
                             dss, mss, tss, noise, ms_noise, ts_noise, self.ds_pb, self.ms_pb, self.ts_pb,
                             self.max_repetitions)
            self.logprobs_cache[key] = (lp, alp)
        else:
            lp, alp = self.logprobs_cache[key]
        logprobs += lp
        axes_logprobs[0] += alp[0]
        axes_logprobs[1] += alp[1]
        axes_logprobs[2] += alp[2]

    def match(self, contexts):
        assert len(contexts) == len(self.context_noise)
        n = len(self._taxonomy)
        context_logprobs = np.zeros(n)
        axes_context_logprobs = _to_typed_list([
            np.zeros(len(self._taxonomy_tasks)),
            np.zeros(len(self._taxonomy_datasets)),
            np.zeros(len(self._taxonomy_metrics)),
        ])

        for context, noise, ms_noise, ts_noise in zip(contexts, self.context_noise, self.metrics_noise, self.task_noise):
            self.compute_context_logprobs(context, noise, ms_noise, ts_noise, context_logprobs, axes_context_logprobs)
        keys = self.taxonomy.taxonomy
        logprobs = context_logprobs
        #keys, logprobs = zip(*context_logprobs.items())
        probs = softmax(np.array(logprobs))
        axes_probs = [softmax(np.array(a)) for a in axes_context_logprobs]
        return (
            zip(keys, probs),
            zip(self._taxonomy_tasks, axes_probs[0]),
            zip(self._taxonomy_datasets, axes_probs[1]),
            zip(self._taxonomy_metrics, axes_probs[2])
        )

    def __call__(self, query, paper_context, abstract_context, table_context, caption, topk=1, debug_info=None):
        cellstr = debug_info.cell.cell_ext_id
        pipeline_logger("linking::taxonomy_linking::call", ext_id=cellstr, query=query,
                        paper_context=paper_context, abstract_context=abstract_context, table_context=table_context,
                        caption=caption)

        paper_hash = ";".join(",".join(sorted(s.elements())) for s in paper_context)
        abstract_hash = ";".join(",".join(sorted(s.elements())) for s in abstract_context)
        mentions_hash = ";".join(",".join(sorted(s.elements())) for s in table_context)
        key = (paper_hash, abstract_hash, mentions_hash, caption, query, topk)
        ###print(f"[DEBUG] {cellstr}")
        ###print("[DEBUG]", debug_info)
        ###print("query:", query, caption)
        if key in self.queries:
            # print(self.queries[key])
            # for context in key:
            #     abbrvs = self.extract_acronyms(context)
            #     context = normalize_cell_ws(normalize_dataset(context))
            #     dss = set(find_datasets(context)) | set(abbrvs.keys())
            #     mss = set(find_metrics(context))
            #     dss -= mss
                ###print("dss", dss)
                ###print("mss", mss)

            ###print("Taking result from cache")
            p = self.queries[key]
        else:
            dists = self.match((paper_context, abstract_context, table_context, caption, query))

            all_top_results = [sorted(list(dist), key=lambda x: x[1], reverse=True)[:max(topk, 5)] for dist in dists]
            top_results, top_results_t, top_results_d, top_results_m = all_top_results

            entries = []
            for it, prob in top_results:
                task, dataset, metric = it
                entry = dict(task=task, dataset=dataset, metric=metric)
                entry.update({"evidence": "", "confidence": prob})
                entries.append(entry)

            if self.include_independent:
                best_independent = dict(
                    task=top_results_t[0][0],
                    dataset=top_results_d[0][0],
                    metric=top_results_m[0][0])
                best_independent.update({
                    "evidence": "",
                    "confidence": 0.79
                })
                entries.append(best_independent)

            # entries = []
            # for i in range(5):
            #     best_independent = dict(
            #         task=top_results_t[i][0],
            #         dataset=top_results_d[i][0],
            #         metric=top_results_m[i][0])
            #     best_independent.update({
            #         "evidence": "",
            #         "confidence": np.power(top_results_t[i][1] * top_results_d[i][1] * top_results_m[i][1], 1.0/3.0)
            #     })
            #     entries.append(best_independent)
                #entries = [best_independent] + entries

            # best, best_p = sorted(dist, key=lambda x: x[1], reverse=True)[0]
            # entry = et[best]
            # p = pd.DataFrame({k:[v] for k, v in entry.items()})
            # p["evidence"] = ""
            # p["confidence"] = best_p
            p = pd.DataFrame(entries).sort_values("confidence", ascending=False)

            self.queries[key] = p

        ###print(p)

        # error analysis only
        if self.debug_gold_df is not None:
            if cellstr in self.debug_gold_df.index:
                gold_record = self.debug_gold_df.loc[cellstr]
                if p.iloc[0].dataset == gold_record.dataset:
                    print("[EA] Matching gold sota record (dataset)")
                else:
                    print(
                        f"[EA] Proposal dataset ({p.iloc[0].dataset}) and gold dataset ({gold_record.dataset}) mismatch")
            else:
                print("[EA] No gold sota record found for the cell")
        # end of error analysis only
        pipeline_logger("linking::taxonomy_linking::topk", ext_id=cellstr, topk=p.head(5))

        q = p.head(topk).copy()
        q["true_metric"] = q.apply(lambda row: self.taxonomy.normalize_metric(row.task, row.dataset, row.metric), axis=1)
        return q


# todo: compare regex approach (old) with find_datasets(.) (current)
# todo: rename it
class DatasetExtractor:
    def __init__(self, evidence_finder):
        self.evidence_finder = evidence_finder
        self.dataset_prefix_re = re.compile(r"[A-Z]|[a-z]+[A-Z]+|[0-9]")
        self.dataset_name_re = re.compile(r"\b(the)\b\s*(?P<name>((?!(the)\b)\w+\W+){1,10}?)(test|val(\.|idation)?|dev(\.|elopment)?|train(\.|ing)?\s+)?\bdata\s*set\b", re.IGNORECASE)

    def find_references(self, text, references):
        refs = r"\bxxref-(" + "|".join([re.escape(ref) for ref in references]) + r")\b"
        return set(re.findall(refs, text))

    def get_table_contexts(self, paper, tables):
        ref_tables = [table for table in tables if table.figure_id and table.figure_id.replace(".", "")]
        refs = [table.figure_id.replace(".", "") for table in ref_tables]
        if not refs:
            return [[Counter(), Counter(), Counter()] for table in tables]
        ref_contexts = {ref: [Counter(), Counter(), Counter()] for ref in refs}
        if hasattr(paper.text, "fragments"):
            for fragment in paper.text.fragments:
                found_refs = self.find_references(fragment.text, refs)
                if found_refs:
                    ts, ds, ms = self(fragment.header + "\n" + fragment.text)
                    for ref in found_refs:
                        ref_contexts[ref][0] += ts
                        ref_contexts[ref][1] += ds
                        ref_contexts[ref][2] += ms
        table_contexts = [
            ref_contexts.get(
                table.figure_id.replace(".", ""),
                [Counter(), Counter(), Counter()]
            ) if table.figure_id else [Counter(), Counter(), Counter()]
            for table in tables
        ]
        return table_contexts

    def from_paper(self, paper):
        abstract = paper.text.abstract
        text = ""
        if hasattr(paper.text, "fragments"):
            text += " ".join(f.text for f in paper.text.fragments)
        return self(text), self(abstract)

    def __call__(self, text):
        text = normalize_cell_ws(normalize_dataset_ws(text))
        ds = self.evidence_finder.find_datasets(text)
        ts = self.evidence_finder.find_tasks(text)
        ms = self.evidence_finder.find_metrics(text)
        ds -= ts
        ds -= ms
        return ts, ds, ms