from IPython.core.display import display, HTML
import matplotlib.pylab as plt
import matplotlib
import pandas as pd

from utils import get_grouping

# Not used in this file, but useful for jupyter notebooks
import os.path as op
import filenames
import seaborn as sns

matplotlib.style.use('ggplot')
matplotlib.rcParams['axes.facecolor'] = 'white'
matplotlib.rcParams['font.size'] = 14
matplotlib.rcParams['figure.autolayout'] = True

task_pretty_names = [
    ('predict_number_only_nouns', 'Number prediction baseline (common nouns)'),
    ('predict_number_only_generalized_nouns', 'Number prediction baseline (all nouns)'),
    ('predict_number', 'Number prediction (LSTM)'),
    ('predict_number_ensemble', 'Number prediction (LSTM: ensemble)'),
    ('predict_number_srn', 'Number prediction (SRN)'),
    ('predict_number_targeted', 'Number prediction (LSTM: targeted)'),
    ('grammaticality', 'Grammaticality judgments (LSTM)'),
    ('grammaticality_targeted', 'Grammaticality judgments (LSTM: targeted)'),
    ('inflect_verb', 'Verb inflection (LSTM)'),
    ('language_model', 'Language modeling (LSTM)'),
]
task_pretty_names_dict = dict(task_pretty_names)

task_shortcut_names = [
    ('predict_number_only_nouns', 'Baseline (common nouns)'),
    ('predict_number_only_generalized_nouns', 'Baseline (all nouns)'),
    ('predict_number', 'NumPred'),
    ('predict_number_srn', 'NumPred (SRN)'),
    ('predict_number_targeted', 'NumPredNP (targeted)'),
    ('grammaticality', 'GramJudg'),
    ('grammaticality_targeted', 'GramJudg (targeted)'),
    ('inflect_verb', 'VerbInfl'),
    ('language_model', 'LangMod'),
]
task_shortcut_names_dict = dict(task_shortcut_names)

def highlight_dep(dep, show_pos=False):
    s = []
    s2 = []
    z = zip(dep.orig_sentence.split(), dep.sentence.split(),
            dep.pos_sentence.split())
    for i, (tok, mixed, pos) in enumerate(z):
        color = 'black'
        if i == dep.subj_index - 1 or i == dep.verb_index - 1:
            color = 'blue'
        elif (dep.subj_index - 1 < i < dep.verb_index - 1 and 
              pos in ['NN', 'NNS']): 
            color = 'red' if pos != dep.subj_pos else 'green'
        s.append('<span style="color: %s">%s</span>' % (color, tok))
        if show_pos:
            s2.append('<span style="color: %s">%s</span>' % (color, pos))
    res = ' '.join(s)
    if show_pos:
        res +=  '<br>' + ' '.join(s2)
    display(HTML(res))


def clean_ticks():
    ax = plt.gca()
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

def percent_ylabel():
    plt.gca().set_yticklabels(['%d%%' % (x * 100) for
                               x in plt.gca().get_yticks()])

def eb(x, field, **kwargs):
    yerr = [x['errorprob'] - x['minconf'], x['maxconf'] - x['errorprob']]
    plt.errorbar(x[field], x['errorprob'], yerr=yerr, **kwargs)

def errorplot(x, y, minconf, maxconf, **kwargs):
    '''
    e.g.
    g = sns.FacetGrid(attr, col='run', hue='subj_pos', col_wrap=5)
    g = g.map(errorplot, 'n_diff_intervening', 'errorprob',
        'minconf', 'maxconf').add_legend()
    '''
    plt.errorbar(x, y, yerr=[y - minconf, maxconf - y], fmt='o-', **kwargs)


def add_relclause_annotation(res):
    uniform = res[res.n_diff_intervening == res.n_intervening]
    def f(x):
        blacklist = set(['NNP', 'PRP'])
        relprons = set(['WDT', 'WP', 'WRB', 'WP$'])
        words_in_dep = x['orig_sentence'].split()[x['subj_index']:x['verb_index']-1]
        pos_in_dep = x['pos_sentence'].split()[x['subj_index']:x['verb_index']-1]
        first_is_that = words_in_dep[:1] == ['that']
        return (bool(blacklist & set(pos_in_dep)), 
                bool(relprons & set(pos_in_dep[:2])) | first_is_that,
                bool(relprons & set(pos_in_dep)) | first_is_that)

    uniform['blacklisted'], uniform['has_early_relpron'], uniform['has_relpron'] = zip(*uniform.apply(f, axis=1))

    pd.options.mode.chained_assignment = None
    df = uniform[~uniform.blacklisted]
    rel_groups = get_grouping(df[df.n_intervening < 4],
                             ['n_intervening', 'has_rel', 'has_relpron', 'has_early_relpron'])
    rel_groups = rel_groups.reset_index()

    def g(x):
        if x['has_rel'] and x['has_relpron'] and x['has_early_relpron']:
            return 'Rel with early pronoun'
        elif x['has_rel'] and x['has_relpron'] and not x['has_early_relpron']:
            return 'Rel with late pronoun'
        elif x['has_rel'] and not x['has_relpron']:
            return 'Rel without pronoun'
        elif not x['has_rel']:
            if x['has_relpron']:
                return 'Error'
            else:
                return 'No rel'
        else:
            return 'Error'
        
    rel_groups['condition'] = rel_groups.apply(g, axis=1)
    rel_groups = rel_groups[rel_groups.condition != 'Error']
    rel_groups[['n_intervening', 'condition', 'count', 'errorprob']]
    return rel_groups