import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

import src.data.config as cfg
import src.train.utils as train_utils
import src.models.utils as model_utils
import src.evaluate.utils as eval_utils
import utils.utils as utils
from IPython import embed


##############################################################################
#                                       BATCH
##############################################################################


def batch_atomic_generate(opt, nums, losses, batch_variables, eval_mode=False):
    data_loader = batch_variables["data"]
    model = batch_variables["model"]
    split = batch_variables["split"]

    batch, reset = data_loader.sample_batch(split, bs=opt.train.dynamic.bs)

    input_ = model_utils.prepare_position_embeddings(
        opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
    attention_mask = batch["attention_mask"]
    loss_mask = batch["loss_mask"]

    targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)

    loss, dist = mle_steps(
        opt.net.model, model, input_[:, :-1, :], targets,
        attention_mask[:, :-1], loss_reduction="none")

    # Set loss name
    micro_name = "total_micro"
    macro_name = "total_macro"

    length = loss_mask.sum(1)
    bs = input_.size(0)

    final_loss = (loss * loss_mask).sum(1)

    update_generation_losses(losses, nums, micro_name, macro_name, bs,
                             length, (loss * loss_mask).sum(1), split)

    final_loss = final_loss / length

    outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}

    return outputs


def batch_conceptnet_generate(opt, nums, losses, batch_variables,
                              eval_mode=False, tracking_mode=False):
    data_loader = batch_variables["data"]
    model = batch_variables["model"]
    split = batch_variables["split"]
    category = batch_variables["category"]

    batch, reset = data_loader.sample_batch(
        split, bs=opt.train.dynamic.bs, cat=category)

    input_ = model_utils.prepare_position_embeddings(
        opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
    attention_mask = batch["attention_mask"]
    loss_mask = batch["loss_mask"]

    targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)

    loss, dist = mle_steps(
        opt.net.model, model, input_[:, :-1, :], targets,
        attention_mask[:, :-1], loss_reduction="none")

    # Set loss name
    if not eval_mode or batch_variables["category"] == "positive":
        micro_name = "total_micro"
        macro_name = "total_macro"
    else:
        micro_name = "negative_micro"
        macro_name = "negative_macro"

    length = loss_mask.sum(1)
    bs = input_.size(0)

    final_loss = (loss * loss_mask).sum(1)

    update_generation_losses(losses, nums, micro_name, macro_name, bs,
                             length, (loss * loss_mask).sum(1), split)

    final_loss = final_loss / length

    outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}

    if tracking_mode:
        outputs["tracking"] = final_loss.squeeze().tolist()

    return outputs


def mle_steps(key, model, input_, targets, attention_mask,
              loss_reduction="mean", i=None):
    word_acts = decode(model, input_.unsqueeze(1),
                       attention_mask, i)

    word_dist = train_utils.modify_output_for_loss_fn(
        "nll", word_acts, dim=-1)

    # Compute losses
    loss = F.nll_loss(
        word_dist.view(-1, word_dist.size(-1)),
        targets, reduction=loss_reduction)

    if loss_reduction != "mean":
        return loss.view(word_dist.size(0), -1), word_dist
    else:
        return loss, word_dist


def decode(model, input_, attention_mask, i=None):
    return model(input_, sequence_mask=attention_mask)


def update_generation_losses(losses, nums, micro, macro, bs,
                             length, loss, split):
    if split == "train":
        train_utils.update_generation_losses(
            losses, nums, micro, macro, bs, length, loss)
    else:
        eval_utils.update_generation_losses(
            losses, nums, micro, macro, bs, length, loss)