import os
import re
import zipfile
import revtok
import torch
import io
import csv
import json
import glob
import hashlib
import unicodedata


from . import sst
from . import imdb
from . import snli
from . import translation

from .. import data


CONTEXT_SPECIAL = 'Context:'
QUESTION_SPECIAL = 'Question:'


def get_context_question(context, question):
    return CONTEXT_SPECIAL +  ' ' + context + ' ' + QUESTION_SPECIAL + ' ' + question


class CQA(data.Dataset):

    fields = ['context', 'question', 'answer', 'context_special', 'question_special', 'context_question']

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))
 

class IMDb(CQA, imdb.IMDb):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    def __init__(self, path, field, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]
        examples = []
        labels = {'neg': 'negative', 'pos': 'positive'}
        question = 'Is this review negative or positive?'

        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            for label in ['pos', 'neg']:
                for fname in glob.iglob(os.path.join(path, label, '*.txt')):
                    with open(fname, 'r') as f:
                        context = f.readline()
                    answer = labels[label]
                    context_question = get_context_question(context, question) 
                    examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields))
                    if subsample is not None and len(examples) > subsample:
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)
        super(imdb.IMDb, self).__init__(examples, fields, **kwargs)


    @classmethod
    def splits(cls, fields, root='.data',
               train='train', validation=None, test='test', **kwargs):
        assert validation is None
        path = cls.download(root)
        train_data = None if train is None else cls(
            os.path.join(path, f'{train}'), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}'), fields, **kwargs)
        return tuple(d for d in (train_data, test_data)
                     if d is not None)


class SST(CQA):

    urls = ['https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/train_binary_sent.csv',
        'https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/dev_binary_sent.csv',
        'https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/test_binary_sent.csv']
    name = 'sst'
    dirname = ''

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    def __init__(self, path, field, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]
        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))

        examples = []
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            labels = ['negative', 'positive']
            question = 'Is this review ' + labels[0] + ' or ' + labels[1] + '?'

            with io.open(os.path.expanduser(path), encoding='utf8') as f:
                next(f)
                for line in f:
                    parsed = list(csv.reader([line.rstrip('\n')]))[0]
                    context = parsed[-1]
                    answer = labels[int(parsed[0])] 
                    context_question = get_context_question(context, question) 
                    examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields))

                    if subsample is not None and len(examples) > subsample:
                        break
       
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)

        self.examples = examples
        super().__init__(examples, fields, **kwargs)

    @classmethod
    def splits(cls, fields, root='.data',
               train='train', validation='dev', test='test', **kwargs):
        path = cls.download(root)
        postfix = f'_binary_sent.csv'
        train_data = None if train is None else cls(
            os.path.join(path, f'{train}{postfix}'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, f'{validation}{postfix}'), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}{postfix}'), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class TranslationDataset(translation.TranslationDataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    def __init__(self, path, exts, field, subsample=None, **kwargs):
        """Create a TranslationDataset given paths and fields.

        Arguments:
            path: Common prefix of paths to the data files for both languages.
            exts: A tuple containing the extension to path for each language.
            fields$: fields for handling all columns
            Remaining keyword arguments: Passed to the constructor of
                data.Dataset.
        """
        fields = [(x, field) for x in self.fields]
        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))

        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            langs = {'.de': 'German', '.en': 'English', '.fr': 'French', '.ar': 'Arabic', '.cs': 'Czech'}
            source, target = langs[exts[0]], langs[exts[1]]
            src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)
            question = f'Translate from {source} to {target}'

            examples = []
            with open(src_path) as src_file, open(trg_path) as trg_file:
                for src_line, trg_line in zip(src_file, trg_file):
                    src_line, trg_line = src_line.strip(), trg_line.strip()
                    if src_line != '' and trg_line != '':
                        context = src_line
                        answer = trg_line
                        context_question = get_context_question(context, question) 
                        examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields))
                        if subsample is not None and len(examples) >= subsample:
                            break


            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)
        super(translation.TranslationDataset, self).__init__(examples, fields, **kwargs)


class Multi30k(TranslationDataset, CQA, translation.Multi30k):
    pass


class IWSLT(TranslationDataset, CQA, translation.IWSLT):
    pass


