"""Functions to plot and compute results for the literature review.
"""

import os
import logging
import logging.config
from collections import OrderedDict
import subprocess

import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import matplotlib.patches as patches
import matplotlib as mpl
import seaborn as sns
import numpy as np
from PIL import Image
from wordcloud import WordCloud, STOPWORDS
from graphviz import Digraph

import config as cfg
import utils as ut


# Set style, context and palette
sns.set_style(rc=cfg.axes_styles)
sns.set_context(rc=cfg.plotting_context)
sns.set_palette(cfg.palette)

for key, val in cfg.axes_styles.items():
    mpl.rcParams[key] = val
for key, val in cfg.plotting_context.items():
    mpl.rcParams[key] = val


# Initialize logger for saving results and stats. Use `logger.info('message')`
# to log results.
logging.config.dictConfig({
    'version': 1,
    'disable_existing_loggers': True,
})
logger = logging.getLogger()
log_savename = os.path.join(cfg.saving_config['savepath'], 'results') + '.log'
handler = logging.FileHandler(log_savename, mode='w')
formatter = logging.Formatter(
        '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)


def plot_prisma_diagram(save_cfg=cfg.saving_config):
    """Plot diagram showing the number of selected articles.

    TODO:
    - Use first two colors of colormap instead of gray
    - Reduce white space
    - Reduce arrow width
    """
    # save_format = save_cfg['format'] if isinstance(save_cfg, dict) else 'svg'
    save_format = 'pdf'
    # save_format = 'eps'
    size = '{},{}!'.format(0.5 * save_cfg['page_width'], 0.2 * save_cfg['page_height'])

    dot = Digraph(format=save_format)
    dot.attr('graph', rankdir='TB', overlap='false', size=size, margin='0')
    dot.attr('node', fontname='Liberation Sans', fontsize=str(9), shape='box', 
             style='filled', margin='0.15,0.07', penwidth='0.1')
    # dot.attr('edge', arrowsize=0.5)

    fillcolor = 'gray98'

    dot.node('A', 'PubMed (n=39)\nGoogle Scholar (n=409)\narXiv (n=105)', 
             fillcolor='gray95')
    dot.node('B', 'Articles identified\nthrough database\nsearching\n(n=553)', 
             fillcolor=fillcolor)
    # dot.node('B2', 'Excluded\n(n=446)', fillcolor=fillcolor)
    dot.node('C', 'Articles after content\nscreening and\nduplicate removal\n(n=105) ', 
             fillcolor=fillcolor)
    dot.node('D', 'Articles included in\nthe analysis\n(n=154)', 
             fillcolor=fillcolor)
    dot.node('E', 'Additional articles\nidentified through\nbibliography search\n(n=49)', 
             fillcolor=fillcolor)

    dot.edge('B', 'C')
    # dot.edge('B', 'B2')
    dot.edge('C', 'D')
    dot.edge('E', 'D')

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'prisma_diagram')
        dot.render(filename=fname, view=False, cleanup=False)
                
    return dot


def plot_domain_tree(df, first_box='DL + EEG studies', min_font_size=10, 
                     max_font_size=14, max_char=16, min_n_items=2, 
                     postprocess=True, save_cfg=cfg.saving_config):
    """Plot tree graph showing the breakdown of study domains.

    Args:
        df (pd.DataFrame): data items table

    Keyword Args:
        first_box (str): text of the first box
        min_font_size (int): minimum font size
        max_font_size (int): maximum font size
        max_char (int): maximum number of characters per line
        min_n_items (int): if a node has less than this number of elements, 
            put it inside a node called "Others".
        postpocess (bool): if True, convert PNG to EPS using inkscape in a 
            system call.
        save_cfg (dict or None):
    
    Returns:
        (graphviz.Digraph): graphviz object

    NOTES:
    - To unflatten automatically, apply the following on the .dot file:
        >> unflatten -l 3 -c 10 dom_domains_tree | dot -Teps -o domains_unflattened.eps
    - To produce a circular version instead (uses space more efficiently):
        >> neato -Tpdf dom_domains_tree -o domains_neato.pdf
    """
    df = df[['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4']].copy()
    df = df[~df['Domain 1'].isnull()]
    df[df == ' '] = None

    n_samples, n_levels = df.shape
    format = save_cfg['format'] if isinstance(save_cfg, dict) else 'svg'
    size = '{},{}!'.format(save_cfg['page_width'], 0.7 * save_cfg['page_height'])
    
    dot = Digraph(format=format)
    dot.attr('graph', rankdir='TB', overlap='false', ratio='fill', size=size)  # LR (left to right), TB (top to bottom)
    dot.attr('node', fontname='Liberation Sans', fontsize=str(max_font_size), 
             shape='box', style='filled, rounded',  margin='0.2,0.01', 
             penwidth='0.5')
    dot.node('A', '{}\n({})'.format(first_box, len(df)), 
             fillcolor='azure')
    
    min_sat, max_sat = 0.05, 0.4
    
    sub_df = df['Domain 1'].value_counts()
    n_categories = len(sub_df)

    for i, (d1, count1) in enumerate(sub_df.iteritems()):
        node1, hue = ut.make_box(
            dot, d1, max_char, count1, n_samples, 0, n_levels, min_sat, max_sat, 
            min_font_size, max_font_size, 'A', counter=i, 
            n_categories=n_categories)
        
        for d2, count2 in df[df['Domain 1'] == d1]['Domain 2'].value_counts().iteritems():
            node2, _ = ut.make_box(
                dot, d2, max_char, count2, n_samples, 1, n_levels, min_sat, 
                max_sat, min_font_size, max_font_size, node1, hue=hue)
            
            n_others3 = 0
            for d3, count3 in df[df['Domain 2'] == d2]['Domain 3'].value_counts().iteritems():
                if isinstance(d3, str) and d3 != 'TBD':
                    if count3 < min_n_items:
                        n_others3 += 1
                    else:
                        node3, _ = ut.make_box(
                            dot, d3, max_char, count3, n_samples, 2, n_levels,
                            min_sat, max_sat, min_font_size, max_font_size, 
                            node2, hue=hue)

                        n_others4 = 0
                        for d4, count4 in df[df['Domain 3'] == d3]['Domain 4'].value_counts().iteritems():
                            if isinstance(d4, str) and d4 != 'TBD':
                                if count4 < min_n_items:
                                    n_others4 += 1
                                else:
                                    ut.make_box(
                                        dot, d4, max_char, count4, n_samples, 3, 
                                        n_levels, min_sat, max_sat, min_font_size, 
                                        max_font_size, node3, hue=hue)

                        if n_others4 > 0:
                            ut.make_box(
                                dot, 'Others', max_char, n_others4, n_samples, 
                                3, n_levels, min_sat, max_sat, min_font_size, 
                                max_font_size, node3, hue=hue, 
                                node_name=node3+'others')

            if n_others3 > 0:
                ut.make_box(
                    dot, 'Others', max_char, n_others3, n_samples, 2, n_levels,
                    min_sat, max_sat, min_font_size, max_font_size, node2, hue=hue, 
                    node_name=node2+'others')

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'dom_domains_tree')
        dot.render(filename=fname, cleanup=False)
        if postprocess:
            subprocess.call(
                ['neato', '-Tpdf', fname, '-o', fname + '.pdf'])
                
    return dot


