import math
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F



class biDafAttn(nn.Module):
    def __init__(self, channel_size):
        super(biDafAttn, self).__init__()
        """
        This method do biDaf from s2 to s1:
            The return value will have the same size as s1.
        :param channel_size: Hidden size of the input
        """
        # self.mlp = nn.Linear(channel_size * 3, 1, bias=False)

    def similarity(self, s1, l1, s2, l2):
        """
        :param s1: [B, t1, D]
        :param l1: [B]
        :param s2: [B, t2, D]
        :param l2: [B]
        :return:
        """
        batch_size = s1.size(0)
        t1 = s1.size(1)
        t2 = s2.size(1)
        S = torch.bmm(s1, s2.transpose(1,
                                       2))  # [B, t1, D] * [B, D, t2] -> [B, t1, t2] S is the similarity matrix from biDAF paper. [B, T1, T2]

        s_mask = S.data.new(*S.size()).fill_(1).byte()  # [B, T1, T2]
        # Init similarity mask using lengths
        for i, (l_1, l_2) in enumerate(zip(l1, l2)):
            s_mask[i][:l_1, :l_2] = 0

        s_mask = Variable(s_mask)
        S.data.masked_fill_(s_mask.data.byte(), -math.inf)
        return S

    def get_U_tile(self, S, s2):
        a_weight = F.softmax(S, dim=2)  # [B, t1, t2]
        a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)  # remove nan from softmax on -inf
        U_tile = torch.bmm(a_weight, s2)  # [B, t1, t2] * [B, t2, D] -> [B, t1, D]
        return U_tile

    def get_both_tile(self, S, s1, s2):
        a_weight = F.softmax(S, dim=2)  # [B, t1, t2]
        a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)  # remove nan from softmax on -inf
        U_tile = torch.bmm(a_weight, s2)  # [B, t1, t2] * [B, t2, D] -> [B, t1, D]

        a1_weight = F.softmax(S, dim=1)  # [B, t1, t2]
        a1_weight.data.masked_fill_(a1_weight.data != a1_weight.data, 0)  # remove nan from softmax on -inf
        U1_tile = torch.bmm(a1_weight.transpose(1, 2), s1)  # [B, t2, t1] * [B, t1, D] -> [B, t2, D]
        return U_tile, U1_tile

    def forward(self, s1, l1, s2, l2):
        S = self.similarity(s1, l1, s2, l2)
        U_tile = self.get_U_tile(S, s2)
        return U_tile


