import os
import pdb
import sys
import json
import time
import random
from pprint import pprint, pformat

sys.path.append('..')

from anikattu.logger import CMDFilter
import logging
from pprint import pprint, pformat

logging.basicConfig(format=config.FORMAT_STRING)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

import config

from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable
import torch

from anikattu.trainer import Trainer, Feeder, Predictor
from anikattu.datafeed import DataFeed, MultiplexedDataFeed
from anikattu.utilz import tqdm, ListTable

from functools import partial

from collections import namedtuple, defaultdict
import itertools

from utilz import Sample
from utilz import PAD,  word_tokenize
from utilz import VOCAB, LABELS
from utilz import rotate

from anikattu.utilz import initialize_task
from anikattu.utilz import pad_seq
from anikattu.utilz import logger
from anikattu.vocab import Vocab
from anikattu.tokenstring import TokenString
from anikattu.utilz import LongVar, Var, init_hidden
import numpy as np

import re
import glob

SELF_NAME = os.path.basename(__file__).replace('.py', '')

def build_sample(raw_sample):
    pass

def prep_samples(dataset):
    ret = []
    vocabulary = defaultdict(int)
    labels = defaultdict(int)

    for i, sample in tqdm(enumerate(dataset)):
        try:
            sample = build_sample(sample)
            if not sample.label in LABELS:
                continue
            for token in sample.sentence:
                vocabulary[token] += 1
            labels[sample.label] += 1
            ret.append(sample)
        except KeyboardInterrupt:
            return
        except:
            log.exception('at id: {}'.format(i))

    return ret, vocabulary, labels


# ## Loss and accuracy function
def loss(output, batch, loss_function, *args, **kwargs):
    indices, (sentence, ), (label, ) = batch
    output, attn = output
    return loss_function(output, label)

def accuracy(output, batch, *args, **kwargs):
    indices, (sentence), (label, ) = batch
    output, attn = output
    return (output.max(dim=1)[1] == label).sum().float()/label.size(0)

def waccuracy(output, batch, *args, **kwargs):
    indices, (sentence, ), (label, ) = batch
    output, attn = output
    index = label
    src = Var(torch.ones(label.size()))
    
    acc_nomin = Var(torch.zeros(output.size(1)))
    acc_denom = Var(torch.ones(output.size(1)))

    acc_denom.scatter_add_(0, index, (label == label).float() )
    acc_nomin.scatter_add_(0, index, (label == output.max(1)[1]).float())

    accuracy = acc_nomin / acc_denom

    #pdb.set_trace()
    return accuracy.mean()

def f1score(output, input_, *args, **kwargs):

    indices, (seq, ) , (target,) = input_
    output, attn = output
    batch_size, class_size  = output.size()

    tp, tn, fp, fn = Var([0]), Var([0]), Var([0]), Var([0])
    p, r, f1 = Var([0]), Var([0]), Var([0])

    i = output
    t = target
    i = i.max(dim=1)[1]
    log.debug('output:{}'.format(pformat(i)))
    log.debug('target:{}'.format(pformat(t)))
    i_ = i
    t_ = t
    tp_ = ( i_ * t_ ).sum().float()
    fp_ = ( i_ > t_ ).sum().float()
    fn_ = ( i_ < t_ ).sum().float()

    i_ = i == 0
    t_ = t == 0
    tn_ = ( i_ * t_ ).sum().float()

    tp += tp_
    tn += tn_
    fp += fp_
    fn += fn_

    log.debug('tp_: {}\n fp_:{} \n fn_: {}\n tn_: {}'.format(tp_, fp_, fn_, tn_))


    if tp_.data.item() > 0:
        p_ = tp_ / (tp_ + fp_)
        r_ = tp_ / (tp_ + fn_)
        f1 += 2 * p_ * r_/ (p_ + r_)
        p += p_
        r += r_

    return (tp, fn, fp, tn), (p), (r), (f1)

def repr_function(output, batch, VOCAB, LABELS):
    indices, (sentence,), (label,) = batch
    
    results = []
    output, attn = output
    output = output.max(1)[1]
    output = output.cpu().numpy()
    for idx, c, a, o in zip(indices, sentence, label, output):
        #if not int(a) == int(o) and LABELS[a] == 'Y':
        if True:
            c = ' '.join([VOCAB[i] for i in c])
            a = ' '.join([LABELS[a]])
            o = ' '.join([LABELS[o]])
            
            results.append([ c, a, o, str(a == o) ])
        
    return results


