import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class IRevNetDownsampling(nn.Module):
    '''The invertible spatial downsampling used in i-RevNet, adapted from
    https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py'''

    def __init__(self, dims_in):
        super().__init__()
        self.block_size = 2
        self.block_size_sq = self.block_size**2

    def forward(self, x, rev=False):
        input = x[0]
        if not rev:
            output = input.permute(0, 2, 3, 1)
            (batch_size, s_height, s_width, s_depth) = output.size()
            d_depth = s_depth * self.block_size_sq
            d_height = int(s_height / self.block_size)
            t_1 = output.split(self.block_size, 2)
            stack = [t_t.contiguous().view(batch_size, d_height, d_depth)
                     for t_t in t_1]
            output = torch.stack(stack, 1)
            output = output.permute(0, 2, 1, 3)
            output = output.permute(0, 3, 1, 2)
            return [output.contiguous()]
            # (own attempt)
            # return torch.cat([
            #         x[:, :,  ::2,  ::2],
            #         x[:, :, 1::2,  ::2],
            #         x[:, :,  ::2, 1::2],
            #         x[:, :, 1::2, 1::2]
            #     ], dim=1)
        else:
            output = input.permute(0, 2, 3, 1)
            (batch_size, d_height, d_width, d_depth) = output.size()
            s_depth = int(d_depth / self.block_size_sq)
            s_width = int(d_width * self.block_size)
            s_height = int(d_height * self.block_size)
            t_1 = output.contiguous().view(batch_size, d_height, d_width,
                                           self.block_size_sq, s_depth)
            spl = t_1.split(self.block_size, 3)
            stack = [t_t.contiguous().view(batch_size, d_height, s_width,
                                           s_depth) for t_t in spl]
            output = torch.stack(stack, 0).transpose(0, 1)
            output = output.permute(0, 2, 1, 3, 4).contiguous()
            output = output.view(batch_size, s_height, s_width, s_depth)
            output = output.permute(0, 3, 1, 2)
            return [output.contiguous()]

    def jacobian(self, x, rev=False):
        # TODO respect batch dimension and .cuda()
        return 0

    def output_dims(self, input_dims):
        assert len(input_dims) == 1, "Can only use 1 input"
        c, w, h = input_dims[0]
        c2, w2, h2 = c*4, w//2, h//2
        assert c*h*w == c2*h2*w2, "Uneven input dimensions"
        return [(c2, w2, h2)]


class IRevNetUpsampling(IRevNetDownsampling):
    '''Just the exact opposite of the i_revnet_downsampling layer.'''

    def __init__(self, dims_in):
        super().__init__(dims_in)

    def forward(self, x, rev=False):
        return super().forward(x, rev=not rev)

    def jacobian(self, x, rev=False):
        # TODO respect batch dimension and .cuda()
        return 0

    def output_dims(self, input_dims):
        assert len(input_dims) == 1, "Can only use 1 input"
        c, w, h = input_dims[0]
        c2, w2, h2 = c//4, w*2, h*2
        assert c*h*w == c2*h2*w2, "Uneven input dimensions"
        return [(c2, w2, h2)]


class HaarDownsampling(nn.Module):
    '''Uses Haar wavelets to split each channel into 4 channels, with half the
    width and height.'''

    def __init__(self, dims_in, order_by_wavelet=False, rebalance=1.):
        super().__init__()

        self.in_channels = dims_in[0][0]
        self.fac_fwd = 0.5 * rebalance
        self.fac_rev = 0.5 / rebalance
        self.haar_weights = torch.ones(4,1,2,2)

        self.haar_weights[1, 0, 0, 1] = -1
        self.haar_weights[1, 0, 1, 1] = -1

        self.haar_weights[2, 0, 1, 0] = -1
        self.haar_weights[2, 0, 1, 1] = -1

        self.haar_weights[3, 0, 1, 0] = -1
        self.haar_weights[3, 0, 0, 1] = -1

        self.haar_weights = torch.cat([self.haar_weights]*self.in_channels, 0)
        self.haar_weights = nn.Parameter(self.haar_weights)
        self.haar_weights.requires_grad = False

        self.permute = order_by_wavelet
        self.last_jac = None

        if self.permute:
            permutation = []
            for i in range(4):
                permutation += [i+4*j for j in range(self.in_channels)]

            self.perm = torch.LongTensor(permutation)
            self.perm_inv = torch.LongTensor(permutation)

            for i, p in enumerate(self.perm):
                self.perm_inv[p] = i

    def forward(self, x, rev=False):
        if not rev:
            self.last_jac = self.elements / 4 * (np.log(16.) + 4 * np.log(self.fac_fwd))
            out = F.conv2d(x[0], self.haar_weights,
                           bias=None, stride=2, groups=self.in_channels)
            if self.permute:
                return [out[:, self.perm] * self.fac_fwd]
            else:
                return [out * self.fac_fwd]

        else:
            self.last_jac = self.elements / 4 * (np.log(16.) + 4 * np.log(self.fac_rev))
            if self.permute:
                x_perm = x[0][:, self.perm_inv]
            else:
                x_perm = x[0]

            return [F.conv_transpose2d(x_perm * self.fac_rev, self.haar_weights,
                                     bias=None, stride=2, groups=self.in_channels)]

    def jacobian(self, x, rev=False):
        # TODO respect batch dimension and .cuda()
        return self.last_jac

    def output_dims(self, input_dims):
        assert len(input_dims) == 1, "Can only use 1 input"
        c, w, h = input_dims[0]
        c2, w2, h2 = c*4, w//2, h//2
        self.elements = c*w*h
        assert c*h*w == c2*h2*w2, "Uneven input dimensions"
        return [(c2, w2, h2)]