def plot_model_comparison(df, save_cfg=cfg.saving_config):
    """Plot bar graph showing the types of baseline models used.
    """
    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 2, 
                                    save_cfg['text_height'] / 5))
    sns.countplot(y=df['Baseline model type'].dropna(axis=0), ax=ax)
    ax.set_xlabel('Number of papers')
    ax.set_ylabel('')
    plt.tight_layout()

    model_prcts = df['Baseline model type'].value_counts() / df.shape[0] * 100
    logger.info('% of studies that used at least one traditional baseline: {}'.format(
        model_prcts['Traditional pipeline'] + model_prcts['DL & Trad.']))
    logger.info('% of studies that used at least one deep learning baseline: {}'.format(
        model_prcts['DL'] + model_prcts['DL & Trad.']))
    logger.info('% of studies that did not report baseline comparisons: {}'.format(
        model_prcts['None']))

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'model_comparison')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_performance_metrics(df, cutoff=3, eeg_clf=None, 
                             save_cfg=cfg.saving_config):
    """Plot bar graph showing the types of performance metrics used.

    Args:
        df (DataFrame)

    Keyword Args:
        cutoff (int): Metrics with less than this number of papers will be cut
            off from the bar graph.
        eeg_clf (bool): If True, only use studies that focus on EEG 
            classification. If False, only use studies that did not focus on 
            EEG classification. If None, use all studies.
        save_cfg (dict)

    Assumptions, simplifications:
    - Rates have been simplified (e.g., "false positive rate" -> "false positives")
    - RMSE and MSE have been merged under MSE
    - Training/testing times have been simplified to "time"
    - Macro f1-score === f1=score
    """
    if eeg_clf is True:
        metrics = df[df['Domain 1'] == 'Classification of EEG signals'][
            'Performance metrics (clean)']
    elif eeg_clf is False:
        metrics = df[df['Domain 1'] != 'Classification of EEG signals'][
            'Performance metrics (clean)']
    elif eeg_clf is None:
        metrics = df['Performance metrics (clean)']

    metrics = metrics.str.split(',').apply(ut.lstrip)

    metric_per_article = list()
    for i, metric_list in metrics.iteritems():
        for m in metric_list:
            metric_per_article.append([i, m])

    metrics_df = pd.DataFrame(metric_per_article, columns=['paper nb', 'metric'])

    # Replace equivalent terms by standardized term
    equivalences = {'selectivity': 'specificity',
                    'true negative rate': 'specificity',
                    'sensitivitiy': 'sensitivity',
                    'sensitivy': 'sensitivity',
                    'recall': 'sensitivity',
                    'hit rate': 'sensitivity', 
                    'true positive rate': 'sensitivity',
                    'sensibility': 'sensitivity',
                    'positive predictive value': 'precision',
                    'f-measure': 'f1-score',
                    'f-score': 'f1-score',
                    'f1-measure': 'f1-score',
                    'macro f1-score': 'f1-score',
                    'macro-averaging f1-score': 'f1-score',
                    'kappa': 'cohen\'s kappa',
                    'mae': 'mean absolute error',
                    'false negative rate': 'false negatives',
                    'fpr': 'false positives',
                    'false positive rate': 'false positives',
                    'false prediction rate': 'false positives',
                    'roc': 'ROC curves',
                    'roc auc': 'ROC AUC',
                    'rmse': 'mean squared error',
                    'mse': 'mean squared error',
                    'training time': 'time',
                    'testing time': 'time',
                    'test error': 'error'}
    metrics_df = metrics_df.replace(equivalences)
    metrics_df['metric'] = metrics_df['metric'].apply(lambda x: x[0].upper() + x[1:])

    # Removing low count categories
    metrics_counts = metrics_df['metric'].value_counts()
    metrics_df = metrics_df[metrics_df['metric'].isin(
        metrics_counts[(metrics_counts >= cutoff)].index)]

    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 2, 
                                    save_cfg['text_height'] / 5))
    ax = sns.countplot(y='metric', data=metrics_df, 
                       order=metrics_df['metric'].value_counts().index)
    ax.set_xlabel('Number of papers')
    ax.set_ylabel('')
    plt.tight_layout()

    if save_cfg is not None:
        savename = 'performance_metrics'
        if eeg_clf is True:
            savename += '_eeg_clf'
        elif eeg_clf is False:
            savename += '_not_eeg_clf'
        fname = os.path.join(save_cfg['savepath'], savename)
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_reported_results(df, data_items_df=None, save_cfg=cfg.saving_config):
    """Plot figures to described the reported results in the studies.

    Args:
        df (DataFrame): contains reported results (second tab in spreadsheet)

    Keyword Args:
        data_items_df (DataFrame): contains data items (first tab in spreadsheet)
        save_cfg (dict)

    Returns:
        (list): list of axes to created figures

    TODO:
    - This function is starting to be a bit too big. Should probably split it up.
    """
    acc_df = df[df['Metric'] == 'accuracy']  # Extract accuracy rows only

    # Create new column that contains both citation and task information
    acc_df.loc[:, 'citation_task'] = acc_df[['Citation', 'Task']].apply(
        lambda x: ' ['.join(x) + ']', axis=1)

    # Create a new column with the year
    acc_df.loc[:, 'year'] = acc_df['Citation'].apply(
        lambda x: int(x[x.find('2'):x.find('2') + 4]))

    # Order by average proposed model accuracy
    acc_ind = acc_df[acc_df['model_type'] == 'Proposed'].groupby(
        'Citation').mean().sort_values(by='Result').index
    acc_df.loc[:, 'Citation'] = acc_df['Citation'].astype('category')
    acc_df['Citation'].cat.set_categories(acc_ind, inplace=True)
    acc_df = acc_df.sort_values(['Citation'])

    # Only keep 2 best per task and model type
    acc2_df = acc_df.sort_values(
        ['Citation', 'Task', 'model_type', 'Result'], ascending=True).groupby(
            ['Citation', 'Task', 'model_type']).tail(2)

    axes = list()
    axes.append(_plot_results_per_citation_task(acc2_df, save_cfg))

    # Only keep the maximum accuracy per citation & task
    best_df = acc_df.groupby(
        ['Citation', 'Task', 'model_type'])[
            'Result'].max().reset_index()

    # Only keep citations/tasks that have a traditional baseline
    best_df = best_df.groupby(['Citation', 'Task']).filter(
        lambda x: 'Baseline (traditional)' in x.values).reset_index()

    # Add back architecture
    best_df = pd.merge(
        best_df, acc_df[['Citation', 'Task', 'model_type', 'Result', 'Architecture']], 
        how='inner').drop_duplicates()  # XXX: why are there duplicates?

    # Compute difference between proposed and traditional baseline
    def acc_diff_and_arch(x):
        diff = x[x['model_type'] == 'Proposed']['Result'].iloc[0] - \
               x[x['model_type'] == 'Baseline (traditional)']['Result'].iloc[0]
        arch = x[x['model_type'] == 'Proposed']['Architecture']
        return pd.Series(diff, arch)

    diff_df = best_df.groupby(['Citation', 'Task']).apply(
        acc_diff_and_arch).reset_index()
    diff_df = diff_df.rename(columns={0: 'acc_diff'})

    axes.append(_plot_results_accuracy_diff_scatter(diff_df, save_cfg))
    axes.append(_plot_results_accuracy_diff_distr(diff_df, save_cfg))

    # Pivot dataframe to plot proposed vs. baseline accuracy as a scatterplot
    best_df['citation_task'] = best_df[['Citation', 'Task']].apply(
        lambda x: ' ['.join(x) + ']', axis=1)
    acc_comparison_df = best_df.pivot_table(
        index='citation_task', columns='model_type', values='Result')

    axes.append(_plot_results_accuracy_comparison(acc_comparison_df, save_cfg))

    if data_items_df is not None:
        domains_df = data_items_df.filter(
            regex='(?=Domain*|Citation|Main domain|Journal / Origin|Dataset name|'
                    'Data - samples|Data - time|Data - subjects|Preprocessing \(clean\)|'
                    'Artefact handling \(clean\)|Features \(clean\)|Architecture \(clean\)|'
                    'Layers \(clean\)|Regularization \(clean\)|Optimizer \(clean\)|'
                    'Intra/Inter subject|Training procedure)')

        # Concatenate domains into one string
        def concat_domains(x):
            domain = ''
            for i in x[1:]:
                if isinstance(i, str):
                    domain += i + '/'
            return domain[:-1]

        domains_df.loc[:, 'domain'] = data_items_df.filter(
            regex='(?=Domain*)').apply(concat_domains, axis=1)
        diff_domain_df = diff_df.merge(domains_df, on='Citation', how='left')
        diff_domain_df = diff_domain_df.sort_values(by='domain')
        diff_domain_df.loc[:, 'arxiv'] = diff_domain_df['Journal / Origin'] == 'Arxiv'

        axes.append(_plot_results_accuracy_per_domain(
            diff_domain_df, diff_df, save_cfg))
        axes.append(_plot_results_stats_impact_on_acc_diff(
            diff_domain_df, save_cfg))
        axes.append(_compute_acc_diff_for_preprints(diff_domain_df, save_cfg))
        
    return axes


