import torch
import torch.nn as nn
import torch.nn.functional as F


class PairwiseBilinear(nn.Module):
    '''
    使用版本
    A bilinear module that deals with broadcasting for efficient memory usage.
    Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)
    Output: tensor of size (N x L1 x L2 x O)'''

    def __init__(self, input1_size, input2_size, output_size, bias=True):
        super().__init__()

        self.input1_size = input1_size
        self.input2_size = input2_size
        self.output_size = output_size
        # W size: [(head_fea_size+1),(dep_fea_size+1),output_size]
        # 无标签弧分类时 output_size=1
        # 标签分类时 output_size=len(labels)
        self.weight = nn.Parameter(torch.Tensor(input1_size, input2_size, output_size))
        self.bias = nn.Parameter(torch.Tensor(output_size)) if bias else 0

    def forward(self, input1, input2):
        input1_size = list(input1.size())
        input2_size = list(input2.size())
        output_size = [input1_size[0], input1_size[1], input2_size[1], self.output_size]

        # ((N x L1) x D1) * (D1 x (D2 x O)) -> (N x L1) x (D2 x O)
        # [(batch_size*seq_len),(head_feat_size+1)] * [(head_feat_size+1),((dep_feat_size+1))*output_size]
        # -> [(batch_size*seq_len),((dep_feat_size+1))*output_size]
        intermediate = torch.mm(input1.view(-1, input1_size[-1]),
                                self.weight.view(-1, self.input2_size * self.output_size))
        # (N x L2 x D2) -> (N x D2 x L2)
        # input2 size: [batch_size, (dep_feat_size+1), seq_len]
        input2 = input2.transpose(1, 2)
        # (N x (L1 x O) x D2) * (N x D2 x L2) -> (N x (L1 x O) x L2)
        # intermediate size:
        # [(batch_size*seq_len),((dep_feat_size+1))*output_size]
        #       ->[batch_size, (seq_len*output_size), (dep_feat_size+1)]

        # [batch_size, (seq_len*output_size), (dep_feat_size+1)] * [batch_size, (dep_feat_size+1), seq_len]
        # -> [batch_size, (seq_len*output_size), seq_len]
        output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)
        # (N x (L1 x O) x L2) -> (N x L1 x L2 x O)
        # output size: [batch_size, seq_len, seq_len, output_size]
        output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3)

        return output


class BiaffineScorer(nn.Module):
    def __init__(self, input1_size, input2_size, output_size):
        super().__init__()
        # 为什么+1??
        # 双仿变换的矩阵形式:
        # S=(H_head⊕1)·W·H_dep
        # 即:(d*d) = (d*(k+1)) * ((k+1)*k) * (k*d)
        self.W_bilin = nn.Bilinear(input1_size + 1, input2_size + 1, output_size)

        self.W_bilin.weight.data.zero_()
        self.W_bilin.bias.data.zero_()

    def forward(self, input1, input2):
        # input1 size:[batch_size, seq_len, feature_size]
        # input1.new_ones(*input1.size()[:-1], 1)'s size: [batch_size, seq_len, 1]
        input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size()) - 1)
        # 拼接后的size:[batch_size, seq_len, (feature_size+1)]
        input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size()) - 1)
        return self.W_bilin(input1, input2)


class PairwiseBiaffineScorer(nn.Module):
    def __init__(self, input1_size, input2_size, output_size):
        """
        使用版本
        :param input1_size:
        :param input2_size:
        :param output_size:双仿的分类空间
        """
        super().__init__()
        # 为什么+1:
        # 双仿变换的矩阵形式:
        # [(batch_size*seq_len),(head_feat_size+1)] * [(head_feat_size+1),((dep_feat_size+1))*output_size]
        #       mm-> [(batch_size*seq_len),((dep_feat_size+1))*output_size]
        # [(batch_size*seq_len),((dep_feat_size+1))*output_size]
        #       view-> [batch_size, (seq_len*output_size), (dep_feat_size+1)]
        # [batch_size, (seq_len*output_size), (dep_feat_size+1)] * [batch_size, (dep_feat_size+1), seq_len]
        #       bmm-> [batch_size, (seq_len*output_size), seq_len]
        # [batch_size, (seq_len*output_size), seq_len]
        #       view-> [batch_size, seq_len, seq_len, output_size]
        self.W_bilin = PairwiseBilinear(input1_size + 1, input2_size + 1, output_size)

        self.W_bilin.weight.data.zero_()
        self.W_bilin.bias.data.zero_()

    def forward(self, input1, input2):
        # input1 size:[batch_size, seq_len, feature_size]
        # input1.new_ones(*input1.size()[:-1], 1)'s size: [batch_size, seq_len, 1]
        input1 = torch.cat([input1, input1.new_ones(*input1.size()[:-1], 1)], len(input1.size()) - 1)
        # 拼接后的size:[batch_size, seq_len, (feature_size+1)]
        input2 = torch.cat([input2, input2.new_ones(*input2.size()[:-1], 1)], len(input2.size()) - 1)
        return self.W_bilin(input1, input2)


class DirectBiaffineScorer(nn.Module):
    def __init__(self, input1_size, input2_size, output_size, pairwise=True):
        super().__init__()
        if pairwise:
            self.scorer = PairwiseBiaffineScorer(input1_size, input2_size, output_size)
        else:
            self.scorer = BiaffineScorer(input1_size, input2_size, output_size)

    def forward(self, input1, input2):
        return self.scorer(input1, input2)


class DeepBiaffineScorer(nn.Module):
    def __init__(self, input1_size, input2_size, hidden_size, output_size, hidden_func=F.relu, dropout=0,
                 pairwise=True):
        """
        使用版本
        :param input1_size:
        :param input2_size:
        :param hidden_size:
        :param output_size: 双仿的分类空间
        :param hidden_func:
        :param dropout:
        :param pairwise:
        """
        super().__init__()
        # 先对输入做两个线性变换得到两个H_dep、H_head
        self.W1 = nn.Linear(input1_size, hidden_size)
        self.W2 = nn.Linear(input2_size, hidden_size)
        # 默认经过relu激活函数:
        self.hidden_func = hidden_func
        if pairwise:
            self.scorer = PairwiseBiaffineScorer(hidden_size, hidden_size, output_size)
        else:
            self.scorer = BiaffineScorer(hidden_size, hidden_size, output_size)
        # 进入双仿前dropout:
        self.dropout = nn.Dropout(dropout)

    def forward(self, input1, input2):
        return self.scorer(self.dropout(self.hidden_func(self.W1(input1))),
                           self.dropout(self.hidden_func(self.W2(input2))))


if __name__ == "__main__":
    x1 = torch.randn(2, 3, 4)
    x2 = torch.randn(2, 3, 5)
    scorer = DeepBiaffineScorer(4, 5, 6, 7)
    print(scorer(x1, x2))
    res = scorer(x1, x2)
    print(res.size())