# Copyright (c) Facebook, Inc. and its affiliates.
# 
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
import torch
import random
import argparse
from torch import nn
from utils import get_logger
from utils import AverageMeter
from utils import EarlyStopping
from utils import get_lr_scheduler
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from listops.models import ReinforceModel
from listops.data_preprocessing import ListOpsDataset
from listops.data_preprocessing import ListOpsBucketSampler


def make_path_preparations(args):
    seed = hash(str(args)) % 1000_000
    ListOpsBucketSampler.random_seed = seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)

    # logger path
    args_hash = str(hash(str(args)))
    if not os.path.exists(args.logs_path):
        os.makedirs(args.logs_path)
    logger = get_logger(f"{args.logs_path}/l{args_hash}.log")
    print(f"{args.logs_path}/l{args_hash}.log")
    logger.info(f"args: {str(args)}")
    logger.info(f"args hash: {args_hash}")
    logger.info(f"random seed: {seed}")

    # model path
    args.model_dir = f"{args.model_dir}/m{args_hash}"
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)
    logger.info(f"checkpoint's dir is: {args.model_dir}")

    # tensorboard path
    tensorboard_path = f"{args.tensorboard_path}/t{args_hash}"
    if not os.path.exists(tensorboard_path):
        os.makedirs(tensorboard_path)
    summary_writer = dict()
    summary_writer["train"] = SummaryWriter(log_dir=os.path.join(tensorboard_path, 'log' + args_hash, 'train'))
    summary_writer["valid"] = SummaryWriter(log_dir=os.path.join(tensorboard_path, 'log' + args_hash, 'valid'))

    return logger, summary_writer


def get_data(args):
    train_dataset = ListOpsDataset("data/listops/interim/train.tsv", "data/listops/processed/vocab.txt", max_len=130)
    valid_dataset = ListOpsDataset("data/listops/interim/valid.tsv", "data/listops/processed/vocab.txt", max_len=300)
    test_dataset = ListOpsDataset("data/listops/interim/test.tsv", "data/listops/processed/vocab.txt")

    train_data_sampler = ListOpsBucketSampler(dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
                                              drop_last=True)
    valid_data_sampler = ListOpsBucketSampler(dataset=valid_dataset, batch_size=args.batch_size, shuffle=False,
                                              drop_last=False)
    test_data_sampler = ListOpsBucketSampler(dataset=test_dataset, batch_size=args.batch_size//4 + 1, shuffle=False,
                                             drop_last=False)

    train_data = DataLoader(dataset=train_dataset, batch_sampler=train_data_sampler, num_workers=6, pin_memory=True,
                            collate_fn=ListOpsDataset.collate_fn)
    valid_data = DataLoader(dataset=valid_dataset, batch_sampler=valid_data_sampler, num_workers=6, pin_memory=True,
                            collate_fn=ListOpsDataset.collate_fn)
    test_data = DataLoader(dataset=test_dataset, batch_sampler=test_data_sampler, num_workers=6, pin_memory=True,
                           collate_fn=ListOpsDataset.collate_fn)

    args.vocab_size = train_dataset.vocab_size
    args.label_size = train_dataset.label_size
    return train_data, valid_data, test_data


def prepare_optimisers(args, logger, policy_parameters, environment_parameters):
    if args.optimizer == "adam":
        optimizer_class = torch.optim.Adam
    elif args.optimizer == "adadelta":
        optimizer_class = torch.optim.Adadelta
    else:
        optimizer_class = torch.optim.SGD
    optimizer = {"policy": optimizer_class(params=policy_parameters, lr=args.pol_lr, weight_decay=args.l2_weight),
                 "env": optimizer_class(params=environment_parameters, lr=args.env_lr, weight_decay=args.l2_weight)}
    lr_scheduler = {"policy": get_lr_scheduler(logger, optimizer["policy"], patience=args.lr_scheduler_patience),
                    "env": get_lr_scheduler(logger, optimizer["env"], patience=args.lr_scheduler_patience)}
    es = EarlyStopping(mode="max", patience=args.es_patience, threshold=args.es_threshold)
    return optimizer, lr_scheduler, es