def _plot_results_per_citation_task(results_df, save_cfg):
    """Plot scatter plot of accuracy for each condition and task.
    """
    fig, ax = plt.subplots(figsize=(save_cfg['text_width'], 
                                    save_cfg['text_height'] * 1.3))
    # figsize = plt.rcParams.get('figure.figsize')
    # fig, ax = plt.subplots(figsize=(figsize[0], figsize[1] * 4))
    # Need to make the graph taller otherwise the y axis labels are on top of
    # each other.
    sns.catplot(y='citation_task', x='Result', hue='model_type', data=results_df, 
                ax=ax)
    ax.set_xlabel('accuracy')
    ax.set_ylabel('')
    plt.tight_layout()

    if save_cfg is not None:
        savename = 'reported_results'
        fname = os.path.join(save_cfg['savepath'], savename)
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def _plot_results_accuracy_diff_scatter(results_df, save_cfg):
    """Plot difference in accuracy for each condition/task as a scatter plot.
    """
    fig, ax = plt.subplots(figsize=(save_cfg['text_width'], 
                                    save_cfg['text_height'] * 1.3))
    # figsize = plt.rcParams.get('figure.figsize')
    # fig, ax = plt.subplots(figsize=(figsize[0], figsize[1] * 2))
    sns.catplot(y='Task', x='acc_diff', data=results_df, ax=ax)
    ax.set_xlabel('Accuracy difference')
    ax.set_ylabel('')
    ax.axvline(0, c='k', alpha=0.2)
    plt.tight_layout()

    if save_cfg is not None:
        savename = 'reported_accuracy_diff_scatter'
        fname = os.path.join(save_cfg['savepath'], savename)
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def _plot_results_accuracy_diff_distr(results_df, save_cfg):
    """Plot the distribution of difference in accuracy.
    """
    fig, ax = plt.subplots(figsize=(save_cfg['text_width'], 
                                    save_cfg['text_height'] * 0.5))
    sns.distplot(results_df['acc_diff'], kde=False, rug=True, ax=ax)
    ax.set_xlabel('Accuracy difference')
    ax.set_ylabel('Number of studies')
    plt.tight_layout()

    if save_cfg is not None:
        savename = 'reported_accuracy_diff_distr'
        fname = os.path.join(save_cfg['savepath'], savename)
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def _plot_results_accuracy_comparison(results_df, save_cfg):
    """Plot the comparison between the best model and best baseline.
    """
    fig, ax = plt.subplots(figsize=(save_cfg['text_width'], 
                                    save_cfg['text_height'] * 0.5))
    sns.scatterplot(data=results_df, x='Baseline (traditional)', y='Proposed', 
                    ax=ax)
    ax.plot([0, 1.1], [0, 1.1], c='k', alpha=0.2)
    plt.axis('square')
    ax.set_xlim([0, 1.1])
    ax.set_ylim([0, 1.1])
    plt.tight_layout()

    if save_cfg is not None:
        savename = 'reported_accuracy_comparison'
        fname = os.path.join(save_cfg['savepath'], savename)
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def _plot_results_accuracy_per_domain(results_df, diff_df, save_cfg):
    """Make scatterplot + boxplot to show accuracy difference by domain.
    """
    fig, axes = plt.subplots(
        nrows=2, ncols=1, sharex=True, 
        figsize=(save_cfg['text_width'], save_cfg['text_height'] / 3), 
        gridspec_kw = {'height_ratios':[5, 1]})

    results_df['Main domain'] = results_df['Main domain'].apply(
        ut.wrap_text, max_char=20)

    sns.catplot(y='Main domain', x='acc_diff', s=3, jitter=True, 
                data=results_df, ax=axes[0])
    axes[0].set_xlabel('')
    axes[0].set_ylabel('')
    axes[0].axvline(0, c='k', alpha=0.2)

    sns.boxplot(x='acc_diff', data=diff_df, ax=axes[1])
    sns.swarmplot(x='acc_diff', data=diff_df, color="0", size=2, ax=axes[1])
    axes[1].axvline(0, c='k', alpha=0.2)
    axes[1].set_xlabel('Accuracy difference')

    fig.subplots_adjust(wspace=0, hspace=0.02)
    plt.tight_layout()

    logger.info('Number of studies included in the accuracy improvement analysis: {}'.format(
        results_df.shape[0]))
    median = diff_df['acc_diff'].median()
    iqr = diff_df['acc_diff'].quantile(.75) - diff_df['acc_diff'].quantile(.25)
    logger.info('Median gain in accuracy: {:.6f}'.format(median))
    logger.info('Interquartile range of the gain in accuracy: {:.6f}'.format(iqr))
    best_improvement = diff_df.nlargest(3, 'acc_diff')
    logger.info('Best improvement in accuracy: {}, in {}'.format(
        best_improvement['acc_diff'].values[0], 
        best_improvement['Citation'].values[0]))
    logger.info('Second best improvement in accuracy: {}, in {}'.format(
        best_improvement['acc_diff'].values[1], 
        best_improvement['Citation'].values[1]))
    logger.info('Third best improvement in accuracy: {}, in {}'.format(
        best_improvement['acc_diff'].values[2], 
        best_improvement['Citation'].values[2]))

    if save_cfg is not None:
        savename = 'reported_accuracy_per_domain'
        fname = os.path.join(save_cfg['savepath'], savename)
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return axes


