import torch
import torch.nn as nn
import torch.nn.init
import torchvision.models as models
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.backends.cudnn as cudnn
from torch.nn.utils.clip_grad import clip_grad_norm
import torch.nn.functional as F
import numpy as np
import bottleneck as bn

def rnn_mask(context_lens, max_step):
    """
    Creates a mask for variable length sequences
    """
    num_batches = len(context_lens)

    mask = torch.FloatTensor(num_batches, max_step).zero_()
    if torch.cuda.is_available():
        mask = mask.cuda()
    for b, batch_l in enumerate(context_lens):
        mask[b, :batch_l] = 1.0
    mask = Variable(mask)
    return mask

def top_n_indexes(arr, n):
        idx = bn.argpartition(arr, arr.size-n, axis=None)[-n:]
        width = arr.shape[1]
        return [divmod(i, width) for i in idx]

class Seq2seqAttention(nn.Module):
    def __init__(self, args):
        super(Seq2seqAttention, self).__init__()
        self.args = args
        self.enable_cuda = args.cuda
        self.vid_dim = args.vid_dim
        self.embed_size = args.embed
        self.hidden_dim = args.hid
        self.vocab_size = args.max_vocab_size
        self.num_layers = args.num_layers
        self.birnn = args.birnn
        self.encoder = EncoderFrames(self.args)
        self.decoder = DecoderRNN(self.args)

    def forward(self, frames, flengths, captions, lengths):
        video_features = self.encoder(frames, flengths)
        outputs = self.decoder(video_features, flengths, captions, lengths)

        return outputs



    def sample(self, frames, flengths):
        video_features = self.encoder.forward(frames, flengths)
        predicted_target = self.decoder.sample(video_features, flengths)
        return predicted_target

    def sample_rl(self, frames, flengths, sampling='multinomial'):
        video_features = self.encoder.forward(frames, flengths)
        predicted_target, outputs = self.decoder.rl_sample(video_features, flengths, sampling=sampling)
        return predicted_target, outputs

    def beam_search(self, frames, flengths, beam_size=5):
        video_features = self.encoder.forward(frames, flengths)
        predicted_target = self.decoder.beam_search(video_features, flengths, beam_size=beam_size)

        return predicted_target






# Based on tutorials/08 - Language Model
# RNN Based Language Model
class EncoderFrames(nn.Module):
    def __init__(self, args):
        super(EncoderFrames, self).__init__()
        # self.use_abs = use_abs
        self.vid_dim = args.vid_dim
        self.embed_size = args.embed
        self.hidden_dim = args.hid
        self.enable_cuda = args.cuda
        self.num_layers = args.num_layers
        self.args = args
        if args.birnn:
            self.birnn = 2
        else:
            self.birnn = 1
        # projection layer
        self.linear = nn.Linear(self.vid_dim, self.embed_size, bias=False)
        # video embedding
        self.rnn = nn.LSTM(self.embed_size, self.hidden_dim, self.num_layers, batch_first=True, bidirectional=self.args.birnn, dropout=args.dropout)
        self.dropout = nn.Dropout(args.dropout)
        self.init_weights()

    def init_weights(self):
        self.rnn.weight_hh_l0.data.uniform_(-0.08, 0.08)
        self.rnn.weight_ih_l0.data.uniform_(-0.08, 0.08)
        self.rnn.bias_ih_l0.data.fill_(0)
        self.rnn.bias_hh_l0.data.fill_(0)
        self.linear.weight.data.uniform_(-0.08, 0.08)
        #self.linear.bias.data.fill_(0)

    def init_hidden(self, batch_size):
        if self.birnn:
            return (Variable(torch.zeros(self.birnn*self.num_layers, batch_size, self.hidden_dim)),
                    Variable(torch.zeros(self.birnn*self.num_layers, batch_size, self.hidden_dim)))



    def forward(self, frames, flengths):
        """Handles variable size frames
           frame_embed: video features
           flengths: frame lengths
        """
        batch_size = flengths.shape[0]
        #frames = self.linear(frames)
        #frames = self.dropout(frames) # adding dropout layer
        self.init_rnn = self.init_hidden(batch_size)
        if self.enable_cuda:
            self.init_rnn = self.init_rnn[0].cuda(), self.init_rnn[1].cuda()

        if batch_size > 1:
            # Sort by length (keep idx)
            flengths, idx_sort = np.sort(flengths)[::-1], np.argsort(-flengths)
            if self.enable_cuda:
                frames = frames.index_select(0, Variable(torch.cuda.LongTensor(idx_sort)))
            else:
                frames = frames.index_select(0, Variable(torch.LongTensor(idx_sort)))



        frames = self.linear(frames)
        frame_packed = nn.utils.rnn.pack_padded_sequence(frames, flengths, batch_first=True)
        outputs, (ht, ct) = self.rnn(frame_packed, self.init_rnn)
        outputs,_ = pad_packed_sequence(outputs,batch_first=True)

        if batch_size > 1:
            # Un-sort by length
            idx_unsort = np.argsort(idx_sort)
            if self.enable_cuda:
                outputs = outputs.index_select(0, Variable(torch.cuda.LongTensor(idx_unsort)))
            else:
                outputs = outputs.index_select(0, Variable(torch.LongTensor(idx_unsort)))

        # print 'Encoder Outputs:',outputs.size()

        return outputs


