# copied from Apazke's original tutorial @ https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
import torch
from torch.autograd import Function
from numpy.fft import rfft2, irfft2


class BadFFTFunction(Function):

    def forward(self, input):
        numpy_input = input.detach().numpy()
        result = abs(rfft2(numpy_input))
        return input.new(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return grad_output.new(result)


def incorrect_fft(input):
    return BadFFTFunction()(input)


input = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(input)
print(result)
result.backward(torch.randn(result.size()))
print(input)


class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter):
        input, filter = input.detach(), filter.detach()  # detach so we can cast to NumPy
        result = correlate2d(input.numpy(), filter.detach().numpy(), mode='valid')
        ctx.save_for_backward(input, filter)
        return input.new(result)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input, filter = ctx.saved_tensors
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')

        return grad_output.new_tensor(grad_input), grad_output.new_tensor(grad_filter)


class ScipyConv2d(Module):

    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(kh, kw))

    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter)


module = ScipyConv2d(3, 3)
print(list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print(output)
output.backward(torch.randn(8, 8))
print(input.grad)