import torch
import torch.nn as nn
import numpy as np


class GaussianScaleMixtureOutput(nn.Module):
    def __init__(self, num_gaussians):
        super().__init__()
        self.num_gaussians = num_gaussians
        self.num_channels = 2 * num_gaussians + 1
        self.softmax = nn.Softmax2d()

    def forward(self, x):
        assert x.size(1) == self.num_channels

        weights, variances, mean  = torch.split(x, self.num_gaussians, dim=1)
        variances = torch.exp(variances)
        weights = self.softmax(weights)
        return mean, variances, weights


class PowerExponentialOutput(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_channels = 2
        self.relu = nn.ReLU()

    def forward(self, x):
        assert x.size(1) == 2

        mean, variance = torch.split(x, 1, dim=1)
        #mean = self.relu(mean)

        variance = torch.exp(variance)
        return mean, variance