# Based on tutorials/03 - Image Captioning
class DecoderRNN(nn.Module):
    def __init__(self, args):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.enable_cuda = args.cuda
        self.embed_size = args.embed
        self.hidden_size = args.hid
        self.vocab_size = args.max_vocab_size
        if args.birnn:
            self.birnn = 2
        else:
            self.birnn = 1

        self.num_layers = args.num_layers
        self.args = args
        self.input_proj = nn.Linear(self.birnn*self.hidden_size+self.embed_size, self.embed_size)
        self.embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.atten = Attention(args, self.birnn*self.hidden_size, self.hidden_size)  
        self.lstm = nn.LSTM(self.embed_size+self.birnn*self.hidden_size, self.hidden_size, self.num_layers, batch_first=True, dropout=args.dropout)

        self.linear = nn.Linear(self.hidden_size, self.vocab_size)
        self.init_weights()

    def init_weights(self):
        """Initialize weights."""
        #self.lstm.weight_hh_l0.data.uniform_(-0.08, 0.08)
        self.lstm.weight_hh_l0.data.uniform_(-0.08, 0.08)
        self.lstm.weight_ih_l0.data.uniform_(-0.08, 0.08)
        self.lstm.bias_ih_l0.data.fill_(0)
        self.lstm.bias_hh_l0.data.fill_(0)
        self.embed.weight.data.uniform_(-0.08, 0.08)
        self.input_proj.weight.data.uniform_(-0.08, 0.08)
        self.input_proj.bias.data.fill_(0)
        self.linear.weight.data.uniform_(-0.08, 0.08)
        self.linear.bias.data.fill_(0)

    def forward(self, video_features, flengths, captions, lengths):
        """Decode image feature vectors and generates captions."""
        """
        :param video_features:
            video encoder output hidden states of size batch_size x max_enc_steps x hidden_dim
        :param flengths:
            video frames length of size batch_size
        :param captions:
            input target captions of size batch_size x max_dec_steps
        :param lengths:
            input captions lengths of size batch_size

        """
        # print features.size(), captions.size(), self.embed_size
        # print 'Input features, captions, lengths', features.size(), captions.size(), lengths, np.sum(lengths)
        # appending <start> token to the input captions
        batch_size,step_size = captions.shape
        max_enc_steps = video_features.shape[1]
        context_mask = rnn_mask(flengths,max_enc_steps)
        captions = torch.cat((Variable(torch.LongTensor(np.ones([batch_size,1]))).cuda(),captions), 1)
        embeddings = self.embed(captions)
        hidden_output = Variable(torch.FloatTensor(batch_size,self.hidden_size).zero_()).cuda()
        state = None
        outputs = []
        for i in range(step_size):
            c_t, alpha = self.atten(hidden_output, video_features, context_mask)
            inp = torch.cat((embeddings[:,i,:], c_t), 1).unsqueeze(1)
            #inp = self.input_proj(inp)
            hidden_output,state = self.lstm(inp,state)
            hidden_output = hidden_output.squeeze(1)
            outputs.append(hidden_output)

        outputs = torch.transpose(torch.stack(outputs), 0, 1) # converting from step_size x batch_size x hidden_size to batch_size x step_size x hidden_size
        outputs = pack_padded_sequence(outputs, lengths, batch_first=True)[0]
        outputs = self.linear(outputs)

        return outputs


    def sample(self, video_features, flengths, max_len=30, state=None):
        """Samples captions for given image features (Greedy search)."""
        sampled_ids = []
        state = None
        batch_size, _, _ = video_features.shape
        max_enc_steps = video_features.shape[1]
        context_mask = rnn_mask(flengths, max_enc_steps)
        hidden_output = Variable(torch.FloatTensor(batch_size,self.hidden_size).zero_()).cuda()
        inputs = self.embed(Variable(torch.LongTensor(np.ones([batch_size,1]))).cuda()).squeeze(1)
        for i in range(max_len + 1):  # maximum sampling length
            c_t, alpha = self.atten(hidden_output, video_features, context_mask)
            inp = torch.cat((inputs, c_t), 1).unsqueeze(1)
            #inp = self.input_proj(inp)
            hidden_output,state = self.lstm(inp,state)
            hidden_output = hidden_output.squeeze(1)
            output = self.linear(hidden_output)  # (batch_size, vocab_size)
            predicted = output.max(1)[1]
            sampled_ids.append(predicted.unsqueeze(1))
            inputs = self.embed(predicted)
        sampled_ids = torch.cat(sampled_ids, 1)  # (batch_size, 20)
        return sampled_ids.squeeze()

    def beam_search(self, video_features, flengths, max_len=20, beam_size=5):
        """ Beam search Implementation during Inference"""
        prev_state = None
        outputs = []
        batch_size, max_enc_steps, _ = video_features.shape
        context_mask = rnn_mask(flengths, max_enc_steps)
        hidden_output = Variable(torch.FloatTensor(batch_size,self.hidden_size).zero_()).cuda()
        inputs = self.embed(Variable(torch.LongTensor(np.ones([batch_size,1]))).cuda()).squeeze(1)
        # handle the zero step case seperately
        c_t, alpha = self.atten(hidden_output, video_features, context_mask)
        inp = torch.cat((inputs,c_t),1).unsqueeze(1)
        next_hidden, next_state = self.lstm(inp, prev_state)
        next_hidden = next_hidden.squeeze(1)
        output = self.linear(next_hidden)
        output = F.softmax(output,1)
        next_probs, next_words = torch.topk(output,beam_size)
        prev_words = torch.t(next_words)
        prev_state = []
        prev_hidden = []
        #print next_state

        for i in range(beam_size):
            prev_state.append(next_state)
            prev_hidden.append(next_hidden)
        #print prev_state
        all_probs = next_probs.cpu().data.numpy()

        generated_sequence = np.zeros((batch_size,beam_size,max_len),dtype=np.int32)
        generated_sequence[:,:,0] = next_words.cpu().data.numpy()

        # variables for final results storing
        final_results = np.zeros((batch_size,beam_size,max_len), dtype=np.int32)
        final_all_probs = np.zeros((batch_size,beam_size))
        final_results_counter = np.zeros((batch_size),dtype=np.int32) # to check the overflow of beam in fina results


        for i in range(1,max_len):
            probs = []
            state = []
            hidden = []
            words = []

            for j in range(beam_size):
                inputs = self.embed(prev_words[j])
                #print inputs
                c_t, alpha = self.atten(prev_hidden[j], video_features, context_mask)
                inp = torch.cat((inputs,c_t),1).unsqueeze(1)
                next_hidden, next_state = self.lstm(inp, prev_state[j])
                next_hidden = next_hidden.squeeze(1)
                output = self.linear(next_hidden)
                output = F.softmax(output,1)
                next_probs, next_words = torch.topk(output, beam_size)
                probs.append(next_probs)
                words.append(next_words)
                state.append(next_state)
                hidden.append(next_hidden)


            probs = np.transpose(np.array(torch.stack(probs).cpu().data.numpy()),(1,0,2))
            #state = np.transpose(np.array(state.cpu().data.numpy()),(1,0,2))
            hidden = np.transpose(np.array(torch.stack(hidden).cpu().data.numpy()),(1,0,2))
            words = np.transpose(np.array(torch.stack(words).cpu().data.numpy()),(1,0,2))
            state = [torch.cat(s,0) for s in state]
            state = torch.stack(state)
            #print state

            prev_state = []
            prev_words = []
            prev_hidden = []
            
            for k in range(batch_size):
                probs[k] = np.transpose(np.transpose(probs[k])*all_probs[k]) # multiply each beam words with each beam probs so far
                top_indices = top_n_indexes(probs[k],beam_size)
                beam_idx,top_choice_idx = zip(*top_indices)
                all_probs[k] = (probs[k])[beam_idx,top_choice_idx]
                prev_state.append([state[idx,:,k,:] for idx in beam_idx])
                prev_hidden.append([hidden[k,idx,:] for idx in beam_idx])
                prev_words.append([words[k,idx,idy] for idx,idy in top_indices])
                generated_sequence[k] = generated_sequence[k,beam_idx,:]
                generated_sequence[k,:,i] = [words[k,idx,idy] for idx,idy in top_indices]



                # code to extract complete summaries ending with [EOS] or [STOP] or [END]

                for beam_idx in range(beam_size):
                    if generated_sequence[k,beam_idx,i] == 2 and final_results_counter[k]<beam_size: # [EOS] or [STOP] or [END] word / check overflow
                        # print generated_sequence[k,beam_idx]
                        final_results[k,final_results_counter[k],:] = generated_sequence[k,beam_idx,:]
                        final_all_probs[k,final_results_counter[k]] = all_probs[k,beam_idx]
                        final_results_counter[k] += 1 
                        all_probs[k,beam_idx] = 0.0 # supress this sentence to flow further through the beam


            if np.sum(final_results_counter) == batch_size*beam_size: # when suffiecient hypothsis are obtained i.e. beam size hypotheis, break the process
                # print "Encounter a case"
                break

            # transpose batch to usual
            #print prev_state
            prev_state = [torch.stack(s,0) for s in prev_state]
            prev_state = torch.stack(prev_state,0)
            prev_state = torch.transpose(prev_state,0,1)
            tmp_state = torch.transpose(prev_state,1,2)
            prev_state = []
            for k in range(beam_size):
                prev_state.append(tuple((tmp_state[k,0,:,:].unsqueeze(0).contiguous(),tmp_state[k,1,:,:].unsqueeze(0).contiguous())))

            #print prev_state
            prev_words = np.transpose(np.array(prev_words),(1,0)) # set order [beam_size, batch_size]
            prev_words = Variable(torch.LongTensor(prev_words)).cuda()
            prev_hidden = np.transpose(np.array(prev_hidden),(1,0,2))
            prev_hidden = Variable(torch.FloatTensor(prev_hidden)).cuda()
            #print prev_hidden[0]
            #print prev_state[0]
            #print generated_sequence
            


        sampled_ids = []
        for k in range(batch_size):
            avg_log_probs = []
            for j in range(beam_size):
                try:
                    num_tokens = final_results[k,j,:].tolist().index(2)+1 #find the stop word and get the lenth of the sequence based on that
                except:
                    num_tokens = 1 # this case is when the number of hypotheis are not equal to beam size, i.e., durining the process sufficinet hypotheisis are not obtained
                probs = np.log(final_all_probs[k][j])/num_tokens

                avg_log_probs.append(probs)
            avg_log_probs = np.array(avg_log_probs)
            sort_order = np.argsort(avg_log_probs)
            sort_order[:] = sort_order[::-1]
            sort_generated_sequence  = final_results[k,sort_order,:]
            sampled_ids.append(sort_generated_sequence[0])
            #print sort_generated_sequence


        return np.asarray(sampled_ids)


    def rl_sample(self, video_features, flengths, max_len=20, sampling='multinomial'):
        sampled_ids = []
        state = None
        outputs = []
        batch_size, max_enc_steps, _ = video_features.shape
        context_mask = rnn_mask(flengths,max_enc_steps)
        hidden_output = Variable(torch.FloatTensor(batch_size,self.hidden_size).zero_()).cuda()
        inputs = self.embed(Variable(torch.LongTensor(np.ones([batch_size,1]))).cuda()).squeeze(1)
        for i in range(max_len):  # maximum sampling length
            c_t, alpha = self.atten(hidden_output, video_features, context_mask)
            inp = torch.cat((inputs, c_t), 1).unsqueeze(1)
            # inp = self.input_proj(inp)
            hidden_output,state = self.lstm(inp,state)
            hidden_output = hidden_output.squeeze(1)
            output = self.linear(hidden_output)  # (batch_size, vocab_size)
            outputs.append(output)
            prob = F.softmax(output, 1)
            if sampling == 'multinomial':
                predicted = torch.multinomial(prob, 1)
                predicted = predicted.squeeze(1)
            elif sampling == 'argmax':
                predicted = prob.max(1)[1]
            sampled_ids.append(predicted.unsqueeze(1))
            inputs = self.embed(predicted)
        sampled_ids = torch.cat(sampled_ids, 1)  # (batch_size, 20)
        outputs = torch.transpose(torch.stack(outputs), 0, 1) 
        return sampled_ids.squeeze(), outputs