def _plot_results_stats_impact_on_acc_diff(results_df, save_cfg):
    """Run statistical analysis to see which data items correlate with acc diff.

    NOTE: This analysis is not perfectly accurate as there are several papers 
        which contrasted results based on data items (e.g., testing the impact
        of number of layers on performance), but our summaries are not at this
        level of granularity. Therefore the results are not to be taken at face
        value.
    """
    binary_data_items = {'Preprocessing (clean)': ['Yes', 'No'],
                         'Artefact handling (clean)': ['Yes', 'No'],
                         'Features (clean)': ['Raw EEG', 'Frequency-domain'],
                         'Regularization (clean)': ['Yes', 'N/M'],
                         'Intra/Inter subject': ['Intra', 'Inter']}
    multiclass_data_items = ['Architecture',  # Architecture (clean)',
                             'Optimizer (clean)']
    continuous_data_items = {'Layers (clean)': False,
                             'Data - subjects': True,
                             'Data - time': True,
                             'Data - samples': True}

    results = dict()
    for key, val in binary_data_items.items():
        results[key] = ut.run_mannwhitneyu(results_df, key, val, plot=True)

    for i in multiclass_data_items:
        results[i] = ut.run_kruskal(results_df, i, plot=True)

    for i in continuous_data_items:
        single_df = ut.keep_single_valued_rows(results_df, i)
        single_df = single_df[single_df[i] != 'N/M']
        single_df[i] = single_df[i].astype(float)
        results[i] = ut.run_spearmanr(single_df, i, log=val, plot=True)
    
    stats_df =  pd.DataFrame(results).T
    logger.info('Results of statistical tests on impact of data items:\n{}'.format(
        stats_df))

    # Categorical plot for each "significant" data item
    significant_items = stats_df[stats_df['pvalue'] < 0.05].index
    if save_cfg is not None and len(significant_items) > 0:
        for i in significant_items:
            savename = 'stat_impact_{}_on_acc_diff'.format(
                i.replace(' ', '_').replace('/', '_'))
            fname = os.path.join(save_cfg['savepath'], savename)
            stats_df.loc[i, 'fig'].savefig(
                fname + '.' + save_cfg['format'], **save_cfg)

    return None


def _compute_acc_diff_for_preprints(results_df, save_cfg):
    """Analyze the acc diff for preprints vs. peer-reviewed articles.
    """
    results_df['preprint'] = results_df['Journal / Origin'].isin(['Arxiv', 'BioarXiv'])
    preprints = results_df.groupby('Citation').first()['preprint'].value_counts()

    logger.info(
        'Number of preprints included in the accuracy difference comparison: '
        '{}/{} papers'.format(preprints[True], sum(preprints)))

    logger.info('Median acc diff for preprints vs. non-preprint:\n{}'.format(
        results_df.groupby('preprint').median()['acc_diff']))
    results = ut.run_mannwhitneyu(results_df, 'preprint', [True, False])
    logger.info('Mann-Whitney test on preprint vs. not preprint: {:0.3f}'.format(
        results['pvalue']))

    return results


def generate_wordcloud(df, save_cfg=cfg.saving_config):
    brain_mask = np.array(Image.open("./img/brain_stencil.png"))

    def transform_format(val):
        if val == 0:
            return 255
        else:
            return val

    text = (df['Title']).to_string()

    stopwords = set(STOPWORDS)
    stopwords.add("using")
    stopwords.add("based")

    wc = WordCloud(
        background_color="white", max_words=2000, max_font_size=50, mask=brain_mask,
        stopwords=stopwords, contour_width=1, contour_color='steelblue')
    wc.generate(text)

    # store to file
    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'DL-EEG_WordCloud')
        wc.to_file(fname + '.' + save_cfg['format']) #, **save_cfg)