class CoattMaxPool(nn.Module):
    def __init__(self, args):
         
        super(CoattMaxPool, self).__init__()
        h_size = [300, 300]
        d = 300
        mlp_d = 300
        v_size=args.max_snli_vocab_size
        max_l= None
        num_of_class=3
        drop_r=args.dropout
        featurizer=None
        itos=None
        with_emlo=False
        activation_type='relu'
        self.h_size = h_size
        self.e_embd = nn.Embedding(v_size, d)
        self.embd_dropout = nn.Dropout(drop_r)
        self.featurizer = featurizer
        self.itos = itos
        self.args = args
        if self.featurizer is not None:
            fcount = self.featurizer.n_context_features()
        else:
            fcount = 0

        self.emlo_embedding_d = 0
        if with_emlo:
            self.emlo_ee = ElmoEmbedder(cuda_device=n_device)
            self.emlo_embedding_d = 1024

        self.emlo_gamma = nn.Parameter(torch.FloatTensor([1]))
        self.emlo_s_vector = nn.Parameter(torch.FloatTensor([1, 1, 1]))



        if self.args.cell_type=='gru':
            self.lstm = nn.GRU(input_size=d + fcount + self.emlo_embedding_d * 1, hidden_size=h_size[0],
                        num_layers=1, bidirectional=True, batch_first=True)

            self.lstm_1 = nn.GRU(input_size=h_size[1] + fcount + self.emlo_embedding_d * 1, hidden_size=h_size[1],
                          num_layers=1, bidirectional=True, batch_first=True)

        else:
            self.lstm = nn.LSTM(input_size=d + fcount + self.emlo_embedding_d * 1, hidden_size=h_size[0],
                        num_layers=1, bidirectional=True, batch_first=True)

            self.lstm_1 = nn.LSTM(input_size=h_size[1] + fcount + self.emlo_embedding_d * 1, hidden_size=h_size[1],
                          num_layers=1, bidirectional=True, batch_first=True)

        self.projection = nn.Linear(h_size[0] * 2 * 4, h_size[1])
        self.projection_dropout = nn.Dropout(drop_r)

        self.max_l = max_l
        self.bidaf = biDafAttn(300)

        self.mlp_1 = nn.Linear(h_size[1] * 2 * 4, mlp_d)
        self.sm = nn.Linear(mlp_d, num_of_class)
        if activation_type == 'relu':
            activation = nn.ReLU()
            # self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(drop_r), self.sm])
        elif activation_type == 'tanh':
            activation = nn.Tanh()
        else:
            raise ValueError("Not a valid activation!")

        self.classifier = nn.Sequential(*[nn.Dropout(drop_r), self.mlp_1, activation, nn.Dropout(drop_r), self.sm])





    def count_params(self):
        total_c = 0
        for param in self.parameters():
            if len(param.size()) == 2:
                d1, d2 = param.size()[0], param.size()[1]
                total_c += d1 * d2
        print("Total count:", total_c)

    def display(self):
        for param in self.parameters():
            print(param.data.size())

    def forward(self, s1, l1, s2, l2):  # [B, T]
        if self.max_l:
            max_l = min(s1.size(1), self.max_l)
            max_l = max(1, max_l)
            max_s1_l = min(max(l1), max_l)

            l1 = l1.clamp(min=1, max=max_s1_l)
            if s1.size(1) > max_s1_l:
                s1 = s1[:, :max_s1_l]

        s1_max_l = s1.size(1)

        if self.max_l:
            max_l = min(s2.size(1), self.max_l)
            max_l = max(1, max_l)
            max_s2_l = min(max(l2), max_l)

            l2 = l2.clamp(min=1, max=max_s2_l)
            if s2.size(1) > max_s2_l:
                s2 = s2[:, :max_s2_l]

        s2_max_l = s2.size(1)

        batch_size = s1.size(0)

        th_packed_f_s1, th_packed_f_s2 = None, None
        emlo_s1_sum, emlo_s2_sum = None, None



        p_s1 = self.e_embd(s1)
        p_s2 = self.e_embd(s2)

        p_s1 = self.embd_dropout(p_s1)  # Embedding dropout
        p_s2 = self.embd_dropout(p_s2)  # Embedding dropout

        feature_p_s1 = torch.cat([seq for seq in [p_s1, th_packed_f_s1, emlo_s1_sum] if seq is not None], dim=2)
        feature_p_s2 = torch.cat([seq for seq in [p_s2, th_packed_f_s2, emlo_s2_sum] if seq is not None], dim=2)

        s1_layer1_out = self.auto_rnn(self.lstm, feature_p_s1, l1)
        s2_layer1_out = self.auto_rnn(self.lstm, feature_p_s2, l2)

        S = self.bidaf.similarity(s1_layer1_out, l1, s2_layer1_out, l2)
        s1_att, s2_att = self.bidaf.get_both_tile(S, s1_layer1_out, s2_layer1_out)



        s1_coattentioned = torch.cat([s1_layer1_out, s1_att, s1_layer1_out - s1_att,
                                      s1_layer1_out * s1_att], dim=2)

        s2_coattentioned = torch.cat([s2_layer1_out, s2_att, s2_layer1_out - s2_att,
                                      s2_layer1_out * s2_att], dim=2)

        p_s1_coattentioned = self.projection_dropout(F.relu(self.projection(s1_coattentioned)))
        p_s2_coattentioned = self.projection_dropout(F.relu(self.projection(s2_coattentioned)))

        s1_coatt_features = torch.cat(
            [seq for seq in [p_s1_coattentioned, th_packed_f_s1, emlo_s1_sum] if seq is not None], dim=2)
        s2_coatt_features = torch.cat(
            [seq for seq in [p_s2_coattentioned, th_packed_f_s2, emlo_s2_sum] if seq is not None], dim=2)

        s1_layer2_out = self.auto_rnn(self.lstm_1, s1_coatt_features, l1)
        s2_layer2_out = self.auto_rnn(self.lstm_1, s2_coatt_features, l2)

        s1_lay2_maxout = max_along_time(s1_layer2_out, l1)
        s2_lay2_maxout = max_along_time(s2_layer2_out, l2)

        s1_lay2_avgout = avg_along_time(s1_layer2_out, l1)
        s2_lay2_avgout = avg_along_time(s2_layer2_out, l2)

        features = torch.cat([s1_lay2_maxout, s2_lay2_maxout,
                              s1_lay2_avgout, s2_lay2_avgout], dim=1)

        logits = self.classifier(features)
        probs = F.softmax(logits, 1)
        pred = torch.max(probs, 1)[1]

        return logits, probs, pred


    
    def auto_rnn(self, rnn: nn.RNN, seqs, lengths, batch_first=True, init_state=None, output_last_states=False):
        batch_size = seqs.size(0) if batch_first else seqs.size(1)
        state_shape = get_state_shape(rnn, batch_size, rnn.bidirectional)

        if not init_state:
            h0 = c0 = Variable(seqs.data.new(*state_shape).zero_())
        else:
            h0 = init_state['h0'].expand(state_shape)
            c0 = init_state['c0'].expand(state_shape)

        packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths, batch_first)
        if self.args.cell_type == 'gru':
            output, hn = rnn(packed_pinputs, h0)
        else:
            output, (hn, cn) = rnn(packed_pinputs, (h0, c0))
        output = unpack_from_rnn_seq(output, r_index, batch_first)

        if not output_last_states:
            return output
        else:
            return output, (hn, cn)




