"""
Copyright 2019, ETH Zurich

This file is part of L3C-PyTorch.

L3C-PyTorch is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
any later version.

L3C-PyTorch is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with L3C-PyTorch.  If not, see <https://www.gnu.org/licenses/>.

--------------------------------------------------------------------------------

This class is based on the TensorFlow code of PixelCNN++:
    https://github.com/openai/pixel-cnn/blob/master/pixel_cnn_pp/nn.py
In contrast to that code, we predict mixture weights pi for each channel, i.e., mixture weights are "non-shared".
Also, x_min, x_max and L are parameters, and we implement a function to get the CDF of a channel.

# ------
# Naming
# ------

Note that we use the following names through the code, following the code PixelCNN++:
    - x: targets, e.g., the RGB image for scale 0
    - l: for the output of the network;
      In Fig. 2 in our paper, l is the final output, denoted with p(z^(s-1) | f^(s)), i.e., it contains the parameters
      for the mixture weights.
"""

from collections import namedtuple

import torch
import torch.nn.functional as F
import torchvision
from fjcommon import functools_ext as ft

import vis.grid
import vis.summarizable_module
from modules import quantizer

# Note that for RGB, we predict the parameters mu, sigma, pi and lambda. Since RGB has C==3 channels, it so happens that
# the total number of channels needed to predict the 4 parameters is 4 * C * K (for K mixtures, see final paragraphs of
# Section 3.4 in the paper). Note that for an input of, e.g., C == 4 channels, we would need 3 * C * K + 6 * K channels
# to predict all parameters. To understand this, see Eq. (7) in the paper, where it can be seen that for \tilde \mu_4,
# we would need 3 lambdas.
# We do not implement this case here, since it would complicate the code unnecessarily.
_NUM_PARAMS_RGB = 4  # mu, sigma, pi, lambda
_NUM_PARAMS_OTHER = 3  # mu, sigma, pi

_LOG_SCALES_MIN = -7.
_MAX_K_FOR_VIS = 10


CDFOut = namedtuple('CDFOut', ['logit_probs_c_sm',
                               'means_c',
                               'log_scales_c',
                               'K',
                               'targets'])


def non_shared_get_Kp(K, C):
    """ Get Kp=number of channels to predict. See note where we define _NUM_PARAMS_RGB above """
    if C == 3:  # finest scale
        return _NUM_PARAMS_RGB * C * K
    else:
        return _NUM_PARAMS_OTHER * C * K


def non_shared_get_K(Kp, C):
    """ Inverse of non_shared_get_Kp, get back K=number of mixtures """
    if C == 3:
        return Kp // (_NUM_PARAMS_RGB * C)
    else:
        return Kp // (_NUM_PARAMS_OTHER * C)


# --------------------------------------------------------------------------------