def plot_model_inspection_and_table(df, cutoff=1, save_cfg=cfg.saving_config):
    """Make bar graph and table listing method inspection techniques.

    Args:
        df (DataFrame)

    Keyword Args:
        cutoff (int): Metrics with less than this number of papers will be cut
            off from the bar graph.
        save_cfg (dict)
    """
    df['inspection_list'] = df[
        'Model inspection (clean)'].str.split(',').apply(ut.lstrip)

    inspection_per_article = list()
    for i, items in df[['Citation', 'inspection_list']].iterrows():
        for m in items['inspection_list']:
            inspection_per_article.append([i, items['Citation'], m])
            
    inspection_df = pd.DataFrame(
        inspection_per_article, 
        columns=['paper nb', 'Citation', 'inspection method'])

    # Remove "no" entries, because they make it really hard to see the 
    # actual distribution
    n_nos = inspection_df['inspection method'].value_counts()['no']
    n_papers = inspection_df.shape[0]
    logger.info('Number of papers without model inspection method: {}'.format(n_nos))
    inspection_df = inspection_df[inspection_df['inspection method'] != 'no']

    # # Replace "no" by "None"
    # inspection_df['inspection method'][
    #     inspection_df['inspection method'] == 'no'] = 'None'

    # Removing low count categories
    inspection_counts = inspection_df['inspection method'].value_counts()
    inspection_df = inspection_df[inspection_df['inspection method'].isin(
        inspection_counts[(inspection_counts >= cutoff)].index)]
    
    inspection_df['inspection method'] = inspection_df['inspection method'].apply(
        lambda x: x.capitalize())
    print(inspection_df['inspection method'])

    # Making table
    inspection_table = inspection_df.groupby(['inspection method'])[
        'Citation'].apply(list)
    order = inspection_df['inspection method'].value_counts().index
    inspection_table = inspection_table.reindex(order)
    inspection_table = inspection_table.apply(lambda x: r'\cite{' + ', '.join(x) + '}')

    with open(os.path.join(save_cfg['table_savepath'], 'inspection_methods.tex'), 'w') as f:
        with pd.option_context("max_colwidth", 1000):
            f.write(inspection_table.to_latex(escape=False))

    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 3, 
                                    save_cfg['text_height'] / 2))
    ax = sns.countplot(y='inspection method', data=inspection_df, 
                    order=inspection_df['inspection method'].value_counts().index)
    ax.set_xlabel('Number of papers')
    ax.set_ylabel('')
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.tight_layout()

    logger.info('% of studies that used model inspection techniques: {}'.format(
        100 - 100 * (n_nos / n_papers)))

    if save_cfg is not None:
        savename = 'model_inspection'
        fname = os.path.join(save_cfg['savepath'], savename)
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_type_of_paper(df, save_cfg=cfg.saving_config):
    """Plot bar graph showing the type of each paper (journal, conference, etc.).
    """
    # Move supplements to journal paper category for the plot (a value of one is
    # not visible on a bar graph).
    df_plot = df.copy()
    df_plot.loc[df['Type of paper'] == 'Supplement', :] = 'Journal'

    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4, 
                                    save_cfg['text_height'] / 5))
    sns.countplot(x=df_plot['Type of paper'], ax=ax)
    ax.set_xlabel('')
    ax.set_ylabel('Number of papers')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    plt.tight_layout()

    counts = df['Type of paper'].value_counts()
    logger.info('Number of journal papers: {}'.format(counts['Journal']))
    logger.info('Number of conference papers: {}'.format(counts['Conference']))
    logger.info('Number of preprints: {}'.format(counts['Preprint']))
    logger.info('Number of papers that were initially published as preprints: '
                '{}'.format(df[df['Type of paper'] != 'Preprint'][
                    'Preprint first'].value_counts()['Yes']))

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'type_of_paper')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_country(df, save_cfg=cfg.saving_config):
    """Plot bar graph showing the country of the first author's affiliation.
    """
    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 3, 
                                    save_cfg['text_height'] / 5))
    sns.countplot(x=df['Country'], ax=ax,
                order=df['Country'].value_counts().index)
    ax.set_ylabel('Number of papers')
    ax.set_xlabel('')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    plt.tight_layout()

    top3 = df['Country'].value_counts().index[:3]
    logger.info('Top 3 countries of first author affiliation: {}'.format(top3.values))

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'country')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_countrymap(dfx, postprocess=True, save_cfg=cfg.saving_config):
    """Plot world map with colour indicating number of papers.

    Plot a world map where the colour of each country indicates how many papers
    were published in which the first author's affiliation was from that country.

    When saved as .eps this figure is well over the 6 MB limit allowed by arXiv.
    To solve this, we first save it as a .png (with high enough dpi), then use
    inkscape to convert it to .eps (leading to a file of ~1.6 MB):

    >> inkscape countrymap.png --export-eps=countrymap.eps

    Keyword Args:
        postpocess (bool): if True, convert PNG to EPS using inkscape in a 
            system call.
    """
    dirname = os.path.dirname(__file__)
    shapefile = os.path.join(dirname, '../img/countries/ne_10m_admin_0_countries.shp')

    gdf = gpd.read_file(shapefile)[['ADMIN', 'geometry']] #.to_crs('+proj=robin')
    # gdf = gdf.to_crs(epsg=4326)
    gdf.crs = '+init=epsg:4326'

    dfx = dfx.Country.value_counts().reset_index().rename(
        columns={'index': 'Country', 'Country': 'Count'})

    #print("Renaming Exceptions!")
    #print(dfx.loc[~dfx['Country'].isin(gdf['ADMIN'])])

    # Exception #1 - USA: United States of America
    dfx.loc[dfx['Country'] == 'USA', 'Country'] = 'United States of America'

    # Exception #2 - UK: United Kingdom
    dfx.loc[dfx['Country'] == 'UK', 'Country'] = 'United Kingdom'

    # Exception #3 - Bosnia: Bosnia and Herzegovina
    dfx.loc[dfx['Country'] == 'Bosnia', 'Country'] = 'Bosnia and Herzegovina'

    if len(dfx.loc[~dfx['Country'].isin(gdf['ADMIN'])]) > 0:
        print("## ERROR ## - Unhandled Countries!")

    # Adding 0 to all other countries!
    gdf['Count'] = 0
    for c in gdf['ADMIN']:
        if any(dfx['Country'].str.contains(c)):
            gdf.loc[gdf['ADMIN'] == c, 'Count'] = dfx[
                dfx['Country'].str.contains(c)]['Count'].values[0]
        else:
            gdf.loc[gdf['ADMIN'] == c, 'Count'] = 0

    # figsize = (16, 10)
    figsize = (save_cfg['text_width'], save_cfg['text_height'] / 2)
    ax = gdf.plot(column='Count', figsize=figsize, cmap='Blues', 
                  scheme='Fisher_Jenks', k=10, legend=True, edgecolor='k',
                  linewidth=0.3, categorical=False, vmin=0,
                  legend_kwds={'loc': 'lower left', 'title': 'Number of studies',
                               'framealpha': 1},
                  rasterized=False)

    # Remove floating points in legend
    leg = ax.get_legend()
    for t in leg.get_texts():
        t.set_text(t.get_text().replace('.00', ''))

    ax.set_axis_off()
    fig = ax.get_figure()
    plt.tight_layout()
    
    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'countrymap')
        save_cfg2 = save_cfg.copy()
        save_cfg2['dpi'] = 1000
        save_cfg2['format'] = 'png'
        fig.savefig(fname + '.png', **save_cfg2)

        if postprocess:
            subprocess.call(
                ['inkscape', fname + '.png', '--export-eps=' + fname + '.eps'])

    return ax


def compute_prct_statistical_tests(df):
    """Compute the number of studies that used statistical tests.
    """
    prct = 100 - 100 * df['Statistical analysis of performance'].value_counts(
        )['No'] / df.shape[0]
    logger.info('% of studies that used statistical test: {}'.format(prct))


def make_domain_table(df, save_cfg=cfg.saving_config):
    """Make domain table that contains every reference.
    """
    # Replace NaNs by ' ' in 'Domain 3' and 'Domain 4' columns
    df = ut.replace_nans_in_column(df, 'Domain 3', replace_by=' ')
    df = ut.replace_nans_in_column(df, 'Domain 4', replace_by=' ')

    cols = ['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4', 'Architecture (clean)']
    df[cols] = df[cols].applymap(ut.tex_escape)

    # Make tuple of first 2 domain levels
    domains_df = df.groupby(cols)['Citation'].apply(list).apply(
        lambda x: '\cite{' + ', '.join(x) + '}').unstack()
    domains_df = domains_df.applymap(
        lambda x: ' ' if isinstance(x, float) and np.isnan(x) else x)

    fname = os.path.join(save_cfg['table_savepath'], 'domains_architecture_table.tex')
    with open(fname, 'w') as f:
        with pd.option_context("max_colwidth", 1000):
            f.write(domains_df.to_latex(
                escape=False, 
                column_format='p{1.5cm}' * 4 + 'p{0.6cm}' * domains_df.shape[1]))