def test_repr_function(output, batch, VOCAB, LABELS):
    indices, (sentence,), (label,) = batch
    
    results = []
    score, attn = output
    attn = attn.transpose(0, 1).squeeze(2)
    score, output = score.max(1)
    score = score.exp()
    for idx, c, a, o, s, at in zip(indices, sentence, label, output, score, attn):
        results.append([idx,
                        ' '.join([VOCAB[i] for i in c]),
                        ' '.join([LABELS[o]]),
                        '{:0.4f}'.format(s),
                        ','.join(['{:0.4f}'.format(i) for i in at.tolist()]),
                        repr([VOCAB[i] for i in c])
        ])
        
    return results

def batchop(datapoints, VOCAB, LABELS, *args, **kwargs):
    indices = [d.id for d in datapoints]
    sentence = []
    label = []

    for d in datapoints:
        sentence.append([VOCAB[w] for w in d.sentence] + [VOCAB['EOS']])
        #sentence.append([VOCAB[w] for w in d.sentence])
        label.append(LABELS[d.label])

    sentence    = LongVar(pad_seq(sentence))
    label   = LongVar(label)

    batch = indices, (sentence, ), (label, )
    return batch

class Base(nn.Module):
    def __init__(self, config, name):
        super(Base, self).__init__()
        self._name = name
        self.log = logging.getLogger(self._name)
        size_log_name = '{}.{}'.format(self._name, 'size')
        self.log.info('constructing logger: {}'.format(size_log_name))
        self.size_log = logging.getLogger(size_log_name)
        self.size_log.info('size_log')
        self.log.setLevel(logging.INFO)
        self.size_log.setLevel(logging.INFO)
        self.print_instance = 0
        
    def __(self, tensor, name='', print_instance=False):
        if isinstance(tensor, list) or isinstance(tensor, tuple):
            for i in range(len(tensor)):
                self.__(tensor[i], '{}[{}]'.format(name, i))
        else:
            self.size_log.debug('{} -> {}'.format(name, tensor.size()))
            if self.print_instance or print_instance:
                self.size_log.debug(tensor)

            
        return tensor

    def name(self, n):
        return '{}.{}'.format(self._name, n)


class BiLSTMModel(Base):
    pass

    
def experiment(config, ROOT_DIR, model, VOCAB, LABELS, datapoints=[[], [], []], eons=1000, epochs=20, checkpoint=1):
    try:
        name = SELF_NAME
        _batchop = partial(batchop, VOCAB=VOCAB, LABELS=LABELS)
        train_feed     = DataFeed(name, datapoints[0], batchop=_batchop, batch_size=config.HPCONFIG.batch_size)
        test_feed      = DataFeed(name, datapoints[1], batchop=_batchop, batch_size=config.HPCONFIG.batch_size)
        predictor_feed = DataFeed(name, datapoints[2], batchop=_batchop, batch_size=1)

        max_freq = max( LABELS.freq_dict[i] for i in LABELS.index2word  )
        loss_weight = [ 1/ ( LABELS.freq_dict[i]/ max_freq) for i in LABELS.index2word ]
        print(list((l, w) for l, w in zip(LABELS.index2word, loss_weight)))
        loss_weight = Var(loss_weight)
        
        loss_ = partial(loss, loss_function=nn.NLLLoss(loss_weight))
        trainer = Trainer(name=name,
                          model=model,
                          optimizer = optim.SGD(model.parameters(),
                                                lr=config.HPCONFIG.OPTIM.lr,
                                                momentum=config.HPCONFIG.OPTIM.momentum),
                          loss_function=loss_, accuracy_function=waccuracy, f1score_function=f1score,
                          checkpoint=checkpoint, epochs=epochs,
                          directory = ROOT_DIR,
                          feeder = Feeder(train_feed, test_feed))

        predictor = Predictor(model=model.clone(), feed=predictor_feed,
                              repr_function=partial(test_repr_function, VOCAB=VOCAB, LABELS=LABELS))
        
        for e in range(eons):

            if not trainer.train():
                raise Exception

            predictor.model.load_state_dict(trainer.best_model[1])
            
            dump = open('{}/results/eon_{}.csv'.format(ROOT_DIR, e), 'w')
            log.info('on {}th eon'.format(e))
            results = ListTable()
            for ri in tqdm(range(predictor_feed.num_batch)):
                output, _results = predictor.predict(ri)
                results.extend(_results)
            dump.write(repr(results))
            dump.close()
            
            


    except KeyboardInterrupt:
        return locals()
    except :
        log.exception('####################')
        return locals()
    
import sys
import pickle
if __name__ == '__main__':

    if sys.argv[1]:
        log.addFilter(CMDFilter(sys.argv[1]))

    ROOT_DIR = initialize_task(SELF_NAME)

    print('====================================')
    print(ROOT_DIR)
    print('====================================')
    
    if config.CONFIG.flush or 'flush' in sys.argv:
        log.info('flushing...')
        dataset = []
        with open('../dataset/dataset.csv') as f:
            for line in tqdm(f.readlines()):
                line = line.split('|')
                dataset.append(
                    Sample(
                        line[0], line[1], line[2]
                    )
                )
        dataset, vocabulary, labels =  prep_samples(dataset)
        pivot = int( config.CONFIG.split_ratio * len(dataset) )
        trainset, testset = dataset[:pivot], dataset[pivot:]
        pickle.dump([trainset, testset, dict(vocabulary), dict(labels)], open('{}__cache.pkl'.format(SELF_NAME), 'wb'))
    else:
        trainset, testset, _vocabulary, _labels = pickle.load(open('{}__cache.pkl'.format(SELF_NAME), 'rb'))
        vocabulary = defaultdict(int); labels = defaultdict(int)
        vocabulary.update(_vocabulary), labels.update(_labels)
        
    log.info('trainset size: {}'.format(len(trainset)))
    log.info('trainset[:10]: {}'.format(pformat(trainset[0])))

    pprint(labels)
    """
    log.info('vocabulary: {}'.format(
        pformat(
            sorted(
                vocabulary.items(), key=lambda x: x[1], reverse=True)
        )))
    """
    

    log.info(pformat(labels))
    VOCAB  = Vocab(vocabulary, VOCAB)
    LABELS = Vocab(labels, tokens=LABELS)
    pprint(LABELS.index2word)


    try:
        model =  BiLSTMModel(config, 'macnet', len(VOCAB),  len(LABELS))
        if config.CONFIG.cuda:  model = model.cuda()
        model.load_state_dict(torch.load('{}/weights/{}.{}'.format(ROOT_DIR, SELF_NAME, 'pth')))
        log.info('loaded the old image for the model')
    except:
        log.exception('failed to load the model')

    model.eval()
    print('**** the model', model, model.training)
    
    if 'train' in sys.argv:
        model.train()
        train_set = sorted(trainset, key=lambda x: -len(x.sentence))
        test_set  = sorted(testset, key=lambda x: -len(x.sentence))
        exp_image = experiment(config, ROOT_DIR,  model, VOCAB, LABELS, datapoints=[train_set, train_set + test_set, train_set + test_set])
        
    if 'predict' in sys.argv:
        print('=========== PREDICTION ==============')
        model.eval()
        count = 0
        while True:
            count += 1
            sentence = []
            input_string = word_tokenize(input('?').lower())
            sentence.append([VOCAB[w] for w in input_string] + [VOCAB['EOS']])
            dummy_label = LongVar([0])
            sentence = LongVar(sentence)
            input_ = [0], (sentence,), (0, )
            output, attn = model(input_)

            print(LABELS[output.max(1)[1]])

            if 'show_plot' in sys.argv or 'save_plot' in sys.argv:
                nwords = len(input_string)

                from matplotlib import pyplot as plt
                plt.figure(figsize=(20,10))
                plt.bar(range(nwords+1), attn.squeeze().data.cpu().numpy())
                plt.title('{}\n{}'.format(output.exp().tolist(), LABELS[output.max(1)[1]]))
                plt.xticks(range(nwords), input_string, rotation='vertical')
                if 'show_plot' in sys.argv:
                    plt.show()
                if 'save_plot' in sys.argv:
                    plt.savefig('{}.png'.format(count))
                plt.close()

            print('Done')
                
    if 'service' in sys.argv:
        model.eval()
        from flask import Flask,request,jsonify
        from flask_cors import CORS
        app = Flask(__name__)
        CORS(app)

        @app.route('/ade-genentech',methods=['POST'])
        def _predict():
           print(' requests incoming..')
           sentence = []
           try:
               input_string = word_tokenize(request.json["text"].lower())
               sentence.append([VOCAB[w] for w in input_string] + [VOCAB['EOS']])
               dummy_label = LongVar([0])
               sentence = LongVar(sentence)
               input_ = [0], (sentence,), (0, )
               output, attn = model(input_)
               #print(LABELS[output.max(1)[1]], attn)
               nwords = len(input_string)
               return jsonify({
                   "result": {
                       'sentence': input_string,
                       'attn': ['{:0.4f}'.format(i) for i in attn.squeeze().data.cpu().numpy().tolist()[:-1]],
                       'probs': ['{:0.4f}'.format(i) for i in output.exp().squeeze().data.cpu().numpy().tolist()],
                       'label': LABELS[output.max(1)[1].squeeze().data.cpu().numpy()]
                   }
               })
           
           except Exception as e:
               print(e)
               return jsonify({"result":"model failed"})

        print('model running on port:5010')
        app.run(host='0.0.0.0',port=5010)