def perform_optimizer_step(optimizer, model, args):
    if args.clip_grad_norm > 0:
        nn.utils.clip_grad_norm_(parameters=model.get_environment_parameters(),
                                 max_norm=args.clip_grad_norm,
                                 norm_type=float("inf"))
    optimizer["env"].step()
    optimizer["env"].zero_grad()

    if args.clip_grad_norm > 0:
        nn.utils.clip_grad_norm_(parameters=model.get_policy_parameters(),
                                 max_norm=args.clip_grad_norm,
                                 norm_type=float("inf"))
    optimizer["policy"].step()
    optimizer["policy"].zero_grad()


def test(test_data, model, device, logger):
    loading_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    accuracy_meter = AverageMeter()
    entropy_meter = AverageMeter()
    n_entropy_meter = AverageMeter()

    model.eval()
    start = time.time()
    with torch.no_grad():
        for labels, tokens, mask in test_data:
            labels = labels.to(device=device, non_blocking=True)
            tokens = tokens.to(device=device, non_blocking=True)
            mask = mask.to(device=device, non_blocking=True)

            loading_time_meter.update(time.time() - start)

            pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \
                model(tokens, mask, labels)
            entropy = entropy.mean()
            normalized_entropy = normalized_entropy.mean()

            accuracy = (labels == pred_labels).to(dtype=torch.float32).mean()
            n = mask.shape[0]
            accuracy_meter.update(accuracy.item(), n)
            ce_loss_meter.update(ce_loss.item(), n)
            entropy_meter.update(entropy.item(), n)
            n_entropy_meter.update(normalized_entropy.item(), n)
            batch_time_meter.update(time.time() - start)
            start = time.time()

    logger.info(f"Test: ce_loss: {ce_loss_meter.avg:.4f} accuracy: {accuracy_meter.avg:.4f} "
                f"entropy: {entropy_meter.avg:.4f} n_entropy: {n_entropy_meter.avg:.4f} "
                f"loading_time: {loading_time_meter.avg:.4f} batch_time: {batch_time_meter.avg:.4f}")
    logger.info("done")

    return accuracy_meter.avg


def validate(valid_data, model, epoch, device, logger, summary_writer):
    loading_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    accuracy_meter = AverageMeter()
    entropy_meter = AverageMeter()
    n_entropy_meter = AverageMeter()

    model.eval()
    start = time.time()
    with torch.no_grad():
        for labels, tokens, mask in valid_data:
            labels = labels.to(device=device, non_blocking=True)
            tokens = tokens.to(device=device, non_blocking=True)
            mask = mask.to(device=device, non_blocking=True)
            loading_time_meter.update(time.time() - start)

            pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \
                model(tokens, mask, labels)
            entropy = entropy.mean()
            normalized_entropy = normalized_entropy.mean()

            accuracy = (labels == pred_labels).to(dtype=torch.float32).mean()
            n = mask.shape[0]
            accuracy_meter.update(accuracy.item(), n)
            ce_loss_meter.update(ce_loss.item(), n)
            entropy_meter.update(entropy.item(), n)
            n_entropy_meter.update(normalized_entropy.item(), n)
            batch_time_meter.update(time.time() - start)
            start = time.time()

    logger.info(f"Valid: epoch: {epoch} ce_loss: {ce_loss_meter.avg:.4f} accuracy: {accuracy_meter.avg:.4f} "
                f"entropy: {entropy_meter.avg:.4f} n_entropy: {n_entropy_meter.avg:.4f} "
                f"loading_time: {loading_time_meter.avg:.4f} batch_time: {batch_time_meter.avg:.4f}")

    summary_writer["valid"].add_scalar(tag="ce", scalar_value=ce_loss_meter.avg, global_step=global_step)
    summary_writer["valid"].add_scalar(tag="accuracy", scalar_value=accuracy_meter.avg, global_step=global_step)
    summary_writer["valid"].add_scalar(tag="n_entropy", scalar_value=n_entropy_meter.avg, global_step=global_step)

    model.train()
    return accuracy_meter.avg