def plot_preprocessing_proportions(df, save_cfg=cfg.saving_config):
    """Plot proportions for preprocessing-related data items.
    """
    data = dict()
    data['(a) Preprocessing of EEG data'] = df[
         'Preprocessing (clean)'].value_counts().to_dict()
    data['(b) Artifact handling'] = df[
         'Artefact handling (clean)'].value_counts().to_dict()
    data['(c) Extracted features'] = df[
         'Features (clean)'].value_counts().to_dict()

    fig, ax = ut.plot_multiple_proportions(
        data, print_count=5, respect_order=['Yes', 'No', 'Other', 'N/M'],
        figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] / 7 * 2))
    
    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'preprocessing')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_hyperparams_proportions(df, save_cfg=cfg.saving_config):
    """Plot proportions for hyperparameter-related data items.
    """
    data = dict()
    data['(a) Training procedure'] = df[
         'Training procedure (clean)'].value_counts().to_dict()
    data['(b) Regularization'] = df[
         'Regularization (clean)'].value_counts().to_dict()
    data['(c) Optimizer'] = df[
         'Optimizer (clean)'].value_counts().to_dict()

    fig, ax = ut.plot_multiple_proportions(
        data, print_count=5, respect_order=['Yes', 'No', 'Other', 'N/M'],
        figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] / 7 * 2))
    
    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'hyperparams')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_reproducibility_proportions(df, save_cfg=cfg.saving_config):
    """Plot proportions for reproducibility-related data items.
    """
    df['Code hosted on'] = df['Code hosted on'].replace(np.nan, 'N/M', regex=True)
    df['Limited data'] = df['Limited data'].replace(np.nan, 'N/M', regex=True)
    df['Code available'] = df['Code available'].replace(np.nan, 'N/M', regex=True)

    data = dict()
    data['(a) Dataset availability'] = df[
         'Dataset accessibility'].value_counts().to_dict()
    data['(b) Code availability'] = df[
         'Code hosted on'].value_counts().to_dict()
    data['(c) Type of baseline'] = df[
         'Baseline model type'].value_counts().to_dict()

    df['reproducibility'] = 'Hard'
    df.loc[(df['Code available'] == 'Yes') & 
           (df['Dataset accessibility'] == 'Public'), 'reproducibility'] = 'Easy' 
    df.loc[(df['Code available'] == 'Yes') & 
           (df['Dataset accessibility'] == 'Both'), 'reproducibility'] = 'Medium' 
    df.loc[(df['Code available'] == 'No') & 
           (df['Dataset accessibility'] == 'Private'), 'reproducibility'] = 'Impossible' 

    data['(d) Reproducibility'] = df[
         'reproducibility'].value_counts().to_dict()

    logger.info('Stats on reproducibility - Dataset Accessibility: {}'.format(
        data['(a) Dataset availability']))
    logger.info('Stats on reproducibility - Code Accessibility: {}'.format(
        df['Code available'].value_counts().to_dict()))
    logger.info('Stats on reproducibility - Code Hosted On: {}'.format(
        data['(b) Code availability']))
    logger.info('Stats on reproducibility - Baseline: {}'.format(
        data['(c) Type of baseline']))
    logger.info('Stats on reproducibility - Reproducibility Level: {}'.format(
        data['(d) Reproducibility']))
    logger.info('Stats on reproducibility - Limited data: {}'.format(
        df['Limited data'].value_counts().to_dict()))
    logger.info('Stats on reproducibility - Shared their Code: {}'.format(
        df[df['Code available'] == 'Yes']['Citation'].to_dict()))

    fig, ax = ut.plot_multiple_proportions(
        data, print_count=5, respect_order=['Easy', 'Medium', 'Hard', 'Impossible'],
        figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] * 0.4))
    
    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'reproducibility')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_domains_per_year(df, save_cfg=cfg.saving_config):
    """Plot stacked bar graph of domains per year.
    """
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_height'] / 4))

    df['Year'] = df['Year'].astype('int32')
    main_domains = ['Epilepsy', 'Sleep', 'BCI', 'Affective', 'Cognitive', 
                    'Improvement of processing tools', 'Generation of data']
    domains_df = df[['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4']]
    df['Main domain'] = [row[row.isin(main_domains)].values[0] 
        if any(row.isin(main_domains)) else 'Others' 
        for ind, row in domains_df.iterrows()]
    df.groupby(['Year', 'Main domain']).size().unstack('Main domain').plot(
        kind='bar', stacked=True, title='', ax=ax)
    ax.set_ylabel('Number of papers')
    ax.set_xlabel('')

    legend = plt.legend()
    for l in legend.get_texts():
        l.set_text(ut.wrap_text(l.get_text(), max_char=14))

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'domains_per_year')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_hardware(df, save_cfg=cfg.saving_config):
    """Plot bar graph showing the hardware used in the study.
    """
    col = 'EEG Hardware'
    hardware_df = ut.split_column_with_multiple_entries(
        df, col, ref_col='Citation', sep=',', lower=False)

    # Remove N/Ms because they make it hard to see anything
    hardware_df = hardware_df[hardware_df[col] != 'N/M']
    
    # Add low cost column
    hardware_df['Low-cost'] = False
    low_cost_devices = ['EPOC (Emotiv)', 'OpenBCI (OpenBCI)', 'Muse (InteraXon)', 
                        'Mindwave Mobile (Neurosky)', 'Mindset (NeuroSky)']
    hardware_df.loc[hardware_df[col].isin(low_cost_devices), 
                    'Low-cost'] = True

    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 2, 
                                    save_cfg['text_height'] / 5 * 2))
    sns.countplot(hue=hardware_df['Low-cost'], y=hardware_df[col], ax=ax,
                  order=hardware_df[col].value_counts().index, 
                  dodge=False)
    # sns.catplot(row=hardware_df['low_cost'], y=hardware_df['hardware'])
    ax.set_xlabel('Number of papers')
    ax.set_ylabel('')
    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'hardware')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_architectures(df, save_cfg=cfg.saving_config):
    """Plot bar graph showing the architectures used in the study.
    """
    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 3, 
                                    save_cfg['text_width'] / 3))
    colors = sns.color_palette()
    counts = df['Architecture (clean)'].value_counts()
    _, _, pct = ax.pie(counts.values, labels=counts.index, autopct='%1.1f%%',
           wedgeprops=dict(width=0.3, edgecolor='w'), colors=colors,
           pctdistance=0.55)
    for i in pct:
        i.set_fontsize(5)

    ax.axis('equal')
    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'architectures')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax
    

def plot_architectures_per_year(df, save_cfg=cfg.saving_config):
    """Plot stacked bar graph of architectures per year.
    """
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 3 * 2, save_cfg['text_width'] / 3))
    colors = sns.color_palette()

    df['Year'] = df['Year'].astype('int32')
    col_name = 'Architecture (clean)'
    df['Arch'] = df[col_name]
    order = df[col_name].value_counts().index
    counts = df.groupby(['Year', 'Arch']).size().unstack('Arch')
    counts = counts[order]

    counts.plot(kind='bar', stacked=True, title='', ax=ax, color=colors)
    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
    ax.set_ylabel('Number of papers')
    ax.set_xlabel('')

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'architectures_per_year')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_architectures_vs_input(df, save_cfg=cfg.saving_config):
    """Plot stacked bar graph of architectures vs input type.
    """
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 3))

    df['Input'] = df['Features (clean)']
    col_name = 'Architecture (clean)'
    df['Arch'] = df[col_name]
    order = df[col_name].value_counts().index
    counts = df.groupby(['Input', 'Arch']).size().unstack('Input')
    counts = counts.loc[order, :]

    # To reduce the height of the figure, wrap long xticklabels
    counts = counts.rename({'CNN+RNN': 'CNN+\nRNN'}, axis='index')

    counts.plot(kind='bar', stacked=True, title='', ax=ax)
    # ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
    ax.set_ylabel('Number of papers')
    ax.set_xlabel('')

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'architectures_vs_input')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

        save_cfg2 = save_cfg.copy()
        save_cfg2['format'] = 'png'
        fig.savefig(fname + '.png', **save_cfg2)

    return ax


