# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


import numpy as np
from scipy import special


# some variables
sq2 = np.sqrt(2)
eps = np.finfo(np.float32).eps
l2p = np.log(2) + np.log(np.pi)


def joint_min(mu: np.ndarray, var: np.ndarray, with_derivatives: bool=False) -> np.ndarray:
    """
    Computes the probability of every given point to be the minimum
    based on the EPMGP[1] algorithm.
    [1] J. Cunningham, P. Hennig, and S. Lacoste-Julien.
    Gaussian probabilities and expectation propagation.
    under review. Preprint at arXiv, November 2011.

    :param mu: Mean value of each of the N points, dims (N,).
    :param var: Covariance matrix for all points, dims (N, N).
    :param with_derivatives: If True than also the gradients are computed.
    :returns: pmin distribution, dims (N,1).
    """

    logP = np.zeros(mu.shape)
    D = mu.shape[0]
    if with_derivatives:
        dlogPdMu = np.zeros((D, D))
        dlogPdSigma = np.zeros((D, int(0.5 * D * (D + 1))))
        dlogPdMudMu = np.zeros((D, D, D))
    for i in range(mu.shape[0]):

        # logP[k] ) self._min_factor(mu, var, 0)
        a = min_factor(mu, var, i)

        logP[i] = next(a)
        if with_derivatives:
            dlogPdMu[i, :] = next(a).T
            dlogPdMudMu[i, :, :] = next(a)
            dlogPdSigma[i, :] = next(a).T

    logP[np.isinf(logP)] = -500
    # re-normalize at the end, to smooth out numerical imbalances:
    logPold = logP
    Z = np.sum(np.exp(logPold))
    maxLogP = np.max(logP)
    s = maxLogP + np.log(np.sum(np.exp(logP - maxLogP)))
    s = maxLogP if np.isinf(s) else s

    logP = logP - s
    if not with_derivatives:
        return logP

    dlogPdMuold = dlogPdMu
    dlogPdSigmaold = dlogPdSigma
    dlogPdMudMuold = dlogPdMudMu
    # adjust derivatives, too. This is a bit tedious.
    Zm = sum(np.rot90((np.exp(logPold) * np.rot90(dlogPdMuold, 1)), 3)) / Z
    Zs = sum(np.rot90((np.exp(logPold) * np.rot90(dlogPdSigmaold, 1)), 3)) / Z

    dlogPdMu = dlogPdMuold - Zm
    dlogPdSigma = dlogPdSigmaold - Zs

    ff = np.einsum('ki,kj->kij', dlogPdMuold, dlogPdMuold)
    gg = np.einsum('kij,k->ij', dlogPdMudMuold + ff, np.exp(logPold)) / Z
    Zij = Zm.T * Zm
    adds = np.reshape(-gg + Zij, (1, D, D))
    dlogPdMudMu = dlogPdMudMuold + adds
    return logP, dlogPdMu, dlogPdSigma, dlogPdMudMu


def min_factor(Mu, Sigma, k, gamma=1):
    D = Mu.shape[0]
    logS = np.zeros((D - 1,))
    # mean time first moment
    MP = np.zeros((D - 1,))

    # precision, second moment
    P = np.zeros((D - 1,))

    M = np.copy(Mu)
    V = np.copy(Sigma)
    b = False
    d = np.NaN
    for count in range(50):
        diff = 0
        for i in range(D - 1):
            l = i if i < k else i + 1  # noqa: E741 to be consistent with paper notation
            try:
                M, V, P[i], MP[i], logS[i], d = lt_factor(k, l, M, V,
                                                          MP[i], P[i], gamma)
            except Exception as e:
                raise

            if np.isnan(d):
                break
            diff += np.abs(d)
        if np.isnan(d):
            break
        if np.abs(diff) < 0.001:
            b = True
            break
    if np.isnan(d):
        logZ = -np.Infinity
        yield logZ
        dlogZdMu = np.zeros((D, 1))
        yield dlogZdMu

        dlogZdMudMu = np.zeros((D, D))
        yield dlogZdMudMu
        dlogZdSigma = np.zeros((int(0.5 * (D * (D + 1))), 1))
        yield dlogZdSigma
        mvmin = [Mu[k], Sigma[k, k]]
        yield mvmin
    else:
        # evaluate log Z:
        C = np.eye(D) / sq2
        C[k, :] = -1 / sq2
        C = np.delete(C, k, 1)

        R = np.sqrt(P.T) * C
        r = np.sum(MP.T * C, 1)
        mp_not_zero = np.where(MP != 0)
        mpm = MP[mp_not_zero] * MP[mp_not_zero] / P[mp_not_zero]
        mpm = sum(mpm)

        s = sum(logS)
        IRSR = (np.eye(D - 1) + np.dot(np.dot(R.T, Sigma), R))
        rSr = np.dot(np.dot(r.T, Sigma), r)
        A = np.dot(R, np.linalg.solve(IRSR, R.T))

        A = 0.5 * (A.T + A)  # ensure symmetry.
        b = (Mu + np.dot(Sigma, r))
        Ab = np.dot(A, b)
        try:
            cIRSR = np.linalg.cholesky(IRSR)
        except np.linalg.LinAlgError:
            try:
                cIRSR = np.linalg.cholesky(IRSR + 1e-10 * np.eye(IRSR.shape[0]))
            except np.linalg.LinAlgError:
                cIRSR = np.linalg.cholesky(IRSR + 1e-6 * np.eye(IRSR.shape[0]))
        dts = 2 * np.sum(np.log(np.diagonal(cIRSR)))
        logZ = 0.5 * (rSr - np.dot(b.T, Ab) - dts) + np.dot(Mu.T, r) + s - 0.5 * mpm
        yield logZ
        btA = np.dot(b.T, A)

        dlogZdMu = r - Ab
        yield dlogZdMu
        dlogZdMudMu = -A
        yield dlogZdMudMu
        dlogZdSigma = -A - 2 * np.outer(r, Ab.T) + np.outer(r, r.T) \
                      + np.outer(btA.T, Ab.T)
        dlogZdSigma2 = np.zeros_like(dlogZdSigma)
        np.fill_diagonal(dlogZdSigma2, np.diagonal(dlogZdSigma))
        dlogZdSigma = 0.5 * (dlogZdSigma + dlogZdSigma.T - dlogZdSigma2)
        dlogZdSigma = np.rot90(dlogZdSigma, k=2)[np.triu_indices(D)][::-1]
        yield dlogZdSigma