def pad_1d(seq, pad_l):
    """
    The seq is a sequence having shape [T, ..]. Note: The seq contains only one instance. This is not batched.
    
    :param seq:  Input sequence with shape [T, ...]
    :param pad_l: The required pad_length.
    :return:  Output sequence will have shape [Pad_L, ...]
    """
    l = seq.size(0)
    if l >= pad_l:
        return seq[:pad_l, ]  # Truncate the length if the length is bigger than required padded_length.
    else:
        pad_seq = Variable(seq.data.new(pad_l - l, *seq.size()[1:]).zero_())  # Requires_grad is False
        return torch.cat([seq, pad_seq], dim=0)


def pad(seqs, length, batch_first=True):
    #TODO The method seems useless to me. Delete?
    """
    Padding the sequence to a fixed length.
    
    :param seqs: [B, T, D] or [B, T] if batch_first else [T, B * D] or [T, B]
    :param length: [B]
    :param batch_first:
    :return:
    """
    if batch_first:
        # [B * T * D]
        if length <= seqs.size(1):
            return seqs[:, :length]
        else:
            batch_size = seqs.size(0)
            pad_seq = Variable(seqs.data.new(batch_size, length - seqs.size(1), *seqs.size()[2:]).zero_())
            # [B * T * D]
            return torch.cat([seqs, pad_seq], dim=1)
    else:
        # [T * B * D]
        if length <= seqs.size(0):
            return seqs
        else:
            return torch.cat([seqs, Variable(seqs.data.new(length - seqs.size(0), *seqs.size()[1:]).zero_())])


def batch_first2time_first(inputs):
    """
    Convert input from batch_first to time_first:
    [B, T, D] -> [T, B, D]
    
    :param inputs:
    :return:
    """
    return torch.transpose(inputs, 0, 1)


def time_first2batch_first(inputs):
    """
    Convert input from batch_first to time_first:
    [T, B, D] -> [B, T, D] 
    
    :param inputs:
    :return:
    """
    return torch.transpose(inputs, 0, 1)


