__author__ = 'max'

from overrides import overrides
from collections import OrderedDict
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class BiLinear(nn.Module):
    """
    Bi-linear layer
    """
    def __init__(self, left_features, right_features, out_features, bias=True):
        """

        Args:
            left_features: size of left input
            right_features: size of right input
            out_features: size of output
            bias: If set to False, the layer will not learn an additive bias.
                Default: True
        """
        super(BiLinear, self).__init__()
        self.left_features = left_features
        self.right_features = right_features
        self.out_features = out_features

        self.U = Parameter(torch.Tensor(self.out_features, self.left_features, self.right_features))
        self.weight_left = Parameter(torch.Tensor(self.out_features, self.left_features))
        self.weight_right = Parameter(torch.Tensor(self.out_features, self.right_features))

        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_left)
        nn.init.xavier_uniform_(self.weight_right)
        nn.init.constant_(self.bias, 0.)
        nn.init.xavier_uniform_(self.U)

    def forward(self, input_left, input_right):
        """

        Args:
            input_left: Tensor
                the left input tensor with shape = [batch1, batch2, ..., left_features]
            input_right: Tensor
                the right input tensor with shape = [batch1, batch2, ..., right_features]

        Returns:

        """

        batch_size = input_left.size()[:-1]
        batch = int(np.prod(batch_size))

        # convert left and right input to matrices [batch, left_features], [batch, right_features]
        input_left = input_left.view(batch, self.left_features)
        input_right = input_right.view(batch, self.right_features)

        # output [batch, out_features]
        output = F.bilinear(input_left, input_right, self.U, self.bias)
        output = output + F.linear(input_left, self.weight_left, None) + F.linear(input_right, self.weight_right, None)
        # convert back to [batch1, batch2, ..., out_features]
        return output.view(batch_size + (self.out_features, ))

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + 'left_features=' + str(self.left_features) \
               + ', right_features=' + str(self.right_features) \
               + ', out_features=' + str(self.out_features) + ')'


class BiAffine(nn.Module):
    '''
    Bi-Affine energy layer.
    '''

    def __init__(self, key_dim, query_dim):
        '''

        Args:
            key_dim: int
                the dimension of the key.
            query_dim: int
                the dimension of the query.

        '''
        super(BiAffine, self).__init__()
        self.key_dim = key_dim
        self.query_dim = query_dim

        self.q_weight = Parameter(torch.Tensor(self.query_dim))
        self.key_weight = Parameter(torch.Tensor(self.key_dim))
        self.b = Parameter(torch.Tensor(1))
        self.U = Parameter(torch.Tensor(self.query_dim, self.key_dim))
        self.reset_parameters()

    def reset_parameters(self):
        bound = 1 / math.sqrt(self.query_dim)
        nn.init.uniform_(self.q_weight, -bound, bound)
        bound = 1 / math.sqrt(self.key_dim)
        nn.init.uniform_(self.key_weight, -bound, bound)
        nn.init.constant_(self.b, 0.)
        nn.init.xavier_uniform_(self.U)

    def forward(self, query, key, mask_query=None, mask_key=None):
        """

        Args:
            query: Tensor
                the decoder input tensor with shape = [batch, length_query, query_dim]
            key: Tensor
                the child input tensor with shape = [batch, length_key, key_dim]
            mask_query: Tensor or None
                the mask tensor for decoder with shape = [batch, length_query]
            mask_key: Tensor or None
                the mask tensor for encoder with shape = [batch, length_key]

        Returns: Tensor
            the energy tensor with shape = [batch, length_query, length_key]

        """
        # output shape [batch, length_query, length_key]
        # compute bi-affine part
        # [batch, length_query, query_dim] * [query_dim, key_dim]
        # output shape [batch, length_query, key_dim]
        output = torch.matmul(query, self.U)
        # [batch, length_query, key_dim] * [batch, key_dim, length_key]
        # output shape [batch, length_query, length_key]
        output = torch.matmul(output, key.transpose(1, 2))

        # compute query part: [query_dim] * [batch, query_dim, length_query]
        # the output shape is [batch, length_query, 1]
        out_q = torch.matmul(self.q_weight, query.transpose(1, 2)).unsqueeze(2)
        # compute decoder part: [key_dim] * [batch, key_dim, length_key]
        # the output shape is [batch, 1, length_key]
        out_k = torch.matmul(self.key_weight, key.transpose(1, 2)).unsqueeze(1)

        output = output + out_q + out_k + self.b

        if mask_query is not None:
            output = output * mask_query.unsqueeze(2)
        if mask_key is not None:
            output = output * mask_key.unsqueeze(1)
        return output

    @overrides
    def extra_repr(self):
        s = '{key_dim}, {query_dim}'
        return s.format(**self.__dict__)


class CharCNN(nn.Module):
    """
    CNN layers for characters
    """
    def __init__(self, num_layers, in_channels, out_channels, hidden_channels=None, activation='elu'):
        super(CharCNN, self).__init__()
        assert activation in ['elu', 'tanh']
        if activation == 'elu':
            ACT = nn.ELU
        else:
            ACT = nn.Tanh
        layers = list()
        for i in range(num_layers - 1):
            layers.append(('conv{}'.format(i), nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1)))
            layers.append(('act{}'.format(i), ACT()))
            in_channels = hidden_channels
        layers.append(('conv_top', nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)))
        layers.append(('act_top', ACT()))
        self.act = ACT
        self.net = nn.Sequential(OrderedDict(layers))

        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.net:
            if isinstance(layer, nn.Conv1d):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.constant_(layer.bias, 0.)
            else:
                assert isinstance(layer, self.act)

    def forward(self, char):
        """

        Args:
            char: Tensor
                the input tensor of character [batch, sent_length, char_length, in_channels]

        Returns: Tensor
            output character encoding with shape [batch, sent_length, in_channels]

        """
        # [batch, sent_length, char_length, in_channels]
        char_size = char.size()
        # first transform to [batch * sent_length, char_length, in_channels]
        # then transpose to [batch * sent_length, in_channels, char_length]
        char = char.view(-1, char_size[2], char_size[3]).transpose(1, 2)
        # [batch * sent_length, out_channels, char_length]
        char = self.net(char).max(dim=2)[0]
        # [batch, sent_length, out_channels]
        return char.view(char_size[0], char_size[1], -1)