"""
Implementation of models from paper.
"""
import torch
import torch.nn as nn
import torch.nn.init as init
from .layers import AdditiveCouplingLayer

def _build_relu_network(latent_dim, hidden_dim, num_layers):
    """Helper function to construct a ReLU network of varying number of layers."""
    _modules = [ nn.Linear(latent_dim, hidden_dim) ]
    for _ in range(num_layers):
        _modules.append( nn.Linear(hidden_dim, hidden_dim) )
        _modules.append( nn.ReLU() )
        _modules.append( nn.BatchNorm1d(hidden_dim) )
    _modules.append( nn.Linear(hidden_dim, latent_dim) )
    return nn.Sequential( *_modules )
    

class NICEModel(nn.Module):
    """
    Replication of model from the paper:
      "Nonlinear Independent Components Estimation",
      Laurent Dinh, David Krueger, Yoshua Bengio (2014)
      https://arxiv.org/abs/1410.8516

    Contains the following components:
    * four additive coupling layers with nonlinearity functions consisting of
      five-layer RELUs
    * a diagonal scaling matrix output layer
    """
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(NICEModel, self).__init__()
        assert (input_dim % 2 == 0), "[NICEModel] only even input dimensions supported for now"
        assert (num_layers > 2), "[NICEModel] num_layers must be at least 3"
        self.input_dim = input_dim
        half_dim = int(input_dim / 2)
        self.layer1 = AdditiveCouplingLayer(input_dim, 'odd', _build_relu_network(half_dim, hidden_dim, num_layers))
        self.layer2 = AdditiveCouplingLayer(input_dim, 'even', _build_relu_network(half_dim, hidden_dim, num_layers))
        self.layer3 = AdditiveCouplingLayer(input_dim, 'odd', _build_relu_network(half_dim, hidden_dim, num_layers))
        self.layer4 = AdditiveCouplingLayer(input_dim, 'even', _build_relu_network(half_dim, hidden_dim, num_layers))
        self.scaling_diag = nn.Parameter(torch.ones(input_dim))

        # randomly initialize weights:
        for p in self.layer1.parameters():
            if len(p.shape) > 1:
                init.kaiming_uniform_(p, nonlinearity='relu')
            else:
                init.normal_(p, mean=0., std=0.001)
        for p in self.layer2.parameters():
            if len(p.shape) > 1:
                init.kaiming_uniform_(p, nonlinearity='relu')
            else:
                init.normal_(p, mean=0., std=0.001)
        for p in self.layer3.parameters():
            if len(p.shape) > 1:
                init.kaiming_uniform_(p, nonlinearity='relu')
            else:
                init.normal_(p, mean=0., std=0.001)
        for p in self.layer4.parameters():
            if len(p.shape) > 1:
                init.kaiming_uniform_(p, nonlinearity='relu')
            else:
                init.normal_(p, mean=0., std=0.001)        


    def forward(self, xs):
        """
        Forward pass through all invertible coupling layers.
        
        Args:
        * xs: float tensor of shape (B,dim).

        Returns:
        * ys: float tensor of shape (B,dim).
        """
        ys = self.layer1(xs)
        ys = self.layer2(ys)
        ys = self.layer3(ys)
        ys = self.layer4(ys)
        ys = torch.matmul(ys, torch.diag(torch.exp(self.scaling_diag)))
        return ys


    def inverse(self, ys):
        """Invert a set of draws from gaussians"""
        with torch.no_grad():
            xs = torch.matmul(ys, torch.diag(torch.reciprocal(torch.exp(self.scaling_diag))))
            xs = self.layer4.inverse(xs)
            xs = self.layer3.inverse(xs)
            xs = self.layer2.inverse(xs)
            xs = self.layer1.inverse(xs)
        return xs