def lt_factor(s, l, M, V, mp, p, gamma):
    cVc = (V[l, l] - 2 * V[s, l] + V[s, s]) / 2.0
    Vc = (V[:, l] - V[:, s]) / sq2
    cM = (M[l] - M[s]) / sq2
    cVnic = np.max([cVc / (1 - p * cVc), 0])
    cmni = cM + cVnic * (p * cM - mp)
    z = cmni / np.sqrt(cVnic + 1e-25)
    if np.isnan(z):
        z = -np.inf
    e, lP, exit_flag = log_relative_gauss(z)
    if exit_flag == 0:
        alpha = e / np.sqrt(cVnic)
        # beta  = alpha * (alpha + cmni / cVnic);
        # r     = beta * cVnic / (1 - cVnic * beta);
        beta = alpha * (alpha * cVnic + cmni)
        r = beta / (1 - beta)
        # new message
        pnew = r / cVnic
        mpnew = r * (alpha + cmni / cVnic) + alpha

        # update terms
        dp = np.max([-p + eps, gamma * (pnew - p)])  # at worst, remove message
        dmp = np.max([-mp + eps, gamma * (mpnew - mp)])
        d = np.max([dmp, dp])  # for convergence measures

        pnew = p + dp
        mpnew = mp + dmp
        # project out to marginal
        Vnew = V - dp / (1 + dp * cVc) * np.outer(Vc, Vc)

        Mnew = M + (dmp - cM * dp) / (1 + dp * cVc) * Vc
        if np.any(np.isnan(Vnew)):
            raise Exception("an error occurs while running expectation "
                            "propagation in entropy search. "
                            "Resulting variance contains NaN")
        # % there is a problem here, when z is very large
        logS = lP - 0.5 * (np.log(beta) - np.log(pnew) - np.log(cVnic)) \
               + (alpha * alpha) / (2 * beta) * cVnic

    elif exit_flag == -1:
        d = np.NAN
        Mnew = 0
        Vnew = 0
        pnew = 0
        mpnew = 0
        logS = -np.Infinity
    elif exit_flag == 1:
        d = 0
        # remove message from marginal:
        # new message
        pnew = 0
        mpnew = 0
        # update terms
        dp = -p  # at worst, remove message
        dmp = -mp
        d = max([dmp, dp])  # for convergence measures
        # project out to marginal
        Vnew = V - dp / (1 + dp * cVc) * (np.outer(Vc, Vc))
        Mnew = M + (dmp - cM * dp) / (1 + dp * cVc) * Vc
        logS = 0
    return Mnew, Vnew, pnew, mpnew, logS, d


def log_relative_gauss(z):
    """
    log_relative_gauss
    """
    if z < -6:
        return 1, -1.0e12, -1
    if z > 6:
        return 0, 0, 1
    else:
        logphi = -0.5 * (z * z + l2p)
        logPhi = np.log(.5 * special.erfc(-z / sq2))
        e = np.exp(logphi - logPhi)
    return e, logPhi, 0