def plot_optimizers_per_year(df, save_cfg=cfg.saving_config):
    """Plot stacked bar graph of optimizers per year.
    """
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 5 * 2))

    df['Input'] = df['Features (clean)']
    col_name = 'Optimizer (clean)'
    df['Opt'] = df[col_name]
    order = df[col_name].value_counts().index
    counts = df.groupby(['Year', 'Opt']).size().unstack('Opt')
    counts = counts[order]

    counts.plot(kind='bar', stacked=True, title='', ax=ax)
    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
    ax.set_ylabel('Number of papers')
    ax.set_xlabel('')

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'optimizers_per_year')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_intra_inter_per_year(df, save_cfg=cfg.saving_config):
    """Plot stacked bar graph of intra-/intersubject studies per year.
    """
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_height'] / 4))

    df['Year'] = df['Year'].astype(int)
    col_name = 'Intra/Inter subject'
    order = df[col_name].value_counts().index
    counts = df.groupby(['Year', col_name]).size().unstack(col_name)
    counts = counts[order]

    logger.info('Stats on inter/intra subjects: {}'.format(
        df[col_name].value_counts() / df.shape[0] * 100))

    counts.plot(kind='bar', stacked=True, title='', ax=ax)
    # ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
    ax.set_ylabel('Number of papers')
    ax.set_xlabel('')

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'intra_inter_per_year')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def plot_number_layers(df, save_cfg=cfg.saving_config):
    """Plot histogram of number of layers.
    """
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 3))

    n_layers_df = df['Layers (clean)'].value_counts().reindex(
        [str(i) for i in range(1, 32)] + ['N/M'])
    n_layers_df = n_layers_df.dropna().astype(int)

    from matplotlib.colors import ListedColormap
    cmap = ListedColormap(sns.color_palette(None).as_hex())

    n_layers_df.plot(kind='bar', width=0.8, rot=0, colormap=cmap, ax=ax)
    ax.set_xlabel('Number of layers')
    ax.set_ylabel('Number of papers')
    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'number_layers')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

        save_cfg2 = save_cfg.copy()
        save_cfg2['format'] = 'png'
        save_cfg2['dpi'] = 300
        fig.savefig(fname + '.png', **save_cfg2)

    return ax   


def plot_number_subjects_by_domain(df, save_cfg=cfg.saving_config):
    """Plot number of subjects in studies by domain.
    """
    # Split values into separate rows and remove invalid values
    col = 'Data - subjects'
    nb_subj_df = ut.split_column_with_multiple_entries(
        df, col, ref_col='Main domain')
    nb_subj_df = nb_subj_df.loc[~nb_subj_df[col].isin(['n/m', 'tbd'])]
    nb_subj_df[col] = nb_subj_df[col].astype(int)
    nb_subj_df = nb_subj_df.loc[nb_subj_df[col] > 0, :]

    nb_subj_df['Main domain'] = nb_subj_df['Main domain'].apply(
        ut.wrap_text, max_char=13)

    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 3 * 2, save_cfg['text_height'] / 3))
    ax.set(xscale='log', yscale='linear')
    sns.swarmplot(
        y='Main domain', x=col, data=nb_subj_df, 
        ax=ax, size=3, order=nb_subj_df.groupby(['Main domain'])[
            col].median().sort_values().index)
    ax.set_xlabel('Number of subjects')
    ax.set_ylabel('')
    
    logger.info('Stats on number of subjects per model: {}'.format(
        nb_subj_df[col].describe()))

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'nb_subject_per_domain')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax 


def plot_number_channels(df, save_cfg=cfg.saving_config):
    """Plot histogram of number of channels.
    """
    nb_channels_df = ut.split_column_with_multiple_entries(
        df, 'Nb Channels', ref_col='Citation', sep=';\n', lower=False)
    nb_channels_df['Nb Channels'] = nb_channels_df['Nb Channels'].astype(int)
    nb_channels_df = nb_channels_df.loc[nb_channels_df['Nb Channels'] > 0, :]

    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 2, save_cfg['text_height'] / 4))
    sns.distplot(nb_channels_df['Nb Channels'], kde=False, norm_hist=False, ax=ax)
    ax.set_xlabel('Number of EEG channels')
    ax.set_ylabel('Number of papers')

    logger.info('Stats on number of channels per model: {}'.format(
        nb_channels_df['Nb Channels'].describe()))

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'nb_channels')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def compute_stats_sampling_rate(df):
    """Compute the statistics for hardware sampling rate.
    """
    fs_df = ut.split_column_with_multiple_entries(
        df, 'Sampling rate', ref_col='Citation', sep=';\n', lower=False)
    fs_df['Sampling rate'] = fs_df['Sampling rate'].astype(float)
    fs_df = fs_df.loc[fs_df['Sampling rate'] > 0, :]

    logger.info('Stats on sampling rate per model: {}'.format(
        fs_df['Sampling rate'].describe()))


def plot_cross_validation(df, save_cfg=cfg.saving_config):
    """Plot bar graph of cross validation approaches.
    """
    col = 'Cross validation (clean)'
    df[col] = df[col].fillna('N/M')
    cv_df = ut.split_column_with_multiple_entries(
        df, col, ref_col='Citation', sep=';\n', lower=False)
    
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 2, save_cfg['text_height'] / 5))
    sns.countplot(y=cv_df[col], order=cv_df[col].value_counts().index, ax=ax)
    ax.set_xlabel('Number of papers')
    ax.set_ylabel('')
    
    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'cross_validation')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax


def make_dataset_table(df, min_n_articles=2, save_cfg=cfg.saving_config):
    """Make table that reports most used datasets.

    Args:
        df

    Keyword Args:
        min_n_articles (int): minimum number of times a dataset must have been
            used to be listed in the table. If under that number, will appear as
            'Other' in the table.
        save_cfg (dict)
    """
    def merge_dataset_names(s):
        if 'bci comp' in s.lower():
            s = 'BCI Competition'
        elif 'tuh' in s.lower():
            s = 'TUH'
        elif 'mahnob' in s.lower():
            s = 'MAHNOB'
        return s

    col = 'Dataset name'
    datasets_df = ut.split_column_with_multiple_entries(
        df, col, ref_col=['Main domain', 'Citation'], sep=';\n', lower=False)

    # Remove not mentioned and internal recordings, as readers won't be able to 
    # use these datasets anyway
    datasets_df = datasets_df.loc[~datasets_df[col].isin(
        ['N/M', 'Internal Recordings', 'TBD'])]

    datasets_df['Dataset'] = datasets_df[col].apply(merge_dataset_names).apply(
        ut.tex_escape)

    # Replace datasets that were used rarely by 'Other'
    counts = datasets_df['Dataset'].value_counts()
    datasets_df.loc[datasets_df['Dataset'].isin(
        counts[counts < min_n_articles].index), 'Dataset'] = 'Other'

    # Remove duplicates (due to grouping of Others and BCI Comp)
    datasets_df = datasets_df.drop(labels=col, axis=1)
    datasets_df = datasets_df.drop_duplicates()

    # Group by dataset and order by number of articles
    dataset_table = datasets_df.groupby(
        ['Main domain', 'Dataset'], as_index=True)['Citation'].apply(list)
    dataset_table = pd.concat([dataset_table.apply(len), dataset_table], axis=1)
    dataset_table.columns = [r'\# articles', 'References']

    dataset_table = dataset_table.sort_values(
        by=['Main domain', r'\# articles'], ascending=[True, False])
    dataset_table['References'] = dataset_table['References'].apply(
        lambda x: r'\cite{' + ', '.join(x) + '}')

    with open(os.path.join(save_cfg['table_savepath'], 'dataset_table.tex'), 'w') as f:
        with pd.option_context("max_colwidth", 1000):
            f.write(dataset_table.to_latex(escape=False, multicolumn=False))