class SQuAD(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
            'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
            'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json',
            'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json',]
    name = 'squad'
    dirname = ''

    def __init__(self, path, field, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]
        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))

        examples, all_answers, q_ids = [], [], []
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples, all_answers, q_ids = torch.load(cache_name)
        else:
            with open(os.path.expanduser(path)) as f:
                squad = json.load(f)['data']
                for document in squad:
                    title = document['title']
                    paragraphs = document['paragraphs']
                    for paragraph in paragraphs:
                        context = paragraph['context']
                        qas = paragraph['qas']
                        for qa in qas:
                            question = ' '.join(qa['question'].split())
                            q_ids.append(qa['id'])
                            squad_id = len(all_answers)
                            context_question = get_context_question(context, question) 
                            if len(qa['answers']) == 0:
                                answer = 'unanswerable'
                                all_answers.append(['unanswerable'])
                                context = ' '.join(context.split())
                                ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                                ex.context_spans = [-1, -1]
                                ex.answer_start = -1
                                ex.answer_end = -1
                            else:
                                answer = qa['answers'][0]['text']
                                all_answers.append([a['text'] for a in qa['answers']])
                                #print('original: ', answer)
                                answer_start = qa['answers'][0]['answer_start']
                                answer_end = answer_start + len(answer) 
                                context_before_answer = context[:answer_start]
                                context_after_answer = context[answer_end:]
                                BEGIN = 'beginanswer ' 
                                END = ' endanswer'

                                tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
                                ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)

                                tokenized_answer = ex.answer
                                #print('tokenized: ', tokenized_answer)
                                for xi, x in enumerate(ex.context):
                                    if BEGIN in x: 
                                        answer_start = xi + 1
                                        ex.context[xi] = x.replace(BEGIN, '')
                                    if END in x: 
                                        answer_end = xi
                                        ex.context[xi] = x.replace(END, '')
                                new_context = []
                                original_answer_start = answer_start
                                original_answer_end = answer_end
                                indexed_with_spaces = ex.context[answer_start:answer_end]
                                if len(indexed_with_spaces) != len(tokenized_answer):
                                    import pdb; pdb.set_trace()

                                # remove spaces
                                for xi, x in enumerate(ex.context):
                                    if len(x.strip()) == 0:
                                        if xi <= original_answer_start:
                                            answer_start -= 1
                                        if xi < original_answer_end:
                                            answer_end -= 1
                                    else:
                                        new_context.append(x)
                                ex.context = new_context
                                ex.answer = [x for x in ex.answer if len(x.strip()) > 0] 
                                if len(ex.context[answer_start:answer_end]) != len(ex.answer):
                                    import pdb; pdb.set_trace()
                                ex.context_spans = list(range(answer_start, answer_end)) 
                                indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1]
                                if len(indexed_answer) != len(ex.answer):
                                    import pdb; pdb.set_trace()
                                if field.eos_token is not None:
                                    ex.context_spans += [len(ex.context)]
                                for context_idx, answer_word in zip(ex.context_spans, ex.answer):
                                    if context_idx == len(ex.context):
                                        continue
                                    if ex.context[context_idx] != answer_word:
                                        import pdb; pdb.set_trace()
                                ex.answer_start = ex.context_spans[0]
                                ex.answer_end = ex.context_spans[-1]
                            ex.squad_id = squad_id
                            examples.append(ex)
                            if subsample is not None and len(examples) > subsample:
                                break
                        if subsample is not None and len(examples) > subsample:
                            break
                    if subsample is not None and len(examples) > subsample:
                        break

            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save((examples, all_answers, q_ids), cache_name)


        FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, 
            lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
        fields.append(('context_spans', FIELD))
        fields.append(('answer_start', FIELD))
        fields.append(('answer_end', FIELD))
        fields.append(('squad_id', FIELD))

        super(SQuAD, self).__init__(examples, fields, **kwargs)
        self.all_answers = all_answers
        self.q_ids = q_ids


    @classmethod
    def splits(cls, fields, root='.data', description='squad1.1',
               train='train', validation='dev', test=None, **kwargs):
        """Create dataset objects for splits of the SQuAD dataset.
        Arguments:
            root: directory containing SQuAD data
            field: field for handling all columns
            train: The prefix of the train data. Default: 'train'.
            validation: The prefix of the validation data. Default: 'val'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        assert test is None
        path = cls.download(root)

        extension = 'v2.0.json' if '2.0' in description else 'v1.1.json'
        train = '-'.join([train, extension]) if train is not None else None
        validation = '-'.join([validation, extension]) if validation is not None else None

        train_data = None if train is None else cls(
            os.path.join(path, train), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, validation), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data)
                     if d is not None)


# https://github.com/abisee/cnn-dailymail/blob/8eace60f306dcbab30d1f1d715e379f07a3782db/make_datafiles.py
dm_single_close_quote = u'\u2019'
dm_double_close_quote = u'\u201d'
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence

def fix_missing_period(line):
  """Adds a period to a line that is missing a period"""
  if "@highlight" in line: return line
  if line=="": return line
  if line[-1] in END_TOKENS: return line
  return line + "."


class Summarization(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    def __init__(self, path, field, one_answer=True, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]
        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))

        examples = []
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            with open(os.path.expanduser(path)) as f:
                lines = f.readlines()
                for line in lines:
                    ex = json.loads(line)
                    context, question, answer = ex['context'], ex['question'], ex['answer']
                    context_question = get_context_question(context, question) 
                    ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                    examples.append(ex)
                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)

        super(Summarization, self).__init__(examples, fields, **kwargs)

    @classmethod
    def cache_splits(cls, path):

        for split in ['training', 'validation', 'test']:
            missing_stories, collected_stories = 0, 0
            split_file_name = os.path.join(path, f'{split}.jsonl')
            if os.path.exists(split_file_name):
                continue
            with open(split_file_name, 'w') as split_file:
                url_file_name = os.path.join(path, f'{cls.name}_wayback_{split}_urls.txt')
                with open(url_file_name) as url_file:
                    for url in url_file:
                        story_file_name = os.path.join(path, 'stories', 
                            f"{hashlib.sha1(url.strip().encode('utf-8')).hexdigest()}.story")
                        try:
                            story_file = open(story_file_name)
                        except EnvironmentError as e:
                            missing_stories += 1
                            print(e)
                            if os.path.exists(split_file_name):
                                os.remove(split_file_name)
                        else:
                            with story_file:
                                article, highlight = [], []
                                is_highlight = False
                                for line in story_file: 
                                    line = line.strip()
                                    if line == "":
                                        continue
                                    line = fix_missing_period(line)
                                    if line.startswith("@highlight"):
                                        is_highlight = True
                                    elif "@highlight" in line:
                                        raise
                                    elif is_highlight:
                                        highlight.append(line)
                                    else:
                                        article.append(line)
                                example = {'context': unicodedata.normalize('NFKC', ' '.join(article)), 
                                           'answer': unicodedata.normalize('NFKC', ' '.join(highlight)), 
                                           'question': 'What is the summary?'}
                                split_file.write(json.dumps(example)+'\n')
                                collected_stories += 1
                                if collected_stories % 1000 == 0:
                                    print(example) 
            print(f'Missing {missing_stories} stories')
            print(f'Collected {collected_stories} stories')


    @classmethod
    def splits(cls, fields, root='.data',
               train='training', validation='validation', test='test', **kwargs):
        path = cls.download(root)
        cls.cache_splits(path)

        train_data = None if train is None else cls(
            os.path.join(path, 'training.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, 'validation.jsonl'), fields, one_answer=False, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, 'test.jsonl'), fields, one_answer=False, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class DailyMail(Summarization):
    name = 'dailymail'
    dirname = 'dailymail'
    urls = [('https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs', 'dailymail_stories.tgz'),
            ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt', 'dailymail/dailymail_wayback_training_urls.txt'),
            ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_validation_urls.txt', 'dailymail/dailymail_wayback_validation_urls.txt'),
            ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_test_urls.txt', 'dailymail/dailymail_wayback_test_urls.txt')]


class CNN(Summarization):
    name = 'cnn'
    dirname = 'cnn'
    urls = [('https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ', 'cnn_stories.tgz'),
            ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_training_urls.txt', 'cnn/cnn_wayback_training_urls.txt'),
            ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_validation_urls.txt', 'cnn/cnn_wayback_validation_urls.txt'),
            ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_test_urls.txt', 'cnn/cnn_wayback_test_urls.txt')]



class Query:
    #https://github.com/salesforce/WikiSQL/blob/c2ed4f9b22db1cc2721805d53e6e76e07e2ccbdc/lib/query.py#L10

    agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
    cond_ops = ['=', '>', '<', 'OP']
    syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']

    def __init__(self, sel_index, agg_index, columns, conditions=tuple()):
        self.sel_index = sel_index
        self.agg_index = agg_index
        self.columns = columns
        self.conditions = list(conditions)

    def __repr__(self):
        rep = 'SELECT {agg} {sel} FROM table'.format(
            agg=self.agg_ops[self.agg_index],
            sel= self.columns[self.sel_index] if self.columns is not None else 'col{}'.format(self.sel_index),
        )
        if self.conditions:
            rep +=  ' WHERE ' + ' AND '.join(['{} {} {}'.format(self.columns[i], self.cond_ops[o], v) for i, o, v in self.conditions])
        return ' '.join(rep.split())

    @classmethod
    def from_dict(cls, d, t):
        return cls(sel_index=d['sel'], agg_index=d['agg'], columns=t, conditions=d['conds'])


class WikiSQL(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2']
    name = 'wikisql'
    dirname = 'data'

    def __init__(self, path, field, query_as_question=False, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]
        FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, 
            lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
        fields.append(('wikisql_id', FIELD))


        cache_name = os.path.join(os.path.dirname(path), '.cache', 'query_as_question' if query_as_question else 'query_as_context', os.path.basename(path), str(subsample))
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples, all_answers = torch.load(cache_name)
        else:

            expanded_path = os.path.expanduser(path)
            table_path = os.path.splitext(expanded_path)
            table_path = table_path[0] + '.tables' + table_path[1]
           
            with open(table_path) as tables_file:
                tables = [json.loads(line) for line in tables_file]
                id_to_tables = {x['id']: x for x in tables}

            all_answers = []
            examples = []
            with open(expanded_path) as example_file:
                for idx, line in enumerate(example_file):
                    entry = json.loads(line)
                    human_query = entry['question']
                    table = id_to_tables[entry['table_id']]
                    sql = entry['sql']
                    header = table['header']
                    answer = repr(Query.from_dict(sql, header))
                    context = (f'The table has columns {", ".join(table["header"])} ' +
                               f'and key words {", ".join(Query.agg_ops[1:] + Query.cond_ops + Query.syms)}')
                    if query_as_question:
                        question = human_query
                    else:
                        question = 'What is the translation from English to SQL?'
                        context += f'-- {human_query}'  
                    context_question = get_context_question(context, question) 
                    ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, idx], fields)
                    examples.append(ex)
                    all_answers.append({'sql': sql, 'header': header, 'answer': answer, 'table': table})
                    if subsample is not None and len(examples) > subsample:
                        break

            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save((examples, all_answers), cache_name)

        super(WikiSQL, self).__init__(examples, fields, **kwargs)
        self.all_answers = all_answers


    @classmethod
    def splits(cls, fields, root='.data',
               train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs):
        """Create dataset objects for splits of the SQuAD dataset.
        Arguments:
            root: directory containing SQuAD data
            field: field for handling all columns
            train: The prefix of the train data. Default: 'train'.
            validation: The prefix of the validation data. Default: 'val'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        path = cls.download(root)

        train_data = None if train is None else cls(
            os.path.join(path, train), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, validation), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, test), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class SRL(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['https://dada.cs.washington.edu/qasrl/data/wiki1.train.qa',
            'https://dada.cs.washington.edu/qasrl/data/wiki1.dev.qa',
            'https://dada.cs.washington.edu/qasrl/data/wiki1.test.qa']

    name = 'srl'
    dirname = ''

    @classmethod
    def clean(cls, s):
        closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm", " 'd", " 're"])
        opening_punctuation = set(['( ', '$ '])
        both_sides = set([' - '])
        s = ' '.join(s.split()).strip()
        s = s.replace('-LRB-', '(')
        s = s.replace('-RRB-', ')')
        s = s.replace('-LAB-', '<')
        s = s.replace('-RAB-', '>')
        s = s.replace('-AMP-', '&')
        s = s.replace('%pw', ' ')

        for p in closing_punctuation:
            s = s.replace(p, p.lstrip())
        for p in opening_punctuation:
            s = s.replace(p, p.rstrip())
        for p in both_sides:
            s = s.replace(p, p.strip())
        s = s.replace('``', '')
        s = s.replace('`', '')
        s = s.replace("''", '')
        s = s.replace('“', '')
        s = s.replace('”', '')
        s = s.replace(" '", '')
        return ' '.join(s.split()).strip()

    def __init__(self, path, field, one_answer=True, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]
        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))

        examples, all_answers = [], []
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples, all_answers = torch.load(cache_name)
        else:
            with open(os.path.expanduser(path)) as f:
                for line in f:
                    ex = json.loads(line)
                    t = ex['type']
                    aa = ex['all_answers']
                    context, question, answer = ex['context'], ex['question'], ex['answer']
                    context_question = get_context_question(context, question) 
                    ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                    examples.append(ex)
                    ex.squad_id = len(all_answers)
                    all_answers.append(aa)
                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save((examples, all_answers), cache_name)

        FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, 
            lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
        fields.append(('squad_id', FIELD))

        super(SRL, self).__init__(examples, fields, **kwargs)
        self.all_answers = all_answers


    @classmethod
    def cache_splits(cls, path, path_to_files, train='train', validation='dev', test='test'):

        for split in [train, validation, test]:
            split_file_name = os.path.join(path, f'{split}.jsonl')
            if os.path.exists(split_file_name):
                continue
            wiki_file = os.path.join(path, f'wiki1.{split}.qa')

            with open(split_file_name, 'w') as split_file:
                with open(os.path.expanduser(wiki_file)) as f:
                    def is_int(x):
                        try:
                            int(x)
                            return True
                        except:
                            return False

                    lines = []
                    for line in f.readlines():
                        line = ' '.join(line.split()).strip()
                        if len(line) == 0:
                            lines.append(line)
                            continue
                        if not 'WIKI1' in line.split('_')[0]:
                            if not is_int(line.split()[0]) or len(line.split()) > 3:
                                lines.append(line)

                    new_example = True
                    for line in lines:
                        line = line.strip() 
                        if new_example:
                            context = cls.clean(line)
                            new_example = False
                            continue
                        if len(line) == 0:
                            new_example = True
                            continue
                        question, answers = line.split('?')
                        question = cls.clean(line.split('?')[0].replace(' _', '') +'?') 
                        answer = cls.clean(answers.split('###')[0])
                        all_answers = [cls.clean(x) for x in answers.split('###')]
                        if answer not in context:
                            low_answer = answer[0].lower() + answer[1:]
                            up_answer = answer[0].upper() + answer[1:]
                            if low_answer in context or up_answer in context:
                                answer = low_answer if low_answer in context else up_answer
                            else:
                                if 'Darcy Burner' in answer:
                                    answer = 'Darcy Burner and other 2008 Democratic congressional candidates, in cooperation with some retired national security officials'
                                elif 'E Street Band' in answer:
                                    answer = 'plan to work with the E Street Band again in the future'
                                elif 'an electric sender' in answer:
                                    answer = 'an electronic sender'
                                elif 'the US army' in answer:
                                    answer = 'the US Army'
                                elif 'Rather than name the' in answer:
                                    answer = 'rather die than name the cause of his disease to his father'
                                elif answer.lower() in context:
                                    answer = answer.lower()
                                else:
                                    import pdb; pdb.set_trace()
                        assert answer in context
                        modified_all_answers = []
                        for a in all_answers:
                            if a not in context:
                                low_answer = a[0].lower() + a[1:]
                                up_answer = a[0].upper() + a[1:]
                                if low_answer in context or up_answer in context:
                                    a = low_answer if low_answer in context else up_answer
                                else:
                                    if 'Darcy Burner' in a:
                                        a = 'Darcy Burner and other 2008 Democratic congressional candidates, in cooperation with some retired national security officials'
                                    elif 'E Street Band' in a:
                                        a = 'plan to work with the E Street Band again in the future'
                                    elif 'an electric sender' in a:
                                        a = 'an electronic sender'
                                    elif 'the US army' in a:
                                        a = 'the US Army'
                                    elif 'Rather than name the' in a:
                                        a = 'rather die than name the cause of his disease to his father'
                                    elif a.lower() in context:
                                        a = a.lower()
                                    else:
                                        import pdb; pdb.set_trace()
                            assert a in context
                            modified_all_answers.append(a)
                        split_file.write(json.dumps({'context': context, 'question': question, 'answer': answer, 'type': 'wiki', 'all_answers': modified_all_answers})+'\n')
            

            

    @classmethod
    def splits(cls, fields, root='.data',
               train='train', validation='dev', test='test', **kwargs):
        path = cls.download(root)
        cls.cache_splits(path, None)

        train_data = None if train is None else cls(
            os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class WinogradSchema(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['https://s3.amazonaws.com/research.metamind.io/decaNLP/data/schema.txt']

    name = 'schema'
    dirname = ''

    def __init__(self, path, field, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]

        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            examples = []
            with open(os.path.expanduser(path)) as f:
                for line in f:
                    ex = json.loads(line)
                    context, question, answer = ex['context'], ex['question'], ex['answer']
                    context_question = get_context_question(context, question) 
                    ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                    examples.append(ex)
                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)

        super(WinogradSchema, self).__init__(examples, fields, **kwargs)

    @classmethod
    def cache_splits(cls, path):
        pattern = '\[.*\]'
        train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl'))
        if os.path.exists(train_jsonl):
            return

        def get_both_schema(context):
             variations = [x[1:-1].split('/') for x in re.findall(pattern, context)]
             splits = re.split(pattern, context)
             results = []
             for which_schema in range(2):
                 vs = [v[which_schema] for v in variations] 
                 context = ''
                 for idx in range(len(splits)):
                     context += splits[idx]
                     if idx < len(vs):
                         context += vs[idx]
                 results.append(context) 
             return results


        schemas = []
        with open(os.path.expanduser(os.path.join(path, 'schema.txt'))) as schema_file:
            schema = []
            for line in schema_file:
                if len(line.split()) == 0:
                    schemas.append(schema)
                    schema = []
                    continue 
                else:
                    schema.append(line.strip())

        examples = []
        for schema in schemas:
            context, question, answer = schema
            contexts = get_both_schema(context)
            questions = get_both_schema(question)
            answers = answer.split('/')
            for idx in range(2):
                answer = answers[idx]
                question = questions[idx] + f' {answers[0]} or {answers[1]}?'
                examples.append({'context': contexts[idx], 'question': question, 'answer': answer})

        traindev = examples[:-100]
        test = examples[-100:]
        train = traindev[:80]
        dev = traindev[80:]

        splits = ['train', 'validation', 'test']
        for split, examples in zip(splits, [train, dev, test]):
            with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file:
                for ex in examples:
                    split_file.write(json.dumps(ex)+'\n')


    @classmethod
    def splits(cls, fields, root='.data',
               train='train', validation='validation', test='test', **kwargs):
        path = cls.download(root)
        cls.cache_splits(path)

        train_data = None if train is None else cls(
            os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class WOZ(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_en.json',
            'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_test_de.json',
            'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_test_en.json',
            'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_de.json',
            'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_en.json',
            'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_validate_de.json',
            'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_validate_en.json']

    name = 'woz'
    dirname = ''

    def __init__(self, path, field, subsample=None, description='woz.en', **kwargs):
        fields = [(x, field) for x in self.fields]
        FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, 
            lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
        fields.append(('woz_id', FIELD))

        examples, all_answers = [], []
        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample), description)
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples, all_answers = torch.load(cache_name)
        else:
            with open(os.path.expanduser(path)) as f:
                for woz_id, line in enumerate(f):
                    ex = example_dict = json.loads(line)
                    if example_dict['lang'] in description:
                        context, question, answer = ex['context'], ex['question'], ex['answer']
                        context_question = get_context_question(context, question) 
                        all_answers.append((ex['lang_dialogue_turn'], answer))
                        ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, woz_id], fields)
                        examples.append(ex)

                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save((examples, all_answers), cache_name)

        super(WOZ, self).__init__(examples, fields, **kwargs)
        self.all_answers = all_answers

    @classmethod
    def cache_splits(cls, path, train='train', validation='validate', test='test'):
        train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl'))
        if os.path.exists(train_jsonl):
            return

        file_name_base = 'woz_{}_{}.json'
        question_base = "What is the change in state"
        for split in [train, validation, test]:
            with open (os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file:
                for lang in ['en', 'de']:
                    file_path = file_name_base.format(split, lang)
                    with open(os.path.expanduser(os.path.join(path, file_path))) as src_file:
                        dialogues = json.loads(src_file.read())
                        for di, d in enumerate(dialogues):
                            previous_state = {'inform': [], 'request': []}
                            turns = d['dialogue']
                            for ti, t in enumerate(turns):
                                question = 'What is the change in state?'
                                actions = []
                                for act in t['system_acts']:
                                    if isinstance(act, list):
                                        act = ': '.join(act)
                                    actions.append(act)
                                actions = ', '.join(actions)
                                if len(actions) > 0:
                                    actions += ' -- '
                                context = actions + t['transcript']
                                belief_state = t['belief_state']
                                delta_state = {'inform': [], 'request': []}
                                current_state = {'inform': [], 'request': []}
                                for item in belief_state:
                                    if 'slots' in item:
                                        slots = item['slots']
                                        for slot in slots:
                                            act = item['act']
                                            if act == 'inform':
                                                current_state['inform'].append(slot)
                                                if not slot in previous_state['inform']:
                                                    delta_state['inform'].append(slot)
                                                else:
                                                    prev_slot = previous_state['inform'][previous_state['inform'].index(slot)]
                                                    if prev_slot[1] != slot[1]:
                                                        delta_state['inform'].append(slot)
                                            else: 
                                                delta_state['request'].append(slot[1])
                                                current_state['request'].append(slot[1])
                                previous_state = current_state
                                answer = ''
                                if len(delta_state['inform']) > 0:
                                    answer = ', '.join([f'{x[0]}: {x[1]}' for x in delta_state['inform']])
                                answer += ';'
                                if len(delta_state['request']) > 0:
                                    answer += ' '
                                    answer += ', '.join(delta_state['request'])
                                ex = {'context': ' '.join(context.split()), 
                                     'question': ' '.join(question.split()), 'lang': lang,
                                     'answer': answer if len(answer) > 1 else 'None',
                                     'lang_dialogue_turn': f'{lang}_{di}_{ti}'}
                                split_file.write(json.dumps(ex)+'\n')


    @classmethod
    def splits(cls, fields, root='.data', train='train', validation='validate', test='test', **kwargs):
        path = cls.download(root)
        cls.cache_splits(path)

        train_data = None if train is None else cls(
            os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class MultiNLI(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip']

    name = 'multinli'
    dirname = 'multinli_1.0'

    def __init__(self, path, field, subsample=None, description='multinli.in.out', **kwargs):
        fields = [(x, field) for x in self.fields]

        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample), description)
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            examples = []
            with open(os.path.expanduser(path)) as f:
                for line in f:
                    ex = example_dict = json.loads(line)
                    if example_dict['subtask'] in description:
                        context, question, answer = ex['context'], ex['question'], ex['answer']
                        context_question = get_context_question(context, question) 
                        ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                        examples.append(ex)
                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)

        super(MultiNLI, self).__init__(examples, fields, **kwargs)

    @classmethod
    def cache_splits(cls, path, train='multinli_1.0_train', validation='mulinli_1.0_dev_{}', test='test'):
        train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl'))
        if os.path.exists(train_jsonl):
            return

        with open(os.path.expanduser(os.path.join(path, f'train.jsonl')), 'a') as split_file:
            with open(os.path.expanduser(os.path.join(path, f'multinli_1.0_train.jsonl'))) as src_file:
                for line in src_file:
                   ex = json.loads(line)
                   ex = {'context': f'Premise: "{ex["sentence1"]}"', 
                         'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?', 
                         'answer': ex['gold_label'], 
                         'subtask': 'multinli'}
                   split_file.write(json.dumps(ex)+'\n')

        with open(os.path.expanduser(os.path.join(path, f'validation.jsonl')), 'a') as split_file:
            for subtask in ['matched', 'mismatched']:
                with open(os.path.expanduser(os.path.join(path, 'multinli_1.0_dev_{}.jsonl'.format(subtask)))) as src_file:
                    for line in src_file:
                       ex = json.loads(line)
                       ex = {'context': f'Premise: "{ex["sentence1"]}"', 
                             'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?', 
                             'answer': ex['gold_label'], 
                             'subtask': 'in' if subtask == 'matched' else 'out'}
                       split_file.write(json.dumps(ex)+'\n')


    @classmethod
    def splits(cls, fields, root='.data', train='train', validation='validation', test='test', **kwargs):
        path = cls.download(root)
        cls.cache_splits(path)

        train_data = None if train is None else cls(
            os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class ZeroShotRE(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['http://nlp.cs.washington.edu/zeroshot/relation_splits.tar.bz2']
    dirname = 'relation_splits'
    name = 'zre'


    def __init__(self, path, field, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]

        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            examples = []
            with open(os.path.expanduser(path)) as f:
                for line in f:
                    ex = example_dict = json.loads(line)
                    context, question, answer = ex['context'], ex['question'], ex['answer']
                    context_question = get_context_question(context, question) 
                    ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                    examples.append(ex)

                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)

        super().__init__(examples, fields, **kwargs)

    @classmethod
    def cache_splits(cls, path, train='train', validation='dev', test='test'):
        train_jsonl = os.path.expanduser(os.path.join(path, f'{train}.jsonl'))
        if os.path.exists(train_jsonl):
            return

        base_file_name = '{}.0'
        for split in [train, validation, test]:
            src_file_name = base_file_name.format(split)
            with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file:
                with open(os.path.expanduser(os.path.join(path, src_file_name))) as src_file:
                    for line in src_file:
                       split_line = line.split('\t')
                       if len(split_line) == 4:
                           answer = ''
                           relation, question, subject, context = split_line
                       else:
                           relation, question, subject, context = split_line[:4]
                           answer = ', '.join(split_line[4:])
                       question = question.replace('XXX', subject)
                       ex = {'context': context, 
                             'question': question, 
                             'answer': answer if len(answer) > 0 else 'unanswerable'}
                       split_file.write(json.dumps(ex)+'\n')


    @classmethod
    def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs):
        path = cls.download(root)
        cls.cache_splits(path)

        train_data = None if train is None else cls(
            os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class OntoNotesNER(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['http://conll.cemantix.org/2012/download/ids/english/all/train.id',
            'http://conll.cemantix.org/2012/download/ids/english/all/development.id',
            'http://conll.cemantix.org/2012/download/ids/english/all/test.id']

    name = 'ontonotes.ner'
    dirname = ''

    @classmethod
    def clean(cls, s):
        closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " '", " n't ", " %"])
        opening_punctuation = set(['( ', '$ '])
        both_sides = set([' - '])
        s = ' '.join(s.split()).strip()
        s = s.replace(' /.', ' .')
        s = s.replace(' /?', ' ?')
        s = s.replace('-LRB-', '(')
        s = s.replace('-RRB-', ')')
        s = s.replace('-LAB-', '<')
        s = s.replace('-RAB-', '>')
        s = s.replace('-AMP-', '&')
        s = s.replace('%pw', ' ')

        for p in closing_punctuation:
            s = s.replace(p, p.lstrip())
        for p in opening_punctuation:
            s = s.replace(p, p.rstrip())
        for p in both_sides:
            s = s.replace(p, p.strip())
        s = s.replace('``', '"')
        s = s.replace("''", '"')
        quote_is_open = True
        quote_idx = s.find('"')
        raw = ''
        while quote_idx >= 0:
            start_enamex_open_idx = s.find('<ENAMEX')
            if start_enamex_open_idx > -1:
                end_enamex_open_idx = s.find('">') + 2
                if start_enamex_open_idx <= quote_idx  <= end_enamex_open_idx:
                    raw += s[:end_enamex_open_idx]
                    s = s[end_enamex_open_idx:]
                    quote_idx = s.find('"')
                    continue
            if quote_is_open:
                raw += s[:quote_idx+1]
                s = s[quote_idx+1:].strip()
                quote_is_open = False
            else:
                raw += s[:quote_idx].strip() + '"'
                s =  s[quote_idx+1:]
                quote_is_open = True
            quote_idx = s.find('"')
        raw += s

        return ' '.join(raw.split()).strip()

    def __init__(self, path, field, one_answer=True, subsample=None, path_to_files='.data/ontonotes-release-5.0/data/files', subtask='all', nones=True, **kwargs):
        fields = [(x, field) for x in self.fields]

        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample), subtask, str(nones))
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            examples = []
            with open(os.path.expanduser(path)) as f:
                for line in f:
                    example_dict = json.loads(line)  
                    t = example_dict['type']
                    a = example_dict['answer']
                    if (subtask == 'both' or t == subtask):
                        if a != 'None' or nones:
                            ex = example_dict
                            context, question, answer = ex['context'], ex['question'], ex['answer']
                            context_question = get_context_question(context, question) 
                            ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                            examples.append(ex)

                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)

        super(OntoNotesNER, self).__init__(examples, fields, **kwargs)


    @classmethod
    def cache_splits(cls, path, path_to_files, train='train', validation='development', test='test'):

        label_to_answer = {'PERSON': 'person',
                           'NORP': 'political',
                           'FAC': 'facility',
                           'ORG': 'organization',
                           'GPE': 'geopolitical',
                           'LOC': 'location',
                           'PRODUCT': 'product',
                           'EVENT': 'event',
                           'WORK_OF_ART': 'artwork',
                           'LAW': 'legal',
                           'LANGUAGE': 'language',
                           'DATE': 'date',
                           'TIME': 'time',
                           'PERCENT': 'percentage',
                           'MONEY': 'monetary',
                           'QUANTITY': 'quantitative',
                           'ORDINAL': 'ordinal',
                           'CARDINAL': 'cardinal'}
        
        pluralize = {'person': 'persons', 'political': 'political', 'facility': 'facilities', 'organization': 'organizations', 
                     'geopolitical': 'geopolitical', 'location': 'locations', 'product': 'products', 'event': 'events',
                     'artwork': 'artworks', 'legal': 'legal', 'language': 'languages', 'date': 'dates', 'time': 'times', 
                     'percentage': 'percentages', 'monetary': 'monetary', 'quantitative': 'quantitative', 'ordinal': 'ordinal',
                     'cardinal': 'cardinal'}

 
        for split in [train, validation, test]:
            split_file_name = os.path.join(path, f'{split}.jsonl')
            if os.path.exists(split_file_name):
                continue
            id_file = os.path.join(path, f'{split}.id')

            num_file_ids = 0
            examples = []
            with open(split_file_name, 'w') as split_file:
                with open(os.path.expanduser(id_file)) as f:
                    for file_id in f:
                        example_file_name = os.path.join(os.path.expanduser(path_to_files), file_id.strip()) + '.name'
                        if not os.path.exists(example_file_name) or 'annotations/tc/ch' in example_file_name:
                            continue
                        num_file_ids += 1
                        with open(example_file_name) as example_file:
                            lines = [x.strip() for x in example_file.readlines() if 'DOC' not in x]
                            for line in lines:
                                original = line
                                line = cls.clean(line)
                                entities = []  
                                while True:
                                    start_enamex_open_idx = line.find('<ENAMEX')
                                    if start_enamex_open_idx == -1:
                                        break
                                    end_enamex_open_idx = line.find('">') + 2
                                    start_enamex_close_idx = line.find('</ENAMEX>')
                                    end_enamex_close_idx = start_enamex_close_idx + len('</ENAMEX>')
    
                                    enamex_open_tag = line[start_enamex_open_idx:end_enamex_open_idx]
                                    enamex_close_tag = line[start_enamex_close_idx:end_enamex_close_idx]
                                    before_entity = line[:start_enamex_open_idx]
                                    entity = line[end_enamex_open_idx:start_enamex_close_idx]
                                    after_entity = line[end_enamex_close_idx:]
    
                                    if 'S_OFF' in enamex_open_tag:
                                        s_off_start = enamex_open_tag.find('S_OFF="')
                                        s_off_end = enamex_open_tag.find('">') if 'E_OFF' not in enamex_open_tag else enamex_open_tag.find('" E_OFF')
                                        s_off = int(enamex_open_tag[s_off_start+len('S_OFF="'):s_off_end])
                                        enamex_open_tag = enamex_open_tag[:s_off_start-2] + '">'
                                        before_entity += entity[:s_off]
                                        entity = entity[s_off:]
    
                                    if 'E_OFF' in enamex_open_tag:
                                        s_off_start = enamex_open_tag.find('E_OFF="')
                                        s_off_end = enamex_open_tag.find('">')
                                        s_off = int(enamex_open_tag[s_off_start+len('E_OFF="'):s_off_end])
                                        enamex_open_tag = enamex_open_tag[:s_off_start-2] + '">'
                                        after_entity = entity[-s_off:] + after_entity
                                        entity = entity[:-s_off]
    
    
                                    label_start = enamex_open_tag.find('TYPE="') + len('TYPE="')
                                    label_end = enamex_open_tag.find('">')
                                    label = enamex_open_tag[label_start:label_end]
                                    assert label in label_to_answer

                                    offsets = (len(before_entity), len(before_entity) + len(entity))
                                    entities.append({'entity': entity, 'char_offsets': offsets, 'label': label})
                                    line = before_entity + entity + after_entity
                                
                                context = line.strip()
                                is_no_good = False
                                for entity_tuple in entities:
                                    entity = entity_tuple['entity']
                                    start, end = entity_tuple['char_offsets']
                                    if not context[start:end] == entity:
                                        is_no_good = True
                                        break
                                if is_no_good:
                                    print('Throwing out example that looks poorly labeled: ', original.strip(), ' (', file_id.strip(), ')')
                                    continue
                                question = 'What are the tags for all entities?'
                                answer = '; '.join([f'{x["entity"]} -- {label_to_answer[x["label"]]}' for x in entities]) 
                                if len(answer) == 0:
                                    answer = 'None'
                                split_file.write(json.dumps({'context': context, 'question': question, 'answer': answer, 'file_id': file_id.strip(), 
                                                             'original': original.strip(), 'entity_list': entities, 'type': 'all'})+'\n')
                                partial_question = 'Which entities are {}?'
 
                                for lab, ans in label_to_answer.items():
                                    question = partial_question.format(pluralize[ans])
                                    entity_of_type_lab = [x['entity'] for x in entities if x['label'] == lab]
                                    answer = ', '.join(entity_of_type_lab)
                                    if len(answer) == 0:
                                        answer = 'None'
                                    split_file.write(json.dumps({'context': context, 
                                                                 'question': question, 
                                                                 'answer': answer, 
                                                                 'file_id': file_id.strip(), 
                                                                 'original': original.strip(), 
                                                                 'entity_list': entities, 
                                                                 'type': 'one', 
                                                                 })+'\n')



    @classmethod
    def splits(cls, fields, root='.data',
               train='train', validation='development', test='test', **kwargs):
        path_to_files = os.path.join(root, 'ontonotes-release-5.0', 'data', 'files')
        assert os.path.exists(path_to_files)
        path = cls.download(root)
        cls.cache_splits(path, path_to_files)

        train_data = None if train is None else cls(
            os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)

