import argparse
import re

import matplotlib.pyplot as plt
import nltk
import numpy
import torch
import torch.nn as nn
from torch.autograd import Variable

import data
import data_ptb
from utils import batchify, get_batch, repackage_hidden, evalb

from parse_comparison import corpus_stats_labeled, corpus_average_depth
from data_ptb import word_tags


criterion = nn.CrossEntropyLoss()
def evaluate(data_source, batch_size=1):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for i in range(0, data_source.size(0) - 1, args.bptt):
        data, targets = get_batch(data_source, i, args, evaluation=True)
        output, hidden = model(data, hidden)
        output = model.decoder(output)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, targets).data
        hidden = repackage_hidden(hidden)
    return total_loss / len(data_source)

def corpus2idx(sentence):
    arr = np.array([data.dictionary.word2idx[c] for c in sentence.split()], dtype=np.int32)
    return torch.from_numpy(arr[:, None]).long()


# Test model
def build_tree(depth, sen):
    assert len(depth) == len(sen)

    if len(depth) == 1:
        parse_tree = sen[0]
    else:
        idx_max = numpy.argmax(depth)
        parse_tree = []
        if len(sen[:idx_max]) > 0:
            tree0 = build_tree(depth[:idx_max], sen[:idx_max])
            parse_tree.append(tree0)
        tree1 = sen[idx_max]
        if len(sen[idx_max + 1:]) > 0:
            tree2 = build_tree(depth[idx_max + 1:], sen[idx_max + 1:])
            tree1 = [tree1, tree2]
        if parse_tree == []:
            parse_tree = tree1
        else:
            parse_tree.append(tree1)
    return parse_tree


# def build_tree(depth, sen):
#     assert len(depth) == len(sen)
#     assert len(depth) >= 0
#
#     if len(depth) == 1:
#         parse_tree = sen[0]
#     else:
#         idx_max = numpy.argmax(depth[1:]) + 1
#         parse_tree = []
#         if len(sen[:idx_max]) > 0:
#             tree0 = build_tree(depth[:idx_max], sen[:idx_max])
#             parse_tree.append(tree0)
#         if len(sen[idx_max:]) > 0:
#             tree1 = build_tree(depth[idx_max:], sen[idx_max:])
#             parse_tree.append(tree1)
#     return parse_tree


def get_brackets(tree, idx=0):
    brackets = set()
    if isinstance(tree, list) or isinstance(tree, nltk.Tree):
        for node in tree:
            node_brac, next_idx = get_brackets(node, idx)
            if next_idx - idx > 1:
                brackets.add((idx, next_idx))
                brackets.update(node_brac)
            idx = next_idx
        return brackets, idx
    else:
        return brackets, idx + 1

def MRG(tr):
    if isinstance(tr, str):
        #return '(' + tr + ')'
        return tr + ' '
    else:
        s = '( '
        for subtr in tr:
            s += MRG(subtr)
        s += ') '
        return s

def MRG_labeled(tr):
    if isinstance(tr, nltk.Tree):
        if tr.label() in word_tags:
            return tr.leaves()[0] + ' '
        else:
            s = '(%s ' % (re.split(r'[-=]', tr.label())[0])
            for subtr in tr:
                s += MRG_labeled(subtr)
            s += ') '
            return s
    else:
        return ''

def mean(x):
    return sum(x) / len(x)