def train(train_data, valid_data, model, optimizer, lr_scheduler, es, epoch, args, logger, summary_writer):
    loading_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    accuracy_meter = AverageMeter()
    entropy_meter = AverageMeter()
    n_entropy_meter = AverageMeter()

    device = args.gpu_id
    model.train()
    start = time.time()
    for batch_idx, (labels, tokens, mask) in enumerate(train_data):
        labels = labels.to(device=device, non_blocking=True)
        tokens = tokens.to(device=device, non_blocking=True)
        mask = mask.to(device=device, non_blocking=True)
        loading_time_meter.update(time.time() - start)

        pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \
            model(tokens, mask, labels)
        entropy = entropy.mean()
        normalized_entropy = normalized_entropy.mean()

        ce_loss.backward()
        rl_loss = (rewards * actions_log_prob).mean() - args.entropy_weight * normalized_entropy
        rl_loss.backward()
        perform_optimizer_step(optimizer, model, args)

        n = mask.shape[0]
        accuracy = (labels == pred_labels).to(dtype=torch.float32).mean()
        accuracy_meter.update(accuracy.item(), n)
        ce_loss_meter.update(ce_loss.item(), n)
        entropy_meter.update(entropy.item(), n)
        n_entropy_meter.update(normalized_entropy.item(), n)
        batch_time_meter.update(time.time() - start)

        global global_step
        summary_writer["train"].add_scalar(tag="ce", scalar_value=ce_loss.item(), global_step=global_step)
        summary_writer["train"].add_scalar(tag="accuracy", scalar_value=accuracy.item(), global_step=global_step)
        summary_writer["train"].add_scalar(tag="n_entropy", scalar_value=normalized_entropy.item(),
                                           global_step=global_step)
        global_step += 1

        if (batch_idx + 1) % (len(train_data) // 3) == 0:
            logger.info(f"Train: epoch: {epoch} batch_idx: {batch_idx + 1} ce_loss: {ce_loss_meter.avg:.4f} "
                        f"accuracy: {accuracy_meter.avg:.4f} entropy: {entropy_meter.avg:.4f} "
                        f"n_entropy: {n_entropy_meter.avg:.4f} loading_time: {loading_time_meter.avg:.4f} "
                        f"batch_time: {batch_time_meter.avg:.4f}")
            val_accuracy = validate(valid_data, model, epoch, device, logger, summary_writer)
            lr_scheduler["env"].step(val_accuracy)
            lr_scheduler["policy"].step(val_accuracy)
            es.step(val_accuracy)
            global best_model_path
            if es.is_converged:
                return
            if es.is_improved():
                logger.info("saving model...")
                best_model_path = f"{args.model_dir}/{epoch}-{batch_idx}.mdl"
                torch.save({"epoch": epoch, "batch_idx": batch_idx, "state_dict": model.state_dict()}, best_model_path)
            model.train()
        start = time.time()


def main(args):
    logger, summary_writer = make_path_preparations(args)
    train_data, valid_data, test_data = get_data(args)
    model = ReinforceModel(vocab_size=args.vocab_size,
                           word_dim=args.word_dim,
                           hidden_dim=args.hidden_dim,
                           label_dim=args.label_size,
                           parser_leaf_transformation=args.parser_leaf_transformation,
                           parser_trans_hidden_dim=args.parser_trans_hidden_dim,
                           tree_leaf_transformation=args.tree_leaf_transformation,
                           tree_trans_hidden_dim=args.tree_trans_hidden_dim,
                           baseline_type=args.baseline_type,
                           var_normalization=args.var_normalization).cuda(args.gpu_id)
    optimizer, lr_scheduler, es = prepare_optimisers(args, logger,
                                                     policy_parameters=model.get_policy_parameters(),
                                                     environment_parameters=model.get_environment_parameters())

    validate(valid_data, model, 0, args.gpu_id, logger, summary_writer)
    for epoch in range(args.max_epoch):
        train(train_data, valid_data, model, optimizer, lr_scheduler, es, epoch, args, logger, summary_writer)
        if es.is_converged:
            break
    print(best_model_path)
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint["state_dict"])
    test(test_data, model, args.gpu_id, logger)


