""" Generates synthetic data from corpora consisting of individual sentences, such as the SICK corpus by replacing 
random words in each sentence with one of their synonyms found in WordNet. Implemented extension strategy owes to:
[1] Mueller et al., "Siamese Recurrent Architectures for Learning Sentence Similarity." 
[2] Zhang et al., "Character-level convolutional networks for text classification."
The extensions are, as expected, reasonably noisy.
"""

import os
import numpy as np
import pandas as pd

import nltk
from nltk import word_tokenize
from nltk.corpus import wordnet

from pywsd.lesk import simple_lesk, cosine_lesk, adapted_lesk
import kenlm


class SickExtender(object):
    """ Extends the SICK sentence similarity corpus with synthetic data generated by substituting synonyms for 
    random content words. Synonyms are obtained via WordNet's synset.lemmas() lookup following the sense disambiguation 
    of the word to be replaced which, in turn, relies on the specified Lesk algorithm - simple, cosine, or adapted.
    Refer to the pywsd documentation for further information. """
    def __init__(self, sick_path, target_directory, lm_path=None, wsd_algorithm='cosine', sampling_parameter=0.5,
                 min_substitutions=2, num_candidates=5, concatenate_corpora=True):
        self.sick_path = sick_path
        self.target_directory = target_directory
        self.lm_path = lm_path
        self.wsd_algorithm = wsd_algorithm
        self.sampling_parameter = sampling_parameter
        self.min_substitutions = min_substitutions
        self.num_candidates = num_candidates
        self.concatenate_corpora = concatenate_corpora
        self.filtered_path = os.path.join(self.target_directory, 'filtered_sick.txt')
        self.noscore_path = os.path.join(self.target_directory, 'noscore_sick.txt')
        # Filter the original SICK corpus to match the expected format, and create file for LM training
        if not os.path.exists(self.filtered_path) or not os.path.exists(self.noscore_path):
            self.filter_sick()
        if self.lm_path is None:
            raise ValueError('No language model provided! Use the noscore_sick corpus to train an .klm LM, first.')
        else:
            self.language_model = kenlm.LanguageModel(self.lm_path)

    def create_extension(self):
        """ Replaces some words within each line of the given file with their WordNet synonyms. Replacement 
        limited to noun, verb, adj, and adv, as those are the POS tags utilized by WordNet."""
        # Track the proportion of the corpus already processed
        counter = 0
        # Create path to the SICK extension corpus
        if self.concatenate_corpora:
            target_path = os.path.join(self.target_directory, 'extended_sick.txt')
        else:
            target_path = os.path.join(self.target_directory, 'sick_extension.txt')
        # Generate paraphrases via thesaurus-based replacement
        print('Commencing with the creation of the synthetic SICK examples.')
        with open(self.filtered_path, 'r') as rf:
            with open(target_path, 'w') as wf:
                for line in rf:
                    # Get tokens and POS tags, i.e. sentences == [sent1, sent2]
                    sentences, sim_score = self.line_prep(line)
                    new_line = list()
                    for sentence in sentences:
                        # Store tokens for subsequent reconstruction
                        tokens = sentence[1]
                        # Get the most likely synset for each token
                        disambiguation = self.disambiguate_synset(sentence)
                        # Replace random words with random synonyms
                        candidate_list = self.replace_with_synonyms(disambiguation)
                        if candidate_list is None:
                            continue
                        paraphrase = self.pick_candidate(tokens, candidate_list)
                        new_line.append(paraphrase)
                    # If nothing could be replaced in either sentence, skip the sentence pair
                    if len(new_line) < 2:
                        continue
                    # Add header
                    # wf.write('sentence_A\tsentence_B\trelatedness_score')
                    if self.concatenate_corpora:
                        wf.write(line)
                        wf.write(new_line[0] + '\t' + new_line[1] + '\t' + sim_score)
                    else:
                        wf.write(new_line[0] + '\t' + new_line[1] + '\t' + sim_score)

                    # Basic bookkeeping
                    counter += 1
                    if counter % 50 == 0 and counter != 0:
                        print('Current progress: Line %d.' % counter)

                    # For quick testing
                    # if counter % 50 == 0 and counter != 0:
                    #    break

        print('The extension sentences for the SICK corpus has been successfully generated.\n'
              'It can be found under %s.\n'
              'Total amount of new sentence pairs: %d.' % (target_path, counter))

    def filter_sick(self):
        """ Processes the original S.I.C.K. corpus into a format where each line contains the two compared sentences
        followed by their relatedness score. """
        # Filter the SICK dataset for sentences and relatedness score only
        df_origin = pd.read_table(self.sick_path)
        df_classify = df_origin.loc[:, ['sentence_A', 'sentence_B', 'relatedness_score']]
        # Scale relatedness score to to lie ∈ [0, 1] for training of the classifier
        df_classify['relatedness_score'] = df_classify['relatedness_score'].apply(
            lambda x: "{:.4f}".format(float(x)/5.0))

        df_noscore = df_origin.loc[:, ['sentence_A', 'sentence_B']]
        df_noscore = df_noscore.stack()

        # Write the filtered set to a .csv file
        df_classify.to_csv(self.filtered_path, sep='\t', index=False, header=False)
        print('Filtered corpus saved to %s.' % self.filtered_path)

        # Write a score-free set to a .csv file to be used in the training of the KN language model
        df_noscore.to_csv(self.noscore_path, index=False, header=False)
        print('Filtered corpus saved to %s.' % self.noscore_path)

    def line_prep(self, line):
        """ Tokenizes and POS-tags a line from the SICK corpus to be compatible with WordNet synset lookup. """
        # Split line into sentences + score
        s1, s2, sim_score = line.split('\t')
        # Tokenize
        s1_tokens = word_tokenize(s1)
        s2_tokens = word_tokenize(s2)
        # Assign part of speech tags
        s1_penn_pos = nltk.pos_tag(s1_tokens)
        s2_penn_pos = nltk.pos_tag(s2_tokens)
        # Convert to WordNet POS tags and store word position in sentence for replacement
        # Each tuple contains (word, WordNet_POS_tag, position)
        s1_wn_pos = list()
        s2_wn_pos = list()
        for idx, item in enumerate(s1_penn_pos):
            if self.get_wordnet_pos(item[1]) != 'OTHER':
                s1_wn_pos.append((item[0], self.get_wordnet_pos(item[1]), s1_penn_pos.index(item)))
        for idx, item in enumerate(s2_penn_pos):
            if self.get_wordnet_pos(item[1]) != 'OTHER':
                s2_wn_pos.append((item[0], self.get_wordnet_pos(item[1]), s2_penn_pos.index(item)))

        # Each tuple contains (word, WordNet_POS_tag, position); Source sentence provided for use in disambiguation
        return [(s1_wn_pos, s1_tokens), (s2_wn_pos, s2_tokens)], sim_score

    def disambiguate_synset(self, sentence_plus_lemmas):
        """ Picks the most likely synset for a lemma provided the context sentence and target word. Utilizes
        the 'Cosine Lesk' algorithm provided by pywds. """
        # Select the disambiguation algorithm
        if self.wsd_algorithm == 'simple':
            wsd_function = simple_lesk
        elif self.wsd_algorithm == 'cosine':
            wsd_function = cosine_lesk
        elif self.wsd_algorithm == 'adapted':
            wsd_function = adapted_lesk
        else:
            raise ValueError('Please specify the word sense disambiguation algorithm:\n '
                             '\'simple\' for Simple Lesk\n'
                             '\'cosine\' for Cosine Lesk\n'
                             '\'adapted\' for Adapted/Extended Lesk')

        lemmas, context = sentence_plus_lemmas
        context = ' '.join(context)
        disambiguated = list()
        for lemma in lemmas:
            try:
                selection = wsd_function(context, lemma[0], pos=lemma[1])
            # For simple_lesk disambiguation algorithm, in case no synsets can be found
            except IndexError:
                selection = None
            disambiguated.append((lemma[0], selection, lemma[2]))
        return disambiguated

    def replace_with_synonyms(self, disambiguated_lemmas):
        """ Calculates the distance between a lemma and each of its synonyms and orders them in a list by increasing 
        distance. Uses the """
        all_synonyms = list()
        # Obtain WordNet synonyms for each lemma in the sentence list
        for idx, lemma in enumerate(disambiguated_lemmas):
            if lemma[1] is not None:
                if len(lemma[1].lemma_names()) > 1:
                    synonyms_per_word = ([' '.join(s.split('_')) for s in lemma[1].lemma_names()], idx)
                    all_synonyms.append(synonyms_per_word)

        # If the sentence cannot be modified, skip it
        if len(all_synonyms) == 0:
            return None

        # Model a geometric distribution with parameter p, following Zhang, Zhao, and LeCun (2015)
        lower_bound = max(min(self.min_substitutions, len(all_synonyms)), 1)
        distribution = {i: self.sampling_parameter ** i for i in range(lower_bound, len(all_synonyms) + 1)}
        sampling_array = list()
        position = 0
        for key in distribution.keys():
            occurrences = int(np.round(distribution[key] * 1000))
            while occurrences != 0:
                sampling_array.append(key)
                position += 1
                occurrences -= 1

        # Sample n substitutions
        outputs = list()
        no_subs = [(l[0], l[2]) for l in disambiguated_lemmas]
        for _ in range(self.num_candidates):
            syn_list = all_synonyms[:]
            candidate = no_subs[:]
            # Randomly pick the amount of word to replace with synonyms
            pick = np.random.randint(0, len(sampling_array))
            to_replace = sampling_array[pick]
            # Perform replacement
            for __ in range(to_replace):
                # Randomly pick the word to be replaced
                j = np.random.randint(0, len(syn_list))
                # Randomly pick the synonym to replace the word with
                k = np.random.randint(0, len(syn_list[j][0]))
                candidate[syn_list[j][1]] = (syn_list[j][0][k], disambiguated_lemmas[syn_list[j][1]][2])
                # Remove the sampled synonym set
                del(syn_list[j])
            outputs.append(candidate)
        return outputs

    def pick_candidate(self, tokens, candidate_list):
        """ Picks the most probable paraprase candidate according to the provided language model. """
        best_paraphrase = None
        best_nll = 0

        # Reconstruct and rate paraphrases
        for candidate in candidate_list:
            for replacement in candidate:
                tokens[replacement[1]] = replacement[0]
            paraphrase = ' '.join(tokens)
            score = self.language_model.score(paraphrase)
            # Keep the most probable one
            if abs(score) > best_nll:
                best_nll = score
                best_paraphrase = paraphrase

        return best_paraphrase

    @staticmethod
    def get_wordnet_pos(treebank_tag):
        """ Converts a Penn Tree-Bank part of speech tag into a corresponding WordNet-friendly tag. 
        Borrowed from: http://stackoverflow.com/questions/15586721/wordnet-lemmatization-and-pos-tagging-in-python. """
        if treebank_tag.startswith('J') or treebank_tag.startswith('A'):
            return wordnet.ADJ
        elif treebank_tag.startswith('V'):
            return wordnet.VERB
        elif treebank_tag.startswith('N'):
            return wordnet.NOUN
        elif treebank_tag.startswith('R'):
            return wordnet.ADV
        else:
            return 'OTHER'