def test(model, corpus, cuda, prt=False):
    model.eval()

    prec_list = []
    reca_list = []
    f1_list = []

    pred_tree_list = []
    targ_tree_list = []

    nsens = 0
    word2idx = corpus.dictionary.word2idx
    if args.wsj10:
        dataset = zip(corpus.train_sens, corpus.train_trees, corpus.train_nltktrees)
    else:
        dataset = zip(corpus.test_sens, corpus.test_trees, corpus.test_nltktrees)

    corpus_sys = {}
    corpus_ref = {}
    for sen, sen_tree, sen_nltktree in dataset:
        if args.wsj10 and len(sen) > 12:
            continue
        x = numpy.array([word2idx[w] if w in word2idx else word2idx['<unk>'] for w in sen])
        input = Variable(torch.LongTensor(x[:, None]))
        if cuda:
            input = input.cuda()

        hidden = model.init_hidden(1)
        _, hidden = model(input, hidden)

        distance = model.distance[0].squeeze().data.cpu().numpy()
        distance_in = model.distance[1].squeeze().data.cpu().numpy()

        nsens += 1
        if prt and nsens % 100 == 0:
            for i in range(len(sen)):
                print('%15s\t%s\t%s' % (sen[i], str(distance[:, i]), str(distance_in[:, i])))
            print('Standard output:', sen_tree)

        sen_cut = sen[1:-1]
        # gates = distance.mean(axis=0)
        for gates in [
            # distance[0],
            distance[1],
            # distance[2],
            # distance.mean(axis=0)
        ]:
            depth = gates[1:-1]
            parse_tree = build_tree(depth, sen_cut)

            corpus_sys[nsens] = MRG(parse_tree)
            corpus_ref[nsens] = MRG_labeled(sen_nltktree)

            pred_tree_list.append(parse_tree)
            targ_tree_list.append(sen_tree)

            model_out, _ = get_brackets(parse_tree)
            std_out, _ = get_brackets(sen_tree)
            overlap = model_out.intersection(std_out)

            prec = float(len(overlap)) / (len(model_out) + 1e-8)
            reca = float(len(overlap)) / (len(std_out) + 1e-8)
            if len(std_out) == 0:
                reca = 1.
                if len(model_out) == 0:
                    prec = 1.
            f1 = 2 * prec * reca / (prec + reca + 1e-8)
            prec_list.append(prec)
            reca_list.append(reca)
            f1_list.append(f1)

            if prt and nsens % 100 == 0:
                print('Model output:', parse_tree)
                print('Prec: %f, Reca: %f, F1: %f' % (prec, reca, f1))

        if prt and nsens % 100 == 0:
            print('-' * 80)

            f, axarr = plt.subplots(3, sharex=True, figsize=(distance.shape[1] // 2, 6))
            axarr[0].bar(numpy.arange(distance.shape[1])-0.2, distance[0], width=0.4)
            axarr[0].bar(numpy.arange(distance_in.shape[1])+0.2, distance_in[0], width=0.4)
            axarr[0].set_ylim([0., 1.])
            axarr[0].set_ylabel('1st layer')
            axarr[1].bar(numpy.arange(distance.shape[1]) - 0.2, distance[1], width=0.4)
            axarr[1].bar(numpy.arange(distance_in.shape[1]) + 0.2, distance_in[1], width=0.4)
            axarr[1].set_ylim([0., 1.])
            axarr[1].set_ylabel('2nd layer')
            axarr[2].bar(numpy.arange(distance.shape[1]) - 0.2, distance[2], width=0.4)
            axarr[2].bar(numpy.arange(distance_in.shape[1]) + 0.2, distance_in[2], width=0.4)
            axarr[2].set_ylim([0., 1.])
            axarr[2].set_ylabel('3rd layer')
            plt.sca(axarr[2])
            plt.xlim(xmin=-0.5, xmax=distance.shape[1] - 0.5)
            plt.xticks(numpy.arange(distance.shape[1]), sen, fontsize=10, rotation=45)

            plt.savefig('figure/%d.png' % (nsens))
            plt.close()

    prec_list, reca_list, f1_list \
        = numpy.array(prec_list).reshape((-1,1)), numpy.array(reca_list).reshape((-1,1)), numpy.array(f1_list).reshape((-1,1))
    if prt:
        print('-' * 80)
        numpy.set_printoptions(precision=4)
        print('Mean Prec:', prec_list.mean(axis=0),
              ', Mean Reca:', reca_list.mean(axis=0),
              ', Mean F1:', f1_list.mean(axis=0))
        print('Number of sentence: %i' % nsens)

        correct, total = corpus_stats_labeled(corpus_sys, corpus_ref)
        print(correct)
        print(total)
        print('ADJP:', correct['ADJP'], total['ADJP'])
        print('NP:', correct['NP'], total['NP'])
        print('PP:', correct['PP'], total['PP'])
        print('INTJ:', correct['INTJ'], total['INTJ'])
        print(corpus_average_depth(corpus_sys))

        evalb(pred_tree_list, targ_tree_list)

    return f1_list.mean(axis=0)


if __name__ == '__main__':
    marks = [' ', '-', '=']

    numpy.set_printoptions(precision=2, suppress=True, linewidth=5000)

    parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')

    # Model parameters.
    parser.add_argument('--data', type=str, default='data/ptb',
                        help='location of the data corpus')
    parser.add_argument('--checkpoint', type=str, default='PTB.pt',
                        help='model checkpoint to use')
    parser.add_argument('--seed', type=int, default=1111,
                        help='random seed')
    parser.add_argument('--cuda', action='store_true',
                        help='use CUDA')
    parser.add_argument('--wsj10', action='store_true',
                        help='use WSJ10')
    args = parser.parse_args()
    args.bptt = 70

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)

    # Load model
    with open(args.checkpoint, 'rb') as f:
        model, _, _ = torch.load(f)
        torch.cuda.manual_seed(args.seed)
        model.cpu()
        if args.cuda:
            model.cuda()

    # Load data
    import hashlib

    fn = 'corpus.{}.data'.format(hashlib.md5('data/penn'.encode()).hexdigest())
    print('Loading cached dataset...')
    corpus = torch.load(fn)
    dictionary = corpus.dictionary

    # test_batch_size = 1
    # test_data = batchify(corpus.test, test_batch_size, args)
    # test_loss = evaluate(test_data, test_batch_size)
    # print('=' * 89)
    # print('| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'.format(
    #     test_loss, math.exp(test_loss), test_loss / math.log(2)))
    # print('=' * 89)

    print('Loading PTB dataset...')
    corpus = data_ptb.Corpus(args.data)
    corpus.dictionary = dictionary

    test(model, corpus, args.cuda, prt=True)