if __name__ == "__main__":
    args = {"word-dim":                   128,
            "hidden-dim":                 128,
            "parser-leaf-transformation": "lstm_transformation",
            "parser-trans-hidden_dim":    128,
            "tree-leaf-transformation":   "no_transformation",
            "tree-trans-hidden_dim":      128,
            "baseline-type":              "self_critical",
            "var-normalization":          "True",
            "entropy-weight":             0.0001,
            "clip-grad-norm":             0.5,
            "optimizer":                  "adadelta",
            "env-lr":                     1.0,
            "pol-lr":                     1.0,
            "lr-scheduler-patience":      8,
            "l2-weight":                  0.0001,
            "batch-size":                 64,
            "max-epoch":                  300,
            "es-patience":                20,
            "es-threshold":               0.005,
            "gpu-id":                     0,
            "model-dir":                  "data/listops/reinforce/models/exp0",
            "logs-path":                  "data/listops/reinforce/logs/exp0",
            "tensorboard-path":           "data/listops/reinforce/tensorboard/exp0"
    }

    parser = argparse.ArgumentParser()
    parser.add_argument("--word-dim", required=False, default=args["word-dim"], type=int)
    parser.add_argument("--hidden-dim", required=False, default=args["hidden-dim"], type=int)
    parser.add_argument("--parser-leaf-transformation", required=False, default=args["parser-leaf-transformation"],
                        choices=["no_transformation", "lstm_transformation",
                                 "bi_lstm_transformation", "conv_transformation"])
    parser.add_argument("--parser-trans-hidden_dim", required=False, default=args["parser-trans-hidden_dim"], type=int)
    parser.add_argument("--tree-leaf-transformation", required=False, default=args["tree-leaf-transformation"],
                        choices=["no_transformation", "lstm_transformation",
                                 "bi_lstm_transformation", "conv_transformation"])
    parser.add_argument("--tree-trans-hidden_dim", required=False, default=args["tree-trans-hidden_dim"], type=int)

    parser.add_argument("--baseline-type", default=args["baseline-type"],
                        choices=["no_baseline", "ema", "self_critical"])
    parser.add_argument("--var-normalization", default=args["var-normalization"],
                        type=lambda string: True if string == "True" else False)
    parser.add_argument("--entropy-weight", default=args["entropy-weight"], type=float)
    parser.add_argument("--clip-grad-norm", default=args["clip-grad-norm"], type=float,
                        help="If the value is less or equal to zero clipping is not performed.")

    parser.add_argument("--optimizer", required=False, default=args["optimizer"], choices=["adam", "sgd", "adadelta"])
    parser.add_argument("--env-lr", required=False, default=args["env-lr"], type=float)
    parser.add_argument("--pol-lr", required=False, default=args["pol-lr"], type=float)
    parser.add_argument("--lr-scheduler-patience", required=False, default=args["lr-scheduler-patience"], type=int)
    parser.add_argument("--l2-weight", required=False, default=args["l2-weight"], type=float)
    parser.add_argument("--batch-size", required=False, default=args["batch-size"], type=int)

    parser.add_argument("--max-epoch", required=False, default=args["max-epoch"], type=int)
    parser.add_argument("--es-patience", required=False, default=args["es-patience"], type=int)
    parser.add_argument("--es-threshold", required=False, default=args["es-threshold"], type=float)
    parser.add_argument("--gpu-id", required=False, default=args["gpu-id"], type=int)
    parser.add_argument("--model-dir", required=False, default=args["model-dir"], type=str)
    parser.add_argument("--logs-path", required=False, default=args["logs-path"], type=str)
    parser.add_argument("--tensorboard-path", required=False, default=args["tensorboard-path"], type=str)

    global_step = 0
    best_model_path = None
    args = parser.parse_args()
    with torch.cuda.device(args.gpu_id):
        main(args)