class DiscretizedMixLogisticLoss(vis.summarizable_module.SummarizableModule):
    def __init__(self, rgb_scale: bool, x_min=0, x_max=255, L=256):
        """
        :param rgb_scale: Whether this is the loss for the RGB scale. In that case,
            use_coeffs=True
            _num_params=_NUM_PARAMS_RGB == 4, since we predict coefficients lambda. See note above.
        :param x_min: minimum value in targets x
        :param x_max: maximum value in targets x
        :param L: number of symbols
        """
        super(DiscretizedMixLogisticLoss, self).__init__()
        self.rgb_scale = rgb_scale
        self.x_min = x_min
        self.x_max = x_max
        self.L = L
        # whether to use coefficients lambda to weight means depending on previously outputed means.
        self.use_coeffs = rgb_scale
        # P means number of different variables contained in l, l means output of network
        self._num_params = _NUM_PARAMS_RGB if rgb_scale else _NUM_PARAMS_OTHER

        # NOTE: in contrast to the original code, we use a sigmoid (instead of a tanh)
        # The optimizer seems to not care, but it would probably be more principaled to use a tanh
        # Compare with L55 here: https://github.com/openai/pixel-cnn/blob/master/pixel_cnn_pp/nn.py#L55
        self._nonshared_coeffs_act = torch.sigmoid

        # Adapted bounds for our case.
        self.bin_width = (x_max - x_min) / (L-1)
        self.x_lower_bound = x_min + 0.001
        self.x_upper_bound = x_max - 0.001

        self._extra_repr = 'DMLL: x={}, L={}, coeffs={}, P={}, bin_width={}'.format(
                (self.x_min, self.x_max), self.L, self.use_coeffs, self._num_params, self.bin_width)

    def to_sym(self, x):
        return quantizer.to_sym(x, self.x_min, self.x_max, self.L)

    def to_bn(self, S):
        return quantizer.to_bn(S, self.x_min, self.x_max, self.L)

    def extra_repr(self):
        return self._extra_repr

    @staticmethod
    def to_per_pixel(entropy, C):
        N, H, W = entropy.shape
        return entropy.sum() / (N*C*H*W)  # NHW -> scalar

    def cdf_step_non_shared(self, l, targets, c_cur, C, x_c=None) -> CDFOut:
        assert c_cur < C

        # NKHW         NKHW     NKHW
        logit_probs_c, means_c, log_scales_c, K = self._extract_non_shared_c(c_cur, C, l, x_c)

        logit_probs_c_softmax = F.softmax(logit_probs_c, dim=1)  # NKHW, pi_k
        return CDFOut(logit_probs_c_softmax, means_c, log_scales_c, K, targets.to(l.device))

    def sample(self, l, C):
        return self._non_shared_sample(l, C)

    def forward(self, x, l, scale=0):
        """
        :param x: labels, i.e., NCHW, float
        :param l: predicted distribution, i.e., NKpHW, see above
        :return: log-likelihood, as NHW if shared, NCHW if non_shared pis
        """
        assert x.min() >= self.x_min and x.max() <= self.x_max, '{},{} not in {},{}'.format(
                x.min(), x.max(), self.x_min, self.x_max)

        # Extract ---
        #  NCKHW      NCKHW  NCKHW
        x, logit_pis, means, log_scales, K = self._extract_non_shared(x, l)

        # visualize pi, means, variances
        self.summarizer.register_images(
                'val', {f'dmll/{scale}/c{c}': lambda c=c: _visualize_params(logit_pis, means, log_scales, c)
                        for c in range(x.shape[1])})

        centered_x = x - means  # NCKHW

        # Calc P = cdf_delta
        # all of the following is NCKHW
        inv_stdv = torch.exp(-log_scales)  # <= exp(7), is exp(-sigma), inverse std. deviation, i.e., sigma'
        plus_in = inv_stdv * (centered_x + self.bin_width/2)  # sigma' * (x - mu + 0.5)
        cdf_plus = torch.sigmoid(plus_in)  # S(sigma' * (x - mu + 1/255))
        min_in = inv_stdv * (centered_x - self.bin_width/2)  # sigma' * (x - mu - 1/255)
        cdf_min = torch.sigmoid(min_in)  # S(sigma' * (x - mu - 1/255)) == 1 / (1 + exp(sigma' * (x - mu - 1/255))
        # the following two follow from the definition of the logistic distribution
        log_cdf_plus = plus_in - F.softplus(plus_in)  # log probability for edge case of 0
        log_one_minus_cdf_min = -F.softplus(min_in)  # log probability for edge case of 255
        # NCKHW, P^k(c)
        cdf_delta = cdf_plus - cdf_min  # probability for all other cases, essentially log_cdf_plus + log_one_minus_cdf_min

        # NOTE: the original code has another condition here:
        #   tf.where(cdf_delta > 1e-5,
        #            tf.log(tf.maximum(cdf_delta, 1e-12)),
        #            log_pdf_mid - np.log(127.5)
        #            )
        # which handles the extremly low porbability case. Since this is only there to stabilize training,
        # and we get fine training without it, I decided to drop it
        #
        # so, we have the following if, where I put in the x_upper_bound and x_lower_bound values for RGB
        # if x < 0.001:                         cond_C
        #       log_cdf_plus                    out_C
        # elif x > 254.999:                     cond_B
        #       log_one_minus_cdf_min           out_B
        # else:
        #       log(cdf_delta)                  out_A
        out_A = torch.log(torch.clamp(cdf_delta, min=1e-12))
        # NOTE, we adapt the bounds for our case
        cond_B = (x > self.x_upper_bound).float()
        out_B = (cond_B * log_one_minus_cdf_min + (1. - cond_B) * out_A)
        cond_C = (x < self.x_lower_bound).float()
        # NCKHW, =log(P^k(c))
        log_probs = cond_C * log_cdf_plus + (1. - cond_C) * out_B

        # combine with pi, NCKHW, (-inf, 0]
        log_probs_weighted = log_probs.add(
                log_softmax(logit_pis, dim=2))  # (-inf, 0]

        # final log(P), NCHW
        return -log_sum_exp(log_probs_weighted, dim=2)  # NCHW

    def _extract_non_shared(self, x, l):
        """
        :param x: targets, NCHW
        :param l: output of net, NKpHW, see above
        :return:
            x NC1HW,
            logit_probs NCKHW (probabilites of scales, i.e., \pi_k)
            means NCKHW,
            log_scales NCKHW (variances),
            K (number of mixtures)
        """
        N, C, H, W = x.shape
        Kp = l.shape[1]

        K = non_shared_get_K(Kp, C)

        # we have, for each channel: K pi / K mu / K sigma / [K coeffs]
        # note that this only holds for C=3 as for other channels, there would be more than 3*K coeffs
        # but non_shared only holds for the C=3 case
        l = l.reshape(N, self._num_params, C, K, H, W)

        logit_probs = l[:, 0, ...]  # NCKHW
        means = l[:, 1, ...]  # NCKHW
        log_scales = torch.clamp(l[:, 2, ...], min=_LOG_SCALES_MIN)  # NCKHW, is >= -7
        x = x.reshape(N, C, 1, H, W)

        if self.use_coeffs:
            assert C == 3  # Coefficients only supported for C==3, see note where we define _NUM_PARAMS_RGB
            coeffs = self._nonshared_coeffs_act(l[:, 3, ...])  # NCKHW, basically coeffs_g_r, coeffs_b_r, coeffs_b_g
            means_r, means_g, means_b = means[:, 0, ...], means[:, 1, ...], means[:, 2, ...]  # each NKHW
            coeffs_g_r,  coeffs_b_r, coeffs_b_g = coeffs[:, 0, ...], coeffs[:, 1, ...], coeffs[:, 2, ...]  # each NKHW
            means = torch.stack(
                    (means_r,
                     means_g + coeffs_g_r * x[:, 0, ...],
                     means_b + coeffs_b_r * x[:, 0, ...] + coeffs_b_g * x[:, 1, ...]), dim=1)  # NCKHW again

        assert means.shape == (N, C, K, H, W), (means.shape, (N, C, K, H, W))
        return x, logit_probs, means, log_scales, K

    def _extract_non_shared_c(self, c, C, l, x=None):
        """
        Same as _extract_non_shared but only for c-th channel, used to get CDF
        """
        assert c < C, f'{c} >= {C}'

        N, Kp, H, W = l.shape
        K = non_shared_get_K(Kp, C)

        l = l.reshape(N, self._num_params, C, K, H, W)
        logit_probs_c = l[:, 0, c, ...]  # NKHW
        means_c = l[:, 1, c, ...]  # NKHW
        log_scales_c = torch.clamp(l[:, 2, c, ...], min=_LOG_SCALES_MIN)  # NKHW, is >= -7

        if self.use_coeffs and c != 0:
            unscaled_coeffs = l[:, 3, ...]  # NCKHW, coeffs_g_r, coeffs_b_r, coeffs_b_g
            if c == 1:
                assert x is not None
                coeffs_g_r = torch.sigmoid(unscaled_coeffs[:, 0, ...])  # NKHW
                means_c += coeffs_g_r * x[:, 0, ...]
            elif c == 2:
                assert x is not None
                coeffs_b_r = torch.sigmoid(unscaled_coeffs[:, 1, ...])  # NKHW
                coeffs_b_g = torch.sigmoid(unscaled_coeffs[:, 2, ...])  # NKHW
                means_c += coeffs_b_r * x[:, 0, ...] + coeffs_b_g * x[:, 1, ...]

        #      NKHW           NKHW     NKHW
        return logit_probs_c, means_c, log_scales_c, K

    def _non_shared_sample(self, l, C):
        """ sample from model """
        N, Kp, H, W = l.shape
        K = non_shared_get_K(Kp, C)
        l = l.reshape(N, self._num_params, C, K, H, W)

        logit_probs = l[:, 0, ...]  # NCKHW

        # sample mixture indicator from softmax
        u = torch.zeros_like(logit_probs).uniform_(1e-5, 1. - 1e-5)  # NCKHW
        sel = torch.argmax(
                logit_probs - torch.log(-torch.log(u)),  # gumbel sampling
                dim=2)  # argmax over K, results in NCHW, specifies for each c: which of the K mixtures to take
        assert sel.shape == (N, C, H, W), (sel.shape, (N, C, H, W))

        sel = sel.unsqueeze(2)  # NC1HW

        means = torch.gather(l[:, 1, ...], 2, sel).squeeze(2)
        log_scales = torch.clamp(torch.gather(l[:, 2, ...], 2, sel).squeeze(2), min=_LOG_SCALES_MIN)

        # sample from the resulting logistic, which now has essentially 1 mixture component only.
        # We use inverse transform sampling. i.e. X~logistic; generate u ~ Unfirom; x = CDF^-1(u),
        #  where CDF^-1 for the logistic is CDF^-1(y) = \mu + \sigma * log(y / (1-y))
        u = torch.zeros_like(means).uniform_(1e-5, 1. - 1e-5)  # NCHW
        x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))  # NCHW

        if self.use_coeffs:
            assert C == 3

            clamp = lambda x_: torch.clamp(x_, 0, 255.)

            # Be careful about coefficients! We need to use the correct selection mask, namely the one for the G and
            #  B channels, as we update the G and B means! Doing torch.gather(l[:, 3, ...], 2, sel) would be completly
            #  wrong.
            coeffs = torch.sigmoid(l[:, 3, ...])
            sel_g, sel_b = sel[:, 1, ...], sel[:, 2, ...]
            coeffs_g_r = torch.gather(coeffs[:, 0, ...], 1, sel_g).squeeze(1)
            coeffs_b_r = torch.gather(coeffs[:, 1, ...], 1, sel_b).squeeze(1)
            coeffs_b_g = torch.gather(coeffs[:, 2, ...], 1, sel_b).squeeze(1)

            # Note: In theory, we should go step by step over the channels and update means with previously sampled
            # xs. But because of the math above (x = means + ...), we can just update the means here and it's all good.
            x0 = clamp(x[:, 0, ...])
            x1 = clamp(x[:, 1, ...] + coeffs_g_r * x0)
            x2 = clamp(x[:, 2, ...] + coeffs_b_r * x0 + coeffs_b_g * x1)
            x = torch.stack((x0, x1, x2), dim=1)
        return x