class SNLI(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip']
    dirname = 'snli_1.0'
    name = 'snli'


    def __init__(self, path, field, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]

        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            examples = []
            with open(os.path.expanduser(path)) as f:
                for line in f:
                    example_dict = json.loads(line)
                    ex = example_dict
                    context, question, answer = ex['context'], ex['question'], ex['answer']
                    context_question = get_context_question(context, question) 
                    ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                    examples.append(ex)

                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)

        super().__init__(examples, fields, **kwargs)

    @classmethod
    def cache_splits(cls, path, train='train', validation='dev', test='test'):
        train_jsonl = os.path.expanduser(os.path.join(path, f'{train}.jsonl'))
        if os.path.exists(train_jsonl):
            return

        base_file_name = 'snli_1.0_{}.jsonl'
        for split in [train, validation, test]:
            src_file_name = base_file_name.format(split)
            with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file:
                with open(os.path.expanduser(os.path.join(path, src_file_name))) as src_file:
                    for line in src_file:
                       ex = json.loads(line)
                       ex = {'context': f'Premise: "{ex["sentence1"]}"', 
                             'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?', 
                             'answer': ex['gold_label']}
                       split_file.write(json.dumps(ex)+'\n')


    @classmethod
    def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs):
        path = cls.download(root)
        cls.cache_splits(path)

        train_data = None if train is None else cls(
            os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)