def get_state_shape(rnn: nn.RNN, batch_size, bidirectional=False):
    """
    Return the state shape of a given RNN. This is helpful when you want to create a init state for RNN.
    Example:
    c0 = h0 = Variable(src_seq_p.data.new(*get_state_shape([your rnn], 3, bidirectional)).zero_())
    
    :param rnn: nn.LSTM, nn.GRU or subclass of nn.RNN
    :param batch_size:  
    :param bidirectional:  
    :return: 
    """
    if bidirectional:
        return rnn.num_layers * 2, batch_size, rnn.hidden_size
    else:
        return rnn.num_layers, batch_size, rnn.hidden_size


def pack_list_sequence(inputs, l, max_l=None, batch_first=True):
    """
    Pack a batch of Tensor into one Tensor with max_length.
    :param inputs: 
    :param l: 
    :param max_l: The max_length of the packed sequence.
    :param batch_first: 
    :return: 
    """
    batch_list = []
    max_l = max(list(l)) if not max_l else max_l
    batch_size = len(inputs)

    for b_i in range(batch_size):
        batch_list.append(pad_1d(inputs[b_i], max_l))
    pack_batch_list = torch.stack(batch_list, dim=1) if not batch_first \
        else torch.stack(batch_list, dim=0)
    return pack_batch_list


def pack_for_rnn_seq(inputs, lengths, batch_first=True):
    """
    :param inputs: Shape of the input should be [B, T, D] if batch_first else [T, B, D].
    :param lengths:  [B]
    :param batch_first: 
    :return: 
    """
    if not batch_first:
        _, sorted_indices = lengths.sort()
        '''
            Reverse to decreasing order
        '''
        r_index = reversed(list(sorted_indices))

        s_inputs_list = []
        lengths_list = []
        reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)

        for j, i in enumerate(r_index):
            s_inputs_list.append(inputs[:, i, :].unsqueeze(1))
            lengths_list.append(lengths[i])
            reverse_indices[i] = j

        reverse_indices = list(reverse_indices)

        s_inputs = torch.cat(s_inputs_list, 1)
        packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list)

        return packed_seq, reverse_indices

    else:
        #print(lengths)
        #_, sorted_indices = lengths.sort()
        r_index = reversed(list(np.argsort(lengths)))
        '''
            Reverse to decreasing order
        '''
        #r_index = reversed(list(sorted_indices))

        s_inputs_list = []
        lengths_list = []
        #reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)
        reverse_indices = np.zeros(len(lengths), dtype=np.int64)

        for j, i in enumerate(r_index):
            s_inputs_list.append(inputs[i, :, :])
            lengths_list.append(lengths[i])
            reverse_indices[i] = j

        reverse_indices = list(reverse_indices)

        s_inputs = torch.stack(s_inputs_list, dim=0)
        packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list, batch_first=batch_first)

        return packed_seq, reverse_indices


def unpack_from_rnn_seq(packed_seq, reverse_indices, batch_first=True):
    unpacked_seq, _ = nn.utils.rnn.pad_packed_sequence(packed_seq, batch_first=batch_first)
    s_inputs_list = []

    if not batch_first:
        for i in reverse_indices:
            s_inputs_list.append(unpacked_seq[:, i, :].unsqueeze(1))
        return torch.cat(s_inputs_list, 1)
    else:
        for i in reverse_indices:
            s_inputs_list.append(unpacked_seq[i, :, :].unsqueeze(0))
        return torch.cat(s_inputs_list, 0)


def auto_rnn(rnn: nn.RNN, seqs, lengths, batch_first=True, init_state=None, output_last_states=False):
    batch_size = seqs.size(0) if batch_first else seqs.size(1)
    state_shape = get_state_shape(rnn, batch_size, rnn.bidirectional)

    if not init_state:
        h0 = c0 = Variable(seqs.data.new(*state_shape).zero_())
    else:
        h0 = init_state['h0'].expand(state_shape)
        c0 = init_state['c0'].expand(state_shape)

    packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths, batch_first)
    output, (hn, cn) = rnn(packed_pinputs, (h0, c0))
    output = unpack_from_rnn_seq(output, r_index, batch_first)

    if not output_last_states:
        return output
    else:
        return output, (hn, cn)


