import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


def convolutionalize(modules, input_size):
    """
    Recast `modules` into fully convolutional form.

    The conversion transfers weights and infers kernel sizes from the
    `input_size` and modules' action on it.

    n.b. This only handles the conversion of linear/fully-connected modules,
    although other module types could require conversion for correctness.

    """
    fully_conv_modules = []
    x = torch.zeros((1, ) + input_size)
    for m in modules:
        if isinstance(m, nn.Linear):
            n = nn.Conv2d(x.size(1), m.weight.size(0), kernel_size=(x.size(2), x.size(3)))
            n.weight.data.view(-1).copy_(m.weight.data.view(-1))
            n.bias.data.view(-1).copy_(m.bias.data.view(-1))
            m = n
        fully_conv_modules.append(m)
        x = m(x)
    return fully_conv_modules


def bilinear_kernel(size, normalize=False):
    """
    Make a 2D bilinear kernel suitable for upsampling/downsampling with
    normalize=False/True. The kernel is size x size square.

    Take
        size: kernel size (square)
        normalize: whether kernel sums to 1 (True) or not

    Give
        kernel: np.array with bilinear kernel coefficient
    """
    factor = (size + 1) // 2
    if size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:size, :size]
    kernel = (1 - abs(og[0] - center) / factor) * \
             (1 - abs(og[1] - center) / factor)
    if normalize:
        kernel /= kernel.sum()
    return kernel


class Interpolator(nn.Module):
    """
    Interpolate by de/up/backward convolution with a bilinear kernel.

    Take
        channel_dim: the input channel dimension
        rate: upsampling rate, that is 4 -> 4x upsampling
        odd: the kernel parity, which is too much to explain here for now, but
             will be handled automagically in the future, promise.
        normalize: whether kernel sums to 1
    """
    def __init__(self, channel_dim, rate, odd=True, normalize=False):
        super().__init__()
        self.rate = rate
        ksize = rate * 2
        if odd:
            ksize -= 1
        # set weights to within-channel bilinear interpolation
        kernel = torch.from_numpy(bilinear_kernel(ksize, normalize))
        weight = torch.zeros(channel_dim, channel_dim, ksize, ksize)
        for k in range(channel_dim):
            weight[k, k] = kernel
        # fix weights
        self.weight = nn.Parameter(weight, requires_grad=False)

    def forward(self, x):
        # no groups (for speed with current pytorch impl.) and no bias
        return F.conv_transpose2d(x, self.weight, stride=self.rate)


class Downsampler(Interpolator):
    '''
    Downsample with a normalized bilinear kernel.
    '''
    def __init__(self, channel_dim, rate, odd=True):
        super().__init__(channel_dim, rate, odd, True)

    def forward(self, x):
        return F.conv2d(x, self.weight, stride=self.rate)