class JSON(CQA, data.Dataset):

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.context), len(ex.answer))

    def __init__(self, path, field, subsample=None, **kwargs):
        fields = [(x, field) for x in self.fields]
        cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))

        examples = []
        if os.path.exists(cache_name):
            print(f'Loading cached data from {cache_name}')
            examples = torch.load(cache_name)
        else:
            with open(os.path.expanduser(path)) as f:
                lines = f.readlines()
                for line in lines:
                    ex = json.loads(line)
                    context, question, answer = ex['context'], ex['question'], ex['answer']
                    context_question = get_context_question(context, question) 
                    ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
                    examples.append(ex)
                    if subsample is not None and len(examples) >= subsample: 
                        break
            os.makedirs(os.path.dirname(cache_name), exist_ok=True)
            print(f'Caching data to {cache_name}')
            torch.save(examples, cache_name)

        super(JSON, self).__init__(examples, fields, **kwargs)

    @classmethod
    def splits(cls, fields, name, root='.data',
               train='train', validation='val', test='test', **kwargs):
        path = os.path.join(root, name) 

        train_data = None if train is None else cls(
            os.path.join(path, 'train.jsonl'), fields, **kwargs)
        validation_data = None if validation is None else cls(
            os.path.join(path, 'val.jsonl'), fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, 'test.jsonl'), fields, **kwargs)
        return tuple(d for d in (train_data, validation_data, test_data)
                     if d is not None)