# Implementation for paper 'Adaptively Aligned Image Captioning via Adaptive Attention Time'
# https://arxiv.org/abs/1909.09060

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import misc.utils as utils

from .AttModel import AttModel, Attention
from .TransformerModel import LayerNorm
from .AoAModel import MultiHeadedDotAttention

class AATCore(nn.Module):
    def __init__(self, opt):
        super(AATCore, self).__init__()
        
        self.drop_prob_lm = opt.drop_prob_lm
        self.rnn_size = opt.rnn_size
        self.epsilon = opt.epsilon
        self.max_att_steps = opt.max_att_steps
        self.use_multi_head = opt.use_multi_head

        self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size)
        
        self.confidence = nn.Sequential(nn.Linear(opt.rnn_size, opt.rnn_size),
                                    nn.ReLU(),
                                    nn.Linear(opt.rnn_size, 1),
                                    nn.Sigmoid())

        self.h2query = nn.Sequential(nn.Linear(opt.rnn_size * 2, opt.rnn_size),
                            nn.ReLU())

        # if opt.use_multi_head == 1: # TODO, not implemented for now           
        #     self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
        if opt.use_multi_head == 2:
            self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
        else:            
            self.attention = Attention(opt)

        self.lang_lstm = nn.LSTMCell(opt.rnn_size + opt.rnn_size, opt.rnn_size)

        self.norm_h = LayerNorm(opt.rnn_size)
        self.norm_c = LayerNorm(opt.rnn_size)

    def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
        batch_size = fc_feats.size()[0]
        accum_conf = Variable(fc_feats.data.new(batch_size, 1).zero_())
        self.att_step = Variable(fc_feats.data.new(batch_size).zero_())
        self.att_cost = Variable(fc_feats.data.new(batch_size).zero_())

        h_lang = Variable(fc_feats.data.new(batch_size, self.rnn_size).zero_())
        c_lang = Variable(fc_feats.data.new(batch_size, self.rnn_size).zero_())

        att_lstm_input = torch.cat([fc_feats + state[0][-1], xt], 1)

        h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
        h_att = self.norm_h(h_att)

        p = self.confidence(h_att)

        self.att_cost += (1+(1-p)).squeeze(1)
        
        selector = (p < 1 - self.epsilon).data

        if selector.any():
            accum_conf += p
            
            h_lang += p * h_att
            c_lang += p * state[1][1]

            h_lang_, c_lang_ = (state[0][1], state[1][1])

            for i in range(self.max_att_steps):
                att_query = self.h2query(torch.cat([h_lang_, h_att], 1))
                if self.use_multi_head == 2:
                    att_ = self.attention(att_query, p_att_feats.narrow(2, 0, self.rnn_size), p_att_feats.narrow(2, self.rnn_size, self.rnn_size), att_masks)
                else:
                    att_ = self.attention(att_query, att_feats, p_att_feats, att_masks)

                lang_lstm_input_ = torch.cat([att_, att_query], 1)
                h_lang_, c_lang_ = self.lang_lstm(lang_lstm_input_, (h_lang_, c_lang_))
                h_lang_ = self.norm_h(h_lang_)
                c_lang_ = self.norm_c(c_lang_)

                self.att_step += selector.squeeze(1).float()
                p_ = self.confidence(h_lang_)
          
                beta = p_ * (1 - accum_conf)
                accum_conf += beta * selector.float()
                h_lang += beta * h_lang_ * selector.float()
                c_lang += beta * c_lang_ * selector.float()

                self.att_cost += ((1+(i+2)*(1-p_)) * selector.float()).squeeze(1)
                selector = (accum_conf < 1 - self.epsilon).data * selector

                if not selector.any():
                    break

            h_lang /= accum_conf
            c_lang /= accum_conf

        else:
            h_lang += h_att
            c_lang += state[1][1]

        output = F.dropout(h_lang, self.drop_prob_lm, self.training)
        state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))

        return output, state

class AATModel(AttModel):
    def __init__(self, opt):
        super(AATModel, self).__init__(opt)
        self.num_layers = 2
        if opt.use_multi_head == 2:
            del self.ctx2att
            self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.rnn_size)
        self.core = AATCore(opt)
        
    def forward(self, *args, **kwargs):
        self.all_att_step = []
        self.all_att_cost = []
        mode = kwargs.get('mode', 'forward')
        if 'mode' in kwargs:
            del kwargs['mode']
        return getattr(self, '_'+mode)(*args, **kwargs)
    
    def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state):
        # 'it' contains a word index
        xt = self.embed(it)

        output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
        logprobs = F.log_softmax(self.logit(output), dim=1)
        
        self.all_att_step.append(self.core.att_step.cpu().numpy())
        self.all_att_cost.append(self.core.att_cost)

        return logprobs, state