import argparse

import torch
from allennlp.common.params import Params
from allennlp.data.dataset_readers import DatasetReader
from transformers import BertTokenizer
from razdel import sentenize

from summarus.readers import *


class BertData:
    def __init__(self, bert_model, lower, max_src_tokens, max_tgt_tokens):
        self.max_src_tokens = max_src_tokens
        self.max_tgt_tokens = max_tgt_tokens
        self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=lower)
        self.sep_token = '[SEP]'
        self.cls_token = '[CLS]'
        self.pad_token = '[PAD]'
        self.tgt_bos = '[unused1] '
        self.tgt_eos = ' [unused2]'
        self.tgt_sent_split = ' [unused3] '
        self.sep_vid = self.tokenizer.vocab[self.sep_token]
        self.cls_vid = self.tokenizer.vocab[self.cls_token]
        self.pad_vid = self.tokenizer.vocab[self.pad_token]

    def preprocess(self, src, tgt):
        src_txt = [' '.join(s) for s in src]
        text = ' {} {} '.format(self.sep_token, self.cls_token).join(src_txt)
        src_tokens = self.tokenizer.tokenize(text)[:self.max_src_tokens]
        src_tokens.insert(0, self.cls_token)
        src_tokens.append(self.sep_token)
        src_indices = self.tokenizer.convert_tokens_to_ids(src_tokens)

        _segs = [-1] + [i for i, t in enumerate(src_indices) if t == self.sep_vid]
        segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
        segments_ids = []
        for i, s in enumerate(segs):
            if i % 2 == 0:
                segments_ids += s * [0]
            else:
                segments_ids += s * [1]
        cls_ids = [i for i, t in enumerate(src_indices) if t == self.cls_vid]

        tgt_txt = ' <q> '.join([' '.join(sentence) for sentence in tgt])
        tgt_tokens = [' '.join(self.tokenizer.tokenize(' '.join(sentence))) for sentence in tgt]
        tgt_tokens_str = self.tgt_bos + self.tgt_sent_split.join(tgt_tokens) + self.tgt_eos
        tgt_tokens = tgt_tokens_str.split()[:self.max_tgt_tokens]
        tgt_indices = self.tokenizer.convert_tokens_to_ids(tgt_tokens)

        return src_indices, tgt_indices, segments_ids, cls_ids, src_txt, tgt_txt


def preprocess(config_path, file_path, save_path, bert_path, max_src_tokens, max_tgt_tokens, lower=False, nrows=None):
    bert = BertData(bert_path, lower, max_src_tokens, max_tgt_tokens)
    params = Params.from_file(config_path)
    reader_params = params.pop("reader", default=Params({}))
    reader = DatasetReader.from_params(reader_params)
    data = []
    for i, (text, summary) in enumerate(reader.parse_set(file_path)):
        if nrows is not None and i >= nrows:
            break
        src = [(s.text.lower() if lower else s.text).split() for s in sentenize(text)]
        tgt = [(s.text.lower() if lower else s.text).split() for s in sentenize(summary)]
        src_indices, tgt_indices, segments_ids, cls_ids, src_txt, tgt_txt = bert.preprocess(src, tgt)
        b_data_dict = {
            "src": src_indices, "tgt": tgt_indices,
            "segs": segments_ids, 'clss': cls_ids,
            'src_txt': src_txt, "tgt_txt": tgt_txt
        }
        data.append(b_data_dict)
    torch.save(data, save_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config-path', type=str, required=True)
    parser.add_argument('--file-path', type=str, required=True)
    parser.add_argument('--save-path', type=str, required=True)
    parser.add_argument('--bert-path', type=str, required=True)
    parser.add_argument('--lower', action='store_true')
    parser.add_argument('--max-src-tokens', type=int, default=600)
    parser.add_argument('--max-tgt-tokens', type=int, default=200)
    parser.add_argument('--nrows', type=int, default=None)
    args = parser.parse_args()
    preprocess(**vars(args))