"""Functions to plot and compute results for the literature review. """ import os import logging import logging.config from collections import OrderedDict import subprocess import pandas as pd import geopandas as gpd import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator import matplotlib.patches as patches import matplotlib as mpl import seaborn as sns import numpy as np from PIL import Image from wordcloud import WordCloud, STOPWORDS from graphviz import Digraph import config as cfg import utils as ut # Set style, context and palette sns.set_style(rc=cfg.axes_styles) sns.set_context(rc=cfg.plotting_context) sns.set_palette(cfg.palette) for key, val in cfg.axes_styles.items(): mpl.rcParams[key] = val for key, val in cfg.plotting_context.items(): mpl.rcParams[key] = val # Initialize logger for saving results and stats. Use `logger.info('message')` # to log results. logging.config.dictConfig({ 'version': 1, 'disable_existing_loggers': True, }) logger = logging.getLogger() log_savename = os.path.join(cfg.saving_config['savepath'], 'results') + '.log' handler = logging.FileHandler(log_savename, mode='w') formatter = logging.Formatter( '%(asctime)s %(name)-12s %(levelname)-8s %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(logging.INFO) def plot_prisma_diagram(save_cfg=cfg.saving_config): """Plot diagram showing the number of selected articles. TODO: - Use first two colors of colormap instead of gray - Reduce white space - Reduce arrow width """ # save_format = save_cfg['format'] if isinstance(save_cfg, dict) else 'svg' save_format = 'pdf' # save_format = 'eps' size = '{},{}!'.format(0.5 * save_cfg['page_width'], 0.2 * save_cfg['page_height']) dot = Digraph(format=save_format) dot.attr('graph', rankdir='TB', overlap='false', size=size, margin='0') dot.attr('node', fontname='Liberation Sans', fontsize=str(9), shape='box', style='filled', margin='0.15,0.07', penwidth='0.1') # dot.attr('edge', arrowsize=0.5) fillcolor = 'gray98' dot.node('A', 'PubMed (n=39)\nGoogle Scholar (n=409)\narXiv (n=105)', fillcolor='gray95') dot.node('B', 'Articles identified\nthrough database\nsearching\n(n=553)', fillcolor=fillcolor) # dot.node('B2', 'Excluded\n(n=446)', fillcolor=fillcolor) dot.node('C', 'Articles after content\nscreening and\nduplicate removal\n(n=105) ', fillcolor=fillcolor) dot.node('D', 'Articles included in\nthe analysis\n(n=154)', fillcolor=fillcolor) dot.node('E', 'Additional articles\nidentified through\nbibliography search\n(n=49)', fillcolor=fillcolor) dot.edge('B', 'C') # dot.edge('B', 'B2') dot.edge('C', 'D') dot.edge('E', 'D') if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'prisma_diagram') dot.render(filename=fname, view=False, cleanup=False) return dot def plot_domain_tree(df, first_box='DL + EEG studies', min_font_size=10, max_font_size=14, max_char=16, min_n_items=2, postprocess=True, save_cfg=cfg.saving_config): """Plot tree graph showing the breakdown of study domains. Args: df (pd.DataFrame): data items table Keyword Args: first_box (str): text of the first box min_font_size (int): minimum font size max_font_size (int): maximum font size max_char (int): maximum number of characters per line min_n_items (int): if a node has less than this number of elements, put it inside a node called "Others". postpocess (bool): if True, convert PNG to EPS using inkscape in a system call. save_cfg (dict or None): Returns: (graphviz.Digraph): graphviz object NOTES: - To unflatten automatically, apply the following on the .dot file: >> unflatten -l 3 -c 10 dom_domains_tree | dot -Teps -o domains_unflattened.eps - To produce a circular version instead (uses space more efficiently): >> neato -Tpdf dom_domains_tree -o domains_neato.pdf """ df = df[['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4']].copy() df = df[~df['Domain 1'].isnull()] df[df == ' '] = None n_samples, n_levels = df.shape format = save_cfg['format'] if isinstance(save_cfg, dict) else 'svg' size = '{},{}!'.format(save_cfg['page_width'], 0.7 * save_cfg['page_height']) dot = Digraph(format=format) dot.attr('graph', rankdir='TB', overlap='false', ratio='fill', size=size) # LR (left to right), TB (top to bottom) dot.attr('node', fontname='Liberation Sans', fontsize=str(max_font_size), shape='box', style='filled, rounded', margin='0.2,0.01', penwidth='0.5') dot.node('A', '{}\n({})'.format(first_box, len(df)), fillcolor='azure') min_sat, max_sat = 0.05, 0.4 sub_df = df['Domain 1'].value_counts() n_categories = len(sub_df) for i, (d1, count1) in enumerate(sub_df.iteritems()): node1, hue = ut.make_box( dot, d1, max_char, count1, n_samples, 0, n_levels, min_sat, max_sat, min_font_size, max_font_size, 'A', counter=i, n_categories=n_categories) for d2, count2 in df[df['Domain 1'] == d1]['Domain 2'].value_counts().iteritems(): node2, _ = ut.make_box( dot, d2, max_char, count2, n_samples, 1, n_levels, min_sat, max_sat, min_font_size, max_font_size, node1, hue=hue) n_others3 = 0 for d3, count3 in df[df['Domain 2'] == d2]['Domain 3'].value_counts().iteritems(): if isinstance(d3, str) and d3 != 'TBD': if count3 < min_n_items: n_others3 += 1 else: node3, _ = ut.make_box( dot, d3, max_char, count3, n_samples, 2, n_levels, min_sat, max_sat, min_font_size, max_font_size, node2, hue=hue) n_others4 = 0 for d4, count4 in df[df['Domain 3'] == d3]['Domain 4'].value_counts().iteritems(): if isinstance(d4, str) and d4 != 'TBD': if count4 < min_n_items: n_others4 += 1 else: ut.make_box( dot, d4, max_char, count4, n_samples, 3, n_levels, min_sat, max_sat, min_font_size, max_font_size, node3, hue=hue) if n_others4 > 0: ut.make_box( dot, 'Others', max_char, n_others4, n_samples, 3, n_levels, min_sat, max_sat, min_font_size, max_font_size, node3, hue=hue, node_name=node3+'others') if n_others3 > 0: ut.make_box( dot, 'Others', max_char, n_others3, n_samples, 2, n_levels, min_sat, max_sat, min_font_size, max_font_size, node2, hue=hue, node_name=node2+'others') if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'dom_domains_tree') dot.render(filename=fname, cleanup=False) if postprocess: subprocess.call( ['neato', '-Tpdf', fname, '-o', fname + '.pdf']) return dot def plot_model_comparison(df, save_cfg=cfg.saving_config): """Plot bar graph showing the types of baseline models used. """ fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_height'] / 5)) sns.countplot(y=df['Baseline model type'].dropna(axis=0), ax=ax) ax.set_xlabel('Number of papers') ax.set_ylabel('') plt.tight_layout() model_prcts = df['Baseline model type'].value_counts() / df.shape[0] * 100 logger.info('% of studies that used at least one traditional baseline: {}'.format( model_prcts['Traditional pipeline'] + model_prcts['DL & Trad.'])) logger.info('% of studies that used at least one deep learning baseline: {}'.format( model_prcts['DL'] + model_prcts['DL & Trad.'])) logger.info('% of studies that did not report baseline comparisons: {}'.format( model_prcts['None'])) if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'model_comparison') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_performance_metrics(df, cutoff=3, eeg_clf=None, save_cfg=cfg.saving_config): """Plot bar graph showing the types of performance metrics used. Args: df (DataFrame) Keyword Args: cutoff (int): Metrics with less than this number of papers will be cut off from the bar graph. eeg_clf (bool): If True, only use studies that focus on EEG classification. If False, only use studies that did not focus on EEG classification. If None, use all studies. save_cfg (dict) Assumptions, simplifications: - Rates have been simplified (e.g., "false positive rate" -> "false positives") - RMSE and MSE have been merged under MSE - Training/testing times have been simplified to "time" - Macro f1-score === f1=score """ if eeg_clf is True: metrics = df[df['Domain 1'] == 'Classification of EEG signals'][ 'Performance metrics (clean)'] elif eeg_clf is False: metrics = df[df['Domain 1'] != 'Classification of EEG signals'][ 'Performance metrics (clean)'] elif eeg_clf is None: metrics = df['Performance metrics (clean)'] metrics = metrics.str.split(',').apply(ut.lstrip) metric_per_article = list() for i, metric_list in metrics.iteritems(): for m in metric_list: metric_per_article.append([i, m]) metrics_df = pd.DataFrame(metric_per_article, columns=['paper nb', 'metric']) # Replace equivalent terms by standardized term equivalences = {'selectivity': 'specificity', 'true negative rate': 'specificity', 'sensitivitiy': 'sensitivity', 'sensitivy': 'sensitivity', 'recall': 'sensitivity', 'hit rate': 'sensitivity', 'true positive rate': 'sensitivity', 'sensibility': 'sensitivity', 'positive predictive value': 'precision', 'f-measure': 'f1-score', 'f-score': 'f1-score', 'f1-measure': 'f1-score', 'macro f1-score': 'f1-score', 'macro-averaging f1-score': 'f1-score', 'kappa': 'cohen\'s kappa', 'mae': 'mean absolute error', 'false negative rate': 'false negatives', 'fpr': 'false positives', 'false positive rate': 'false positives', 'false prediction rate': 'false positives', 'roc': 'ROC curves', 'roc auc': 'ROC AUC', 'rmse': 'mean squared error', 'mse': 'mean squared error', 'training time': 'time', 'testing time': 'time', 'test error': 'error'} metrics_df = metrics_df.replace(equivalences) metrics_df['metric'] = metrics_df['metric'].apply(lambda x: x[0].upper() + x[1:]) # Removing low count categories metrics_counts = metrics_df['metric'].value_counts() metrics_df = metrics_df[metrics_df['metric'].isin( metrics_counts[(metrics_counts >= cutoff)].index)] fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 2, save_cfg['text_height'] / 5)) ax = sns.countplot(y='metric', data=metrics_df, order=metrics_df['metric'].value_counts().index) ax.set_xlabel('Number of papers') ax.set_ylabel('') plt.tight_layout() if save_cfg is not None: savename = 'performance_metrics' if eeg_clf is True: savename += '_eeg_clf' elif eeg_clf is False: savename += '_not_eeg_clf' fname = os.path.join(save_cfg['savepath'], savename) fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_reported_results(df, data_items_df=None, save_cfg=cfg.saving_config): """Plot figures to described the reported results in the studies. Args: df (DataFrame): contains reported results (second tab in spreadsheet) Keyword Args: data_items_df (DataFrame): contains data items (first tab in spreadsheet) save_cfg (dict) Returns: (list): list of axes to created figures TODO: - This function is starting to be a bit too big. Should probably split it up. """ acc_df = df[df['Metric'] == 'accuracy'] # Extract accuracy rows only # Create new column that contains both citation and task information acc_df.loc[:, 'citation_task'] = acc_df[['Citation', 'Task']].apply( lambda x: ' ['.join(x) + ']', axis=1) # Create a new column with the year acc_df.loc[:, 'year'] = acc_df['Citation'].apply( lambda x: int(x[x.find('2'):x.find('2') + 4])) # Order by average proposed model accuracy acc_ind = acc_df[acc_df['model_type'] == 'Proposed'].groupby( 'Citation').mean().sort_values(by='Result').index acc_df.loc[:, 'Citation'] = acc_df['Citation'].astype('category') acc_df['Citation'].cat.set_categories(acc_ind, inplace=True) acc_df = acc_df.sort_values(['Citation']) # Only keep 2 best per task and model type acc2_df = acc_df.sort_values( ['Citation', 'Task', 'model_type', 'Result'], ascending=True).groupby( ['Citation', 'Task', 'model_type']).tail(2) axes = list() axes.append(_plot_results_per_citation_task(acc2_df, save_cfg)) # Only keep the maximum accuracy per citation & task best_df = acc_df.groupby( ['Citation', 'Task', 'model_type'])[ 'Result'].max().reset_index() # Only keep citations/tasks that have a traditional baseline best_df = best_df.groupby(['Citation', 'Task']).filter( lambda x: 'Baseline (traditional)' in x.values).reset_index() # Add back architecture best_df = pd.merge( best_df, acc_df[['Citation', 'Task', 'model_type', 'Result', 'Architecture']], how='inner').drop_duplicates() # XXX: why are there duplicates? # Compute difference between proposed and traditional baseline def acc_diff_and_arch(x): diff = x[x['model_type'] == 'Proposed']['Result'].iloc[0] - \ x[x['model_type'] == 'Baseline (traditional)']['Result'].iloc[0] arch = x[x['model_type'] == 'Proposed']['Architecture'] return pd.Series(diff, arch) diff_df = best_df.groupby(['Citation', 'Task']).apply( acc_diff_and_arch).reset_index() diff_df = diff_df.rename(columns={0: 'acc_diff'}) axes.append(_plot_results_accuracy_diff_scatter(diff_df, save_cfg)) axes.append(_plot_results_accuracy_diff_distr(diff_df, save_cfg)) # Pivot dataframe to plot proposed vs. baseline accuracy as a scatterplot best_df['citation_task'] = best_df[['Citation', 'Task']].apply( lambda x: ' ['.join(x) + ']', axis=1) acc_comparison_df = best_df.pivot_table( index='citation_task', columns='model_type', values='Result') axes.append(_plot_results_accuracy_comparison(acc_comparison_df, save_cfg)) if data_items_df is not None: domains_df = data_items_df.filter( regex='(?=Domain*|Citation|Main domain|Journal / Origin|Dataset name|' 'Data - samples|Data - time|Data - subjects|Preprocessing \(clean\)|' 'Artefact handling \(clean\)|Features \(clean\)|Architecture \(clean\)|' 'Layers \(clean\)|Regularization \(clean\)|Optimizer \(clean\)|' 'Intra/Inter subject|Training procedure)') # Concatenate domains into one string def concat_domains(x): domain = '' for i in x[1:]: if isinstance(i, str): domain += i + '/' return domain[:-1] domains_df.loc[:, 'domain'] = data_items_df.filter( regex='(?=Domain*)').apply(concat_domains, axis=1) diff_domain_df = diff_df.merge(domains_df, on='Citation', how='left') diff_domain_df = diff_domain_df.sort_values(by='domain') diff_domain_df.loc[:, 'arxiv'] = diff_domain_df['Journal / Origin'] == 'Arxiv' axes.append(_plot_results_accuracy_per_domain( diff_domain_df, diff_df, save_cfg)) axes.append(_plot_results_stats_impact_on_acc_diff( diff_domain_df, save_cfg)) axes.append(_compute_acc_diff_for_preprints(diff_domain_df, save_cfg)) return axes def _plot_results_per_citation_task(results_df, save_cfg): """Plot scatter plot of accuracy for each condition and task. """ fig, ax = plt.subplots(figsize=(save_cfg['text_width'], save_cfg['text_height'] * 1.3)) # figsize = plt.rcParams.get('figure.figsize') # fig, ax = plt.subplots(figsize=(figsize[0], figsize[1] * 4)) # Need to make the graph taller otherwise the y axis labels are on top of # each other. sns.catplot(y='citation_task', x='Result', hue='model_type', data=results_df, ax=ax) ax.set_xlabel('accuracy') ax.set_ylabel('') plt.tight_layout() if save_cfg is not None: savename = 'reported_results' fname = os.path.join(save_cfg['savepath'], savename) fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def _plot_results_accuracy_diff_scatter(results_df, save_cfg): """Plot difference in accuracy for each condition/task as a scatter plot. """ fig, ax = plt.subplots(figsize=(save_cfg['text_width'], save_cfg['text_height'] * 1.3)) # figsize = plt.rcParams.get('figure.figsize') # fig, ax = plt.subplots(figsize=(figsize[0], figsize[1] * 2)) sns.catplot(y='Task', x='acc_diff', data=results_df, ax=ax) ax.set_xlabel('Accuracy difference') ax.set_ylabel('') ax.axvline(0, c='k', alpha=0.2) plt.tight_layout() if save_cfg is not None: savename = 'reported_accuracy_diff_scatter' fname = os.path.join(save_cfg['savepath'], savename) fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def _plot_results_accuracy_diff_distr(results_df, save_cfg): """Plot the distribution of difference in accuracy. """ fig, ax = plt.subplots(figsize=(save_cfg['text_width'], save_cfg['text_height'] * 0.5)) sns.distplot(results_df['acc_diff'], kde=False, rug=True, ax=ax) ax.set_xlabel('Accuracy difference') ax.set_ylabel('Number of studies') plt.tight_layout() if save_cfg is not None: savename = 'reported_accuracy_diff_distr' fname = os.path.join(save_cfg['savepath'], savename) fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def _plot_results_accuracy_comparison(results_df, save_cfg): """Plot the comparison between the best model and best baseline. """ fig, ax = plt.subplots(figsize=(save_cfg['text_width'], save_cfg['text_height'] * 0.5)) sns.scatterplot(data=results_df, x='Baseline (traditional)', y='Proposed', ax=ax) ax.plot([0, 1.1], [0, 1.1], c='k', alpha=0.2) plt.axis('square') ax.set_xlim([0, 1.1]) ax.set_ylim([0, 1.1]) plt.tight_layout() if save_cfg is not None: savename = 'reported_accuracy_comparison' fname = os.path.join(save_cfg['savepath'], savename) fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def _plot_results_accuracy_per_domain(results_df, diff_df, save_cfg): """Make scatterplot + boxplot to show accuracy difference by domain. """ fig, axes = plt.subplots( nrows=2, ncols=1, sharex=True, figsize=(save_cfg['text_width'], save_cfg['text_height'] / 3), gridspec_kw = {'height_ratios':[5, 1]}) results_df['Main domain'] = results_df['Main domain'].apply( ut.wrap_text, max_char=20) sns.catplot(y='Main domain', x='acc_diff', s=3, jitter=True, data=results_df, ax=axes[0]) axes[0].set_xlabel('') axes[0].set_ylabel('') axes[0].axvline(0, c='k', alpha=0.2) sns.boxplot(x='acc_diff', data=diff_df, ax=axes[1]) sns.swarmplot(x='acc_diff', data=diff_df, color="0", size=2, ax=axes[1]) axes[1].axvline(0, c='k', alpha=0.2) axes[1].set_xlabel('Accuracy difference') fig.subplots_adjust(wspace=0, hspace=0.02) plt.tight_layout() logger.info('Number of studies included in the accuracy improvement analysis: {}'.format( results_df.shape[0])) median = diff_df['acc_diff'].median() iqr = diff_df['acc_diff'].quantile(.75) - diff_df['acc_diff'].quantile(.25) logger.info('Median gain in accuracy: {:.6f}'.format(median)) logger.info('Interquartile range of the gain in accuracy: {:.6f}'.format(iqr)) best_improvement = diff_df.nlargest(3, 'acc_diff') logger.info('Best improvement in accuracy: {}, in {}'.format( best_improvement['acc_diff'].values[0], best_improvement['Citation'].values[0])) logger.info('Second best improvement in accuracy: {}, in {}'.format( best_improvement['acc_diff'].values[1], best_improvement['Citation'].values[1])) logger.info('Third best improvement in accuracy: {}, in {}'.format( best_improvement['acc_diff'].values[2], best_improvement['Citation'].values[2])) if save_cfg is not None: savename = 'reported_accuracy_per_domain' fname = os.path.join(save_cfg['savepath'], savename) fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return axes def _plot_results_stats_impact_on_acc_diff(results_df, save_cfg): """Run statistical analysis to see which data items correlate with acc diff. NOTE: This analysis is not perfectly accurate as there are several papers which contrasted results based on data items (e.g., testing the impact of number of layers on performance), but our summaries are not at this level of granularity. Therefore the results are not to be taken at face value. """ binary_data_items = {'Preprocessing (clean)': ['Yes', 'No'], 'Artefact handling (clean)': ['Yes', 'No'], 'Features (clean)': ['Raw EEG', 'Frequency-domain'], 'Regularization (clean)': ['Yes', 'N/M'], 'Intra/Inter subject': ['Intra', 'Inter']} multiclass_data_items = ['Architecture', # Architecture (clean)', 'Optimizer (clean)'] continuous_data_items = {'Layers (clean)': False, 'Data - subjects': True, 'Data - time': True, 'Data - samples': True} results = dict() for key, val in binary_data_items.items(): results[key] = ut.run_mannwhitneyu(results_df, key, val, plot=True) for i in multiclass_data_items: results[i] = ut.run_kruskal(results_df, i, plot=True) for i in continuous_data_items: single_df = ut.keep_single_valued_rows(results_df, i) single_df = single_df[single_df[i] != 'N/M'] single_df[i] = single_df[i].astype(float) results[i] = ut.run_spearmanr(single_df, i, log=val, plot=True) stats_df = pd.DataFrame(results).T logger.info('Results of statistical tests on impact of data items:\n{}'.format( stats_df)) # Categorical plot for each "significant" data item significant_items = stats_df[stats_df['pvalue'] < 0.05].index if save_cfg is not None and len(significant_items) > 0: for i in significant_items: savename = 'stat_impact_{}_on_acc_diff'.format( i.replace(' ', '_').replace('/', '_')) fname = os.path.join(save_cfg['savepath'], savename) stats_df.loc[i, 'fig'].savefig( fname + '.' + save_cfg['format'], **save_cfg) return None def _compute_acc_diff_for_preprints(results_df, save_cfg): """Analyze the acc diff for preprints vs. peer-reviewed articles. """ results_df['preprint'] = results_df['Journal / Origin'].isin(['Arxiv', 'BioarXiv']) preprints = results_df.groupby('Citation').first()['preprint'].value_counts() logger.info( 'Number of preprints included in the accuracy difference comparison: ' '{}/{} papers'.format(preprints[True], sum(preprints))) logger.info('Median acc diff for preprints vs. non-preprint:\n{}'.format( results_df.groupby('preprint').median()['acc_diff'])) results = ut.run_mannwhitneyu(results_df, 'preprint', [True, False]) logger.info('Mann-Whitney test on preprint vs. not preprint: {:0.3f}'.format( results['pvalue'])) return results def generate_wordcloud(df, save_cfg=cfg.saving_config): brain_mask = np.array(Image.open("./img/brain_stencil.png")) def transform_format(val): if val == 0: return 255 else: return val text = (df['Title']).to_string() stopwords = set(STOPWORDS) stopwords.add("using") stopwords.add("based") wc = WordCloud( background_color="white", max_words=2000, max_font_size=50, mask=brain_mask, stopwords=stopwords, contour_width=1, contour_color='steelblue') wc.generate(text) # store to file if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'DL-EEG_WordCloud') wc.to_file(fname + '.' + save_cfg['format']) #, **save_cfg) def plot_model_inspection_and_table(df, cutoff=1, save_cfg=cfg.saving_config): """Make bar graph and table listing method inspection techniques. Args: df (DataFrame) Keyword Args: cutoff (int): Metrics with less than this number of papers will be cut off from the bar graph. save_cfg (dict) """ df['inspection_list'] = df[ 'Model inspection (clean)'].str.split(',').apply(ut.lstrip) inspection_per_article = list() for i, items in df[['Citation', 'inspection_list']].iterrows(): for m in items['inspection_list']: inspection_per_article.append([i, items['Citation'], m]) inspection_df = pd.DataFrame( inspection_per_article, columns=['paper nb', 'Citation', 'inspection method']) # Remove "no" entries, because they make it really hard to see the # actual distribution n_nos = inspection_df['inspection method'].value_counts()['no'] n_papers = inspection_df.shape[0] logger.info('Number of papers without model inspection method: {}'.format(n_nos)) inspection_df = inspection_df[inspection_df['inspection method'] != 'no'] # # Replace "no" by "None" # inspection_df['inspection method'][ # inspection_df['inspection method'] == 'no'] = 'None' # Removing low count categories inspection_counts = inspection_df['inspection method'].value_counts() inspection_df = inspection_df[inspection_df['inspection method'].isin( inspection_counts[(inspection_counts >= cutoff)].index)] inspection_df['inspection method'] = inspection_df['inspection method'].apply( lambda x: x.capitalize()) print(inspection_df['inspection method']) # Making table inspection_table = inspection_df.groupby(['inspection method'])[ 'Citation'].apply(list) order = inspection_df['inspection method'].value_counts().index inspection_table = inspection_table.reindex(order) inspection_table = inspection_table.apply(lambda x: r'\cite{' + ', '.join(x) + '}') with open(os.path.join(save_cfg['table_savepath'], 'inspection_methods.tex'), 'w') as f: with pd.option_context("max_colwidth", 1000): f.write(inspection_table.to_latex(escape=False)) fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 3, save_cfg['text_height'] / 2)) ax = sns.countplot(y='inspection method', data=inspection_df, order=inspection_df['inspection method'].value_counts().index) ax.set_xlabel('Number of papers') ax.set_ylabel('') ax.xaxis.set_major_locator(MaxNLocator(integer=True)) plt.tight_layout() logger.info('% of studies that used model inspection techniques: {}'.format( 100 - 100 * (n_nos / n_papers))) if save_cfg is not None: savename = 'model_inspection' fname = os.path.join(save_cfg['savepath'], savename) fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_type_of_paper(df, save_cfg=cfg.saving_config): """Plot bar graph showing the type of each paper (journal, conference, etc.). """ # Move supplements to journal paper category for the plot (a value of one is # not visible on a bar graph). df_plot = df.copy() df_plot.loc[df['Type of paper'] == 'Supplement', :] = 'Journal' fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4, save_cfg['text_height'] / 5)) sns.countplot(x=df_plot['Type of paper'], ax=ax) ax.set_xlabel('') ax.set_ylabel('Number of papers') ax.set_xticklabels(ax.get_xticklabels(), rotation=90) plt.tight_layout() counts = df['Type of paper'].value_counts() logger.info('Number of journal papers: {}'.format(counts['Journal'])) logger.info('Number of conference papers: {}'.format(counts['Conference'])) logger.info('Number of preprints: {}'.format(counts['Preprint'])) logger.info('Number of papers that were initially published as preprints: ' '{}'.format(df[df['Type of paper'] != 'Preprint'][ 'Preprint first'].value_counts()['Yes'])) if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'type_of_paper') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_country(df, save_cfg=cfg.saving_config): """Plot bar graph showing the country of the first author's affiliation. """ fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 3, save_cfg['text_height'] / 5)) sns.countplot(x=df['Country'], ax=ax, order=df['Country'].value_counts().index) ax.set_ylabel('Number of papers') ax.set_xlabel('') ax.set_xticklabels(ax.get_xticklabels(), rotation=90) plt.tight_layout() top3 = df['Country'].value_counts().index[:3] logger.info('Top 3 countries of first author affiliation: {}'.format(top3.values)) if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'country') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_countrymap(dfx, postprocess=True, save_cfg=cfg.saving_config): """Plot world map with colour indicating number of papers. Plot a world map where the colour of each country indicates how many papers were published in which the first author's affiliation was from that country. When saved as .eps this figure is well over the 6 MB limit allowed by arXiv. To solve this, we first save it as a .png (with high enough dpi), then use inkscape to convert it to .eps (leading to a file of ~1.6 MB): >> inkscape countrymap.png --export-eps=countrymap.eps Keyword Args: postpocess (bool): if True, convert PNG to EPS using inkscape in a system call. """ dirname = os.path.dirname(__file__) shapefile = os.path.join(dirname, '../img/countries/ne_10m_admin_0_countries.shp') gdf = gpd.read_file(shapefile)[['ADMIN', 'geometry']] #.to_crs('+proj=robin') # gdf = gdf.to_crs(epsg=4326) gdf.crs = '+init=epsg:4326' dfx = dfx.Country.value_counts().reset_index().rename( columns={'index': 'Country', 'Country': 'Count'}) #print("Renaming Exceptions!") #print(dfx.loc[~dfx['Country'].isin(gdf['ADMIN'])]) # Exception #1 - USA: United States of America dfx.loc[dfx['Country'] == 'USA', 'Country'] = 'United States of America' # Exception #2 - UK: United Kingdom dfx.loc[dfx['Country'] == 'UK', 'Country'] = 'United Kingdom' # Exception #3 - Bosnia: Bosnia and Herzegovina dfx.loc[dfx['Country'] == 'Bosnia', 'Country'] = 'Bosnia and Herzegovina' if len(dfx.loc[~dfx['Country'].isin(gdf['ADMIN'])]) > 0: print("## ERROR ## - Unhandled Countries!") # Adding 0 to all other countries! gdf['Count'] = 0 for c in gdf['ADMIN']: if any(dfx['Country'].str.contains(c)): gdf.loc[gdf['ADMIN'] == c, 'Count'] = dfx[ dfx['Country'].str.contains(c)]['Count'].values[0] else: gdf.loc[gdf['ADMIN'] == c, 'Count'] = 0 # figsize = (16, 10) figsize = (save_cfg['text_width'], save_cfg['text_height'] / 2) ax = gdf.plot(column='Count', figsize=figsize, cmap='Blues', scheme='Fisher_Jenks', k=10, legend=True, edgecolor='k', linewidth=0.3, categorical=False, vmin=0, legend_kwds={'loc': 'lower left', 'title': 'Number of studies', 'framealpha': 1}, rasterized=False) # Remove floating points in legend leg = ax.get_legend() for t in leg.get_texts(): t.set_text(t.get_text().replace('.00', '')) ax.set_axis_off() fig = ax.get_figure() plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'countrymap') save_cfg2 = save_cfg.copy() save_cfg2['dpi'] = 1000 save_cfg2['format'] = 'png' fig.savefig(fname + '.png', **save_cfg2) if postprocess: subprocess.call( ['inkscape', fname + '.png', '--export-eps=' + fname + '.eps']) return ax def compute_prct_statistical_tests(df): """Compute the number of studies that used statistical tests. """ prct = 100 - 100 * df['Statistical analysis of performance'].value_counts( )['No'] / df.shape[0] logger.info('% of studies that used statistical test: {}'.format(prct)) def make_domain_table(df, save_cfg=cfg.saving_config): """Make domain table that contains every reference. """ # Replace NaNs by ' ' in 'Domain 3' and 'Domain 4' columns df = ut.replace_nans_in_column(df, 'Domain 3', replace_by=' ') df = ut.replace_nans_in_column(df, 'Domain 4', replace_by=' ') cols = ['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4', 'Architecture (clean)'] df[cols] = df[cols].applymap(ut.tex_escape) # Make tuple of first 2 domain levels domains_df = df.groupby(cols)['Citation'].apply(list).apply( lambda x: '\cite{' + ', '.join(x) + '}').unstack() domains_df = domains_df.applymap( lambda x: ' ' if isinstance(x, float) and np.isnan(x) else x) fname = os.path.join(save_cfg['table_savepath'], 'domains_architecture_table.tex') with open(fname, 'w') as f: with pd.option_context("max_colwidth", 1000): f.write(domains_df.to_latex( escape=False, column_format='p{1.5cm}' * 4 + 'p{0.6cm}' * domains_df.shape[1])) def plot_preprocessing_proportions(df, save_cfg=cfg.saving_config): """Plot proportions for preprocessing-related data items. """ data = dict() data['(a) Preprocessing of EEG data'] = df[ 'Preprocessing (clean)'].value_counts().to_dict() data['(b) Artifact handling'] = df[ 'Artefact handling (clean)'].value_counts().to_dict() data['(c) Extracted features'] = df[ 'Features (clean)'].value_counts().to_dict() fig, ax = ut.plot_multiple_proportions( data, print_count=5, respect_order=['Yes', 'No', 'Other', 'N/M'], figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] / 7 * 2)) if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'preprocessing') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_hyperparams_proportions(df, save_cfg=cfg.saving_config): """Plot proportions for hyperparameter-related data items. """ data = dict() data['(a) Training procedure'] = df[ 'Training procedure (clean)'].value_counts().to_dict() data['(b) Regularization'] = df[ 'Regularization (clean)'].value_counts().to_dict() data['(c) Optimizer'] = df[ 'Optimizer (clean)'].value_counts().to_dict() fig, ax = ut.plot_multiple_proportions( data, print_count=5, respect_order=['Yes', 'No', 'Other', 'N/M'], figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] / 7 * 2)) if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'hyperparams') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_reproducibility_proportions(df, save_cfg=cfg.saving_config): """Plot proportions for reproducibility-related data items. """ df['Code hosted on'] = df['Code hosted on'].replace(np.nan, 'N/M', regex=True) df['Limited data'] = df['Limited data'].replace(np.nan, 'N/M', regex=True) df['Code available'] = df['Code available'].replace(np.nan, 'N/M', regex=True) data = dict() data['(a) Dataset availability'] = df[ 'Dataset accessibility'].value_counts().to_dict() data['(b) Code availability'] = df[ 'Code hosted on'].value_counts().to_dict() data['(c) Type of baseline'] = df[ 'Baseline model type'].value_counts().to_dict() df['reproducibility'] = 'Hard' df.loc[(df['Code available'] == 'Yes') & (df['Dataset accessibility'] == 'Public'), 'reproducibility'] = 'Easy' df.loc[(df['Code available'] == 'Yes') & (df['Dataset accessibility'] == 'Both'), 'reproducibility'] = 'Medium' df.loc[(df['Code available'] == 'No') & (df['Dataset accessibility'] == 'Private'), 'reproducibility'] = 'Impossible' data['(d) Reproducibility'] = df[ 'reproducibility'].value_counts().to_dict() logger.info('Stats on reproducibility - Dataset Accessibility: {}'.format( data['(a) Dataset availability'])) logger.info('Stats on reproducibility - Code Accessibility: {}'.format( df['Code available'].value_counts().to_dict())) logger.info('Stats on reproducibility - Code Hosted On: {}'.format( data['(b) Code availability'])) logger.info('Stats on reproducibility - Baseline: {}'.format( data['(c) Type of baseline'])) logger.info('Stats on reproducibility - Reproducibility Level: {}'.format( data['(d) Reproducibility'])) logger.info('Stats on reproducibility - Limited data: {}'.format( df['Limited data'].value_counts().to_dict())) logger.info('Stats on reproducibility - Shared their Code: {}'.format( df[df['Code available'] == 'Yes']['Citation'].to_dict())) fig, ax = ut.plot_multiple_proportions( data, print_count=5, respect_order=['Easy', 'Medium', 'Hard', 'Impossible'], figsize=(save_cfg['text_width'] / 4 * 4, save_cfg['text_height'] * 0.4)) if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'reproducibility') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_domains_per_year(df, save_cfg=cfg.saving_config): """Plot stacked bar graph of domains per year. """ fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_height'] / 4)) df['Year'] = df['Year'].astype('int32') main_domains = ['Epilepsy', 'Sleep', 'BCI', 'Affective', 'Cognitive', 'Improvement of processing tools', 'Generation of data'] domains_df = df[['Domain 1', 'Domain 2', 'Domain 3', 'Domain 4']] df['Main domain'] = [row[row.isin(main_domains)].values[0] if any(row.isin(main_domains)) else 'Others' for ind, row in domains_df.iterrows()] df.groupby(['Year', 'Main domain']).size().unstack('Main domain').plot( kind='bar', stacked=True, title='', ax=ax) ax.set_ylabel('Number of papers') ax.set_xlabel('') legend = plt.legend() for l in legend.get_texts(): l.set_text(ut.wrap_text(l.get_text(), max_char=14)) if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'domains_per_year') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_hardware(df, save_cfg=cfg.saving_config): """Plot bar graph showing the hardware used in the study. """ col = 'EEG Hardware' hardware_df = ut.split_column_with_multiple_entries( df, col, ref_col='Citation', sep=',', lower=False) # Remove N/Ms because they make it hard to see anything hardware_df = hardware_df[hardware_df[col] != 'N/M'] # Add low cost column hardware_df['Low-cost'] = False low_cost_devices = ['EPOC (Emotiv)', 'OpenBCI (OpenBCI)', 'Muse (InteraXon)', 'Mindwave Mobile (Neurosky)', 'Mindset (NeuroSky)'] hardware_df.loc[hardware_df[col].isin(low_cost_devices), 'Low-cost'] = True fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_height'] / 5 * 2)) sns.countplot(hue=hardware_df['Low-cost'], y=hardware_df[col], ax=ax, order=hardware_df[col].value_counts().index, dodge=False) # sns.catplot(row=hardware_df['low_cost'], y=hardware_df['hardware']) ax.set_xlabel('Number of papers') ax.set_ylabel('') plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'hardware') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_architectures(df, save_cfg=cfg.saving_config): """Plot bar graph showing the architectures used in the study. """ fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 3, save_cfg['text_width'] / 3)) colors = sns.color_palette() counts = df['Architecture (clean)'].value_counts() _, _, pct = ax.pie(counts.values, labels=counts.index, autopct='%1.1f%%', wedgeprops=dict(width=0.3, edgecolor='w'), colors=colors, pctdistance=0.55) for i in pct: i.set_fontsize(5) ax.axis('equal') plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'architectures') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_architectures_per_year(df, save_cfg=cfg.saving_config): """Plot stacked bar graph of architectures per year. """ fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 3 * 2, save_cfg['text_width'] / 3)) colors = sns.color_palette() df['Year'] = df['Year'].astype('int32') col_name = 'Architecture (clean)' df['Arch'] = df[col_name] order = df[col_name].value_counts().index counts = df.groupby(['Year', 'Arch']).size().unstack('Arch') counts = counts[order] counts.plot(kind='bar', stacked=True, title='', ax=ax, color=colors) ax.legend(loc='upper left', bbox_to_anchor=(1, 1)) ax.set_ylabel('Number of papers') ax.set_xlabel('') plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'architectures_per_year') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_architectures_vs_input(df, save_cfg=cfg.saving_config): """Plot stacked bar graph of architectures vs input type. """ fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 3)) df['Input'] = df['Features (clean)'] col_name = 'Architecture (clean)' df['Arch'] = df[col_name] order = df[col_name].value_counts().index counts = df.groupby(['Input', 'Arch']).size().unstack('Input') counts = counts.loc[order, :] # To reduce the height of the figure, wrap long xticklabels counts = counts.rename({'CNN+RNN': 'CNN+\nRNN'}, axis='index') counts.plot(kind='bar', stacked=True, title='', ax=ax) # ax.legend(loc='upper left', bbox_to_anchor=(1, 1)) ax.set_ylabel('Number of papers') ax.set_xlabel('') plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'architectures_vs_input') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) save_cfg2 = save_cfg.copy() save_cfg2['format'] = 'png' fig.savefig(fname + '.png', **save_cfg2) return ax def plot_optimizers_per_year(df, save_cfg=cfg.saving_config): """Plot stacked bar graph of optimizers per year. """ fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 5 * 2)) df['Input'] = df['Features (clean)'] col_name = 'Optimizer (clean)' df['Opt'] = df[col_name] order = df[col_name].value_counts().index counts = df.groupby(['Year', 'Opt']).size().unstack('Opt') counts = counts[order] counts.plot(kind='bar', stacked=True, title='', ax=ax) ax.legend(loc='upper left', bbox_to_anchor=(1, 1)) ax.set_ylabel('Number of papers') ax.set_xlabel('') plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'optimizers_per_year') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_intra_inter_per_year(df, save_cfg=cfg.saving_config): """Plot stacked bar graph of intra-/intersubject studies per year. """ fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_height'] / 4)) df['Year'] = df['Year'].astype(int) col_name = 'Intra/Inter subject' order = df[col_name].value_counts().index counts = df.groupby(['Year', col_name]).size().unstack(col_name) counts = counts[order] logger.info('Stats on inter/intra subjects: {}'.format( df[col_name].value_counts() / df.shape[0] * 100)) counts.plot(kind='bar', stacked=True, title='', ax=ax) # ax.legend(loc='upper left', bbox_to_anchor=(1, 1)) ax.set_ylabel('Number of papers') ax.set_xlabel('') plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'intra_inter_per_year') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_number_layers(df, save_cfg=cfg.saving_config): """Plot histogram of number of layers. """ fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 3)) n_layers_df = df['Layers (clean)'].value_counts().reindex( [str(i) for i in range(1, 32)] + ['N/M']) n_layers_df = n_layers_df.dropna().astype(int) from matplotlib.colors import ListedColormap cmap = ListedColormap(sns.color_palette(None).as_hex()) n_layers_df.plot(kind='bar', width=0.8, rot=0, colormap=cmap, ax=ax) ax.set_xlabel('Number of layers') ax.set_ylabel('Number of papers') plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'number_layers') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) save_cfg2 = save_cfg.copy() save_cfg2['format'] = 'png' save_cfg2['dpi'] = 300 fig.savefig(fname + '.png', **save_cfg2) return ax def plot_number_subjects_by_domain(df, save_cfg=cfg.saving_config): """Plot number of subjects in studies by domain. """ # Split values into separate rows and remove invalid values col = 'Data - subjects' nb_subj_df = ut.split_column_with_multiple_entries( df, col, ref_col='Main domain') nb_subj_df = nb_subj_df.loc[~nb_subj_df[col].isin(['n/m', 'tbd'])] nb_subj_df[col] = nb_subj_df[col].astype(int) nb_subj_df = nb_subj_df.loc[nb_subj_df[col] > 0, :] nb_subj_df['Main domain'] = nb_subj_df['Main domain'].apply( ut.wrap_text, max_char=13) fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 3 * 2, save_cfg['text_height'] / 3)) ax.set(xscale='log', yscale='linear') sns.swarmplot( y='Main domain', x=col, data=nb_subj_df, ax=ax, size=3, order=nb_subj_df.groupby(['Main domain'])[ col].median().sort_values().index) ax.set_xlabel('Number of subjects') ax.set_ylabel('') logger.info('Stats on number of subjects per model: {}'.format( nb_subj_df[col].describe())) plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'nb_subject_per_domain') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def plot_number_channels(df, save_cfg=cfg.saving_config): """Plot histogram of number of channels. """ nb_channels_df = ut.split_column_with_multiple_entries( df, 'Nb Channels', ref_col='Citation', sep=';\n', lower=False) nb_channels_df['Nb Channels'] = nb_channels_df['Nb Channels'].astype(int) nb_channels_df = nb_channels_df.loc[nb_channels_df['Nb Channels'] > 0, :] fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 2, save_cfg['text_height'] / 4)) sns.distplot(nb_channels_df['Nb Channels'], kde=False, norm_hist=False, ax=ax) ax.set_xlabel('Number of EEG channels') ax.set_ylabel('Number of papers') logger.info('Stats on number of channels per model: {}'.format( nb_channels_df['Nb Channels'].describe())) plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'nb_channels') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def compute_stats_sampling_rate(df): """Compute the statistics for hardware sampling rate. """ fs_df = ut.split_column_with_multiple_entries( df, 'Sampling rate', ref_col='Citation', sep=';\n', lower=False) fs_df['Sampling rate'] = fs_df['Sampling rate'].astype(float) fs_df = fs_df.loc[fs_df['Sampling rate'] > 0, :] logger.info('Stats on sampling rate per model: {}'.format( fs_df['Sampling rate'].describe())) def plot_cross_validation(df, save_cfg=cfg.saving_config): """Plot bar graph of cross validation approaches. """ col = 'Cross validation (clean)' df[col] = df[col].fillna('N/M') cv_df = ut.split_column_with_multiple_entries( df, col, ref_col='Citation', sep=';\n', lower=False) fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 2, save_cfg['text_height'] / 5)) sns.countplot(y=cv_df[col], order=cv_df[col].value_counts().index, ax=ax) ax.set_xlabel('Number of papers') ax.set_ylabel('') plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'cross_validation') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax def make_dataset_table(df, min_n_articles=2, save_cfg=cfg.saving_config): """Make table that reports most used datasets. Args: df Keyword Args: min_n_articles (int): minimum number of times a dataset must have been used to be listed in the table. If under that number, will appear as 'Other' in the table. save_cfg (dict) """ def merge_dataset_names(s): if 'bci comp' in s.lower(): s = 'BCI Competition' elif 'tuh' in s.lower(): s = 'TUH' elif 'mahnob' in s.lower(): s = 'MAHNOB' return s col = 'Dataset name' datasets_df = ut.split_column_with_multiple_entries( df, col, ref_col=['Main domain', 'Citation'], sep=';\n', lower=False) # Remove not mentioned and internal recordings, as readers won't be able to # use these datasets anyway datasets_df = datasets_df.loc[~datasets_df[col].isin( ['N/M', 'Internal Recordings', 'TBD'])] datasets_df['Dataset'] = datasets_df[col].apply(merge_dataset_names).apply( ut.tex_escape) # Replace datasets that were used rarely by 'Other' counts = datasets_df['Dataset'].value_counts() datasets_df.loc[datasets_df['Dataset'].isin( counts[counts < min_n_articles].index), 'Dataset'] = 'Other' # Remove duplicates (due to grouping of Others and BCI Comp) datasets_df = datasets_df.drop(labels=col, axis=1) datasets_df = datasets_df.drop_duplicates() # Group by dataset and order by number of articles dataset_table = datasets_df.groupby( ['Main domain', 'Dataset'], as_index=True)['Citation'].apply(list) dataset_table = pd.concat([dataset_table.apply(len), dataset_table], axis=1) dataset_table.columns = [r'\# articles', 'References'] dataset_table = dataset_table.sort_values( by=['Main domain', r'\# articles'], ascending=[True, False]) dataset_table['References'] = dataset_table['References'].apply( lambda x: r'\cite{' + ', '.join(x) + '}') with open(os.path.join(save_cfg['table_savepath'], 'dataset_table.tex'), 'w') as f: with pd.option_context("max_colwidth", 1000): f.write(dataset_table.to_latex(escape=False, multicolumn=False)) def plot_data_quantity(df, save_cfg=cfg.saving_config): """Plot the quantity of data used by domain. """ data_df = ut.split_column_with_multiple_entries( df, ['Data - samples', 'Data - time'], ref_col=['Citation', 'Main domain'], sep=';\n', lower=False) # Remove N/M and TBD col = 'Data - samples' data_df.loc[data_df[col].isin(['N/M', 'TBD', '[TBD]']), col] = np.nan data_df[col] = data_df[col].astype(float) col2 = 'Data - time' data_df.loc[data_df[col2].isin(['N/M', 'TBD', '[TBD]']), col2] = np.nan data_df[col2] = data_df[col2].astype(float) # Wrap main domain text data_df['Main domain'] = data_df['Main domain'].apply( ut.wrap_text, max_char=13) # Extract ratio data_df['data_ratio'] = data_df['Data - samples'] / data_df['Data - time'] data_df = data_df.sort_values(['Main domain', 'data_ratio']) # Plot fig, axes = plt.subplots( ncols=3, figsize=(save_cfg['text_width'], save_cfg['text_height'] / 3)) axes[0].set(xscale='log', yscale='linear') sns.swarmplot(y='Main domain', x=col2, data=data_df, ax=axes[0], size=3) axes[0].set_xlabel('Recording time (min)') axes[0].set_ylabel('') max_val = int(np.ceil(np.log10(data_df[col2].max()))) axes[0].set_xticks(np.power(10, range(0, max_val + 1))) axes[1].set(xscale='log', yscale='linear') sns.swarmplot(y='Main domain', x=col, data=data_df, ax=axes[1], size=3) axes[1].set_xlabel('Number of examples') axes[1].set_yticklabels('') axes[1].set_ylabel('') min_val = int(np.floor(np.log10(data_df[col].min()))) max_val = int(np.ceil(np.log10(data_df[col].max()))) axes[1].set_xticks(np.power(10, range(min_val, max_val + 1))) axes[2].set(xscale='log', yscale='linear') sns.swarmplot(y='Main domain', x='data_ratio', data=data_df, ax=axes[2], size=3) axes[2].set_xlabel('Ratio (examples/min)') axes[2].set_ylabel('') axes[2].set_yticklabels('') min_val = int(np.floor(np.log10(data_df['data_ratio'].min()))) max_val = int(np.ceil(np.log10(data_df['data_ratio'].max()))) axes[2].set_xticks(np.power(10, np.arange(min_val, max_val + 1, dtype=float))) plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'data_quantity') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return axes def plot_eeg_intro(save_cfg=cfg.saving_config): """Plot a figure that shows basic EEG notions such as epochs and samples. """ # Visualization parameters win_len = 1 # in s step = 0.5 # in s first_epoch = 1 data, t, fs = ut.get_real_eeg_data(start=30, stop=34, chans=[0, 10, 20, 30]) t = t - t[0] # Offset data for visualization data -= data.mean(axis=0) max_std = np.max(data.std(axis=0)) offsets = np.arange(data.shape[1])[::-1] * 4 * max_std data += offsets rect_y_border = 0.6 * max_std min_y = data.min() - rect_y_border max_y = data.max() + rect_y_border # Make figure fig, ax = plt.subplots( figsize=(save_cfg['text_width'] / 4 * 3, save_cfg['text_height'] / 3)) ax.plot(t, data) ax.set_xlabel('Time (s)') ax.set_ylabel(r'Amplitude (e.g., $\mu$V)') ax.set_yticks(offsets) ax.set_yticklabels(['channel {}'.format(i + 1) for i in range(data.shape[1])]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) # Display epochs as dashed line rectangles rect1 = patches.Rectangle((first_epoch, min_y + rect_y_border / 4), win_len, max_y - min_y, linewidth=1, linestyle='--', edgecolor='k', facecolor='none') rect2 = patches.Rectangle((first_epoch + step, min_y - rect_y_border / 4), win_len, max_y - min_y, linewidth=1, linestyle='--', edgecolor='k', facecolor='none') ax.add_patch(rect1) ax.add_patch(rect2) # Annotate epochs ax.annotate( r'$\bf{Window}$ or $\bf{epoch}$ or $\bf{trial}$' + '\n({:.0f} points in a \n1-s window at {:.0f} Hz)'.format(fs, fs), #fontsize=14, xy=(first_epoch, min_y), arrowprops=dict(facecolor='black', shrink=0.05, width=2, headwidth=6), xytext=(0, min_y - 3.5 * max_std), xycoords='data', ha='center', va='top') # Annotate input ax.annotate(r'Neural network input' + '\n' r'$X_i \in \mathbb{R}^{c \times l}$', #fontsize=14, xy=(first_epoch+1.5, min_y), arrowprops=dict(facecolor='black', shrink=0.05, width=2), xytext=(4, min_y - 5.3 * max_std), xycoords='data', ha='right', va='bottom') # Annotate sample special_ind = np.where((t >= 2.4) & (t < 2.5))[0][0] special_point = data[special_ind, 0] ax.plot(t[special_ind], special_point, '.', c='k') ax.annotate( r'$\bf{Point}$ or $\bf{sample}$', #fontsize=14, xy=(t[special_ind], special_point), arrowprops=dict(facecolor='black', shrink=0.05, width=2, headwidth=6), xytext=(3, max_y), xycoords='data', ha='left', va='bottom') # Annotate overlap ut.draw_brace(ax, (first_epoch + step, first_epoch + step * 2), r'0.5-s $\bf{overlap}$' + '\nbetween windows', beta_factor=300, y_offset=max_y) plt.tight_layout() if save_cfg is not None: fname = os.path.join(save_cfg['savepath'], 'eeg_intro') fig.savefig(fname + '.' + save_cfg['format'], **save_cfg) return ax