from __future__ import absolute_import
import scipy.stats

import autograd.numpy as np
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.extend import primitive, defvjp


pdf    =  primitive(scipy.stats.multivariate_normal.pdf)
logpdf =  primitive(scipy.stats.multivariate_normal.logpdf)
entropy = primitive(scipy.stats.multivariate_normal.entropy)

# With thanks to Eric Bresch.
# Some formulas are from
# "An extended collection of matrix derivative results
#  for forward and reverse mode algorithmic differentiation"
# by Mike Giles
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf

def generalized_outer_product(x):
    if np.ndim(x) == 1:
        return np.outer(x, x)
    return np.matmul(x, np.swapaxes(x, -1, -2))

def covgrad(x, mean, cov, allow_singular=False):
    if allow_singular:
        raise NotImplementedError("The multivariate normal pdf is not "
                "differentiable w.r.t. a singular covariance matix")
    J = np.linalg.inv(cov)
    solved = np.matmul(J, np.expand_dims(x - mean, -1))
    return 1./2 * (generalized_outer_product(solved) - J)

def solve(allow_singular):
    if allow_singular:
        return lambda A, x: np.dot(np.linalg.pinv(A), x)
    else:
        return np.linalg.solve

defvjp(logpdf,
       lambda ans, x, mean, cov, allow_singular=False:
       unbroadcast_f(x, lambda g: -np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T),
       lambda ans, x, mean, cov, allow_singular=False:
       unbroadcast_f(mean, lambda g:  np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T),
       lambda ans, x, mean, cov, allow_singular=False:
       unbroadcast_f(cov, lambda g: np.reshape(g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular)))

# Same as log pdf, but multiplied by the pdf (ans).
defvjp(pdf,
       lambda ans, x, mean, cov, allow_singular=False:
       unbroadcast_f(x, lambda g: -np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T),
       lambda ans, x, mean, cov, allow_singular=False:
       unbroadcast_f(mean, lambda g:  np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T),
       lambda ans, x, mean, cov, allow_singular=False:
       unbroadcast_f(cov, lambda g: np.reshape(ans * g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular)))

defvjp(entropy, None,
       lambda ans, mean, cov:
       unbroadcast_f(cov, lambda g: 0.5 * g * np.linalg.inv(cov).T))