def plot_data_quantity(df, save_cfg=cfg.saving_config):
    """Plot the quantity of data used by domain.
    """
    data_df = ut.split_column_with_multiple_entries(
        df, ['Data - samples', 'Data - time'], ref_col=['Citation', 'Main domain'], 
        sep=';\n', lower=False)

    # Remove N/M and TBD
    col = 'Data - samples'
    data_df.loc[data_df[col].isin(['N/M', 'TBD', '[TBD]']), col] = np.nan
    data_df[col] = data_df[col].astype(float)

    col2 = 'Data - time'
    data_df.loc[data_df[col2].isin(['N/M', 'TBD', '[TBD]']), col2] = np.nan
    data_df[col2] = data_df[col2].astype(float)

    # Wrap main domain text
    data_df['Main domain'] = data_df['Main domain'].apply(
        ut.wrap_text, max_char=13)

    # Extract ratio
    data_df['data_ratio'] = data_df['Data - samples'] / data_df['Data - time']
    data_df = data_df.sort_values(['Main domain', 'data_ratio'])

    # Plot
    fig, axes = plt.subplots(
        ncols=3, 
        figsize=(save_cfg['text_width'], save_cfg['text_height'] / 3))

    axes[0].set(xscale='log', yscale='linear')
    sns.swarmplot(y='Main domain', x=col2, data=data_df, ax=axes[0], size=3)
    axes[0].set_xlabel('Recording time (min)')
    axes[0].set_ylabel('')
    max_val = int(np.ceil(np.log10(data_df[col2].max())))
    axes[0].set_xticks(np.power(10, range(0, max_val + 1)))

    axes[1].set(xscale='log', yscale='linear')
    sns.swarmplot(y='Main domain', x=col, data=data_df, ax=axes[1], size=3)
    axes[1].set_xlabel('Number of examples')
    axes[1].set_yticklabels('')
    axes[1].set_ylabel('')
    min_val = int(np.floor(np.log10(data_df[col].min())))
    max_val = int(np.ceil(np.log10(data_df[col].max())))
    axes[1].set_xticks(np.power(10, range(min_val, max_val + 1)))

    axes[2].set(xscale='log', yscale='linear')
    sns.swarmplot(y='Main domain', x='data_ratio', data=data_df, ax=axes[2], 
                  size=3)
    axes[2].set_xlabel('Ratio (examples/min)')
    axes[2].set_ylabel('')
    axes[2].set_yticklabels('')
    min_val = int(np.floor(np.log10(data_df['data_ratio'].min())))
    max_val = int(np.ceil(np.log10(data_df['data_ratio'].max())))
    axes[2].set_xticks(np.power(10, np.arange(min_val, max_val + 1, dtype=float)))

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'data_quantity')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return axes


def plot_eeg_intro(save_cfg=cfg.saving_config):
    """Plot a figure that shows basic EEG notions such as epochs and samples.
    """

    # Visualization parameters
    win_len = 1  # in s
    step = 0.5  # in s
    first_epoch = 1

    data, t, fs = ut.get_real_eeg_data(start=30, stop=34, chans=[0, 10, 20, 30])
    t = t - t[0]

    # Offset data for visualization
    data -= data.mean(axis=0)
    max_std = np.max(data.std(axis=0))
    offsets = np.arange(data.shape[1])[::-1] * 4 * max_std
    data += offsets

    rect_y_border = 0.6 * max_std
    min_y = data.min() - rect_y_border
    max_y = data.max() + rect_y_border

    # Make figure
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 4 * 3, save_cfg['text_height'] / 3))
    ax.plot(t, data)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel(r'Amplitude (e.g., $\mu$V)')
    ax.set_yticks(offsets)
    ax.set_yticklabels(['channel {}'.format(i + 1) for i in range(data.shape[1])])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Display epochs as dashed line rectangles
    rect1 = patches.Rectangle((first_epoch, min_y + rect_y_border / 4), 
                            win_len, max_y - min_y, 
                            linewidth=1, linestyle='--', edgecolor='k',
                            facecolor='none')
    rect2 = patches.Rectangle((first_epoch + step, min_y - rect_y_border / 4), 
                            win_len, max_y - min_y, 
                            linewidth=1, linestyle='--', edgecolor='k',
                            facecolor='none')

    ax.add_patch(rect1)
    ax.add_patch(rect2)

    # Annotate epochs
    ax.annotate(
        r'$\bf{Window}$ or $\bf{epoch}$ or $\bf{trial}$' +
        '\n({:.0f} points in a \n1-s window at {:.0f} Hz)'.format(fs, fs), #fontsize=14, 
        xy=(first_epoch, min_y), 
        arrowprops=dict(facecolor='black', shrink=0.05, width=2, headwidth=6),
        xytext=(0, min_y - 3.5 * max_std),
        xycoords='data', ha='center', va='top')
    
    # Annotate input
    ax.annotate(r'Neural network input' + '\n'
        r'$X_i \in \mathbb{R}^{c \times l}$', #fontsize=14,
        xy=(first_epoch+1.5, min_y),
        arrowprops=dict(facecolor='black', shrink=0.05, width=2),
        xytext=(4, min_y - 5.3 * max_std),
        xycoords='data', ha='right', va='bottom')

    # Annotate sample
    special_ind = np.where((t >= 2.4) & (t < 2.5))[0][0]
    special_point = data[special_ind, 0]
    ax.plot(t[special_ind], special_point, '.', c='k')
    ax.annotate(
        r'$\bf{Point}$ or $\bf{sample}$', #fontsize=14, 
        xy=(t[special_ind], special_point), 
        arrowprops=dict(facecolor='black', shrink=0.05, width=2, headwidth=6),
        xytext=(3, max_y),
        xycoords='data', ha='left', va='bottom')

    # Annotate overlap
    ut.draw_brace(ax, (first_epoch + step, first_epoch + step * 2), 
            r'0.5-s $\bf{overlap}$' + '\nbetween windows', 
            beta_factor=300, y_offset=max_y)

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'eeg_intro')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax