#!/usr/bin/env python
import os
import re
import sys
import ipdb
import glob
import random
import argparse
import datetime
import subprocess

# YAML setup
from ruamel.yaml import YAML
yaml = YAML()
yaml.preserve_quotes = True
yaml.boolean_representation = ['False', 'True']

import torch
import torch.nn as nn
from torch import cuda

import opts
import onmt
import onmt.io
import onmt.Models
import onmt.modules
import onmt.ModelConstructor
from onmt.Utils import use_gpu


parser = argparse.ArgumentParser(
    description='train.py',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# opts.py
opts.add_md_help_argument(parser)
opts.model_opts(parser)
opts.train_opts(parser)

opt = parser.parse_args()

if opt.word_vec_size != -1:
    opt.src_word_vec_size = opt.word_vec_size
    opt.tgt_word_vec_size = opt.word_vec_size

if opt.layers != -1:
    opt.enc_layers = opt.layers
    opt.dec_layers = opt.layers

opt.brnn = (opt.encoder_type == "brnn")
if opt.seed > 0:
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)

if torch.cuda.is_available() and not opt.gpuid:
    print("WARNING: You have a CUDA device, should run with -gpuid 0")

if opt.gpuid:
    cuda.set_device(opt.gpuid[0])
    if opt.seed > 0:
        torch.cuda.manual_seed(opt.seed)


def report_func(trnr, epoch, batch, num_batches, start_time, lr, report_stats, enc_output_dict, dec_output_dict, mem_dict):
    """
    This is the user-defined batch-level traing progress
    report function.

    Args:
        epoch(int): current epoch count.
        batch(int): current batch count.
        num_batches(int): total number of batches.
        start_time(float): last report time.
        lr(float): current learning rate.
        report_stats(Statistics): old Statistics instance.
    Returns:
        report_stats(Statistics): updated Statistics instance.
    """
    global iteration_log_loss_file
    global mem_log_file

    if batch % opt.report_every == -1 % opt.report_every:
        report_stats.output(epoch, batch+1, num_batches, start_time)

        if mem_dict:
            if opt.separate_buffers:
                enc_used_bits = mem_dict['enc_used_bits']
                dec_used_bits = mem_dict['dec_used_bits']
                enc_optimal_bits = mem_dict['enc_optimal_bits']
                dec_optimal_bits = mem_dict['dec_optimal_bits']
                enc_normal_bits = mem_dict['enc_normal_bits']
                dec_normal_bits = mem_dict['dec_normal_bits']

                enc_actual_ratio = enc_normal_bits / enc_used_bits
                enc_optimal_ratio = enc_normal_bits / enc_optimal_bits
                dec_actual_ratio = dec_normal_bits / dec_used_bits
                dec_optimal_ratio = dec_normal_bits / dec_optimal_bits

                print("Enc actual memory ratio {}".format(enc_actual_ratio))
                print("Enc optimal memory ratio {}".format(enc_optimal_ratio))
                print("Dec actual memory ratio {}".format(dec_actual_ratio))
                print("Dec optimal memory ratio {}".format(dec_optimal_ratio))

                mem_log_file.write('{} {} {} {} {} {}\n'.format(
                                    epoch, batch, enc_actual_ratio, enc_optimal_ratio, dec_actual_ratio, dec_optimal_ratio))
                mem_log_file.flush()
            else:
                used_bits = mem_dict['used_bits']
                optimal_bits = mem_dict['optimal_bits']
                normal_bits = mem_dict['normal_bits']

                actual_ratio = normal_bits / used_bits
                optimal_ratio = normal_bits / optimal_bits

                print("Actual memory ratio {}".format(actual_ratio))
                print("Optimal memory ratio {}".format(optimal_ratio))

                mem_log_file.write('{} {} {} {}\n'.format(
                                    epoch, batch, actual_ratio, optimal_ratio))
                mem_log_file.flush()


        if not opt.no_log_during_epoch:
            valid_stats = trnr.validate()

            iteration_log_loss_file.write('{} {} {} {} {} {}\n'.format(
                                           epoch, batch, report_stats.accuracy(), report_stats.ppl(), valid_stats.accuracy(),
                                           valid_stats.ppl()))
            iteration_log_loss_file.flush()

        report_stats = onmt.Statistics()

    return report_stats


def make_train_data_iter(train_dataset, opt):
    """
    This returns user-defined train data iterator for the trainer
    to iterate over during each train epoch. We implement simple
    ordered iterator strategy here, but more sophisticated strategy
    like curriculum learning is ok too.
    """
    # Sort batch by decreasing lengths of sentence required by pytorch.
    # sort=False means "Use dataset's sortkey instead of iterator's".
    return onmt.io.OrderedIterator(
                dataset=train_dataset, batch_size=opt.batch_size,
                device=opt.gpuid[0] if opt.gpuid else -1,
                sort=False, sort_within_batch=True, repeat=False)


def make_valid_data_iter(valid_dataset, opt):
    """
    This returns user-defined validate data iterator for the trainer
    to iterate over during each validate epoch. We implement simple
    ordered iterator strategy here, but more sophisticated strategy
    is ok too.
    """
    # Sort batch by decreasing lengths of sentence required by pytorch.
    # sort=False means "Use dataset's sortkey instead of iterator's".
    return onmt.io.OrderedIterator(
                dataset=valid_dataset, batch_size=opt.valid_batch_size,
                device=opt.gpuid[0] if opt.gpuid else -1,
                train=False, sort=False, sort_within_batch=True)


iteration_log_loss_file = None
mem_log_file = None


def train_model(model, train_dataset, valid_dataset, fields, optim, model_opt):

    global iteration_log_loss_file
    global mem_log_file

    train_iter = make_train_data_iter(train_dataset, opt)
    valid_iter = make_valid_data_iter(valid_dataset, opt)

    trunc_size = opt.truncated_decoder
    shard_size = opt.max_generator_batches
    data_type = train_dataset.data_type

    trainer = onmt.Trainer(model,
                           train_iter,
                           valid_iter,
                           fields["tgt"].vocab,
                           optim,
                           trunc_size,
                           shard_size,
                           data_type,
                           opt=model_opt)

    log_perp = open(os.path.join(opt.save_dir, 'log_perp'), 'w')
    iteration_log_loss_file = open(os.path.join(opt.save_dir, 'iteration_log_loss'), 'w')
    mem_log_file = open(os.path.join(opt.save_dir, 'mem_log'), 'w')

    best_val_ppl = 1e6
    best_val_checkpoint_path = None

    # You can press Ctrl+C at any time to exit training early, and print the best model checkpoint.
    try:
        for epoch in range(opt.start_epoch, opt.epochs + 1):
            print('')

            # 1. Train for one epoch on the training set.
            train_stats = trainer.train(epoch, report_func)
            print('Train perplexity: %g' % train_stats.ppl())
            print('Train accuracy: %g' % train_stats.accuracy())

            # 2. Validate on the validation set.
            valid_stats = trainer.validate()
            print('Validation perplexity: %g' % valid_stats.ppl())
            print('Validation accuracy: %g' % valid_stats.accuracy())

            # 3. Logging
            # Log to remote server.
            if opt.exp_host:
                train_stats.log("train", experiment, optim.lr)
                valid_stats.log("valid", experiment, optim.lr)

            # Write train and val perplexities to a file
            log_perp.write("{} {} {}\n".format(epoch, train_stats.ppl(), valid_stats.ppl()))
            log_perp.flush()

            # 4. Update the learning rate
            trainer.epoch_step(valid_stats.ppl(), epoch)

            # 5. Drop a checkpoint if needed.
            if epoch >= opt.start_checkpoint_at:
                checkpoint_path = trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)

                if valid_stats.ppl() < best_val_ppl:
                    best_val_ppl = valid_stats.ppl()
                    best_val_checkpoint_path = checkpoint_path
    except KeyboardInterrupt:
        print("Exiting from training early!")

    print('\nBest val checkpoint: {}'.format(best_val_checkpoint_path))

    return best_val_checkpoint_path


def extract_bleu_score(output_string):
    output_string = output_string.strip().split('\n')[-1]  # Gets last line of output, like ">> BLEU = 9.79, 37.2/13.3/6.6/3.1 (BP=0.978, ratio=0.978, hyp_len=11976, ref_len=12242)"
    output_string = output_string.split(',')[0]    # Gets ">> BLEU = 9.79"
    m = re.search('>> BLEU = (.+)', output_string)
    return m.group(1)


def evaluate(best_val_checkpoint_path):
    # python translate.py -src data/multi30k/test2016.en.atok -output pred.txt \
    #                     -replace_unk -tgt=data/multi30k/test2016.de.atok -report_bleu -gpu 2
    #                     -model saves/2018-02-09-enc:Rev-dec:Rev-et:RevGRU-dt:RevGRU-h:300-el:1-dl:1-em:300-atn:general-cxt:slice_emb-sl:20-ef1:0.875-ef2:0.875-df1:0.875-df2:0.875/best_checkpoint.pt

    base_dir = os.path.dirname(best_val_checkpoint_path)

    if '600' in best_val_checkpoint_path:
        test_output = subprocess.run(['python', 'translate.py', '-src', 'data/en-de/IWSLT16.TED.tst2014.en-de.en.tok.low',
                                      '-output', os.path.join(base_dir, 'test_pred.txt'), '-replace_unk', '-tgt', 'data/en-de/IWSLT16.TED.tst2014.en-de.de.tok.low',
                                      '-report_bleu', '-gpu', str(opt.gpuid[0]), '-model', best_val_checkpoint_path], stdout=subprocess.PIPE)

        test_output_string = test_output.stdout.decode('utf-8')
        print(test_output_string)

        # Also save the whole stdout string for reference
        with open(os.path.join(base_dir, 'test_stdout.txt'), 'w') as f:
            f.write('{}\n'.format(test_output_string))

        val_output = subprocess.run(['python', 'translate.py', '-src', 'data/en-de/IWSLT16.TED.tst2013.en-de.en.tok.low',
                                     '-output', os.path.join(base_dir, 'val_pred.txt'), '-replace_unk', '-tgt', 'data/en-de/IWSLT16.TED.tst2013.en-de.de.tok.low',
                                     '-report_bleu', '-gpu', str(opt.gpuid[0]), '-model', best_val_checkpoint_path], stdout=subprocess.PIPE)

        val_output_string = val_output.stdout.decode('utf-8')
        print(val_output_string)
    else:
        test_output = subprocess.run(['python', 'translate.py', '-src', 'data/multi30k/test2016.en.tok.low',
                                      '-output', os.path.join(base_dir, 'test_pred.txt'), '-replace_unk', '-tgt', 'data/multi30k/test2016.de.tok.low',
                                      '-report_bleu', '-gpu', str(opt.gpuid[0]), '-model', best_val_checkpoint_path], stdout=subprocess.PIPE)

        test_output_string = test_output.stdout.decode('utf-8')
        print(test_output_string)

        # Also save the whole stdout string for reference
        with open(os.path.join(base_dir, 'test_stdout.txt'), 'w') as f:
            f.write('{}\n'.format(test_output_string))

        val_output = subprocess.run(['python', 'translate.py', '-src', 'data/multi30k/val.en.tok.low',
                                     '-output', os.path.join(base_dir, 'val_pred.txt'), '-replace_unk', '-tgt', 'data/multi30k/val.de.tok.low',
                                     '-report_bleu', '-gpu', str(opt.gpuid[0]), '-model', best_val_checkpoint_path], stdout=subprocess.PIPE)

        val_output_string = val_output.stdout.decode('utf-8')
        print(val_output_string)

    # Also save the whole stdout string for reference
    with open(os.path.join(base_dir, 'val_stdout.txt'), 'w') as f:
        f.write('{}\n'.format(val_output_string))

    val_bleu = extract_bleu_score(val_output_string)
    test_bleu = extract_bleu_score(test_output_string)

    with open(os.path.join(base_dir, 'result.txt'), 'w') as f:
        f.write('{} {}\n'.format(val_bleu, test_bleu))

    print('Val BLEU: {} | Test BLEU: {}'.format(val_bleu, test_bleu))


def check_save_model_path():
    if not os.path.exists(opt.save_dir):
        os.makedirs(opt.save_dir)


def tally_parameters(model):
    n_params = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % n_params)
    enc = 0
    dec = 0
    for name, param in model.named_parameters():
        if 'encoder' in name:
            enc += param.nelement()
        elif 'decoder' or 'generator' in name:
            dec += param.nelement()
    print('encoder: ', enc)
    print('decoder: ', dec)


def load_dataset(data_type):
    assert data_type in ["train", "valid"]

    print("Loading %s data from '%s'" % (data_type, opt.data))

    pts = glob.glob(opt.data + '.' + data_type + '.[0-9]*.pt')
    if pts:
        # Multiple onmt.io.*Dataset's, coalesce all.
        # torch.load loads them imemediately, which might eat up
        # too much memory. A lazy load would be better, but later
        # when we create data iterator, it still requires these
        # data to be loaded. So it seams we don't have a good way
        # to avoid this now.
        datasets = []
        for pt in pts:
            datasets.append(torch.load(pt))
        dataset = onmt.io.ONMTDatasetBase.coalesce_datasets(datasets)
    else:
        # Only one onmt.io.*Dataset, simple!
        dataset = torch.load(opt.data + '.' + data_type + '.pt')

    print(' * number of %s sentences: %d' % (data_type, len(dataset)))

    return dataset


def load_fields(train_dataset, valid_dataset, checkpoint):
    data_type = train_dataset.data_type

    fields = onmt.io.load_fields_from_vocab(torch.load(opt.data + '.vocab.pt'), data_type)
    fields = dict([(k, f) for (k, f) in fields.items() if k in train_dataset.examples[0].__dict__])

    # We save fields in vocab.pt, so assign them back to dataset here.
    train_dataset.fields = fields
    valid_dataset.fields = fields

    if opt.train_from:
        print('Loading vocab from checkpoint at %s.' % opt.train_from)
        fields = onmt.io.load_fields_from_vocab(checkpoint['vocab'], data_type)

    if data_type == 'text':
        print(' * vocabulary size. source = %d; target = %d' %
              (len(fields['src'].vocab), len(fields['tgt'].vocab)))
    else:
        print(' * vocabulary size. target = %d' %
              (len(fields['tgt'].vocab)))

    return fields


def collect_report_features(fields):
    src_features = onmt.io.collect_features(fields, side='src')
    tgt_features = onmt.io.collect_features(fields, side='tgt')

    for j, feat in enumerate(src_features):
        print(' * src feature %d size = %d' % (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        print(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab)))


def build_model(model_opt, opt, fields, checkpoint):
    print('Building model...')
    model = onmt.ModelConstructor.make_base_model(model_opt, fields, use_gpu(opt), checkpoint)
    if len(opt.gpuid) > 1:
        print('Multi gpu training: ', opt.gpuid)
        model = nn.DataParallel(model, device_ids=opt.gpuid, dim=1)
    print(model)
    return model


def build_optim(model, checkpoint):
    if opt.train_from:
        print('Loading optimizer from checkpoint.')
        optim = checkpoint['optim']
        optim.optimizer.load_state_dict(
            checkpoint['optim'].optimizer.state_dict())
    else:
        optim = onmt.Optim(opt.optim,  # SGD by default
                           opt.learning_rate,
                           opt.max_grad_norm,
                           lr_decay=opt.learning_rate_decay,
                           start_decay_at=opt.start_decay_at,
                           beta1=opt.adam_beta1,
                           beta2=opt.adam_beta2,
                           adagrad_accum=opt.adagrad_accumulator_init,
                           decay_method=opt.decay_method,
                           warmup_steps=opt.warmup_steps,
                           model_size=opt.rnn_size)

    optim.set_parameters(model.parameters())

    return optim


def main():

    #######################################################################
    ### Create save folder
    ### Example: saves/2018-01-21-enc:CUDNN-dec:CUDNN-etype:LSTM-dtype:LSTM
    #######################################################################
    timestamp = '{:%Y-%m-%d}'.format(datetime.datetime.now())

    if opt.encoder_model == 'Rev' or opt.decoder_model == 'Rev':
        exp_name = '{}-enc:{}-dec:{}-et:{}-dt:{}-h:{}-el:{}-dl:{}-em:{}-atn:{}-cxt:{}-sl:{}-ef:{}-df:{}-di:{}-dh:{}-do:{}-ds:{}-lr:{}-init:{}-userev:{}'.format(
                    timestamp, opt.encoder_model, opt.decoder_model, opt.encoder_rnn_type, opt.decoder_rnn_type,
                    opt.rnn_size, opt.enc_layers, opt.dec_layers, opt.word_vec_size,
                    opt.global_attention, opt.context_type, opt.slice_dim,
                    opt.enc_max_forget, opt.dec_max_forget, opt.dropouti, opt.dropouth, opt.dropouto, opt.dropouts, opt.learning_rate, opt.param_init, int(opt.use_reverse))
    else:
        exp_name = '{}-enc:{}-dec:{}-et:{}-dt:{}-h:{}-el:{}-dl:{}-em:{}-atn:{}-cxt:{}-sl:{}-di:{}-dh:{}-do:{}-ds:{}-lr:{}-init:{}-userev:{}'.format(
                    timestamp, opt.encoder_model, opt.decoder_model, opt.encoder_rnn_type, opt.decoder_rnn_type,
                    opt.rnn_size, opt.enc_layers, opt.dec_layers, opt.word_vec_size,
                    opt.global_attention, opt.context_type, opt.slice_dim, opt.dropouti, opt.dropouth, opt.dropouto, opt.dropouts, opt.learning_rate, opt.param_init, int(opt.use_reverse))

    opt.save_dir = os.path.join(opt.save_dir, exp_name)

    if os.path.exists(os.path.join(opt.save_dir, 'result.txt')):
        print('The result file {} exists! Terminating to not overwrite it!'.format(os.path.join(opt.save_dir, 'result.txt')))
        sys.exit(0)

    # Create save dir if it doesn't exist
    if not os.path.exists(opt.save_dir):
        os.makedirs(opt.save_dir)

    # Save command-line arguments
    with open(os.path.join(opt.save_dir, 'args.yaml'), 'w') as f:
        yaml.dump(vars(opt), f)

    # Load train and validate data.
    train_dataset = load_dataset("train")
    valid_dataset = load_dataset("valid")
    print(' * maximum batch size: %d' % opt.batch_size)

    # Load checkpoint if resuming from a previous training run.
    if opt.train_from:
        print('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
        opt.start_epoch = checkpoint['epoch'] + 1
    else:
        checkpoint = None
        model_opt = opt

    # Load fields generated from preprocess phase.
    fields = load_fields(train_dataset, valid_dataset, checkpoint)

    # Report src/tgt features.
    collect_report_features(fields)

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    tally_parameters(model)

    # Build optimizer.
    optim = build_optim(model, checkpoint)

    # Do training.
    best_val_checkpoint_path = train_model(model, train_dataset, valid_dataset, fields, optim, model_opt)

    # Evaluate the final model on the validation and test sets
    evaluate(best_val_checkpoint_path)


if __name__ == "__main__":
    main()