class HaarUpsampling(nn.Module):
    '''Uses Haar wavelets to merge 4 channels into one, with double the
    width and height.'''

    def __init__(self, dims_in):
        super().__init__()

        self.in_channels = dims_in[0][0] // 4
        self.haar_weights = torch.ones(4, 1, 2, 2)

        self.haar_weights[1, 0, 0, 1] = -1
        self.haar_weights[1, 0, 1, 1] = -1

        self.haar_weights[2, 0, 1, 0] = -1
        self.haar_weights[2, 0, 1, 1] = -1

        self.haar_weights[3, 0, 1, 0] = -1
        self.haar_weights[3, 0, 0, 1] = -1

        self.haar_weights *= 0.5
        self.haar_weights = torch.cat([self.haar_weights]*self.in_channels, 0)
        self.haar_weights = nn.Parameter(self.haar_weights)
        self.haar_weights.requires_grad = False

    def forward(self, x, rev=False):
        if rev:
            return [F.conv2d(x[0], self.haar_weights,
                             bias=None, stride=2, groups=self.in_channels)]
        else:
            return [F.conv_transpose2d(x[0], self.haar_weights,
                                       bias=None, stride=2,
                                       groups=self.in_channels)]

    def jacobian(self, x, rev=False):
        # TODO respect batch dimension and .cuda()
        return 0

    def output_dims(self, input_dims):
        assert len(input_dims) == 1, "Can only use 1 input"
        c, w, h = input_dims[0]
        c2, w2, h2 = c//4, w*2, h*2
        assert c*h*w == c2*h2*w2, "Uneven input dimensions"
        return [(c2, w2, h2)]


class Flatten(nn.Module):
    '''Flattens N-D tensors into 1-D tensors.'''
    def __init__(self, dims_in):
        super().__init__()
        self.size = dims_in[0]

    def forward(self, x, rev=False):
        if not rev:
            return [x[0].view(x[0].shape[0], -1)]
        else:
            return [x[0].view(x[0].shape[0], *self.size)]

    def jacobian(self, x, rev=False):
        # TODO respect batch dimension and .cuda()
        return 0

    def output_dims(self, input_dims):
        return [(int(np.prod(input_dims[0])),)]


class Reshape(nn.Module):
    '''reshapes N-D tensors into target dim tensors.'''
    def __init__(self, dims_in, target_dim):
        super().__init__()
        self.size = dims_in[0]
        self.target_dim = target_dim
        assert int(np.prod(dims_in[0])) == int(np.prod(self.target_dim)), "Output and input dim don't match."

    def forward(self, x, rev=False):
        if not rev:
            return [x[0].reshape(x[0].shape[0], *self.target_dim)]
        else:
            return [x[0].reshape(x[0].shape[0], *self.size)]

    def jacobian(self, x, rev=False):
        return 0.

    def output_dims(self, dim):
        return [self.target_dim]

import warnings

def _deprecated_by(orig_class):
    class deprecated_class(orig_class):
        def __init__(self, *args, **kwargs):

            warnings.warn(F"{self.__class__.__name__} is deprecated and will be removed in the public release. "
                          F"Use {orig_class.__name__} instead.",
                          DeprecationWarning)
            super().__init__(*args, **kwargs)

    return deprecated_class

i_revnet_downsampling = _deprecated_by(IRevNetDownsampling)
i_revnet_upsampling = _deprecated_by(IRevNetUpsampling)
haar_multiplex_layer = _deprecated_by(HaarDownsampling)
haar_restore_layer = _deprecated_by(HaarUpsampling)
flattening_layer = _deprecated_by(Flatten)
reshape_layer = _deprecated_by(Reshape)