import numpy as np
import os
import random
import torch
from torch.autograd import Variable
from nltk.translate.bleu_score import SmoothingFunction
import nltk


def calc_bleu(reference, hypothesis):
    weights = (0.25, 0.25, 0.25, 0.25)
    return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weights,
                                                   smoothing_function=SmoothingFunction().method1)


def load_human_answer(data_path):
    ans = []
    file_list = [
        data_path + 'reference.0',
        data_path + 'reference.1',
    ]
    for file in file_list:
        with open(file) as f:
            for line in f:
                line = line.strip()
                line = line.split('\t')[1].split()
                parse_line = [int(x) for x in line]
                ans.append(parse_line)
    return ans


def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


def id2text_sentence(sen_id, id_to_word):
    sen_text = []
    max_i = len(id_to_word)
    for i in sen_id:
        if i == 3:  # id_eos
            break
        if i >= max_i:
            i = 1  # UNK
        sen_text.append(id_to_word[i])
    return ' '.join(sen_text)


def to_var(x, volatile=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, volatile=volatile)


def get_cuda(tensor):
    # if torch.cuda.is_available():
    #     tensor = tensor
    return tensor.cuda()


def load_word_dict_info(word_dict_file, max_num):
    id_to_word = []
    with open(word_dict_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            item = line.strip()
            item_list = item.split('\t')
            word = item_list[0]
            if len(item_list) > 1:
                num = int(item_list[1])
                if num < max_num:
                    break
            id_to_word.append(word)
    print("Load word-dict with %d size and %d max_num." % (len(id_to_word), max_num))
    return id_to_word, len(id_to_word)


def load_data1(file1):
    token_stream = []
    with open(file1, 'r') as f:
        for line in f:
            line = line.strip()
            line = line.split()
            parse_line = [int(x) for x in line]
            token_stream.append(parse_line)
    return token_stream


def prepare_data(data_path, max_num, task_type):
    print("prepare data ...")
    id_to_word, vocab_size = load_word_dict_info(data_path + 'word_to_id.txt', max_num)

    # define train / test file
    # train_file_list = []
    # train_label_list = []
    train_file_list = [
        data_path + 'sentiment.train.0', data_path + 'sentiment.train.1',
        data_path + 'sentiment.dev.0', data_path + 'sentiment.dev.1',
    ]
    train_label_list = [
        [0],
        [1],
        [0],
        [1],
    ]


    return id_to_word, vocab_size, train_file_list, train_label_list


def pad_batch_seuqences(origin_seq, sos_id, eos_id, unk_id, max_seq_length, vocab_size):
    '''padding with 0, mask id_num > vocab_size with unk_id.'''
    max_l = 0
    for i in origin_seq:
        max_l = max(max_l, len(i))

    max_l = min(max_seq_length, max_l + 1)

    encoder_input_seq = np.zeros((len(origin_seq), max_l-1), dtype=int)
    decoder_input_seq = np.zeros((len(origin_seq), max_l), dtype=int)
    decoder_target_seq = np.zeros((len(origin_seq), max_l), dtype=int)
    encoder_input_seq_length = np.zeros((len(origin_seq)), dtype=int)
    decoder_input_seq_length = np.zeros((len(origin_seq)), dtype=int)
    for i in range(len(origin_seq)):
        decoder_input_seq[i][0] = sos_id
        for j in range(min(max_l-1, len(origin_seq[i]))):
            this_id = origin_seq[i][j]
            if this_id >= vocab_size:
                this_id = unk_id
            encoder_input_seq[i][j] = this_id
            decoder_input_seq[i][j + 1] = this_id
            decoder_target_seq[i][j] = this_id
        encoder_input_seq_length[i] = min(max_l-1, len(origin_seq[i]))
        decoder_input_seq_length[i] = min(max_l, len(origin_seq[i]) + 1)
        decoder_target_seq[i][decoder_input_seq_length[i]-1] = eos_id
    return encoder_input_seq, decoder_input_seq, decoder_target_seq, encoder_input_seq_length, decoder_input_seq_length


class non_pair_data_loader():
    def __init__(self, batch_size, id_bos, id_eos, id_unk, max_sequence_length, vocab_size):
        self.sentences_batches = []
        self.labels_batches = []

        self.src_batches = []
        self.src_mask_batches = []
        self.tgt_batches = []
        self.tgt_y_batches = []
        self.tgt_mask_batches = []
        self.ntokens_batches = []

        self.num_batch = 0
        self.batch_size = batch_size
        self.pointer = 0
        self.id_bos = id_bos
        self.id_eos = id_eos
        self.id_unk = id_unk
        self.max_sequence_length = max_sequence_length
        self.vocab_size = vocab_size


    def create_batches(self, train_file_list, train_label_list, if_shuffle=True):
        self.data_label_pairs = []
        for _index in range(len(train_file_list)):
            with open(train_file_list[_index]) as fin:
                for line in fin:
                    line = line.strip()
                    line = line.split()
                    parse_line = [int(x) for x in line]
                    self.data_label_pairs.append([parse_line, train_label_list[_index]])

        if if_shuffle:
            random.shuffle(self.data_label_pairs)

        # Split batches
        if self.batch_size == None:
            self.batch_size = len(self.data_label_pairs)
        self.num_batch = int(len(self.data_label_pairs) / self.batch_size)
        for _index in range(self.num_batch):
            item_data_label_pairs = self.data_label_pairs[_index*self.batch_size:(_index+1)*self.batch_size]
            item_sentences = [_i[0] for _i in item_data_label_pairs]
            item_labels = [_i[1] for _i in item_data_label_pairs]

            batch_encoder_input, batch_decoder_input, batch_decoder_target, \
            batch_encoder_length, batch_decoder_length = pad_batch_seuqences(
                item_sentences, self.id_bos, self.id_eos, self.id_unk, self.max_sequence_length, self.vocab_size,)

            src = get_cuda(torch.tensor(batch_encoder_input, dtype=torch.long))
            tgt = get_cuda(torch.tensor(batch_decoder_input, dtype=torch.long))
            tgt_y = get_cuda(torch.tensor(batch_decoder_target, dtype=torch.long))

            src_mask = (src != 0).unsqueeze(-2)
            tgt_mask = self.make_std_mask(tgt, 0)
            ntokens = (tgt_y != 0).data.sum().float()

            # For debug
            # print("item_sentences", item_sentences)
            # print("item_labels", item_labels)
            # print("src", src)
            # print("tgt", tgt)
            # print("tgt_y", tgt_y)
            # print("batch_encoder_length", batch_encoder_length)
            # print("batch_decoder_length", batch_decoder_length)
            # print("src_mask", src_mask)
            # print("tgt_mask", tgt_mask)
            # print("ntokens", ntokens.float())
            # input("--------------")

            self.sentences_batches.append(item_sentences)
            self.labels_batches.append(get_cuda(torch.tensor(item_labels, dtype=torch.float)))
            self.src_batches.append(src)
            self.tgt_batches.append(tgt)
            self.tgt_y_batches.append(tgt_y)
            self.src_mask_batches.append(src_mask)
            self.tgt_mask_batches.append(tgt_mask)
            self.ntokens_batches.append(ntokens)

        self.pointer = 0
        print("Load data from %s !\nCreate %d batches with %d batch_size" % (
            ' '.join(train_file_list), self.num_batch, self.batch_size
        ))

    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Variable(
            subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        return tgt_mask


    def next_batch(self):
        """take next batch by self.pointer"""
        this_batch_sentences = self.sentences_batches[self.pointer]
        this_batch_labels = self.labels_batches[self.pointer]

        this_src = self.src_batches[self.pointer]
        this_src_mask = self.src_mask_batches[self.pointer]
        this_tgt = self.tgt_batches[self.pointer]
        this_tgt_y = self.tgt_y_batches[self.pointer]
        this_tgt_mask = self.tgt_mask_batches[self.pointer]
        this_ntokens = self.ntokens_batches[self.pointer]

        self.pointer = (self.pointer + 1) % self.num_batch
        return this_batch_sentences, this_batch_labels, \
               this_src, this_src_mask, this_tgt, this_tgt_y, \
               this_tgt_mask, this_ntokens


    def reset_pointer(self):
        self.pointer = 0


if __name__ == '__main__':



    class Batch:
        "Object for holding a batch of data with mask during training."

        def __init__(self, src, trg=None, pad=0):
            self.src = src
            self.src_mask = (src != pad).unsqueeze(-2)
            if trg is not None:
                self.trg = trg[:, :-1]
                self.trg_y = trg[:, 1:]
                self.trg_mask = \
                    self.make_std_mask(self.trg, pad)
                self.ntokens = (self.trg_y != pad).data.sum()

        @staticmethod
        def make_std_mask(tgt, pad):
            "Create a mask to hide padding and future words."
            tgt_mask = (tgt != pad).unsqueeze(-2)
            tgt_mask = tgt_mask & Variable(
                subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
            return tgt_mask


    def data_gen(V, batch, nbatches):
        "Generate random data for a src-tgt copy task."
        for i in range(nbatches):
            data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))
            data[:, 0] = 1
            src = Variable(data, requires_grad=False)
            tgt = Variable(data, requires_grad=False)
            yield Batch(src, tgt, 0)


    for i in range(100):
        print("%d ----- " % i)
        data_iter = data_gen(10, 3, 2)
        for j, batch in enumerate(data_iter):
            print("%d:", j)
            print(batch.src)
            print(batch.src_mask)
            print(batch.trg)
            print(batch.trg_y)
            print(batch.trg_mask)
            input("=====")