class Attention(nn.Module):
    def __init__(self, args, enc_dim, dec_dim, attn_dim=None):
        super(Attention, self).__init__()
        
        self.args = args
        self.enc_dim = enc_dim
        self.dec_dim = dec_dim
        self.attn_dim = self.dec_dim if attn_dim is None else attn_dim
        if self.args.birnn:
            self.birnn = 2
        else:
            self.birnn = 1

        self.encoder_in = nn.Linear(self.enc_dim, self.attn_dim, bias=True)
        self.decoder_in = nn.Linear(self.dec_dim, self.attn_dim, bias=False)
        self.attn_linear = nn.Linear(self.attn_dim, 1, bias=False)
        self.init_weights()
 

    def init_weights(self):
        self.encoder_in.weight.data.uniform_(-0.08, 0.08)
        self.encoder_in.bias.data.fill_(0)
        self.decoder_in.weight.data.uniform_(-0.08, 0.08)
        self.attn_linear.weight.data.uniform_(-0.08, 0.08)


    def forward(self, dec_state, enc_states, mask, dag=None):
        """
        :param dec_state: 
            decoder hidden state of size batch_size x dec_dim
        :param enc_states:
            all encoder hidden states of size batch_size x max_enc_steps x enc_dim
        :param flengths:
            encoder video frame lengths of size batch_size
        """
        dec_contrib = self.decoder_in(dec_state)
        batch_size, max_enc_steps, _  = enc_states.size()
        enc_contrib = self.encoder_in(enc_states.contiguous().view(-1, self.enc_dim)).contiguous().view(batch_size, max_enc_steps, self.attn_dim)
        pre_attn = F.tanh(enc_contrib + dec_contrib.unsqueeze(1).expand_as(enc_contrib))
       
        
        energy = self.attn_linear(pre_attn.view(-1, self.attn_dim)).view(batch_size, max_enc_steps)
        alpha = F.softmax(energy, 1)
        # mask alpha and renormalize it
        alpha = alpha* mask
        alpha = torch.div(alpha, alpha.sum(1).unsqueeze(1).expand_as(alpha))

        context_vector = torch.bmm(alpha.unsqueeze(1), enc_states).squeeze(1) # (batch_size, enc_dim)

        return context_vector, alpha