import torch
import torch.nn as nn
import torch.distributions as dist
import numpy as np
from utils import *

# (1) not sure why dtype is explicitly required in some places to force float32
dtype = torch.float32

class GP(nn.Module):

    def __init__(self, dim, X, y, kernel, variance=1.0, N_max=None):
        super(GP, self).__init__()

        self.dim = torch.tensor([dim], requires_grad=False)
        self.kernel = kernel
        self.variance = torch.nn.Parameter(
                                transform_backward(torch.tensor([variance])))

        if torch.is_tensor(X):
            self.X = X
        else:
            self.X = torch.tensor(X, requires_grad=False, dtype=dtype)

        self.N_max = N_max
        self.N = self.X.size()[0]

        if isinstance(y, Sparse1DTensor):
            self.y = y
            ix = torch.tensor([k for k in y.ix.keys()], dtype=torch.int64)
            self.get_batch = BatchIndices(None, ix, self.N_max)
        else:
            # NOTE: see (1)
            self.y = torch.tensor(y.squeeze(), dtype=dtype,
                                  requires_grad=False)
            self.get_batch = BatchIndices(self.N, None, self.N_max)

    def get_cov(self, ix=None):

        if ix is None:
            ix = torch.arange(0, self.N)

        return torch.potrf(self.kernel(self.X[ix])
                           + torch.eye(ix.numel())
                                *transform_forward(self.variance),
                           upper=False)

    def forward(self, ix=None):

        if ix is None:
            ix = self.get_batch()

        mn = torch.zeros(ix.numel())
        cov = self.get_cov(ix=ix)
        pdf = dist.multivariate_normal.MultivariateNormal(mn, scale_tril=cov)

        return -pdf.log_prob(self.y[ix])

    def posterior(self, Xtest):
        # assumes stationary kernel

        with torch.no_grad():
            if isinstance(self.y, Sparse1DTensor):
                ix = self.get_batch.ix
                Ks = self.kernel(self.X[ix], Xtest)
                L = self.get_cov(ix)
                alpha = torch.trtrs(Ks, L, upper=False)[0]
                fmean = torch.matmul(torch.t(alpha),
                                     torch.trtrs(self.y.v.squeeze(), L,
                                                 upper=False)[0])
            else:
                Ks = self.kernel(self.X, Xtest)
                L = self.get_cov()
                alpha = torch.trtrs(Ks, L, upper=False)[0]
                fmean = torch.matmul(torch.t(alpha),
                                     torch.trtrs(self.y, L, upper=False)[0])
            fvar = transform_forward(self.kernel.variance) - (alpha**2).sum(0)

            return fmean, fvar.reshape((-1,1))

class GPLVM(nn.Module):

    def __init__(self, dim, X, Y, kernel, D_max=None, **kwargs):
        super(GPLVM, self).__init__()

        if torch.is_tensor(X):
            self.X = torch.nn.Parameter(X)
        else:
            # NOTE: see (1)
            self.X = torch.nn.Parameter(torch.tensor(X, dtype=dtype))

        self.GPs = nn.ModuleList([])
        if isinstance(Y, np.ndarray):
            self.D = Y.shape[1]
            for d in range(self.D):
                ix = np.where(np.invert(np.isnan(Y[:,d])))[0]
                y = Sparse1DTensor(Y[ix,d], torch.tensor(ix))
                self.GPs.append(GP(dim, self.X, y, kernel, **kwargs))
        elif isinstance(Y, list):
            # assumes col indexing starts at 0 and is (integer-)continuous
            self.D = int(np.max(Y[2])) + 1
            for d in range(self.D):
                ix = np.where(Y[2]==d)[0]
                y = Sparse1DTensor(Y[0][ix], torch.tensor(Y[1][ix]))
                self.GPs.append(GP(dim, self.X, y, kernel, **kwargs))
        else:
            assert False, 'Bad Y input'

        if D_max is None:
            self.D_max = self.D
        else:
            self.D_max = D_max

        self.dim = dim
        self.kernel = kernel
        for j in range(1, self.D):
            self.GPs[j].variance = self.GPs[0].variance
        self.variance = self.GPs[0].variance

        self.get_batch = BatchIndices(self.D, None, self.D_max)

    def forward(self, ix=None):

        if ix is None:
            ix = self.get_batch()

        lp = torch.tensor([0.])
        for j in ix:
            lp += self.GPs[j]()

        return lp*self.D/self.D_max