def pack_sequence_for_linear(inputs, lengths, batch_first=True):
    """
    :param inputs: [B, T, D] if batch_first 
    :param lengths:  [B]
    :param batch_first:  
    :return: 
    """
    batch_list = []
    if batch_first:
        for i, l in enumerate(lengths):
            # print(inputs[i, :l].size())
            batch_list.append(inputs[i, :l])
        packed_sequence = torch.cat(batch_list, 0)
        # if chuck:
        #     return list(torch.chunk(packed_sequence, chuck, dim=0))
        # else:
        return packed_sequence
    else:
        raise NotImplemented()


def chucked_forward(inputs, net, chuck=None):
    if not chuck:
        return net(inputs)
    else:
        output_list = [net(chuck) for chuck in torch.chunk(inputs, chuck, dim=0)]
        return torch.cat(output_list, dim=0)


def unpack_sequence_for_linear(inputs, lengths, batch_first=True):
    batch_list = []
    max_l = max(lengths)

    if not isinstance(inputs, list):
        inputs = [inputs]
    inputs = torch.cat(inputs)

    if batch_first:
        start = 0
        for l in lengths:
            end = start + l
            batch_list.append(pad_1d(inputs[start:end], max_l))
            start = end
        return torch.stack(batch_list)
    else:
        raise NotImplemented()


def seq2seq_cross_entropy(logits, label, l, chuck=None, sos_truncate=True):
    """
    :param logits: [exB, V] : exB = sum(l)
    :param label: [B] : a batch of Label
    :param l: [B] : a batch of LongTensor indicating the lengths of each inputs
    :param chuck: Number of chuck to process
    :return: A loss value
    """
    packed_label = pack_sequence_for_linear(label, l)
    cross_entropy_loss = functools.partial(F.cross_entropy, size_average=False)
    total = sum(l)

    assert total == logits.size(0) or packed_label.size(0) == logits.size(0),\
        "logits length mismatch with label length."

    if chuck:
        logits_losses = 0
        for x, y in zip(torch.chunk(logits, chuck, dim=0), torch.chunk(packed_label, chuck, dim=0)):
            logits_losses += cross_entropy_loss(x, y)
        return logits_losses * (1 / total)
    else:
        return cross_entropy_loss(logits, packed_label) * (1 / total)


def avg_along_time(inputs, lengths, list_in=False):
    """
    :param inputs: [B, T, D] 
    :param lengths:  [B]
    :return: [B * D] max_along_time
    :param list_in: 
    """
    ls = list(lengths)

    if not list_in:
        b_seq_max_list = []
        for i, l in enumerate(ls):
            seq_i = inputs[i, :l, :]
            seq_i_max = seq_i.mean(dim=0)
            seq_i_max = seq_i_max.squeeze()
            b_seq_max_list.append(seq_i_max)

        return torch.stack(b_seq_max_list)
    else:
        b_seq_max_list = []
        for i, l in enumerate(ls):
            seq_i = inputs[i]
            seq_i_max = seq_i.mean(dim=0)
            seq_i_max = seq_i_max.squeeze()
            b_seq_max_list.append(seq_i_max)

        return torch.stack(b_seq_max_list)



def max_along_time(inputs, lengths, list_in=False):
    """
    :param inputs: [B, T, D] 
    :param lengths:  [B]
    :return: [B * D] max_along_time
    :param list_in: 
    """
    ls = list(lengths)

    if not list_in:
        b_seq_max_list = []
        for i, l in enumerate(ls):
            seq_i = inputs[i, :l, :]
            seq_i_max, _ = seq_i.max(dim=0)
            seq_i_max = seq_i_max.squeeze()
            b_seq_max_list.append(seq_i_max)

        return torch.stack(b_seq_max_list)
    else:
        b_seq_max_list = []
        for i, l in enumerate(ls):
            seq_i = inputs[i]
            seq_i_max, _ = seq_i.max(dim=0)
            seq_i_max = seq_i_max.squeeze()
            b_seq_max_list.append(seq_i_max)

        return torch.stack(b_seq_max_list)