def log_prob_from_logits(logit_probs):
    """ numerically stable log_softmax implementation that prevents overflow """
    # logit_probs is NKHW
    m, _ = torch.max(logit_probs, dim=1, keepdim=True)
    return logit_probs - m - torch.log(torch.sum(torch.exp(logit_probs - m), dim=1, keepdim=True))


# TODO(pytorch): replace with pytorch internal in 1.0, there is a bug in 0.4.1
def log_softmax(logit_probs, dim):
    """ numerically stable log_softmax implementation that prevents overflow """
    m, _ = torch.max(logit_probs, dim=dim, keepdim=True)
    return logit_probs - m - torch.log(torch.sum(torch.exp(logit_probs - m), dim=dim, keepdim=True))


def log_sum_exp(log_probs, dim):
    """ numerically stable log_sum_exp implementation that prevents overflow """
    m, _        = torch.max(log_probs, dim=dim)
    m_keep, _   = torch.max(log_probs, dim=dim, keepdim=True)
    # == m + torch.log(torch.sum(torch.exp(log_probs - m_keep), dim=dim))
    return log_probs.sub_(m_keep).exp_().sum(dim=dim).log_().add(m)


def _visualize_params(logits_pis, means, log_scales, channel):
    """
    :param logits_pis:  NCKHW
    :param means: NCKHW
    :param log_scales: NCKHW
    :param channel: int
    :return:
    """
    assert logits_pis.shape == means.shape == log_scales.shape
    logits_pis = logits_pis[0, channel, ...].detach()
    means = means[0, channel, ...].detach()
    log_scales = log_scales[0, channel, ...].detach()

    pis = torch.softmax(logits_pis, dim=0)  # Kdim==0 -> KHW

    mixtures = ft.lconcat(
            zip(_iter_Kdim_normalized(pis, normalize=False),
                _iter_Kdim_normalized(means),
                _iter_Kdim_normalized(log_scales)))
    grid = vis.grid.prep_for_grid(mixtures)
    img = torchvision.utils.make_grid(grid, nrow=3)
    return img


def _iter_Kdim_normalized(t, normalize=True):
    """ normalizes t, then iterates over Kdim (1st dimension) """
    K = t.shape[0]

    if normalize:
        lo, hi = float(t.min()), float(t.max())
        t = t.clamp(min=lo, max=hi).add_(-lo).div_(hi - lo + 1e-5)

    for k in range(min(_MAX_K_FOR_VIS, K)):
        yield t[k, ...]  # HW